Program Listing for File cells.cpp

Return to documentation for file (src/rnn/cells.cpp)

#include "rnn/cells.h"

#include "graph/node_operators_binary.h"
#include "tensors/tensor_operators.h"

namespace marian {
namespace rnn {

struct GRUFastNodeOp : public NaryNodeOp {
  bool final_;

  GRUFastNodeOp(const std::vector<Expr>& nodes, bool final)
      : NaryNodeOp(nodes), final_(final) {}

  NodeOps forwardOps() override {
    std::vector<Tensor> inputs;
    for(size_t i = 0; i < children_.size(); ++i)
      inputs.push_back(child(i)->val());

    return {NodeOp(GRUFastForward(val_, inputs, final_))};
  }

  NodeOps backwardOps() override {
    std::vector<Tensor> inputs;
    std::vector<Tensor> outputs;
    for(auto child : children_) {
      inputs.push_back(child->val());
      if(child->trainable())
        outputs.push_back(child->grad());
      else
        outputs.push_back(nullptr);
    }

    return {NodeOp(GRUFastBackward(graph()->allocator(), outputs, inputs, adj_, final_))};
  }

  // do not check if node is trainable
  virtual void runBackward(const NodeOps& ops) override {
    for(auto&& op : ops)
      op();
  }

  const std::string type() override { return "GRU-ops"; }

  const std::string color() override { return "yellow"; }

  virtual size_t hash() override {
    size_t seed = NaryNodeOp::hash();
    util::hash_combine(seed, final_);
    return seed;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<GRUFastNodeOp>(node);
    if(!cnode)
      return false;
    if(final_ != cnode->final_)
      return false;
    return true;
  }
};

Expr gruOps(const std::vector<Expr>& nodes, bool final) {
  return Expression<GRUFastNodeOp>(nodes, final);
}

/******************************************************************************/

struct LSTMCellNodeOp : public NaryNodeOp {
  LSTMCellNodeOp(const std::vector<Expr>& nodes) : NaryNodeOp(nodes) {}

  NodeOps forwardOps() override {
    std::vector<Tensor> inputs;
    for(size_t i = 0; i < children_.size(); ++i)
      inputs.push_back(child(i)->val());

    return {NodeOp(LSTMCellForward(val_, inputs))};
  }

  NodeOps backwardOps() override {
    std::vector<Tensor> inputs;
    std::vector<Tensor> outputs;
    for(auto child : children_) {
      inputs.push_back(child->val());
      if(child->trainable())
        outputs.push_back(child->grad());
      else
        outputs.push_back(nullptr);
    }

    return {NodeOp(LSTMCellBackward(outputs, inputs, adj_))};
  }

  // do not check if node is trainable
  virtual void runBackward(const NodeOps& ops) override {
    for(auto&& op : ops)
      op();
  }

  const std::string type() override { return "LSTM-cell-ops"; }

  const std::string color() override { return "yellow"; }
};

struct LSTMOutputNodeOp : public NaryNodeOp {
  LSTMOutputNodeOp(const std::vector<Expr>& nodes) : NaryNodeOp(nodes) {}

  NodeOps forwardOps() override {
    std::vector<Tensor> inputs;
    for(size_t i = 0; i < children_.size(); ++i)
      inputs.push_back(child(i)->val());

    return {NodeOp(LSTMOutputForward(val_, inputs))};
  }

  NodeOps backwardOps() override {
    std::vector<Tensor> inputs;
    std::vector<Tensor> outputs;
    for(auto child : children_) {
      inputs.push_back(child->val());
      if(child->trainable())
        outputs.push_back(child->grad());
      else
        outputs.push_back(nullptr);
    }

    return {NodeOp(LSTMOutputBackward(outputs, inputs, adj_))};
  }

  // do not check if node is trainable
  virtual void runBackward(const NodeOps& ops) override {
    for(auto&& op : ops)
      op();
  }

  const std::string type() override { return "LSTM-output-ops"; }

  const std::string color() override { return "yellow"; }
};

Expr lstmOpsC(const std::vector<Expr>& nodes) {
  return Expression<LSTMCellNodeOp>(nodes);
}

Expr lstmOpsO(const std::vector<Expr>& nodes) {
  return Expression<LSTMOutputNodeOp>(nodes);
}
}  // namespace rnn
}  // namespace marian