# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MediaPipe solution drawing utils.""" import math from typing import List, Mapping, Optional, Tuple, Union import cv2 import dataclasses import matplotlib.pyplot as plt import numpy as np from mediapipe.framework.formats import detection_pb2 from mediapipe.framework.formats import location_data_pb2 from mediapipe.framework.formats import landmark_pb2 _PRESENCE_THRESHOLD = 0.5 _VISIBILITY_THRESHOLD = 0.5 _BGR_CHANNELS = 3 WHITE_COLOR = (224, 224, 224) BLACK_COLOR = (0, 0, 0) RED_COLOR = (0, 0, 255) GREEN_COLOR = (0, 128, 0) BLUE_COLOR = (255, 0, 0) @dataclasses.dataclass class DrawingSpec: # Color for drawing the annotation. Default to the white color. color: Tuple[int, int, int] = WHITE_COLOR # Thickness for drawing the annotation. Default to 2 pixels. thickness: int = 2 # Circle radius. Default to 2 pixels. circle_radius: int = 2 def _normalized_to_pixel_coordinates( normalized_x: float, normalized_y: float, image_width: int, image_height: int) -> Union[None, Tuple[int, int]]: """Converts normalized value pair to pixel coordinates.""" # Checks if the float value is between 0 and 1. def is_valid_normalized_value(value: float) -> bool: return (value > 0 or math.isclose(0, value)) and (value < 1 or math.isclose(1, value)) if not (is_valid_normalized_value(normalized_x) and is_valid_normalized_value(normalized_y)): # TODO: Draw coordinates even if it's outside of the image bounds. return None x_px = min(math.floor(normalized_x * image_width), image_width - 1) y_px = min(math.floor(normalized_y * image_height), image_height - 1) return x_px, y_px def draw_landmarks( image: np.ndarray, landmark_list: landmark_pb2.NormalizedLandmarkList, connections: Optional[List[Tuple[int, int]]] = None): """Draws the landmarks and the connections on the image. Args: image: A three channel BGR image represented as numpy ndarray. landmark_list: A normalized landmark list proto message to be annotated on the image. connections: A list of landmark index tuples that specifies how landmarks to be connected in the drawing. landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand landmarks to the DrawingSpecs that specifies the landmarks' drawing settings such as color, line thickness, and circle radius. If this argument is explicitly set to None, no landmarks will be drawn. connection_drawing_spec: Either a DrawingSpec object or a mapping from hand connections to the DrawingSpecs that specifies the connections' drawing settings such as color and line thickness. If this argument is explicitly set to None, no landmark connections will be drawn. Raises: ValueError: If one of the followings: a) If the input image is not three channel BGR. b) If any connetions contain invalid landmark index. """ if not landmark_list: return if image.shape[2] != _BGR_CHANNELS: raise ValueError('Input image must contain three channel bgr data.') image_rows, image_cols, _ = image.shape # 所有的点转换成坐标的字典 idx_to_coordinates = {} for idx, landmark in enumerate(landmark_list.landmark): # print('landmark:',landmark) if ((landmark.HasField('visibility') and landmark.visibility < _VISIBILITY_THRESHOLD) or (landmark.HasField('presence') and landmark.presence < _PRESENCE_THRESHOLD)): continue landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y, image_cols, image_rows) # print('landmark_px:',landmark_px) if landmark_px: idx_to_coordinates[idx] = landmark_px if connections: num_landmarks = len(landmark_list.landmark) # print('connections:',connections) # Draws the connections if the start and end landmarks are both visible. start_list = [] end_list = [] for connection in connections: # print(connection) start_idx = connection[0] end_idx = connection[1] start_list.append(start_idx) end_list.append(end_idx) point_list = [] for point_idx in end_list: # if point_idx not in start_list: # print(point_idx) point_list.append(point_idx) point_axis_list = [] for point in point_list: if point in list(idx_to_coordinates.keys()): point_axis_list.append(idx_to_coordinates[point]) return point_axis_list