You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
XZNSH-Code-AI/tool/PP_TSMv2_infer.py

159 lines
5.7 KiB
Python

2 years ago
import os
import os.path as osp
from paddlevideo.utils.config import get_config
2 years ago
from paddle.inference import Config, create_predictor
from utils import build_inference_helper
import warnings
warnings.filterwarnings("ignore")
class PP_TSMv2_predict(object):
2 years ago
"""PP-TSMv2模型中常用的参数初始化"""
def __init__(self,use_gpu=True,ir_optim=True,
disable_glog=False,save_name=None,enable_mklddn=False,
precision="fp32",gpu_mem=8000,cpu_threads=None):
2 years ago
self.use_gpu = use_gpu #是否使用GPU
self.cpu_threads = cpu_threads #cpu线程数
self.ir_optim = ir_optim #是否开启IR优化
2 years ago
self.disable_glog = disable_glog
self.gpu_mem = gpu_mem #GPU存储大小
self.enable_mkldnn = enable_mklddn #是否开启mkldnn
self.precision = precision #mfldnn精度
self.save_name = save_name #转化推理模型存放名称
def parse_file_paths(self,input_path: str) -> list:
"""
获取模型输入数据
input_path:模型的输入文件
"""
if osp.isfile(input_path):
files = [
input_path,
]
else:
files = os.listdir(input_path)
files = [
file for file in files
if (file.endswith(".avi") or file.endswith(".mp4"))
]
files = [osp.join(input_path, file) for file in files]
return files
2 years ago
def create_paddle_predictor(self,model_f,pretr_p,cfg):
"""
创建推理引擎
model_f:可推理模型存放的路径+配置文件
pretr_p:训练后的参数存放文件
cfg:模型配置文件
2 years ago
"""
2 years ago
config = Config(model_f,pretr_p)
if self.use_gpu:
config.enable_use_gpu(self.gpu_mem,0)
else:
config.disable_gpu()
if self.cpu_threads:
config.set_cpu_math_library_num_threads(self.cpu_threads)
if self.enable_mkldnn:
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if self.precision == "fp16":
config.enable_mkldnn_bfloat16()
config.switch_ir_optim(self.ir_optim)
config.enable_memory_optim()
config.switch_use_feed_fetch_ops(False)
if self.disable_glog:
config.disable_glog_info()
predictor = create_predictor(config)
return config,predictor
def predict(self,config,input_f,batch_size,model_f,params_f):
"""
推理模型,对数据进行推理预测
config :PP-TSMv2模型的配置文件
input_f:待推理数据集的存放路径
batch_size:模型推理中所取数据的多少,default = 1
model_f:可推理模型存放的路径+配置文件
params_f:可推理模型的参数
"""
result = {}
2 years ago
cfg = get_config(config, overrides=None, show=False)
model_name = cfg.model_name
print(f"Inference model({model_name})...")
#创建推理模型
2 years ago
InferenceHelper = build_inference_helper(cfg.INFERENCE)
#创建推理引擎
_ , predictor = self.create_paddle_predictor(model_f,params_f,cfg)
2 years ago
# get input_tensor and output_tensor
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
input_tensor_list = []
output_tensor_list = []
for item in input_names:
input_tensor_list.append(predictor.get_input_handle(item))
for item in output_names:
output_tensor_list.append(predictor.get_output_handle(item))
files = self.parse_file_paths(input_f)#input_path=input_f
2 years ago
batch_num = batch_size
for st_idx in range(0, len(files), batch_num):
ed_idx = min(st_idx + batch_num, len(files))
#输出数据预处理
2 years ago
batched_inputs = InferenceHelper.preprocess_batch(files[st_idx:ed_idx])
for i in range(len(input_tensor_list)):
input_tensor_list[i].copy_from_cpu(batched_inputs[i])
#推理引擎开始推理
2 years ago
predictor.run()
batched_outputs = []
for j in range(len(output_tensor_list)):
batched_outputs.append(output_tensor_list[j].copy_to_cpu())
#输出推理结果
res = InferenceHelper.postprocess(batched_outputs,False,True)
result["video_id"] = res[0]["video_id"]
result["topk_class"] = res[0]["topk_class"].tolist()[0]
result["topk_scores"] = res[0]["topk_scores"].tolist()[0]
return result
2 years ago
def main():
config='D:/download/PaddleVideo1/configs/recognition/pptsm/v2/pptsm_lcnet_k400_16frames_uniform.yaml' #配置文件地址
input_file='D:/download/PaddleVideo1/data/dataset/video_seg_no_hand/test02_84.avi' #待推理数据集存放的地址
model_file='D:/download/PaddleVideo1/output/ppTSMv2.pdmodel' #推理模型存放地址
params_file='D:/download/PaddleVideo1/output/ppTSMv2.pdiparams' #推理模型参数存放地址
batch_size= 1 #输出推理模型
t = PP_TSMv2_predict().predict(config,input_file,batch_size,model_file,params_file) #推理模型推理、预测
print(t)
2 years ago
if __name__ == "__main__":
main()