.. _program_listing_file_src_training_validator.cpp: Program Listing for File validator.cpp ====================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/training/validator.cpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "training/validator.h" namespace marian { std::vector*/>> Validators( std::vector> vocabs, Ptr config) { std::vector*/>> validators; auto validMetrics = config->get>("valid-metrics"); std::vector ceMetrics = {"cross-entropy", "ce-mean", "ce-sum", "ce-mean-words", "perplexity"}; for(auto metric : validMetrics) { if(std::find(ceMetrics.begin(), ceMetrics.end(), metric) != ceMetrics.end()) { Ptr opts = New(*config); opts->set("cost-type", metric); auto validator = New(vocabs, opts); validators.push_back(validator); } else if(metric == "valid-script") { auto validator = New(vocabs, config); validators.push_back(validator); } else if(metric == "translation") { auto validator = New(vocabs, config); validators.push_back(validator); } else if(metric == "bleu" || metric == "bleu-detok" || metric == "bleu-segmented" || metric == "chrf") { auto validator = New(vocabs, config, metric); validators.push_back(validator); } else if(metric == "accuracy") { auto validator = New(vocabs, config); validators.push_back(validator); } else if(metric == "bert-lm-accuracy") { auto validator = New(vocabs, config, true); validators.push_back(validator); } else if(metric == "bert-sentence-accuracy") { auto validator = New(vocabs, config, false); validators.push_back(validator); } else { ABORT("Unknown validation metric: {}", metric); } } return validators; } float ValidatorBase::initScore() { return lowerIsBetter_ ? std::numeric_limits::max() : std::numeric_limits::lowest(); } void ValidatorBase::actAfterLoaded(TrainingState& state) { if(state.validators[type()]) { lastBest_ = state.validators[type()]["last-best"].as(); stalled_ = state.validators[type()]["stalled"].as(); } } CrossEntropyValidator::CrossEntropyValidator(std::vector> vocabs, Ptr options) : Validator(vocabs, options) { createBatchGenerator(/*isTranslating=*/false); auto opts = options_->with("inference", true, // @TODO: check if required "cost-type", "ce-sum"); // @TODO: remove, only used for saving? builder_ = models::createCriterionFunctionFromOptions(opts, models::usage::scoring); } float CrossEntropyValidator::validateBG(const std::vector>& graphs) { auto ctype = options_->get("cost-type"); // @TODO: use with(...) everywhere, this will help with creating immutable options. // Make options const everywhere and get rid of "set"? auto opts = options_->with("inference", true, "cost-type", "ce-sum"); StaticLoss loss; size_t samples = 0; std::deque> graphQueue(graphs.begin(), graphs.end()); auto task = [=, &loss, &samples, &graphQueue](BatchPtr batch) { thread_local Ptr graph; if(!graph) { std::unique_lock lock(mutex_); ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue"); graph = graphQueue.front(); graphQueue.pop_front(); } auto builder = models::createCriterionFunctionFromOptions(options_, models::usage::scoring); builder->clear(graph); auto dynamicLoss = builder->build(graph, batch); graph->forward(); std::unique_lock lock(mutex_); loss += *dynamicLoss; samples += batch->size(); }; { threadPool_.reserve(graphs.size()); TaskBarrier taskBarrier; for(auto batch : *batchGenerator_) taskBarrier.push_back(threadPool_.enqueue(task, batch)); // ~TaskBarrier waits until all are done } if(ctype == "perplexity") return std::exp(loss.loss / loss.count); if(ctype == "ce-mean-words") return loss.loss / loss.count; if(ctype == "ce-sum") return loss.loss; else return loss.loss / samples; // @TODO: back-compat, to be removed } AccuracyValidator::AccuracyValidator(std::vector> vocabs, Ptr options) : Validator(vocabs, options, /*lowerIsBetter=*/false) { createBatchGenerator(/*isTranslating=*/false); // @TODO: remove, only used for saving? builder_ = models::createModelFromOptions(options_, models::usage::raw); } float AccuracyValidator::validateBG(const std::vector>& graphs) { size_t correct = 0; size_t totalLabels = 0; std::deque> graphQueue(graphs.begin(), graphs.end()); auto task = [=, &correct, &totalLabels, &graphQueue](BatchPtr batch) { thread_local Ptr graph; if(!graph) { std::unique_lock lock(mutex_); ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue"); graph = graphQueue.front(); graphQueue.pop_front(); } auto builder = models::createModelFromOptions(options_, models::usage::raw); builder->clear(graph); Expr logits = builder->build(graph, batch).getLogits(); graph->forward(); std::vector vLogits; logits->val()->get(vLogits); const auto& groundTruth = batch->back()->data(); IndexType cols = logits->shape()[-1]; size_t thisCorrect = 0; size_t thisLabels = groundTruth.size(); for(int i = 0; i < thisLabels; ++i) { // CPU-side Argmax Word bestWord = Word::NONE; float bestValue = std::numeric_limits::lowest(); for(IndexType j = 0; j < cols; ++j) { float currValue = vLogits[i * cols + j]; if(currValue > bestValue) { bestValue = currValue; bestWord = Word::fromWordIndex(j); } } thisCorrect += (size_t)(bestWord == groundTruth[i]); } std::unique_lock lock(mutex_); totalLabels += thisLabels; correct += thisCorrect; }; { threadPool_.reserve(graphs.size()); TaskBarrier taskBarrier; for(auto batch : *batchGenerator_) taskBarrier.push_back(threadPool_.enqueue(task, batch)); // ~TaskBarrier waits until all are done } return (float)correct / (float)totalLabels; } BertAccuracyValidator::BertAccuracyValidator(std::vector> vocabs, Ptr options, bool evalMaskedLM) : Validator(vocabs, options, /*lowerIsBetter=*/false), evalMaskedLM_(evalMaskedLM) { createBatchGenerator(/*isTranslating=*/false); // @TODO: remove, only used for saving? builder_ = models::createModelFromOptions(options_, models::usage::raw); } float BertAccuracyValidator::validateBG(const std::vector>& graphs) { size_t correct = 0; size_t totalLabels = 0; size_t batchId = 0; std::deque> graphQueue(graphs.begin(), graphs.end()); auto task = [=, &correct, &totalLabels, &graphQueue](BatchPtr batch, size_t batchId) { thread_local Ptr graph; if(!graph) { std::unique_lock lock(mutex_); ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue"); graph = graphQueue.front(); graphQueue.pop_front(); } auto builder = models::createModelFromOptions(options_, models::usage::raw); thread_local std::unique_ptr engine; if(!engine) engine.reset(new std::mt19937((unsigned int)(Config::seed + batchId))); auto bertBatch = New(batch, *engine, options_->get("bert-masking-fraction"), options_->get("bert-mask-symbol"), options_->get("bert-sep-symbol"), options_->get("bert-class-symbol"), options_->get("bert-type-vocab-size")); builder->clear(graph); auto classifierStates = std::dynamic_pointer_cast(builder)->apply(graph, bertBatch, true); graph->forward(); auto maskedLMLogits = classifierStates[0]->getLogProbs(); const auto& maskedLMLabels = bertBatch->bertMaskedWords(); auto sentenceLogits = classifierStates[1]->getLogProbs(); const auto& sentenceLabels = bertBatch->back()->data(); auto count = [=, &correct, &totalLabels](Expr logits, const Words& labels) { IndexType cols = logits->shape()[-1]; size_t thisCorrect = 0; size_t thisLabels = labels.size(); std::vector vLogits; logits->val()->get(vLogits); for(int i = 0; i < thisLabels; ++i) { // CPU-side Argmax IndexType bestIndex = 0; float bestValue = std::numeric_limits::lowest(); for(IndexType j = 0; j < cols; ++j) { float currValue = vLogits[i * cols + j]; if(currValue > bestValue) { bestValue = currValue; bestIndex = j; } } thisCorrect += (size_t)(bestIndex == labels[i].toWordIndex()); } std::unique_lock lock(mutex_); totalLabels += thisLabels; correct += thisCorrect; }; if(evalMaskedLM_) count(maskedLMLogits, maskedLMLabels); else count(sentenceLogits, sentenceLabels); }; { threadPool_.reserve(graphs.size()); TaskBarrier taskBarrier; for(auto batch : *batchGenerator_) { taskBarrier.push_back(threadPool_.enqueue(task, batch, batchId)); batchId++; } // ~TaskBarrier waits until all are done } return (float)correct / (float)totalLabels; } ScriptValidator::ScriptValidator(std::vector> vocabs, Ptr options) : Validator(vocabs, options, false) { // @TODO: remove, only used for saving? builder_ = models::createModelFromOptions(options_, models::usage::raw); ABORT_IF(!options_->hasAndNotEmpty("valid-script-path"), "valid-script metric but no script given"); } float ScriptValidator::validate(const std::vector>& graphs, Ptr /*ignored*/) { using namespace data; auto model = options_->get("model"); std::string suffix = model.substr(model.size() - 4); ABORT_IF(suffix != ".npz" && suffix != ".bin", "Unknown model suffix {}", suffix); builder_->save(graphs[0], model + ".dev" + suffix, true); auto valStr = utils::exec(options_->get("valid-script-path"), options_->get>("valid-script-args")); float val = (float)std::atof(valStr.c_str()); updateStalled(graphs, val); return val; } TranslationValidator::TranslationValidator(std::vector> vocabs, Ptr options) : Validator(vocabs, options, false), quiet_(options_->get("quiet-translation")) { // @TODO: remove, only used for saving? builder_ = models::createModelFromOptions(options_, models::usage::translation); if(!options_->hasAndNotEmpty("valid-script-path")) LOG_VALID(warn, "No post-processing script given for validating translator"); createBatchGenerator(/*isTranslating=*/true); } float TranslationValidator::validate(const std::vector>& graphs, Ptr state) { using namespace data; // Generate batches batchGenerator_->prepare(); // Create scorer auto model = options_->get("model"); std::vector> scorers; for(auto graph : graphs) { auto builder = models::createModelFromOptions(options_, models::usage::translation); Ptr scorer = New(builder, "", 1.0f, model); scorers.push_back(scorer); // @TODO: should this be done in the contructor? } // Set up output file std::string fileName; Ptr tempFile; if(options_->hasAndNotEmpty("valid-translation-output")) { fileName = options_->get("valid-translation-output"); // fileName can be a template with fields for training state parameters: fileName = state->fillTemplate(fileName); } else { tempFile.reset(new io::TemporaryFile(options_->get("tempdir"), false)); fileName = tempFile->getFileName(); } for(auto graph : graphs) graph->setInference(true); if(!quiet_) LOG(info, "Translating validation set..."); timer::Timer timer; { auto printer = New(options_, vocabs_.back()); // @TODO: This can be simplified. If there is no "valid-translation-output", fileName already // contains the name of temporary file that should be used? auto collector = options_->hasAndNotEmpty("valid-translation-output") ? New(fileName) : New(tempFile->getFileName()); if(quiet_) collector->setPrintingStrategy(New()); else collector->setPrintingStrategy(New()); std::deque> graphQueue(graphs.begin(), graphs.end()); std::deque> scorerQueue(scorers.begin(), scorers.end()); auto task = [=, &graphQueue, &scorerQueue](BatchPtr batch) { thread_local Ptr graph; thread_local Ptr scorer; if(!graph) { std::unique_lock lock(mutex_); ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue"); graph = graphQueue.front(); graphQueue.pop_front(); ABORT_IF(scorerQueue.empty(), "Asking for scorer, but none left on queue"); scorer = scorerQueue.front(); scorerQueue.pop_front(); } auto search = New(options_, std::vector>{scorer}, vocabs_.back()); auto histories = search->search(graph, batch); for(auto history : histories) { std::stringstream best1; std::stringstream bestn; printer->print(history, best1, bestn); collector->Write( (long)history->getLineNum(), best1.str(), bestn.str(), options_->get("n-best")); } }; threadPool_.reserve(graphs.size()); TaskBarrier taskBarrier; for(auto batch : *batchGenerator_) taskBarrier.push_back(threadPool_.enqueue(task, batch)); // ~TaskBarrier waits until all are done } if(!quiet_) LOG(info, "Total translation time: {:.5f}s", timer.elapsed()); for(auto graph : graphs) graph->setInference(false); float val = 0.0f; // Run post-processing script if given if(options_->hasAndNotEmpty("valid-script-path")) { // auto command = options_->get("valid-script-path") + " " + fileName; // auto valStr = utils::exec(command); auto valStr = utils::exec(options_->get("valid-script-path"), options_->get>("valid-script-args"), fileName); val = (float)std::atof(valStr.c_str()); updateStalled(graphs, val); } return val; }; SacreBleuValidator::SacreBleuValidator(std::vector> vocabs, Ptr options, const std::string& metric) : Validator(vocabs, options, /*lowerIsBetter=*/false), metric_(metric), computeChrF_(metric == "chrf"), useWordIds_(metric == "bleu-segmented"), quiet_(options_->get("quiet-translation")) { ABORT_IF(computeChrF_ && useWordIds_, "Cannot compute ChrF on word ids"); // should not really happen, but let's check. if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF, order_ = 6; // we compute stats over character ngrams up to length 6 // @TODO: remove, only used for saving? builder_ = models::createModelFromOptions(options_, models::usage::translation); auto vocab = vocabs_.back(); createBatchGenerator(/*isTranslating=*/true); } float SacreBleuValidator::validate(const std::vector>& graphs, Ptr state) { using namespace data; // Generate batches batchGenerator_->prepare(); // Create scorer auto model = options_->get("model"); // @TODO: check if required - Temporary options for translation auto mopts = New(); mopts->merge(options_); mopts->set("inference", true); std::vector> scorers; for(auto graph : graphs) { auto builder = models::createModelFromOptions(options_, models::usage::translation); Ptr scorer = New(builder, "", 1.0f, model); scorers.push_back(scorer); } for(auto graph : graphs) graph->setInference(true); if(!quiet_) LOG(info, "Translating validation set..."); // For BLEU // 0: 1-grams matched, 1: 1-grams cand total, 2: 1-grams ref total (used in ChrF) // ..., // n: reference length (used in BLEU) std::vector stats(statsPerOrder * order_ + 1, 0.f); timer::Timer timer; { auto printer = New(options_, vocabs_.back()); Ptr collector; if(options_->hasAndNotEmpty("valid-translation-output")) { auto fileName = options_->get("valid-translation-output"); // fileName can be a template with fields for training state parameters: fileName = state->fillTemplate(fileName); collector = New(fileName); // for debugging } else { collector = New(/* null */); // don't print, but log } if(quiet_) collector->setPrintingStrategy(New()); else collector->setPrintingStrategy(New()); std::deque> graphQueue(graphs.begin(), graphs.end()); std::deque> scorerQueue(scorers.begin(), scorers.end()); auto task = [=, &stats, &graphQueue, &scorerQueue](BatchPtr batch) { thread_local Ptr graph; thread_local Ptr scorer; if(!graph) { std::unique_lock lock(mutex_); ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue"); graph = graphQueue.front(); graphQueue.pop_front(); ABORT_IF(scorerQueue.empty(), "Asking for scorer, but none left on queue"); scorer = scorerQueue.front(); scorerQueue.pop_front(); } auto search = New(options_, std::vector>{scorer}, vocabs_.back()); auto histories = search->search(graph, batch); size_t no = 0; std::lock_guard statsLock(mutex_); for(auto history : histories) { auto result = history->top(); const auto& words = std::get<0>(result); updateStats(stats, words, batch, no); std::stringstream best1; std::stringstream bestn; printer->print(history, best1, bestn); collector->Write((long)history->getLineNum(), best1.str(), bestn.str(), /*nbest=*/false); no++; } }; threadPool_.reserve(graphs.size()); TaskBarrier taskBarrier; for(auto batch : *batchGenerator_) taskBarrier.push_back(threadPool_.enqueue(task, batch)); // ~TaskBarrier waits until all are done } if(!quiet_) LOG(info, "Total translation time: {:.5f}s", timer.elapsed()); for(auto graph : graphs) graph->setInference(false); float val = computeChrF_ ? calcChrF(stats) : calcBLEU(stats); updateStalled(graphs, val); return val; } std::vector SacreBleuValidator::decode(const Words& words, bool addEOS) { auto vocab = vocabs_.back(); auto tokenString = vocab->surfaceForm(words); // detokenize to surface form auto vocabType = vocab->type(); if(vocabType == "FactoredVocab" || vocabType == "SentencePieceVocab") { LOG_VALID_ONCE(info, "Decoding validation set with {} for scoring", vocabType); tokenString = tokenize(tokenString); // tokenize according to SacreBLEU rules if(!computeChrF_) // for ChrF, we break into characters below, so no need to do this here tokenString = tokenizeContinuousScript(tokenString); // CJT scripts only: further break into characters } else { LOG_VALID_ONCE(info, "{} keeps original segments for scoring", vocabType); } auto tokens = computeChrF_ ? splitIntoUnicodeChars(tokenString, /*removeWhiteSpace=*/true) // break into vector of unicode chars (as utf8 strings) for ChrF : utils::splitAny(tokenString, " ", /*keepEmpty=*/false); // or just split according to whitespace for BLEU if(addEOS) tokens.push_back(""); return tokens; } void SacreBleuValidator::updateStats(std::vector& stats, const Words& cand, const Ptr batch, size_t no) { auto vocab = vocabs_.back(); auto corpusBatch = std::static_pointer_cast(batch); auto subBatch = corpusBatch->back(); size_t size = subBatch->batchSize(); size_t width = subBatch->batchWidth(); Words ref; // fill ref for(size_t i = 0; i < width; ++i) { Word w = subBatch->data()[i * size + no]; if(w == vocab->getEosId()) break; if(w == vocab->getUnkId()) LOG_VALID_ONCE(info, "References contain unknown word, metric scores may be inaccurate"); ref.push_back(w); } LOG_VALID_ONCE(info, "First sentence's tokens as scored:"); LOG_VALID_ONCE(info, " Hyp: {}", utils::join(decode(cand, /*addEOS=*/false))); LOG_VALID_ONCE(info, " Ref: {}", utils::join(decode(ref, /*addEOS=*/false))); if(useWordIds_) updateStats(stats, cand, ref); else updateStats(stats, decode(cand, /*addEOS=*/false), decode(ref, /*addEOS=*/false)); } // Re-implementation of BLEU metric from SacreBLEU float SacreBleuValidator::calcBLEU(const std::vector& stats) { float logbleu = 0; for(int i = 0; i < order_; ++i) { float commonNgrams = stats[statsPerOrder * i + 0]; float hypothesesNgrams = stats[statsPerOrder * i + 1]; if(commonNgrams == 0.f) return 0.f; logbleu += std::log(commonNgrams) - std::log(hypothesesNgrams); } logbleu /= order_; float refLen = stats[statsPerOrder * order_]; float hypUnigrams = stats[1]; float brev_penalty = 1.f - std::max(refLen / hypUnigrams, 1.f); return std::exp(logbleu + brev_penalty) * 100.f; } // Re-implementation of ChrF metric from SacreBLEU, using standard parameters float SacreBleuValidator::calcChrF(const std::vector& stats) { float beta = 2.f; float avgPrecision = 0.f; float avgRecall = 0.f; size_t effectiveOrder = 0; for(size_t i = 0; i < order_; ++i) { float commonNgrams = stats[statsPerOrder * i + 0]; float hypothesesNgrams = stats[statsPerOrder * i + 1]; float referencesNgrams = stats[statsPerOrder * i + 2]; if(hypothesesNgrams > 0 && referencesNgrams > 0) { avgPrecision += commonNgrams / hypothesesNgrams; avgRecall += commonNgrams / referencesNgrams; effectiveOrder++; } } if(effectiveOrder == 0) return 0.f; avgPrecision /= effectiveOrder; avgRecall /= effectiveOrder; if(avgPrecision + avgRecall == 0.f) return 0.f; auto betaSquare = beta * beta; auto score = (1.f + betaSquare) * (avgPrecision * avgRecall) / ((betaSquare * avgPrecision) + avgRecall); return score * 100.f; // we multiply by 100 which is usually not done for ChrF, but this makes it more comparable to BLEU } } // namespace marian