.. _program_listing_file_src_training_validator.h: Program Listing for File validator.h ==================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/training/validator.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "3rd_party/threadpool.h" #include "common/config.h" #include "common/timer.h" #include "common/utils.h" #include "common/regex.h" #include "common/utils.h" #include "data/batch_generator.h" #include "data/corpus.h" #include "graph/expression_graph.h" #include "training/training_state.h" #include "translator/beam_search.h" #include "translator/history.h" #include "translator/output_collector.h" #include "translator/output_printer.h" #include "translator/scorers.h" #include "models/bert.h" #include #include #include namespace marian { class ValidatorBase : public TrainingObserver { protected: bool lowerIsBetter_{true}; float lastBest_; size_t stalled_{0}; std::mutex mutex_; ThreadPool threadPool_; public: ValidatorBase(bool lowerIsBetter) : lowerIsBetter_(lowerIsBetter), lastBest_{initScore()} {} virtual ~ValidatorBase() {} virtual float validate(const std::vector>& graphs, Ptr state) = 0; virtual std::string type() = 0; float& lastBest() { return lastBest_; } size_t& stalled() { return stalled_; } virtual float initScore(); virtual void actAfterLoaded(TrainingState& state) override; }; template // @TODO: BuilderType doesn't really serve a purpose here? Review and remove. class Validator : public ValidatorBase { public: virtual ~Validator() {} Validator(std::vector> vocabs, Ptr options, bool lowerIsBetter = true) : ValidatorBase(lowerIsBetter), vocabs_(vocabs), // options_ is a clone of global options, so it can be safely modified within the class options_(New(options->clone())) { // set options common for all validators options_->set("inference", true); options_->set("shuffle", "none"); // don't shuffle validation sets if(options_->has("valid-max-length")) { options_->set("max-length", options_->get("valid-max-length")); options_->set("max-length-crop", true); // @TODO: make this configureable } // @TODO: make this work with mini-batch-fit etc. if(options_->has("valid-mini-batch")) { options_->set("mini-batch", options_->get("valid-mini-batch")); options_->set("mini-batch-words", 0); } options_->set("mini-batch-sort", "src"); options_->set("maxi-batch", 10); } typedef typename DataSet::batch_ptr BatchPtr; protected: // Create the BatchGenerator. Note that ScriptValidator does not use batchGenerator_. void createBatchGenerator(bool /*isTranslating*/) { // Create corpus auto validPaths = options_->get>("valid-sets"); auto corpus = New(validPaths, vocabs_, options_); // Create batch generator batchGenerator_ = New>(corpus, options_); } public: virtual float validate(const std::vector>& graphs, Ptr /*ignored*/) override { for(auto graph : graphs) graph->setInference(true); batchGenerator_->prepare(); // Validate on batches float val = validateBG(graphs); updateStalled(graphs, val); for(auto graph : graphs) graph->setInference(false); return val; }; protected: std::vector> vocabs_; Ptr options_; Ptr builder_; // @TODO: remove, this is not guaranteed to be state-free, hence not thread-safe, but we are using validators with multi-threading. Ptr> batchGenerator_; virtual float validateBG(const std::vector>&) = 0; void updateStalled(const std::vector>& graphs, float val) { if((lowerIsBetter_ && lastBest_ > val) || (!lowerIsBetter_ && lastBest_ < val)) { stalled_ = 0; lastBest_ = val; if(options_->get("keep-best")) keepBest(graphs); } else /* if (lastBest_ != val) */ { // (special case 0 at start) @TODO: needed? Seems stall count gets reset each time it does improve. If not needed, remove "if(...)" again. stalled_++; } } virtual void keepBest(const std::vector>& graphs) { auto model = options_->get("model"); std::string suffix = model.substr(model.size() - 4); ABORT_IF(suffix != ".npz" && suffix != ".bin", "Unknown model suffix {}", suffix); builder_->save(graphs[0], model + ".best-" + type() + suffix, true); } }; class CrossEntropyValidator : public Validator { using Validator::BatchPtr; public: CrossEntropyValidator(std::vector> vocabs, Ptr options); virtual ~CrossEntropyValidator() {} std::string type() override { return options_->get("cost-type"); } protected: virtual float validateBG(const std::vector>& graphs) override; }; // Used for validating with classifiers. Compute prediction accuracy versus ground truth for a set of classes class AccuracyValidator : public Validator { public: AccuracyValidator(std::vector> vocabs, Ptr options); virtual ~AccuracyValidator() {} std::string type() override { return "accuracy"; } protected: virtual float validateBG(const std::vector>& graphs) override; }; class BertAccuracyValidator : public Validator { private: bool evalMaskedLM_{true}; public: BertAccuracyValidator(std::vector> vocabs, Ptr options, bool evalMaskedLM); virtual ~BertAccuracyValidator() {} std::string type() override { if(evalMaskedLM_) return "bert-lm-accuracy"; else return "bert-sentence-accuracy"; } protected: virtual float validateBG(const std::vector>& graphs) override; }; class ScriptValidator : public Validator { public: ScriptValidator(std::vector> vocabs, Ptr options); virtual ~ScriptValidator() {} virtual float validate(const std::vector>& graphs, Ptr /*ignored*/) override; std::string type() override { return "valid-script"; } protected: virtual float validateBG(const std::vector>& /*graphs*/) override { return 0; } }; // validator that translates and computes BLEU (or any metric) with an external script class TranslationValidator : public Validator { public: TranslationValidator(std::vector> vocabs, Ptr options); virtual ~TranslationValidator() {} virtual float validate(const std::vector>& graphs, Ptr state) override; std::string type() override { return "translation"; } protected: bool quiet_{false}; virtual float validateBG(const std::vector>& /*graphs*/) override { return 0; } }; // validator that translates and computes BLEU/ChrF internally, with or without decoding // Aims to follow SacreBLEU as close as possible. // @TODO: combine with TranslationValidator (above) to avoid code duplication class SacreBleuValidator : public Validator { public: SacreBleuValidator(std::vector> vocabs, Ptr options, const std::string& metric); virtual ~SacreBleuValidator() {} virtual float validate(const std::vector>& graphs, Ptr state) override; std::string type() override { return metric_; } protected: // Tokenizer function adapted from multi-bleu-detok.pl, corresponds to sacreBLEU.py static std::string tokenize(const std::string& text) { std::string normText = text; // language-independent part: normText = regex::regex_replace(normText, regex::regex(""), ""); // strip "skipped" tags normText = regex::regex_replace(normText, regex::regex("-\\n"), ""); // strip end-of-line hyphenation and join lines normText = regex::regex_replace(normText, regex::regex("\\n"), " "); // join lines normText = regex::regex_replace(normText, regex::regex("""), "\""); // convert SGML tag for quote to " normText = regex::regex_replace(normText, regex::regex("&"), "&"); // convert SGML tag for ampersand to & normText = regex::regex_replace(normText, regex::regex("<"), "<"); //convert SGML tag for less-than to > normText = regex::regex_replace(normText, regex::regex(">"), ">"); //convert SGML tag for greater-than to < // language-dependent part (assuming Western languages): normText = " " + normText + " "; normText = regex::regex_replace(normText, regex::regex("([\\{-\\~\\[-\\` -\\&\\(-\\+\\:-\\@\\/])"), " $1 "); // tokenize punctuation normText = regex::regex_replace(normText, regex::regex("([^0-9])([\\.,])"), "$1 $2 "); // tokenize period and comma unless preceded by a digit normText = regex::regex_replace(normText, regex::regex("([\\.,])([^0-9])"), " $1 $2"); // tokenize period and comma unless followed by a digit normText = regex::regex_replace(normText, regex::regex("([0-9])(-)"), "$1 $2 "); // tokenize dash when preceded by a digit normText = regex::regex_replace(normText, regex::regex("\\s+"), " "); // one space only between words normText = regex::regex_replace(normText, regex::regex("^\\s+"), ""); // no leading space normText = regex::regex_replace(normText, regex::regex("\\s+$"), ""); // no trailing space return normText; } static std::string tokenizeContinuousScript(const std::string& sUTF8) { // We want BLEU-like scores that are comparable across different tokenization schemes. // For continuous scripts (Chinese, Japanese, Thai), we would need a language-specific // statistical word segmenter, which is outside the scope of Marian. As a practical // compromise, we segment continuous-script sequences into individual characters, while // leaving Western scripts as words. This way we can use the same settings for Western // languages, where Marian would report SacreBLEU scores, and Asian languages, where // scores are not standard but internally comparable across tokenization schemes. // @TODO: Check what sacrebleu.py is doing, and whether we can replicate that here faithfully. std::u32string in = utils::utf8ToUnicodeString(sUTF8); std::u32string out = in.substr(0, 0); // (out should be same type as in, don't want to bother with exact type) for (auto c : in) { bool isCS = utils::isContinuousScript(c); if (isCS) // surround continuous-script chars by spaces on each side out.push_back(' '); // (duplicate spaces are ignored when splitting later) out.push_back(c); if (isCS) out.push_back(' '); } return utils::utf8FromUnicodeString(out); } static std::vector splitIntoUnicodeChars(const std::string& sUTF8, bool removeWhiteSpace=true) { std::u32string in = utils::utf8ToUnicodeString(sUTF8); std::u32string space = utils::utf8ToUnicodeString(" "); std::vector out; for(char32_t c : in) { std::u32string temp(1, c); if(removeWhiteSpace && temp != space) out.push_back(utils::utf8FromUnicodeString(temp)); } return out; } std::vector decode(const Words& words, bool addEOS = false); // Update document-wide sufficient statistics for BLEU with single sentence n-gram stats. template void updateStats(std::vector& stats, const std::vector& cand, const std::vector& ref) { auto countNgrams = [this](const std::vector& tokens) { std::map, size_t> ngramCounts; for(size_t i = 0; i < tokens.size(); ++i) { // template deduction for std::min seems to be weird under VS due to // macros in windows.h hence explicit type to avoid macro parsing. for(size_t len = 1; len <= std::min(order_, tokens.size() - i); ++len) { std::vector ngram(len); std::copy(tokens.begin() + i, tokens.begin() + i + len, ngram.begin()); ngramCounts[ngram]++; } } return ngramCounts; }; auto cgrams = countNgrams(cand); auto rgrams = countNgrams(ref); for(auto& ngramcount : cgrams) { size_t order = ngramcount.first.size() - 1; size_t tc = ngramcount.second; size_t rc = rgrams[ngramcount.first]; stats[statsPerOrder * order + 0] += std::min(tc, rc); // count common ngrams (for BLEU and ChrF) stats[statsPerOrder * order + 1] += tc; // count hypotheses ngrams (for BLEU and ChrF) } if(computeChrF_) { for(auto& ngramcount : rgrams) { size_t order = ngramcount.first.size() - 1; size_t rc = ngramcount.second; stats[statsPerOrder * order + 2] += rc; // count reference ngrams (for ChrF) } } stats[statsPerOrder * order_] += ref.size(); // reference length for BLEU (technically same as stats[2], but let's keep it separate) } // Extract matching target reference from batch and pass on to update BLEU stats void updateStats(std::vector& stats, const Words& cand, const Ptr batch, size_t no); float calcBLEU(const std::vector& stats); float calcChrF(const std::vector& stats); virtual float validateBG(const std::vector>& /*graphs*/) override { return 0; } private: const std::string metric_; // allowed values are: bleu, bleu-detok (same as bleu), bleu-segmented, chrf bool computeChrF_{ false }; // should we compute ChrF instead of BLEU (BLEU by default)? size_t order_{ 4 }; // 4-grams for BLEU by default static const size_t statsPerOrder = 3; // 0: common ngrams, 1: candidate ngrams, 2: reference ngrams bool useWordIds_{ false }; // compute BLEU score by matching numeric segment ids bool quiet_{ false }; }; std::vector*/>> Validators( std::vector> vocabs, Ptr config); } // namespace marian