/* * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ENTROPY_CALIBRATOR_H #define ENTROPY_CALIBRATOR_H #include "BatchStream.h" #include "NvInfer.h" //! \class EntropyCalibratorImpl //! //! \brief Implements common functionality for Entropy calibrators. //! template class EntropyCalibratorImpl { public: EntropyCalibratorImpl(TBatchStream const& stream, int firstBatch, std::string const& networkName, const char* inputBlobName, bool readCache = true) : mStream{stream} , mCalibrationTableName("CalibrationTable" + networkName) , mInputBlobName(inputBlobName) , mReadCache(readCache) { nvinfer1::Dims dims = mStream.getDims(); mInputCount = samplesCommon::volume(dims); CHECK(cudaMalloc(&mDeviceInput, mInputCount * sizeof(float))); mStream.reset(firstBatch); } virtual ~EntropyCalibratorImpl() { CHECK(cudaFree(mDeviceInput)); } int getBatchSize() const noexcept { return mStream.getBatchSize(); } bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept { if (!mStream.next()) { return false; } CHECK(cudaMemcpy(mDeviceInput, mStream.getBatch(), mInputCount * sizeof(float), cudaMemcpyHostToDevice)); ASSERT(!strcmp(names[0], mInputBlobName)); bindings[0] = mDeviceInput; return true; } const void* readCalibrationCache(size_t& length) noexcept { mCalibrationCache.clear(); std::ifstream input(mCalibrationTableName, std::ios::binary); input >> std::noskipws; if (mReadCache && input.good()) { std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(mCalibrationCache)); } length = mCalibrationCache.size(); return length ? mCalibrationCache.data() : nullptr; } void writeCalibrationCache(const void* cache, size_t length) noexcept { std::ofstream output(mCalibrationTableName, std::ios::binary); output.write(reinterpret_cast(cache), length); } private: TBatchStream mStream; size_t mInputCount; std::string mCalibrationTableName; const char* mInputBlobName; bool mReadCache{true}; void* mDeviceInput{nullptr}; std::vector mCalibrationCache; }; //! \class Int8EntropyCalibrator2 //! //! \brief Implements Entropy calibrator 2. //! CalibrationAlgoType is kENTROPY_CALIBRATION_2. //! template class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { public: Int8EntropyCalibrator2(TBatchStream const& stream, int32_t firstBatch, const char* networkName, const char* inputBlobName, bool readCache = true) : mImpl(stream, firstBatch, networkName, inputBlobName, readCache) { } int getBatchSize() const noexcept override { return mImpl.getBatchSize(); } bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { return mImpl.getBatch(bindings, names, nbBindings); } const void* readCalibrationCache(size_t& length) noexcept override { return mImpl.readCalibrationCache(length); } void writeCalibrationCache(const void* cache, size_t length) noexcept override { mImpl.writeCalibrationCache(cache, length); } private: EntropyCalibratorImpl mImpl; }; #endif // ENTROPY_CALIBRATOR_H