Program Listing for File graph_group_singleton.cpp

Return to documentation for file (src/training/graph_group_singleton.cpp)

#include "training/graph_group_singleton.h"

namespace marian {

void SingletonGraph::setScheduler(Ptr<Scheduler> scheduler) {
  scheduler_ = scheduler;
  // optimizer has to be registered last to see changes of learning rate
  scheduler_->registerTrainingObserver(scheduler_);
  for(auto opt : optimizerShards_)
    scheduler_->registerTrainingObserver(opt);
}

void SingletonGraph::execute(Ptr<data::Batch> batch) {
  auto graph = graphs_[0];
  auto model = models_[0];
  auto opt   = optimizerShards_[0];

  auto lossNode = model->build(graph, batch);
  if(costScalingFactor_ != 1.f) {
    // for fp16 training, it's ok to go out of scope, we do not use the scaled version for anything
    auto scaledLoss = lossNode->loss() * costScalingFactor_;
  }

  graph->forward();
  graph->backward();

  bool noNanOrInf = true;
  if(costScaling_) {
    // Are there NaNs in the gradient?
    bool hasNan = false, hasInf = false;
    IsNaN(graph->params()->grads(), graph->allocator(), hasNan, hasInf);
    noNanOrInf = !(hasNan || hasInf);

    if(!noNanOrInf) // there was a NaN, decrease cost-scaling
      GraphGroup::decreaseCostScaleFactor();
  }

  if(noNanOrInf) // skip update if NaN was seen @TODO: repeat instead with smaller factor?
    opt->update(graph->params()->vals(),
                graph->params()->grads(),
                batch->wordsTrg(),
                costScalingFactor_);

  if(scheduler_) {
    scheduler_->update(*lossNode, batch);

    if(scheduler_->validating()) {
      swapWithSmoothed();
      scheduler_->validate(graphs_);
      swapWithSmoothed();
    }

    if(scheduler_->saving())
      this->save();
  }

  if(noNanOrInf)
    GraphGroup::increaseCostScaleFactor();
}

}  // namespace marian