Program Listing for File node.h¶
↰ Return to documentation for file (src/graph/node.h
)
#pragma once
#include <iostream>
#include <memory>
#include <thread>
#include "common/hash.h"
#include "tensors/backend.h"
#include "tensors/tensor.h"
#include "graph/chainable.h"
namespace marian {
class Node : public Chainable<Tensor> {
protected:
size_t id_{0};
size_t edges_{0};
bool trainable_{true};
bool destroy_{true};
bool memoize_{false};
std::vector<Expr> children_;
Weak<ExpressionGraph> graph_;
Shape shape_{1, 1, 1, 1}; // defines the dimensionality of the node (for tensors)
Type valueType_{Type::float32}; // defines the element type of the node (for tensors)
std::string name_{"none"};
Tensor val_{nullptr}; // the resulting new tensor in forward pass
Tensor adj_{nullptr}; // the accumulated gradients (a tensor) in backward pass
bool markedForDebug_{false};
std::string debugMessage_;
Ptr<std::list<Expr>> subtape_; // a subtape is used to keep track of nodes that need to be freed and recomputed with gradient-checkpointing.
bool isCheckpoint_{false}; // true if this node has been selected to be a checkpoint, currently only done manually.
Ptr<AutoTunerRecorder> recorder_;
size_t recorderHash_;
bool recorderStop_;
public:
Node(Ptr<ExpressionGraph> graph, const Shape& shape, const Type& valueType = Type::float32)
: graph_(graph), shape_(shape), valueType_(valueType) {}
virtual ~Node() {
free();
}
virtual float scalar() override;
virtual NodeOps forwardOps() override { return {}; };
virtual NodeOps backwardOps() override { return {}; };
virtual void runForward(const NodeOps& ops) {
for(auto&& op : ops)
op();
}
virtual void runBackward(const NodeOps& ops) {
size_t i = 0;
for(auto&& op : ops)
if(child(i++)->trainable())
op();
}
virtual void forward() override;
virtual void backward() override;
virtual bool trainable() override { return trainable_; }
virtual void setTrainable(bool trainable) override { trainable_ = trainable; }
virtual bool memoize() override { return memoize_; };
virtual void setMemoize(bool memoize) override { memoize_ = memoize; };
virtual void setId(size_t id) override { id_ = id; }
virtual size_t getId() override { return id_; }
virtual void increaseEdges(size_t edges = 1) { edges_ += edges; };
virtual void decreaseEdges(size_t edges = 1) { edges_ -= edges; };
virtual size_t edges() { return edges_; };
virtual Ptr<ExpressionGraph> graph() override { return graph_.lock(); }
virtual void debug(const std::string& message) override {
debugMessage_ = message;
markedForDebug_ = true;
}
virtual bool marked_for_debug() override { return markedForDebug_; }
virtual const std::string& debug_message() override { return debugMessage_; }
virtual void allocate() override;
virtual void free() override;
virtual void init() override {};
virtual void init_dependent() override;
virtual void set_zero_adjoint() override;
virtual Tensor& val() override { return val_; };
virtual Tensor& grad() override { return adj_; };
virtual const Shape& shape() override { return shape_; }
virtual const Type& value_type() override { return valueType_; }
void set_name(const std::string& name) override { name_ = name; }
const std::string& name() const override { return name_; }
virtual const std::string form() override { return "box"; }
virtual const std::string color() override { return "orange"; }
virtual const std::string label() override {
std::stringstream label;
label << "<" << type();
if(name_ != "none") {
label << "<br/>"
<< "\"" << name_ << "\"";
}
label << " (" << getId() << "/" << trainable() << ")>";
return label.str();
}
virtual std::string graphviz() override {
std::stringstream ss;
ss << "\"" << this << "\" ["
<< "shape=\"" << form() << "\", "
<< "label=" << label() << ", "
<< "style=\"filled\", "
<< (isCheckpoint_ ? "penwidth=3, " : "penwidth=1, ")
<< "fillcolor=\"" << color() << "\"];" << std::endl;
for(auto&& child : children())
ss << "\"" << child << "\" -> \"" << this << "\";" << std::endl;
if(subtape_) {
for(auto&& dep : *subtape_)
ss << "\"" << dep << "\" -> \"" << this << "\" [style=dotted];" << std::endl;
}
ss << std::endl;
return ss.str();
}
virtual std::vector<Expr>& children() override { return children_; }
virtual Expr child(size_t i) override { return children_[i]; }
Ptr<Backend> getBackend();
void record(Ptr<AutoTunerRecorder>, size_t, bool) override;
// this is currently only called manually by checkpoint(Expr). In the future we will figure out a general algorithm
virtual void markCheckpoint() override {
isCheckpoint_ = true;
}
virtual bool isCheckpoint() const override {
return (children_.empty() || isCheckpoint_); // this node is a checkPoint if it's a leaf or if it has been marked.
}
virtual void setSubtape(Ptr<std::list<Expr>> subtape) override {
subtape_ = subtape;
}
virtual Ptr<std::list<Expr>> getSubtape() override {
return subtape_;
};
};
struct NaryNodeOp : public Node {
size_t hash_{0};
// Deduce type automatically, but then all types must be the same
// this is called automatically when no output type is specified.
// If the input types are mixed, the output type needs to be specified
// in the constructor.
static Type commonType(const std::vector<Expr>& nodes) {
ABORT_IF(nodes.size() == 0, "NaryNodeOp has no children");
Type type = nodes[0]->value_type();
for(int i = 1; i < nodes.size(); ++i)
ABORT_IF(nodes[i]->value_type() != type,
"Child {} has different type (first: {} != child: {})",
i, type, nodes[i]->value_type());
return type;
}
NaryNodeOp(const std::vector<Expr>& nodes)
: NaryNodeOp(nodes, nodes[0]->shape()) {}
// this contructor will try to deduce the node type automatically
NaryNodeOp(const std::vector<Expr>& nodes, Shape shape)
: NaryNodeOp(nodes, shape, commonType(nodes)) {}
// this contructor will takes a node type
NaryNodeOp(const std::vector<Expr>& nodes,
Shape shape,
Type value_type)
: Node(nodes.front()->graph(), shape, value_type) {
children_.resize(nodes.size());
for(size_t i = 0; i < nodes.size(); ++i)
children_[i] = nodes[i];
setTrainable(std::any_of(
nodes.begin(), nodes.end(), [](Expr a) { return a->trainable(); }));
// Node is to be memoized if all children are to be memoized.
setMemoize(std::all_of(
nodes.begin(), nodes.end(), [](Expr a) { return a->memoize(); }));
}
virtual ~NaryNodeOp() {}
std::vector<Expr>& children() override { return children_; }
virtual size_t hash() override {
if(!hash_) {
std::size_t seed = util::hash<std::string>()(name());
util::hash_combine(seed, type());
util::hash_combine(seed, (size_t)value_type());
for(size_t i = 0; i < children_.size(); ++i)
util::hash_combine(seed, child(i)->hash());
hash_ = seed;
}
return hash_;
}
virtual bool equal(Expr node) override {
if(type() != node->type())
return false;
else if(name() != node->name())
return false;
else if(value_type() != node->value_type())
return false;
else if(children().size() != node->children().size())
return false;
else {
for(size_t i = 0; i < children().size(); ++i)
if(children()[i]->getId() != node->children()[i]->getId())
return false;
return true;
}
}
};
} // namespace marian