You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
8.2 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import math
import time
import cv2
import json
from queue import Queue, Empty
from threading import Thread
from log import logger
from ultralytics import YOLO
from yolov8_det import analysis_yolov8
from capture_queue import CAMERA_QUEUE, camera_mul_thread
with open('cfg.json', 'r') as f:
cfg_dict = json.load(f)
class FrameToVideo(Thread):
"""
农商行员工打瞌睡,玩手机分析类
"""
def __init__(self, camera_name):
super(FrameToVideo, self).__init__()
self.camera = camera_name
# self.queue_img = CAMERA_QUEUE[camera_name]
# self.yolo_model = {'person': YOLO(cfg_dict['model_path']['person']),
# 'head': YOLO(cfg_dict['model_path']['head'])}
self.yolo_model = YOLO(cfg_dict['model_path']['all'])
self.person_target_list = []
self.head_target_list = []
self.phone_target_list = []
@staticmethod
def save_video(frames, fps, fourcc, video_path, w_h_size):
"""
截取后图像保存视频
"""
encoding = list(fourcc)
video_fourcc = cv2.VideoWriter_fourcc(encoding[0], encoding[1], encoding[2], encoding[3])
video_write_obj = cv2.VideoWriter(video_path, video_fourcc, fps, w_h_size)
for frame in frames:
video_write_obj.write(frame)
logger.info(f'生成视频:{video_path}')
@staticmethod
def boundary_treat(frame_x, frame_y, coord):
"""
边界处理,裁剪可能会超出范围
"""
y_min = coord[1] - cfg_dict['error_y']
y_max = coord[3] + cfg_dict['error_y']
x_min = coord[0] - cfg_dict['error_x']
x_max = coord[2] + cfg_dict['error_x']
split_y = {'min': int(y_min) if y_min >= 0 else 0,
'max': int(y_max) if y_max <= frame_y else frame_y}
split_x = {'min': int(x_min) if y_min >= 0 else 0,
'max': int(x_max) if x_max <= frame_x else frame_x}
return split_x, split_y
@staticmethod
def target_match(target_list, coord, frame_img, new_target_list):
"""
遍历目标进行匹配
"""
match_flag = False
for target in target_list:
if target['flag']:
continue
if all([abs(coord[n] - target['coord'][n]) <= 5 for n in range(4)]): # 误差判断
frame_split = frame_img[target['split_y']['min']:target['split_y']['max'],
target['split_x']['min']:target['split_x']['max']]
# cv2.imshow('yang', frame_split)
# cv2.waitKey(2000)
target['frame'].append(frame_split)
target['count'] += 1
target['flag'] = True
# new_target_list.append(target)
match_flag = True
break
else:
continue
return match_flag
def target_analysis(self, target_list, new_target_list, person_coord_list, frame_x, frame_y, frame_img, label):
if not target_list:
for line in person_coord_list:
# label that self define maybe different from model
if label == 'head':
coord = line['head']
elif label == 'person':
coord = line['person']
else:
coord = line['phone']
split_x, split_y = self.boundary_treat(frame_x, frame_y, coord)
# 裁剪大一圈,固定裁剪范围
frame_split = frame_img[split_y['min']:split_y['max'], split_x['min']:split_x['max']]
target_list.append({'frame': [frame_split], 'coord': coord, 'count': 0, 'split_x': split_x,
'split_y': split_y, 'flag': False})
else:
for line in person_coord_list:
coord = line[label]
match_flag = self.target_match(target_list, coord, frame_img, new_target_list)
if not match_flag:
split_x, split_y = self.boundary_treat(frame_x, frame_y, coord)
# 裁剪大一圈,固定裁剪范围
frame_split = frame_img[split_y['min']:split_y['max'], split_x['min']:split_x['max']]
new_target_list.append({'frame': [frame_split], 'coord': coord, 'count': 0, 'split_x': split_x,
'split_y': split_y, 'flag': False})
# 判断帧数,生成视频
for target in target_list:
if len(target['frame']) == cfg_dict['video_length']:
frame_w = target['split_x']['max'] - target['split_x']['min']
frame_h = target['split_y']['max'] - target['split_y']['min']
logger.info(f'开始输出视频:{label}')
self.save_video(target['frame'], cfg_dict['fps'], cfg_dict['video_encoding'],
cfg_dict['video_path'][label] + self.camera + str(int(time.time())) + '.mp4v',
(frame_w, frame_h))
logger.info(f'输出视频结束:{label}')
continue
# 过滤中断没有匹配到的目标
if target['flag']:
target['flag'] = False
new_target_list.append(target)
if label == 'person':
self.person_target_list = new_target_list
else:
self.head_target_list = new_target_list
def frame_analysis(self):
video_capture = cv2.VideoCapture(cfg_dict['test_video_path']) # 本地测试用
while True:
result, frame_img = video_capture.read() # 本地测试用
# try:
# frame_img = self.queue_img.get_nowait()
# except Empty:
# time.sleep(0.01)
# continue
new_person_target_list = []
new_head_target_list = []
new_phone_target_list = []
# 调用模型,逐帧检测
# person_coord_list = analysis_yolov8(frame=frame_img, model_coco=self.yolo_model['person'],
# confidence=cfg_dict['confidence']['person'])
# head_coord_list = analysis_yolov8(frame=frame_img, model_coco=self.yolo_model['head'],
# confidence=cfg_dict['confidence']['head'])
person_coord_list, head_coord_list, phone_coord_list = analysis_yolov8(frame=frame_img,
model_coco=self.yolo_model,
confidence=cfg_dict['confidence'])
frame_y, frame_x, _ = frame_img.shape
logger.debug(f'帧尺寸y:{frame_y},x:{frame_x}')
self.target_analysis(self.person_target_list, new_person_target_list, person_coord_list,
frame_x, frame_y, frame_img, 'person')
self.target_analysis(self.head_target_list, new_head_target_list, head_coord_list, frame_x,
frame_y, frame_img, 'head')
self.target_analysis(self.phone_target_list, new_phone_target_list, phone_coord_list, frame_x,
frame_y, frame_img, 'phone')
def run(self):
self.frame_analysis()
class ViolationJudgmentSend(Thread):
"""
农商行员工打瞌睡,玩手机,结果分析、发送类
开线程运行
"""
def __int__(self):
super(ViolationJudgmentSend, self).__init__()
self.action_model = cfg_dict['model_path']['action']
def video_analysis_sand(self):
while True:
pass
def run(self):
self.video_analysis_sand()
# 程序启动
def process_run():
logger.info('程序启动')
# 接入监控线程
camera_mul_thread()
# 截取视频线程
frame_to_video_obj = [FrameToVideo(camera) for camera in CAMERA_QUEUE]
for line in frame_to_video_obj:
line.start()
# 发送结果线程
send_obj = ViolationJudgmentSend()
send_obj.start()
send_obj.join()
if __name__ == '__main__':
fv = FrameToVideo('camera_01')
fv.frame_analysis()
# process_run()