import cv2
from tqdm import tqdm
from ultralytics import YOLO
from ultralytics.yolo.utils.plotting import Annotator
import os

import cv2

import time
import os
import queue
import threading

from pipeline_atm_getResult import atm_det
from pipeline_atm_anaysis import analysis_pph,analysis_button


class analysis_atm():

    def __init__(self,imgPath,savePath,modellist):

        self.imgPath = imgPath
        self.savePath = savePath

        self.imgList = os.listdir(self.imgPath)

        #加载模型
        self.model_person = YOLO(modellist[0])
        self.model_pp_hand = YOLO(modellist[1])
        self.model_blue = YOLO(modellist[2])
        self.model_screen = YOLO(modellist[3])

        # 图片检测后队列
        self.imgQueue1 = queue.Queue(maxsize=0)
        self.imgQueue2 = queue.Queue(maxsize=0)

        #线程
        self.get_img_listThread = threading.Thread(target=self.get_img_list)
        self.analysis_personThread = threading.Thread(target=self.analysis_person)
        self.draw_imagesThread = threading.Thread(target=self.draw_images)


    # 读图
    def get_img_list(self):

        for image_path in self.imgList:

            imagepath = os.path.join(self.imgPath,image_path) 

            self.imgQueue1.put(imagepath)       

    # 检测图
    def analysis_person(self):

        count = 0

        while True:
            if ~self.imgQueue1.empty():

                imagepath = self.imgQueue1.get()

                count = count + 1

                print('-----------------------------------------start',count,self.imgPath,'---------------------------------------------------------')
                # 检测工作人员信息
                score_person = atm_det.get_person_result(imagepath,self.model_person)
                # 检测当前屏幕信息
                score_screen = atm_det.get_screen_result(imagepath,self.model_screen)

                print('-----------------------------------------score_person:',score_person,'score_screen:',score_screen,'---------------------------------------------------------')
                

                # 屏幕信息列表
                screen_pph_list = ['Sign','Enter_Password_1','Enter_Password_2']
                screen_password_list = ['Enter_Password_1','Enter_Password_2']
                # screen_button_list = ['Select_Service_Item','Bank_Card','ID_Card',"Face_Verification"]

                imglist = []

                if score_person and score_screen:

                    #检测屏幕中按钮情况
                    score_button = atm_det.get_blue_result(imagepath,self.model_blue)

                    # print('score_button:',score_button)
                    
                    # 屏幕编号
                    screen_id  = list(score_screen[0].keys())[0]
                    # 检测人员类别
                    person_labels = list(score_person[0].keys())[0]
                    person_ppox = list(score_person[0].values())[0]

                    # print('screen_id:',screen_id,'person_labels:',person_labels)

                    #工作人员行为检测列表
                    # person_re_list = []

                    # 屏幕+工作人员
                    screen_dict = {screen_id  + "_wrong":list(score_screen[0].values())[0]}

                    # person_re_list.append(screen_dict)

                    person_dict = {person_labels + "_wrong":person_ppox}
                    
                    # person_re_list.append(person_dict)
                    person_re_list = [screen_dict,person_dict]

                    # print('person_re_list:',person_re_list)


                        
                    # # 检测到工作人员状态再检测签字情况
                    # if person_labels == 'wrong' and screen_id == 'Sign' :
                    if person_labels == 'wrong' and screen_id == 'Sign':
                    # if person_labels == 'wrong' :

                        score_pp_hand = analysis_pph(imagepath,self.model_pp_hand,score_person)

                        if score_pp_hand:

                            pp_hand_dict = {list(score_pp_hand[0].keys())[0] + '_hand_dict_wrong' :list(score_pp_hand[0].values())[0]}
                            person_re_list.append(pp_hand_dict)
                        else:
                            pp_hand_dict = {'_hand_dict_right':[50,50]}
                            # pp_hand_dict = {'pp_hand_dict_right':'sign and enter password is right action.'}
                            person_re_list.append(pp_hand_dict)

                    # 检测输入密码环节
                    if person_labels == 'wrong' and screen_id in screen_password_list:
                    # if person_labels == 'wrong' :

                        score_pp_hand = analysis_pph(imagepath,self.model_pp_hand,score_person)

                        if list(score_pp_hand[0].keys())[0] == 'password_area':

                            pp_hand_dict = {list(score_pp_hand[0].keys())[0] + '_pp_dict_wrong' :list(score_pp_hand[0].values())[0]}
                            person_re_list.append(pp_hand_dict)
                        else:
                            pp_hand_dict = {'pp_dict_right':[50,50]}
                            # pp_hand_dict = {'pp_hand_dict_right':'sign and enter password is right action.'}
                            person_re_list.append(pp_hand_dict)


                    # 检测屏幕中按钮与首部情况情况
                    if score_button:

                        # 读取图片,判断按钮的位置

                        # print('score_button:',score_button)

                        button_result = analysis_button(imagepath,score_button,score_person)

                        if button_result:
                            button_dict ={'button_dict_wrong':button_result[0]}
                            person_re_list.append(button_dict)
                        else:
                            # 判别顾客的手误识别
                            button_dict ={'button_dict_right':[60,60]}
                            # button_dict ={'button_dict_right':'button action is right.'}
                            person_re_list.append(button_dict)

                    


                    # 检测人员回避情况
                    if screen_id in screen_password_list:

                        if person_labels == 'wrong' or 'correct':

                            avoid_dict = {'avoid_dict_wrong':person_ppox}
                            person_re_list.append(avoid_dict)
                        else:
                            # print('person_labels:',person_labels)
                            avoid_dict = {'avoid_dict_right':[70,70]}
                            # avoid_dict = {'avoid_dict_right':person_ppox}
                            person_re_list.append(avoid_dict) 

                    # print('person_re_list',person_re_list)

                    print('person_re_list_2:',person_re_list)

                # else:
                #     pass

                # print('person_re_list_2:',person_re_list)
                
                # imglist.append(person_re_list)
                
                    img_dict = {imagepath:person_re_list}
                    print('img_dict:',img_dict)

                    self.imgQueue2.put(img_dict)




    def draw_images(self):   

        while True:
            if ~self.imgQueue2.empty():

                imagedict = self.imgQueue2.get()

                print('---------------------------------------','imagedict:',imagedict,'---------------------------------------')
                imgpath = list(imagedict.keys())[0]
                img_result = list(imagedict.values())[0]

                imgcv = cv2.imread(imgpath)
                imgname = os.path.basename(imgpath)

                print('img_result:',img_result)


                for re_dic in img_result:

                    # print(re_dic)

                    re_txt = list(re_dic.keys())[0]

                    re_right = re_txt.split('_')[-1]
                    re_bbox = list(re_dic.values())[0]

                    print('re_bbox:',re_bbox)

                    print('---------------------------------------',re_right,'---------------------------------------')

                    if re_right == "wrong":

                        # re_bbox = list(re_dic.values())[0]

                        cv2.rectangle(imgcv, (int(re_bbox[0]), int(re_bbox[1])),(int(re_bbox[2]), int(re_bbox[3])), (0, 0, 255), 1)

                        cv2.putText(imgcv, re_txt, (int(re_bbox[0]) - 10, int(re_bbox[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 1)

                        # cv2.imwrite(os.path.join(self.savePath,imgname),imgcv)
                    else:

                        print('re_txt:',re_txt)

                        cv2.putText(imgcv, re_txt, (int(re_bbox[0]), int(re_bbox[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 0, 255), 3)

                cv2.imwrite(os.path.join(self.savePath,imgname),imgcv)



    def run(self):

        self.get_img_listThread.start()
        self.analysis_personThread.start()
        self.draw_imagesThread.start()         





if __name__ == '__main__':
    



    modelList = ("model_files/bk1.pt","model_files/best_pph.pt","model_files/best_butten_h_01.pt","model_files/best_screen_02.pt")

    # imgPath= 'E:/BANK_XZ_yolov8/data_file'
    imgPath = 'E:/Images Data/XZ/DatasetId_1845281_1685322163/Images'
    # imgPath = 'E:/Images Data/XZ/DatasetId_1845279_1685326217/DatasetId_1845279_1685326217/Images'
    savepath = 'E:/BANK_XZ_yolov8/output_data_06'

    # imgList = os.listdir(imgPath)

    # for image_path in imgList:

    #     imagepath = os.path.join(imgPath,image_path)

    #     # imagepath = 'E:/BANK_XZ_yolov8/data_file/0001326.jpg'

    a = analysis_atm(imgPath,savepath,modelList)
    a.run()