You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
1.6 KiB
C++

#pragma once
#include "MI_Interface.h"
class ML_Log;
class MA_TRTInferAlgoBase : public MI_VisionInterface
{
public:
MA_TRTInferAlgoBase(const trtUtils::InitParameter& param);
~MA_TRTInferAlgoBase();
virtual bool initEngine(const std::string& _onnxFileName);
virtual bool check();
virtual bool doTRTInfer(const std::vector<MN_VisionImage::MS_ImageParam>& _bufImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user);
virtual bool doTRTInfer(const std::vector<cv::Mat>& _matImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user);
virtual std::string getError();
virtual void freeMemeory();
virtual bool measureAxis(std::vector<double>& measureRes, const MN_VisionImage::MS_ImageParam& _bufImg);
protected:
// 加载TRT模型数据
virtual int loadTRTModelData(const std::string& _trtFile, std::vector<uchar>& _modelData);
// 将图像数据拷贝到显存
virtual int copyToDevice(const std::vector<cv::Mat>& _imgsBatch);
// 预处理
virtual int preProcess(const std::vector<cv::Mat>& _imgsBatch);
// 推理
virtual int infer();
// 将推理结果从显存拷贝到cpu
virtual int copyFromDevice(const std::vector<cv::Mat>& _imgsBatch);
// 后处理
virtual int postProcess(const std::vector<cv::Mat>& _imgsBatch);
// 将buffer图像转换为cv::Mat格式
bool buffer2Mat(const MN_VisionImage::MS_ImageParam& _inImg, cv::Mat& _mat);
protected:
std::unique_ptr<nvinfer1::IRuntime> m_runtime{ nullptr };
std::unique_ptr<nvinfer1::ICudaEngine> m_engine{ nullptr };
std::unique_ptr<nvinfer1::IExecutionContext> m_context{ nullptr };
cudaStream_t mStream{ nullptr };
trtUtils::InitParameter m_param;
std::shared_ptr<ML_Log> mLogPtr{ nullptr };
};