#include "MF_ObjectDetectBase.h"

MF_ObjectDetectBase::MF_ObjectDetectBase(const trtUtils::InitParameter& param) :MA_TRTInferAlgoBase(param)
{
	// input
	m_input_src_device = nullptr;
	m_input_resize_device = nullptr;
	m_input_rgb_device = nullptr;
	m_input_norm_device = nullptr;
	m_input_hwc_device = nullptr;
	checkRuntime(cudaMalloc(&m_input_src_device, param.batch_size * 3 * param.mImage.m_height * param.mImage.m_width * sizeof(unsigned char)));
	checkRuntime(cudaMalloc(&m_input_resize_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
	checkRuntime(cudaMalloc(&m_input_rgb_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
	checkRuntime(cudaMalloc(&m_input_norm_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
	checkRuntime(cudaMalloc(&m_input_hwc_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
	checkRuntime(cudaMalloc(&m_input_hwc_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));

	// output
	m_output_src_device = nullptr;
	m_output_objects_device = nullptr;
	m_output_objects_host = nullptr;
	m_output_objects_width = 7;
	m_output_idx_device = nullptr;
	m_output_conf_device = nullptr;
	int output_objects_size = param.batch_size * (1 + param.topK * m_output_objects_width); // 1: count
	checkRuntime(cudaMalloc(&m_output_objects_device, output_objects_size * sizeof(float)));
	checkRuntime(cudaMalloc(&m_output_idx_device, m_param.batch_size * m_param.topK * sizeof(int)));
	checkRuntime(cudaMalloc(&m_output_conf_device, m_param.batch_size * m_param.topK * sizeof(float)));
	m_output_objects_host = new float[output_objects_size];
	m_objectss.resize(param.batch_size);

}

MF_ObjectDetectBase::~MF_ObjectDetectBase()
{
	// input
	checkRuntime(cudaFree(m_input_src_device));
	checkRuntime(cudaFree(m_input_resize_device));
	checkRuntime(cudaFree(m_input_rgb_device));
	checkRuntime(cudaFree(m_input_norm_device));
	checkRuntime(cudaFree(m_input_hwc_device));

	// output
	checkRuntime(cudaFree(m_output_src_device));
	checkRuntime(cudaFree(m_output_objects_device));
	checkRuntime(cudaFree(m_output_idx_device));
	checkRuntime(cudaFree(m_output_conf_device));
	delete[] m_output_objects_host;
}
	
std::vector<std::vector<trtUtils::Box>> MF_ObjectDetectBase::getObjectss() const
{
	return m_objectss;
}