.. _program_listing_file_src_graph_node.h: Program Listing for File node.h =============================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/graph/node.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include "common/hash.h" #include "tensors/backend.h" #include "tensors/tensor.h" #include "graph/chainable.h" namespace marian { class Node : public Chainable { protected: size_t id_{0}; size_t edges_{0}; bool trainable_{true}; bool destroy_{true}; bool memoize_{false}; std::vector children_; Weak graph_; Shape shape_{1, 1, 1, 1}; // defines the dimensionality of the node (for tensors) Type valueType_{Type::float32}; // defines the element type of the node (for tensors) std::string name_{"none"}; Tensor val_{nullptr}; // the resulting new tensor in forward pass Tensor adj_{nullptr}; // the accumulated gradients (a tensor) in backward pass bool markedForDebug_{false}; std::string debugMessage_; Ptr> subtape_; // a subtape is used to keep track of nodes that need to be freed and recomputed with gradient-checkpointing. bool isCheckpoint_{false}; // true if this node has been selected to be a checkpoint, currently only done manually. Ptr recorder_; size_t recorderHash_; bool recorderStop_; public: Node(Ptr graph, const Shape& shape, const Type& valueType = Type::float32) : graph_(graph), shape_(shape), valueType_(valueType) {} virtual ~Node() { free(); } virtual float scalar() override; virtual NodeOps forwardOps() override { return {}; }; virtual NodeOps backwardOps() override { return {}; }; virtual void runForward(const NodeOps& ops) { for(auto&& op : ops) op(); } virtual void runBackward(const NodeOps& ops) { size_t i = 0; for(auto&& op : ops) if(child(i++)->trainable()) op(); } virtual void forward() override; virtual void backward() override; virtual bool trainable() override { return trainable_; } virtual void setTrainable(bool trainable) override { trainable_ = trainable; } virtual bool memoize() override { return memoize_; }; virtual void setMemoize(bool memoize) override { memoize_ = memoize; }; virtual void setId(size_t id) override { id_ = id; } virtual size_t getId() override { return id_; } virtual void increaseEdges(size_t edges = 1) { edges_ += edges; }; virtual void decreaseEdges(size_t edges = 1) { edges_ -= edges; }; virtual size_t edges() { return edges_; }; virtual Ptr graph() override { return graph_.lock(); } virtual void debug(const std::string& message) override { debugMessage_ = message; markedForDebug_ = true; } virtual bool marked_for_debug() override { return markedForDebug_; } virtual const std::string& debug_message() override { return debugMessage_; } virtual void allocate() override; virtual void free() override; virtual void init() override {}; virtual void init_dependent() override; virtual void set_zero_adjoint() override; virtual Tensor& val() override { return val_; }; virtual Tensor& grad() override { return adj_; }; virtual const Shape& shape() override { return shape_; } virtual const Type& value_type() override { return valueType_; } void set_name(const std::string& name) override { name_ = name; } const std::string& name() const override { return name_; } virtual const std::string form() override { return "box"; } virtual const std::string color() override { return "orange"; } virtual const std::string label() override { std::stringstream label; label << "<" << type(); if(name_ != "none") { label << "
" << "\"" << name_ << "\""; } label << " (" << getId() << "/" << trainable() << ")>"; return label.str(); } virtual std::string graphviz() override { std::stringstream ss; ss << "\"" << this << "\" [" << "shape=\"" << form() << "\", " << "label=" << label() << ", " << "style=\"filled\", " << (isCheckpoint_ ? "penwidth=3, " : "penwidth=1, ") << "fillcolor=\"" << color() << "\"];" << std::endl; for(auto&& child : children()) ss << "\"" << child << "\" -> \"" << this << "\";" << std::endl; if(subtape_) { for(auto&& dep : *subtape_) ss << "\"" << dep << "\" -> \"" << this << "\" [style=dotted];" << std::endl; } ss << std::endl; return ss.str(); } virtual std::vector& children() override { return children_; } virtual Expr child(size_t i) override { return children_[i]; } Ptr getBackend(); void record(Ptr, size_t, bool) override; // this is currently only called manually by checkpoint(Expr). In the future we will figure out a general algorithm virtual void markCheckpoint() override { isCheckpoint_ = true; } virtual bool isCheckpoint() const override { return (children_.empty() || isCheckpoint_); // this node is a checkPoint if it's a leaf or if it has been marked. } virtual void setSubtape(Ptr> subtape) override { subtape_ = subtape; } virtual Ptr> getSubtape() override { return subtape_; }; }; struct NaryNodeOp : public Node { size_t hash_{0}; // Deduce type automatically, but then all types must be the same // this is called automatically when no output type is specified. // If the input types are mixed, the output type needs to be specified // in the constructor. static Type commonType(const std::vector& nodes) { ABORT_IF(nodes.size() == 0, "NaryNodeOp has no children"); Type type = nodes[0]->value_type(); for(int i = 1; i < nodes.size(); ++i) ABORT_IF(nodes[i]->value_type() != type, "Child {} has different type (first: {} != child: {})", i, type, nodes[i]->value_type()); return type; } NaryNodeOp(const std::vector& nodes) : NaryNodeOp(nodes, nodes[0]->shape()) {} // this contructor will try to deduce the node type automatically NaryNodeOp(const std::vector& nodes, Shape shape) : NaryNodeOp(nodes, shape, commonType(nodes)) {} // this contructor will takes a node type NaryNodeOp(const std::vector& nodes, Shape shape, Type value_type) : Node(nodes.front()->graph(), shape, value_type) { children_.resize(nodes.size()); for(size_t i = 0; i < nodes.size(); ++i) children_[i] = nodes[i]; setTrainable(std::any_of( nodes.begin(), nodes.end(), [](Expr a) { return a->trainable(); })); // Node is to be memoized if all children are to be memoized. setMemoize(std::all_of( nodes.begin(), nodes.end(), [](Expr a) { return a->memoize(); })); } virtual ~NaryNodeOp() {} std::vector& children() override { return children_; } virtual size_t hash() override { if(!hash_) { std::size_t seed = util::hash()(name()); util::hash_combine(seed, type()); util::hash_combine(seed, (size_t)value_type()); for(size_t i = 0; i < children_.size(); ++i) util::hash_combine(seed, child(i)->hash()); hash_ = seed; } return hash_; } virtual bool equal(Expr node) override { if(type() != node->type()) return false; else if(name() != node->name()) return false; else if(value_type() != node->value_type()) return false; else if(children().size() != node->children().size()) return false; else { for(size_t i = 0; i < children().size(); ++i) if(children()[i]->getId() != node->children()[i]->getId()) return false; return true; } } }; } // namespace marian