From 7b35961cce05904e03f3ffa4ae25c3fb458fff82 Mon Sep 17 00:00:00 2001 From: zhoujinjuan <1033171360@qq.com> Date: Mon, 29 Jan 2024 16:00:27 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0motorcycle=EF=BC=8Cbus?= =?UTF-8?q?=E5=92=8Ctruck=E6=A3=80=E6=B5=8B=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=BD=A6=E7=89=8C=E6=A3=80=E6=B5=8B=EF=BC=8C=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 10 +++++- src/algs/ocr_alg_test.py | 24 +++++++++++++ src/config.py | 16 ++++++--- src/main_v2.py | 21 ++++++----- src/main_v2_test.py | 36 +++++++++++++++---- src/utils.py | 75 +++++++++++++++++++++++++++++++--------- src/utils_test.py | 37 ++++++++++++++++++++ 7 files changed, 183 insertions(+), 36 deletions(-) create mode 100644 src/algs/ocr_alg_test.py create mode 100644 src/utils_test.py diff --git a/README.md b/README.md index 238fa84..f46f602 100644 --- a/README.md +++ b/README.md @@ -26,4 +26,12 @@ ### 跳帧 - 相比不跳帧,耗时明显减少,筛选得到的结果也明显减少; -- gpu使用率降低; \ No newline at end of file +- gpu使用率降低; + +### 更新 +- 20240125:对被执法车加入truck,motorcycle,bus类别,增加摩托车车牌识别规则(车牌内容格式为两行); + +### 部分视频无法检测人或车原因 +- 人脸相似个数太少,导致合并时间段时被过滤; +- yolo对车目标检测置信度太低,或直接没有检测到,导致不会进行下一步ocr检测; +- 车牌不够清晰导致ocr检测无结果; \ No newline at end of file diff --git a/src/algs/ocr_alg_test.py b/src/algs/ocr_alg_test.py new file mode 100644 index 0000000..e7bb479 --- /dev/null +++ b/src/algs/ocr_alg_test.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +''' +@Project :Filter_Object +@File :ocr_alg_test.py +@IDE :PyCharm +@Author :zjj +@Date :2024/1/23 13:26 +''' +import unittest + +import cv2 + +from algs.ocr_alg import ocr_predict + + +class OcrTest(unittest.TestCase): + + @staticmethod + def test_ocr_predict(): + frame = cv2.imread('../data/truck3.png') + result, vis_im = ocr_predict(frame) + print(result) + cv2.imwrite('../data/ocr_result.png', vis_im) diff --git a/src/config.py b/src/config.py index 6c40209..2da365c 100644 --- a/src/config.py +++ b/src/config.py @@ -13,7 +13,12 @@ from pathlib import Path PERSON = "person" CAR = "car" -LABEL_NAMES = {0: PERSON, 2: CAR} +TRUCK = 'truck' +MOTORCYCLE = 'motorcycle' +BUS = 'bus' +MOTOR_VEHICLE = 'car_related' +MOTOR_VEHICLE_LIST = [CAR, TRUCK, MOTORCYCLE, BUS] +LABEL_NAMES = {0: PERSON, 2: CAR, 3: MOTORCYCLE, 5: BUS, 7: TRUCK} # 过程中保存图片路径 # SAVE_BASE_DIR = Path('/home/TP/PMP/media/tp_result') @@ -69,11 +74,11 @@ TOPN = {PERSON: 5, CAR: 2} # 人脸检测阈值 0.85 -FACE_THRESHOLD = 0.85 +FACE_THRESHOLD = 0.83 # 目标检测置信度阈值 0.8 -DETECT_CONF_THRESHOLD = 0.8 +DETECT_CONF_THRESHOLD = 0.6 # 框面积阈值 todo 存在有的被执法人就是离镜头远或者模型画的框小 15 -BOX_AREA_THRESHOLD = 15 +BOX_AREA_THRESHOLD = 5 # 每一帧内处理面积top2个大框图 TOPN_AREA = 2 @@ -83,3 +88,6 @@ POLICE_IOU = 0.5 # 人脸清晰度阈值 0-100之间 ARTICULATION_THD = 30 + +# 摩托车两行车牌文字box的接近程度(上面box的ymax和下面box的ymin差值绝对值) +DELTA_Y = 10 \ No newline at end of file diff --git a/src/main_v2.py b/src/main_v2.py index bc33fdd..1d1f7b7 100644 --- a/src/main_v2.py +++ b/src/main_v2.py @@ -19,7 +19,7 @@ from config import ( LABEL_NAMES, BOX_AREA_THRESHOLD, DETECT_CONF_THRESHOLD, - ARTICULATION_THD, + ARTICULATION_THD, MOTOR_VEHICLE, MOTOR_VEHICLE_LIST, ) from fd_face_detection import FaceRecognition from fd_yolo import FdYolov8 @@ -72,7 +72,7 @@ class TrackMain(object): :return: """ xywhs, labels, scores, xy_list = infos - tinfos_each_frame = {PERSON: [], CAR: []} + tinfos_each_frame = {PERSON: [], MOTOR_VEHICLE: []} for xywh, label_id, score, polygon in zip( xywhs, labels, scores, xy_list @@ -93,6 +93,7 @@ class TrackMain(object): conf = round(score, 2) # 通过面积过滤掉一些小框 + # todo 预筛选,阈值不能太高,后面还会有topN过滤 if s <= BOX_AREA_THRESHOLD: continue if conf <= DETECT_CONF_THRESHOLD: @@ -104,7 +105,11 @@ class TrackMain(object): info['conf'] = conf info['box_area'] = s info['polygon_indexs'] = polygon.astype(int) - tinfos_each_frame[label].append(info) + # 机动车包含car,truck,motorcycle,bus + if label in MOTOR_VEHICLE_LIST: + tinfos_each_frame[MOTOR_VEHICLE].append(info) + else: + tinfos_each_frame[label].append(info) return tinfos_each_frame @@ -124,12 +129,12 @@ class TrackMain(object): all_licenses = [] # 存储一帧中所有人脸embedding all_face_embeddings = [] - for label, target_infos in tinfos_each_frame.items(): + for label_alias, target_infos in tinfos_each_frame.items(): # 按照框面积大小降序排序 target_infos.sort(key=lambda x: x['box_area'], reverse=True) # todo 交警 先对整个图片上前topn的大框图中的person,送到交警检测模型中检测是否是交警 - if label == PERSON: + if label_alias == PERSON: police_indexs = get_police( frame_copy, target_infos[:TOPN_AREA], self.model_traffic_police ) @@ -145,7 +150,7 @@ class TrackMain(object): target_img = frame[p1[1]: p2[1], p1[0]: p2[0]] target_img = target_img.astype(np.uint8) - if label == CAR: + if label_alias == MOTOR_VEHICLE: # licenses = predict_ocr(target_img, self.ocr) licenses = predict_ocr(target_img) licenses = list(set(licenses)) @@ -153,7 +158,7 @@ class TrackMain(object): if licenses: is_hit = True - elif label == PERSON: + elif label_alias == PERSON: # 是交警,则不处理 if index in police_indexs: continue @@ -174,7 +179,7 @@ class TrackMain(object): # test if is_hit: frame_copy = draw_rectangle_text( - frame_copy, index, p1, p2, label, info['conf'], -1, info['box_area'] + frame_copy, index, p1, p2, label_alias, info['conf'], -1, info['box_area'] ) return all_face_embeddings, all_licenses, frame_copy diff --git a/src/main_v2_test.py b/src/main_v2_test.py index b1af5a8..595a362 100644 --- a/src/main_v2_test.py +++ b/src/main_v2_test.py @@ -16,10 +16,13 @@ import cv2 from main_v2 import TrackMain from utils import extract_yolo_results +file_name = 'AM03778_09060120200408073818_0015A.MP4' + def make_result_dir(): # 保存图片,可视化 - dir_name = str(uuid.uuid4())[:7] + # dir_name = str(uuid.uuid4())[:7] + dir_name = file_name dirx = Path(r'/mnt/large/zhoujinjuan_data/data/result') / dir_name if not dirx.exists(): dirx.mkdir() @@ -30,17 +33,17 @@ def make_result_dir(): class TestMain(unittest.TestCase): def setUp(self): self.use_trt = True - self.half = True + self.half = False self.obj = TrackMain(device='gpu', use_trt=self.use_trt, half=self.half) - video_path = os.path.join('/mnt/large/zhoujinjuan_data/data/tp_videos/2.mp4') + video_path = os.path.join('/mnt/large/zhoujinjuan_data/data/tp_videos/20240122-test/' + file_name) self.cap = cv2.VideoCapture(video_path) # 记录视频所有满足条件的人脸 self.all_face_embeds = [] # 记录视频所有满足条件的车牌号 self.all_licenses_list = [] self.frame_count = 0 - # 每x帧处理一帧 - self.skip_frame = 9 + # 跳帧:每x帧处理一帧 + self.skip_frame = 1 print('init success') def process_one_frame_v2(self, frame): @@ -48,13 +51,33 @@ class TestMain(unittest.TestCase): # 提取模型检测结果 infos=[xywhs, cls, scores, xy_list] is_hit, infos = extract_yolo_results(results) if not is_hit: + print('no target detect') return [], [], frame + # yolo检测结果可视化 + # annotated_frame = results[0].plot() + # cv2.imwrite('data/yolo_result.png', annotated_frame) + + # 阈值筛选 tinfos_each_frame = self.obj.statistics_one_frame(infos) + # print(tinfos_each_frame) + # 保持原接口不动,比process_one_frame 多返回一个frame_copy face_embeddings, licenses, frame_copy = self.obj.process_topn_in_one_frame(frame, tinfos_each_frame) return face_embeddings, licenses, frame_copy + def test_process_one_frame(self): + """ + 测试一帧的检测结果 + :return: + """ + for index, name in enumerate(['truck2.png']): + pathx = 'data/' + name + frame = cv2.imread(pathx) + _, licenses, frame_copy = self.process_one_frame_v2(frame) + print(f'licenses {licenses}') + cv2.imwrite('data/result_' + str(index) + '.png', frame_copy) + def test_main(self): print('begin test') # 先创建结果目录 @@ -79,7 +102,7 @@ class TestMain(unittest.TestCase): cv2.imwrite((result_dir / ('face_' + name + '.png')).as_posix(), frame_copy) elif not face_embeds and license_list: name = str(uuid.uuid4())[:7] - cv2.imwrite((result_dir / ('car_' + name + '.png')).as_posix(), frame_copy) + cv2.imwrite((result_dir / ('motorVehicle_' + name + '.png')).as_posix(), frame_copy) elif face_embeds and license_list: name = str(uuid.uuid4())[:7] cv2.imwrite((result_dir / ('all_' + name + '.png')).as_posix(), frame_copy) @@ -92,6 +115,7 @@ class TestMain(unittest.TestCase): print(f'frame count {self.frame_count}\nactual deal frame count {actual_frame_number}\ntime {t}s' f'\nface embeddings count {len(self.all_face_embeds)}' f'\nlicenses count {len(self.all_licenses_list)}') + print(self.all_licenses_list) print(f'use_trt={self.use_trt}\nhalf={self.half}\n') diff --git a/src/utils.py b/src/utils.py index dbd5aed..31dec81 100644 --- a/src/utils.py +++ b/src/utils.py @@ -27,7 +27,7 @@ from config import ( FACE_COMPARE_THRESHOLD, ARTICULATION_MODEL_PATH, ARTICULATION_RANGE_PATH, - POLICE_IOU, + POLICE_IOU, DELTA_Y, ) @@ -497,12 +497,65 @@ def predict_ocr(frame): # OCR检测 license_plate_list = [] ocr_result = ocr_alg.ocr_predict(frame) - # 赋予空值 + # 常规车牌号识别 for txt in ocr_result[0].text: # 车牌 - plate_num = parse_plate_number(txt) - if plate_num != None: - license_plate_list.append(plate_num) + if len(txt) in [7, 8, 9]: + plate_num = parse_plate_number(txt) + if plate_num is not None: + # 中间漏点的补上,通常是大卡车车厢上的车牌被识别 + if len(txt) == 7: + plate_num = plate_num[:2] + '·' + plate_num[2:] + license_plate_list.append(plate_num) + + # todo 20240124 摩托车车牌规则匹配,通常是上下两行文本(不管上面是否检测到车牌) + moto_list = parse_moto_number(ocr_result) + + license_plate_list.extend(moto_list) + return license_plate_list + + +# 使用正则解析车牌号 +def parse_plate_number(txt): + pattern = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[A-Z]{1}[·]?[A-Z0-9]{4}[A-Z0-9挂学警港澳]{1}$" + pattern1 = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[A-Z]{1}[·]?[A-Z0-9]{5}[A-Z0-9挂学警港澳]{1}$" + match = re.match(pattern, txt) + match1 = re.match(pattern1, txt) + if match: + return match.group() + if match1: + return match1.group() + + +def parse_moto_number(ocr_result): + """ + 摩托车车牌规则匹配,通常是上下两行文本 + :param ocr_result: + :return: + """ + license_plate_list = [] + p_before = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[·]?[A-Z]{1}$" + p_end = "^[A-Z0-9]{4}[A-Z0-9挂学警港澳]{1}$" + p_end1 = "^[A-Z0-9]{5}[A-Z0-9挂学警港澳]{1}$" + box_text_list = [[box, text] for box, text in zip(ocr_result[0].boxes, ocr_result[0].text)] + for index, box_text in enumerate(box_text_list): + box, text = box_text + # 文本长度预先过滤 + if len(text) in [2, 3]: + match = re.match(p_before, text) + if match: + before = match.group() + if len(before) == 2: + before = before[0] + '·' + before[1] + for box_j, text_j in box_text_list[index + 1:]: + # 先用文本长度过滤,然后用上面box的ymax和下面box的ymin差距小于阈值筛选 + if len(text_j) in [5, 6] and (abs(box_j[1] - box[5]) < DELTA_Y or abs(box_j[3] - box[5]) < DELTA_Y): + match_j = re.match(p_end, text_j) + if not match_j: + match_j = re.match(p_end1, text_j) + if match_j: + license_plate_list.append(before + match_j.group()) + return license_plate_list @@ -573,18 +626,6 @@ def output_car_targets(left_cars, save_dir, output_car_dir): return paths -# 使用正则解析车牌号 -def parse_plate_number(txt): - pattern = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[A-Z]{1}[·]{1}[A-Z0-9]{4}[A-Z0-9挂学警港澳]{1}$" - pattern1 = "^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼使领A-Z]{1}[A-Z]{1}[·]{1}[A-Z0-9]{5}[A-Z0-9挂学警港澳]{1}$" - match = re.match(pattern, txt) - match1 = re.match(pattern1, txt) - if match: - return match.group() - if match1: - return match1.group() - - def det_articulation(image=None, img_path=None): if image is None and not img_path: return 0 diff --git a/src/utils_test.py b/src/utils_test.py new file mode 100644 index 0000000..3a14840 --- /dev/null +++ b/src/utils_test.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +''' +@Project :Filter_Object +@File :utils_test.py +@IDE :PyCharm +@Author :zjj +@Date :2024/1/25 9:56 +''' +import unittest + +import cv2 + +from algs import ocr_alg +from utils import parse_moto_number, predict_ocr, parse_plate_number + + +class UtilsTest(unittest.TestCase): + + @staticmethod + def test_predict_ocr(): + frame = cv2.imread('data/truck1.png') + result = predict_ocr(frame) + print(result) + + @staticmethod + def test_parse_plate_number(): + text = '苏J7272Z' + result = parse_plate_number(text) + print(result) + + @staticmethod + def test_parse_moto_number(): + frame = cv2.imread('data/moto2.jpg') + ocr_result = ocr_alg.ocr_predict(frame) + moto_list = parse_moto_number(ocr_result) + print(moto_list) \ No newline at end of file