You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
4.4 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import time
import torch
from ultralytics import YOLO
import numpy as np
import cv2
import os
from lxml.etree import Element, SubElement, tostring
def create_xml(boxs,img_shape,xml_path):
"""
创建xml文件依次写入xml文件必备关键字
:param boxs: txt文件中的box
:param img_shape: 图片信息xml中需要写入WHC
:return:
"""
node_root = Element('annotation')
node_folder = SubElement(node_root, 'folder')
node_folder.text = 'Images'
node_filename = SubElement(node_root, 'filename')
node_filename.text = str(img_shape[3])
node_size = SubElement(node_root, 'size')
node_width = SubElement(node_size, 'width')
node_width.text = str(img_shape[1])
node_height = SubElement(node_size, 'height')
node_height.text = str(img_shape[0])
node_depth = SubElement(node_size, 'depth')
node_depth.text = str(img_shape[2])
if len(boxs)>=1: # 循环写入box
for box in boxs:
node_object = SubElement(node_root, 'object')
node_name = SubElement(node_object, 'name')
# if str(list_[4]) == "person": # 根据条件筛选需要标注的标签,例如这里只标记person这类不符合则直接跳过
# node_name.text = str(list_[4])
# else:
# continue
node_name.text = str(box[4])
node_difficult = SubElement(node_object, 'difficult')
node_difficult.text = '0'
node_bndbox = SubElement(node_object, 'bndbox')
node_xmin = SubElement(node_bndbox, 'xmin')
node_xmin.text = str(box[0])
node_ymin = SubElement(node_bndbox, 'ymin')
node_ymin.text = str(box[1])
node_xmax = SubElement(node_bndbox, 'xmax')
node_xmax.text = str(box[2])
node_ymax = SubElement(node_bndbox, 'ymax')
node_ymax.text = str(box[3])
xml = tostring(node_root, pretty_print=True) # 格式化显示,该换行的换行
file_name = img_shape[3].split(".")[0]
filename = xml_path+"/{}.xml".format(file_name)
f = open(filename, "wb")
f.write(xml)
f.close()
def draw_bounding_box(img, class_name, confidence, x, y, x_plus_w, y_plus_h,color):
label = f'{class_name} ({confidence:.2f})'
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
def main(weights_path,img_path,xml_path,outputs_path):
model = YOLO(weights_path)
arr = torch.ones(1,3,224,224)
result_init = model(arr)[0]
colors = np.random.uniform(0, 255, size=(len(result_init.names), 3)) # 初始化画图颜色
for name in os.listdir(img_path):
t0 = time.time()
original_image = cv2.imread(os.path.join(img_path,name))
img_shape = (original_image.shape[0], original_image.shape[1], original_image.shape[2], name)
# Use the model
results = model(original_image)[0] # predict on an image
boxes = results.boxes.cpu().numpy()
CLASSES = results.names
xyxy_cls = []
for i in range(len(boxes.xyxy)):
xyxy_cls.append([int(boxes.xyxy[i][0]),int(boxes.xyxy[i][1]),int(boxes.xyxy[i][2]),int(boxes.xyxy[i][3]),CLASSES[boxes.cls[i]]])
draw_bounding_box(original_image,CLASSES[boxes.cls[i]],boxes.conf[i],int(boxes.xyxy[i][0]),int(boxes.xyxy[i][1]),int(boxes.xyxy[i][2]),int(boxes.xyxy[i][3]),colors[int(boxes.cls[i])])
if len(xyxy_cls) >0:
create_xml(xyxy_cls,img_shape,xml_path) # 创建xmls
t1 = time.time()
print("img name: {} infer:{:4f} ms".format(name,(t1-t0)*1000))
cv2.imwrite(os.path.join(outputs_path, name), original_image) # 保留输出结果图
# cv2.imshow('image', original_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
if __name__ =="__main__":
# weights_path =r"C:\Users\Administrator\Desktop\train24\weights\best.pt"
weights_path =r"C:\Users\Administrator\Desktop\ultralytics-main\yolov8s.pt" #权重地址
imgs_path = r"C:\Users\Administrator\Desktop\yolov5-label-xml-main\inference\images" #图片地址
xmls_path = r"C:\Users\Administrator\Desktop\yolov5-label-xml-main\inference\xmlss" #输出xml文件地址
outputs_path = "./"
main(weights_path,imgs_path,xmls_path,outputs_path)