Class NCCLCommunicator

Inheritance Relationships

Base Type

Class Documentation

class NCCLCommunicator : public marian::ICommunicator

Public Functions

NCCLCommunicator(const std::vector<Ptr<ExpressionGraph>> &graphs, ShardingMode shardingMode, Ptr<IMPIWrapper> mpi)
~NCCLCommunicator()
template<typename Ret>
Ret foreachAcc(const ForeachFunc<Ret> &func, const AccFunc<Ret> &acc, Ret init, bool parallel = true) const
float foreach(const ForeachFunc<float> &func, AccFunc<float> acc, float init, bool parallel = true) const
bool foreach(const ForeachFunc<bool> &func, bool parallel = true) const
void scatterReduceAndResetGrads() const
void allGatherParams() const
void broadcastParams(bool average = false) const
void broadcastShards(const std::vector<Ptr<OptimizerBase>> &opts, bool average = false) const
void scatterState(const io::Item &data, const OptimizerBase::ScatterStateSetFunc &setFn) const
io::Item gatherState(const OptimizerBase::GatherStateGetFunc &getFn) const