Class ReshapeNodeOp

Inheritance Relationships

Base Type

Derived Type

Class Documentation

class ReshapeNodeOp : public marian::UnaryNodeOp

Subclassed by marian::CallbackNodeOp

Public Functions

ReshapeNodeOp(Expr a, Shape shape)
~ReshapeNodeOp()
void allocate()
void free()
void forward()
void backward()
void init_dependent()

Initialization for backward step of top node in computation graph.

Allocates memory and sets gradient to 1 (df/df == 1).

void set_zero_adjoint()

Initialization for backward step of any non-top node in computation graph.

Allocates memory and sets gradient to 0 for further accumulation of gradients from all parents.

Tensor &val()
Tensor &grad()
const std::string type()
const std::string color()
virtual size_t hash()
virtual bool equal(Expr node)

Protected Attributes

Expr reshapee_

Friends

friend marian::ReshapeNodeOp::SerializationHelpers