You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

537 lines
16 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from analysis_result.get_model_result import det_img
from analysis_result.same_model_img import same_model_img_analysis_labels, model_labels_selet
from model_load.model_load import Load_model
from drawing_img.drawing_img import drawing_frame
from analysis_data.data_rtsp import rtsp_para
from analysis_data.data_dir_file import get_dir_file
from analysis_data.config_load import get_configs
from add_xml import add_xml
from create_xml import create_xml
from analysis_data.change_video import mp4_to_H264
import cv2
import os
import time
from datetime import datetime
import json
from loguru import logger
import logging
import logstash
host = '192.168.10.96'
xbank_logger = logging.getLogger('python-logstash-logger')
xbank_logger.setLevel(logging.INFO)
xbank_logger.addHandler(logstash.LogstashHandler(host, 5959, version=1))
def data_load(args):
source = args[0]
model_ymal = args[1]
# 数据加载
rtsp_source = rtsp_para(source)
dir_source = os.path.isdir(source)
file_source = os.path.isfile(source)
# # 模型加载
model_data = get_configs(model_ymal)
model_inference = Load_model(model_file=model_data["model"],
device=model_data["model_parameter"]['device'],
cache_file=model_data["model_cache"])
if rtsp_source:
rtsp_detect_process(source=source, model_data=model_data,
model_inference=model_inference)
if dir_source:
dir_source_process(source, model_inference, model_data)
if file_source:
file_source_process(source, model_inference, model_data)
def rtsp_detect_process(source, model_data, model_inference):
cap = cv2.VideoCapture(source)
logger.info(f"视频流{source}读取中...")
# 视频流信息
fps = int(cap.get(cv2.CAP_PROP_FPS))
fps_num = fps*model_data['detect_time']
fps_num_small = fps*model_data['detect_time_small']
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
i = 0
j = 0
n = 0
det_t_num = 0
nodet_t_num = 0
det_img = []
video_name_time = 0
det_fps_time = []
while True:
try:
ret, frame = cap.read()
i += 1
j += 1
# 读取到当前视频帧时间
data_now = datetime.now()
get_time = str(data_now.strftime("%H")) + \
str(data_now.strftime("%M")) + str(data_now.strftime("%S")) + \
str(data_now.strftime("%f"))
imgframe_dict = {"path": source, 'frame': frame,
'get_fps': j, 'get_time': get_time}
# 视频暂时保存路径
if video_name_time == 0:
video_name_time = get_time
video_path = video_name(
video_name_base=video_name_time, save_path=model_data['save_videos'], save_file='temp')
out_video = cv2.VideoWriter(
video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, size)
logger.info(f"视频{video_path}已经暂时保存...")
# 模型推理
images_det_result = img_process(
imgframe_dict, model_inference, model_data)
images_update = save_process(
imgframe_dict, images_det_result, model_data)
# 结果判断t
if images_det_result:
det_t_num += 1
if len(det_img) == 0:
img_dict = images_update.copy()
det_img.append(img_dict)
if not images_det_result and len(det_img) > 0:
nodet_t_num += 1
if (det_t_num + nodet_t_num) >= fps_num_small:
para = determine_time(
det_num=det_t_num, nodet_num=nodet_t_num, ratio_set=model_data['detect_ratio'])
if para:
first_fps_time = det_img[0]
first_fps_time.update(
{"dert_fps": (j-int(first_fps_time['get_fps'])+1)})
det_fps_time.append(first_fps_time)
det_img.clear()
det_t_num = 0
nodet_t_num = 0
# 视频保存
out_video.write(images_update['frame'])
# 结果判断 T
if j >= fps_num:
try:
out_video.release()
except Exception:
logger.exception(f"视频release失败")
else:
logger.info("视频release成功")
# T时间截至判断t时间结果。
if det_img:
para = determine_time(
det_num=det_t_num, nodet_num=nodet_t_num, ratio_set=model_data['detect_ratio'])
first_fps_time = det_img[0]
time_1 = (j-int(first_fps_time['get_fps'])+1)
if para and time_1 >= (fps_num_small/2):
first_fps_time.update(
{"dert_fps": (j-int(first_fps_time['get_fps'])+1)})
det_fps_time.append(first_fps_time)
if det_fps_time:
det_fps_time = determine_duration(result_list=det_fps_time)
# 转换后视频保存路径
save_video_name = os.path.basename(video_path)
only_video_name = save_video_name.split('.')[0]
save_video_path = os.path.join(model_data['save_videos'], (str(data_now.strftime(
"%Y")) + str(data_now.strftime("%m")) + str(data_now.strftime("%d"))))
save_video_path = os.path.join(
save_video_path, only_video_name)
# 路径
save_video = os.path.join(save_video_path, save_video_name)
json_path = os.path.join(
save_video_path, only_video_name + '.json')
images_path = os.path.join(save_video_path, 'images')
# 转换视频、保存视频
change_video = mp4_to_H264()
change_video.convert_byfile(video_path, save_video)
# 保存图片
update_det_fps = video_cut_images_save(
det_list=det_fps_time, images_path=images_path)
# print(update_det_fps)
# 保存json文件
re_list, result_lables = json_get(
time_list=update_det_fps, video_path=save_video, fps=fps)
result_path = json_save(re_list, json_path)
send_message(update_det_fps=update_det_fps, result_path=result_path,
source=source, result_lables=result_lables)
else:
# print(video_path)
os.remove(video_path)
logger.info(f"未检测到目标信息的视频{video_path}删除成功")
logger.info('开始信息重置')
det_img.clear()
det_fps_time.clear()
det_t_num = 0
nodet_t_num = 0
video_name_time = 0
j = 0
# print('det_fps_time:', det_fps_time,'det_img:',det_img)
# t2 = time.time()
# tx = t2 - t1
# logger.info(f'检测一张图片的时间为:{tx}.')
except Exception as e:
logger.debug(f"读帧率失败{source}未读到...")
logger.debug(e)
cap.release()
cap = cv2.VideoCapture(source)
logger.info(f"摄像头{source}重新读取")
# break
def video_name(video_name_base, save_path, save_file):
video_name_base = video_name_base
savePath = os.path.join(save_path, save_file)
if not os.path.exists(savePath):
os.makedirs(savePath)
video_path = os.path.join(
savePath, video_name_base + '.mp4')
return video_path
def dir_source_process(source, model_inference, model_data):
img_ext = [".jpg", ".JPG", ".bmp"]
video_ext = [".mp4", ".avi", ".MP4"]
img_list = get_dir_file(source, img_ext)
video_list = get_dir_file(source, video_ext)
if img_list:
for img in img_list:
t1 = time.time()
images = cv2.imread(img)
imgframe_dict = {"path": img, 'frame': images}
images_update = img_process(
imgframe_dict, model_inference, model_data)
t2 = time.time()
tx = t2 - t1
print('检测一张图片的时间为:', tx)
if video_list:
pass
def file_source_process(source, model_inference, model_data):
img_para = True
if img_para:
images = cv2.imread(source)
imgframe_dict = {"path": source, 'frame': images}
images_update = img_process(
imgframe_dict, model_inference, model_data)
def img_process(images, model_inference, model_data):
# t1 = time.time()
# 检测每帧图片,返回推理结果
results = det_img(model_inference=model_inference,
images_frame=images['frame'],
confidence=model_data["model_parameter"]['confidence'],
label_name_list=model_data["model_parameter"]['label_names'])
# print(results)
# print(images['path'])
# 根据需要挑选标注框信息
select_labels_list = model_labels_selet(example_list=model_data["model_parameter"]['compara_label_names'],
result_dict_list=results)
if model_data["model_parameter"]['compara_relevancy']:
# 需要根据的逻辑判断标注框信息
determine_bbox = same_model_img_analysis_labels(example_list=model_data["model_parameter"]['compara_label_names'],
result_dicts_list=select_labels_list,
relevancy=model_data["model_parameter"]['compara_relevancy'],
relevancy_para=model_data["model_parameter"]['relevancy_para'])
else:
determine_bbox = select_labels_list
# print(determine_bbox)
if model_data['model_parameter']['object_num_min']:
if len(determine_bbox) >= model_data["model_parameter"]['object_num_min']:
determine_bbox.clear()
# logger.debug(f"正确获得检测后的信息{determine_bbox}...")
# 返回检测后结果
return determine_bbox
def save_process(images, determine_bbox, model_data):
if determine_bbox:
images.update({"results": determine_bbox})
if model_data['save_path_original']:
imgname_original = images_save(images=images,
save_path=model_data["save_path_original"])
img_save = drawing_frame(
images_frame=images['frame'], result_list=determine_bbox)
images.update({"frame": img_save})
if model_data["save_path"]:
imgname = images_save(
images=images, save_path=model_data["save_path"])
if model_data["save_annotations"]:
if not os.path.exists(model_data["save_annotations"]):
os.makedirs(model_data["save_annotations"])
save_annotations_xml(
xml_save_file=model_data["save_annotations"], save_infors=determine_bbox, images=images['path'])
else:
pass
else:
# 没检测出来的图片是否保存
if model_data["test_path"]:
imgname = images_save(
images=images, save_path=model_data["test_path"])
# print('no:',images['path'],imgname)
# 展示显示
# if images['path'] == 'rtsp://admin:@192.168.10.11':
# cv2.namedWindow('11', cv2.WINDOW_NORMAL)
# cv2.imshow('11',images['frame'])
# cv2.waitKey(1)
# cv2.destroyAllWindows()
# t2 = time.time()
return images
def images_save(images, save_path):
# 保存时候时间为图片名
images_name = images['get_time'] + '.jpg'
if not os.path.exists(save_path):
os.makedirs(save_path)
full_name = os.path.join(save_path, images_name)
cv2.imwrite(full_name, images['frame'])
return full_name
def save_annotations_xml(xml_save_file, save_infors, images):
results = save_infors
img = os.path.basename(images)
img_frame = cv2.imread(images)
xml_save_path = os.path.join(xml_save_file, img.split('.')[0] + '.xml')
w, h, d = img_frame.shape
img_shape = (w, h, d, img)
if os.path.isfile(xml_save_path):
add_xml(inforsDict=results,
xmlFilePath=xml_save_path)
else:
create_xml(boxs=results,
img_shape=img_shape,
xml_path=xml_save_path)
def determine_time(det_num, nodet_num, ratio_set):
ratio = det_num / (det_num + nodet_num)
if ratio >= ratio_set:
return True
else:
return False
def determine_duration(result_list):
i = 0
while i < len(result_list) - 1:
dict_i = result_list[i]
dict_j = result_list[i + 1]
if 'get_fps' in dict_i and 'dert_fps' in dict_i and 'get_fps' in dict_j:
num_i = int(dict_i['get_fps'])
dura_i = int(dict_i['dert_fps'])
num_j = int(dict_j['get_fps'])
if num_i + dura_i == num_j:
dura_j = int(dict_j['dert_fps'])
dura_update = dura_i + dura_j
dict_i['dert_fps'] = dura_update
result_list.pop(i + 1)
else:
i += 1
else:
i += 1
return result_list
# print('2:', result_list)
def json_get(time_list, video_path, fps):
result_dict = {'info': {'video_path': video_path, 'fps': fps}}
re_dict = {}
for i, det_dict in enumerate(time_list):
list_hands = ["Keypad", "hands", "keyboard", "mouse", "phone"]
list_sleep = ["sleep"]
list_person = ["person"]
if list(det_dict['results'][0].keys())[0] in list_hands:
result_lables = 'playing_phone'
if list(det_dict['results'][0].keys())[0] in list_sleep:
result_lables = "sleep"
if list(det_dict['results'][0].keys())[0] in list_person:
result_lables = "person"
fps_dict = {'time': det_dict['get_fps'],
'duration': det_dict['dert_fps'],
'images_path': det_dict['images_path']}
re_dict.update({('id_' + str(i)): fps_dict})
result_dict.update({'result': re_dict})
return result_dict, result_lables
def json_save(result_dict, json_path):
result = json.dumps(result_dict)
f = open(json_path, 'w')
f.write(result + '\n')
f.close
return json_path
def video_cut_images_save(det_list, images_path):
for det_dict in det_list:
images_path_full = images_save(images=det_dict, save_path=images_path)
del det_dict['frame']
del det_dict['get_time']
det_dict.update({'images_path': images_path_full})
return det_list
def send_message(update_det_fps, result_path, source, result_lables):
for det_dict in update_det_fps:
extra = {
'worker': 'xbank',
'time': det_dict['get_fps'],
'config_file': result_path,
'source': source,
'type': result_lables
}
xbank_logger.info('xBank_infer', extra=extra)
logger.info(f'发送信息{extra}')
# if __name__ == '__main__':
# data_load(['rtsp://admin:@192.168.10.203',
# 'E:/Bank_files/Bank_03/xbank_poc_test_use/config_phone.yaml'])