Class OptimizerBase

Inheritance Relationships

Base Types

Derived Types

Class Documentation

class OptimizerBase : public marian::TrainingObserver, public marian::ExponentialSmoothing

Base class for optimizers.

Subclassed by marian::Adagrad, marian::Adam, marian::Sgd

Public Types

typedef std::function<void(size_t, const char *, const char *)> ScatterStateSetFunc
typedef std::function<io::Item(size_t)> GatherStateGetFunc
typedef std::function<void(const io::Item&, const ScatterStateSetFunc&)> ScatterStateFunc
typedef std::function<io::Item(const GatherStateGetFunc&)> GatherStateFunc

Public Functions

OptimizerBase(Ptr<Options> options)
virtual ~OptimizerBase()
float update(Ptr<ExpressionGraph> graph, size_t mbSize, float costScaleFactor = 1.f)
float update(Tensor params, Tensor grads, size_t mbSize, float costScaleFactor = 1.f)
virtual void init(TrainingState &state)
virtual void actAfterLoaded(TrainingState &state)
virtual void actAfterEpoch(TrainingState &state)
virtual void actAfterBatches(TrainingState &state)
virtual void actAfterStalled(TrainingState &state)
virtual void setParams(const std::vector<float> &params) = 0
void load(std::vector<io::Item> &items, const std::vector<Ptr<OptimizerBase>> &opts, const std::vector<Ptr<Backend>> &backends, const ScatterStateFunc &scatterFn, bool isMainProcess)
void save(std::vector<io::Item> &items, const std::vector<Ptr<OptimizerBase>> &opts, const GatherStateFunc &gatherFn, bool isMainProcess)
void swapWithSmoothed(Tensor params)
virtual std::vector<Tensor> getShards()

Protected Functions

virtual void updateImpl(Tensor params, Tensor grads, size_t actualMBSize) = 0
virtual void resetStats() = 0

Protected Attributes

Ptr<Options> options_
float eta_
size_t refMBWordsParam_ = {0}
size_t batchesSeen_ = {0}
bool normalizedGradient_ = {false}
Type optimizerType_ = {Type::float32}
bool castOptimizerType_ = {false}
Ptr<Clipper> clipper_
Ptr<TensorAllocator> baseAlloc_
Ptr<Allocator> alloc_
Tensor avg_
Tensor pm_
Tensor gd_