#pragma once #include "MF_ImageClassificationBase.h" #include #define NUMCLASSES 5 // 目标种类 #define IMGS_IDX 3 class MF_Resnet34Infer : public MF_ImageClassificationBase { public: MF_Resnet34Infer(const trtUtils::InitParameter& param); ~MF_Resnet34Infer(); // 初始化引擎 engine bool initEngine(const std::string& _onnxFileName); // 推理 bool doTRTInfer(const std::vector& _bufImgs, std::vector* _detectRes, int* _user); // 推理 bool doTRTInfer(const std::vector& _matImgs, std::vector* _detectRes, int* _user); // 获取错误信息 static std::string getError(); // 清理数据/内存 void freeMemeory(); protected: // 通过onnx格式文件构建engine int buildModel(const std::string& _onnxFile); // 预处理 int preProcess(const std::vector& _imgsBatch); // 推理 int infer(); // 后处理 int postProcess(const std::vector& _imgsBatch); // 获取推理结果 int getDetectResult(std::vector& _result); private: std::vector softmax(std::vector _input); std::vector load_labels(const std::string& _labelPath); std::string UTF8_2_GB(const char* str); private: int input_numel; nvinfer1::Dims input_dims; // 输入数据维度 trtUtils::ME_DetectRes detectRes; // 检测结果 float confidence; // 置信度 std::string predictName; // 推理类别 std::mutex m_mutex; // 输入数据 float* input_data_host{ nullptr }; float* input_data_device{ nullptr }; // 输出数据 float output_data_host[NUMCLASSES]; float* output_data_device{ nullptr }; };