diff --git a/tool/mediapipe_detection_image.py b/tool/mediapipe_detection_image.py new file mode 100644 index 0000000..c6ed08d --- /dev/null +++ b/tool/mediapipe_detection_image.py @@ -0,0 +1,264 @@ +import mediapipe as mp +import math +from typing import List, Mapping, Optional, Tuple, Union +import cv2 +import dataclasses +import numpy as np +from mediapipe.framework.formats import landmark_pb2 +import os + +mp_holistic = mp.solutions.holistic + +_PRESENCE_THRESHOLD = 0.5 +_VISIBILITY_THRESHOLD = 0.5 +_BGR_CHANNELS = 3 + +WHITE_COLOR = (224, 224, 224) +BLACK_COLOR = (0, 0, 0) +RED_COLOR = (0, 0, 255) +GREEN_COLOR = (0, 128, 0) +BLUE_COLOR = (255, 0, 0) + + +@dataclasses.dataclass +class DrawingSpec: + """ + 定义标注框的线性、色彩 + """ + # Color for drawing the annotation. Default to the white color. + color: Tuple[int, int, int] = GREEN_COLOR + # Thickness for drawing the annotation. Default to 2 pixels. + thickness: int = 2 + # Circle radius. Default to 2 pixels. + circle_radius: int = 2 + + +def _normalized_to_pixel_coordinates( + normalized_x: float, normalized_y: float, image_width: int, + image_height: int) -> Union[None, Tuple[int, int]]: + """ + 将关键点坐标转化为图像像素坐标 + """ + + # Checks if the float value is between 0 and 1. + def is_valid_normalized_value(value: float) -> bool: + return (value > 0 or math.isclose(0, value)) and (value < 1 or + math.isclose(1, value)) + + if not (is_valid_normalized_value(normalized_x) and + is_valid_normalized_value(normalized_y)): + # TODO: Draw coordinates even if it's outside of the image bounds. + return None + x_px = min(math.floor(normalized_x * image_width), image_width - 1) + y_px = min(math.floor(normalized_y * image_height), image_height - 1) + # return print("转化的真实坐标:",x_px, y_px) + return x_px,y_px + + +def draw_landmarks( + image: np.ndarray, + landmark_list: landmark_pb2.NormalizedLandmarkList, + connections: Optional[List[Tuple[int, int]]] = None, + landmark_drawing_spec: Union[DrawingSpec,Mapping[int, DrawingSpec]] = DrawingSpec(color=RED_COLOR), + connection_drawing_spec: Union[DrawingSpec, Mapping[Tuple[int, int],DrawingSpec]] = DrawingSpec()): + """ + 主要是绘制关键点的连接图 + image:输入的数据 + landmark_list:关键点列表 + connections:连接点 + """ + if not landmark_list: + return + if image.shape[2] != _BGR_CHANNELS: + raise ValueError('Input image must contain three channel bgr data.') + image_rows, image_cols, _ = image.shape + + idx_to_coordinates = {} + for idx, landmark in enumerate(landmark_list.landmark): + if ((landmark.HasField('visibility') and + landmark.visibility < _VISIBILITY_THRESHOLD) or + (landmark.HasField('presence') and + landmark.presence < _PRESENCE_THRESHOLD)): + continue + landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y, #将归一化坐标值转换为图像坐标值 + image_cols, image_rows) + # print('图像像素坐标:',landmark_px) + if landmark_px: + idx_to_coordinates[idx] = landmark_px + # print("这是什么:",idx_to_coordinates[idx]) + dot_list = [] + if connections: + # num_landmarks = len(landmark_list.landmark) + #connections:keypoint索引元组的列表,用于指定如何在图形中连接地标。 + # Draws the connections if the start and end landmarks are both visible. + + starts = [] + ends = [] + for connection in connections: + # print("怎样连接的:",connection[0],connection[1]) + + start_idx = connection[0] + end_idx = connection[1] + + starts.append(start_idx) + ends.append(end_idx) + """取消注释部分可以绘制关键点连接的图像""" + # if not (0 <= start_idx < num_landmarks and 0 <= end_idx < num_landmarks): + # raise ValueError(f'Landmark index is out of range. Invalid connection ' + # f'from landmark #{start_idx} to landmark #{end_idx}.') + # if start_idx in idx_to_coordinates and end_idx in idx_to_coordinates: + # drawing_spec = connection_drawing_spec[connection] if isinstance( + # connection_drawing_spec, Mapping) else connection_drawing_spec + # cv2.line(image, idx_to_coordinates[start_idx], + # idx_to_coordinates[end_idx], drawing_spec.color, + # drawing_spec.thickness) + + # print("头节点:",start_list) + # print("尾结点:",end_list) + for dot in ends: + if dot in list(idx_to_coordinates.keys()): + # print((idx_to_coordinates.keys())) + dot_list.append(idx_to_coordinates[dot]) + + # if landmark_drawing_spec: + # for idx, landmark_px in idx_to_coordinates.items(): + # drawing_spec = landmark_drawing_spec[idx] if isinstance( + # landmark_drawing_spec, Mapping) else landmark_drawing_spec + # # White circle border + # circle_border_radius = max(drawing_spec.circle_radius + 1, + # int(drawing_spec.circle_radius * 1.2)) + # cv2.circle(image, landmark_px, circle_border_radius, WHITE_COLOR, + # drawing_spec.thickness) + # # Fill color into the circle + # cv2.circle(image, landmark_px, drawing_spec.circle_radius, + # drawing_spec.color, drawing_spec.thickness) + + return dot_list + + +class mediapipe_detect: + + def mediapipe_detection(self,image, model): + """ + mediapipe检测模块 + image:输入数据集 + model:调用mediapipe模型检测动作的模块 + """ + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB + image.flags.writeable = False # Image is no longer writeable + results = model.process(image) # 利用模块预测动作并输出坐标 + # image.flags.writeable = True # Image is now writeable + # image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR + return image, results + + def Drawing_bbox(self,result,bias): + + ''' + 根据关键点坐标,获取最大外接矩形的坐标点 + result:关键点坐标 + b:修正值,增大或者减小矩形框 + ''' + + result = np.array(result) + b = bias + if result.any(): + + rect = cv2.boundingRect(result) #返回值, 左上角的坐标[x,y, w,h] + + bbox = [[rect[0] - b, rect[1] - b], [rect[0] + rect[2] + b, rect[1] - b], + [rect[0] - b, rect[1] + rect[3] + b], [rect[0] + rect[2] + b, rect[1] + rect[3] + b]] #四个角的坐标 + # print(bbox) + return bbox + def get_bbox(self,image, results,face_b,left_hand_b,right_hand_b): + + ''' + 主要是根据关键点坐标,绘制矩形框 + images: 待检测数据 + results: mediapipe检测结果 + ''' + + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + """获取头部、手部关键点""" + face_location = draw_landmarks( + image, + results.face_landmarks, + mp_holistic.FACEMESH_CONTOURS, + # DrawingSpec(color=(80, 110, 10), thickness=1, circle_radius=1), + # DrawingSpec(color=(80, 256, 121), thickness=1, circle_radius=1) + ) + + right_hand_location = draw_landmarks( + image, + results.right_hand_landmarks, + mp_holistic.HAND_CONNECTIONS, + # DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4), + # DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2) + ) + + left_hand_location = draw_landmarks( + image, + results.left_hand_landmarks, + mp_holistic.HAND_CONNECTIONS, + # DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4), + # DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2) + ) + + """根据关键点的坐标绘制最大外接矩形""" + fl_bbox = self.Drawing_bbox(face_location,face_b) + lh_bbox = self.Drawing_bbox(left_hand_location,left_hand_b) + rh_bbox = self.Drawing_bbox(right_hand_location,right_hand_b) + # print("fl",fl_bbox) + # print("lh", lh_bbox) + # print("rh",rh_bbox) + """调整头部检测框的大小""" + if fl_bbox is not None: + """对头部动作检测框微调""" + fl_bbox[3][0] = fl_bbox[3][0] - 15 + fl_bbox[3][1] = fl_bbox[3][1] - 30 + fl_bbox[0][0] = fl_bbox[0][0] + 30 + fl_bbox[0][1] = fl_bbox[0][1] + 5 + cv2.rectangle(image, fl_bbox[0], fl_bbox[3],DrawingSpec.color, DrawingSpec.thickness) + if lh_bbox is not None: + cv2.rectangle(image, lh_bbox[0], lh_bbox[3],DrawingSpec.color, DrawingSpec.thickness) + if rh_bbox is not None: + cv2.rectangle(image, rh_bbox[0], rh_bbox[3],DrawingSpec.color, DrawingSpec.thickness) + + res = {'face_bbox': fl_bbox, 'hand_bbox': [lh_bbox,rh_bbox]} + + return image,res + + +def main(input_path,face_b,left_hand_b,right_hand_b): + """" + 图片检测模块 + input_path:检测图片路径 + face_b:头部标注框修正值,可根据输入值调整头部标注框的大小 + left_hand_b:左手检测标注框... + right_hand_b:右手检测标注框... + return:返回是关键点坐标框的坐标值 + """ + image = cv2.imread(input_path) + # image_name = os.path.basename(input_path) + + with mp_holistic.Holistic(model_complexity=2,min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic: + image,res = mediapipe_detect().mediapipe_detection(image,holistic) + image ,res = mediapipe_detect().get_bbox(image,res,face_b,left_hand_b,right_hand_b) + # cv2.imwrite(output+"/"+image_name,image) #取消注释可以保存处理后的图片 + cv2.namedWindow("mediapipe_detections", cv2.WINDOW_AUTOSIZE) + cv2.imshow("mediapipe_detections", image) + cv2.waitKey() + cv2.destroyAllWindows() + # print("标注框坐标值",res) + return res + + +if __name__ == "__main__": + input = 'D:/inference/mediapipe/mediapipe/python/video/test_picture/test11.jpg' + # output = 'D:/inference/mediapipe/mediapipe/python/video/output' + face_b = 50 #头部标注框修正值 + left_hand_b = 7 #头部标注框修正值 + right_hand_b = 7 #头部标注框修正值 + main(input,face_b,left_hand_b,right_hand_b) +