/* * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "getOptions.h" #include "logger.h" #include #include #include #include #include namespace nvinfer1 { namespace utility { //! Matching for TRTOptions is defined as follows: //! //! If A and B both have longName set, A matches B if and only if A.longName == //! B.longName and (A.shortName == B.shortName if both have short name set). //! //! If A only has shortName set and B only has longName set, then A does not //! match B. It is assumed that when 2 TRTOptions are compared, one of them is //! the definition of a TRTOption in the input to getOptions. As such, if the //! definition only has shortName set, it will never be equal to a TRTOption //! that does not have shortName set (and same for longName). //! //! If A and B both have shortName set but B does not have longName set, A //! matches B if and only if A.shortName == B.shortName. //! //! If A has neither long or short name set, A matches B if and only if B has //! neither long or short name set. bool matches(const TRTOption& a, const TRTOption& b) { if (!a.longName.empty() && !b.longName.empty()) { if (a.shortName && b.shortName) { return (a.longName == b.longName) && (a.shortName == b.shortName); } return a.longName == b.longName; } // If only one of them is not set, this will return false anyway. return a.shortName == b.shortName; } //! getTRTOptionIndex returns the index of a TRTOption in a vector of //! TRTOptions, -1 if not found. int getTRTOptionIndex(const std::vector& options, const TRTOption& opt) { for (size_t i = 0; i < options.size(); ++i) { if (matches(opt, options[i])) { return i; } } return -1; } //! validateTRTOption will return a string containing an error message if options //! contain non-numeric characters, or if there are duplicate option names found. //! Otherwise, returns the empty string. std::string validateTRTOption( const std::set& seenShortNames, const std::set& seenLongNames, const TRTOption& opt) { if (opt.shortName != 0) { if (!std::isalnum(opt.shortName)) { return "Short name '" + std::to_string(opt.shortName) + "' is non-alphanumeric"; } if (seenShortNames.find(opt.shortName) != seenShortNames.end()) { return "Short name '" + std::to_string(opt.shortName) + "' is a duplicate"; } } if (!opt.longName.empty()) { for (const char& c : opt.longName) { if (!std::isalnum(c) && c != '-' && c != '_') { return "Long name '" + opt.longName + "' contains characters that are not '-', '_', or alphanumeric"; } } if (seenLongNames.find(opt.longName) != seenLongNames.end()) { return "Long name '" + opt.longName + "' is a duplicate"; } } return ""; } //! validateTRTOptions will return a string containing an error message if any //! options contain non-numeric characters, or if there are duplicate option //! names found. Otherwise, returns the empty string. std::string validateTRTOptions(const std::vector& options) { std::set seenShortNames; std::set seenLongNames; for (size_t i = 0; i < options.size(); ++i) { const std::string errMsg = validateTRTOption(seenShortNames, seenLongNames, options[i]); if (!errMsg.empty()) { return "Error '" + errMsg + "' at TRTOption " + std::to_string(i); } seenShortNames.insert(options[i].shortName); seenLongNames.insert(options[i].longName); } return ""; } //! parseArgs parses an argument list and returns a TRTParsedArgs with the //! fields set accordingly. Assumes that options is validated. //! ErrMsg will be set if: //! - an argument is null //! - an argument is empty //! - an argument does not have option (i.e. "-" and "--") //! - a short argument has more than 1 character //! - the last argument in the list requires a value TRTParsedArgs parseArgs(int argc, const char* const* argv, const std::vector& options) { TRTParsedArgs parsedArgs; parsedArgs.values.resize(options.size()); for (int i = 1; i < argc; ++i) // index of current command-line argument { if (argv[i] == nullptr) { return TRTParsedArgs{"Null argument at index " + std::to_string(i)}; } const std::string argStr(argv[i]); if (argStr.empty()) { return TRTParsedArgs{"Empty argument at index " + std::to_string(i)}; } // No starting hyphen means it is a positional argument if (argStr[0] != '-') { parsedArgs.positionalArgs.push_back(argStr); continue; } if (argStr == "-" || argStr == "--") { return TRTParsedArgs{"Argument does not specify an option at index " + std::to_string(i)}; } // If only 1 hyphen, char after is the flag. TRTOption opt{' ', "", false, ""}; std::string value; if (argStr[1] != '-') { // Must only have 1 char after the hyphen if (argStr.size() > 2) { return TRTParsedArgs{"Short arg contains more than 1 character at index " + std::to_string(i)}; } opt.shortName = argStr[1]; } else { opt.longName = argStr.substr(2); // We need to support --foo=bar syntax, so look for '=' const size_t eqIndex = opt.longName.find('='); if (eqIndex < opt.longName.size()) { value = opt.longName.substr(eqIndex + 1); opt.longName = opt.longName.substr(0, eqIndex); } } const int idx = getTRTOptionIndex(options, opt); if (idx < 0) { continue; } if (options[idx].valueRequired) { if (!value.empty()) { parsedArgs.values[idx].second.push_back(value); parsedArgs.values[idx].first = parsedArgs.values[idx].second.size(); continue; } if (i + 1 >= argc) { return TRTParsedArgs{"Last argument requires value, but none given"}; } const std::string nextArg(argv[i + 1]); if (nextArg.size() >= 1 && nextArg[0] == '-') { sample::gLogWarning << "Warning: Using '" << nextArg << "' as a value for '" << argStr << "', Should this be its own flag?" << std::endl; } parsedArgs.values[idx].second.push_back(nextArg); i += 1; // Next argument already consumed parsedArgs.values[idx].first = parsedArgs.values[idx].second.size(); } else { parsedArgs.values[idx].first += 1; } } return parsedArgs; } TRTParsedArgs getOptions(int argc, const char* const* argv, const std::vector& options) { const std::string errMsg = validateTRTOptions(options); if (!errMsg.empty()) { return TRTParsedArgs{errMsg}; } return parseArgs(argc, argv, options); } } // namespace utility } // namespace nvinfer1