Program Listing for File node_operators.cpp

Return to documentation for file (src/graph/node_operators.cpp)

#include "node_operators.h"
#include "expression_graph.h"

#include "tensors/tensor_operators.h"

namespace marian {

ConstantNode::ConstantNode(Ptr<ExpressionGraph> graph,
                          const Shape& shape,
                          const Ptr<inits::NodeInitializer>& init,
                          Type valueType)
    : Node(graph, shape, valueType),
      init_(init),
      initialized_(false) {
  init_->setAllocator(graph->allocator());
  setTrainable(false);
}

void ConstantNode::allocate() {
  if(!val_) {
    graph()->allocateForward(this);
  }
}

void ConstantNode::init() {
  if(!initialized_) {
    init_->apply(val_);
    initialized_ = true;
  }
  init_.reset();
}

ParamNode::ParamNode(Ptr<ExpressionGraph> graph,
                     const Shape& shape,
                     const Ptr<inits::NodeInitializer>& init,
                     bool fixed)
    : ParamNode(graph, shape, init, Type::float32, fixed) {}

ParamNode::ParamNode(Ptr<ExpressionGraph> graph,
                     const Shape& shape,
                     const Ptr<inits::NodeInitializer>& init,
                     Type valueType,
                     bool fixed)
    : Node(graph, shape, valueType),
      init_(init),
      initialized_(false) {
  init_->setAllocator(graph->allocator());
  setTrainable(!fixed);
  setMemoize(graph->isInference());
}

void ParamNode::init() {
  if(!initialized_) {
    init_->apply(val_);
    initialized_ = true;
  }
  init_.reset();
}
}  // namespace marian