import os
import cv2
from lxml.etree import Element, SubElement, tostring
import xml.etree.ElementTree as ET


# 解析原始XML标注文件,获取人物的坐标信息
def parse_xml(xml_path,labels):

    tree = ET.parse(xml_path)
    root = tree.getroot()

    labeks_list = []
    for obj in root.findall('object'):

        try:

            name_labels = str(obj.find('name').text)
        
            if name_labels == labels:

                labels_dict = {}
                labels_dict['name'] = name_labels
                labels_dict['xmin'] = int(obj.find('bndbox/xmin').text)
                labels_dict['ymin'] = int(obj.find('bndbox/ymin').text)
                labels_dict['xmax'] = int(obj.find('bndbox/xmax').text)
                labels_dict['ymax'] = int(obj.find('bndbox/ymax').text)
                labeks_list.append(labels_dict)
        except:

            continue
    #如果有需要的信息,就返回列表
    if labeks_list:
        return labeks_list


# 筛选手部和头部信息
def select_infors(xml_path,person):

    infors_list = []

    # 分别获得头部和手部信息
    hands = parse_xml(xml_path=xml_path,labels='phone')
    head = parse_xml(xml_path=xml_path,labels='head')

    # 分别处理头部和手部信息
    if hands:

        for h,hand in enumerate(hands):

            hand_result = change_bbox(infor=hand,contrast=person)

            if hand_result:

                infors_list.append(hand_result) 

    if head:

        for h,head_infor in enumerate(head):

            head_result = change_bbox(infor=head_infor,contrast=person)

            if head_result:
                infors_list.append(head_result)
    else:
        pass       
    
    return infors_list


# 转换坐标值
def change_bbox(infor,contrast):
        
        if contrast['ymin'] <= infor['ymin'] and contrast['ymax'] >= infor['ymax'] and contrast['xmin'] <= infor['xmin'] and contrast['xmax'] >= infor['xmax']:

            # 截取图片后的相对坐标

            infor['ymin'] = infor['ymin'] - contrast['ymin']
            infor['ymax'] = infor['ymax']  - contrast['ymin']
            infor['xmin'] = infor['xmin'] - contrast['xmin']
            infor['xmax'] = infor['xmax']  - contrast['xmin']

            return infor

        else:
            pass


# 根据坐标值截图并保存
def cut_img_bbox(image_path,infors,num):

    # # 读取大图
    image = cv2.imread(image_path)

    # 取大一些范围的截图尺寸
    infors.update(ymin=infors['ymin'] - 25)
    infors.update(ymax=infors['ymax'] + 25)
    infors.update(xmin=infors['xmin'] - 25)
    infors.update(xmax=infors['xmax'] + 25)

    # 截取人物图片
    person_image = image[infors['ymin']:infors['ymax'], infors['xmin']:infors['xmax']]

    # 保存人物图片
    person_filename = os.path.basename(image_path).split('.')[0] +'_'+ f"person_{num}.jpg"
    # print(person_filename)
    person_image_path = os.path.join(output_img_folder, person_filename)
    cv2.imwrite(person_image_path, person_image)

    return person_filename




# 读取大图并截取人物图片
def extract_people(image_path, xml_path, output_img_folder, output_xml_folder):

    # 解析XML标注文件、获得人的标注信息
    people = parse_xml(xml_path=xml_path,labels='person')

    # 逐个截取人物图片
    for i, person in enumerate(people):

        # 截图保存并返回图片名
        person_filename = cut_img_bbox(image_path=image_path,infors=person,num=i)

        # 写xml
        node_root = Element('annotation')
        node_folder = SubElement(node_root, 'folder')
        node_folder.text = 'Images'
        node_filename = SubElement(node_root, 'filename')
        node_filename.text = str(person_filename)
        node_size = SubElement(node_root, 'size')
        node_width = SubElement(node_size, 'width')
        node_width.text = str(person['xmax'] - person['xmin'])
        node_height = SubElement(node_size, 'height')
        node_height.text = str(person['ymax'] - person['ymin'])
        node_depth = SubElement(node_size, 'depth')
        node_depth.text = str(3)        

        # 获得小图上的相对坐标
        select_infor_list = select_infors(xml_path=xml_path,person=person)

        if len(select_infor_list)>=1:        # 循环写入box
            for infors in select_infor_list:
                node_object = SubElement(node_root, 'object')
                node_name = SubElement(node_object, 'name')
                node_name.text = str(infors['name'])
                node_difficult = SubElement(node_object, 'difficult')
                node_difficult.text = '0'
                node_bndbox = SubElement(node_object, 'bndbox')
                node_xmin = SubElement(node_bndbox, 'xmin')
                node_xmin.text = str(int(infors['xmin']))
                node_ymin = SubElement(node_bndbox, 'ymin')
                node_ymin.text = str(int(infors['ymin']))
                node_xmax = SubElement(node_bndbox, 'xmax')
                node_xmax.text = str(int(infors['xmax']))
                node_ymax = SubElement(node_bndbox, 'ymax')
                node_ymax.text = str(int(infors['ymax']))
        else:
            pass
        xml = tostring(node_root, pretty_print=True)   # 格式化显示,该换行的换行

        # # 保存新的XML标注文件
        new_xml_filename = person_filename.split('.')[0] + '.xml'
        new_xml_path = os.path.join(output_xml_folder, new_xml_filename)
        new_xml = open(new_xml_path, "wb")
        new_xml.write(xml)
        new_xml.close()

        print(f"Saved person image {person_filename} and corresponding annotation {new_xml_filename}")


# 读取文件
def main(image_path,xml_path,output_img_folder,output_xml_folder):

    # 检查保存的文件路径
    print("frame image save path:{}".format(output_img_folder,output_xml_folder))
    os.makedirs(output_img_folder, exist_ok=True)
    os.makedirs(output_xml_folder, exist_ok=True)

    img_file_list = os.listdir(image_path)
    xml_file_list = os.listdir(xml_path)

    for img_file in img_file_list:

        img_file_full = os.path.join(image_path,img_file)

        xml_file_combination_name = img_file.split('.')[0] + '.xml'

        xml_file_combination = os.path.join(xml_path,xml_file_combination_name)\

        if xml_file_combination_name in xml_file_list:

            # 遍历文件处理数据
            extract_people(image_path=img_file_full, 
                           xml_path=xml_file_combination,
                           output_img_folder=output_img_folder,
                           output_xml_folder=output_xml_folder
                           )
        else:
            pass



if __name__ == '__main__':  

    # 原始大图像路径
    image_path = "E:/Images_Data/XZ/bank2/yolov8/images"

    # 原始XML标注文件路径
    xml_path = "E:/Images_Data/XZ/bank2/yolov8/XML_collect"

    # 输出文件夹路径
    output_img_folder = "E:/Images_Data/XZ/bank2/yolov8/output_files/images"
    output_xml_folder = 'E:/Images_Data/XZ/bank2/yolov8/output_files/Annotations'

    # parse_xml('E:/Images_Data/XZ/bank2/yolov8/XML_collect/Test_1_000014.xml')

    main(image_path=image_path,
                  xml_path=xml_path,
                  output_img_folder=output_img_folder,
                  output_xml_folder=output_xml_folder
                  )