From b6aa129cf2b2cace02f945613f46916f09411df1 Mon Sep 17 00:00:00 2001
From: jiangxt <1579525634@qq.com>
Date: Mon, 7 Aug 2023 10:15:34 +0800
Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0mediapipe=E6=A0=87=E6=B3=A8?=
 =?UTF-8?q?=E6=A1=86=E4=BF=AE=E6=AD=A3=E9=97=AE=E9=A2=98?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 tool/mediapipe_detection.py | 72 ++++++++++++++++++++++++++++++-------
 1 file changed, 60 insertions(+), 12 deletions(-)

diff --git a/tool/mediapipe_detection.py b/tool/mediapipe_detection.py
index 645e630..cafacc6 100644
--- a/tool/mediapipe_detection.py
+++ b/tool/mediapipe_detection.py
@@ -7,6 +7,7 @@ import numpy as np
 from mediapipe.framework.formats import landmark_pb2
 import os
 
+
 mp_holistic = mp.solutions.holistic
 
 _PRESENCE_THRESHOLD = 0.5
@@ -167,8 +168,10 @@ class mediapipe_detect:
 
             bbox = [[rect[0] - b, rect[1] - b], [rect[0] + rect[2] + b, rect[1] - b],
                     [rect[0] - b, rect[1] + rect[3] + b], [rect[0] + rect[2] + b, rect[1] + rect[3] + b]]  #四个角的坐标
+            
             return bbox
-    def get_bbox(self,image, results,face_b,left_hand_b,right_hand_b):
+        
+    def get_bbox(self,image,results,face_b,left_hand_b,right_hand_b):
 
         '''
             主要是根据关键点坐标,绘制矩形框
@@ -178,6 +181,8 @@ class mediapipe_detect:
 
         image.flags.writeable = True
         image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+        h,w,g = image.shape
+        # print("h:",h,"w:",w,"g:",g)
 
         """获取头部、手部关键点"""
         face_location = draw_landmarks(
@@ -216,16 +221,51 @@ class mediapipe_detect:
             fl_bbox[3][1] = fl_bbox[3][1] - 30
             fl_bbox[0][0] = fl_bbox[0][0] + 30
             fl_bbox[0][1] = fl_bbox[0][1] + 5
+            # print(fl_bbox)
+            for i in range(0,4):
+                for j in range(0,2):
+                    if fl_bbox[i][j] < 0:
+                        fl_bbox[i][j] = 0
+                    elif fl_bbox[i][0] > w:
+                        fl_bbox[i][0] = w
+                    elif fl_bbox[i][1] > h :
+                        fl_bbox[i][1] = h
+                    else:
+                        pass
             cv2.rectangle(image, fl_bbox[0], fl_bbox[3],DrawingSpec.color, DrawingSpec.thickness)
+
         if lh_bbox is not None:
+            for i in range(0,4):
+                for j in range(0,2):
+                    if lh_bbox[i][j] < 0:
+                        lh_bbox[i][j] = 0
+                    elif lh_bbox[i][0] > w:
+                        lh_bbox[i][0] = w
+                    elif lh_bbox[i][1] > h :
+                        lh_bbox[i][1] = h
+                    else:
+                        pass
             cv2.rectangle(image, lh_bbox[0], lh_bbox[3],DrawingSpec.color, DrawingSpec.thickness)
+
         if rh_bbox is not None:
+            for i in range(0,4):
+                for j in range(0,2):
+                    if rh_bbox[i][j] < 0:
+                        rh_bbox[i][j] = 0
+                    elif rh_bbox[i][0] > w:
+                        rh_bbox[i][0] = w
+                    elif rh_bbox[i][1] > h:
+                        rh_bbox[i][1] = h
+                    else:
+                        pass
             cv2.rectangle(image, rh_bbox[0], rh_bbox[3],DrawingSpec.color, DrawingSpec.thickness)
 
+
         res = {'face_bbox': fl_bbox, 'hand_bbox': [lh_bbox,rh_bbox]}
-        cv2.namedWindow("mediapipe_detections", cv2.WINDOW_AUTOSIZE)
-        cv2.imshow("mediapipe_detections",image)
-        # print(result_dict)
+
+
+        # print(res)
+
         return image,res
 
 
@@ -236,8 +276,8 @@ def main(input_path,output_path,face_b,left_hand_b,right_hand_b):
      frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
      fps = int(cap.get(cv2.CAP_PROP_FPS))
      codec = cv2.VideoWriter_fourcc(*'XVID')
-     video_name = os.path.basename(input_path).split('.')[0]
-     out = cv2.VideoWriter(output_path +"/"+ video_name+".avi", codec, fps, (frame_width, frame_height))
+     video_name = os.path.basename(input_path)
+     out = cv2.VideoWriter(output_path +"/"+ video_name, codec, fps, (frame_width, frame_height))
      with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
          while True:
              ret, frame = cap.read()
@@ -246,20 +286,28 @@ def main(input_path,output_path,face_b,left_hand_b,right_hand_b):
              image, results = mediapipe_detect().mediapipe_detection(frame,holistic)
              image,res = mediapipe_detect().get_bbox(image,results,face_b,left_hand_b,right_hand_b)
              out.write(image)
-
+             cv2.namedWindow("mediapipe_detections", cv2.WINDOW_AUTOSIZE)
+             cv2.imshow("mediapipe_detections", image)
+             # print(res)
              if cv2.waitKey(10) & 0xFF == ord('q'):
                  break
 
-
      cap.release()
      out.release()
      cv2.destroyAllWindows()
 
 if __name__ == "__main__":
-    input = 'D:/inference/mediapipe/mediapipe/python/video/test/test02_2.avi'
-    output = 'D:/inference/mediapipe/mediapipe/python/video/output_video'
+    input = 'D:/download/PaddleVideo1/output/output/after_1/test02_0.avi'
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-1_0_1.avi'
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-3_1400_0.avi'
+    # input = "C:/Users/Administrator/Pictures/video_seg_re_hand/test01_3.avi"
+    # input = 'C:/Users/Administrator/Pictures/video3.0/sleep/0711-3_7_01_5.avi'
+    # input = " D:/download/PaddleVideo1/output/output/after_1/0711-1_199_0.avi"
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/test05_10750_1.avi'
+    output = 'D:/download/PaddleVideo1/output/output/output'
     face_b = 50          #头部标注框修正值
-    left_hand_b = 7      #头部标注框修正值
-    right_hand_b = 7     #头部标注框修正值
+    left_hand_b = 7      #左手部分标注框修正值
+    right_hand_b = 7     #右手部分标注框修正值
     main(input,output,face_b,left_hand_b,right_hand_b)
 
+