import cv2
import time
from tqdm import tqdm
from ultralytics import YOLO
from ultralytics.yolo.utils.plotting import Annotator
from d_face import face_detection

def analysis_video(source_path,output_path,people_modle_path,face_modle_path,action_modle_path):
    model_coco = YOLO(people_modle_path)
    action_model = YOLO(action_modle_path)
    cap = cv2.VideoCapture(source_path)

    # 直接从视频的第 frameToStart 帧开始
    frameToStart = 0
    cap.set(cv2.CAP_PROP_POS_FRAMES, frameToStart)

    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_movie = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    if frameToStart != 0:
        count = frameToStart
    else:
        count = 0
    # 标记有没有出现过拉门的动作
    action_flag = {
        "action":0,
        "action_frame":[]
        }
    # 标记有没有出现过人脸
    face_flag = {
        "face":0,
        "frame":[]
        }
    # 定义帧数字典
    XJ_dict = {
        "head":0,
        "tail":0
        }
    while cap.isOpened():
        # Read a frame from the video
        success, frame = cap.read()

        count += 1
        if success:
            # 第一步:用COCO数据集推理
            results_coco = model_coco(frame)
            action_result = action_model(frame)
            
            for r in results_coco:
                annotator = Annotator(frame, line_width=1)   
                boxes = r.boxes
                for box in boxes:                             
                    b = box.xyxy[0]  # get box coordinates in (x1,y1,x2,y2) format  #tensor([ 677.5757,  147.2737, 1182.3381,  707.2565])
                    b_i = b.int() + 1
                    c = box.cls  # tensor([0.])
                    confidence = float(box.conf)
                    confidence = round(confidence, 2)
                    # 过滤置信度0.5以下目标
                    if confidence < 0.5:
                        continue
                    # 当类别为巡检
                    if c.int() == 1:
                        if XJ_dict['head'] == 0 :
                            XJ_dict['head'] = count
                        else:
                            XJ_dict['tail'] = count
                        crop_img = frame[b_i[1]:b_i[3],b_i[0]:b_i[2]]
                        # 人脸检测
                        frame = face_detection(face_modle_path,frame,crop_img,b_i[0],b_i[1],b_i[2],b_i[3],face_flag,count)
                    annotator.box_label(b, model_coco.names[int(c)]+str(confidence),(0,0,255))

            for r_a in action_result:
                annotator_a = Annotator(frame, line_width=1)    
                boxes_a = r_a.boxes
                for box_a in boxes_a:                             
                    b_a = box_a.xyxy[0]  # get box coordinates in (x1,y1,x2,y2) format  #tensor([ 677.5757,  147.2737, 1182.3381,  707.2565])
                    c_a = box_a.cls  # tensor([0.])
                    confidence_a = float(box_a.conf)
                    confidence_a = round(confidence_a, 2)
                    # 过滤置信度0.5以下目标
                    if confidence_a < 0.5:
                        continue
                    # 当类别为check
                    if c_a.int() == 1:
                        action_flag["action"] += 1
                        action_flag["action_frame"].append(count)
                    annotator_a.box_label(b_a, action_model.names[int(c_a)]+str(confidence_a),(255,0,0))
                    
            annotated_a_frame_coco = annotator_a.result()
            output_movie.write(annotated_a_frame_coco)
        else:
            # Break the loop if the end of the video is reached
            break
    
    cap.release()
    output_movie.release()
    # 计算巡检时长
    diff = round((XJ_dict["tail"]-XJ_dict["head"])/fps,2)
    fina_frame = [round(_ /fps,2) for _ in face_flag["frame"]] 
    s = ', '.join(map(str, fina_frame))
    # 拉门时间
    action_frame = [round(_ /fps,2) for _ in action_flag["action_frame"]] 
    s_action = ', '.join(map(str, action_frame))

    return diff,face_flag,s,action_flag,s_action