import numpy as np
import cv2
import os
import time

def get_video_list(path):
    video_ext = [".mp4", ".avi",".MP4"]
    video_names = []
    for maindir, subdir, file_name_list in os.walk(path):
        for filename in file_name_list:
            apath = os.path.join(maindir, filename)
            ext = os.path.splitext(apath)[1]
            if ext in video_ext:
                video_names.append(apath)
    return video_names


def save_seg_video(video_name,frameToStart,frametoStop,videoWriter):

    cap = cv2.VideoCapture(video_name)
    count = 0

    while True:

        success, frame = cap.read()

        # count += 1
        if success:
    
            print('frametoStop:',frametoStop,'frameToStart:',frameToStart)
            print('correct2= ', count)
            count += 1
            if count <= frametoStop and count > frameToStart:  # 选取起始帧
                print('correct= ', count)
                videoWriter.write(frame)
                
        if not success or count >= frametoStop:
            break

        print('end')  


def get_seg_video(video_file,video_save_path,dertTime):


    # 检查路径
    print("frame image save path:{}".format(video_save_path))
    os.makedirs(video_save_path, exist_ok=True)
    
    if os.path.isdir(video_file):
        files = get_video_list(video_file)
    else:
        files = [video_file]

    files.sort()
    video_num = len(files)

    for num in range(video_num):

        # 视频名字
        video_name = files[num]
        video_basename = os.path.basename(video_name).split('.')[0]

        cap = cv2.VideoCapture(video_name)

        #帧率
        fps = cap.get(cv2.CAP_PROP_FPS)
        fps = int(fps)

        print(fps)

        # 获得原视频尺寸
        size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))

        # 获取视频总帧数
        total_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))  

        # 每段帧率
        dertF = dertTime * fps

        print('dertF:',dertF)

        #帧率分段
        n = total_frame/dertF

        print(int(n))


        for i in range(int(n) + 1):


            video_name_save = video_save_path + '/' +  video_basename + '_' +str(i) + '.avi'

            videoWriter =cv2.VideoWriter(video_name_save,cv2.VideoWriter_fourcc('X','V','I','D'),fps,size)

            start_time = i * dertF

            if i == int(n) + 1:

                stop_time = total_frame
            
            else:
                stop_time = start_time + dertF


            print(video_name)

            save_seg_video(video_name=video_name,frameToStart=start_time,frametoStop=stop_time,videoWriter=videoWriter)




if __name__ == '__main__':  
     

    # 每个视频的时长(单位秒)
    dertTime = 5

    video = "dataset/video"
    video_save = 'dataset/video_seg_5s'

    get_seg_video(video_file=video,video_save_path=video_save,dertTime=dertTime)