You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

313 lines
8.7 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
import cv2
import os
import time
from tqdm import tqdm
from ultralytics import YOLO
from ultralytics.yolo.utils.plotting import Annotator
import queue
from yolov8_det import analysis_yolov8
import threading
from config import Q_SZ
class DealVideo():
def __init__(self,video_file,video_save_file):
self.video_file = video_file
self.video_save_file = video_save_file
# 图片检测后队列
self.videoQueue = queue.Queue(maxsize=Q_SZ)
self.frameQueue = queue.Queue(maxsize=0)
#线程
self.get_video_listThread = threading.Thread(target=self.get_video_list)
self.get_video_frameThread = threading.Thread(target=self.get_video_frame)
self.write_videoThread = threading.Thread(target=self.write_video)
# 图像文件夹
def get_video_list(self):
if os.path.isdir(self.video_file):
video_ext = [".mp4", ".avi",".MP4"]
for maindir, subdir, file_name_list in os.walk(self.video_file):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in video_ext:
self.videoQueue.put(apath)
else:
self.videoQueue.put(self.video_file)
def get_video_frame(self):
while True:
if ~self.videoQueue.empty():
video_path = self.videoQueue.get()
video_basename = os.path.basename(video_path).split('.')[0]
# print('video_path:',video_path)
cap = cv2.VideoCapture(video_path)
frame_list = []
count_fps = 0
while cap.isOpened():
success, frame = cap.read()
if not success:
print("Ignoring empty camera frame.")
break
count_fps += 1
frame_dict = {'fps':count_fps,'frame':frame}
frame_list.append(frame_dict)
video_dict = {'video_path':video_path,'frame_list':frame_list,'cap':cap}
self.frameQueue.put(video_dict)
def write_video(self):
while True:
if ~self.frameQueue.empty():
video_frame_dict = self.frameQueue.get()
video_basename = os.path.basename(video_frame_dict['video_path'])
video_name_save = os.path.join(self.video_save_file, video_basename)
# 原视频帧率和尺寸
cap = video_frame_dict['cap']
fps = cap.get(cv2.CAP_PROP_FPS)
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
# 保存视频
videoWriter =cv2.VideoWriter(video_name_save,cv2.VideoWriter_fourcc('X','V','I','D'),fps,size)
frame_list = video_frame_dict['frame_list']
for i in range(len(frame_list)):
if frame_list[i]["fps"] == i + 1:
videoWriter.write(frame_list[i]["frame"])
else:
break
def run(self):
self.get_video_listThread.start()
self.get_video_frameThread.start()
self.write_videoThread.start()
class Process_tools():
# 图像文件夹
def get_video_list(path):
video_ext = [".mp4", ".avi",".MP4"]
video_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in video_ext:
video_names.append(apath)
return video_names
# 截取裁剪需要的视频帧
def save_seg_video(video_name,frameToStart,frametoStop,videoWriter,bbox):
cap = cv2.VideoCapture(video_name)
count = 0
while True:
success, frame = cap.read()
if success:
count += 1
if count <= frametoStop and count > frameToStart: # 选取起始帧
print('correct= ', count)
#裁剪视频画面
frame_target = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] # (split_height, split_width)
videoWriter.write(frame_target)
if not success or count >= frametoStop:
break
print('end')
# 获得字典中所有values值这个值是列表
def get_dict_values(lst):
"""
获取列表中所有字典的 values 值(如果值是列表)
参数:
lst: 包含字典的列表
返回值:
values: 包含所有字典的 values 值的列表(如果值是列表)
"""
return [value for dictionary in lst for value in dictionary.values() if isinstance(value, list)]
# 解析检测后的结果,为检测后的结果排序
def analysis_sort_list(result_dict):
# print('result_dict:',result_dict)
# 获得检测列表
re_list = result_dict['start_bbox']
# print('re_list:',re_list)
# 获得列表中所有字典的values值
re_bbox_list = get_dict_values(re_list)
# 为检测出来的标注框排序
sorted_lst = sorted(re_bbox_list, key=lambda x: x[0])
return sorted_lst
#对比重叠率高的两个部分,并结合标注框,保存最大的标注框
def contrast_bbox(e_bbox,r_bbox):
e_bbox_min = e_bbox[:2]
r_bbox_min = r_bbox[:2]
bbox_min = [min(x, y) for x, y in zip(e_bbox_min, r_bbox_min)]
e_bbox_max = e_bbox[-2:]
r_bbox_max = r_bbox[-2:]
bbox_max = [max(x, y) for x, y in zip(e_bbox_max, r_bbox_max)]
bbox = bbox_min + bbox_max
return bbox
# 解析result_list列表
def analysis_re01_list(example_dict,result_dict):
# 第一次检测到目标的帧率和信息
example_dict_fps = list(example_dict.keys())[0]
example_sorted_lst = Process_tools.analysis_sort_list(example_dict)
# 当前帧检测结果中所有的检测结果数值
re_dict_fps = list(result_dict.keys())[0]
re_dict_sorted_lst = Process_tools.analysis_sort_list(result_dict)
# 保存前后帧率连续的范围、筛选出相同的部分
cut_list = []
example_temp = []
re_temp = []
for i,ex_bbox in enumerate(example_sorted_lst):
for j,re_bbox in enumerate(re_dict_sorted_lst):
iou = calculate_iou(box1=ex_bbox, box2=re_bbox)
# print(iou)
if iou > 0:
bbox = contrast_bbox(e_bbox=ex_bbox,r_bbox=re_bbox)
cut_list.append({i:bbox})
example_temp.append(ex_bbox)
re_temp.append(re_bbox)
break
else:
continue
example_sorted_lst = [item for item in example_sorted_lst if item not in example_temp]
re_dict_sorted_lst = [item for item in re_dict_sorted_lst if item not in re_temp]
return cut_list,example_sorted_lst,re_dict_sorted_lst
# 计算前后帧率重叠范围
def calculate_iou(box1, box2):
"""
计算两个边界框之间的IoU值
参数:
box1: 边界框1的坐标x1, y1, x2, y2
box2: 边界框2的坐标x1, y1, x2, y2
返回值:
iou: 两个边界框之间的IoU值
"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
# 计算交集区域面积
intersection_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
# 计算边界框1和边界框2的面积
box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
# 计算并集区域面积
union_area = box1_area + box2_area - intersection_area
# 计算IoU值
iou = intersection_area / union_area
return iou
if __name__ == '__main__':
# 每个视频的时长(单位秒)
dertTime = 5
video = "E:/Bank_files/Bank_02/dataset/video_person/after_1/"
video_save = 'videos_codes_2'
# get_seg_video(video_file=video,video_save_path=video_save,dertTime=dertTime)
deal = DealVideo(video_file=video,video_save_file=video_save)
deal.run()