Program Listing for File training.h

Return to documentation for file (src/training/training.h)

#pragma once

#include "common/config.h"
#include "common/utils.h"
#include "data/batch_generator.h"
#ifndef _MSC_VER // @TODO: include SqLite in Visual Studio project
#include "data/corpus_sqlite.h"
#endif
#include "models/model_task.h"
#include "training/scheduler.h"
#include "training/validator.h"

namespace marian {

template <class ModelWrapper>
class Train : public ModelTask {
private:
  Ptr<Options> options_;
  void installCustomSignalHandlers();

public:
  Train(Ptr<Options> options) : options_(options) {}

  void run() override {
    using namespace data;

    // MPI init should be first thing in training
    auto mpi = initMPI(/*multiThreaded=*/!options_->get<bool>("sync-sgd")); // @TODO: do we need the multiThreaded distinction at all?

    if(mpi) { // if we run MPI, then make sure to sync seed across processes as first action
      mpi->bCast(&Config::seed, 1, IMPIWrapper::getDataType(&Config::seed));
      LOG(info, "Synced seed {}", Config::seed);
    }

    Ptr<CorpusBase> dataset;
    auto corpusSeed = Config::seed + (mpi ? mpi->myMPIRank() : 0); // @BUGBUG: no correct resume right now
    if(!options_->get<std::string>("sqlite").empty())
#ifndef _MSC_VER // @TODO: include SqLite in Visual Studio project
      dataset = New<CorpusSQLite>(options_, /*translate=*/false, corpusSeed);
#else
      ABORT("SqLite presently not supported on Windows");
#endif
    else
      dataset = New<Corpus>(options_, /*translate=*/false, corpusSeed);

    dataset->prepare();

    Ptr<BatchStats> stats;
    if(options_->get<bool>("mini-batch-fit")) {
      LOG(info,
          "[batching] Collecting statistics for batch fitting with step size {}",
          options_->get<size_t>("mini-batch-fit-step"));
      // @TODO this should receive a function object that can generate a fake batch;
      // that way vocabs would not be exposed.
      auto model = New<ModelWrapper>(options_, mpi);

      // use temporary scheduler to make sure everything gets destroyed properly
      // otherwise the scheduler believes that registered objects still exist
      auto tempTrainState = New<TrainingState>(options_->get<float>("learn-rate"));
      auto tempScheduler = New<Scheduler>(options_, tempTrainState, mpi);

      model->setScheduler(tempScheduler); // collectStats() needs to know about dynamic MB scaling
      stats = model->collectStats(dataset->getVocabs());
      LOG(info, "[batching] Done. Typical MB size is {} target words", utils::withCommas(stats->estimateTypicalTrgWords()));
    }

    auto trainState = New<TrainingState>(options_->get<float>("learn-rate"));
    auto scheduler = New<Scheduler>(options_, trainState, mpi);

    if((options_->hasAndNotEmpty("valid-sets") || options_->hasAndNotEmpty("valid-script-path"))
       && SchedulingParameter::parse(options_->get<std::string>("valid-freq"))) {
      for(auto validator : Validators(dataset->getVocabs(), options_))
        scheduler->addValidator(validator);
    }

    auto batchGenerator = New<CorpusBatchGenerator>(dataset, options_, stats);

    scheduler->registerTrainingObserver(batchGenerator);

    auto model = New<ModelWrapper>(options_, mpi);
    model->setScheduler(scheduler);
    model->setTypicalTrgBatchWords(batchGenerator->estimateTypicalTrgBatchWords()); // needed for dynamic MB scaling
    model->load();

    bool restored = !options_->get<bool>("no-restore-corpus")
                    && batchGenerator->restore(trainState);

    // We only want custom behavior once training starts.
    installCustomSignalHandlers();

    // -- main training loop
    scheduler->started();
    while(scheduler->keepGoing()) {
      if(!restored)
        batchGenerator->prepare();
      restored = false;

      // main training loop for one epoch
      for(auto batch : *batchGenerator) {
        if (!scheduler->keepGoing())
          break;
        model->update(batch);
      }

      if(scheduler->keepGoing())
        scheduler->increaseEpoch();
    }
    scheduler->finished();

    model->finalize(); // allow async to sync before final save   --@TODO: rename, or move into save()

    // Avoid saving the model twice if it has been loaded and training did not progress
    if(!trainState->loaded)
      model->save(true);

    // Signal success to a potential MPI runner
    model = nullptr;     // release any reference to MPI that model may hold
    scheduler = nullptr; // as above
    finalizeMPI(std::move(mpi));
  }
};

template <class ModelWrapper>
void Train<ModelWrapper>::installCustomSignalHandlers(){
  const std::string sigTermAction = options_->get<std::string>("sigterm");
  if (sigTermAction == "save-and-exit") {
    LOG(debug, "Will save before exiting upon SIGTERM.");
    signal(SIGTERM, requestSaveAndExit);
  }
  else if (sigTermAction != "exit-immediately")
    ABORT("Unrecognized value '{}' for --sigterm", sigTermAction);
}

}  // namespace marian