增加motorcycle,bus和truck检测,优化车牌检测,增加单元测试

main
zhoujinjuan
parent 2f43ce101e
commit 7b35961cce

@ -26,4 +26,12 @@
### 跳帧
- 相比不跳帧,耗时明显减少,筛选得到的结果也明显减少;
- gpu使用率降低
- gpu使用率降低
### 更新
- 20240125对被执法车加入truckmotorcyclebus类别增加摩托车车牌识别规则车牌内容格式为两行
### 部分视频无法检测人或车原因
- 人脸相似个数太少,导致合并时间段时被过滤;
- yolo对车目标检测置信度太低或直接没有检测到导致不会进行下一步ocr检测
- 车牌不够清晰导致ocr检测无结果

@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project Filter_Object
@File ocr_alg_test.py
@IDE PyCharm
@Author zjj
@Date 2024/1/23 13:26
'''
import unittest
import cv2
from algs.ocr_alg import ocr_predict
class OcrTest(unittest.TestCase):
@staticmethod
def test_ocr_predict():
frame = cv2.imread('../data/truck3.png')
result, vis_im = ocr_predict(frame)
print(result)
cv2.imwrite('../data/ocr_result.png', vis_im)

@ -13,7 +13,12 @@ from pathlib import Path
PERSON = "person"
CAR = "car"
LABEL_NAMES = {0: PERSON, 2: CAR}
TRUCK = 'truck'
MOTORCYCLE = 'motorcycle'
BUS = 'bus'
MOTOR_VEHICLE = 'car_related'
MOTOR_VEHICLE_LIST = [CAR, TRUCK, MOTORCYCLE, BUS]
LABEL_NAMES = {0: PERSON, 2: CAR, 3: MOTORCYCLE, 5: BUS, 7: TRUCK}
# 过程中保存图片路径
# SAVE_BASE_DIR = Path('/home/TP/PMP/media/tp_result')
@ -69,11 +74,11 @@ TOPN = {PERSON: 5, CAR: 2}
# 人脸检测阈值 0.85
FACE_THRESHOLD = 0.85
FACE_THRESHOLD = 0.83
# 目标检测置信度阈值 0.8
DETECT_CONF_THRESHOLD = 0.8
DETECT_CONF_THRESHOLD = 0.6
# 框面积阈值 todo 存在有的被执法人就是离镜头远或者模型画的框小 15
BOX_AREA_THRESHOLD = 15
BOX_AREA_THRESHOLD = 5
# 每一帧内处理面积top2个大框图
TOPN_AREA = 2
@ -83,3 +88,6 @@ POLICE_IOU = 0.5
# 人脸清晰度阈值 0-100之间
ARTICULATION_THD = 30
# 摩托车两行车牌文字box的接近程度上面box的ymax和下面box的ymin差值绝对值
DELTA_Y = 10

@ -19,7 +19,7 @@ from config import (
LABEL_NAMES,
BOX_AREA_THRESHOLD,
DETECT_CONF_THRESHOLD,
ARTICULATION_THD,
ARTICULATION_THD, MOTOR_VEHICLE, MOTOR_VEHICLE_LIST,
)
from fd_face_detection import FaceRecognition
from fd_yolo import FdYolov8
@ -72,7 +72,7 @@ class TrackMain(object):
:return:
"""
xywhs, labels, scores, xy_list = infos
tinfos_each_frame = {PERSON: [], CAR: []}
tinfos_each_frame = {PERSON: [], MOTOR_VEHICLE: []}
for xywh, label_id, score, polygon in zip(
xywhs, labels, scores, xy_list
@ -93,6 +93,7 @@ class TrackMain(object):
conf = round(score, 2)
# 通过面积过滤掉一些小框
# todo 预筛选阈值不能太高后面还会有topN过滤
if s <= BOX_AREA_THRESHOLD:
continue
if conf <= DETECT_CONF_THRESHOLD:
@ -104,7 +105,11 @@ class TrackMain(object):
info['conf'] = conf
info['box_area'] = s
info['polygon_indexs'] = polygon.astype(int)
tinfos_each_frame[label].append(info)
# 机动车包含car,truck,motorcycle,bus
if label in MOTOR_VEHICLE_LIST:
tinfos_each_frame[MOTOR_VEHICLE].append(info)
else:
tinfos_each_frame[label].append(info)
return tinfos_each_frame
@ -124,12 +129,12 @@ class TrackMain(object):
all_licenses = []
# 存储一帧中所有人脸embedding
all_face_embeddings = []
for label, target_infos in tinfos_each_frame.items():
for label_alias, target_infos in tinfos_each_frame.items():
# 按照框面积大小降序排序
target_infos.sort(key=lambda x: x['box_area'], reverse=True)
# todo 交警 先对整个图片上前topn的大框图中的person送到交警检测模型中检测是否是交警
if label == PERSON:
if label_alias == PERSON:
police_indexs = get_police(
frame_copy, target_infos[:TOPN_AREA], self.model_traffic_police
)
@ -145,7 +150,7 @@ class TrackMain(object):
target_img = frame[p1[1]: p2[1], p1[0]: p2[0]]
target_img = target_img.astype(np.uint8)
if label == CAR:
if label_alias == MOTOR_VEHICLE:
# licenses = predict_ocr(target_img, self.ocr)
licenses = predict_ocr(target_img)
licenses = list(set(licenses))
@ -153,7 +158,7 @@ class TrackMain(object):
if licenses:
is_hit = True
elif label == PERSON:
elif label_alias == PERSON:
# 是交警,则不处理
if index in police_indexs:
continue
@ -174,7 +179,7 @@ class TrackMain(object):
# test
if is_hit:
frame_copy = draw_rectangle_text(
frame_copy, index, p1, p2, label, info['conf'], -1, info['box_area']
frame_copy, index, p1, p2, label_alias, info['conf'], -1, info['box_area']
)
return all_face_embeddings, all_licenses, frame_copy

@ -16,10 +16,13 @@ import cv2
from main_v2 import TrackMain
from utils import extract_yolo_results
file_name = 'AM03778_09060120200408073818_0015A.MP4'
def make_result_dir():
# 保存图片,可视化
dir_name = str(uuid.uuid4())[:7]
# dir_name = str(uuid.uuid4())[:7]
dir_name = file_name
dirx = Path(r'/mnt/large/zhoujinjuan_data/data/result') / dir_name
if not dirx.exists():
dirx.mkdir()
@ -30,17 +33,17 @@ def make_result_dir():
class TestMain(unittest.TestCase):
def setUp(self):
self.use_trt = True
self.half = True
self.half = False
self.obj = TrackMain(device='gpu', use_trt=self.use_trt, half=self.half)
video_path = os.path.join('/mnt/large/zhoujinjuan_data/data/tp_videos/2.mp4')
video_path = os.path.join('/mnt/large/zhoujinjuan_data/data/tp_videos/20240122-test/' + file_name)
self.cap = cv2.VideoCapture(video_path)
# 记录视频所有满足条件的人脸
self.all_face_embeds = []
# 记录视频所有满足条件的车牌号
self.all_licenses_list = []
self.frame_count = 0
# 每x帧处理一帧
self.skip_frame = 9
# 跳帧:每x帧处理一帧
self.skip_frame = 1
print('init success')
def process_one_frame_v2(self, frame):
@ -48,13 +51,33 @@ class TestMain(unittest.TestCase):
# 提取模型检测结果 infos=[xywhs, cls, scores, xy_list]
is_hit, infos = extract_yolo_results(results)
if not is_hit:
print('no target detect')
return [], [], frame
# yolo检测结果可视化
# annotated_frame = results[0].plot()
# cv2.imwrite('data/yolo_result.png', annotated_frame)
# 阈值筛选
tinfos_each_frame = self.obj.statistics_one_frame(infos)
# print(tinfos_each_frame)
# 保持原接口不动比process_one_frame 多返回一个frame_copy
face_embeddings, licenses, frame_copy = self.obj.process_topn_in_one_frame(frame, tinfos_each_frame)
return face_embeddings, licenses, frame_copy
def test_process_one_frame(self):
"""
测试一帧的检测结果
:return:
"""
for index, name in enumerate(['truck2.png']):
pathx = 'data/' + name
frame = cv2.imread(pathx)
_, licenses, frame_copy = self.process_one_frame_v2(frame)
print(f'licenses {licenses}')
cv2.imwrite('data/result_' + str(index) + '.png', frame_copy)
def test_main(self):
print('begin test')
# 先创建结果目录
@ -79,7 +102,7 @@ class TestMain(unittest.TestCase):
cv2.imwrite((result_dir / ('face_' + name + '.png')).as_posix(), frame_copy)
elif not face_embeds and license_list:
name = str(uuid.uuid4())[:7]
cv2.imwrite((result_dir / ('car_' + name + '.png')).as_posix(), frame_copy)
cv2.imwrite((result_dir / ('motorVehicle_' + name + '.png')).as_posix(), frame_copy)
elif face_embeds and license_list:
name = str(uuid.uuid4())[:7]
cv2.imwrite((result_dir / ('all_' + name + '.png')).as_posix(), frame_copy)
@ -92,6 +115,7 @@ class TestMain(unittest.TestCase):
print(f'frame count {self.frame_count}\nactual deal frame count {actual_frame_number}\ntime {t}s'
f'\nface embeddings count {len(self.all_face_embeds)}'
f'\nlicenses count {len(self.all_licenses_list)}')
print(self.all_licenses_list)
print(f'use_trt={self.use_trt}\nhalf={self.half}\n')

@ -27,7 +27,7 @@ from config import (
FACE_COMPARE_THRESHOLD,
ARTICULATION_MODEL_PATH,
ARTICULATION_RANGE_PATH,
POLICE_IOU,
POLICE_IOU, DELTA_Y,
)
@ -497,12 +497,65 @@ def predict_ocr(frame):
# OCR检测
license_plate_list = []
ocr_result = ocr_alg.ocr_predict(frame)
# 赋予空值
# 常规车牌号识别
for txt in ocr_result[0].text:
# 车牌
plate_num = parse_plate_number(txt)
if plate_num != None:
license_plate_list.append(plate_num)
if len(txt) in [7, 8, 9]:
plate_num = parse_plate_number(txt)
if plate_num is not None:
# 中间漏点的补上,通常是大卡车车厢上的车牌被识别
if len(txt) == 7:
plate_num = plate_num[:2] + '·' + plate_num[2:]
license_plate_list.append(plate_num)
# todo 20240124 摩托车车牌规则匹配,通常是上下两行文本(不管上面是否检测到车牌)
moto_list = parse_moto_number(ocr_result)
license_plate_list.extend(moto_list)
return license_plate_list
# 使用正则解析车牌号
def parse_plate_number(txt):
pattern = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[A-Z]{1}[·]?[A-Z0-9]{4}[A-Z0-9挂学警港澳]{1}$"
pattern1 = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[A-Z]{1}[·]?[A-Z0-9]{5}[A-Z0-9挂学警港澳]{1}$"
match = re.match(pattern, txt)
match1 = re.match(pattern1, txt)
if match:
return match.group()
if match1:
return match1.group()
def parse_moto_number(ocr_result):
"""
摩托车车牌规则匹配通常是上下两行文本
:param ocr_result:
:return:
"""
license_plate_list = []
p_before = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[·]?[A-Z]{1}$"
p_end = "^[A-Z0-9]{4}[A-Z0-9挂学警港澳]{1}$"
p_end1 = "^[A-Z0-9]{5}[A-Z0-9挂学警港澳]{1}$"
box_text_list = [[box, text] for box, text in zip(ocr_result[0].boxes, ocr_result[0].text)]
for index, box_text in enumerate(box_text_list):
box, text = box_text
# 文本长度预先过滤
if len(text) in [2, 3]:
match = re.match(p_before, text)
if match:
before = match.group()
if len(before) == 2:
before = before[0] + '·' + before[1]
for box_j, text_j in box_text_list[index + 1:]:
# 先用文本长度过滤然后用上面box的ymax和下面box的ymin差距小于阈值筛选
if len(text_j) in [5, 6] and (abs(box_j[1] - box[5]) < DELTA_Y or abs(box_j[3] - box[5]) < DELTA_Y):
match_j = re.match(p_end, text_j)
if not match_j:
match_j = re.match(p_end1, text_j)
if match_j:
license_plate_list.append(before + match_j.group())
return license_plate_list
@ -573,18 +626,6 @@ def output_car_targets(left_cars, save_dir, output_car_dir):
return paths
# 使用正则解析车牌号
def parse_plate_number(txt):
pattern = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[A-Z]{1}[·]{1}[A-Z0-9]{4}[A-Z0-9挂学警港澳]{1}$"
pattern1 = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[A-Z]{1}[·]{1}[A-Z0-9]{5}[A-Z0-9挂学警港澳]{1}$"
match = re.match(pattern, txt)
match1 = re.match(pattern1, txt)
if match:
return match.group()
if match1:
return match1.group()
def det_articulation(image=None, img_path=None):
if image is None and not img_path:
return 0

@ -0,0 +1,37 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project Filter_Object
@File utils_test.py
@IDE PyCharm
@Author zjj
@Date 2024/1/25 9:56
'''
import unittest
import cv2
from algs import ocr_alg
from utils import parse_moto_number, predict_ocr, parse_plate_number
class UtilsTest(unittest.TestCase):
@staticmethod
def test_predict_ocr():
frame = cv2.imread('data/truck1.png')
result = predict_ocr(frame)
print(result)
@staticmethod
def test_parse_plate_number():
text = '苏J7272Z'
result = parse_plate_number(text)
print(result)
@staticmethod
def test_parse_moto_number():
frame = cv2.imread('data/moto2.jpg')
ocr_result = ocr_alg.ocr_predict(frame)
moto_list = parse_moto_number(ocr_result)
print(moto_list)
Loading…
Cancel
Save