.. _program_listing_file_src_training_graph_group.h: Program Listing for File graph_group.h ====================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/training/graph_group.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "common/definitions.h" #include "common/options.h" #include "data/batch_generator.h" #include "graph/expression_graph.h" #include "models/model_base.h" #include "optimizers/optimizers.h" #include "training/scheduler.h" #include "training/communicator.h" namespace marian { // With -Ofast enabled gcc will fail to identify NaN or Inf. Safeguard here. static inline bool isFinite(float x) { #ifdef __GNUC__ ABORT_IF(std::isfinite(0.f / 0.f), "NaN detection unreliable. Disable -Ofast compiler option."); #endif return std::isfinite(x); } #ifdef _MSC_VER // MS Visual studio insists that this funtion is not being referenced although is being referenced by name as an argument #pragma warning(push) #pragma warning(disable: 4505) //Unreferenced local function has been removed #endif // to accumulate gradients norms, first undo sqrt, sum, re-apply sqrt. // if one value is nonfinite propagate Nan into the reduction. static inline void accNanOrNorm(float& lhs, float rhs) { if(isFinite(lhs) && isFinite(rhs)) { lhs = sqrtf(lhs * lhs + rhs * rhs); } else lhs = std::numeric_limits::quiet_NaN(); } #ifdef _MSC_VER #pragma warning(pop) #endif class GraphGroup { protected: Ptr options_; Ptr comm_; // [not null] communicator, e.g. NCCLCommunicator Ptr mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run) std::vector devices_; // [deviceIndex] ShardingMode shardingMode_{ShardingMode::global}; // If local and multi-node training, shard only on local devices and do full sync (faster). If global shard across entire set of GPUs (more RAM). // common for all graph groups, individual graph groups decide how to fill them std::vector> graphs_; // [deviceIndex] std::vector> models_; // [deviceIndex] std::vector> optimizerShards_; // [deviceIndex] Ptr scheduler_; // scheduler that keeps track of how much has been processed bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed) double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false bool costScaling_{false}; float costScalingFactor_{1.f}; // @TODO, add current costScalingFactor_ to trainingState for serialization size_t costScalingFreq_{2000}; float costScalingMultiplier_{2.f}; float costScalingFactorMinimum_{1.f}; size_t noNanSeen_{0}; // @TODO, add current noNanSeen_ to trainingState for serialization size_t nanSeen_{0}; bool checkGradientNan_{false}; bool dynamicGradientScaling_{false}; float dynamicGradientScalingFactor_{2.f}; bool dynamicGradientScalingUseLogs_{false}; size_t dynamicGradientScalingFadeout_{0ul}; // determines the number of input streams (i.e. input files or fields in the TSV input) that need // to be included in the batch, i.e. without alignments and weights size_t numberOfInputFiles(); public: GraphGroup(Ptr options, Ptr mpi); GraphGroup(Ptr options); void initGraphsAndOpts(); virtual ~GraphGroup() {} virtual void update(Ptr batch) = 0; // increase cost-scaling factor if no NaN has been detected for a // given number of iterations. Usually we increase by 2 which adds // one more bit for precision. void increaseCostScaleFactor(); // call when a NaN was seen to decrease cost-scaling factor void decreaseCostScaleFactor(); virtual void load(); virtual void save(bool isFinal = false); private: void load(const OptimizerBase::ScatterStateFunc& scatterFn); void save(bool isFinal, const OptimizerBase::GatherStateFunc& gatherOptimizerStateFn); bool restoreFromCheckpoint(const std::string& modelFileName, const OptimizerBase::ScatterStateFunc& scatterFn); void saveCheckpoint(const std::string& modelFileName, const OptimizerBase::GatherStateFunc& gatherFn); public: void swapWithSmoothed(); bool isMainProcess() const { return mpi_->isMainProcess(); } // (we need this test a few times) void barrier() const { mpi_->barrier(); } // (we need this several times) void validate(); virtual void finalize(); virtual void setScheduler(Ptr scheduler) = 0; float checkNanOrNorm(size_t i, size_t begin, size_t end); float executeAndCollectNorm(const std::function& task); float computeNormalizationFactor(float gNorm, size_t updateTrgWords); // @TODO: Can this be made const? It seems wrong to have a stateful method that still returns a result. virtual Ptr collectStats(Ptr graph, Ptr model, const std::vector>& vocabs, double multiplier = 1.); virtual Ptr collectStats(const std::vector>& vocabs) = 0; void setTypicalTrgBatchWords(size_t typicalTrgBatchWords); double getTypicalTrgBatchWords(); void updateAverageTrgBatchWords(size_t trgBatchWords); }; } // namespace marian