You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
8.0 KiB
C++
249 lines
8.0 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "getOptions.h"
|
|
#include "logger.h"
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <cctype>
|
|
#include <cstring>
|
|
#include <set>
|
|
|
|
namespace nvinfer1
|
|
{
|
|
namespace utility
|
|
{
|
|
|
|
//! Matching for TRTOptions is defined as follows:
|
|
//!
|
|
//! If A and B both have longName set, A matches B if and only if A.longName ==
|
|
//! B.longName and (A.shortName == B.shortName if both have short name set).
|
|
//!
|
|
//! If A only has shortName set and B only has longName set, then A does not
|
|
//! match B. It is assumed that when 2 TRTOptions are compared, one of them is
|
|
//! the definition of a TRTOption in the input to getOptions. As such, if the
|
|
//! definition only has shortName set, it will never be equal to a TRTOption
|
|
//! that does not have shortName set (and same for longName).
|
|
//!
|
|
//! If A and B both have shortName set but B does not have longName set, A
|
|
//! matches B if and only if A.shortName == B.shortName.
|
|
//!
|
|
//! If A has neither long or short name set, A matches B if and only if B has
|
|
//! neither long or short name set.
|
|
bool matches(const TRTOption& a, const TRTOption& b)
|
|
{
|
|
if (!a.longName.empty() && !b.longName.empty())
|
|
{
|
|
if (a.shortName && b.shortName)
|
|
{
|
|
return (a.longName == b.longName) && (a.shortName == b.shortName);
|
|
}
|
|
return a.longName == b.longName;
|
|
}
|
|
|
|
// If only one of them is not set, this will return false anyway.
|
|
return a.shortName == b.shortName;
|
|
}
|
|
|
|
//! getTRTOptionIndex returns the index of a TRTOption in a vector of
|
|
//! TRTOptions, -1 if not found.
|
|
int getTRTOptionIndex(const std::vector<TRTOption>& options, const TRTOption& opt)
|
|
{
|
|
for (size_t i = 0; i < options.size(); ++i)
|
|
{
|
|
if (matches(opt, options[i]))
|
|
{
|
|
return i;
|
|
}
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
//! validateTRTOption will return a string containing an error message if options
|
|
//! contain non-numeric characters, or if there are duplicate option names found.
|
|
//! Otherwise, returns the empty string.
|
|
std::string validateTRTOption(
|
|
const std::set<char>& seenShortNames, const std::set<std::string>& seenLongNames, const TRTOption& opt)
|
|
{
|
|
if (opt.shortName != 0)
|
|
{
|
|
if (!std::isalnum(opt.shortName))
|
|
{
|
|
return "Short name '" + std::to_string(opt.shortName) + "' is non-alphanumeric";
|
|
}
|
|
|
|
if (seenShortNames.find(opt.shortName) != seenShortNames.end())
|
|
{
|
|
return "Short name '" + std::to_string(opt.shortName) + "' is a duplicate";
|
|
}
|
|
}
|
|
|
|
if (!opt.longName.empty())
|
|
{
|
|
for (const char& c : opt.longName)
|
|
{
|
|
if (!std::isalnum(c) && c != '-' && c != '_')
|
|
{
|
|
return "Long name '" + opt.longName + "' contains characters that are not '-', '_', or alphanumeric";
|
|
}
|
|
}
|
|
|
|
if (seenLongNames.find(opt.longName) != seenLongNames.end())
|
|
{
|
|
return "Long name '" + opt.longName + "' is a duplicate";
|
|
}
|
|
}
|
|
return "";
|
|
}
|
|
|
|
//! validateTRTOptions will return a string containing an error message if any
|
|
//! options contain non-numeric characters, or if there are duplicate option
|
|
//! names found. Otherwise, returns the empty string.
|
|
std::string validateTRTOptions(const std::vector<TRTOption>& options)
|
|
{
|
|
std::set<char> seenShortNames;
|
|
std::set<std::string> seenLongNames;
|
|
for (size_t i = 0; i < options.size(); ++i)
|
|
{
|
|
const std::string errMsg = validateTRTOption(seenShortNames, seenLongNames, options[i]);
|
|
if (!errMsg.empty())
|
|
{
|
|
return "Error '" + errMsg + "' at TRTOption " + std::to_string(i);
|
|
}
|
|
|
|
seenShortNames.insert(options[i].shortName);
|
|
seenLongNames.insert(options[i].longName);
|
|
}
|
|
return "";
|
|
}
|
|
|
|
//! parseArgs parses an argument list and returns a TRTParsedArgs with the
|
|
//! fields set accordingly. Assumes that options is validated.
|
|
//! ErrMsg will be set if:
|
|
//! - an argument is null
|
|
//! - an argument is empty
|
|
//! - an argument does not have option (i.e. "-" and "--")
|
|
//! - a short argument has more than 1 character
|
|
//! - the last argument in the list requires a value
|
|
TRTParsedArgs parseArgs(int argc, const char* const* argv, const std::vector<TRTOption>& options)
|
|
{
|
|
TRTParsedArgs parsedArgs;
|
|
parsedArgs.values.resize(options.size());
|
|
|
|
for (int i = 1; i < argc; ++i) // index of current command-line argument
|
|
{
|
|
if (argv[i] == nullptr)
|
|
{
|
|
return TRTParsedArgs{"Null argument at index " + std::to_string(i)};
|
|
}
|
|
|
|
const std::string argStr(argv[i]);
|
|
if (argStr.empty())
|
|
{
|
|
return TRTParsedArgs{"Empty argument at index " + std::to_string(i)};
|
|
}
|
|
|
|
// No starting hyphen means it is a positional argument
|
|
if (argStr[0] != '-')
|
|
{
|
|
parsedArgs.positionalArgs.push_back(argStr);
|
|
continue;
|
|
}
|
|
|
|
if (argStr == "-" || argStr == "--")
|
|
{
|
|
return TRTParsedArgs{"Argument does not specify an option at index " + std::to_string(i)};
|
|
}
|
|
|
|
// If only 1 hyphen, char after is the flag.
|
|
TRTOption opt{' ', "", false, ""};
|
|
std::string value;
|
|
if (argStr[1] != '-')
|
|
{
|
|
// Must only have 1 char after the hyphen
|
|
if (argStr.size() > 2)
|
|
{
|
|
return TRTParsedArgs{"Short arg contains more than 1 character at index " + std::to_string(i)};
|
|
}
|
|
opt.shortName = argStr[1];
|
|
}
|
|
else
|
|
{
|
|
opt.longName = argStr.substr(2);
|
|
|
|
// We need to support --foo=bar syntax, so look for '='
|
|
const size_t eqIndex = opt.longName.find('=');
|
|
if (eqIndex < opt.longName.size())
|
|
{
|
|
value = opt.longName.substr(eqIndex + 1);
|
|
opt.longName = opt.longName.substr(0, eqIndex);
|
|
}
|
|
}
|
|
|
|
const int idx = getTRTOptionIndex(options, opt);
|
|
if (idx < 0)
|
|
{
|
|
continue;
|
|
}
|
|
|
|
if (options[idx].valueRequired)
|
|
{
|
|
if (!value.empty())
|
|
{
|
|
parsedArgs.values[idx].second.push_back(value);
|
|
parsedArgs.values[idx].first = parsedArgs.values[idx].second.size();
|
|
continue;
|
|
}
|
|
|
|
if (i + 1 >= argc)
|
|
{
|
|
return TRTParsedArgs{"Last argument requires value, but none given"};
|
|
}
|
|
|
|
const std::string nextArg(argv[i + 1]);
|
|
if (nextArg.size() >= 1 && nextArg[0] == '-')
|
|
{
|
|
sample::gLogWarning << "Warning: Using '" << nextArg << "' as a value for '" << argStr
|
|
<< "', Should this be its own flag?" << std::endl;
|
|
}
|
|
|
|
parsedArgs.values[idx].second.push_back(nextArg);
|
|
i += 1; // Next argument already consumed
|
|
|
|
parsedArgs.values[idx].first = parsedArgs.values[idx].second.size();
|
|
}
|
|
else
|
|
{
|
|
parsedArgs.values[idx].first += 1;
|
|
}
|
|
}
|
|
return parsedArgs;
|
|
}
|
|
|
|
TRTParsedArgs getOptions(int argc, const char* const* argv, const std::vector<TRTOption>& options)
|
|
{
|
|
const std::string errMsg = validateTRTOptions(options);
|
|
if (!errMsg.empty())
|
|
{
|
|
return TRTParsedArgs{errMsg};
|
|
}
|
|
return parseArgs(argc, argv, options);
|
|
}
|
|
} // namespace utility
|
|
} // namespace nvinfer1
|