You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
XZNSH-Code-AI/tool/PP_TSMv2_infer.py

165 lines
5.9 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import os.path as osp
from paddlevideo.utils.config import get_config
from paddle.inference import Config, create_predictor
from utils import build_inference_helper
class PP_TSMv2_predict(object):
"""PP-TSMv2模型中常用的参数初始化"""
def __init__(self,use_gpu=True,ir_optim=True,
disable_glog=False,save_name=None,enable_mklddn=False,
precision="fp32",gpu_mem=8000,cpu_threads=None):
self.use_gpu = use_gpu #是否使用GPU
self.cpu_threads = cpu_threads #cpu线程数
self.ir_optim = ir_optim #是否开启IR优化
self.disable_glog = disable_glog
self.gpu_mem = gpu_mem #GPU存储大小
self.enable_mkldnn = enable_mklddn #是否开启mkldnn
self.precision = precision #mfldnn精度
self.save_name = save_name #转化推理模型存放名称
def parse_file_paths(self,input_path: str) -> list:
"""
获取模型输入数据
input_path:模型的输入文件
"""
if osp.isfile(input_path):
files = [
input_path,
]
else:
files = os.listdir(input_path)
files = [
file for file in files
if (file.endswith(".avi") or file.endswith(".mp4"))
]
files = [osp.join(input_path, file) for file in files]
return files
def create_paddle_predictor(self,model_f,pretr_p,cfg):
"""
创建推理引擎
model_f:可推理模型存放的路径+配置文件
pretr_p:训练后的参数存放文件
cfg:模型配置文件
"""
config = Config(model_f,pretr_p)
if self.use_gpu:
config.enable_use_gpu(self.gpu_mem,0)
else:
config.disable_gpu()
if self.cpu_threads:
config.set_cpu_math_library_num_threads(self.cpu_threads)
if self.enable_mkldnn:
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if self.precision == "fp16":
config.enable_mkldnn_bfloat16()
config.switch_ir_optim(self.ir_optim)
config.enable_memory_optim()
config.switch_use_feed_fetch_ops(False)
if self.disable_glog:
config.disable_glog_info()
predictor = create_predictor(config)
return config,predictor
def create_inference_model(self,config,model_f,params_f):
"""
创建推理模型以及引擎
config模型配置文件
model_f可推理模型的存放路径
params_f可推理模型的参数
"""
cfg = get_config(config, overrides=None, show=False)
InferenceHelper = build_inference_helper(cfg.INFERENCE)
_, predictor = self.create_paddle_predictor(model_f, params_f, cfg)
return InferenceHelper,predictor
def predict(self,config,input_f,batch_size,predictor,InferenceHelper):
"""
推理模型,对数据进行推理、预测
config :PP-TSMv2模型的配置文件
input_f:待推理数据集的存放路径
batch_size:模型推理中所取数据的多少,default = 1
predictor:推理引擎
InferenceHelper:推理模型
"""
result = {}
cfg = get_config(config, overrides=None, show=False)
model_name = cfg.model_name
print(f"Inference model({model_name})...")
# get input_tensor and output_tensor
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
input_tensor_list = []
output_tensor_list = []
for item in input_names:
input_tensor_list.append(predictor.get_input_handle(item))
for item in output_names:
output_tensor_list.append(predictor.get_output_handle(item))
files = self.parse_file_paths(input_f)#input_path=input_f
batch_num = batch_size
for st_idx in range(0, len(files), batch_num):
ed_idx = min(st_idx + batch_num, len(files))
#输出数据预处理
batched_inputs = InferenceHelper.preprocess_batch(files[st_idx:ed_idx])
for i in range(len(input_tensor_list)):
input_tensor_list[i].copy_from_cpu(batched_inputs[i])
#推理引擎开始推理
predictor.run()
batched_outputs = []
for j in range(len(output_tensor_list)):
batched_outputs.append(output_tensor_list[j].copy_to_cpu())
#输出推理结果
res = InferenceHelper.postprocess(batched_outputs,False,True)
result["video_id"] = res[0]["video_id"]
result["topk_class"] = res[0]["topk_class"].tolist()[0]
result["topk_scores"] = res[0]["topk_scores"].tolist()[0]
print(result)
return result
def main():
config = 'D:/download/PaddleVideo1/output/output/pptsm_lcnet_k400_16frames_uniform.yaml' # 配置文件地址
input_file='C:/Users/Administrator/Pictures/video_seg_re_hand/test01_3.avi' #待推理数据集存放的地址
model_file = 'D:/download/PaddleVideo1/output/output/ppTSMv2.pdmodel' # 推理模型存放地址
params_file = 'D:/download/PaddleVideo1/output/output/ppTSMv2.pdiparams'
batch_size= 1 #输出推理模型
infer,predictor = PP_TSMv2_predict().create_inference_model(config,model_file,params_file)
PP_TSMv2_predict().predict(config,input_file,batch_size,predictor,infer) #推理模型推理、预测
if __name__ == "__main__":
main()