.. _program_listing_file_src_training_training_state.h: Program Listing for File training_state.h ========================================= |exhale_lsh| :ref:`Return to documentation for file ` (``src/training/training_state.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "common/definitions.h" #include "common/filesystem.h" #include "common/scheduling_parameter.h" #include "common/utils.h" #include #include namespace marian { class TrainingState; class TrainingObserver { public: virtual ~TrainingObserver() {} virtual void init(TrainingState&) {} virtual void actAfterEpoch(TrainingState&) {} virtual void actAfterBatches(TrainingState&) {} virtual void actAfterStalled(TrainingState&) {} virtual void actAfterLoaded(TrainingState&) {} }; class TrainingState { public: // Current epoch size_t epochs{1}; // The total number of updates since beginning of training --@TODO: rename to 'updates' size_t batches{0}; // The number of batches seen in this epoch --note: not updates; an update can consist of multiple batches size_t batchesEpoch{0}; // The number of sentences seen in this epoch --@TODO: rename to 'sentencesEpoch' size_t samplesEpoch{0}; // Number of word labels processed since beginning of training size_t labelsTotal{0}; // Values before previous update() call size_t prevLabelsTotal{0}, prevBatches{0}, prevEpochs{0}; // The number of stalled validations size_t stalled{0}; // The largest number of stalled validations so far size_t maxStalled{0}; std::string validator; // List of validators YAML::Node validators; // Reset optimizer parameters bool reset{false}; // Current learning rate, representing all adjustment processes and factors float eta; void updateEta(float dynamicBaseLR) { // note: no other function may write to 'eta' (besides load()) eta = dynamicBaseLR * factor; } // State-based multiplication factor for learning rate float factor{1.f}; // @TODO: should also have warmup period here? SchedulingParameter warmupStart; // has same unit as lr-warmup // Sum of costs since last display float costSum{0}; // Number of labels aggregated in // costSum since last display float costCount{0}; // Number of words seen since last display size_t wordsDisp{0}; // Number of samples/sentences seen since last display size_t samplesDisp{0}; // Number of updates seen since last display size_t updatesDisp{0}; // Running average of gradient norm float gradientNormAvg{0}; // Running variance of gradient norm float gradientNormVar{0}; // Running average of log gradient norm float logGradientNormAvg{0}; // Running variance of log gradient norm float logGradientNormVar{0}; // The state of the random number generator from a batch generator std::string seedBatch; // The state of the random number generator from a corpus std::string seedCorpus; // Set flag if training was resumed bool loaded{false}; // Set flag if the model was validated in the current batch bool validated{false}; TrainingState(float learnRate) { updateEta(learnRate); } void registerObserver(Ptr observer) { observer->init(*this); wObservers_.push_back(observer); } // return the totals count that corresponds to the given unit (batches, labels, or epochs) size_t getProgressIn(SchedulingUnit u) const { switch(u) { case SchedulingUnit::trgLabels: return labelsTotal; case SchedulingUnit::updates : return batches; case SchedulingUnit::epochs : return epochs; default: ABORT("corrupt enum value"); } } // update() first calls this // This is to make sure that enteredNewPeriodOf() can detect a transition intoa new period void rememberPreviousProgress() { prevLabelsTotal = labelsTotal; prevBatches = batches; prevEpochs = epochs; } size_t getPreviousProgressIn(SchedulingUnit u) const { switch(u) { case SchedulingUnit::trgLabels: return prevLabelsTotal; case SchedulingUnit::updates : return prevBatches; case SchedulingUnit::epochs : return prevEpochs; default: ABORT("corrupt enum value"); } } // Tests whether we entered a new period, e.g. disp-freq, according to the // unit in which that parameter is given. There are a few edge cases: // - this function will be called many times within the same epoch // - labelsTotal does not increment by 1, so simple modulus does not work // // So instead of modulus==0, this function compares the previous progress/period // to the current, and triggers if they differ (i.e. the border between two // periods was crossed). This requires that rememberPreviousProgress() is called // between calls to this. We call it from update(). Unfortunately, newEpoch() // is called at the wrong place for this to work, so SchedulingUnit::epoch is forbidden // for periods. bool enteredNewPeriodOf(std::string schedulingParam) const { auto period = SchedulingParameter::parse(schedulingParam); // @TODO: adapt to logical epochs ABORT_IF(period.unit == SchedulingUnit::epochs, "Unit {} is not supported for frequency parameters", schedulingParam); auto previousProgress = getPreviousProgressIn(period.unit); auto progress = getProgressIn(period.unit); return period && progress / period.n != previousProgress / period.n; } void newEpoch() { ++epochs; for(auto wObserver : wObservers_) { auto observer = wObserver.lock(); ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time"); observer->actAfterEpoch(*this); } samplesEpoch = 0; batchesEpoch = 0; } void newUpdate(size_t batchesInUpdate) { ++batches; batchesEpoch += batchesInUpdate; loaded = false; validated = false; for(auto wObserver : wObservers_) { auto observer = wObserver.lock(); ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time"); observer->actAfterBatches(*this); } } void newStalled(size_t num) { stalled = num; if(num > maxStalled) ++maxStalled; for(auto wObserver : wObservers_) { auto observer = wObserver.lock(); ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time"); observer->actAfterStalled(*this); } } void newLoad() { loaded = true; for(auto wObserver : wObservers_) { auto observer = wObserver.lock(); ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time"); observer->actAfterLoaded(*this); } } void load(const std::string& name) { if(!filesystem::exists(name)) return; YAML::Node config = YAML::LoadFile(name); epochs = config["epochs"].as(); batches = config["batches"].as(); batchesEpoch = config["batches-epoch"].as(); // different serialization name for backward compatibility samplesEpoch = config["samples"].as(); // clang-format off // optional for backward compatibility labelsTotal = config["labels-total"] ? config["labels-total"].as() : 0; prevLabelsTotal = config["prev-labels-total"] ? config["prev-labels-total"].as() : 0; prevBatches = config["prev-batches"] ? config["prev-batches"].as() : 0; prevEpochs = config["prev-epochs"] ? config["prev-epochs"].as() : 0; // clang-format on stalled = config["stalled"].as(); maxStalled = config["stalled-max"].as(); validator = config["validator"].as(); validators = config["validators"]; reset = config["reset"].as(); eta = config["eta"].as(); factor = config["eta-factor"].as(); warmupStart = SchedulingParameter::parse(config["warmup-start"].as()); costSum = config["cost-sum"].as(); costCount = config["cost-count"].as(); wordsDisp = config["disp-words"].as(); samplesDisp = config["disp-samples"].as(); updatesDisp = config["disp-updates"].as(); gradientNormAvg = config["gradient-norm-avg"].as(); gradientNormVar = config["gradient-norm-var"].as(); logGradientNormAvg = config["log-gradient-norm-avg"].as(); logGradientNormVar = config["log-gradient-norm-var"].as(); seedBatch = config["seed-batch"].as(); seedCorpus = config["seed-corpus"].as(); } void save(const std::string& name) const { std::ofstream fout(name); YAML::Node config; config["epochs"] = epochs; config["batches"] = batches; config["batches-epoch"] = batchesEpoch; config["samples"] = samplesEpoch; config["labels-total"] = labelsTotal; config["prev-labels-total"] = prevLabelsTotal; config["prev-batches"] = prevBatches; config["prev-epochs"] = prevEpochs; config["stalled"] = stalled; config["stalled-max"] = maxStalled; config["validator"] = validator; config["validators"] = validators; config["reset"] = reset; config["eta"] = eta; config["eta-factor"] = factor; config["warmup-start"] = std::string(warmupStart); config["cost-sum"] = costSum; config["cost-count"] = costCount; config["disp-updates"] = updatesDisp; config["disp-samples"] = samplesDisp; config["disp-words"] = wordsDisp; config["gradient-norm-avg"] = gradientNormAvg; config["gradient-norm-var"] = gradientNormVar; config["log-gradient-norm-avg"] = logGradientNormAvg; config["log-gradient-norm-var"] = logGradientNormVar; config["seed-batch"] = seedBatch; config["seed-corpus"] = seedCorpus; fout << config; } std::string fillTemplate(const std::string& templ) const { // The formatting below uses fmtlib, which is included with spdlog // and is included via the logger. return fmt::format(templ.c_str(), fmt::arg("E", epochs), fmt::arg("U", batches), fmt::arg("B", batchesEpoch), fmt::arg("T", labelsTotal)); } private: // this needs to be a vector of weak pointers, otherwise // it is likely to cause circular dependencies. std::vector> wObservers_; }; } // namespace marian