You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
2.3 KiB
C++
54 lines
2.3 KiB
C++
#include "MF_ObjectDetectBase.h"
|
|
|
|
MF_ObjectDetectBase::MF_ObjectDetectBase(const utils::InitParameter& param) :MA_TRTInferAlgoBase(param)
|
|
{
|
|
// input
|
|
m_input_src_device = nullptr;
|
|
m_input_resize_device = nullptr;
|
|
m_input_rgb_device = nullptr;
|
|
m_input_norm_device = nullptr;
|
|
m_input_hwc_device = nullptr;
|
|
checkRuntime(cudaMalloc(&m_input_src_device, param.batch_size * 3 * param.mImage.m_height * param.mImage.m_width * sizeof(unsigned char)));
|
|
checkRuntime(cudaMalloc(&m_input_resize_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
|
|
checkRuntime(cudaMalloc(&m_input_rgb_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
|
|
checkRuntime(cudaMalloc(&m_input_norm_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
|
|
checkRuntime(cudaMalloc(&m_input_hwc_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
|
|
checkRuntime(cudaMalloc(&m_input_hwc_device, param.batch_size * 3 * param.dst_h * param.dst_w * sizeof(float)));
|
|
|
|
// output
|
|
m_output_src_device = nullptr;
|
|
m_output_objects_device = nullptr;
|
|
m_output_objects_host = nullptr;
|
|
m_output_objects_width = 7;
|
|
m_output_idx_device = nullptr;
|
|
m_output_conf_device = nullptr;
|
|
int output_objects_size = param.batch_size * (1 + param.topK * m_output_objects_width); // 1: count
|
|
checkRuntime(cudaMalloc(&m_output_objects_device, output_objects_size * sizeof(float)));
|
|
checkRuntime(cudaMalloc(&m_output_idx_device, m_param.batch_size * m_param.topK * sizeof(int)));
|
|
checkRuntime(cudaMalloc(&m_output_conf_device, m_param.batch_size * m_param.topK * sizeof(float)));
|
|
m_output_objects_host = new float[output_objects_size];
|
|
m_objectss.resize(param.batch_size);
|
|
|
|
}
|
|
|
|
MF_ObjectDetectBase::~MF_ObjectDetectBase()
|
|
{
|
|
// input
|
|
checkRuntime(cudaFree(m_input_src_device));
|
|
checkRuntime(cudaFree(m_input_resize_device));
|
|
checkRuntime(cudaFree(m_input_rgb_device));
|
|
checkRuntime(cudaFree(m_input_norm_device));
|
|
checkRuntime(cudaFree(m_input_hwc_device));
|
|
|
|
// output
|
|
checkRuntime(cudaFree(m_output_src_device));
|
|
checkRuntime(cudaFree(m_output_objects_device));
|
|
checkRuntime(cudaFree(m_output_idx_device));
|
|
checkRuntime(cudaFree(m_output_conf_device));
|
|
delete[] m_output_objects_host;
|
|
}
|
|
|
|
std::vector<std::vector<utils::Box>> MF_ObjectDetectBase::getObjectss() const
|
|
{
|
|
return m_objectss;
|
|
} |