You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
xBank-Code-AI/model_load/model_load.py

51 lines
1.3 KiB
Python

import fastdeploy as fd
import os.path as op
def build_option(device, backend, cache_file):
"""
创建Runtime运行选项
device设备CPU or GPU
backend:TensorRT引擎
"""
option = fd.RuntimeOption()
option.use_cpu()
option.trt_option.serialize_file = cache_file
if device.lower() == "gpu":
option.use_gpu(0)
if backend.lower() == "trt":
assert device.lower(
) == "gpu", "TensorRT backend require inference on device GPU."
option.use_trt_backend()
return option
def Load_model(model_file, device, cache_file):
"""
加载模型的tensorRT引擎
model_file模型权重 格式".onnx"
device设备选择
"""
model_name = op.basename(model_file).split('_')[0]
# print(model_file)
runtime_option = build_option(
device=device, backend="trt", cache_file=cache_file)
model_inference = []
if model_name == "yolov5":
model = fd.vision.detection.YOLOv5(
model_file, runtime_option=runtime_option)
model_inference = model
elif model_name == "yolov8":
model = fd.vision.detection.YOLOv8(
model_file, runtime_option=runtime_option)
model_inference = model
# print(model_inference)
return model_inference