You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

64 lines
1.6 KiB
C++

#pragma once
#include "MF_ImageClassificationBase.h"
#include <mutex>
#define NUMCLASSES 5 // 目标种类
#define IMGS_IDX 3
class MF_Resnet34Infer : public MF_ImageClassificationBase
{
public:
MF_Resnet34Infer(const trtUtils::InitParameter& param);
~MF_Resnet34Infer();
// 初始化引擎 engine
bool initEngine(const std::string& _onnxFileName);
// 推理
bool doTRTInfer(const std::vector<MN_VisionImage::MS_ImageParam>& _bufImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user);
// 推理
bool doTRTInfer(const std::vector<cv::Mat>& _matImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user);
// 获取错误信息
static std::string getError();
// 清理数据/内存
void freeMemeory();
protected:
// 通过onnx格式文件构建engine
int buildModel(const std::string& _onnxFile);
// 预处理
int preProcess(const std::vector<cv::Mat>& _imgsBatch);
// 推理
int infer();
// 后处理
int postProcess(const std::vector<cv::Mat>& _imgsBatch);
// 获取推理结果
int getDetectResult(std::vector<trtUtils::MR_Result>& _result);
private:
std::vector<float> softmax(std::vector<float> _input);
std::vector<std::string> load_labels(const std::string& _labelPath);
std::string UTF8_2_GB(const char* str);
private:
int input_numel;
nvinfer1::Dims input_dims; // 输入数据维度
trtUtils::ME_DetectRes detectRes; // 检测结果
float confidence; // 置信度
std::string predictName; // 推理类别
std::mutex m_mutex;
// 输入数据
float* input_data_host{ nullptr };
float* input_data_device{ nullptr };
// 输出数据
float output_data_host[NUMCLASSES];
float* output_data_device{ nullptr };
};