.. _program_listing_file_src_rnn_cells.cpp: Program Listing for File cells.cpp ================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/rnn/cells.cpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "rnn/cells.h" #include "graph/node_operators_binary.h" #include "tensors/tensor_operators.h" namespace marian { namespace rnn { struct GRUFastNodeOp : public NaryNodeOp { bool final_; GRUFastNodeOp(const std::vector& nodes, bool final) : NaryNodeOp(nodes), final_(final) {} NodeOps forwardOps() override { std::vector inputs; for(size_t i = 0; i < children_.size(); ++i) inputs.push_back(child(i)->val()); return {NodeOp(GRUFastForward(val_, inputs, final_))}; } NodeOps backwardOps() override { std::vector inputs; std::vector outputs; for(auto child : children_) { inputs.push_back(child->val()); if(child->trainable()) outputs.push_back(child->grad()); else outputs.push_back(nullptr); } return {NodeOp(GRUFastBackward(graph()->allocator(), outputs, inputs, adj_, final_))}; } // do not check if node is trainable virtual void runBackward(const NodeOps& ops) override { for(auto&& op : ops) op(); } const std::string type() override { return "GRU-ops"; } const std::string color() override { return "yellow"; } virtual size_t hash() override { size_t seed = NaryNodeOp::hash(); util::hash_combine(seed, final_); return seed; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(final_ != cnode->final_) return false; return true; } }; Expr gruOps(const std::vector& nodes, bool final) { return Expression(nodes, final); } /******************************************************************************/ struct LSTMCellNodeOp : public NaryNodeOp { LSTMCellNodeOp(const std::vector& nodes) : NaryNodeOp(nodes) {} NodeOps forwardOps() override { std::vector inputs; for(size_t i = 0; i < children_.size(); ++i) inputs.push_back(child(i)->val()); return {NodeOp(LSTMCellForward(val_, inputs))}; } NodeOps backwardOps() override { std::vector inputs; std::vector outputs; for(auto child : children_) { inputs.push_back(child->val()); if(child->trainable()) outputs.push_back(child->grad()); else outputs.push_back(nullptr); } return {NodeOp(LSTMCellBackward(outputs, inputs, adj_))}; } // do not check if node is trainable virtual void runBackward(const NodeOps& ops) override { for(auto&& op : ops) op(); } const std::string type() override { return "LSTM-cell-ops"; } const std::string color() override { return "yellow"; } }; struct LSTMOutputNodeOp : public NaryNodeOp { LSTMOutputNodeOp(const std::vector& nodes) : NaryNodeOp(nodes) {} NodeOps forwardOps() override { std::vector inputs; for(size_t i = 0; i < children_.size(); ++i) inputs.push_back(child(i)->val()); return {NodeOp(LSTMOutputForward(val_, inputs))}; } NodeOps backwardOps() override { std::vector inputs; std::vector outputs; for(auto child : children_) { inputs.push_back(child->val()); if(child->trainable()) outputs.push_back(child->grad()); else outputs.push_back(nullptr); } return {NodeOp(LSTMOutputBackward(outputs, inputs, adj_))}; } // do not check if node is trainable virtual void runBackward(const NodeOps& ops) override { for(auto&& op : ops) op(); } const std::string type() override { return "LSTM-output-ops"; } const std::string color() override { return "yellow"; } }; Expr lstmOpsC(const std::vector& nodes) { return Expression(nodes); } Expr lstmOpsO(const std::vector& nodes) { return Expression(nodes); } } // namespace rnn } // namespace marian