import mediapipe as mp
import math
from typing import List, Mapping, Optional, Tuple, Union
import cv2
import dataclasses
import numpy as np
from mediapipe.framework.formats import landmark_pb2
import os

mp_holistic = mp.solutions.holistic

_PRESENCE_THRESHOLD = 0.5
_VISIBILITY_THRESHOLD = 0.5
_BGR_CHANNELS = 3

WHITE_COLOR = (224, 224, 224)
BLACK_COLOR = (0, 0, 0)
RED_COLOR = (0, 0, 255)
GREEN_COLOR = (0, 128, 0)
BLUE_COLOR = (255, 0, 0)


@dataclasses.dataclass
class DrawingSpec:
    """
        定义标注框的线性、色彩
    """
    # Color for drawing the annotation. Default to the white color.
    color: Tuple[int, int, int] = GREEN_COLOR
    # Thickness for drawing the annotation. Default to 2 pixels.
    thickness: int = 2
    # Circle radius. Default to 2 pixels.
    circle_radius: int = 2


def _normalized_to_pixel_coordinates(
        normalized_x: float, normalized_y: float, image_width: int,
        image_height: int) -> Union[None, Tuple[int, int]]:
    """
        将关键点坐标转化为图像像素坐标
    """

    # Checks if the float value is between 0 and 1.
    def is_valid_normalized_value(value: float) -> bool:
        return (value > 0 or math.isclose(0, value)) and (value < 1 or
                                                          math.isclose(1, value))

    if not (is_valid_normalized_value(normalized_x) and
            is_valid_normalized_value(normalized_y)):
        # TODO: Draw coordinates even if it's outside of the image bounds.
        return None
    x_px = min(math.floor(normalized_x * image_width), image_width - 1)
    y_px = min(math.floor(normalized_y * image_height), image_height - 1)
    # return print("转化的真实坐标:",x_px, y_px)
    return x_px, y_px


def draw_landmarks(
        image: np.ndarray,
        landmark_list: landmark_pb2.NormalizedLandmarkList,
        connections: Optional[List[Tuple[int, int]]] = None,
        landmark_drawing_spec: Union[DrawingSpec, Mapping[int, DrawingSpec]] = DrawingSpec(color=RED_COLOR),
        connection_drawing_spec: Union[DrawingSpec, Mapping[Tuple[int, int], DrawingSpec]] = DrawingSpec()):
    """
        主要是绘制关键点的连接图
        image:输入的数据
        landmark_list:关键点列表
        connections:连接点
    """
    if not landmark_list:
        return
    if image.shape[2] != _BGR_CHANNELS:
        raise ValueError('Input image must contain three channel bgr data.')
    image_rows, image_cols, _ = image.shape

    idx_to_coordinates = {}
    for idx, landmark in enumerate(landmark_list.landmark):
        if ((landmark.HasField('visibility') and
             landmark.visibility < _VISIBILITY_THRESHOLD) or
                (landmark.HasField('presence') and
                 landmark.presence < _PRESENCE_THRESHOLD)):
            continue
        landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y,  # 将归一化坐标值转换为图像坐标值
                                                       image_cols, image_rows)
        # print('图像像素坐标:',landmark_px)
        if landmark_px:
            idx_to_coordinates[idx] = landmark_px
            # print("这是什么:",idx_to_coordinates[idx])
    dot_list = []
    if connections:
        # num_landmarks = len(landmark_list.landmark)
        # connections:keypoint索引元组的列表,用于指定如何在图形中连接地标。
        # Draws the connections if the start and end landmarks are both visible.

        starts = []
        ends = []
        for connection in connections:
            # print("怎样连接的:",connection[0],connection[1])

            start_idx = connection[0]
            end_idx = connection[1]

            starts.append(start_idx)
            ends.append(end_idx)
            """取消注释部分可以绘制关键点连接的图像"""
            # if not (0 <= start_idx < num_landmarks and 0 <= end_idx < num_landmarks):
            #     raise ValueError(f'Landmark index is out of range. Invalid connection '
            #                      f'from landmark #{start_idx} to landmark #{end_idx}.')
            # if start_idx in idx_to_coordinates and end_idx in idx_to_coordinates:
            #     drawing_spec = connection_drawing_spec[connection] if isinstance(
            #         connection_drawing_spec, Mapping) else connection_drawing_spec
            #     cv2.line(image, idx_to_coordinates[start_idx],
            #              idx_to_coordinates[end_idx], drawing_spec.color,
            #              drawing_spec.thickness)

        # print("头节点:",start_list)
        # print("尾结点:",end_list)
        for dot in ends:
            if dot in list(idx_to_coordinates.keys()):
                # print((idx_to_coordinates.keys()))
                dot_list.append(idx_to_coordinates[dot])

    # if landmark_drawing_spec:
    #     for idx, landmark_px in idx_to_coordinates.items():
    #         drawing_spec = landmark_drawing_spec[idx] if isinstance(
    #             landmark_drawing_spec, Mapping) else landmark_drawing_spec
    #         # White circle border
    #         circle_border_radius = max(drawing_spec.circle_radius + 1,
    #                                    int(drawing_spec.circle_radius * 1.2))
    #         cv2.circle(image, landmark_px, circle_border_radius, WHITE_COLOR,
    #                    drawing_spec.thickness)
    #         # Fill color into the circle
    #         cv2.circle(image, landmark_px, drawing_spec.circle_radius,
    #                    drawing_spec.color, drawing_spec.thickness)

    return dot_list


class mediapipe_detect:

    def mediapipe_detection(self, image, model):
        """
            mediapipe检测模块
            image:输入数据集
            model:调用mediapipe模型检测动作的模块
        """
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # COLOR CONVERSION BGR 2 RGB
        image.flags.writeable = False  # Image is no longer writeable
        results = model.process(image)  # 利用模块预测动作并输出坐标
        # image.flags.writeable = True  # Image is now writeable
        # image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)  # COLOR COVERSION RGB 2 BGR
        return image, results

    def Drawing_bbox(self, result, bias):

        '''
            根据关键点坐标,获取最大外接矩形的坐标点
            result:关键点坐标
            b:修正值,增大或者减小矩形框
        '''

        result = np.array(result)
        b = bias
        if result.any():
            rect = cv2.boundingRect(result)  # 返回值, 左上角的坐标[x,y, w,h]

            bbox = [[rect[0] - b, rect[1] - b], [rect[0] + rect[2] + b, rect[1] - b],
                    [rect[0] - b, rect[1] + rect[3] + b], [rect[0] + rect[2] + b, rect[1] + rect[3] + b]]  # 四个角的坐标

            return bbox

    def get_bbox(self, image, results, face_b, left_hand_b, right_hand_b, label):

        '''
            主要是根据关键点坐标,绘制矩形框
            images: 待检测数据
            results: mediapipe检测结果
        '''

        image.flags.writeable = True
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        h, w, g = image.shape
        # print("h:",h,"w:",w,"g:",g)

        """获取头部、手部关键点"""
        face_location = draw_landmarks(
            image,
            results.face_landmarks,
            mp_holistic.FACEMESH_CONTOURS,
            # DrawingSpec(color=(80, 110, 10), thickness=1, circle_radius=1),
            # DrawingSpec(color=(80, 256, 121), thickness=1, circle_radius=1)
        )

        right_hand_location = draw_landmarks(
            image,
            results.right_hand_landmarks,
            mp_holistic.HAND_CONNECTIONS,
            # DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4),
            # DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2)
        )

        left_hand_location = draw_landmarks(
            image,
            results.left_hand_landmarks,
            mp_holistic.HAND_CONNECTIONS,
            # DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4),
            # DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2)
        )

        """根据关键点的坐标绘制最大外接矩形"""
        fl_bbox = self.Drawing_bbox(face_location, face_b)
        lh_bbox = self.Drawing_bbox(left_hand_location, left_hand_b)
        rh_bbox = self.Drawing_bbox(right_hand_location, right_hand_b)

        if label == "Nodding" or label == "sleep!" or label == "not playing phone!":
            lh_bbox = None
            rh_bbox = None
        elif label == "not sleep!" or label == "playing phone!":
            fl_bbox = None

        """调整头部检测框的大小"""
        if fl_bbox is not None:
            """对头部动作检测框微调"""
            fl_bbox[3][0] = fl_bbox[3][0] - 15
            fl_bbox[3][1] = fl_bbox[3][1] - 30
            fl_bbox[0][0] = fl_bbox[0][0] + 30
            fl_bbox[0][1] = fl_bbox[0][1] + 5
            # print(fl_bbox)
            for i in range(0, 4):
                for j in range(0, 2):
                    if fl_bbox[i][j] < 0:
                        fl_bbox[i][j] = 0
                    elif fl_bbox[i][0] > w:
                        fl_bbox[i][0] = w
                    elif fl_bbox[i][1] > h:
                        fl_bbox[i][1] = h
                    else:
                        pass
            cv2.rectangle(image, fl_bbox[0], fl_bbox[3], DrawingSpec.color, DrawingSpec.thickness)

        if lh_bbox is not None:
            for i in range(0, 4):
                for j in range(0, 2):
                    if lh_bbox[i][j] < 0:
                        lh_bbox[i][j] = 0
                    elif lh_bbox[i][0] > w:
                        lh_bbox[i][0] = w
                    elif lh_bbox[i][1] > h:
                        lh_bbox[i][1] = h
                    else:
                        pass
            cv2.rectangle(image, lh_bbox[0], lh_bbox[3], DrawingSpec.color, DrawingSpec.thickness)

        if rh_bbox is not None:
            for i in range(0, 4):
                for j in range(0, 2):
                    if rh_bbox[i][j] < 0:
                        rh_bbox[i][j] = 0
                    elif rh_bbox[i][0] > w:
                        rh_bbox[i][0] = w
                    elif rh_bbox[i][1] > h:
                        rh_bbox[i][1] = h
                    else:
                        pass
            cv2.rectangle(image, rh_bbox[0], rh_bbox[3], DrawingSpec.color, DrawingSpec.thickness)

        res = {'face_bbox': fl_bbox, 'hand_bbox': [lh_bbox, rh_bbox]}

        return image, res


def main(input_path, output_path, face_b, left_hand_b, right_hand_b):
    cap = cv2.VideoCapture(input_path)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    codec = cv2.VideoWriter_fourcc(*'XVID')
    video_name = os.path.basename(input_path)
    out = cv2.VideoWriter(output_path + "/" + video_name, codec, fps, (frame_width, frame_height))
    label = ""
    with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            image, results = mediapipe_detect().mediapipe_detection(frame, holistic)
            image, res = mediapipe_detect().get_bbox(image, results, face_b, left_hand_b, right_hand_b, label)
            out.write(image)
            cv2.namedWindow("mediapipe_detections", cv2.WINDOW_AUTOSIZE)
            cv2.imshow("mediapipe_detections", image)
            # print(res)
            if cv2.waitKey(10) & 0xFF == ord('q'):
                break

    cap.release()
    out.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    # input = 'D:/download/PaddleVideo1/output/output/after_1/test02_0.avi'
    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-1_0_1.avi'
    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-3_1400_0.avi'
    # input = "C:/Users/Administrator/Pictures/video_seg_re_hand/test01_3.avi"
    # input = 'C:/Users/Administrator/Pictures/video3.0/sleep/0711-3_7_01_5.avi'
    input = " D:/download/PaddleVideo1/output/output/after_1/0711-1_199_0.avi"
    # input = 'D:/download/PaddleVideo1/output/output/after_1/test05_10750_1.avi'
    output = 'D:/download/PaddleVideo1/output/output/output'
    face_b = 50  # 头部标注框修正值
    left_hand_b = 7  # 左手部分标注框修正值
    right_hand_b = 7  # 右手部分标注框修正值
    main(input, output, face_b, left_hand_b, right_hand_b)