import numpy as np
import cv2
import os
import time

from ultralytics import YOLO 
import queue

import threading
from config import Q_SZ

from ModelDet.personDet import analysis_yolov8
from tools import Process_tools
# from ModelDet.holisticDet import MediapipeProcess



class DealVideo():

    def __init__(self,video_file,video_save_file,person_model,mediapipe_model,pptsmv2_model):

        '''
        加载数据
        '''

        self.video_file = video_file
        self.video_save_file = video_save_file

        # 初始化模型

        self.person_model = person_model
        self.mediapipe_model = mediapipe_model
        self.pptsmv2_model = pptsmv2_model

        # 图片检测后队列
        self.videoQueue = queue.Queue(maxsize=Q_SZ)
        self.frameQueue = queue.Queue(maxsize=0)

        #线程
        self.get_video_listThread = threading.Thread(target=self.get_video_list)
        self.get_video_frameThread = threading.Thread(target=self.get_video_frame)
        self.person_detThread = threading.Thread(target=self.person_det)
        self.write_videoThread = threading.Thread(target=self.write_video)
        

    def get_video_list(self):

        '''
        获取数据文件
        '''

        if os.path.isdir(self.video_file):

            video_ext = [".mp4", ".avi",".MP4"]
            for maindir, subdir, file_name_list in os.walk(self.video_file):
                for filename in file_name_list:
                    apath = os.path.join(maindir, filename)
                    ext = os.path.splitext(apath)[1]
                    if ext in video_ext:
                        self.videoQueue.put(apath)

        else:
            self.videoQueue.put(self.video_file)

    def get_video_frame(self):

        '''
        对视频进行分帧、每一帧都保存队列
        '''

        while True:
            if ~self.videoQueue.empty():
                
                try:
                    video_path = self.videoQueue.get()  

                    video_basename = os.path.basename(video_path).split('.')[0]

                    cap = cv2.VideoCapture(video_path)

                    frame_list = []
                    count_fps = 0

                    while cap.isOpened():
                        success, frame = cap.read()
                        if not success:
                            print(video_path,"Ignoring empty camera frame.")
                            break
                        count_fps  += 1

                        frame_dict = {'fps':count_fps,'frame':frame}
                        frame_list.append(frame_dict)

                    video_dict = {'video_path':video_path,'frame_list':frame_list,'cap':cap}
    
                    self.frameQueue.put(video_dict)
                except Exception as e:
                    print(e)

    def person_det(self):

        '''
        从队列中获取视频帧frame,进行第一步人员的检测
        '''

        while True: 

            if ~self.frameQueue.empty():

                video_frame_dict = self.frameQueue.get()

                frame_list = video_frame_dict['frame_list']

                frame_result_contact = []

                for i in range(len(frame_list)):

                    if frame_list[i]["fps"] == i + 1:                

                        person_det = analysis_yolov8(frame=frame_list[i]['frame'],
                                                     model_coco=self.person_model,
                                                     confidence_set=0.5)
                        
                        # 当前帧检测的结果列表,只包含bboxlist
                        person_list = Process_tools.get_dict_values(person_det)

                        # 保存第一帧结果为对比坐标
                        if not frame_result_contact:

                            bbox_list_all = Process_tools.change_list_dict(fps1=frame_list[i]["fps"],re_list=person_list)

                            frame_result_contact = bbox_list_all
                            print("frame_result_contact:",frame_result_contact)

                        else:

                            example_dict_list = frame_result_contact

                            cut_list,example_lst,re_dict_lst = Process_tools.analysis_re01_list(example_list=example_dict_list,
                                                                                                              result_list=person_list)
                            
                            print('cut_list:',cut_list)
                            print('example_sorted_lst:',example_lst)
                            print('re_dict_sorted_lst:',re_dict_lst)


                            # 统计截止时间
                            time_out_list = Process_tools.statistics_fps(fps_now=frame_list[i]["fps"],re_list=frame_result_contact)
                            
                            if time_out_list:

                                # bbox_list = Process_tools.change_dict_list(time_out_list)

                                # 裁剪保存视频
                                # cut_dict = {"start_fps":time_out_list[0]['fps'],"stop_fps":frame_list[i]["fps"],'bbox_list':bbox_list}

                                frame_result_contact = [item for item in frame_result_contact if item not in time_out_list]
                        
                            # 有目标减少情况
                            if example_lst:

                                # cut_dict = {"start_fps":frame_result_contact[0]['fps'],"stop_fps":frame_list[i]["fps"],'bbox_list':example_lst}

                                frame_result_contact = [item for item in frame_result_contact if item not in example_lst]
                                
                            # 有新添加目标情况
                            if re_dict_lst:

                                
                                update_list = Process_tools.change_list_dict(fps1=frame_list[i]["fps"],re_list=re_dict_lst)

                                frame_result_contact = frame_result_contact + update_list

                        print('frame_result_contact:',frame_result_contact)
                      

    def write_video(self): 

        '''  
        保存成视频
        '''

        while True:
            if ~self.frameQueue.empty():
                video_frame_dict = self.frameQueue.get()

                video_basename = os.path.basename(video_frame_dict['video_path'])

                video_name_save = os.path.join(self.video_save_file, video_basename)

                # 原视频帧率和尺寸
                cap = video_frame_dict['cap']
                fps = cap.get(cv2.CAP_PROP_FPS)
                size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))

                # 保存视频
                videoWriter =cv2.VideoWriter(video_name_save,cv2.VideoWriter_fourcc('X','V','I','D'),fps,size) 

                frame_list = video_frame_dict['frame_list']

                for i in range(len(frame_list)):

                    if frame_list[i]["fps"] == i + 1:

                        videoWriter.write(frame_list[i]["frame"])

                    else:
                        break




    def run(self):

        self.get_video_listThread.start()
        self.get_video_frameThread.start()
        self.person_detThread.start()
        # self.write_videoThread.start()



if __name__ == '__main__':  
     

    # 每个视频的时长(单位秒)
    dertTime = 5

    video = "E:/Bank_files/Bank_02/dataset/video_test"
    video_save = 'videos_codes_2'

    person_model = YOLO("model_file/yolov8x_person.pt")

    # get_seg_video(video_file=video,video_save_path=video_save,dertTime=dertTime)
    deal = DealVideo(video_file=video,video_save_file=video_save,person_model=person_model,mediapipe_model='model_file/yolov8x_person.pt',pptsmv2_model='model_file/yolov8x_person.pt')
    deal.run()