#include "MF_Yolov8Infer.h"
#include "ML_Log.h"

namespace fs = std::filesystem;

MF_Yolov8Infer::MF_Yolov8Infer(const utils::InitParameter& param) : MF_ObjectDetectBase(param)
{
}

MF_Yolov8Infer::~MF_Yolov8Infer()
{
	checkRuntime(cudaFree(m_output_src_transpose_device));
}

bool MF_Yolov8Infer::initEngine(const std::string& _onnxFileName)
{
	// 判断传入的onnx文件是否存在
	fs::path onnx_path(_onnxFileName);
	if (!fs::exists(onnx_path))
	{
		LOG_ERROR("init engine, onnx file does not exist. \n");
		return false;
	}

	// 替换文件扩展名,将.onnx 替换为 .trt,并判断 trt 模型是否已经存在
	// 若本地存在trt模型,则直接加载trt模型并构建engine
	fs::path extension("trt");
	fs::path trt_Path = onnx_path.replace_extension(extension);
	// std::string trtFileName = trt_Path.string();
	if (fs::exists(trt_Path))
	{
		LOG_INFO("trt model has existed.\n");

		std::vector<uchar> engine_data;
		int iRet = loadTRTModelData(trt_Path.string(), engine_data);
		if (iRet != 0)
		{
			LOG_ERROR("load trt model failed.\n");
			return false;
		}

		auto runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(sample::gLogger.getTRTLogger()));
		if (!runtime)
		{
			LOG_ERROR("on the init engine, create infer runtime failed.\n");
			return false;
		}

		m_engine = std::unique_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(engine_data.data(), engine_data.size()));
		if (!m_engine)
		{
			LOG_ERROR("on the init engine, deserialize engine failed.\n");
			return false;
		}
		m_context = std::unique_ptr<nvinfer1::IExecutionContext>(m_engine->createExecutionContext());
		if (!m_context)
		{
			LOG_ERROR("on the init engine, create excution context failed.\n");
			return false;
		}

		if (m_param.dynamic_batch)
		{
			m_context->setBindingDimensions(0, nvinfer1::Dims4(m_param.batch_size, 3, m_param.dst_h, m_param.dst_w));
		}

		m_output_dims = this->m_context->getBindingDimensions(1);
		m_total_objects = m_output_dims.d[2];
		assert(m_param.batch_size <= m_output_dims.d[0]);
		m_output_area = 1;
		for (int i = 1; i < m_output_dims.nbDims; i++)
		{
			if (m_output_dims.d[i] != 0)
			{
				m_output_area *= m_output_dims.d[i];
			}
		}
		checkRuntime(cudaMalloc(&m_output_src_device, m_param.batch_size * m_output_area * sizeof(float)));
		checkRuntime(cudaMalloc(&m_output_src_transpose_device, m_param.batch_size * m_output_area * sizeof(float)));
		float a = float(m_param.dst_h) / m_param.mImage.m_height;
		float b = float(m_param.dst_w) / m_param.mImage.m_width;
		float scale = a < b ? a : b;
		cv::Mat src2dst = (cv::Mat_<float>(2, 3) << scale, 0.f, (-scale * m_param.mImage.m_width + m_param.dst_w + scale - 1) * 0.5,
			0.f, scale, (-scale * m_param.mImage.m_height + m_param.dst_h + scale - 1) * 0.5);
		cv::Mat dst2src = cv::Mat::zeros(2, 3, CV_32FC1);
		cv::invertAffineTransform(src2dst, dst2src);

		m_dst2src.v0 = dst2src.ptr<float>(0)[0];
		m_dst2src.v1 = dst2src.ptr<float>(0)[1];
		m_dst2src.v2 = dst2src.ptr<float>(0)[2];
		m_dst2src.v3 = dst2src.ptr<float>(1)[0];
		m_dst2src.v4 = dst2src.ptr<float>(1)[1];
		m_dst2src.v5 = dst2src.ptr<float>(1)[2];

		return true;
	}

	return false;
}

bool MF_Yolov8Infer::doTRTInfer(const std::vector<MN_VisionImage::MS_ImageParam>& _bufImgs, std::vector<utils::MR_Result>* _detectRes, int* _user)
{
	std::vector<cv::Mat> matImgs;
	for (auto _var : _bufImgs)
	{
		cv::Mat image;
		bool bRet = buffer2Mat(_var, image);
		if (!bRet)
		{
			LOG_ERROR("doinfer(), convert buffer to Mat failed. \n");
			return false;
		}
		matImgs.emplace_back(image);
	}

	int iRet = 0;
	iRet = this->copyToDevice(matImgs);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), copy image data from cpu to device failed. \n");
		return false;
	}

	iRet = this->preProcess(matImgs);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), preprocess image failed. \n");
		return false;
	}

	iRet = this->infer();
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), infer failed. \n");
		return false;
	}

	iRet = this->copyFromDevice(matImgs);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), copy image data from device to cpu failed. \n");
		return false;
	}

	iRet = this->postProcess(matImgs);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), postprocess image failed. \n");
		return false;
	}

	iRet = this->getDetectResult(*_detectRes);
	if (iRet != 0)
	{
		LOG_ERROR("doinfer(), get detect result failed. \n");
		return false;
	}

	return false;
}

bool MF_Yolov8Infer::doTRTInfer(const std::vector<cv::Mat>& _bufImgs, std::vector<utils::MR_Result>* _detectRes, int* _user)
{
	return false;
}

std::string MF_Yolov8Infer::getError()
{
	return "";
}

void MF_Yolov8Infer::freeMemeory()
{
	checkRuntime(cudaMemset(m_output_objects_device, 0, sizeof(float) * m_param.batch_size * (1 + 7 * m_param.topK)));
	for (size_t bi = 0; bi < m_param.batch_size; bi++)
	{
		m_objectss[bi].clear();
	}
}

int MF_Yolov8Infer::copyToDevice(const std::vector<cv::Mat>& _imgsBatch)
{
	// update 20230302, faster. 
   // 1. move uint8_to_float in cuda kernel function. For 8*3*1920*1080, cost time 15ms -> 3.9ms
   // 2. Todo
	unsigned char* pi = m_input_src_device;
	for (size_t i = 0; i < _imgsBatch.size(); i++)
	{
		checkRuntime(cudaMemcpy(pi, _imgsBatch[i].data, sizeof(unsigned char) * 3 * m_param.mImage.m_height * m_param.mImage.m_width, cudaMemcpyHostToDevice));
		pi += 3 * m_param.mImage.m_height * m_param.mImage.m_width;
	}

	return 0;
}

int MF_Yolov8Infer::preProcess(const std::vector<cv::Mat>& _imgsBatch)
{
	resizeDevice(m_param.batch_size, m_input_src_device, m_param.mImage.m_width, m_param.mImage.m_height,
		m_input_resize_device, m_param.dst_w, m_param.dst_h, 114, m_dst2src);
	bgr2rgbDevice(m_param.batch_size, m_input_resize_device, m_param.dst_w, m_param.dst_h,
		m_input_rgb_device, m_param.dst_w, m_param.dst_h);
	normDevice(m_param.batch_size, m_input_rgb_device, m_param.dst_w, m_param.dst_h,
		m_input_norm_device, m_param.dst_w, m_param.dst_h, m_param);
	hwc2chwDevice(m_param.batch_size, m_input_norm_device, m_param.dst_w, m_param.dst_h,
		m_input_hwc_device, m_param.dst_w, m_param.dst_h);

	return 0;
}

int MF_Yolov8Infer::infer()
{
	float* bindings[] = { m_input_hwc_device, m_output_src_device };
	bool bRet = m_context->executeV2((void**)bindings);
	if (!bRet)
	{
		LOG_ERROR("infer failed.\n");
		return 1;
	}

	return 0;
}

int MF_Yolov8Infer::copyFromDevice(const std::vector<cv::Mat>& _imgsBatch)
{
	return 0;
}

int MF_Yolov8Infer::postProcess(const std::vector<cv::Mat>& _imgsBatch)
{
	decodeDevice(m_param, m_output_src_device, 5 + m_param.num_class, m_total_objects,
		m_output_area, m_output_objects_device, m_output_objects_width, m_param.topK);

	// nmsv1(nms faster)
	nmsDeviceV1(m_param, m_output_objects_device, m_output_objects_width, m_param.topK, m_param.topK * m_output_objects_width + 1);

	// nmsv2(nms sort)
	//nmsDeviceV2(m_param, m_output_objects_device, m_output_objects_width, m_param.topK, m_param.topK * m_output_objects_width + 1, m_output_idx_device, m_output_conf_device);

	checkRuntime(cudaMemcpy(m_output_objects_host, m_output_objects_device, m_param.batch_size * sizeof(float) * (1 + 7 * m_param.topK), cudaMemcpyDeviceToHost));
	for (size_t bi = 0; bi < _imgsBatch.size(); bi++)
	{
		int num_boxes = (std::min)((int)(m_output_objects_host + bi * (m_param.topK * m_output_objects_width + 1))[0], m_param.topK);
		for (size_t i = 0; i < num_boxes; i++)
		{
			float* ptr = m_output_objects_host + bi * (m_param.topK * m_output_objects_width + 1) + m_output_objects_width * i + 1;
			int keep_flag = ptr[6];
			if (keep_flag)
			{
				float x_lt = m_dst2src.v0 * ptr[0] + m_dst2src.v1 * ptr[1] + m_dst2src.v2;
				float y_lt = m_dst2src.v3 * ptr[0] + m_dst2src.v4 * ptr[1] + m_dst2src.v5;
				float x_rb = m_dst2src.v0 * ptr[2] + m_dst2src.v1 * ptr[3] + m_dst2src.v2;
				float y_rb = m_dst2src.v3 * ptr[2] + m_dst2src.v4 * ptr[3] + m_dst2src.v5;
				m_objectss[bi].emplace_back(x_lt, y_lt, x_rb, y_rb, ptr[4], (int)ptr[5]);
			}
		}
	}

	return 0;
}

int MF_Yolov8Infer::getDetectResult(std::vector<utils::MR_Result>& _result)
{
	if (_result.size() <= 0)
	{
		LOG_INFO("get detect result faild. \n");
		return 1;
	}

	return 0;
}