Program Listing for File training.h¶
↰ Return to documentation for file (src/examples/mnist/training.h
)
#pragma once
#include "common/options.h"
#include "models/model_task.h"
#include "training/scheduler.h"
#include "examples/mnist/dataset.h"
#include "examples/mnist/validator.h"
namespace marian {
template <class ModelWrapper>
class TrainMNIST : public ModelTask {
private:
Ptr<Options> options_;
public:
TrainMNIST(Ptr<Options> options) : options_(options) {}
void run() override {
using namespace data;
// Prepare data set
auto paths = options_->get<std::vector<std::string>>("train-sets");
auto dataset = New<data::MNISTData>(paths);
auto batchGenerator = New<BatchGenerator<data::MNISTData>>(dataset, options_, nullptr);
// Prepare scheduler with validators
auto trainState = New<TrainingState>(options_->get<float>("learn-rate"));
auto scheduler = New<Scheduler>(options_, trainState, nullptr);
scheduler->addValidator(New<MNISTAccuracyValidator>(options_));
// Multi-node training
auto mpi = initMPI(/*multiThreaded=*/false);
// Prepare model
auto model = New<ModelWrapper>(options_, mpi);
model->setScheduler(scheduler);
model->load();
// Run training
while(scheduler->keepGoing()) {
batchGenerator->prepare();
for(auto batch : *batchGenerator) {
if(!scheduler->keepGoing())
break;
model->update(batch);
}
if(scheduler->keepGoing())
scheduler->increaseEpoch();
}
scheduler->finished();
model = nullptr;
finalizeMPI(std::move(mpi));
}
};
} // namespace marian