#include "MF_Resnet34Infer.h"
#include "ML_Log.h"

namespace fs = std::filesystem;
using namespace samplesCommon;

MF_Resnet34Infer::MF_Resnet34Infer(const trtUtils::InitParameter& param) :MF_ImageClassificationBase(param)
{
	checkRuntime(cudaStreamCreate(&mStream));

	input_numel = m_param.batch_size * m_param.mImage.m_channels * m_param.dst_h * m_param.dst_w;
	checkRuntime(cudaMallocHost(&input_data_host, input_numel * sizeof(float)));
	checkRuntime(cudaMalloc(&input_data_device, input_numel * sizeof(float)));

	output_data_host[NUMCLASSES] = { 0.0 };
	checkRuntime(cudaMalloc(&output_data_device, input_numel * sizeof(float)));
}

MF_Resnet34Infer::~MF_Resnet34Infer()
{
	// ÊÍ·ÅÄÚ´æ
	checkRuntime(cudaStreamDestroy(mStream));
	checkRuntime(cudaFreeHost(input_data_host));
	checkRuntime(cudaFree(input_data_device));
	checkRuntime(cudaFree(output_data_device));
}

bool MF_Resnet34Infer::initEngine(const std::string& _onnxFileName)
{
	LOG_INFO("on the init engine, input onnx file : " + _onnxFileName);

	m_mutex.lock();
	// Åжϴ«ÈëµÄonnxÎļþÊÇ·ñ´æÔÚ
	fs::path onnx_path(_onnxFileName);
	if (!fs::exists(onnx_path))
	{
		LOG_ERROR("init engine, input onnx file does not exist. \n");
		return false;
	}

	// Ìæ»»ÎļþÀ©Õ¹Ãû£¬½«.onnx Ìæ»»Îª .trt£¬²¢ÅÐ¶Ï trt Ä£ÐÍÊÇ·ñÒѾ­´æÔÚ
	// Èô±¾µØ´æÔÚtrtÄ£ÐÍ£¬ÔòÖ±½Ó¼ÓÔØtrtÄ£ÐͲ¢¹¹½¨engine
	fs::path extension("trt");
	fs::path trt_Path = onnx_path.replace_extension(extension);
	// std::string trtFileName = trt_Path.string();
	if (fs::exists(trt_Path))
	{
		LOG_INFO("trt model has existed.\n");

		std::vector<uchar> engine_data;
		int iRet = loadTRTModelData(trt_Path.string(), engine_data);
		if (iRet != 0)
		{
			LOG_ERROR("load trt model failed.\n");
			return false;
		}

		m_mutex.unlock();

		m_runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(sample::gLogger.getTRTLogger()));
		if (!m_runtime)
		{
			LOG_ERROR("on the init engine, create infer runtime failed.\n");
			return false;
		}

		m_engine = std::unique_ptr<nvinfer1::ICudaEngine>(m_runtime->deserializeCudaEngine(engine_data.data(), engine_data.size()));
		if (!m_engine)
		{
			LOG_ERROR("on the init engine, deserialize engine failed.\n");
			return false;
		}

		m_context = std::unique_ptr<nvinfer1::IExecutionContext>(m_engine->createExecutionContext());
		if (!m_context)
		{
			LOG_ERROR("on the init engine, create excution context failed.\n");
			return false;
		}

		if (m_param.dynamic_batch)
		{
			m_context->setBindingDimensions(0, nvinfer1::Dims4(m_param.batch_size, m_param.mImage.m_channels, m_param.dst_h, m_param.dst_w));
		}

		// Ã÷È·µ±Ç°ÍÆÀíʱ£¬Ê¹ÓõÄÊäÈëÊý¾Ý´óС
		input_dims = m_context->getBindingDimensions(0);
		input_dims.d[0] = m_param.batch_size;

	}
	else
	{
		// ±¾µØ²»´æÔÚtrtÄ£ÐÍ£¬ÐèÒªÖØÐ¹¹½¨Engine
		int iRet = this->buildModel(_onnxFileName);
		if (iRet != 0)
		{
			LOG_ERROR("on the init engine, from onnx file build model failed.\n");
			return false;
		}
	}

	return true;
}

bool MF_Resnet34Infer::doTRTInfer(const std::vector<MN_VisionImage::MS_ImageParam>& _bufImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user)
{
	m_mutex.lock();
	std::vector<cv::Mat> matImgs;
	for (auto _var : _bufImgs)
	{
		cv::Mat image;
		bool bRet = buffer2Mat(_var, image);
		if (!bRet)
		{
			LOG_ERROR("doinfer(), convert buffer to Mat failed. \n");
			return false;
		}
		matImgs.emplace_back(image);
	}

	int iRet = 0;
	iRet = this->preProcess(matImgs);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), preprocess image failed. \n");
		return false;
	}

	iRet = this->infer();
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), infer failed. \n");
		return false;
	}

	iRet = this->postProcess(matImgs);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), postprocess image failed. \n");
		return false;
	}

	iRet = this->getDetectResult(*_detectRes);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), get detect result failed. \n");
		return false;
	}

	m_mutex.unlock();

	return true;
}

bool MF_Resnet34Infer::doTRTInfer(const std::vector<cv::Mat>& _matImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user)
{
	m_mutex.lock();
	int iRet = 0;
	iRet = this->preProcess(_matImgs);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), preprocess image failed. \n");
		return false;
	}

	iRet = this->infer();
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), infer failed. \n");
		return false;
	}

	iRet = this->postProcess(_matImgs);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), postprocess image failed. \n");
		return false;
	}

	iRet = this->getDetectResult(*_detectRes);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), get detect result failed. \n");
		return false;
	}

	m_mutex.unlock();

	return true;
}

std::string MF_Resnet34Infer::getError()
{
	return "";
}

void MF_Resnet34Infer::freeMemeory()
{
	return;
}

int MF_Resnet34Infer::buildModel(const std::string& _onnxFile)
{
	LOG_INFO("on the build model, input onnx file : " + _onnxFile);

	auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
	if (!builder)
	{
		LOG_ERROR("on the build model, create infer builder failed. \n");
		return 1;
	}

	auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
	auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
	if (!network)
	{
		LOG_ERROR("on the build model, create network failed. \n");
		return 1;
	}

	auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
	if (!config)
	{
		LOG_ERROR("on the build model, create builder config failed. \n");
		return 1;
	}

	config->setFlag(nvinfer1::BuilderFlag::kFP16);
	samplesCommon::enableDLA(builder.get(), config.get(), -1);

	// ͨ¹ýonnxparser½âÎöÆ÷½âÎöµÄ½á¹û»áÌî³äµ½networkÖÐ
	auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
	if (!parser->parseFromFile(_onnxFile.c_str(), static_cast<int>(sample::gLogger.getReportableSeverity())))
	{
		LOG_ERROR("on the build model, parser onnx file failed. \n");
		for (int i = 0; i < parser->getNbErrors(); i++)
		{
			LOG_WARN(parser->getError(i)->desc());
		}

		return 1;
	}

	int maxBatchSize = 1;
	printf("workspace size = %.2f MB.\n", (1 << 28) / 1024.0f / 1024.0f);

	// Èç¹ûÄ£ÐÍÓжà¸öÊäÈ룬Ôò±ØÐëÓжà¸öprofile
	auto profile = builder->createOptimizationProfile();
	auto input_tensor = network->getInput(0);
	auto input_dims = input_tensor->getDimensions();

	// ÅäÖÃ×îС¡¢×îÓÅ¡¢×î´ó·¶Î§
	input_dims.d[0] = 1;		// batchsize
	profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMIN, input_dims);
	profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kOPT, input_dims);
	input_dims.d[0] = maxBatchSize;		// [batchsize, channel, height, width]
	profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMAX, input_dims);
	config->addOptimizationProfile(profile);

	auto profileStream = samplesCommon::makeCudaStream();
	if (!profileStream)
	{
		LOG_ERROR("on the build model, make cuda stream failed.\n");
		return 1;
	}
	config->setProfileStream(*profileStream);

	auto engine_data = std::unique_ptr<nvinfer1::IHostMemory>(builder->buildSerializedNetwork(*network, *config));
	if (!engine_data)
	{
		LOG_ERROR("on the build model, build engine failed.\n");
		return 1;
	}

	// ½«Ä£Ðͱ£´æÎª¶þ½øÖÆÎļþ
	fs::path onnx_path(_onnxFile);
	fs::path extension("trt");
	fs::path trt_Path = onnx_path.replace_extension(extension);
	FILE* fp;
	errno_t err = fopen_s(&fp, trt_Path.string().c_str(), "wb");
	if (err != 0)
	{
		LOG_ERROR("save engine data failed. \n");
		return 1;
	}
	fwrite(engine_data->data(), 1, engine_data->size(), fp);
	fclose(fp);
	LOG_INFO("on the build model, save model success. \n");

	m_runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(sample::gLogger.getTRTLogger()));
	if (!m_runtime)
	{
		LOG_ERROR("on the build model, create infer runtime failed.\n");
		return 1;
	}

	m_engine = std::unique_ptr<nvinfer1::ICudaEngine>(m_runtime->deserializeCudaEngine(engine_data->data(), engine_data->size()));
	if (!m_runtime)
	{
		LOG_ERROR("on the build model, deserialize cuda engine failed.\n");
		return 1;
	}

	m_context = std::unique_ptr<nvinfer1::IExecutionContext>(m_engine->createExecutionContext());
	if (!m_context)
	{
		LOG_ERROR("on the build model, create excution context failed.\n");
		return 1;
	}

	if (m_param.dynamic_batch)
	{
		m_context->setBindingDimensions(0, nvinfer1::Dims4(m_param.batch_size, m_param.mImage.m_channels, m_param.dst_h, m_param.dst_w));
	}

	// Ã÷È·µ±Ç°ÍÆÀíʱ£¬Ê¹ÓõÄÊäÈëÊý¾Ý´óС
	input_dims = m_context->getBindingDimensions(0);
	input_dims.d[0] = m_param.batch_size;
	LOG_INFO("build model success. \n");

	return 0;
}

int MF_Resnet34Infer::preProcess(const std::vector<cv::Mat>& _imgsBatch)
{
	cv::Mat image;
	cv::resize(_imgsBatch[IMGS_IDX], image, cv::Size(m_param.dst_w, m_param.dst_h));
	int image_size = image.rows * image.cols;
	uchar* pData = image.data;
	float* pHost_b = input_data_host + image_size * 0;
	float* pHost_g = input_data_host + image_size * 1;
	float* pHost_r = input_data_host + image_size * 2;
	for (int i = 0; i < image_size; i++, pData += 3)
	{
		// ×¢ÒâÕâÀïµÄ˳Ðòrgbµ÷»»ÁË
		*pHost_r++ = (pData[0] / 255.0f - m_param.meanVec[0]) / m_param.stdVec[0];
		*pHost_g++ = (pData[1] / 255.0f - m_param.meanVec[1]) / m_param.stdVec[1];
		*pHost_b++ = (pData[2] / 255.0f - m_param.meanVec[2]) / m_param.stdVec[2];
	}

	// ½«ÊäÈëÊý¾Ý´Ó host ¿½±´µ½ device
	checkRuntime(cudaMemcpyAsync(input_data_device, input_data_host, input_numel * sizeof(float), cudaMemcpyHostToDevice, mStream));

	return 0;
}

int MF_Resnet34Infer::infer()
{
	if (m_context == nullptr)
	{
		LOG_ERROR("do infer, context dose not init. \n");
		return 1;
	}

	// ÉèÖõ±Ç°ÍÆÀíʱ£¬input´óС
	m_context->setBindingDimensions(0, input_dims);

	float* bindings[] = { input_data_device, output_data_device };
	bool bRet = m_context->enqueueV2((void**)bindings, mStream, nullptr);
	if (!bRet)
	{
		LOG_ERROR("infer failed.\n");
		return 1;
	}

	checkRuntime(cudaStreamSynchronize(mStream));

	return 0;
}

int MF_Resnet34Infer::postProcess(const std::vector<cv::Mat>& _imgsBatch)
{
	// ½«ÍÆÀí½á¹û´Ódevice¿½±´µ½host
	checkRuntime(cudaMemcpyAsync(output_data_host, output_data_device, sizeof(output_data_host), cudaMemcpyDeviceToHost, mStream));
	checkRuntime(cudaStreamSynchronize(mStream));

	std::vector<float> output_data_vec;
	for (auto _var : output_data_host)
	{
		output_data_vec.emplace_back(_var);
	}
	std::vector<float> output_result = this->softmax(output_data_vec);

	float output[NUMCLASSES] = { output_result[0],output_result[1],output_result[2],output_result[3],output_result[4] };;

	float* prob = output;
	int predict_label = std::max_element(prob, prob + NUMCLASSES) - prob;		// È·¶¨Ô¤²âÀà±ðµÄϱê
	auto labels = m_param.class_names;
	predictName = labels[predict_label];
	confidence = prob[predict_label];		// »ñµÃÔ¤²âÖµµÄÖÃÐŶÈ
	if (confidence < 0.5)
	{
		detectRes = trtUtils::ME_DetectRes::E_DETECT_NG;
	}
	else
	{
		detectRes = trtUtils::ME_DetectRes::E_DETECT_OK;
	}
	printf("Predict: %s, confidence: %.3f, label: %d. \n", predictName.c_str(), confidence, predict_label);

	return 0;
}

int MF_Resnet34Infer::getDetectResult(std::vector<trtUtils::MR_Result>& _result)
{
	trtUtils::MR_Result res;
	for (size_t i = 0; i < m_param.batch_size; i++)
	{
		res.mClassifyDecRes.mDetectRes = detectRes;
		res.mClassifyDecRes.mConfidence = confidence;
		res.mClassifyDecRes.mLabel = predictName;

		_result.emplace_back(res);
	}

	return 0;
}

std::vector<float> MF_Resnet34Infer::softmax(std::vector<float> _input)
{
	std::vector<float> result{};

	float total = 0;
	float MAX = _input[0];

	for (auto x : _input)
	{
		MAX = (std::max)(x, MAX);
	}

	for (auto x : _input)
	{
		total += exp(x - MAX);
	}

	for (auto x : _input)
	{
		result.emplace_back(exp(x - MAX) / total);
	}
	return result;
}

std::vector<std::string> MF_Resnet34Infer::load_labels(const std::string& _labelPath)
{
	std::vector<std::string> lines;

	FILE* fp = nullptr;
	errno_t err = fopen_s(&fp, _labelPath.c_str(), "r");
	if (err != 0)
	{
		return std::vector<std::string>{};
	}

	char buf[1024];
	std::string line;
	while (!feof(fp))
	{
		fgets(buf, sizeof(buf), (FILE*)fp);
		line = this->UTF8_2_GB(buf);
		lines.emplace_back(line);
	}
	fclose(fp);

	return lines;
}

std::string MF_Resnet34Infer::UTF8_2_GB(const char* str)
{
	std::string result;
	WCHAR* strSrc;
	LPSTR szRes;

	//»ñµÃÁÙʱ±äÁ¿µÄ´óС
	int i = MultiByteToWideChar(CP_UTF8, 0, str, -1, NULL, 0);
	strSrc = new WCHAR[i + 1];
	MultiByteToWideChar(CP_UTF8, 0, str, -1, strSrc, i);

	//»ñµÃÁÙʱ±äÁ¿µÄ´óС
	i = WideCharToMultiByte(CP_ACP, 0, strSrc, -1, NULL, 0, NULL, NULL);
	szRes = new CHAR[i + 1];
	WideCharToMultiByte(CP_ACP, 0, strSrc, -1, szRes, i, NULL, NULL);

	result = szRes;
	delete[]strSrc;
	delete[]szRes;

	return result;
}