import numpy as np
import cv2
import os
import time
from tqdm import tqdm
from ultralytics import YOLO
from ultralytics.yolo.utils.plotting import Annotator

from yolov8_det import analysis_yolov8

model_yolo = YOLO("E:/Bank_files/Bank_02/model_files/all_labels.pt")


# 图像文件夹
def get_video_list(path):
    video_ext = [".mp4", ".avi",".MP4"]
    video_names = []
    for maindir, subdir, file_name_list in os.walk(path):
        for filename in file_name_list:
            apath = os.path.join(maindir, filename)
            ext = os.path.splitext(apath)[1]
            if ext in video_ext:
                video_names.append(apath)
    return video_names


# 截取裁剪需要的视频帧
def save_seg_video(video_name,frameToStart,frametoStop,videoWriter,bbox):

    cap = cv2.VideoCapture(video_name)
    count = 0

    while True:

        success, frame = cap.read()

        if success:

            count += 1
            if count <= frametoStop and count > frameToStart:  # 选取起始帧
                print('correct= ', count)
                
                #裁剪视频画面
                frame_target = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]  # (split_height, split_width)

                videoWriter.write(frame_target)
                
        if not success or count >= frametoStop:
            break

        print('end')  



# 对视频的操作
def get_seg_video(video_file,video_save_path,dertTime):


    # 检查路径
    print("frame image save path:{}".format(video_save_path))
    os.makedirs(video_save_path, exist_ok=True)
    
    if os.path.isdir(video_file):
        files = get_video_list(video_file)
    else:
        files = [video_file]

    files.sort()
    video_num = len(files)

    for num in range(video_num):

        # 视频名字
        video_name = files[num]

        print(video_name)
        video_basename = os.path.basename(video_name).split('.')[0]

        cap = cv2.VideoCapture(video_name)
        #帧率
        fps = cap.get(cv2.CAP_PROP_FPS)

        success,frame = cap.read()

        count_fps = 0
        write_fps = 0
        dertTime = 2
        # 每段帧率
        dertF = dertTime * fps

        # while True:

        # 前后帧信息保存

        # result_list = []

        # count_result_num = 0

        while success:

                count_fps  += 1

                # 调用模型,逐帧检测
                results_img = analysis_yolov8(frame=frame,
                                            model_coco=model_yolo,
                                            confidence=0.1)
                

                # result_list.append({count_fps:results_img})

                # if len(result_list) == 5:

                #     result_list.clear()

                
                # num = len(results_img)

                # # 如果只检测到一个人
                if num == 1:

                    # 起始帧
                    write_fps = count_fps

                    stop_fps = write_fps + dertF

                    # 目标检测结果

                    bbox = list(results_img[0].values())[0]
                    w = bbox[2] -bbox[0]
                    h = bbox[3] -bbox[1]
                    size = [int(w),int(h)]
                    # 保存截取视频
                    video_name_save = video_save_path + '/' +  video_basename + '_' +str(write_fps) + '.avi'

                    videoWriter =cv2.VideoWriter(video_name_save,cv2.VideoWriter_fourcc('X','V','I','D'),fps,size)                    

                    save_seg_video(video_name,write_fps,stop_fps,videoWriter,bbox)

                    # result_dict = {count_fps:bbox}

                    # result_list.append(result_dict)
                    
                    # print(count_fps,write_fps,stop_fps,video_name,bbox)

                    break

                # if num == 0:

                #     continue

                # if num > 1:

                #     # print()
                #     pass



if __name__ == '__main__':  
     

    # 每个视频的时长(单位秒)
    dertTime = 5

    video = "E:/Bank_files/Bank_02/dataset/vlc_0711/0711-1.mp4"
    video_save = 'videos_codes_2'

    get_seg_video(video_file=video,video_save_path=video_save,dertTime=dertTime)