You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
5.1 KiB
Python
152 lines
5.1 KiB
Python
# Copyright 2020 The MediaPipe Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""MediaPipe solution drawing utils."""
|
|
|
|
import math
|
|
from typing import List, Mapping, Optional, Tuple, Union
|
|
|
|
import cv2
|
|
import dataclasses
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
from mediapipe.framework.formats import detection_pb2
|
|
from mediapipe.framework.formats import location_data_pb2
|
|
from mediapipe.framework.formats import landmark_pb2
|
|
|
|
_PRESENCE_THRESHOLD = 0.5
|
|
_VISIBILITY_THRESHOLD = 0.5
|
|
_BGR_CHANNELS = 3
|
|
|
|
WHITE_COLOR = (224, 224, 224)
|
|
BLACK_COLOR = (0, 0, 0)
|
|
RED_COLOR = (0, 0, 255)
|
|
GREEN_COLOR = (0, 128, 0)
|
|
BLUE_COLOR = (255, 0, 0)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DrawingSpec:
|
|
# Color for drawing the annotation. Default to the white color.
|
|
color: Tuple[int, int, int] = WHITE_COLOR
|
|
# Thickness for drawing the annotation. Default to 2 pixels.
|
|
thickness: int = 2
|
|
# Circle radius. Default to 2 pixels.
|
|
circle_radius: int = 2
|
|
|
|
|
|
def _normalized_to_pixel_coordinates(
|
|
normalized_x: float, normalized_y: float, image_width: int,
|
|
image_height: int) -> Union[None, Tuple[int, int]]:
|
|
"""Converts normalized value pair to pixel coordinates."""
|
|
|
|
# Checks if the float value is between 0 and 1.
|
|
def is_valid_normalized_value(value: float) -> bool:
|
|
return (value > 0 or math.isclose(0, value)) and (value < 1 or
|
|
math.isclose(1, value))
|
|
|
|
if not (is_valid_normalized_value(normalized_x) and
|
|
is_valid_normalized_value(normalized_y)):
|
|
# TODO: Draw coordinates even if it's outside of the image bounds.
|
|
return None
|
|
x_px = min(math.floor(normalized_x * image_width), image_width - 1)
|
|
y_px = min(math.floor(normalized_y * image_height), image_height - 1)
|
|
return x_px, y_px
|
|
|
|
|
|
|
|
def draw_landmarks(
|
|
image: np.ndarray,
|
|
landmark_list: landmark_pb2.NormalizedLandmarkList,
|
|
connections: Optional[List[Tuple[int, int]]] = None):
|
|
"""Draws the landmarks and the connections on the image.
|
|
|
|
Args:
|
|
image: A three channel BGR image represented as numpy ndarray.
|
|
landmark_list: A normalized landmark list proto message to be annotated on
|
|
the image.
|
|
connections: A list of landmark index tuples that specifies how landmarks to
|
|
be connected in the drawing.
|
|
landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand
|
|
landmarks to the DrawingSpecs that specifies the landmarks' drawing
|
|
settings such as color, line thickness, and circle radius. If this
|
|
argument is explicitly set to None, no landmarks will be drawn.
|
|
connection_drawing_spec: Either a DrawingSpec object or a mapping from hand
|
|
connections to the DrawingSpecs that specifies the connections' drawing
|
|
settings such as color and line thickness. If this argument is explicitly
|
|
set to None, no landmark connections will be drawn.
|
|
|
|
Raises:
|
|
ValueError: If one of the followings:
|
|
a) If the input image is not three channel BGR.
|
|
b) If any connetions contain invalid landmark index.
|
|
"""
|
|
if not landmark_list:
|
|
return
|
|
if image.shape[2] != _BGR_CHANNELS:
|
|
raise ValueError('Input image must contain three channel bgr data.')
|
|
image_rows, image_cols, _ = image.shape
|
|
|
|
# 所有的点转换成坐标的字典
|
|
idx_to_coordinates = {}
|
|
for idx, landmark in enumerate(landmark_list.landmark):
|
|
# print('landmark:',landmark)
|
|
if ((landmark.HasField('visibility') and
|
|
landmark.visibility < _VISIBILITY_THRESHOLD) or
|
|
(landmark.HasField('presence') and
|
|
landmark.presence < _PRESENCE_THRESHOLD)):
|
|
continue
|
|
landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y,
|
|
image_cols, image_rows)
|
|
# print('landmark_px:',landmark_px)
|
|
if landmark_px:
|
|
idx_to_coordinates[idx] = landmark_px
|
|
|
|
|
|
if connections:
|
|
num_landmarks = len(landmark_list.landmark)
|
|
# print('connections:',connections)
|
|
|
|
# Draws the connections if the start and end landmarks are both visible.
|
|
|
|
start_list = []
|
|
end_list = []
|
|
for connection in connections:
|
|
# print(connection)
|
|
|
|
start_idx = connection[0]
|
|
end_idx = connection[1]
|
|
|
|
start_list.append(start_idx)
|
|
end_list.append(end_idx)
|
|
|
|
|
|
point_list = []
|
|
for point_idx in end_list:
|
|
|
|
# if point_idx not in start_list:
|
|
|
|
# print(point_idx)
|
|
point_list.append(point_idx)
|
|
|
|
|
|
point_axis_list = []
|
|
for point in point_list:
|
|
|
|
if point in list(idx_to_coordinates.keys()):
|
|
point_axis_list.append(idx_to_coordinates[point])
|
|
|
|
|
|
return point_axis_list
|
|
|