#pragma once
#include "MF_ImageClassificationBase.h"
#include <mutex>

#define NUMCLASSES 5		// 目标种类
#define IMGS_IDX 3


class MF_Resnet34Infer : public MF_ImageClassificationBase
{
public:
	MF_Resnet34Infer(const trtUtils::InitParameter& param);
	~MF_Resnet34Infer();

	// 初始化引擎 engine
	bool initEngine(const std::string& _onnxFileName);
	// 推理
	bool doTRTInfer(const std::vector<MN_VisionImage::MS_ImageParam>& _bufImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user);
	// 推理
	bool doTRTInfer(const std::vector<cv::Mat>& _matImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user);
	// 获取错误信息
	static std::string getError();
	// 清理数据/内存
	void freeMemeory();


protected:
	// 通过onnx格式文件构建engine
	int buildModel(const std::string& _onnxFile);
	// 预处理
	int preProcess(const std::vector<cv::Mat>& _imgsBatch);
	// 推理
	int infer();
	// 后处理
	int postProcess(const std::vector<cv::Mat>& _imgsBatch);
	// 获取推理结果
	int getDetectResult(std::vector<trtUtils::MR_Result>& _result);


private:
	std::vector<float> softmax(std::vector<float> _input);
	std::vector<std::string> load_labels(const std::string& _labelPath);
	std::string UTF8_2_GB(const char* str);


private:
	int input_numel;
	nvinfer1::Dims input_dims;		// 输入数据维度
	trtUtils::ME_DetectRes detectRes;	// 检测结果
	float confidence;				// 置信度
	std::string predictName;		// 推理类别
	std::mutex m_mutex;

	// 输入数据
	float* input_data_host{ nullptr };
	float* input_data_device{ nullptr };

	// 输出数据
	float output_data_host[NUMCLASSES];
	float* output_data_device{ nullptr };

};