You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

537 lines
16 KiB
Python

2 years ago
from analysis_result.get_model_result import det_img
from analysis_result.same_model_img import same_model_img_analysis_labels, model_labels_selet
2 years ago
from model_load.model_load import Load_model
2 years ago
from drawing_img.drawing_img import drawing_frame
2 years ago
from analysis_data.data_rtsp import rtsp_para
2 years ago
from analysis_data.data_dir_file import get_dir_file
2 years ago
from analysis_data.config_load import get_configs
2 years ago
from add_xml import add_xml
from create_xml import create_xml
2 years ago
from analysis_data.change_video import mp4_to_H264
2 years ago
import cv2
import os
import time
from datetime import datetime
2 years ago
import json
2 years ago
from loguru import logger
import logging
import logstash
host = '192.168.10.96'
xbank_logger = logging.getLogger('python-logstash-logger')
xbank_logger.setLevel(logging.INFO)
xbank_logger.addHandler(logstash.LogstashHandler(host, 5959, version=1))
2 years ago
def data_load(args):
source = args[0]
2 years ago
model_ymal = args[1]
2 years ago
# 数据加载
2 years ago
rtsp_source = rtsp_para(source)
2 years ago
dir_source = os.path.isdir(source)
2 years ago
file_source = os.path.isfile(source)
2 years ago
# # 模型加载
2 years ago
model_data = get_configs(model_ymal)
model_inference = Load_model(model_file=model_data["model"],
2 years ago
device=model_data["model_parameter"]['device'],
cache_file=model_data["model_cache"])
2 years ago
if rtsp_source:
2 years ago
2 years ago
rtsp_detect_process(source=source, model_data=model_data,
model_inference=model_inference)
2 years ago
2 years ago
if dir_source:
dir_source_process(source, model_inference, model_data)
2 years ago
if file_source:
2 years ago
2 years ago
file_source_process(source, model_inference, model_data)
2 years ago
2 years ago
def rtsp_detect_process(source, model_data, model_inference):
2 years ago
2 years ago
cap = cv2.VideoCapture(source)
logger.info(f"视频流{source}读取中...")
2 years ago
2 years ago
# 视频流信息
fps = int(cap.get(cv2.CAP_PROP_FPS))
fps_num = fps*model_data['detect_time']
fps_num_small = fps*model_data['detect_time_small']
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
2 years ago
2 years ago
i = 0
j = 0
n = 0
2 years ago
2 years ago
det_t_num = 0
nodet_t_num = 0
2 years ago
det_img = []
2 years ago
video_name_time = 0
det_fps_time = []
2 years ago
2 years ago
while True:
2 years ago
2 years ago
try:
ret, frame = cap.read()
2 years ago
2 years ago
i += 1
j += 1
2 years ago
2 years ago
# 读取到当前视频帧时间
data_now = datetime.now()
get_time = str(data_now.strftime("%H")) + \
str(data_now.strftime("%M")) + str(data_now.strftime("%S")) + \
str(data_now.strftime("%f"))
2 years ago
2 years ago
imgframe_dict = {"path": source, 'frame': frame,
'get_fps': j, 'get_time': get_time}
2 years ago
2 years ago
# 视频暂时保存路径
if video_name_time == 0:
2 years ago
2 years ago
video_name_time = get_time
video_path = video_name(
video_name_base=video_name_time, save_path=model_data['save_videos'], save_file='temp')
2 years ago
2 years ago
out_video = cv2.VideoWriter(
video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, size)
2 years ago
2 years ago
logger.info(f"视频{video_path}已经暂时保存...")
2 years ago
2 years ago
# 模型推理
images_det_result = img_process(
imgframe_dict, model_inference, model_data)
2 years ago
2 years ago
images_update = save_process(
imgframe_dict, images_det_result, model_data)
2 years ago
2 years ago
# 结果判断t
if images_det_result:
2 years ago
2 years ago
det_t_num += 1
2 years ago
2 years ago
if len(det_img) == 0:
img_dict = images_update.copy()
det_img.append(img_dict)
2 years ago
2 years ago
if not images_det_result and len(det_img) > 0:
2 years ago
2 years ago
nodet_t_num += 1
if (det_t_num + nodet_t_num) >= fps_num_small:
para = determine_time(
det_num=det_t_num, nodet_num=nodet_t_num, ratio_set=model_data['detect_ratio'])
if para:
2 years ago
2 years ago
first_fps_time = det_img[0]
first_fps_time.update(
{"dert_fps": (j-int(first_fps_time['get_fps'])+1)})
2 years ago
2 years ago
det_fps_time.append(first_fps_time)
2 years ago
2 years ago
det_img.clear()
det_t_num = 0
nodet_t_num = 0
2 years ago
2 years ago
# 视频保存
out_video.write(images_update['frame'])
2 years ago
2 years ago
# 结果判断 T
if j >= fps_num:
try:
2 years ago
out_video.release()
2 years ago
2 years ago
except Exception:
logger.exception(f"视频release失败")
else:
logger.info("视频release成功")
2 years ago
2 years ago
# T时间截至判断t时间结果。
if det_img:
2 years ago
2 years ago
para = determine_time(
det_num=det_t_num, nodet_num=nodet_t_num, ratio_set=model_data['detect_ratio'])
2 years ago
2 years ago
first_fps_time = det_img[0]
time_1 = (j-int(first_fps_time['get_fps'])+1)
2 years ago
2 years ago
if para and time_1 >= (fps_num_small/2):
first_fps_time.update(
{"dert_fps": (j-int(first_fps_time['get_fps'])+1)})
det_fps_time.append(first_fps_time)
2 years ago
2 years ago
if det_fps_time:
2 years ago
2 years ago
det_fps_time = determine_duration(result_list=det_fps_time)
2 years ago
2 years ago
# 转换后视频保存路径
save_video_name = os.path.basename(video_path)
only_video_name = save_video_name.split('.')[0]
save_video_path = os.path.join(model_data['save_videos'], (str(data_now.strftime(
"%Y")) + str(data_now.strftime("%m")) + str(data_now.strftime("%d"))))
save_video_path = os.path.join(
save_video_path, only_video_name)
# 路径
save_video = os.path.join(save_video_path, save_video_name)
json_path = os.path.join(
save_video_path, only_video_name + '.json')
images_path = os.path.join(save_video_path, 'images')
# 转换视频、保存视频
change_video = mp4_to_H264()
change_video.convert_byfile(video_path, save_video)
# 保存图片
update_det_fps = video_cut_images_save(
det_list=det_fps_time, images_path=images_path)
# print(update_det_fps)
# 保存json文件
re_list, result_lables = json_get(
time_list=update_det_fps, video_path=save_video, fps=fps)
result_path = json_save(re_list, json_path)
send_message(update_det_fps=update_det_fps, result_path=result_path,
source=source, result_lables=result_lables)
else:
# print(video_path)
os.remove(video_path)
logger.info(f"未检测到目标信息的视频{video_path}删除成功")
logger.info('开始信息重置')
det_img.clear()
det_fps_time.clear()
det_t_num = 0
nodet_t_num = 0
video_name_time = 0
j = 0
# print('det_fps_time:', det_fps_time,'det_img:',det_img)
# t2 = time.time()
# tx = t2 - t1
# logger.info(f'检测一张图片的时间为:{tx}.')
2 years ago
except Exception as e:
2 years ago
2 years ago
logger.debug(f"读帧率失败{source}未读到...")
logger.debug(e)
cap.release()
cap = cv2.VideoCapture(source)
logger.info(f"摄像头{source}重新读取")
2 years ago
2 years ago
# break
def video_name(video_name_base, save_path, save_file):
2 years ago
2 years ago
video_name_base = video_name_base
2 years ago
2 years ago
savePath = os.path.join(save_path, save_file)
2 years ago
2 years ago
if not os.path.exists(savePath):
os.makedirs(savePath)
2 years ago
2 years ago
video_path = os.path.join(
savePath, video_name_base + '.mp4')
2 years ago
2 years ago
return video_path
2 years ago
2 years ago
def dir_source_process(source, model_inference, model_data):
2 years ago
2 years ago
img_ext = [".jpg", ".JPG", ".bmp"]
video_ext = [".mp4", ".avi", ".MP4"]
2 years ago
2 years ago
img_list = get_dir_file(source, img_ext)
video_list = get_dir_file(source, video_ext)
2 years ago
2 years ago
if img_list:
2 years ago
2 years ago
for img in img_list:
2 years ago
2 years ago
t1 = time.time()
images = cv2.imread(img)
2 years ago
2 years ago
imgframe_dict = {"path": img, 'frame': images}
2 years ago
images_update = img_process(
2 years ago
imgframe_dict, model_inference, model_data)
2 years ago
2 years ago
t2 = time.time()
tx = t2 - t1
print('检测一张图片的时间为:', tx)
if video_list:
pass
def file_source_process(source, model_inference, model_data):
img_para = True
if img_para:
images = cv2.imread(source)
imgframe_dict = {"path": source, 'frame': images}
images_update = img_process(
imgframe_dict, model_inference, model_data)
2 years ago
def img_process(images, model_inference, model_data):
# t1 = time.time()
2 years ago
# 检测每帧图片,返回推理结果
results = det_img(model_inference=model_inference,
images_frame=images['frame'],
confidence=model_data["model_parameter"]['confidence'],
label_name_list=model_data["model_parameter"]['label_names'])
2 years ago
# print(results)
2 years ago
# print(images['path'])
# 根据需要挑选标注框信息
select_labels_list = model_labels_selet(example_list=model_data["model_parameter"]['compara_label_names'],
result_dict_list=results)
if model_data["model_parameter"]['compara_relevancy']:
# 需要根据的逻辑判断标注框信息
determine_bbox = same_model_img_analysis_labels(example_list=model_data["model_parameter"]['compara_label_names'],
result_dicts_list=select_labels_list,
relevancy=model_data["model_parameter"]['compara_relevancy'],
relevancy_para=model_data["model_parameter"]['relevancy_para'])
else:
determine_bbox = select_labels_list
2 years ago
# print(determine_bbox)
if model_data['model_parameter']['object_num_min']:
if len(determine_bbox) >= model_data["model_parameter"]['object_num_min']:
2 years ago
determine_bbox.clear()
2 years ago
# logger.debug(f"正确获得检测后的信息{determine_bbox}...")
2 years ago
# 返回检测后结果
return determine_bbox
def save_process(images, determine_bbox, model_data):
2 years ago
if determine_bbox:
images.update({"results": determine_bbox})
2 years ago
if model_data['save_path_original']:
imgname_original = images_save(images=images,
save_path=model_data["save_path_original"])
2 years ago
img_save = drawing_frame(
images_frame=images['frame'], result_list=determine_bbox)
2 years ago
2 years ago
images.update({"frame": img_save})
2 years ago
if model_data["save_path"]:
2 years ago
2 years ago
imgname = images_save(
images=images, save_path=model_data["save_path"])
2 years ago
2 years ago
if model_data["save_annotations"]:
if not os.path.exists(model_data["save_annotations"]):
os.makedirs(model_data["save_annotations"])
save_annotations_xml(
xml_save_file=model_data["save_annotations"], save_infors=determine_bbox, images=images['path'])
else:
pass
2 years ago
else:
# 没检测出来的图片是否保存
if model_data["test_path"]:
2 years ago
imgname = images_save(
images=images, save_path=model_data["test_path"])
# print('no:',images['path'],imgname)
2 years ago
# 展示显示
# if images['path'] == 'rtsp://admin:@192.168.10.11':
# cv2.namedWindow('11', cv2.WINDOW_NORMAL)
# cv2.imshow('11',images['frame'])
# cv2.waitKey(1)
# cv2.destroyAllWindows()
# t2 = time.time()
return images
def images_save(images, save_path):
2 years ago
2 years ago
# 保存时候时间为图片名
2 years ago
images_name = images['get_time'] + '.jpg'
2 years ago
2 years ago
if not os.path.exists(save_path):
os.makedirs(save_path)
2 years ago
2 years ago
full_name = os.path.join(save_path, images_name)
2 years ago
2 years ago
cv2.imwrite(full_name, images['frame'])
2 years ago
return full_name
2 years ago
def save_annotations_xml(xml_save_file, save_infors, images):
results = save_infors
img = os.path.basename(images)
img_frame = cv2.imread(images)
xml_save_path = os.path.join(xml_save_file, img.split('.')[0] + '.xml')
w, h, d = img_frame.shape
img_shape = (w, h, d, img)
if os.path.isfile(xml_save_path):
add_xml(inforsDict=results,
xmlFilePath=xml_save_path)
else:
create_xml(boxs=results,
img_shape=img_shape,
xml_path=xml_save_path)
def determine_time(det_num, nodet_num, ratio_set):
ratio = det_num / (det_num + nodet_num)
if ratio >= ratio_set:
return True
else:
return False
2 years ago
def determine_duration(result_list):
i = 0
2 years ago
2 years ago
while i < len(result_list) - 1:
dict_i = result_list[i]
dict_j = result_list[i + 1]
if 'get_fps' in dict_i and 'dert_fps' in dict_i and 'get_fps' in dict_j:
num_i = int(dict_i['get_fps'])
dura_i = int(dict_i['dert_fps'])
num_j = int(dict_j['get_fps'])
2 years ago
2 years ago
if num_i + dura_i == num_j:
dura_j = int(dict_j['dert_fps'])
dura_update = dura_i + dura_j
dict_i['dert_fps'] = dura_update
result_list.pop(i + 1)
else:
i += 1
else:
i += 1
2 years ago
2 years ago
return result_list
2 years ago
2 years ago
# print('2:', result_list)
2 years ago
2 years ago
def json_get(time_list, video_path, fps):
2 years ago
2 years ago
result_dict = {'info': {'video_path': video_path, 'fps': fps}}
re_dict = {}
for i, det_dict in enumerate(time_list):
2 years ago
2 years ago
list_hands = ["Keypad", "hands", "keyboard", "mouse", "phone"]
list_sleep = ["sleep"]
list_person = ["person"]
2 years ago
if list(det_dict['results'][0].keys())[0] in list_hands:
result_lables = 'playing_phone'
if list(det_dict['results'][0].keys())[0] in list_sleep:
result_lables = "sleep"
2 years ago
if list(det_dict['results'][0].keys())[0] in list_person:
2 years ago
2 years ago
result_lables = "person"
2 years ago
2 years ago
fps_dict = {'time': det_dict['get_fps'],
'duration': det_dict['dert_fps'],
'images_path': det_dict['images_path']}
2 years ago
2 years ago
re_dict.update({('id_' + str(i)): fps_dict})
2 years ago
2 years ago
result_dict.update({'result': re_dict})
2 years ago
2 years ago
return result_dict, result_lables
2 years ago
2 years ago
def json_save(result_dict, json_path):
2 years ago
2 years ago
result = json.dumps(result_dict)
2 years ago
2 years ago
f = open(json_path, 'w')
2 years ago
f.write(result + '\n')
f.close
2 years ago
return json_path
def video_cut_images_save(det_list, images_path):
for det_dict in det_list:
images_path_full = images_save(images=det_dict, save_path=images_path)
del det_dict['frame']
del det_dict['get_time']
det_dict.update({'images_path': images_path_full})
return det_list
def send_message(update_det_fps, result_path, source, result_lables):
for det_dict in update_det_fps:
extra = {
'worker': 'xbank',
'time': det_dict['get_fps'],
'config_file': result_path,
'source': source,
'type': result_lables
}
xbank_logger.info('xBank_infer', extra=extra)
logger.info(f'发送信息{extra}')
2 years ago
# if __name__ == '__main__':
# data_load(['rtsp://admin:@192.168.10.203',
# 'E:/Bank_files/Bank_03/xbank_poc_test_use/config_phone.yaml'])