diff --git a/BANK_atm/xml-dimension.py b/BANK_atm/xml-dimension.py new file mode 100644 index 0000000..d28316f --- /dev/null +++ b/BANK_atm/xml-dimension.py @@ -0,0 +1,110 @@ +import time +import torch +from ultralytics import YOLO +import numpy as np +import cv2 +import os +from lxml.etree import Element, SubElement, tostring + + +def create_xml(boxs,img_shape,xml_path): + """ + 创建xml文件,依次写入xml文件必备关键字 + :param boxs: txt文件中的box + :param img_shape: 图片信息,xml中需要写入WHC + :return: + """ + node_root = Element('annotation') + node_folder = SubElement(node_root, 'folder') + node_folder.text = 'Images' + node_filename = SubElement(node_root, 'filename') + node_filename.text = str(img_shape[3]) + node_size = SubElement(node_root, 'size') + node_width = SubElement(node_size, 'width') + node_width.text = str(img_shape[1]) + node_height = SubElement(node_size, 'height') + node_height.text = str(img_shape[0]) + node_depth = SubElement(node_size, 'depth') + node_depth.text = str(img_shape[2]) + + if len(boxs)>=1: # 循环写入box + for box in boxs: + node_object = SubElement(node_root, 'object') + node_name = SubElement(node_object, 'name') + # if str(list_[4]) == "person": # 根据条件筛选需要标注的标签,例如这里只标记person这类,不符合则直接跳过 + # node_name.text = str(list_[4]) + # else: + # continue + node_name.text = str(box[4]) + node_difficult = SubElement(node_object, 'difficult') + node_difficult.text = '0' + node_bndbox = SubElement(node_object, 'bndbox') + node_xmin = SubElement(node_bndbox, 'xmin') + node_xmin.text = str(box[0]) + node_ymin = SubElement(node_bndbox, 'ymin') + node_ymin.text = str(box[1]) + node_xmax = SubElement(node_bndbox, 'xmax') + node_xmax.text = str(box[2]) + node_ymax = SubElement(node_bndbox, 'ymax') + node_ymax.text = str(box[3]) + + xml = tostring(node_root, pretty_print=True) # 格式化显示,该换行的换行 + + file_name = img_shape[3].split(".")[0] + filename = xml_path+"/{}.xml".format(file_name) + + f = open(filename, "wb") + f.write(xml) + f.close() + + +def draw_bounding_box(img, class_name, confidence, x, y, x_plus_w, y_plus_h,color): + label = f'{class_name} ({confidence:.2f})' + cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) + cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) + + +def main(weights_path,img_path,xml_path,outputs_path): + model = YOLO(weights_path) + arr = torch.ones(1,3,224,224) + result_init = model(arr)[0] + colors = np.random.uniform(0, 255, size=(len(result_init.names), 3)) # 初始化画图颜色 + + for name in os.listdir(img_path): + t0 = time.time() + original_image = cv2.imread(os.path.join(img_path,name)) + img_shape = (original_image.shape[0], original_image.shape[1], original_image.shape[2], name) + + # Use the model + results = model(original_image)[0] # predict on an image + boxes = results.boxes.cpu().numpy() + CLASSES = results.names + xyxy_cls = [] + + for i in range(len(boxes.xyxy)): + xyxy_cls.append([int(boxes.xyxy[i][0]),int(boxes.xyxy[i][1]),int(boxes.xyxy[i][2]),int(boxes.xyxy[i][3]),CLASSES[boxes.cls[i]]]) + draw_bounding_box(original_image,CLASSES[boxes.cls[i]],boxes.conf[i],int(boxes.xyxy[i][0]),int(boxes.xyxy[i][1]),int(boxes.xyxy[i][2]),int(boxes.xyxy[i][3]),colors[int(boxes.cls[i])]) + + if len(xyxy_cls) >0: + create_xml(xyxy_cls,img_shape,xml_path) # 创建xmls + t1 = time.time() + + print("img name: {} infer:{:4f} ms".format(name,(t1-t0)*1000)) + cv2.imwrite(os.path.join(outputs_path, name), original_image) # 保留输出结果图 + + # cv2.imshow('image', original_image) + # cv2.waitKey(0) + # cv2.destroyAllWindows() + + + +if __name__ =="__main__": + + # weights_path =r"C:\Users\Administrator\Desktop\train24\weights\best.pt" + weights_path =r"C:\Users\Administrator\Desktop\ultralytics-main\yolov8s.pt" #权重地址 + imgs_path = r"C:\Users\Administrator\Desktop\yolov5-label-xml-main\inference\images" #图片地址 + xmls_path = r"C:\Users\Administrator\Desktop\yolov5-label-xml-main\inference\xmlss" #输出xml文件地址 + outputs_path = "./" + + + main(weights_path,imgs_path,xmls_path,outputs_path)