Program Listing for File chainable.h

Return to documentation for file (src/graph/chainable.h)

#pragma once

#include "common/definitions.h"

#include <memory>
#include <vector>
#include <list>

namespace marian {

#define NodeOp(op) [=]() { op; }
typedef std::vector<std::function<void()>> NodeOps;

class AutoTunerRecorder;

template <class DataType>
class Chainable;
typedef IPtr<Chainable<Tensor>> Expr;
typedef IWeak<Chainable<Tensor>> WExpr;

class ExpressionGraph;

template <class DataType>
class Chainable {
private:
 ENABLE_INTRUSIVE_PTR(Chainable<DataType>)

public:
  Chainable() {}
  virtual ~Chainable(){};

  virtual void forward() = 0;
  virtual void backward() = 0;
  virtual NodeOps forwardOps() = 0;
  virtual NodeOps backwardOps() = 0;

  virtual void allocate() = 0;
  virtual void free() = 0;
  virtual void init() = 0;
  virtual void init_dependent() {}
  virtual void set_zero_adjoint() {}

  virtual bool trainable() = 0;
  virtual void setTrainable(bool) = 0;

  virtual bool memoize() = 0;
  virtual void setMemoize(bool) = 0;

  virtual void setId(size_t) = 0;
  virtual size_t getId() = 0;

  // virtual const std::string& type() = 0;
  virtual Ptr<ExpressionGraph> graph() = 0;

  virtual const Shape& shape() = 0;
  virtual const Type& value_type() = 0;

  virtual std::vector<Expr>& children() = 0;
  virtual Expr child(size_t) = 0;
  virtual DataType& val() = 0;
  virtual DataType& grad() = 0;
  virtual float scalar() = 0;

  virtual const std::string type() = 0;
  virtual const std::string color() = 0;
  virtual const std::string form() = 0;
  virtual const std::string label() = 0;
  virtual std::string graphviz() = 0;

  virtual void set_name(const std::string&) = 0;
  virtual const std::string& name() const = 0;

  virtual void debug(const std::string& message) = 0;
  virtual bool marked_for_debug() = 0;
  virtual const std::string& debug_message() = 0;

  virtual size_t hash() = 0;
  virtual bool equal(Expr) = 0;

  virtual void record(Ptr<AutoTunerRecorder>, size_t, bool) = 0;

  virtual void markCheckpoint() = 0;
  virtual bool isCheckpoint() const = 0;
  virtual void setSubtape(Ptr<std::list<Expr>>) = 0;
  virtual Ptr<std::list<Expr>> getSubtape() = 0;
};
}  // namespace marian