/* * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef TENSORRT_SAFE_COMMON_H #define TENSORRT_SAFE_COMMON_H #include "cuda_runtime.h" #include "NvInferRuntimeCommon.h" #include #include #include #include #include #include // For safeLoadLibrary #ifdef _MSC_VER // Needed so that the max/min definitions in windows.h do not conflict with std::max/min. #define NOMINMAX #include #undef NOMINMAX #else #include #endif #undef CHECK #define CHECK(status) \ do \ { \ auto ret = (status); \ if (ret != 0) \ { \ std::cerr << "Cuda failure: " << ret << std::endl; \ abort(); \ } \ } while (0) #undef SAFE_ASSERT #define SAFE_ASSERT(condition) \ do \ { \ if (!(condition)) \ { \ std::cerr << "Assertion failure: " << #condition << std::endl; \ abort(); \ } \ } while (0) namespace samplesCommon { template inline std::shared_ptr infer_object(T* obj) { if (!obj) { throw std::runtime_error("Failed to create object"); } return std::shared_ptr(obj); } inline uint32_t elementSize(nvinfer1::DataType t) { switch (t) { case nvinfer1::DataType::kINT32: case nvinfer1::DataType::kFLOAT: return 4; case nvinfer1::DataType::kHALF: return 2; case nvinfer1::DataType::kINT8: return 1; case nvinfer1::DataType::kUINT8: return 1; case nvinfer1::DataType::kBOOL: return 1; case nvinfer1::DataType::kFP8: return 1; } return 0; } template inline A divUp(A x, B n) { return (x + n - 1) / n; } inline int64_t volume(nvinfer1::Dims const& d) { return std::accumulate(d.d, d.d + d.nbDims, int64_t{1}, std::multiplies{}); } // Return m rounded up to nearest multiple of n template inline T roundUp(T m, T n) { return ((m + n - 1) / n) * n; } //! comps is the number of components in a vector. Ignored if vecDim < 0. inline int64_t volume(nvinfer1::Dims dims, int32_t vecDim, int32_t comps, int32_t batch) { if (vecDim >= 0) { dims.d[vecDim] = roundUp(dims.d[vecDim], comps); } return samplesCommon::volume(dims) * std::max(batch, 1); } //! //! \class TrtCudaGraphSafe //! \brief Managed CUDA graph //! class TrtCudaGraphSafe { public: explicit TrtCudaGraphSafe() = default; TrtCudaGraphSafe(const TrtCudaGraphSafe&) = delete; TrtCudaGraphSafe& operator=(const TrtCudaGraphSafe&) = delete; TrtCudaGraphSafe(TrtCudaGraphSafe&&) = delete; TrtCudaGraphSafe& operator=(TrtCudaGraphSafe&&) = delete; ~TrtCudaGraphSafe() { if (mGraphExec) { cudaGraphExecDestroy(mGraphExec); } } void beginCapture(cudaStream_t& stream) { // cudaStreamCaptureModeGlobal is the only allowed mode in SAFE CUDA CHECK(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal)); } bool launch(cudaStream_t& stream) { return cudaGraphLaunch(mGraphExec, stream) == cudaSuccess; } void endCapture(cudaStream_t& stream) { CHECK(cudaStreamEndCapture(stream, &mGraph)); CHECK(cudaGraphInstantiate(&mGraphExec, mGraph, nullptr, nullptr, 0)); CHECK(cudaGraphDestroy(mGraph)); } void endCaptureOnError(cudaStream_t& stream) { // There are two possibilities why stream capture would fail: // (1) stream is in cudaErrorStreamCaptureInvalidated state. // (2) TRT reports a failure. // In case (1), the returning mGraph should be nullptr. // In case (2), the returning mGraph is not nullptr, but it should not be used. const auto ret = cudaStreamEndCapture(stream, &mGraph); if (ret == cudaErrorStreamCaptureInvalidated) { SAFE_ASSERT(mGraph == nullptr); } else { SAFE_ASSERT(ret == cudaSuccess); SAFE_ASSERT(mGraph != nullptr); CHECK(cudaGraphDestroy(mGraph)); mGraph = nullptr; } // Clean up any CUDA error. cudaGetLastError(); sample::gLogError << "The CUDA graph capture on the stream has failed." << std::endl; } private: cudaGraph_t mGraph{}; cudaGraphExec_t mGraphExec{}; }; inline void safeLoadLibrary(const std::string& path) { #ifdef _MSC_VER void* handle = LoadLibrary(path.c_str()); #else int32_t flags{RTLD_LAZY}; void* handle = dlopen(path.c_str(), flags); #endif if (handle == nullptr) { #ifdef _MSC_VER sample::gLogError << "Could not load plugin library: " << path << std::endl; #else sample::gLogError << "Could not load plugin library: " << path << ", due to: " << dlerror() << std::endl; #endif } } inline std::vector safeSplitString(std::string str, char delimiter = ',') { std::vector splitVect; std::stringstream ss(str); std::string substr; while (ss.good()) { getline(ss, substr, delimiter); splitVect.emplace_back(std::move(substr)); } return splitVect; } } // namespace samplesCommon #endif // TENSORRT_SAFE_COMMON_H