mediapipe检测图片关键点

V0.1.0
jiangxt 2 years ago
parent 465aa26938
commit f8ca37bcb6

@ -0,0 +1,264 @@
import mediapipe as mp
import math
from typing import List, Mapping, Optional, Tuple, Union
import cv2
import dataclasses
import numpy as np
from mediapipe.framework.formats import landmark_pb2
import os
mp_holistic = mp.solutions.holistic
_PRESENCE_THRESHOLD = 0.5
_VISIBILITY_THRESHOLD = 0.5
_BGR_CHANNELS = 3
WHITE_COLOR = (224, 224, 224)
BLACK_COLOR = (0, 0, 0)
RED_COLOR = (0, 0, 255)
GREEN_COLOR = (0, 128, 0)
BLUE_COLOR = (255, 0, 0)
@dataclasses.dataclass
class DrawingSpec:
"""
定义标注框的线性色彩
"""
# Color for drawing the annotation. Default to the white color.
color: Tuple[int, int, int] = GREEN_COLOR
# Thickness for drawing the annotation. Default to 2 pixels.
thickness: int = 2
# Circle radius. Default to 2 pixels.
circle_radius: int = 2
def _normalized_to_pixel_coordinates(
normalized_x: float, normalized_y: float, image_width: int,
image_height: int) -> Union[None, Tuple[int, int]]:
"""
将关键点坐标转化为图像像素坐标
"""
# Checks if the float value is between 0 and 1.
def is_valid_normalized_value(value: float) -> bool:
return (value > 0 or math.isclose(0, value)) and (value < 1 or
math.isclose(1, value))
if not (is_valid_normalized_value(normalized_x) and
is_valid_normalized_value(normalized_y)):
# TODO: Draw coordinates even if it's outside of the image bounds.
return None
x_px = min(math.floor(normalized_x * image_width), image_width - 1)
y_px = min(math.floor(normalized_y * image_height), image_height - 1)
# return print("转化的真实坐标:",x_px, y_px)
return x_px,y_px
def draw_landmarks(
image: np.ndarray,
landmark_list: landmark_pb2.NormalizedLandmarkList,
connections: Optional[List[Tuple[int, int]]] = None,
landmark_drawing_spec: Union[DrawingSpec,Mapping[int, DrawingSpec]] = DrawingSpec(color=RED_COLOR),
connection_drawing_spec: Union[DrawingSpec, Mapping[Tuple[int, int],DrawingSpec]] = DrawingSpec()):
"""
主要是绘制关键点的连接图
image:输入的数据
landmark_list:关键点列表
connections:连接点
"""
if not landmark_list:
return
if image.shape[2] != _BGR_CHANNELS:
raise ValueError('Input image must contain three channel bgr data.')
image_rows, image_cols, _ = image.shape
idx_to_coordinates = {}
for idx, landmark in enumerate(landmark_list.landmark):
if ((landmark.HasField('visibility') and
landmark.visibility < _VISIBILITY_THRESHOLD) or
(landmark.HasField('presence') and
landmark.presence < _PRESENCE_THRESHOLD)):
continue
landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y, #将归一化坐标值转换为图像坐标值
image_cols, image_rows)
# print('图像像素坐标:',landmark_px)
if landmark_px:
idx_to_coordinates[idx] = landmark_px
# print("这是什么:",idx_to_coordinates[idx])
dot_list = []
if connections:
# num_landmarks = len(landmark_list.landmark)
#connections:keypoint索引元组的列表用于指定如何在图形中连接地标。
# Draws the connections if the start and end landmarks are both visible.
starts = []
ends = []
for connection in connections:
# print("怎样连接的:",connection[0],connection[1])
start_idx = connection[0]
end_idx = connection[1]
starts.append(start_idx)
ends.append(end_idx)
"""取消注释部分可以绘制关键点连接的图像"""
# if not (0 <= start_idx < num_landmarks and 0 <= end_idx < num_landmarks):
# raise ValueError(f'Landmark index is out of range. Invalid connection '
# f'from landmark #{start_idx} to landmark #{end_idx}.')
# if start_idx in idx_to_coordinates and end_idx in idx_to_coordinates:
# drawing_spec = connection_drawing_spec[connection] if isinstance(
# connection_drawing_spec, Mapping) else connection_drawing_spec
# cv2.line(image, idx_to_coordinates[start_idx],
# idx_to_coordinates[end_idx], drawing_spec.color,
# drawing_spec.thickness)
# print("头节点:",start_list)
# print("尾结点:",end_list)
for dot in ends:
if dot in list(idx_to_coordinates.keys()):
# print((idx_to_coordinates.keys()))
dot_list.append(idx_to_coordinates[dot])
# if landmark_drawing_spec:
# for idx, landmark_px in idx_to_coordinates.items():
# drawing_spec = landmark_drawing_spec[idx] if isinstance(
# landmark_drawing_spec, Mapping) else landmark_drawing_spec
# # White circle border
# circle_border_radius = max(drawing_spec.circle_radius + 1,
# int(drawing_spec.circle_radius * 1.2))
# cv2.circle(image, landmark_px, circle_border_radius, WHITE_COLOR,
# drawing_spec.thickness)
# # Fill color into the circle
# cv2.circle(image, landmark_px, drawing_spec.circle_radius,
# drawing_spec.color, drawing_spec.thickness)
return dot_list
class mediapipe_detect:
def mediapipe_detection(self,image, model):
"""
mediapipe检测模块
image输入数据集
model调用mediapipe模型检测动作的模块
"""
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB
image.flags.writeable = False # Image is no longer writeable
results = model.process(image) # 利用模块预测动作并输出坐标
# image.flags.writeable = True # Image is now writeable
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR
return image, results
def Drawing_bbox(self,result,bias):
'''
根据关键点坐标获取最大外接矩形的坐标点
result:关键点坐标
b:修正值,增大或者减小矩形框
'''
result = np.array(result)
b = bias
if result.any():
rect = cv2.boundingRect(result) #返回值, 左上角的坐标[x,y, w,h]
bbox = [[rect[0] - b, rect[1] - b], [rect[0] + rect[2] + b, rect[1] - b],
[rect[0] - b, rect[1] + rect[3] + b], [rect[0] + rect[2] + b, rect[1] + rect[3] + b]] #四个角的坐标
# print(bbox)
return bbox
def get_bbox(self,image, results,face_b,left_hand_b,right_hand_b):
'''
主要是根据关键点坐标绘制矩形框
images: 待检测数据
results: mediapipe检测结果
'''
image.flags.writeable = True
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
"""获取头部、手部关键点"""
face_location = draw_landmarks(
image,
results.face_landmarks,
mp_holistic.FACEMESH_CONTOURS,
# DrawingSpec(color=(80, 110, 10), thickness=1, circle_radius=1),
# DrawingSpec(color=(80, 256, 121), thickness=1, circle_radius=1)
)
right_hand_location = draw_landmarks(
image,
results.right_hand_landmarks,
mp_holistic.HAND_CONNECTIONS,
# DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4),
# DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2)
)
left_hand_location = draw_landmarks(
image,
results.left_hand_landmarks,
mp_holistic.HAND_CONNECTIONS,
# DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4),
# DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2)
)
"""根据关键点的坐标绘制最大外接矩形"""
fl_bbox = self.Drawing_bbox(face_location,face_b)
lh_bbox = self.Drawing_bbox(left_hand_location,left_hand_b)
rh_bbox = self.Drawing_bbox(right_hand_location,right_hand_b)
# print("fl",fl_bbox)
# print("lh", lh_bbox)
# print("rh",rh_bbox)
"""调整头部检测框的大小"""
if fl_bbox is not None:
"""对头部动作检测框微调"""
fl_bbox[3][0] = fl_bbox[3][0] - 15
fl_bbox[3][1] = fl_bbox[3][1] - 30
fl_bbox[0][0] = fl_bbox[0][0] + 30
fl_bbox[0][1] = fl_bbox[0][1] + 5
cv2.rectangle(image, fl_bbox[0], fl_bbox[3],DrawingSpec.color, DrawingSpec.thickness)
if lh_bbox is not None:
cv2.rectangle(image, lh_bbox[0], lh_bbox[3],DrawingSpec.color, DrawingSpec.thickness)
if rh_bbox is not None:
cv2.rectangle(image, rh_bbox[0], rh_bbox[3],DrawingSpec.color, DrawingSpec.thickness)
res = {'face_bbox': fl_bbox, 'hand_bbox': [lh_bbox,rh_bbox]}
return image,res
def main(input_path,face_b,left_hand_b,right_hand_b):
""""
图片检测模块
input_path:检测图片路径
face_b:头部标注框修正值可根据输入值调整头部标注框的大小
left_hand_b:左手检测标注框...
right_hand_b:右手检测标注框...
return:返回是关键点坐标框的坐标值
"""
image = cv2.imread(input_path)
# image_name = os.path.basename(input_path)
with mp_holistic.Holistic(model_complexity=2,min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
image,res = mediapipe_detect().mediapipe_detection(image,holistic)
image ,res = mediapipe_detect().get_bbox(image,res,face_b,left_hand_b,right_hand_b)
# cv2.imwrite(output+"/"+image_name,image) #取消注释可以保存处理后的图片
cv2.namedWindow("mediapipe_detections", cv2.WINDOW_AUTOSIZE)
cv2.imshow("mediapipe_detections", image)
cv2.waitKey()
cv2.destroyAllWindows()
# print("标注框坐标值",res)
return res
if __name__ == "__main__":
input = 'D:/inference/mediapipe/mediapipe/python/video/test_picture/test11.jpg'
# output = 'D:/inference/mediapipe/mediapipe/python/video/output'
face_b = 50 #头部标注框修正值
left_hand_b = 7 #头部标注框修正值
right_hand_b = 7 #头部标注框修正值
main(input,face_b,left_hand_b,right_hand_b)
Loading…
Cancel
Save