Class Node

Inheritance Relationships

Base Type

Derived Types

Class Documentation

class Node : public marian::Chainable<Tensor>

Main node class for computation graph, implements most common functions demanded by Chainable.

Each operation in a computation graph is a node.

Subclassed by marian::ConstantNode, marian::NaryNodeOp, marian::ParamNode

Public Functions

Node(Ptr<ExpressionGraph> graph, const Shape &shape, const Type &valueType = Type::float32)
virtual ~Node()
float scalar()
virtual NodeOps forwardOps()
virtual NodeOps backwardOps()
virtual void runForward(const NodeOps &ops)
virtual void runBackward(const NodeOps &ops)
void forward()
void backward()
virtual bool trainable()
virtual void setTrainable(bool trainable)
virtual bool memoize()
virtual void setMemoize(bool memoize)
virtual void setId(size_t id)
virtual size_t getId()
virtual void increaseEdges(size_t edges = 1)
virtual void decreaseEdges(size_t edges = 1)
virtual size_t edges()
virtual Ptr<ExpressionGraph> graph()
virtual void debug(const std::string &message)
virtual bool marked_for_debug()
virtual const std::string &debug_message()
void allocate()
void free()
virtual void init()
void init_dependent()

Initialization for backward step of top node in computation graph.

Allocates memory and sets gradient to 1 (df/df == 1).

void set_zero_adjoint()

Initialization for backward step of any non-top node in computation graph.

Allocates memory and sets gradient to 0 for further accumulation of gradients from all parents.

virtual Tensor &val()
virtual Tensor &grad()
virtual const Shape &shape()
virtual const Type &value_type()
void set_name(const std::string &name)
const std::string &name() const
virtual const std::string form()
virtual const std::string color()
virtual const std::string label()
virtual std::string graphviz()
virtual std::vector<Expr> &children()
virtual Expr child(size_t i)
Ptr<Backend> getBackend()
void record(Ptr<AutoTunerRecorder> recorder, size_t recorderHash, bool stop)
virtual void markCheckpoint()
virtual bool isCheckpoint() const
virtual void setSubtape(Ptr<std::list<Expr>> subtape)
virtual Ptr<std::list<Expr>> getSubtape()

Protected Attributes

size_t id_ = {0}
size_t edges_ = {0}
bool trainable_ = {true}
bool destroy_ = {true}
bool memoize_ = {false}
std::vector<Expr> children_
Weak<ExpressionGraph> graph_
Shape shape_ = {1, 1, 1, 1}
Type valueType_ = {Type::float32}
std::string name_ = {"none"}
Tensor val_ = {nullptr}
Tensor adj_ = {nullptr}
bool markedForDebug_ = {false}
std::string debugMessage_
Ptr<std::list<Expr>> subtape_
bool isCheckpoint_ = {false}
Ptr<AutoTunerRecorder> recorder_
size_t recorderHash_
bool recorderStop_