Program Listing for File graph_group.h¶
↰ Return to documentation for file (src/training/graph_group.h
)
#pragma once
#include "common/definitions.h"
#include "common/options.h"
#include "data/batch_generator.h"
#include "graph/expression_graph.h"
#include "models/model_base.h"
#include "optimizers/optimizers.h"
#include "training/scheduler.h"
#include "training/communicator.h"
namespace marian {
// With -Ofast enabled gcc will fail to identify NaN or Inf. Safeguard here.
static inline bool isFinite(float x) {
#ifdef __GNUC__
ABORT_IF(std::isfinite(0.f / 0.f), "NaN detection unreliable. Disable -Ofast compiler option.");
#endif
return std::isfinite(x);
}
#ifdef _MSC_VER // MS Visual studio insists that this funtion is not being referenced although is being referenced by name as an argument
#pragma warning(push)
#pragma warning(disable: 4505) //Unreferenced local function has been removed
#endif
// to accumulate gradients norms, first undo sqrt, sum, re-apply sqrt.
// if one value is nonfinite propagate Nan into the reduction.
static inline void accNanOrNorm(float& lhs, float rhs) {
if(isFinite(lhs) && isFinite(rhs)) {
lhs = sqrtf(lhs * lhs + rhs * rhs);
} else
lhs = std::numeric_limits<float>::quiet_NaN();
}
#ifdef _MSC_VER
#pragma warning(pop)
#endif
class GraphGroup {
protected:
Ptr<Options> options_;
Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator
Ptr<IMPIWrapper> mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run)
std::vector<DeviceId> devices_; // [deviceIndex]
ShardingMode shardingMode_{ShardingMode::global}; // If local and multi-node training, shard only on local devices and do full sync (faster). If global shard across entire set of GPUs (more RAM).
// common for all graph groups, individual graph groups decide how to fill them
std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
std::vector<Ptr<models::ICriterionFunction>> models_; // [deviceIndex]
std::vector<Ptr<OptimizerBase>> optimizerShards_; // [deviceIndex]
Ptr<Scheduler> scheduler_; // scheduler that keeps track of how much has been processed
bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed)
double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words
bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false
bool costScaling_{false};
float costScalingFactor_{1.f}; // @TODO, add current costScalingFactor_ to trainingState for serialization
size_t costScalingFreq_{2000};
float costScalingMultiplier_{2.f};
float costScalingFactorMinimum_{1.f};
size_t noNanSeen_{0}; // @TODO, add current noNanSeen_ to trainingState for serialization
size_t nanSeen_{0};
bool checkGradientNan_{false};
bool dynamicGradientScaling_{false};
float dynamicGradientScalingFactor_{2.f};
bool dynamicGradientScalingUseLogs_{false};
size_t dynamicGradientScalingFadeout_{0ul};
// determines the number of input streams (i.e. input files or fields in the TSV input) that need
// to be included in the batch, i.e. without alignments and weights
size_t numberOfInputFiles();
public:
GraphGroup(Ptr<Options> options, Ptr<IMPIWrapper> mpi);
GraphGroup(Ptr<Options> options);
void initGraphsAndOpts();
virtual ~GraphGroup() {}
virtual void update(Ptr<data::Batch> batch) = 0;
// increase cost-scaling factor if no NaN has been detected for a
// given number of iterations. Usually we increase by 2 which adds
// one more bit for precision.
void increaseCostScaleFactor();
// call when a NaN was seen to decrease cost-scaling factor
void decreaseCostScaleFactor();
virtual void load();
virtual void save(bool isFinal = false);
private:
void load(const OptimizerBase::ScatterStateFunc& scatterFn);
void save(bool isFinal,
const OptimizerBase::GatherStateFunc& gatherOptimizerStateFn);
bool restoreFromCheckpoint(const std::string& modelFileName,
const OptimizerBase::ScatterStateFunc& scatterFn);
void saveCheckpoint(const std::string& modelFileName,
const OptimizerBase::GatherStateFunc& gatherFn);
public:
void swapWithSmoothed();
bool isMainProcess() const { return mpi_->isMainProcess(); } // (we need this test a few times)
void barrier() const { mpi_->barrier(); } // (we need this several times)
void validate();
virtual void finalize();
virtual void setScheduler(Ptr<Scheduler> scheduler) = 0;
float checkNanOrNorm(size_t i, size_t begin, size_t end);
float executeAndCollectNorm(const std::function<float(size_t, size_t, size_t)>& task);
float computeNormalizationFactor(float gNorm, size_t updateTrgWords);
// @TODO: Can this be made const? It seems wrong to have a stateful method that still returns a result.
virtual Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph> graph,
Ptr<models::ICriterionFunction> model,
const std::vector<Ptr<Vocab>>& vocabs,
double multiplier = 1.);
virtual Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>>& vocabs) = 0;
void setTypicalTrgBatchWords(size_t typicalTrgBatchWords);
double getTypicalTrgBatchWords();
void updateAverageTrgBatchWords(size_t trgBatchWords);
};
} // namespace marian