Program Listing for File config_validator.cpp¶
↰ Return to documentation for file (src/common/config_validator.cpp
)
#include "common/config_validator.h"
#include "common/logging.h"
#include "common/regex.h"
#include "common/utils.h"
#include "common/filesystem.h"
#include <set>
namespace marian {
bool ConfigValidator::has(const std::string& key) const {
return config_[key];
}
ConfigValidator::ConfigValidator(const YAML::Node& config)
: config_(config),
dumpConfigOnly_(config["dump-config"] && !config["dump-config"].as<std::string>().empty()
&& config["dump-config"].as<std::string>() != "false") {}
ConfigValidator::~ConfigValidator() {}
void ConfigValidator::validateOptions(cli::mode mode) const {
// clang-format off
switch(mode) {
case cli::mode::translation:
validateOptionsTranslation();
break;
case cli::mode::scoring:
validateOptionsParallelData();
validateOptionsScoring();
break;
case cli::mode::embedding:
validateOptionsParallelData();
validateOptionsScoring();
break;
case cli::mode::training:
validateOptionsParallelData();
validateOptionsTraining();
break;
default:
ABORT("wrong CLI mode");
break;
}
// clang-format on
validateModelExtension(mode);
validateDevices(mode);
}
void ConfigValidator::validateOptionsTranslation() const {
auto models = get<std::vector<std::string>>("models");
auto configs = get<std::vector<std::string>>("config");
ABORT_IF(models.empty() && configs.empty(),
"You need to provide at least one model file or a config file");
ABORT_IF(get<bool>("model-mmap") && get<size_t>("cpu-threads") == 0,
"Model MMAP is CPU-only, please use --cpu-threads");
for(const auto& modelFile : models) {
filesystem::Path modelPath(modelFile);
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile);
}
auto vocabs = get<std::vector<std::string>>("vocabs");
ABORT_IF(vocabs.empty(), "Translating, but vocabularies are not given");
for(const auto& vocabFile : vocabs) {
filesystem::Path vocabPath(vocabFile);
ABORT_IF(!filesystem::exists(vocabPath), "Vocabulary file does not exist: " + vocabFile);
}
}
void ConfigValidator::validateOptionsParallelData() const {
// Do not check these constraints if only goal is to dump config
if(dumpConfigOnly_)
return;
auto trainSets = get<std::vector<std::string>>("train-sets");
ABORT_IF(trainSets.empty(), "No train sets given in config file or on command line");
auto numVocabs = get<std::vector<std::string>>("vocabs").size();
ABORT_IF(!get<bool>("tsv") && numVocabs > 0 && numVocabs != trainSets.size(),
"There should be as many vocabularies as training files");
// disallow, for example --tsv --train-sets file1.tsv file2.tsv
ABORT_IF(get<bool>("tsv") && trainSets.size() != 1,
"A single file must be provided with --train-sets (or stdin) for a tab-separated input");
// disallow, for example --train-sets stdin stdin or --train-sets stdin file.tsv
ABORT_IF(trainSets.size() > 1
&& std::any_of(trainSets.begin(),
trainSets.end(),
[](const std::string& s) { return (s == "stdin") || (s == "-"); }),
"Only one 'stdin' or '-' in --train-sets is allowed");
}
void ConfigValidator::validateOptionsScoring() const {
filesystem::Path modelPath(get<std::string>("model"));
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelPath.string());
auto vocabs = get<std::vector<std::string>>("vocabs");
ABORT_IF(vocabs.empty(), "Scoring, but vocabularies are not given");
for(const auto& vocabFile : vocabs) {
filesystem::Path vocabPath(vocabFile);
ABORT_IF(!filesystem::exists(vocabPath), "Vocabulary file does not exist: " + vocabFile);
}
}
void ConfigValidator::validateOptionsTraining() const {
auto trainSets = get<std::vector<std::string>>("train-sets");
ABORT_IF(has("embedding-vectors")
&& get<std::vector<std::string>>("embedding-vectors").size() != trainSets.size()
&& !get<std::vector<std::string>>("embedding-vectors").empty(),
"There should be as many embedding vector files as training files");
filesystem::Path modelPath(get<std::string>("model"));
auto modelDir = modelPath.parentPath();
if(modelDir.empty())
modelDir = filesystem::currentPath();
ABORT_IF(!modelDir.empty() && !filesystem::isDirectory(modelDir),
"Model directory does not exist");
std::string errorMsg = "There should be as many validation files as training files";
if(get<bool>("tsv"))
errorMsg += ". If the training set is in the TSV format, validation sets have to also be a single TSV file";
ABORT_IF(has("valid-sets")
&& get<std::vector<std::string>>("valid-sets").size() != trainSets.size()
&& !get<std::vector<std::string>>("valid-sets").empty(),
errorMsg);
// check if --early-stopping-on has proper value
std::set<std::string> supportedStops = {"first", "all", "any"};
ABORT_IF(supportedStops.find(get<std::string>("early-stopping-on")) == supportedStops.end(),
"Supported options for --early-stopping-on are: first, all, any");
// validations for learning rate decaying
ABORT_IF(get<float>("lr-decay") > 1.f, "Learning rate decay factor greater than 1.0 is unusual");
auto strategy = get<std::string>("lr-decay-strategy");
ABORT_IF((strategy == "epoch+batches" || strategy == "epoch+stalled")
&& get<std::vector<size_t>>("lr-decay-start").size() != 2,
"Decay strategies 'epoch+batches' and 'epoch+stalled' require two values specified with "
"--lr-decay-start option");
ABORT_IF((strategy == "epoch" || strategy == "batches" || strategy == "stalled")
&& get<std::vector<size_t>>("lr-decay-start").size() != 1,
"Single decay strategies require only one value specified with --lr-decay-start option");
// validate ULR options
ABORT_IF((has("ulr") && get<bool>("ulr") && (get<std::string>("ulr-query-vectors") == ""
|| get<std::string>("ulr-keys-vectors") == "")),
"ULR requires query and keys vectors specified with --ulr-query-vectors and "
"--ulr-keys-vectors option");
// validate model quantization
size_t bits = get<size_t>("quantize-bits");
ABORT_IF(bits > 32, "Invalid quantization bits. Must be from 0 to 32 bits");
ABORT_IF(bits > 0 && !get<bool>("sync-sgd"), "Model quantization only works with synchronous training (--sync-sgd)");
}
void ConfigValidator::validateModelExtension(cli::mode mode) const {
std::vector<std::string> models;
if(mode == cli::mode::translation)
models = get<std::vector<std::string>>("models");
else
models.push_back(get<std::string>("model"));
for(const auto& modelPath : models) {
bool hasProperExt = utils::endsWith(modelPath, ".npz") || utils::endsWith(modelPath, ".bin");
ABORT_IF(!hasProperExt,
"Unknown model format for file '{}'. Supported file extensions: .npz, .bin",
modelPath);
}
}
void ConfigValidator::validateDevices(cli::mode /*mode*/) const {
std::string devices = utils::join(get<std::vector<std::string>>("devices"));
utils::trim(devices);
regex::regex pattern;
std::string help;
// valid strings: '0', '0 1 2 3', '3 2 0 1'
pattern = "[0-9]+( *[0-9]+)*";
help = "Supported formats: '0 1 2 3'";
ABORT_IF(!regex::regex_match(devices, pattern),
"the argument '{}' for option '--devices' is invalid. {}",
devices,
help);
}
} // namespace marian