diff --git a/tool/demo.py b/tool/demo.py
new file mode 100644
index 0000000..b7227d6
--- /dev/null
+++ b/tool/demo.py
@@ -0,0 +1,89 @@
+from PP_TSMv2_infer import *
+from mediapipe_detection import mediapipe_detect
+import mediapipe as mp
+import cv2
+mp_holistic = mp.solutions.holistic
+
+
+def main(input_path,output_path, face_b, left_hand_b, right_hand_b):
+
+    cap = cv2.VideoCapture(input_path)
+    config = 'D:/download/PaddleVideo1/output/output/pptsm_lcnet_k400_16frames_uniform.yaml'
+    model_file = 'D:/download/PaddleVideo1/output/output/ppTSMv2.pdmodel'  # 推理模型存放地址
+    params_file = 'D:/download/PaddleVideo1/output/output/ppTSMv2.pdiparams'  # 推理模型参数存放地址
+    batch_size = 1  # 输出推理模型
+    infer,predictor = PP_TSMv2_predict().create_inference_model(config,model_file,params_file)
+    res = PP_TSMv2_predict().predict(config, input_path, batch_size, predictor,infer)
+    label = res["topk_class"]
+    if label == 0:
+        label = "Nodding!"
+    elif label == 1:
+        label = "not playing phone!"
+    elif label == 2:
+        label = "not sleep!"
+    elif label == 3:
+        label = "playing phone!"
+    elif label == 4:
+        label = "sleep!"
+    else:
+        pass
+    fps_video = cap.get(cv2.CAP_PROP_FPS)
+    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+    codec = cv2.VideoWriter_fourcc(*'XVID')
+    video_name = os.path.basename(input_path)
+    out = cv2.VideoWriter(output_path + "/" + video_name, codec, fps_video, (frame_width, frame_height))
+    with mp_holistic.Holistic(model_complexity=2,min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
+        while True:
+            ret, frame = cap.read()
+            if not ret:
+                break
+            image, results = mediapipe_detect().mediapipe_detection(frame, holistic)
+            cv2.namedWindow("mediapipe_detections", cv2.WINDOW_AUTOSIZE)
+            if label == "Nodding!":
+                image, res = mediapipe_detect().get_bbox(image, results, face_b, left_hand_b, right_hand_b,label)
+                cv2.putText(image, "the person's head is " + label, (0, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0),
+                            1)
+            elif label == "sleep!":
+                image, res = mediapipe_detect().get_bbox(image, results, face_b, left_hand_b, right_hand_b,label)
+                cv2.putText(image, "the person is " + label, (0, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0),
+                            1)
+            elif label == "not sleep!":
+                image, res = mediapipe_detect().get_bbox(image, results, face_b, left_hand_b, right_hand_b,label)
+                cv2.putText(image, "the person is " + label, (0, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0),
+                            1)
+            elif label == "playing phone!":
+                image, res = mediapipe_detect().get_bbox(image, results, face_b, left_hand_b, right_hand_b,label)
+                cv2.putText(image, "the person'hand is " + label, (0, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0),
+                            1)
+            elif label == "not playing phone!":
+                image, res = mediapipe_detect().get_bbox(image, results, face_b, left_hand_b, right_hand_b,label)
+                cv2.putText(image, "the person'hand is " + label, (0, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0),
+                            1)
+
+            cv2.imshow("mediapipe_detections", image)
+            out.write(image)
+            if cv2.waitKey(10) & 0xFF == ord('q'):
+                break
+            out.write(image)
+            # print(res)
+    cap.release()
+    out.release()
+    cv2.destroyAllWindows()
+
+
+if __name__=="__main__":
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-1_0_1.avi'
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-1_398_1.avi'
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-1_597_0.avi'  #正例
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-1_597_1.avi'
+    input = 'D:/download/PaddleVideo1/output/output/after_1/0711-1_796_0.avi'    #正例,推理成功
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-1_796_1.avi'
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-3_0_0.avi'
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-3_1400_0.avi'
+
+    output = "D:/download/PaddleVideo1/output/output1"
+    face_b = 50
+    left_hand_b = 7
+    right_hand_b = 7
+    main(input,output,face_b,left_hand_b,right_hand_b)
\ No newline at end of file
diff --git a/tool/mediapipe_detection.py b/tool/mediapipe_detection.py
index cafacc6..25e1eb7 100644
--- a/tool/mediapipe_detection.py
+++ b/tool/mediapipe_detection.py
@@ -7,7 +7,6 @@ import numpy as np
 from mediapipe.framework.formats import landmark_pb2
 import os
 
-
 mp_holistic = mp.solutions.holistic
 
 _PRESENCE_THRESHOLD = 0.5
@@ -53,15 +52,15 @@ def _normalized_to_pixel_coordinates(
     x_px = min(math.floor(normalized_x * image_width), image_width - 1)
     y_px = min(math.floor(normalized_y * image_height), image_height - 1)
     # return print("转化的真实坐标:",x_px, y_px)
-    return x_px,y_px
+    return x_px, y_px
 
 
 def draw_landmarks(
-    image: np.ndarray,
-    landmark_list: landmark_pb2.NormalizedLandmarkList,
-    connections: Optional[List[Tuple[int, int]]] = None,
-    landmark_drawing_spec: Union[DrawingSpec,Mapping[int, DrawingSpec]] = DrawingSpec(color=RED_COLOR),
-    connection_drawing_spec: Union[DrawingSpec, Mapping[Tuple[int, int],DrawingSpec]] = DrawingSpec()):
+        image: np.ndarray,
+        landmark_list: landmark_pb2.NormalizedLandmarkList,
+        connections: Optional[List[Tuple[int, int]]] = None,
+        landmark_drawing_spec: Union[DrawingSpec, Mapping[int, DrawingSpec]] = DrawingSpec(color=RED_COLOR),
+        connection_drawing_spec: Union[DrawingSpec, Mapping[Tuple[int, int], DrawingSpec]] = DrawingSpec()):
     """
         主要是绘制关键点的连接图
         image:输入的数据
@@ -81,7 +80,7 @@ def draw_landmarks(
                 (landmark.HasField('presence') and
                  landmark.presence < _PRESENCE_THRESHOLD)):
             continue
-        landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y,          #将归一化坐标值转换为图像坐标值
+        landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y,  # 将归一化坐标值转换为图像坐标值
                                                        image_cols, image_rows)
         # print('图像像素坐标:',landmark_px)
         if landmark_px:
@@ -90,7 +89,7 @@ def draw_landmarks(
     dot_list = []
     if connections:
         # num_landmarks = len(landmark_list.landmark)
-        #connections:keypoint索引元组的列表,用于指定如何在图形中连接地标。
+        # connections:keypoint索引元组的列表,用于指定如何在图形中连接地标。
         # Draws the connections if the start and end landmarks are both visible.
 
         starts = []
@@ -139,7 +138,7 @@ def draw_landmarks(
 
 class mediapipe_detect:
 
-    def mediapipe_detection(self,image, model):
+    def mediapipe_detection(self, image, model):
         """
             mediapipe检测模块
             image:输入数据集
@@ -152,7 +151,7 @@ class mediapipe_detect:
         # image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)  # COLOR COVERSION RGB 2 BGR
         return image, results
 
-    def Drawing_bbox(self,result,bias):
+    def Drawing_bbox(self, result, bias):
 
         '''
             根据关键点坐标,获取最大外接矩形的坐标点
@@ -163,15 +162,14 @@ class mediapipe_detect:
         result = np.array(result)
         b = bias
         if result.any():
-
-            rect = cv2.boundingRect(result)                #返回值, 左上角的坐标[x,y, w,h]
+            rect = cv2.boundingRect(result)  # 返回值, 左上角的坐标[x,y, w,h]
 
             bbox = [[rect[0] - b, rect[1] - b], [rect[0] + rect[2] + b, rect[1] - b],
-                    [rect[0] - b, rect[1] + rect[3] + b], [rect[0] + rect[2] + b, rect[1] + rect[3] + b]]  #四个角的坐标
-            
+                    [rect[0] - b, rect[1] + rect[3] + b], [rect[0] + rect[2] + b, rect[1] + rect[3] + b]]  # 四个角的坐标
+
             return bbox
-        
-    def get_bbox(self,image,results,face_b,left_hand_b,right_hand_b):
+
+    def get_bbox(self, image, results, face_b, left_hand_b, right_hand_b, label):
 
         '''
             主要是根据关键点坐标,绘制矩形框
@@ -181,7 +179,7 @@ class mediapipe_detect:
 
         image.flags.writeable = True
         image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
-        h,w,g = image.shape
+        h, w, g = image.shape
         # print("h:",h,"w:",w,"g:",g)
 
         """获取头部、手部关键点"""
@@ -210,9 +208,15 @@ class mediapipe_detect:
         )
 
         """根据关键点的坐标绘制最大外接矩形"""
-        fl_bbox = self.Drawing_bbox(face_location,face_b)
-        lh_bbox = self.Drawing_bbox(left_hand_location,left_hand_b)
-        rh_bbox = self.Drawing_bbox(right_hand_location,right_hand_b)
+        fl_bbox = self.Drawing_bbox(face_location, face_b)
+        lh_bbox = self.Drawing_bbox(left_hand_location, left_hand_b)
+        rh_bbox = self.Drawing_bbox(right_hand_location, right_hand_b)
+
+        if label == "Nodding" or label == "sleep!" or label == "not playing phone!":
+            lh_bbox = None
+            rh_bbox = None
+        elif label == "not sleep!" or label == "playing phone!":
+            fl_bbox = None
 
         """调整头部检测框的大小"""
         if fl_bbox is not None:
@@ -222,34 +226,34 @@ class mediapipe_detect:
             fl_bbox[0][0] = fl_bbox[0][0] + 30
             fl_bbox[0][1] = fl_bbox[0][1] + 5
             # print(fl_bbox)
-            for i in range(0,4):
-                for j in range(0,2):
+            for i in range(0, 4):
+                for j in range(0, 2):
                     if fl_bbox[i][j] < 0:
                         fl_bbox[i][j] = 0
                     elif fl_bbox[i][0] > w:
                         fl_bbox[i][0] = w
-                    elif fl_bbox[i][1] > h :
+                    elif fl_bbox[i][1] > h:
                         fl_bbox[i][1] = h
                     else:
                         pass
-            cv2.rectangle(image, fl_bbox[0], fl_bbox[3],DrawingSpec.color, DrawingSpec.thickness)
+            cv2.rectangle(image, fl_bbox[0], fl_bbox[3], DrawingSpec.color, DrawingSpec.thickness)
 
         if lh_bbox is not None:
-            for i in range(0,4):
-                for j in range(0,2):
+            for i in range(0, 4):
+                for j in range(0, 2):
                     if lh_bbox[i][j] < 0:
                         lh_bbox[i][j] = 0
                     elif lh_bbox[i][0] > w:
                         lh_bbox[i][0] = w
-                    elif lh_bbox[i][1] > h :
+                    elif lh_bbox[i][1] > h:
                         lh_bbox[i][1] = h
                     else:
                         pass
-            cv2.rectangle(image, lh_bbox[0], lh_bbox[3],DrawingSpec.color, DrawingSpec.thickness)
+            cv2.rectangle(image, lh_bbox[0], lh_bbox[3], DrawingSpec.color, DrawingSpec.thickness)
 
         if rh_bbox is not None:
-            for i in range(0,4):
-                for j in range(0,2):
+            for i in range(0, 4):
+                for j in range(0, 2):
                     if rh_bbox[i][j] < 0:
                         rh_bbox[i][j] = 0
                     elif rh_bbox[i][0] > w:
@@ -258,56 +262,51 @@ class mediapipe_detect:
                         rh_bbox[i][1] = h
                     else:
                         pass
-            cv2.rectangle(image, rh_bbox[0], rh_bbox[3],DrawingSpec.color, DrawingSpec.thickness)
-
-
-        res = {'face_bbox': fl_bbox, 'hand_bbox': [lh_bbox,rh_bbox]}
-
+            cv2.rectangle(image, rh_bbox[0], rh_bbox[3], DrawingSpec.color, DrawingSpec.thickness)
+
+        res = {'face_bbox': fl_bbox, 'hand_bbox': [lh_bbox, rh_bbox]}
+
+        return image, res
+
+
+def main(input_path, output_path, face_b, left_hand_b, right_hand_b):
+    cap = cv2.VideoCapture(input_path)
+    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+    fps = int(cap.get(cv2.CAP_PROP_FPS))
+    codec = cv2.VideoWriter_fourcc(*'XVID')
+    video_name = os.path.basename(input_path)
+    out = cv2.VideoWriter(output_path + "/" + video_name, codec, fps, (frame_width, frame_height))
+    label = ""
+    with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
+        while True:
+            ret, frame = cap.read()
+            if not ret:
+                break
+            image, results = mediapipe_detect().mediapipe_detection(frame, holistic)
+            image, res = mediapipe_detect().get_bbox(image, results, face_b, left_hand_b, right_hand_b, label)
+            out.write(image)
+            cv2.namedWindow("mediapipe_detections", cv2.WINDOW_AUTOSIZE)
+            cv2.imshow("mediapipe_detections", image)
+            # print(res)
+            if cv2.waitKey(10) & 0xFF == ord('q'):
+                break
+
+    cap.release()
+    out.release()
+    cv2.destroyAllWindows()
 
-        # print(res)
-
-        return image,res
-
-
-def main(input_path,output_path,face_b,left_hand_b,right_hand_b):
-
-     cap = cv2.VideoCapture(input_path)
-     frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
-     frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
-     fps = int(cap.get(cv2.CAP_PROP_FPS))
-     codec = cv2.VideoWriter_fourcc(*'XVID')
-     video_name = os.path.basename(input_path)
-     out = cv2.VideoWriter(output_path +"/"+ video_name, codec, fps, (frame_width, frame_height))
-     with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
-         while True:
-             ret, frame = cap.read()
-             if not ret:
-                 break
-             image, results = mediapipe_detect().mediapipe_detection(frame,holistic)
-             image,res = mediapipe_detect().get_bbox(image,results,face_b,left_hand_b,right_hand_b)
-             out.write(image)
-             cv2.namedWindow("mediapipe_detections", cv2.WINDOW_AUTOSIZE)
-             cv2.imshow("mediapipe_detections", image)
-             # print(res)
-             if cv2.waitKey(10) & 0xFF == ord('q'):
-                 break
-
-     cap.release()
-     out.release()
-     cv2.destroyAllWindows()
 
 if __name__ == "__main__":
-    input = 'D:/download/PaddleVideo1/output/output/after_1/test02_0.avi'
+    # input = 'D:/download/PaddleVideo1/output/output/after_1/test02_0.avi'
     # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-1_0_1.avi'
     # input = 'D:/download/PaddleVideo1/output/output/after_1/0711-3_1400_0.avi'
     # input = "C:/Users/Administrator/Pictures/video_seg_re_hand/test01_3.avi"
     # input = 'C:/Users/Administrator/Pictures/video3.0/sleep/0711-3_7_01_5.avi'
-    # input = " D:/download/PaddleVideo1/output/output/after_1/0711-1_199_0.avi"
+    input = " D:/download/PaddleVideo1/output/output/after_1/0711-1_199_0.avi"
     # input = 'D:/download/PaddleVideo1/output/output/after_1/test05_10750_1.avi'
     output = 'D:/download/PaddleVideo1/output/output/output'
-    face_b = 50          #头部标注框修正值
-    left_hand_b = 7      #左手部分标注框修正值
-    right_hand_b = 7     #右手部分标注框修正值
-    main(input,output,face_b,left_hand_b,right_hand_b)
-
-
+    face_b = 50  # 头部标注框修正值
+    left_hand_b = 7  # 左手部分标注框修正值
+    right_hand_b = 7  # 右手部分标注框修正值
+    main(input, output, face_b, left_hand_b, right_hand_b)
\ No newline at end of file