Program Listing for File validator.cpp¶
↰ Return to documentation for file (src/training/validator.cpp
)
#include "training/validator.h"
namespace marian {
std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> Validators(
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> config) {
std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> validators;
auto validMetrics = config->get<std::vector<std::string>>("valid-metrics");
std::vector<std::string> ceMetrics
= {"cross-entropy", "ce-mean", "ce-sum", "ce-mean-words", "perplexity"};
for(auto metric : validMetrics) {
if(std::find(ceMetrics.begin(), ceMetrics.end(), metric) != ceMetrics.end()) {
Ptr<Options> opts = New<Options>(*config);
opts->set("cost-type", metric);
auto validator = New<CrossEntropyValidator>(vocabs, opts);
validators.push_back(validator);
} else if(metric == "valid-script") {
auto validator = New<ScriptValidator>(vocabs, config);
validators.push_back(validator);
} else if(metric == "translation") {
auto validator = New<TranslationValidator>(vocabs, config);
validators.push_back(validator);
} else if(metric == "bleu" || metric == "bleu-detok" || metric == "bleu-segmented" || metric == "chrf") {
auto validator = New<SacreBleuValidator>(vocabs, config, metric);
validators.push_back(validator);
} else if(metric == "accuracy") {
auto validator = New<AccuracyValidator>(vocabs, config);
validators.push_back(validator);
} else if(metric == "bert-lm-accuracy") {
auto validator = New<BertAccuracyValidator>(vocabs, config, true);
validators.push_back(validator);
} else if(metric == "bert-sentence-accuracy") {
auto validator = New<BertAccuracyValidator>(vocabs, config, false);
validators.push_back(validator);
} else {
ABORT("Unknown validation metric: {}", metric);
}
}
return validators;
}
float ValidatorBase::initScore() {
return lowerIsBetter_ ? std::numeric_limits<float>::max() : std::numeric_limits<float>::lowest();
}
void ValidatorBase::actAfterLoaded(TrainingState& state) {
if(state.validators[type()]) {
lastBest_ = state.validators[type()]["last-best"].as<float>();
stalled_ = state.validators[type()]["stalled"].as<size_t>();
}
}
CrossEntropyValidator::CrossEntropyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options) {
createBatchGenerator(/*isTranslating=*/false);
auto opts = options_->with("inference",
true, // @TODO: check if required
"cost-type",
"ce-sum");
// @TODO: remove, only used for saving?
builder_ = models::createCriterionFunctionFromOptions(opts, models::usage::scoring);
}
float CrossEntropyValidator::validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) {
auto ctype = options_->get<std::string>("cost-type");
// @TODO: use with(...) everywhere, this will help with creating immutable options.
// Make options const everywhere and get rid of "set"?
auto opts = options_->with("inference", true, "cost-type", "ce-sum");
StaticLoss loss;
size_t samples = 0;
std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
auto task = [=, &loss, &samples, &graphQueue](BatchPtr batch) {
thread_local Ptr<ExpressionGraph> graph;
if(!graph) {
std::unique_lock<std::mutex> lock(mutex_);
ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
graph = graphQueue.front();
graphQueue.pop_front();
}
auto builder = models::createCriterionFunctionFromOptions(options_, models::usage::scoring);
builder->clear(graph);
auto dynamicLoss = builder->build(graph, batch);
graph->forward();
std::unique_lock<std::mutex> lock(mutex_);
loss += *dynamicLoss;
samples += batch->size();
};
{
threadPool_.reserve(graphs.size());
TaskBarrier taskBarrier;
for(auto batch : *batchGenerator_)
taskBarrier.push_back(threadPool_.enqueue(task, batch));
// ~TaskBarrier waits until all are done
}
if(ctype == "perplexity")
return std::exp(loss.loss / loss.count);
if(ctype == "ce-mean-words")
return loss.loss / loss.count;
if(ctype == "ce-sum")
return loss.loss;
else
return loss.loss / samples; // @TODO: back-compat, to be removed
}
AccuracyValidator::AccuracyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, /*lowerIsBetter=*/false) {
createBatchGenerator(/*isTranslating=*/false);
// @TODO: remove, only used for saving?
builder_ = models::createModelFromOptions(options_, models::usage::raw);
}
float AccuracyValidator::validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) {
size_t correct = 0;
size_t totalLabels = 0;
std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
auto task = [=, &correct, &totalLabels, &graphQueue](BatchPtr batch) {
thread_local Ptr<ExpressionGraph> graph;
if(!graph) {
std::unique_lock<std::mutex> lock(mutex_);
ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
graph = graphQueue.front();
graphQueue.pop_front();
}
auto builder = models::createModelFromOptions(options_, models::usage::raw);
builder->clear(graph);
Expr logits = builder->build(graph, batch).getLogits();
graph->forward();
std::vector<float> vLogits;
logits->val()->get(vLogits);
const auto& groundTruth = batch->back()->data();
IndexType cols = logits->shape()[-1];
size_t thisCorrect = 0;
size_t thisLabels = groundTruth.size();
for(int i = 0; i < thisLabels; ++i) {
// CPU-side Argmax
Word bestWord = Word::NONE;
float bestValue = std::numeric_limits<float>::lowest();
for(IndexType j = 0; j < cols; ++j) {
float currValue = vLogits[i * cols + j];
if(currValue > bestValue) {
bestValue = currValue;
bestWord = Word::fromWordIndex(j);
}
}
thisCorrect += (size_t)(bestWord == groundTruth[i]);
}
std::unique_lock<std::mutex> lock(mutex_);
totalLabels += thisLabels;
correct += thisCorrect;
};
{
threadPool_.reserve(graphs.size());
TaskBarrier taskBarrier;
for(auto batch : *batchGenerator_)
taskBarrier.push_back(threadPool_.enqueue(task, batch));
// ~TaskBarrier waits until all are done
}
return (float)correct / (float)totalLabels;
}
BertAccuracyValidator::BertAccuracyValidator(std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> options,
bool evalMaskedLM)
: Validator(vocabs, options, /*lowerIsBetter=*/false), evalMaskedLM_(evalMaskedLM) {
createBatchGenerator(/*isTranslating=*/false);
// @TODO: remove, only used for saving?
builder_ = models::createModelFromOptions(options_, models::usage::raw);
}
float BertAccuracyValidator::validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) {
size_t correct = 0;
size_t totalLabels = 0;
size_t batchId = 0;
std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
auto task = [=, &correct, &totalLabels, &graphQueue](BatchPtr batch, size_t batchId) {
thread_local Ptr<ExpressionGraph> graph;
if(!graph) {
std::unique_lock<std::mutex> lock(mutex_);
ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
graph = graphQueue.front();
graphQueue.pop_front();
}
auto builder = models::createModelFromOptions(options_, models::usage::raw);
thread_local std::unique_ptr<std::mt19937> engine;
if(!engine)
engine.reset(new std::mt19937((unsigned int)(Config::seed + batchId)));
auto bertBatch = New<data::BertBatch>(batch,
*engine,
options_->get<float>("bert-masking-fraction"),
options_->get<std::string>("bert-mask-symbol"),
options_->get<std::string>("bert-sep-symbol"),
options_->get<std::string>("bert-class-symbol"),
options_->get<int>("bert-type-vocab-size"));
builder->clear(graph);
auto classifierStates
= std::dynamic_pointer_cast<BertEncoderClassifier>(builder)->apply(graph, bertBatch, true);
graph->forward();
auto maskedLMLogits = classifierStates[0]->getLogProbs();
const auto& maskedLMLabels = bertBatch->bertMaskedWords();
auto sentenceLogits = classifierStates[1]->getLogProbs();
const auto& sentenceLabels = bertBatch->back()->data();
auto count = [=, &correct, &totalLabels](Expr logits, const Words& labels) {
IndexType cols = logits->shape()[-1];
size_t thisCorrect = 0;
size_t thisLabels = labels.size();
std::vector<float> vLogits;
logits->val()->get(vLogits);
for(int i = 0; i < thisLabels; ++i) {
// CPU-side Argmax
IndexType bestIndex = 0;
float bestValue = std::numeric_limits<float>::lowest();
for(IndexType j = 0; j < cols; ++j) {
float currValue = vLogits[i * cols + j];
if(currValue > bestValue) {
bestValue = currValue;
bestIndex = j;
}
}
thisCorrect += (size_t)(bestIndex == labels[i].toWordIndex());
}
std::unique_lock<std::mutex> lock(mutex_);
totalLabels += thisLabels;
correct += thisCorrect;
};
if(evalMaskedLM_)
count(maskedLMLogits, maskedLMLabels);
else
count(sentenceLogits, sentenceLabels);
};
{
threadPool_.reserve(graphs.size());
TaskBarrier taskBarrier;
for(auto batch : *batchGenerator_) {
taskBarrier.push_back(threadPool_.enqueue(task, batch, batchId));
batchId++;
}
// ~TaskBarrier waits until all are done
}
return (float)correct / (float)totalLabels;
}
ScriptValidator::ScriptValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, false) {
// @TODO: remove, only used for saving?
builder_ = models::createModelFromOptions(options_, models::usage::raw);
ABORT_IF(!options_->hasAndNotEmpty("valid-script-path"),
"valid-script metric but no script given");
}
float ScriptValidator::validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
Ptr<const TrainingState> /*ignored*/) {
using namespace data;
auto model = options_->get<std::string>("model");
std::string suffix = model.substr(model.size() - 4);
ABORT_IF(suffix != ".npz" && suffix != ".bin", "Unknown model suffix {}", suffix);
builder_->save(graphs[0], model + ".dev" + suffix, true);
auto valStr = utils::exec(options_->get<std::string>("valid-script-path"),
options_->get<std::vector<std::string>>("valid-script-args"));
float val = (float)std::atof(valStr.c_str());
updateStalled(graphs, val);
return val;
}
TranslationValidator::TranslationValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, false), quiet_(options_->get<bool>("quiet-translation")) {
// @TODO: remove, only used for saving?
builder_ = models::createModelFromOptions(options_, models::usage::translation);
if(!options_->hasAndNotEmpty("valid-script-path"))
LOG_VALID(warn, "No post-processing script given for validating translator");
createBatchGenerator(/*isTranslating=*/true);
}
float TranslationValidator::validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
Ptr<const TrainingState> state) {
using namespace data;
// Generate batches
batchGenerator_->prepare();
// Create scorer
auto model = options_->get<std::string>("model");
std::vector<Ptr<Scorer>> scorers;
for(auto graph : graphs) {
auto builder = models::createModelFromOptions(options_, models::usage::translation);
Ptr<Scorer> scorer = New<ScorerWrapper>(builder, "", 1.0f, model);
scorers.push_back(scorer); // @TODO: should this be done in the contructor?
}
// Set up output file
std::string fileName;
Ptr<io::TemporaryFile> tempFile;
if(options_->hasAndNotEmpty("valid-translation-output")) {
fileName = options_->get<std::string>("valid-translation-output");
// fileName can be a template with fields for training state parameters:
fileName = state->fillTemplate(fileName);
} else {
tempFile.reset(new io::TemporaryFile(options_->get<std::string>("tempdir"), false));
fileName = tempFile->getFileName();
}
for(auto graph : graphs)
graph->setInference(true);
if(!quiet_)
LOG(info, "Translating validation set...");
timer::Timer timer;
{
auto printer = New<OutputPrinter>(options_, vocabs_.back());
// @TODO: This can be simplified. If there is no "valid-translation-output", fileName already
// contains the name of temporary file that should be used?
auto collector = options_->hasAndNotEmpty("valid-translation-output")
? New<OutputCollector>(fileName)
: New<OutputCollector>(tempFile->getFileName());
if(quiet_)
collector->setPrintingStrategy(New<QuietPrinting>());
else
collector->setPrintingStrategy(New<GeometricPrinting>());
std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
std::deque<Ptr<Scorer>> scorerQueue(scorers.begin(), scorers.end());
auto task = [=, &graphQueue, &scorerQueue](BatchPtr batch) {
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<Scorer> scorer;
if(!graph) {
std::unique_lock<std::mutex> lock(mutex_);
ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
graph = graphQueue.front();
graphQueue.pop_front();
ABORT_IF(scorerQueue.empty(), "Asking for scorer, but none left on queue");
scorer = scorerQueue.front();
scorerQueue.pop_front();
}
auto search = New<BeamSearch>(options_, std::vector<Ptr<Scorer>>{scorer}, vocabs_.back());
auto histories = search->search(graph, batch);
for(auto history : histories) {
std::stringstream best1;
std::stringstream bestn;
printer->print(history, best1, bestn);
collector->Write(
(long)history->getLineNum(), best1.str(), bestn.str(), options_->get<bool>("n-best"));
}
};
threadPool_.reserve(graphs.size());
TaskBarrier taskBarrier;
for(auto batch : *batchGenerator_)
taskBarrier.push_back(threadPool_.enqueue(task, batch));
// ~TaskBarrier waits until all are done
}
if(!quiet_)
LOG(info, "Total translation time: {:.5f}s", timer.elapsed());
for(auto graph : graphs)
graph->setInference(false);
float val = 0.0f;
// Run post-processing script if given
if(options_->hasAndNotEmpty("valid-script-path")) {
// auto command = options_->get<std::string>("valid-script-path") + " " + fileName;
// auto valStr = utils::exec(command);
auto valStr = utils::exec(options_->get<std::string>("valid-script-path"),
options_->get<std::vector<std::string>>("valid-script-args"),
fileName);
val = (float)std::atof(valStr.c_str());
updateStalled(graphs, val);
}
return val;
};
SacreBleuValidator::SacreBleuValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options, const std::string& metric)
: Validator(vocabs, options, /*lowerIsBetter=*/false),
metric_(metric),
computeChrF_(metric == "chrf"),
useWordIds_(metric == "bleu-segmented"),
quiet_(options_->get<bool>("quiet-translation")) {
ABORT_IF(computeChrF_ && useWordIds_, "Cannot compute ChrF on word ids"); // should not really happen, but let's check.
if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF,
order_ = 6; // we compute stats over character ngrams up to length 6
// @TODO: remove, only used for saving?
builder_ = models::createModelFromOptions(options_, models::usage::translation);
auto vocab = vocabs_.back();
createBatchGenerator(/*isTranslating=*/true);
}
float SacreBleuValidator::validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
Ptr<const TrainingState> state) {
using namespace data;
// Generate batches
batchGenerator_->prepare();
// Create scorer
auto model = options_->get<std::string>("model");
// @TODO: check if required - Temporary options for translation
auto mopts = New<Options>();
mopts->merge(options_);
mopts->set("inference", true);
std::vector<Ptr<Scorer>> scorers;
for(auto graph : graphs) {
auto builder = models::createModelFromOptions(options_, models::usage::translation);
Ptr<Scorer> scorer = New<ScorerWrapper>(builder, "", 1.0f, model);
scorers.push_back(scorer);
}
for(auto graph : graphs)
graph->setInference(true);
if(!quiet_)
LOG(info, "Translating validation set...");
// For BLEU
// 0: 1-grams matched, 1: 1-grams cand total, 2: 1-grams ref total (used in ChrF)
// ...,
// n: reference length (used in BLEU)
std::vector<float> stats(statsPerOrder * order_ + 1, 0.f);
timer::Timer timer;
{
auto printer = New<OutputPrinter>(options_, vocabs_.back());
Ptr<OutputCollector> collector;
if(options_->hasAndNotEmpty("valid-translation-output")) {
auto fileName = options_->get<std::string>("valid-translation-output");
// fileName can be a template with fields for training state parameters:
fileName = state->fillTemplate(fileName);
collector = New<OutputCollector>(fileName); // for debugging
} else {
collector = New<OutputCollector>(/* null */); // don't print, but log
}
if(quiet_)
collector->setPrintingStrategy(New<QuietPrinting>());
else
collector->setPrintingStrategy(New<GeometricPrinting>());
std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
std::deque<Ptr<Scorer>> scorerQueue(scorers.begin(), scorers.end());
auto task = [=, &stats, &graphQueue, &scorerQueue](BatchPtr batch) {
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<Scorer> scorer;
if(!graph) {
std::unique_lock<std::mutex> lock(mutex_);
ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
graph = graphQueue.front();
graphQueue.pop_front();
ABORT_IF(scorerQueue.empty(), "Asking for scorer, but none left on queue");
scorer = scorerQueue.front();
scorerQueue.pop_front();
}
auto search = New<BeamSearch>(options_, std::vector<Ptr<Scorer>>{scorer}, vocabs_.back());
auto histories = search->search(graph, batch);
size_t no = 0;
std::lock_guard<std::mutex> statsLock(mutex_);
for(auto history : histories) {
auto result = history->top();
const auto& words = std::get<0>(result);
updateStats(stats, words, batch, no);
std::stringstream best1;
std::stringstream bestn;
printer->print(history, best1, bestn);
collector->Write((long)history->getLineNum(),
best1.str(),
bestn.str(),
/*nbest=*/false);
no++;
}
};
threadPool_.reserve(graphs.size());
TaskBarrier taskBarrier;
for(auto batch : *batchGenerator_)
taskBarrier.push_back(threadPool_.enqueue(task, batch));
// ~TaskBarrier waits until all are done
}
if(!quiet_)
LOG(info, "Total translation time: {:.5f}s", timer.elapsed());
for(auto graph : graphs)
graph->setInference(false);
float val = computeChrF_ ? calcChrF(stats) : calcBLEU(stats);
updateStalled(graphs, val);
return val;
}
std::vector<std::string> SacreBleuValidator::decode(const Words& words, bool addEOS) {
auto vocab = vocabs_.back();
auto tokenString = vocab->surfaceForm(words); // detokenize to surface form
auto vocabType = vocab->type();
if(vocabType == "FactoredVocab" || vocabType == "SentencePieceVocab") {
LOG_VALID_ONCE(info, "Decoding validation set with {} for scoring", vocabType);
tokenString = tokenize(tokenString); // tokenize according to SacreBLEU rules
if(!computeChrF_) // for ChrF, we break into characters below, so no need to do this here
tokenString = tokenizeContinuousScript(tokenString); // CJT scripts only: further break into characters
} else {
LOG_VALID_ONCE(info, "{} keeps original segments for scoring", vocabType);
}
auto tokens = computeChrF_ ? splitIntoUnicodeChars(tokenString, /*removeWhiteSpace=*/true) // break into vector of unicode chars (as utf8 strings) for ChrF
: utils::splitAny(tokenString, " ", /*keepEmpty=*/false); // or just split according to whitespace for BLEU
if(addEOS)
tokens.push_back("</s>");
return tokens;
}
void SacreBleuValidator::updateStats(std::vector<float>& stats,
const Words& cand,
const Ptr<data::Batch> batch,
size_t no) {
auto vocab = vocabs_.back();
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
auto subBatch = corpusBatch->back();
size_t size = subBatch->batchSize();
size_t width = subBatch->batchWidth();
Words ref; // fill ref
for(size_t i = 0; i < width; ++i) {
Word w = subBatch->data()[i * size + no];
if(w == vocab->getEosId())
break;
if(w == vocab->getUnkId())
LOG_VALID_ONCE(info, "References contain unknown word, metric scores may be inaccurate");
ref.push_back(w);
}
LOG_VALID_ONCE(info, "First sentence's tokens as scored:");
LOG_VALID_ONCE(info, " Hyp: {}", utils::join(decode(cand, /*addEOS=*/false)));
LOG_VALID_ONCE(info, " Ref: {}", utils::join(decode(ref, /*addEOS=*/false)));
if(useWordIds_)
updateStats(stats, cand, ref);
else
updateStats(stats, decode(cand, /*addEOS=*/false), decode(ref, /*addEOS=*/false));
}
// Re-implementation of BLEU metric from SacreBLEU
float SacreBleuValidator::calcBLEU(const std::vector<float>& stats) {
float logbleu = 0;
for(int i = 0; i < order_; ++i) {
float commonNgrams = stats[statsPerOrder * i + 0];
float hypothesesNgrams = stats[statsPerOrder * i + 1];
if(commonNgrams == 0.f)
return 0.f;
logbleu += std::log(commonNgrams) - std::log(hypothesesNgrams);
}
logbleu /= order_;
float refLen = stats[statsPerOrder * order_];
float hypUnigrams = stats[1];
float brev_penalty = 1.f - std::max(refLen / hypUnigrams, 1.f);
return std::exp(logbleu + brev_penalty) * 100.f;
}
// Re-implementation of ChrF metric from SacreBLEU, using standard parameters
float SacreBleuValidator::calcChrF(const std::vector<float>& stats) {
float beta = 2.f;
float avgPrecision = 0.f;
float avgRecall = 0.f;
size_t effectiveOrder = 0;
for(size_t i = 0; i < order_; ++i) {
float commonNgrams = stats[statsPerOrder * i + 0];
float hypothesesNgrams = stats[statsPerOrder * i + 1];
float referencesNgrams = stats[statsPerOrder * i + 2];
if(hypothesesNgrams > 0 && referencesNgrams > 0) {
avgPrecision += commonNgrams / hypothesesNgrams;
avgRecall += commonNgrams / referencesNgrams;
effectiveOrder++;
}
}
if(effectiveOrder == 0)
return 0.f;
avgPrecision /= effectiveOrder;
avgRecall /= effectiveOrder;
if(avgPrecision + avgRecall == 0.f)
return 0.f;
auto betaSquare = beta * beta;
auto score = (1.f + betaSquare) * (avgPrecision * avgRecall) / ((betaSquare * avgPrecision) + avgRecall);
return score * 100.f; // we multiply by 100 which is usually not done for ChrF, but this makes it more comparable to BLEU
}
} // namespace marian