.. _program_listing_file_src_examples_mnist_training.h: Program Listing for File training.h =================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/examples/mnist/training.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "common/options.h" #include "models/model_task.h" #include "training/scheduler.h" #include "examples/mnist/dataset.h" #include "examples/mnist/validator.h" namespace marian { template class TrainMNIST : public ModelTask { private: Ptr options_; public: TrainMNIST(Ptr options) : options_(options) {} void run() override { using namespace data; // Prepare data set auto paths = options_->get>("train-sets"); auto dataset = New(paths); auto batchGenerator = New>(dataset, options_, nullptr); // Prepare scheduler with validators auto trainState = New(options_->get("learn-rate")); auto scheduler = New(options_, trainState, nullptr); scheduler->addValidator(New(options_)); // Multi-node training auto mpi = initMPI(/*multiThreaded=*/false); // Prepare model auto model = New(options_, mpi); model->setScheduler(scheduler); model->load(); // Run training while(scheduler->keepGoing()) { batchGenerator->prepare(); for(auto batch : *batchGenerator) { if(!scheduler->keepGoing()) break; model->update(batch); } if(scheduler->keepGoing()) scheduler->increaseEpoch(); } scheduler->finished(); model = nullptr; finalizeMPI(std::move(mpi)); } }; } // namespace marian