Program Listing for File graph_group_sync.h¶
↰ Return to documentation for file (src/training/graph_group_sync.h
)
#pragma once
#include "optimizers/quantizer.h"
#include "training/graph_group.h"
namespace marian {
class SyncGraphGroup : public GraphGroup {
using Base = GraphGroup;
const double delay_{1.}; // optimizer-delay parameter. Fractional means to use a fraction of whatever the MB size is
// @TODO: instead, create an array of ExponentialSmoothing objects, and don't use ExponentialSmoothing as a base class
std::vector<Ptr<TensorAllocator>> paramsAllocs_; // [deviceIndex] we must hold a reference to the memory until this class dies
// model quantizer
std::vector<Ptr<ModelQuantizer>> quantizers_;
// state for update()
bool first_{ true }; // gets interpreted and cleared by update()
std::vector<Ptr<data::Batch>> pendingBatches_; // in case of dynamic MB-size scaling, we temporarly buffer up batches across update() calls until enough
double updateMultiplier_{1}; // multiplier not applied in collectStats() (no multiplier if not mini-batch-fit)
void initialize(const Ptr<data::Batch>& exampleBatch);
bool tryGetSubBatches(Ptr<data::Batch> newBatch, std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches);
void update(std::vector<Ptr<data::Batch>> subBatches, size_t numReadBatches);
public:
SyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi);
void setScheduler(Ptr<Scheduler> scheduler) override;
void update(Ptr<data::Batch> batch) override;
Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>>&) override;
void finalize() override;
// @TODO: consider to make this a virtual as well? Currently it is a template dispatch
};
} // namespace marian