优化PP-TSMv2中推理引擎加载问题

V0.1.0
jiangxt 2 years ago
parent b6aa129cf2
commit f0ee9320c4

@ -3,9 +3,6 @@ import os.path as osp
from paddlevideo.utils.config import get_config
from paddle.inference import Config, create_predictor
from utils import build_inference_helper
import warnings
warnings.filterwarnings("ignore")
class PP_TSMv2_predict(object):
@ -24,6 +21,8 @@ class PP_TSMv2_predict(object):
self.precision = precision #mfldnn精度
self.save_name = save_name #转化推理模型存放名称
def parse_file_paths(self,input_path: str) -> list:
"""
@ -77,29 +76,36 @@ class PP_TSMv2_predict(object):
return config,predictor
def create_inference_model(self,config,model_f,params_f):
"""
创建推理模型以及引擎
config模型配置文件
model_f可推理模型的存放路径
params_f可推理模型的参数
"""
cfg = get_config(config, overrides=None, show=False)
InferenceHelper = build_inference_helper(cfg.INFERENCE)
_, predictor = self.create_paddle_predictor(model_f, params_f, cfg)
return InferenceHelper,predictor
def predict(self,config,input_f,batch_size,model_f,params_f):
def predict(self,config,input_f,batch_size,predictor,InferenceHelper):
"""
推理模型,对数据进行推理预测
config :PP-TSMv2模型的配置文件
input_f:待推理数据集的存放路径
batch_size:模型推理中所取数据的多少,default = 1
model_f:可推理模型存放的路径+配置文件
params_f:可推理模型的参数
predictor:推理引擎
InferenceHelper:推理模型
"""
result = {}
cfg = get_config(config, overrides=None, show=False)
model_name = cfg.model_name
print(f"Inference model({model_name})...")
#创建推理模型
InferenceHelper = build_inference_helper(cfg.INFERENCE)
#创建推理引擎
_ , predictor = self.create_paddle_predictor(model_f,params_f,cfg)
# get input_tensor and output_tensor
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
@ -133,21 +139,21 @@ class PP_TSMv2_predict(object):
result["video_id"] = res[0]["video_id"]
result["topk_class"] = res[0]["topk_class"].tolist()[0]
result["topk_scores"] = res[0]["topk_scores"].tolist()[0]
print(result)
return result
def main():
config='D:/download/PaddleVideo1/configs/recognition/pptsm/v2/pptsm_lcnet_k400_16frames_uniform.yaml' #配置文件地址
input_file='D:/download/PaddleVideo1/data/dataset/video_seg_no_hand/test02_84.avi' #待推理数据集存放的地址
model_file='D:/download/PaddleVideo1/output/ppTSMv2.pdmodel' #推理模型存放地址
params_file='D:/download/PaddleVideo1/output/ppTSMv2.pdiparams' #推理模型参数存放地址
config = 'D:/download/PaddleVideo1/output/output/pptsm_lcnet_k400_16frames_uniform.yaml' # 配置文件地址
input_file='C:/Users/Administrator/Pictures/video_seg_re_hand/test01_3.avi' #待推理数据集存放的地址
model_file = 'D:/download/PaddleVideo1/output/output/ppTSMv2.pdmodel' # 推理模型存放地址
params_file = 'D:/download/PaddleVideo1/output/output/ppTSMv2.pdiparams'
batch_size= 1 #输出推理模型
t = PP_TSMv2_predict().predict(config,input_file,batch_size,model_file,params_file) #推理模型推理、预测
print(t)
infer,predictor = PP_TSMv2_predict().create_inference_model(config,model_file,params_file)
PP_TSMv2_predict().predict(config,input_file,batch_size,predictor,infer) #推理模型推理、预测

Loading…
Cancel
Save