You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

247 lines
7.8 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
import cv2
import os
import time
from ultralytics import YOLO
import queue
import threading
from config import Q_SZ
from ModelDet.personDet import analysis_yolov8
from tools import Process_tools
# from ModelDet.holisticDet import MediapipeProcess
class DealVideo():
def __init__(self,video_file,video_save_file,person_model,mediapipe_model,pptsmv2_model):
'''
加载数据
'''
self.video_file = video_file
self.video_save_file = video_save_file
# 初始化模型
self.person_model = person_model
self.mediapipe_model = mediapipe_model
self.pptsmv2_model = pptsmv2_model
# 图片检测后队列
self.videoQueue = queue.Queue(maxsize=Q_SZ)
self.frameQueue = queue.Queue(maxsize=0)
#线程
self.get_video_listThread = threading.Thread(target=self.get_video_list)
self.get_video_frameThread = threading.Thread(target=self.get_video_frame)
self.person_detThread = threading.Thread(target=self.person_det)
self.write_videoThread = threading.Thread(target=self.write_video)
def get_video_list(self):
'''
获取数据文件
'''
if os.path.isdir(self.video_file):
video_ext = [".mp4", ".avi",".MP4"]
for maindir, subdir, file_name_list in os.walk(self.video_file):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in video_ext:
self.videoQueue.put(apath)
else:
self.videoQueue.put(self.video_file)
def get_video_frame(self):
'''
对视频进行分帧、每一帧都保存队列
'''
while True:
if ~self.videoQueue.empty():
try:
video_path = self.videoQueue.get()
video_basename = os.path.basename(video_path).split('.')[0]
cap = cv2.VideoCapture(video_path)
frame_list = []
count_fps = 0
while cap.isOpened():
success, frame = cap.read()
if not success:
print(video_path,"Ignoring empty camera frame.")
break
count_fps += 1
frame_dict = {'fps':count_fps,'frame':frame}
frame_list.append(frame_dict)
video_dict = {'video_path':video_path,'frame_list':frame_list,'cap':cap}
self.frameQueue.put(video_dict)
except Exception as e:
print(e)
def person_det(self):
'''
从队列中获取视频帧frame进行第一步人员的检测
'''
while True:
if ~self.frameQueue.empty():
video_frame_dict = self.frameQueue.get()
frame_list = video_frame_dict['frame_list']
frame_result_contact = []
for i in range(len(frame_list)):
if frame_list[i]["fps"] == i + 1:
person_det = analysis_yolov8(frame=frame_list[i]['frame'],
model_coco=self.person_model,
confidence_set=0.5)
# 当前帧检测的结果列表只包含bboxlist
person_list = Process_tools.get_dict_values(person_det)
# 保存第一帧结果为对比坐标
if not frame_result_contact:
bbox_list_all = Process_tools.change_list_dict(fps1=frame_list[i]["fps"],re_list=person_list)
frame_result_contact = bbox_list_all
print("frame_result_contact:",frame_result_contact)
else:
example_dict_list = frame_result_contact
cut_list,example_lst,re_dict_lst = Process_tools.analysis_re01_list(example_list=example_dict_list,
result_list=person_list)
print('cut_list:',cut_list)
print('example_sorted_lst:',example_lst)
print('re_dict_sorted_lst:',re_dict_lst)
# 统计截止时间
time_out_list = Process_tools.statistics_fps(fps_now=frame_list[i]["fps"],re_list=frame_result_contact)
if time_out_list:
# bbox_list = Process_tools.change_dict_list(time_out_list)
# 裁剪保存视频
# cut_dict = {"start_fps":time_out_list[0]['fps'],"stop_fps":frame_list[i]["fps"],'bbox_list':bbox_list}
frame_result_contact = [item for item in frame_result_contact if item not in time_out_list]
# 有目标减少情况
if example_lst:
# cut_dict = {"start_fps":frame_result_contact[0]['fps'],"stop_fps":frame_list[i]["fps"],'bbox_list':example_lst}
frame_result_contact = [item for item in frame_result_contact if item not in example_lst]
# 有新添加目标情况
if re_dict_lst:
update_list = Process_tools.change_list_dict(fps1=frame_list[i]["fps"],re_list=re_dict_lst)
frame_result_contact = frame_result_contact + update_list
print('frame_result_contact:',frame_result_contact)
def write_video(self):
'''
保存成视频
'''
while True:
if ~self.frameQueue.empty():
video_frame_dict = self.frameQueue.get()
video_basename = os.path.basename(video_frame_dict['video_path'])
video_name_save = os.path.join(self.video_save_file, video_basename)
# 原视频帧率和尺寸
cap = video_frame_dict['cap']
fps = cap.get(cv2.CAP_PROP_FPS)
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
# 保存视频
videoWriter =cv2.VideoWriter(video_name_save,cv2.VideoWriter_fourcc('X','V','I','D'),fps,size)
frame_list = video_frame_dict['frame_list']
for i in range(len(frame_list)):
if frame_list[i]["fps"] == i + 1:
videoWriter.write(frame_list[i]["frame"])
else:
break
def run(self):
self.get_video_listThread.start()
self.get_video_frameThread.start()
self.person_detThread.start()
# self.write_videoThread.start()
if __name__ == '__main__':
# 每个视频的时长(单位秒)
dertTime = 5
video = "E:/Bank_files/Bank_02/dataset/video_test"
video_save = 'videos_codes_2'
person_model = YOLO("model_file/yolov8x_person.pt")
# get_seg_video(video_file=video,video_save_path=video_save,dertTime=dertTime)
deal = DealVideo(video_file=video,video_save_file=video_save,person_model=person_model,mediapipe_model='model_file/yolov8x_person.pt',pptsmv2_model='model_file/yolov8x_person.pt')
deal.run()