You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
5.1 KiB
Python

# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe solution drawing utils."""
import math
from typing import List, Mapping, Optional, Tuple, Union
import cv2
import dataclasses
import matplotlib.pyplot as plt
import numpy as np
from mediapipe.framework.formats import detection_pb2
from mediapipe.framework.formats import location_data_pb2
from mediapipe.framework.formats import landmark_pb2
_PRESENCE_THRESHOLD = 0.5
_VISIBILITY_THRESHOLD = 0.5
_BGR_CHANNELS = 3
WHITE_COLOR = (224, 224, 224)
BLACK_COLOR = (0, 0, 0)
RED_COLOR = (0, 0, 255)
GREEN_COLOR = (0, 128, 0)
BLUE_COLOR = (255, 0, 0)
@dataclasses.dataclass
class DrawingSpec:
# Color for drawing the annotation. Default to the white color.
color: Tuple[int, int, int] = WHITE_COLOR
# Thickness for drawing the annotation. Default to 2 pixels.
thickness: int = 2
# Circle radius. Default to 2 pixels.
circle_radius: int = 2
def _normalized_to_pixel_coordinates(
normalized_x: float, normalized_y: float, image_width: int,
image_height: int) -> Union[None, Tuple[int, int]]:
"""Converts normalized value pair to pixel coordinates."""
# Checks if the float value is between 0 and 1.
def is_valid_normalized_value(value: float) -> bool:
return (value > 0 or math.isclose(0, value)) and (value < 1 or
math.isclose(1, value))
if not (is_valid_normalized_value(normalized_x) and
is_valid_normalized_value(normalized_y)):
# TODO: Draw coordinates even if it's outside of the image bounds.
return None
x_px = min(math.floor(normalized_x * image_width), image_width - 1)
y_px = min(math.floor(normalized_y * image_height), image_height - 1)
return x_px, y_px
def draw_landmarks(
image: np.ndarray,
landmark_list: landmark_pb2.NormalizedLandmarkList,
connections: Optional[List[Tuple[int, int]]] = None):
"""Draws the landmarks and the connections on the image.
Args:
image: A three channel BGR image represented as numpy ndarray.
landmark_list: A normalized landmark list proto message to be annotated on
the image.
connections: A list of landmark index tuples that specifies how landmarks to
be connected in the drawing.
landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand
landmarks to the DrawingSpecs that specifies the landmarks' drawing
settings such as color, line thickness, and circle radius. If this
argument is explicitly set to None, no landmarks will be drawn.
connection_drawing_spec: Either a DrawingSpec object or a mapping from hand
connections to the DrawingSpecs that specifies the connections' drawing
settings such as color and line thickness. If this argument is explicitly
set to None, no landmark connections will be drawn.
Raises:
ValueError: If one of the followings:
a) If the input image is not three channel BGR.
b) If any connetions contain invalid landmark index.
"""
if not landmark_list:
return
if image.shape[2] != _BGR_CHANNELS:
raise ValueError('Input image must contain three channel bgr data.')
image_rows, image_cols, _ = image.shape
# 所有的点转换成坐标的字典
idx_to_coordinates = {}
for idx, landmark in enumerate(landmark_list.landmark):
# print('landmark:',landmark)
if ((landmark.HasField('visibility') and
landmark.visibility < _VISIBILITY_THRESHOLD) or
(landmark.HasField('presence') and
landmark.presence < _PRESENCE_THRESHOLD)):
continue
landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y,
image_cols, image_rows)
# print('landmark_px:',landmark_px)
if landmark_px:
idx_to_coordinates[idx] = landmark_px
if connections:
num_landmarks = len(landmark_list.landmark)
# print('connections:',connections)
# Draws the connections if the start and end landmarks are both visible.
start_list = []
end_list = []
for connection in connections:
# print(connection)
start_idx = connection[0]
end_idx = connection[1]
start_list.append(start_idx)
end_list.append(end_idx)
point_list = []
for point_idx in end_list:
# if point_idx not in start_list:
# print(point_idx)
point_list.append(point_idx)
point_axis_list = []
for point in point_list:
if point in list(idx_to_coordinates.keys()):
point_axis_list.append(idx_to_coordinates[point])
return point_axis_list