Program Listing for File graph_group_async.h

Return to documentation for file (src/training/graph_group_async.h)

#pragma once

#include "3rd_party/threadpool.h"
#include "training/graph_group.h"

#include <future>
#include <thread>

namespace marian {

class AsyncGraphGroup : public GraphGroup {
public:
  virtual void setScheduler(Ptr<Scheduler> scheduler) override;

protected:
  bool first_{true};

  std::mutex sync_;
  std::vector<std::mutex> shardSync_;

  std::mutex schedulerMutex_;

  std::vector<Tensor> params_;
  std::vector<Ptr<TensorAllocator>> paramsAlloc_;

  std::vector<Tensor> grads_;
  std::vector<Ptr<TensorAllocator>> gradsAlloc_;

  int shardSize_;

  std::unique_ptr<ThreadPool> pool_;

  size_t optimizerDelay_{1};

  virtual void fetchParams(Tensor oldParams,
                           const std::vector<Tensor>& params,
                           int device_id);

  virtual void pushGradients(Tensor newGrads,
                             int device_id,
                             size_t mbSize);

  virtual void init(Ptr<data::Batch> batch);
  void execute(Ptr<data::Batch> batch);

public:
  AsyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi);

  void update(Ptr<data::Batch> batch) override {
    validate();
    execute(batch);
  }

  // @TODO: give it a fake batch generator which own vocabs instead of passing vocabs
  Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>>& vocabs) override {
    return GraphGroup::collectStats(graphs_[0], models_[0], vocabs);
  }

  virtual void finalize() override;
};

}  // namespace marian