Program Listing for File chainable.h¶
↰ Return to documentation for file (src/graph/chainable.h
)
#pragma once
#include "common/definitions.h"
#include <memory>
#include <vector>
#include <list>
namespace marian {
#define NodeOp(op) [=]() { op; }
typedef std::vector<std::function<void()>> NodeOps;
class AutoTunerRecorder;
template <class DataType>
class Chainable;
typedef IPtr<Chainable<Tensor>> Expr;
typedef IWeak<Chainable<Tensor>> WExpr;
class ExpressionGraph;
template <class DataType>
class Chainable {
private:
ENABLE_INTRUSIVE_PTR(Chainable<DataType>)
public:
Chainable() {}
virtual ~Chainable(){};
virtual void forward() = 0;
virtual void backward() = 0;
virtual NodeOps forwardOps() = 0;
virtual NodeOps backwardOps() = 0;
virtual void allocate() = 0;
virtual void free() = 0;
virtual void init() = 0;
virtual void init_dependent() {}
virtual void set_zero_adjoint() {}
virtual bool trainable() = 0;
virtual void setTrainable(bool) = 0;
virtual bool memoize() = 0;
virtual void setMemoize(bool) = 0;
virtual void setId(size_t) = 0;
virtual size_t getId() = 0;
// virtual const std::string& type() = 0;
virtual Ptr<ExpressionGraph> graph() = 0;
virtual const Shape& shape() = 0;
virtual const Type& value_type() = 0;
virtual std::vector<Expr>& children() = 0;
virtual Expr child(size_t) = 0;
virtual DataType& val() = 0;
virtual DataType& grad() = 0;
virtual float scalar() = 0;
virtual const std::string type() = 0;
virtual const std::string color() = 0;
virtual const std::string form() = 0;
virtual const std::string label() = 0;
virtual std::string graphviz() = 0;
virtual void set_name(const std::string&) = 0;
virtual const std::string& name() const = 0;
virtual void debug(const std::string& message) = 0;
virtual bool marked_for_debug() = 0;
virtual const std::string& debug_message() = 0;
virtual size_t hash() = 0;
virtual bool equal(Expr) = 0;
virtual void record(Ptr<AutoTunerRecorder>, size_t, bool) = 0;
virtual void markCheckpoint() = 0;
virtual bool isCheckpoint() const = 0;
virtual void setSubtape(Ptr<std::list<Expr>>) = 0;
virtual Ptr<std::list<Expr>> getSubtape() = 0;
};
} // namespace marian