.. _program_listing_file_src_training_graph_group_async.h: Program Listing for File graph_group_async.h ============================================ |exhale_lsh| :ref:`Return to documentation for file ` (``src/training/graph_group_async.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "3rd_party/threadpool.h" #include "training/graph_group.h" #include #include namespace marian { class AsyncGraphGroup : public GraphGroup { public: virtual void setScheduler(Ptr scheduler) override; protected: bool first_{true}; std::mutex sync_; std::vector shardSync_; std::mutex schedulerMutex_; std::vector params_; std::vector> paramsAlloc_; std::vector grads_; std::vector> gradsAlloc_; int shardSize_; std::unique_ptr pool_; size_t optimizerDelay_{1}; virtual void fetchParams(Tensor oldParams, const std::vector& params, int device_id); virtual void pushGradients(Tensor newGrads, int device_id, size_t mbSize); virtual void init(Ptr batch); void execute(Ptr batch); public: AsyncGraphGroup(Ptr config, Ptr mpi); void update(Ptr batch) override { validate(); execute(batch); } // @TODO: give it a fake batch generator which own vocabs instead of passing vocabs Ptr collectStats(const std::vector>& vocabs) override { return GraphGroup::collectStats(graphs_[0], models_[0], vocabs); } virtual void finalize() override; }; } // namespace marian