Program Listing for File scheduler.h¶
↰ Return to documentation for file (src/training/scheduler.h
)
#pragma once
#include "common/options.h"
#include "common/signal_handling.h"
#include "training/training_state.h"
#include "training/validator.h"
#include "training/communicator.h"
#include "layers/loss.h"
namespace marian {
class Scheduler : public TrainingObserver {
private:
Ptr<Options> options_;
Ptr<TrainingState> state_;
std::vector<Ptr<ValidatorBase>> validators_;
Ptr<IMPIWrapper> mpi_;
bool first_{true}; // true if this is the first update after renewing the training
size_t gradientNormAvgWindow_{100}; // window size for recording the exponential average of gradient norms, after this many updates about 90% of the mass comes from this many last updates
SchedulingParameter logicalEpoch_;
size_t logicalEpochWidth_{0};
timer::Timer timer_;
timer::Timer heartBeatTimer_;
// The variable helps to keep track of the end of the current epoch
// (regardless if it's the 1st or nth epoch and if it's a new or continued training),
// which indicates the end of the training data stream from STDIN
bool endOfStdin_{false}; // true at the end of the epoch if training from STDIN;
// @TODO: figure out how to compute this with regard to updates as well, although maybe harder since no final value
// determine scheduled LR decay factor (--lr-decay-inv-sqrt option)
float getScheduledLRDecayFactor(const TrainingState& state) const {
auto args = options_->get<std::vector<std::string>>("lr-decay-inv-sqrt");
ABORT_IF(args.empty() || args.size() > 2, "--lr-decay-inv-sqrt argument must be one or two numbers with units");
auto decayGoogle = SchedulingParameter::parse(args[0]);
size_t progress = state.getProgressIn(decayGoogle.unit);
size_t start = decayGoogle.n;
if (args.size() > 1) {
auto decayStart = SchedulingParameter::parse(args[1]);
ABORT_IF(decayStart && decayStart.unit != decayGoogle.unit,
"both --lr-decay-inv-sqrt arguments must have the same unit");
start = decayStart.n;
}
if (decayGoogle && progress > start) {
progress = progress - start + decayGoogle.n; // shift so that we get 1 at progress==start
return (float)(std::sqrt((double)decayGoogle.n / (double)progress));
}
else
return 1.f;
}
void updateLearningRate(TrainingState& state) const {
float baselr = options_->get<float>("learn-rate");
// warm-up factor
float warmupFactor = 1.f;
auto warmupParam = SchedulingParameter::parse(options_->get<std::string>("lr-warmup"));
if(warmupParam) {
ABORT_IF(state.warmupStart && state.warmupStart.unit != warmupParam.unit,
"lr-warmup and warmup-start must have the same unit");
auto bno = state.getProgressIn(warmupParam.unit) - state.warmupStart.n;
warmupFactor = std::min(1.f, (float)bno / (float)warmupParam.n);
}
// TODO: why lr-warmup-start-rate is extracted from options_ instead of using state.warmupStart?
float lrStart = options_->get<float>("lr-warmup-start-rate");
baselr = lrStart + (baselr - lrStart) * warmupFactor; // linear interpolation between
// lr-warmup-start-rate to learn-rate
// schedule-based decay factor (--lr-decay-inv-sqrt)
float scheduledDecayFactor = getScheduledLRDecayFactor(state);
baselr = baselr * scheduledDecayFactor;
// factor in state-based decay and set final LR as state.eta
state.updateEta(baselr);
}
std::string formatLoss(std::string lossType,
bool dispLabelCounts,
size_t batchLabels,
Ptr<TrainingState> state) {
std::stringstream ss;
ss << "Cost ";
ss << std::setprecision(8) << std::fixed;
// @TODO: put a single loss formatting function into loss.h and reuse here to avoid code duplication
// @TODO: use dispLabelCounts with any display type?
// @TODO: bugbug cost-type ce-mean-words with multi-loss-type mean divides too much in display
if(lossType == "ce-mean-words") {
ss << state->costSum / state->costCount;
} else if(lossType == "ce-sum" && dispLabelCounts) {
ss << state->costSum / state->costCount
<< " * " << utils::withCommas((size_t)state->costCount);
if(batchLabels > 0)
ss << " @ " << utils::withCommas(batchLabels);
ss << " after " << utils::withCommas(state->labelsTotal);
} else if(lossType == "ce-sum" && !dispLabelCounts) {
ss << state->costSum / state->updatesDisp; // average over batches
} else if(lossType == "perplexity") {
ss << std::exp(state->costSum / state->costCount);
} else if(lossType == "cross-entropy" || lossType == "ce-mean") { // backwards-compat, @TODO: get rid of this?
ss << state->costSum / state->samplesDisp;
} else {
ABORT("Unknown loss type {}", lossType);
}
return ss.str();
}
// Here we calculate the logical epoch as defined by the user, by default this will be just a traditional data epoch.
// We understand a data epoch as a complete pass throught the training data as far as that information is available.
// By contrast, a logical epoch is defined somewhat indepdently of the number of data passes as by the number of seen updates or labels
// or as a multitude of data epochs.
float calculateLogicalEpoch() {
if(logicalEpoch_.unit == SchedulingUnit::epochs)
return (float)state_->epochs / (float)logicalEpoch_.n; // logical epoch as multiple of n data epochs
else if(logicalEpoch_.unit == SchedulingUnit::trgLabels)
return (float)state_->labelsTotal / (float)logicalEpoch_.n; // logical epoch as multiple of n labels
else if(logicalEpoch_.unit == SchedulingUnit::updates)
return (float)state_->batches / (float)logicalEpoch_.n; // logical epoch as multiple of n gradient updates (not actually batches @TODO: change name)
else
ABORT("Unknown scheduling unit occurred in logical epoch"); // shouldn't really happen unless we add a new unit in the corresponding enum
}
// Formatting for logical epochs
std::string formatLogicalEpoch() {
return fmt::format("{:." + std::to_string(logicalEpochWidth_) + "f}", calculateLogicalEpoch());
}
public:
Scheduler(Ptr<Options> options, Ptr<TrainingState> state, Ptr<IMPIWrapper> mpi = nullptr)
: options_(options), state_(state), mpi_(mpi),
gradientNormAvgWindow_(options_->get<size_t>("gradient-norm-average-window", 100)) {
// parse logical-epoch parameters
auto logicalEpochStr = options->get<std::vector<std::string>>("logical-epoch", {"1e", "0"});
ABORT_IF(logicalEpochStr.empty(), "Logical epoch information is missing?");
logicalEpoch_ = SchedulingParameter::parse(logicalEpochStr[0]);
// here we deduce the floating point width to be used in formatLogicalEpoch()
if(logicalEpochStr.size() > 1) { // if the width is given, just use that
logicalEpochWidth_ = std::stoul(logicalEpochStr[1]);
} else { // the width is not given so we deduce a suitable display width
if(logicalEpoch_.unit == SchedulingUnit::epochs && logicalEpoch_.n == 1)
logicalEpochWidth_ = 0; // for a data epoch, output is an integer and looks like before this feature was introduced
else
logicalEpochWidth_ = 3; // all other outputs can be fractional, hence floating point format. We choose
// 3 as a default which corresponds to a multiplier of 1000 (3 orders of magnitude).
}
ABORT_IF(state_->factor != 1, "state.factor unexpectedly not 1 at this point??");
updateLearningRate(*state);
}
// test if any parameters specify dynamic MB scaling
bool isDynamicMBSizeScaling() const {
auto mbWarmup = SchedulingParameter::parse(options_->get<std::string>("mini-batch-warmup"));
auto mbTracking = options_->get<bool>("mini-batch-track-lr");
return mbWarmup || mbTracking;
}
// determine dynamic MB scaling factor
double getDynamicMBSizeMultiplier() const {
double ratio = 1.0;
auto mbWarmup = SchedulingParameter::parse(options_->get<std::string>("mini-batch-warmup"));
if (mbWarmup) {
// mini-batch-warmup
LOG_ONCE(info, "[scheduler] Mini-batch size warmup {}", std::string(mbWarmup));
// This ramps up MB size at start, relative to progress within warm-up period.
size_t progress = state_->getProgressIn(mbWarmup.unit); // number of updates/labels processed
auto progressRatio = (double)progress / (double)mbWarmup.n; // where are we relatively within target warm-up period
// if unit is labels, then account for the fact that our increment itself is not constant
#if 1 // this seems to hurt convergence quite a bit compared to when updates is used
if (mbWarmup.unit == SchedulingUnit::trgLabels)
progressRatio = std::sqrt(progressRatio);
#endif
if (progressRatio < 1)
ratio *= progressRatio;
}
// dynamic MB-size tracking with learning rate
// As LR goes down, MB gets ramped up by the same ratio, which has been found to be safe.
auto mbTracking = options_->get<bool>("mini-batch-track-lr");
if (mbTracking) {
ABORT("Please review this code");
auto lrFactor = getScheduledLRDecayFactor(*state_) * state_->factor; // (don't include lr-warmup)
if (lrFactor != 1)
LOG_ONCE(info, "[scheduler] Dynamic mini-batch size adjustment enabled and kicking in");
ratio /= lrFactor;
}
return ratio;
}
std::tuple<size_t, float, float> getGradientNormStats() const {
return std::make_tuple(gradientNormAvgWindow_, state_->gradientNormAvg, state_->gradientNormVar);
}
std::tuple<size_t, float, float> getLogGradientNormStats() const {
return std::make_tuple(gradientNormAvgWindow_, state_->logGradientNormAvg, state_->logGradientNormVar);
}
bool keepGoing() {
if(saveAndExitRequested()) // via SIGTERM
return false;
#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches
// stop if it reached the maximum number of epochs
size_t stopAfterEpochs = options_->get<size_t>("after-epochs");
if(stopAfterEpochs > 0 && calculateLogicalEpoch() > stopAfterEpochs)
return false;
// stop if it reached the maximum number of batch updates
size_t stopAfterBatches = options_->get<size_t>("after-batches");
if(stopAfterBatches > 0 && state_->batches >= stopAfterBatches)
return false;
#endif
// get list of stopping criteria e.g. "10e,300Ku,20Gt" (10 epochs, 300,000 updates, 20 billion target labels)
// and stop for whatever criterion hits first.
std::vector<std::string> stoppingCriteria = utils::split(options_->get<std::string>("after"), ",");
for(auto stoppingCriterionString : stoppingCriteria) {
SchedulingParameter stoppingCriterion = SchedulingParameter::parse(stoppingCriterionString);
if(stoppingCriterion.n > 0) { // is any stopping criterion defined?
if(stoppingCriterion.unit == SchedulingUnit::epochs && calculateLogicalEpoch() > stoppingCriterion.n) return false;
if(stoppingCriterion.unit == SchedulingUnit::updates && state_->batches >= stoppingCriterion.n) return false;
if(stoppingCriterion.unit == SchedulingUnit::trgLabels && state_->labelsTotal >= stoppingCriterion.n) return false;
}
}
// stop if the first/all/any validators did not improve for a given number of checks
size_t stopAfterStalled = options_->get<size_t>("early-stopping");
if(stopAfterStalled > 0 && stalled() >= stopAfterStalled)
return false;
// stop if data streaming from STDIN is stopped
if(endOfStdin_)
return false;
return true;
}
void increaseEpoch() {
LOG(info, "Seen {} samples", utils::withCommas(state_->samplesEpoch));
state_->newEpoch();
if(std::to_string(logicalEpoch_) == "1e")
LOG(info, "Starting epoch {}", state_->epochs);
else
LOG(info, "Starting data epoch {} in logical epoch {}", state_->epochs, formatLogicalEpoch());
}
void started() { LOG(info, "Training started"); }
void finished() {
if (saveAndExitRequested())
LOG(info, "Training interrupted (via signal).");
else
LOG(info, "Training finished");
}
void addValidator(Ptr<ValidatorBase> validator) {
validators_.push_back(validator);
registerTrainingObserver(validators_.back());
if(!state_->loaded) {
state_->validators[validator->type()]["last-best"] = validator->initScore();
state_->validators[validator->type()]["stalled"] = 0;
}
if(validators_.size() == 1)
state_->validator = validator->type();
}
bool validating() {
return (!validators_.empty()
&& state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq"))
&& keepGoing());
}
bool saving() {
return state_->enteredNewPeriodOf(options_->get<std::string>("save-freq"));
}
bool syncing() {
return state_->enteredNewPeriodOf(options_->get<std::string>("sync-freq", "0"));
}
void validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
bool isFinal = false) {
// Do not validate if already validated (for instance, after the model is loaded)
// or if validation is scheduled for another update, or when a graceful shutdown
// was requested.
if(saveAndExitRequested()
|| state_->validated // already validated (in resumed training, for example)
|| (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !isFinal)) // not now
return;
size_t stalledPrev = stalled();
for(auto validator : validators_) {
if(!validator)
continue;
float value = 0;
if(!mpi_ || mpi_->isMainProcess()) {
// We run validation only in the main process, but this is risky with MPI.
// Validators might modify random state etc., maybe we should run validators
// everywhere, but not report and not save on the other processes.
value = validator->validate(graphs, state_);
if(validator->stalled() > 0) {
LOG_VALID(info,
"Ep. {} : Up. {} : {} : {} : stalled {} times (last best: {})",
formatLogicalEpoch(),
state_->batches,
validator->type(),
value,
validator->stalled(), validator->lastBest());
} else {
LOG_VALID(info,
"Ep. {} : Up. {} : {} : {} : new best",
formatLogicalEpoch(),
state_->batches,
validator->type(),
value);
}
}
if(mpi_) {
// collect and broadcast validation result to all processes and bring validator up-to-date
mpi_->bCast(&value, 1, IMPIWrapper::getDataType(&value));
// @TODO: add function to validator?
mpi_->bCast(&validator->stalled(), 1, IMPIWrapper::getDataType(&validator->stalled()));
mpi_->bCast(&validator->lastBest(), 1, IMPIWrapper::getDataType(&validator->lastBest()));
}
state_->validators[validator->type()]["last-best"] = validator->lastBest();
state_->validators[validator->type()]["stalled"] = validator->stalled();
}
// notify training observers about stalled validation
size_t stalledNew = stalled();
if(stalledNew > stalledPrev)
state_->newStalled(stalledNew);
state_->validated = true;
}
// Returns the proper number of stalled validation w.r.t. early-stopping-on
size_t stalled() {
std::string stopOn = options_->get<std::string>("early-stopping-on");
if(stopOn == "any")
return stalledMax();
if(stopOn == "all")
return stalledMin();
return stalled1st();
}
// Returns the number of stalled validations for the first validator
size_t stalled1st() {
if(!validators_.empty())
if(validators_[0])
return validators_[0]->stalled();
return 0;
}
// Returns the largest number of stalled validations across validators or 0 if there are no validators
size_t stalledMax() {
size_t max = 0;
for(auto validator : validators_)
if(validator && validator->stalled() > max)
max = validator->stalled();
return max;
}
// Returns the lowest number of stalled validations across validators or 0 if there are no validators
size_t stalledMin() {
size_t min = std::numeric_limits<std::size_t>::max();
for(auto validator : validators_)
if(validator && validator->stalled() < min)
min = validator->stalled();
return min == std::numeric_limits<std::size_t>::max() ? 0 : min;
}
void update(StaticLoss rationalLoss, Ptr<data::Batch> batch) {
update(rationalLoss, /*numReadBatches=*/1, /*batchSize=*/batch->size(), /*batchLabels=*/batch->wordsTrg(), /*gradientNorm=*/0.f);
}
// @TODO: go back to function which takes batch as an argument? The current arguments make it hard
// to choose which subbatch should be used for speed display. For sequence-classifiers it's more interesting
// to see the source-words consumed rather than the labels.
void update(StaticLoss rationalLoss,
size_t numReadBatches, // number of batches read by the reader (for seeking in case of restart)
size_t batchSize, // total number of sentences in batch
size_t batchLabels, // total number of target words in batch
float gradientNorm) { // gradientNorm of update
state_->rememberPreviousProgress(); // note: epoch increases happen at the wrong place, hence
// -freq parameters do not support epoch units
state_->validated = false;
// Since batchLabels is counted across all MPI processes, we also should temporarily
// extrapolate cost across MPI processes, to have numbers in the right range.
// When doing the actual log, we then aggregate across MPI processes to get the accurate number.
if(mpi_) {
rationalLoss.loss *= mpi_->numMPIProcesses();
rationalLoss.count *= mpi_->numMPIProcesses();
}
// @BUGBUG: rationalLoss.count is float, not a count. Possible solution: make (costSum, costCount) a StaticLoss object as well
state_->costSum += rationalLoss.loss; // aggregate sum cost since last display
state_->costCount += rationalLoss.count; // cost gets normalized w.r.t. this in display
state_->updatesDisp += 1;
state_->samplesDisp += batchSize;
state_->wordsDisp += batchLabels; // words at given input processed since last display, for speed display
state_->samplesEpoch += batchSize; // sentences processed in this epoch
state_->labelsTotal += batchLabels; // total labels processed
state_->newUpdate(numReadBatches);
if(gradientNorm) {
size_t range = std::min(gradientNormAvgWindow_, state_->batches);
float alpha = 2.f / (float)(range + 1);
float delta = gradientNorm - state_->gradientNormAvg;
state_->gradientNormAvg = state_->gradientNormAvg + alpha * delta;
state_->gradientNormVar = (1.0f - alpha) * (state_->gradientNormVar + alpha * delta * delta);
float logDelta = std::log(gradientNorm) - state_->logGradientNormAvg;
state_->logGradientNormAvg = state_->logGradientNormAvg + alpha * logDelta;
state_->logGradientNormVar = (1.0f - alpha) * (state_->logGradientNormVar + alpha * logDelta * logDelta);
}
// reconstruct sum cost, for displaying epoch-level averages instead of minibatch-level
auto lossType = options_->get<std::string>("cost-type");
auto dispLabelCounts = options_->get<bool>("disp-label-counts"); // if true then show as "cost per label * number of labels"
if(state_->enteredNewPeriodOf(options_->get<std::string>("disp-freq")) || state_->batches <= options_->get<size_t>("disp-first")) {
// if MPI then aggregate precise cost across workers
if(mpi_) {
state_->costSum /= mpi_->numMPIProcesses(); // undo the extra scaling
state_->costCount /= mpi_->numMPIProcesses(); // undo the extra scaling
mpi_->allReduce(&state_->costSum, &state_->costSum, 1, MPI_FLOAT, MPI_SUM);
mpi_->allReduce(&state_->costCount, &state_->costCount, 1, MPI_FLOAT, MPI_SUM);
}
if(mpi_ && mpi_->myMPIRank() != 0) {
// skip the report on alternate worker processes
} else if(options_->get<bool>("lr-report")) {
LOG(info,
"Ep. {} : Up. {} : Sen. {} : {} : Time {:.2f}s : {:.2f} words/s : gNorm {:.4f} : L.r. {:.4e}",
formatLogicalEpoch(),
state_->batches,
utils::withCommas(state_->samplesEpoch),
formatLoss(lossType, dispLabelCounts, batchLabels, state_),
timer_.elapsed(),
state_->wordsDisp / timer_.elapsed(),
state_->gradientNormAvg,
state_->eta);
} else {
LOG(info,
"Ep. {} : Up. {} : Sen. {} : {} : Time {:.2f}s : {:.2f} words/s : gNorm {:.4f}",
formatLogicalEpoch(),
state_->batches,
utils::withCommas(state_->samplesEpoch),
formatLoss(lossType, dispLabelCounts, batchLabels, state_),
timer_.elapsed(),
state_->wordsDisp / timer_.elapsed(),
state_->gradientNormAvg);
}
timer_.start();
state_->costSum = 0;
state_->costCount = 0;
state_->updatesDisp = 0;
state_->samplesDisp = 0;
state_->wordsDisp = 0;
}
// progress heartbeat for MS-internal Philly compute cluster
// This environment variable exists when running on the cluster.
using namespace std::chrono;
if((!mpi_ || mpi_->myMPIRank() == 0) && getenv("PHILLY_JOB_ID")
&& heartBeatTimer_.elapsed<std::chrono::minutes>() >= 30) {
fprintf(stderr, "PROGRESS: %.2f%%\nEVALERR: %.7f%%\n",
(double)calculateLogicalEpoch(),
state_->costSum / (state_->costCount ? state_->costCount : 1));
fflush(stderr);
heartBeatTimer_.start();
}
}
void load(const std::string& name) {
std::string nameYaml = name + ".progress.yml";
if(filesystem::exists(nameYaml))
state_->load(nameYaml);
if(options_->get<bool>("no-restore-corpus")) {
state_->samplesEpoch = 0;
state_->costSum = 0;
state_->costCount = 0;
state_->updatesDisp = 0;
state_->samplesDisp = 0;
state_->wordsDisp = 0;
}
if(options_->get<bool>("valid-reset-stalled")) {
state_->stalled = 0;
state_->maxStalled = 0;
for(const auto& validator : validators_) {
if(state_->validators[validator->type()])
state_->validators[validator->type()]["stalled"] = 0;
}
}
state_->newLoad();
}
void save(const std::string& name) {
// Save config options
std::ofstream fout(name + ".yml");
fout << options_->asYamlString();
// Save training progress
state_->save(name + ".progress.yml");
}
size_t numberOfBatches() { return state_->batches; }
void registerTrainingObserver(Ptr<TrainingObserver> observer) {
state_->registerObserver(observer);
}
void actAfterEpoch(TrainingState& state) override {
// stop if data streaming from STDIN is stopped for a TSV input
std::string firstPath = options_->get<std::vector<std::string>>("train-sets")[0];
if(options_->get<bool>("tsv", false) && (firstPath == "stdin" || firstPath == "-"))
endOfStdin_ = true;
float factor = options_->get<float>("lr-decay");
updateLearningRate(state);
if(factor > 0.0) {
bool decay = false;
auto strategy = options_->get<std::string>("lr-decay-strategy");
state.reset = false;
if(strategy == "epoch" || strategy == "epoch+batches"
|| strategy == "epoch+stalled") {
size_t startEpoch
= options_->get<std::vector<size_t>>("lr-decay-start").front();
if(startEpoch && state.epochs >= startEpoch)
decay = true;
}
if(strategy == "epoch+batches") {
size_t startBatches
= options_->get<std::vector<size_t>>("lr-decay-start")[1];
if(startBatches && state.batches >= startBatches)
decay = true;
}
if(strategy == "epoch+stalled") {
size_t startStalled
= options_->get<std::vector<size_t>>("lr-decay-start")[1];
if(startStalled && state.maxStalled >= startStalled)
decay = true;
}
if(decay) {
state.factor *= factor;
updateLearningRate(state);
LOG(info, "Decaying learning rate to {} in epoch {}", state.eta, state.epochs);
state.reset = options_->get<bool>("lr-decay-reset-optimizer");
if(state.reset)
LOG(info, "Resetting optimizer statistics");
if(options_->get<bool>("lr-decay-repeat-warmup")) {
LOG(info, "Restarting learning rate warmup");
state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
}
}
}
}
void actAfterBatches(TrainingState& state) override {
float factor = options_->get<float>("lr-decay");
state.reset = false;
updateLearningRate(state);
if(factor > 0.0) {
if(options_->get<std::string>("lr-decay-strategy") == "batches") {
size_t start = options_->get<std::vector<size_t>>("lr-decay-start").front();
size_t freq = options_->get<size_t>("lr-decay-freq"); // note: unlike e.g. disp-freq, this is always in batches
if(start > 0 && freq > 0 && state.batches >= start
&& ((state.batches - start) % freq == 0)) {
state.factor *= factor;
updateLearningRate(state);
LOG(info, "Decaying learning rate to {} after {} batches", state.eta, state.batches);
state.reset = options_->get<bool>("lr-decay-reset-optimizer");
if(state.reset)
LOG(info, "Resetting optimizer statistics");
if(options_->get<bool>("lr-decay-repeat-warmup")) {
LOG(info, "Restarting learning rate warmup");
// TODO: avoid repeating this many times and minimize calls to options_->get
state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
}
}
}
}
if(first_ && options_->get<bool>("lr-warmup-at-reload")) {
LOG(info, "Restarting learning rate warmup");
state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
}
if(options_->get<bool>("lr-warmup-cycle")) {
if(state_->enteredNewPeriodOf(options_->get<std::string>("lr-warmup")))
state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
}
first_ = false;
}
void actAfterStalled(TrainingState& state) override {
float factor = options_->get<float>("lr-decay");
state.reset = false;
updateLearningRate(state);
if(factor > 0.0) {
if(options_->get<std::string>("lr-decay-strategy") == "stalled") {
size_t startStalled = options_->get<std::vector<size_t>>("lr-decay-start").front();
if(startStalled && state.stalled && state.stalled % startStalled == 0) {
state.factor *= factor;
updateLearningRate(state);
LOG(info,
"Decaying learning rate to {} after having stalled {} time(s)",
state.eta,
state.stalled);
state.reset = options_->get<bool>("lr-decay-reset-optimizer");
if(state.reset)
LOG(info, "Resetting optimizer statistics");
if(options_->get<bool>("lr-decay-repeat-warmup")) {
LOG(info, "Restarting learning rate warmup");
state.warmupStart.n = state.getProgressIn(
SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
}
}
}
}
}
};
} // namespace marian