Class TrainingState

Class Documentation

class TrainingState

Public Functions

void updateEta(float dynamicBaseLR)
TrainingState(float learnRate)
void registerObserver(Ptr<TrainingObserver> observer)
size_t getProgressIn(SchedulingUnit u) const
void rememberPreviousProgress()
size_t getPreviousProgressIn(SchedulingUnit u) const
bool enteredNewPeriodOf(std::string schedulingParam) const
void newEpoch()
void newUpdate(size_t batchesInUpdate)
void newStalled(size_t num)
void newLoad()
void load(const std::string &name)
void save(const std::string &name) const
std::string fillTemplate(const std::string &templ) const

Public Members

size_t epochs = {1}
size_t batches = {0}
size_t batchesEpoch = {0}
size_t samplesEpoch = {0}
size_t labelsTotal = {0}
size_t prevLabelsTotal = {0}
size_t prevBatches = {0}
size_t prevEpochs = {0}
size_t stalled = {0}
size_t maxStalled = {0}
std::string validator
YAML::Node validators
bool reset = {false}
float eta
float factor = {1.f}
SchedulingParameter warmupStart
float costSum = {0}
float costCount = {0}
size_t wordsDisp = {0}
size_t samplesDisp = {0}
size_t updatesDisp = {0}
float gradientNormAvg = {0}
float gradientNormVar = {0}
float logGradientNormAvg = {0}
float logGradientNormVar = {0}
std::string seedBatch
std::string seedCorpus
bool loaded = {false}
bool validated = {false}