You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

489 lines
13 KiB
C++

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "MF_Resnet34Infer.h"
#include "ML_Log.h"
namespace fs = std::filesystem;
using namespace samplesCommon;
MF_Resnet34Infer::MF_Resnet34Infer(const trtUtils::InitParameter& param) :MF_ImageClassificationBase(param)
{
checkRuntime(cudaStreamCreate(&mStream));
input_numel = m_param.batch_size * m_param.mImage.m_channels * m_param.dst_h * m_param.dst_w;
checkRuntime(cudaMallocHost(&input_data_host, input_numel * sizeof(float)));
checkRuntime(cudaMalloc(&input_data_device, input_numel * sizeof(float)));
output_data_host[NUMCLASSES] = { 0.0 };
checkRuntime(cudaMalloc(&output_data_device, input_numel * sizeof(float)));
}
MF_Resnet34Infer::~MF_Resnet34Infer()
{
// 释放内存
checkRuntime(cudaStreamDestroy(mStream));
checkRuntime(cudaFreeHost(input_data_host));
checkRuntime(cudaFree(input_data_device));
checkRuntime(cudaFree(output_data_device));
}
bool MF_Resnet34Infer::initEngine(const std::string& _onnxFileName)
{
LOG_INFO("on the init engine, input onnx file : " + _onnxFileName);
m_mutex.lock();
// 判断传入的onnx文件是否存在
fs::path onnx_path(_onnxFileName);
if (!fs::exists(onnx_path))
{
LOG_ERROR("init engine, input onnx file does not exist. \n");
return false;
}
// 替换文件扩展名,将.onnx 替换为 .trt并判断 trt 模型是否已经存在
// 若本地存在trt模型则直接加载trt模型并构建engine
fs::path extension("trt");
fs::path trt_Path = onnx_path.replace_extension(extension);
// std::string trtFileName = trt_Path.string();
if (fs::exists(trt_Path))
{
LOG_INFO("trt model has existed.\n");
std::vector<uchar> engine_data;
int iRet = loadTRTModelData(trt_Path.string(), engine_data);
if (iRet != 0)
{
LOG_ERROR("load trt model failed.\n");
return false;
}
m_mutex.unlock();
m_runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(sample::gLogger.getTRTLogger()));
if (!m_runtime)
{
LOG_ERROR("on the init engine, create infer runtime failed.\n");
return false;
}
m_engine = std::unique_ptr<nvinfer1::ICudaEngine>(m_runtime->deserializeCudaEngine(engine_data.data(), engine_data.size()));
if (!m_engine)
{
LOG_ERROR("on the init engine, deserialize engine failed.\n");
return false;
}
m_context = std::unique_ptr<nvinfer1::IExecutionContext>(m_engine->createExecutionContext());
if (!m_context)
{
LOG_ERROR("on the init engine, create excution context failed.\n");
return false;
}
if (m_param.dynamic_batch)
{
m_context->setBindingDimensions(0, nvinfer1::Dims4(m_param.batch_size, m_param.mImage.m_channels, m_param.dst_h, m_param.dst_w));
}
// 明确当前推理时,使用的输入数据大小
input_dims = m_context->getBindingDimensions(0);
input_dims.d[0] = m_param.batch_size;
}
else
{
// 本地不存在trt模型需要重新构建Engine
int iRet = this->buildModel(_onnxFileName);
if (iRet != 0)
{
LOG_ERROR("on the init engine, from onnx file build model failed.\n");
return false;
}
}
return true;
}
bool MF_Resnet34Infer::doTRTInfer(const std::vector<MN_VisionImage::MS_ImageParam>& _bufImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user)
{
m_mutex.lock();
std::vector<cv::Mat> matImgs;
for (auto _var : _bufImgs)
{
cv::Mat image;
bool bRet = buffer2Mat(_var, image);
if (!bRet)
{
LOG_ERROR("doinfer(), convert buffer to Mat failed. \n");
return false;
}
matImgs.emplace_back(image);
}
int iRet = 0;
iRet = this->preProcess(matImgs);
if (iRet != 0)
{
LOG_ERROR("doinfer(), preprocess image failed. \n");
return false;
}
iRet = this->infer();
if (iRet != 0)
{
LOG_ERROR("doinfer(), infer failed. \n");
return false;
}
iRet = this->postProcess(matImgs);
if (iRet != 0)
{
LOG_ERROR("doinfer(), postprocess image failed. \n");
return false;
}
iRet = this->getDetectResult(*_detectRes);
if (iRet != 0)
{
LOG_ERROR("doinfer(), get detect result failed. \n");
return false;
}
m_mutex.unlock();
return true;
}
bool MF_Resnet34Infer::doTRTInfer(const std::vector<cv::Mat>& _matImgs, std::vector<trtUtils::MR_Result>* _detectRes, int* _user)
{
m_mutex.lock();
int iRet = 0;
iRet = this->preProcess(_matImgs);
if (iRet != 0)
{
LOG_ERROR("doinfer(), preprocess image failed. \n");
return false;
}
iRet = this->infer();
if (iRet != 0)
{
LOG_ERROR("doinfer(), infer failed. \n");
return false;
}
iRet = this->postProcess(_matImgs);
if (iRet != 0)
{
LOG_ERROR("doinfer(), postprocess image failed. \n");
return false;
}
iRet = this->getDetectResult(*_detectRes);
if (iRet != 0)
{
LOG_ERROR("doinfer(), get detect result failed. \n");
return false;
}
m_mutex.unlock();
return true;
}
std::string MF_Resnet34Infer::getError()
{
return "";
}
void MF_Resnet34Infer::freeMemeory()
{
return;
}
int MF_Resnet34Infer::buildModel(const std::string& _onnxFile)
{
LOG_INFO("on the build model, input onnx file : " + _onnxFile);
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
if (!builder)
{
LOG_ERROR("on the build model, create infer builder failed. \n");
return 1;
}
auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
if (!network)
{
LOG_ERROR("on the build model, create network failed. \n");
return 1;
}
auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
if (!config)
{
LOG_ERROR("on the build model, create builder config failed. \n");
return 1;
}
config->setFlag(nvinfer1::BuilderFlag::kFP16);
samplesCommon::enableDLA(builder.get(), config.get(), -1);
// 通过onnxparser解析器解析的结果会填充到network中
auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
if (!parser->parseFromFile(_onnxFile.c_str(), static_cast<int>(sample::gLogger.getReportableSeverity())))
{
LOG_ERROR("on the build model, parser onnx file failed. \n");
for (int i = 0; i < parser->getNbErrors(); i++)
{
LOG_WARN(parser->getError(i)->desc());
}
return 1;
}
int maxBatchSize = 1;
printf("workspace size = %.2f MB.\n", (1 << 28) / 1024.0f / 1024.0f);
// 如果模型有多个输入则必须有多个profile
auto profile = builder->createOptimizationProfile();
auto input_tensor = network->getInput(0);
auto input_dims = input_tensor->getDimensions();
// 配置最小、最优、最大范围
input_dims.d[0] = 1; // batchsize
profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMIN, input_dims);
profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kOPT, input_dims);
input_dims.d[0] = maxBatchSize; // [batchsize, channel, height, width]
profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMAX, input_dims);
config->addOptimizationProfile(profile);
auto profileStream = samplesCommon::makeCudaStream();
if (!profileStream)
{
LOG_ERROR("on the build model, make cuda stream failed.\n");
return 1;
}
config->setProfileStream(*profileStream);
auto engine_data = std::unique_ptr<nvinfer1::IHostMemory>(builder->buildSerializedNetwork(*network, *config));
if (!engine_data)
{
LOG_ERROR("on the build model, build engine failed.\n");
return 1;
}
// 将模型保存为二进制文件
fs::path onnx_path(_onnxFile);
fs::path extension("trt");
fs::path trt_Path = onnx_path.replace_extension(extension);
FILE* fp;
errno_t err = fopen_s(&fp, trt_Path.string().c_str(), "wb");
if (err != 0)
{
LOG_ERROR("save engine data failed. \n");
return 1;
}
fwrite(engine_data->data(), 1, engine_data->size(), fp);
fclose(fp);
LOG_INFO("on the build model, save model success. \n");
m_runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(sample::gLogger.getTRTLogger()));
if (!m_runtime)
{
LOG_ERROR("on the build model, create infer runtime failed.\n");
return 1;
}
m_engine = std::unique_ptr<nvinfer1::ICudaEngine>(m_runtime->deserializeCudaEngine(engine_data->data(), engine_data->size()));
if (!m_runtime)
{
LOG_ERROR("on the build model, deserialize cuda engine failed.\n");
return 1;
}
m_context = std::unique_ptr<nvinfer1::IExecutionContext>(m_engine->createExecutionContext());
if (!m_context)
{
LOG_ERROR("on the build model, create excution context failed.\n");
return 1;
}
if (m_param.dynamic_batch)
{
m_context->setBindingDimensions(0, nvinfer1::Dims4(m_param.batch_size, m_param.mImage.m_channels, m_param.dst_h, m_param.dst_w));
}
// 明确当前推理时,使用的输入数据大小
input_dims = m_context->getBindingDimensions(0);
input_dims.d[0] = m_param.batch_size;
LOG_INFO("build model success. \n");
return 0;
}
int MF_Resnet34Infer::preProcess(const std::vector<cv::Mat>& _imgsBatch)
{
cv::Mat image;
cv::resize(_imgsBatch[IMGS_IDX], image, cv::Size(m_param.dst_w, m_param.dst_h));
int image_size = image.rows * image.cols;
uchar* pData = image.data;
float* pHost_b = input_data_host + image_size * 0;
float* pHost_g = input_data_host + image_size * 1;
float* pHost_r = input_data_host + image_size * 2;
for (int i = 0; i < image_size; i++, pData += 3)
{
// 注意这里的顺序rgb调换了
*pHost_r++ = (pData[0] / 255.0f - m_param.meanVec[0]) / m_param.stdVec[0];
*pHost_g++ = (pData[1] / 255.0f - m_param.meanVec[1]) / m_param.stdVec[1];
*pHost_b++ = (pData[2] / 255.0f - m_param.meanVec[2]) / m_param.stdVec[2];
}
// 将输入数据从 host 拷贝到 device
checkRuntime(cudaMemcpyAsync(input_data_device, input_data_host, input_numel * sizeof(float), cudaMemcpyHostToDevice, mStream));
return 0;
}
int MF_Resnet34Infer::infer()
{
if (m_context == nullptr)
{
LOG_ERROR("do infer, context dose not init. \n");
return 1;
}
// 设置当前推理时input大小
m_context->setBindingDimensions(0, input_dims);
float* bindings[] = { input_data_device, output_data_device };
bool bRet = m_context->enqueueV2((void**)bindings, mStream, nullptr);
if (!bRet)
{
LOG_ERROR("infer failed.\n");
return 1;
}
checkRuntime(cudaStreamSynchronize(mStream));
return 0;
}
int MF_Resnet34Infer::postProcess(const std::vector<cv::Mat>& _imgsBatch)
{
// 将推理结果从device拷贝到host
checkRuntime(cudaMemcpyAsync(output_data_host, output_data_device, sizeof(output_data_host), cudaMemcpyDeviceToHost, mStream));
checkRuntime(cudaStreamSynchronize(mStream));
std::vector<float> output_data_vec;
for (auto _var : output_data_host)
{
output_data_vec.emplace_back(_var);
}
std::vector<float> output_result = this->softmax(output_data_vec);
float output[NUMCLASSES] = { output_result[0],output_result[1],output_result[2],output_result[3],output_result[4] };;
float* prob = output;
int predict_label = std::max_element(prob, prob + NUMCLASSES) - prob; // 确定预测类别的下标
auto labels = m_param.class_names;
predictName = labels[predict_label];
confidence = prob[predict_label]; // 获得预测值的置信度
if (confidence < 0.5)
{
detectRes = trtUtils::ME_DetectRes::E_DETECT_NG;
}
else
{
detectRes = trtUtils::ME_DetectRes::E_DETECT_OK;
}
printf("Predict: %s, confidence: %.3f, label: %d. \n", predictName.c_str(), confidence, predict_label);
return 0;
}
int MF_Resnet34Infer::getDetectResult(std::vector<trtUtils::MR_Result>& _result)
{
trtUtils::MR_Result res;
for (size_t i = 0; i < m_param.batch_size; i++)
{
res.mClassifyDecRes.mDetectRes = detectRes;
res.mClassifyDecRes.mConfidence = confidence;
res.mClassifyDecRes.mLabel = predictName;
_result.emplace_back(res);
}
return 0;
}
std::vector<float> MF_Resnet34Infer::softmax(std::vector<float> _input)
{
std::vector<float> result{};
float total = 0;
float MAX = _input[0];
for (auto x : _input)
{
MAX = (std::max)(x, MAX);
}
for (auto x : _input)
{
total += exp(x - MAX);
}
for (auto x : _input)
{
result.emplace_back(exp(x - MAX) / total);
}
return result;
}
std::vector<std::string> MF_Resnet34Infer::load_labels(const std::string& _labelPath)
{
std::vector<std::string> lines;
FILE* fp = nullptr;
errno_t err = fopen_s(&fp, _labelPath.c_str(), "r");
if (err != 0)
{
return std::vector<std::string>{};
}
char buf[1024];
std::string line;
while (!feof(fp))
{
fgets(buf, sizeof(buf), (FILE*)fp);
line = this->UTF8_2_GB(buf);
lines.emplace_back(line);
}
fclose(fp);
return lines;
}
std::string MF_Resnet34Infer::UTF8_2_GB(const char* str)
{
std::string result;
WCHAR* strSrc;
LPSTR szRes;
//获得临时变量的大小
int i = MultiByteToWideChar(CP_UTF8, 0, str, -1, NULL, 0);
strSrc = new WCHAR[i + 1];
MultiByteToWideChar(CP_UTF8, 0, str, -1, strSrc, i);
//获得临时变量的大小
i = WideCharToMultiByte(CP_ACP, 0, strSrc, -1, NULL, 0, NULL, NULL);
szRes = new CHAR[i + 1];
WideCharToMultiByte(CP_ACP, 0, strSrc, -1, szRes, i, NULL, NULL);
result = szRes;
delete[]strSrc;
delete[]szRes;
return result;
}