You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
2.3 KiB
C++

#include "MF_ObjectDetectBase.h"
MF_ObjectDetectBase::MF_ObjectDetectBase(const utils::InitParameter& param) :MA_TRTInferAlgoBase(param)
{
// input
m_input_src_device = nullptr;
m_input_resize_device = nullptr;
m_input_rgb_device = nullptr;
m_input_norm_device = nullptr;
m_input_hwc_device = nullptr;
checkRuntime(cudaMalloc(&m_input_src_device, param.batch_size * 3 * param.mImage.m_height * param.mImage.m_width * sizeof(unsigned char)));
checkRuntime(cudaMalloc(&m_input_resize_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
checkRuntime(cudaMalloc(&m_input_rgb_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
checkRuntime(cudaMalloc(&m_input_norm_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
checkRuntime(cudaMalloc(&m_input_hwc_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
checkRuntime(cudaMalloc(&m_input_hwc_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
// output
m_output_src_device = nullptr;
m_output_objects_device = nullptr;
m_output_objects_host = nullptr;
m_output_objects_width = 7;
m_output_idx_device = nullptr;
m_output_conf_device = nullptr;
int output_objects_size = param.batch_size * (1 + param.topK * m_output_objects_width); // 1: count
checkRuntime(cudaMalloc(&m_output_objects_device, output_objects_size * sizeof(float)));
checkRuntime(cudaMalloc(&m_output_idx_device, m_param.batch_size * m_param.topK * sizeof(int)));
checkRuntime(cudaMalloc(&m_output_conf_device, m_param.batch_size * m_param.topK * sizeof(float)));
m_output_objects_host = new float[output_objects_size];
m_objectss.resize(param.batch_size);
}
MF_ObjectDetectBase::~MF_ObjectDetectBase()
{
// input
checkRuntime(cudaFree(m_input_src_device));
checkRuntime(cudaFree(m_input_resize_device));
checkRuntime(cudaFree(m_input_rgb_device));
checkRuntime(cudaFree(m_input_norm_device));
checkRuntime(cudaFree(m_input_hwc_device));
// output
checkRuntime(cudaFree(m_output_src_device));
checkRuntime(cudaFree(m_output_objects_device));
checkRuntime(cudaFree(m_output_idx_device));
checkRuntime(cudaFree(m_output_conf_device));
delete[] m_output_objects_host;
}
std::vector<std::vector<utils::Box>> MF_ObjectDetectBase::getObjectss() const
{
return m_objectss;
}