Class AsyncGraphGroup

Inheritance Relationships

Base Type

Class Documentation

class AsyncGraphGroup : public marian::GraphGroup

Public Functions

void setScheduler(Ptr<Scheduler> scheduler)
AsyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi)
void update(Ptr<data::Batch> batch)
Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>> &vocabs)
void finalize()

Protected Functions

void fetchParams(Tensor oldParams, const std::vector<Tensor> &params, int device_id)
void pushGradients(Tensor newGrads, int device_id, size_t mbSize)
void init(Ptr<data::Batch> batch)
void execute(Ptr<data::Batch> batch)

Protected Attributes

bool first_ = {true}
std::mutex sync_
std::vector<std::mutex> shardSync_
std::mutex schedulerMutex_
std::vector<Tensor> params_
std::vector<Ptr<TensorAllocator>> paramsAlloc_
std::vector<Tensor> grads_
std::vector<Ptr<TensorAllocator>> gradsAlloc_
int shardSize_
std::unique_ptr<ThreadPool> pool_
size_t optimizerDelay_ = {1}