You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
4.1 KiB
C
137 lines
4.1 KiB
C
1 year ago
|
/*
|
||
|
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#ifndef ENTROPY_CALIBRATOR_H
|
||
|
#define ENTROPY_CALIBRATOR_H
|
||
|
|
||
|
#include "BatchStream.h"
|
||
|
#include "NvInfer.h"
|
||
|
|
||
|
//! \class EntropyCalibratorImpl
|
||
|
//!
|
||
|
//! \brief Implements common functionality for Entropy calibrators.
|
||
|
//!
|
||
|
template <typename TBatchStream>
|
||
|
class EntropyCalibratorImpl
|
||
|
{
|
||
|
public:
|
||
|
EntropyCalibratorImpl(TBatchStream const& stream, int firstBatch, std::string const& networkName,
|
||
|
const char* inputBlobName, bool readCache = true)
|
||
|
: mStream{stream}
|
||
|
, mCalibrationTableName("CalibrationTable" + networkName)
|
||
|
, mInputBlobName(inputBlobName)
|
||
|
, mReadCache(readCache)
|
||
|
{
|
||
|
nvinfer1::Dims dims = mStream.getDims();
|
||
|
mInputCount = samplesCommon::volume(dims);
|
||
|
CHECK(cudaMalloc(&mDeviceInput, mInputCount * sizeof(float)));
|
||
|
mStream.reset(firstBatch);
|
||
|
}
|
||
|
|
||
|
virtual ~EntropyCalibratorImpl()
|
||
|
{
|
||
|
CHECK(cudaFree(mDeviceInput));
|
||
|
}
|
||
|
|
||
|
int getBatchSize() const noexcept
|
||
|
{
|
||
|
return mStream.getBatchSize();
|
||
|
}
|
||
|
|
||
|
bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept
|
||
|
{
|
||
|
if (!mStream.next())
|
||
|
{
|
||
|
return false;
|
||
|
}
|
||
|
CHECK(cudaMemcpy(mDeviceInput, mStream.getBatch(), mInputCount * sizeof(float), cudaMemcpyHostToDevice));
|
||
|
ASSERT(!strcmp(names[0], mInputBlobName));
|
||
|
bindings[0] = mDeviceInput;
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
const void* readCalibrationCache(size_t& length) noexcept
|
||
|
{
|
||
|
mCalibrationCache.clear();
|
||
|
std::ifstream input(mCalibrationTableName, std::ios::binary);
|
||
|
input >> std::noskipws;
|
||
|
if (mReadCache && input.good())
|
||
|
{
|
||
|
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
|
||
|
std::back_inserter(mCalibrationCache));
|
||
|
}
|
||
|
length = mCalibrationCache.size();
|
||
|
return length ? mCalibrationCache.data() : nullptr;
|
||
|
}
|
||
|
|
||
|
void writeCalibrationCache(const void* cache, size_t length) noexcept
|
||
|
{
|
||
|
std::ofstream output(mCalibrationTableName, std::ios::binary);
|
||
|
output.write(reinterpret_cast<const char*>(cache), length);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
TBatchStream mStream;
|
||
|
size_t mInputCount;
|
||
|
std::string mCalibrationTableName;
|
||
|
const char* mInputBlobName;
|
||
|
bool mReadCache{true};
|
||
|
void* mDeviceInput{nullptr};
|
||
|
std::vector<char> mCalibrationCache;
|
||
|
};
|
||
|
|
||
|
//! \class Int8EntropyCalibrator2
|
||
|
//!
|
||
|
//! \brief Implements Entropy calibrator 2.
|
||
|
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
|
||
|
//!
|
||
|
template <typename TBatchStream>
|
||
|
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2
|
||
|
{
|
||
|
public:
|
||
|
Int8EntropyCalibrator2(TBatchStream const& stream, int32_t firstBatch, const char* networkName,
|
||
|
const char* inputBlobName, bool readCache = true)
|
||
|
: mImpl(stream, firstBatch, networkName, inputBlobName, readCache)
|
||
|
{
|
||
|
}
|
||
|
|
||
|
int getBatchSize() const noexcept override
|
||
|
{
|
||
|
return mImpl.getBatchSize();
|
||
|
}
|
||
|
|
||
|
bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override
|
||
|
{
|
||
|
return mImpl.getBatch(bindings, names, nbBindings);
|
||
|
}
|
||
|
|
||
|
const void* readCalibrationCache(size_t& length) noexcept override
|
||
|
{
|
||
|
return mImpl.readCalibrationCache(length);
|
||
|
}
|
||
|
|
||
|
void writeCalibrationCache(const void* cache, size_t length) noexcept override
|
||
|
{
|
||
|
mImpl.writeCalibrationCache(cache, length);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
EntropyCalibratorImpl<TBatchStream> mImpl;
|
||
|
};
|
||
|
|
||
|
#endif // ENTROPY_CALIBRATOR_H
|