Program Listing for File node_operators_unary.h

Return to documentation for file (src/graph/node_operators_unary.h)

#pragma once

#include "tensors/backend.h"
#include "tensors/tensor.h"

#include "functional/functional.h"
#include "graph/node.h"
#include "tensors/tensor_operators.h"

#ifdef CUDNN
#include "tensors/gpu/cudnn_wrappers.h"
#endif

namespace marian {

struct UnaryNodeOp : public NaryNodeOp {
  UnaryNodeOp(Expr a, Shape shape, Type value_type)
      : NaryNodeOp({a}, shape, value_type) {}

  UnaryNodeOp(Expr a, Type value_type)
      : NaryNodeOp({a}, a->shape(), value_type) {}

  UnaryNodeOp(Expr a, Shape shape)
      : NaryNodeOp({a}, shape, a->value_type()) {}

  UnaryNodeOp(Expr a)
      : NaryNodeOp({a}, a->shape(), a->value_type()) {}

  const std::string color() override { return "yellow"; }
};

struct ScalarAddNodeOp : public UnaryNodeOp {
private:
  friend class SerializationHelpers;
  float scalar_{0};

public:
  ScalarAddNodeOp(Expr a, float scalar) : UnaryNodeOp(a), scalar_{scalar} {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(_1, child(0)->grad(), adj_))};
  }

  const std::string type() override { return "scalar_add"; }

  virtual size_t hash() override {
    if(!hash_) {
      hash_ = NaryNodeOp::hash();
      util::hash_combine(hash_, scalar_);
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<ScalarAddNodeOp>(node);
    if(!cnode)
      return false;
    if(scalar_ != cnode->scalar_)
      return false;
    return true;
  }
};

// Cast a tensor to a different type
struct CastNodeOp : public UnaryNodeOp {
public:
  CastNodeOp(Expr a, Type type) : UnaryNodeOp(a, type) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return { NodeOp(CopyCast(val_, child(0)->val())) };
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return { NodeOp(AddCast(child(0)->grad(), adj_)) };
  }

  const std::string type() override { return "cast"; }
};

struct ScalarMultNodeOp : public UnaryNodeOp {
private:
  friend class SerializationHelpers;
  float scalar_{0};

public:
  ScalarMultNodeOp(Expr a, float scalar) : UnaryNodeOp(a), scalar_{scalar} {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = scalar_ * _2, val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))};
  }

  const std::string type() override { return "scalar_mult"; }

  virtual size_t hash() override {
    if(!hash_) {
      hash_ = NaryNodeOp::hash();
      util::hash_combine(hash_, scalar_);
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<ScalarMultNodeOp>(node);
    if(!cnode)
      return false;
    if(scalar_ != cnode->scalar_)
      return false;
    return true;
  }
};

struct ClipNodeOp : public UnaryNodeOp {
private:
  float clip_{0};

public:
  ClipNodeOp(Expr a, float clip) : UnaryNodeOp(a), clip_{clip} {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = clip(_2, clip_), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(
        Add(bump(_1, clip_) * _2, child(0)->grad(), child(0)->val(), adj_))};
  }

  const std::string type() override { return "clip"; }

  virtual size_t hash() override {
    if(!hash_) {
      hash_ = NaryNodeOp::hash();
      util::hash_combine(hash_, clip_);
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<ClipNodeOp>(node);
    if(!cnode)
      return false;
    if(clip_ != cnode->clip_)
      return false;
    return true;
  }
};

struct SigmoidNodeOp : public UnaryNodeOp {
  SigmoidNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = sigmoid(_2), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))};
  }

  const std::string type() override { return "sigmoid"; }
};

// struct Scalar2PowNodeOp : public UnaryNodeOp {
// private:
//  float scalar_{0};
//
// public:
//  template <typename... Args>
//  Scalar2PowNodeOp(Expr a, float scalar, Args... args)
//      : UnaryNodeOp(a, args...), scalar_{scalar} {}
//
//  NodeOps forwardOps() {
//    return {NodeOp(Element(_1 = Pow(_2, scalar_), val_, child(0)->val()))};
//  }
//
//  NodeOps backwardOps() {
//    return {NodeOp(Add(scalar_ * Pow(_1, scalar_ - 1.f) * _2,
//    child(0)->grad(), child(0)->val(), adj_))};
//  }
//
//  const std::string type() { return "scalar_pow2"; }
//};
//
// struct Scalar1PowNodeOp : public UnaryNodeOp {
// private:
//  float scalar_{0};
//
// public:
//  template <typename... Args>
//  Scalar1PowNodeOp(float scalar, Expr a, Args... args)
//      : UnaryNodeOp(a, args...), scalar_{scalar} {}
//
//  NodeOps forwardOps() {
//    return {NodeOp(Element(_1 = Pow(scalar_, _2), val_, child(0)->val()))};
//  }
//
//  NodeOps backwardOps() {
//    return {NodeOp(Add(Pow(scalar_, _1) * log(scalar_) * _2, child(0)->grad(),
//    child(0)->val(), adj_))};
//  }
//
//  const std::string type() { return "scalar_pow1"; }
//};

struct TanhNodeOp : public NaryNodeOp {
  TanhNodeOp(const std::vector<Expr>& nodes)
      : NaryNodeOp(nodes, newShape(nodes)) {}

  Shape newShape(const std::vector<Expr>& nodes) {
    return Shape::broadcast(nodes);
  }

  NodeOps forwardOps() override {
    using namespace functional;
    switch(children_.size()) {
      case 1: return {NodeOp(Element(_1 = tanh(_2), val_, child(0)->val()))};
      case 2:
        return {NodeOp(Element(
            _1 = tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))};
      case 3:
        return {NodeOp(Element(_1 = tanh(_2 + _3 + _4),
                               val_,
                               child(0)->val(),
                               child(1)->val(),
                               child(2)->val()))};
      default:
        return {
          NodeOp(Element(_1 = _2 + _3 + _4,
                         val_,
                         child(0)->val(),
                         child(1)->val(),
                         child(2)->val());
                 for(size_t i = 3; i < children_.size(); ++i)
                     Element(_1 = _1 + _2, val_, child(i)->val());
                 Element(_1 = tanh(_1), val_);)
        };
    }
  }

  NodeOps backwardOps() override {
    using namespace functional;
    NodeOps ops;
    for(size_t i = 0; i < children_.size(); i++) {
      ops.push_back(
          NodeOp(Add(_1 * (1.0f - (_2 * _2)), child(i)->grad(), adj_, val_)));
    }
    return ops;
  }

  const std::string color() override { return "yellow"; }

  const std::string type() override { return "tanh"; }
};

struct ReLUNodeOp : public UnaryNodeOp {
  ReLUNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    // f(x) = max(0, x)
    using namespace functional;
    return {NodeOp(Element(_1 = ReLU(_2),
                           val_,            // _1 := f(x) to be calculated
                           child(0)->val()  // _2 := x
                           ))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    // dJ/dx += dJ/df * binarystep(x)
    return {NodeOp(Add(_1 * ReLUback(_2),
                       child(0)->grad(),  // dJ/dx
                       adj_,              // _1 := dJ/df
                       child(0)->val()    // _2 := f(x) = max(0, x)
                       ))};
  }

  const std::string type() override { return "ReLU"; }
};

struct PReLUNodeOp : public UnaryNodeOp {
  PReLUNodeOp(float alpha, Expr a) : UnaryNodeOp(a), alpha_(alpha) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(
        _1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))};
  }

  const std::string type() override { return "PReLU"; }

  virtual size_t hash() override {
    if(!hash_) {
      hash_ = NaryNodeOp::hash();
      util::hash_combine(hash_, alpha_);
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<PReLUNodeOp>(node);
    if(!cnode)
      return false;
    if(alpha_ != cnode->alpha_)
      return false;
    return true;
  }

private:
  float alpha_{0.01f};
};

struct SwishNodeOp : public UnaryNodeOp {
  SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x)))
    return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))),
                       child(0)->grad(),  // dJ/dx
                       adj_,              // _1 := dJ/df
                       child(0)->val(),   // _2 := x
                       val_               // _3 := f(x) = x*sigmoid(b*x)
                       ))};
  }

  const std::string type() override { return "swish"; }

  virtual size_t hash() override {
    if(!hash_) {
      hash_ = NaryNodeOp::hash();
      util::hash_combine(hash_, b_);
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<SwishNodeOp>(node);
    if(!cnode)
      return false;
    if(b_ != cnode->b_)
      return false;
    return true;
  }

  float b_;
};

struct SoftmaxNodeOp : public UnaryNodeOp {
  SoftmaxNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    return {
        NodeOp(Softmax(val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    // For each row, the Jacobian times vector is given by:
    // J * dy = p .* (dy - avg*1)
    // where avg = p'*dy and p is the softmax output (probabilities).
    //
    // For more information, see sec. 2.5 of the following reference:
    // André F. T. Martins and Ramon Astudillo.
    // "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label
    // Classification." ICML 2016.
    // http://jmlr.org/proceedings/papers/v48/martins16.pdf

    // val_ is already masked if there is a mask, so no need to apply here.

    return {NodeOp(SoftmaxGrad(child(0)->grad(), adj_, val_))};
  }

  const std::string type() override { return "softmax"; }
};

struct LogSoftmaxNodeOp : public UnaryNodeOp {
  LogSoftmaxNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override { return {NodeOp(LogSoftmax(val_, child(0)->val()))}; }

  NodeOps backwardOps() override {
    // Based on the description for softmax, we have logsoftmax:
    // J * dy = dy - avg*1
    // where avg = exp(p)'*dy and p is the softmax output (probabilities).
    return {NodeOp(LogSoftmaxGrad(child(0)->grad(), adj_, val_))};
  }

  const std::string type() override { return "logsoftmax"; }
};

enum class ReduceNodeOpCode {
  sum, mean, rms, meanSqr, min, max, prod, logSumExp
};

struct ReduceNodeOp : public UnaryNodeOp {
  friend class SerializationHelpers;
  int axis_;
  ReduceNodeOpCode opCode_;
  int reducedDim_; // dimension of axis being reduced, e.g. used in mean()

  ReduceNodeOp(Expr a, int axis, ReduceNodeOpCode opCode)
      : UnaryNodeOp(a, newShape(a, axis)), opCode_(opCode)
  {
    reducedDim_ = a->shape()[axis]; // e.g. used in mean()
    ABORT_IF(reducedDim_ != a->shape().elements() / shape().elements(),
             "Bug in determining reducedDim {} != {}",
             reducedDim_,
             a->shape().elements() / shape().elements());
  }

  NodeOps forwardOps() override {
    using namespace functional;

    switch (opCode_) {
    case ReduceNodeOpCode::sum:
      return {NodeOp(Reduce(_1, val_, child(0)->val()))};
    case ReduceNodeOpCode::mean:
      return {NodeOp(Reduce(_1, 1.0f / (float)reducedDim_, val_, child(0)->val()))};
    case ReduceNodeOpCode::rms:
      return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val());
                     Element(_1 = sqrt(_1), val_))};
    case ReduceNodeOpCode::meanSqr:
      return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val()))};
    case ReduceNodeOpCode::min:
      return {NodeOp(Reduce(_1, min(_1,_2), std::numeric_limits<float>::max(), val_, child(0)->val()))};
    case ReduceNodeOpCode::max:
      return {NodeOp(Reduce(_1, max(_1,_2), std::numeric_limits<float>::lowest(), val_, child(0)->val()))};
    case ReduceNodeOpCode::prod:
      return {NodeOp(Reduce(_1, _1 * _2, 1.0f, val_, child(0)->val()))};
    case ReduceNodeOpCode::logSumExp:
      return {NodeOp(Reduce(_1, logaddexp(_1,_2), std::numeric_limits<float>::lowest(), val_, child(0)->val()))};
    default:
      ABORT("Unexpected reduction op-code {}", (int)opCode_);
    }
  }

  NodeOps backwardOps() override {
    using namespace functional;
#if 1 // @BUGBUG: This is a workaround for not correctly propagating non-trainable information. @TODO: Do this the right and general way.
    if (adj_ == nullptr)
      return {};
#endif
    switch (opCode_) {
    case ReduceNodeOpCode::sum:
      return {NodeOp(Add(_1, child(0)->grad(), adj_))};
    case ReduceNodeOpCode::mean:
      return {NodeOp(Add(_1, 1.0f / (float)reducedDim_, child(0)->grad(), adj_))};
    case ReduceNodeOpCode::rms: // WARNING: UNTESTED!!
      // y = (sum_j x_j^2)^0.5
      // dJ/dx_i = dJ/dy * 0.5 (sum_j x_j^2)^-0.5 * 2 x_i = dJ/dy * x_i / y  --@REVIEW: is this correct?
      // @TODO: do we need protection against div by 0? L'hospital rule?
      return {NodeOp(Add(_1 * _2 / _3, child(0)->grad(), adj_, child(0)->val(), val_))};
    case ReduceNodeOpCode::meanSqr: // WARNING: UNTESTED!!
      // y = sum_j x_j^2
      // dJ/dx_i = dJ/dy * sum_j dx_j^2/dx_i = dJ/dy * 2 dx_i  --@REVIEW: is this correct?
      return {NodeOp(Add(_1 * 2.0f * _2, child(0)->grad(), adj_, child(0)->val()))};
    case ReduceNodeOpCode::min:  // WARNING: UNTESTED!!
    case ReduceNodeOpCode::max:  // WARNING: UNTESTED!!
      // adj_ gets routed into the min/max value  --@REVIEW: is this correct?
      return {NodeOp(Add((_1 == _2) * _3, child(0)->grad(), child(0)->val(), val_, adj_))};
    case ReduceNodeOpCode::logSumExp:
      // y = log(sum_j exp(x_j))
      // dJ/dx_i = dJ/dy * 1/(sum_j exp(x_j)) exp(x_i) = dJ/dy * exp(x_i - y))  --@REVIEW: is this correct?
      return {NodeOp(Add(_1 * exp(_2 - _3), child(0)->grad(), adj_, child(0)->val(), val_))};
    default:
      ABORT("Unexpected reduction op-code {}", (int)opCode_);
    }
  }

  Shape newShape(Expr a, int axis) {
    Shape shape = a->shape();
    axis_ = shape.axis(axis);

    shape.set(axis_, 1);
    return shape;
  }

  const std::string type() override {
    switch (opCode_) {
    case ReduceNodeOpCode::sum:       return "sum";
    case ReduceNodeOpCode::mean:      return "mean";
    case ReduceNodeOpCode::rms:       return "rms";
    case ReduceNodeOpCode::meanSqr:   return "meanSqr";
    case ReduceNodeOpCode::min:       return "min";
    case ReduceNodeOpCode::max:       return "max";
    case ReduceNodeOpCode::prod:      return "prod";
    case ReduceNodeOpCode::logSumExp: return "logSumExp";
    default: ABORT("Unexpected reduction op-code {}", (int)opCode_);
    }
  }

  const std::string color() override { return "orange"; }

  virtual size_t hash() override {
    if(!hash_) {
      hash_ = NaryNodeOp::hash();
      util::hash_combine(hash_, axis_);
      util::hash_combine(hash_, (int)opCode_);
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<ReduceNodeOp>(node);
    if(!cnode)
      return false;
    if(axis_ != cnode->axis_ || opCode_ != cnode->opCode_)
      return false;
    return true;
  }
};

struct LogNodeOp : public UnaryNodeOp {
  LogNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = log(_2), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {// NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_,
            // child(0)->val()))};
            NodeOp(Add(_1 / _2, child(0)->grad(), adj_, child(0)->val()))};
  }

  const std::string type() override { return "log"; }
};

struct ExpNodeOp : public UnaryNodeOp {
  ExpNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = exp(_2), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(_1 * exp(_2), child(0)->grad(), adj_, child(0)->val()))};
  }

  const std::string type() override { return "exp"; }
};

struct SinNodeOp : public UnaryNodeOp {
  SinNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = sin(_2), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(_1 * cos(_2), child(0)->grad(), adj_, child(0)->val()))};
  }

  const std::string type() override { return "sin"; }
};

struct CosNodeOp : public UnaryNodeOp {
  CosNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = cos(_2), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(_1 * -sin(_2), child(0)->grad(), adj_, child(0)->val()))};
  }

  const std::string type() override { return "cos"; }
};

struct TanNodeOp : public UnaryNodeOp {
  TanNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = tan(_2), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(_1 / sqr(cos(_2)), child(0)->grad(), adj_, child(0)->val()))};
  }

  const std::string type() override { return "tan"; }
};

struct SqrtNodeOp : public UnaryNodeOp {
  float epsilon_;

  SqrtNodeOp(Expr a, float epsilon) : UnaryNodeOp(a), epsilon_(epsilon) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = sqrt(_2 + epsilon_), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(0.5f * (1.f / _1) * _2, child(0)->grad(), val_, adj_))};
  }

  const std::string type() override { return "sqrt"; }

  virtual size_t hash() override {
    if(!hash_) {
      size_t seed = NaryNodeOp::hash();
      util::hash_combine(seed, epsilon_);
      hash_ = seed;
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<SqrtNodeOp>(node);
    if(!cnode)
      return false;
    if(epsilon_ != cnode->epsilon_)
      return false;
    return true;
  }
};

struct SquareNodeOp : public UnaryNodeOp {
  SquareNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = _2 * _2, val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {
        NodeOp(Add(2.f * _1 * _2, child(0)->grad(), child(0)->val(), adj_))};
  }

  const std::string type() override { return "square"; }
};

struct NegNodeOp : public UnaryNodeOp {
  NegNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = -_2, val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(-_1, child(0)->grad(), adj_))};
  }

  const std::string type() override { return "negate"; }
};

struct TransposeNodeOp : public UnaryNodeOp {
  TransposeNodeOp(Expr a, const std::vector<int>& axes)
      : UnaryNodeOp(a, newShape(a, axes)), axes_{axes}, axesBw_(axes.size()) {
    for(int i = 0; i < axes_.size(); ++i)
       axesBw_[axes_[i]] = i;
  }

  NodeOps forwardOps() override {
    return {NodeOp(TransposeND(val_, child(0)->val(), axes_))};
  }

  NodeOps backwardOps() override {
    return {NodeOp(TransposeNDGrad(child(0)->grad(), adj_, axesBw_))};
  }

  Shape newShape(Expr a, const std::vector<int>& axes) {
    Shape shape = a->shape();

    ABORT_IF(shape.size() != axes.size(),
             "Shape and transpose axes have different number of dimensions");

    for(size_t i = 0; i < shape.size(); ++i)
      shape.set(i, a->shape()[axes[i]]);

    return shape;
  }

  virtual size_t hash() override {
    if(!hash_) {
      size_t seed = NaryNodeOp::hash();
      for(auto ax : axes_)
        util::hash_combine(seed, ax);
      hash_ = seed;
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<TransposeNodeOp>(node);
    if(!cnode)
      return false;
    if(axes_ != cnode->axes_)
      return false;
    return true;
  }

  const std::string type() override { return "transpose"; }

  const std::string color() override { return "orange"; }

private:
  friend class SerializationHelpers;
  std::vector<int> axes_;
  std::vector<int> axesBw_;
};

class ReshapeNodeOp : public UnaryNodeOp {
protected:
  friend class SerializationHelpers;
  Expr reshapee_;

public:
  ReshapeNodeOp(Expr a, Shape shape) : UnaryNodeOp(a, shape), reshapee_(a) {
    ABORT_IF(a->shape().elements() != shape.elements(),
             "Reshape must not change the number of elements (from {} to {})", a->shape().toString(), shape.toString());
    Node::destroy_ = false;
  }

  ~ReshapeNodeOp() {}

  void allocate() override {}
  void free() override {}

  void forward() override {}
  void backward() override {}

  void init_dependent() override { reshapee_->init_dependent(); }

  void set_zero_adjoint() override { reshapee_->set_zero_adjoint(); }

  Tensor& val() override {
    auto childVal = reshapee_->val();
    auto temp = TensorBase::New(childVal->memory(), shape(), childVal->type(), childVal->getBackend());
    val_.swap(temp);
    return val_;
  };

  Tensor& grad() override {
    auto childGrad = reshapee_->grad();
    auto temp = TensorBase::New(childGrad->memory(), shape(), childGrad->type(), childGrad->getBackend());
    adj_.swap(temp);
    return adj_;
  };

  const std::string type() override { return "reshape"; }

  const std::string color() override { return "grey"; }

  virtual size_t hash() override {
    if(!hash_) {
      size_t seed = NaryNodeOp::hash();
      for(auto s : shape())
        util::hash_combine(seed, s);
      hash_ = seed;
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<ReshapeNodeOp>(node);
    if(!cnode)
      return false;
    if(shape() != cnode->shape())
      return false;
    return true;
  }
};

// @TODO: add version with access to backward step
// This allows to attach a lambda function to any node during the execution. It is a non-operation otherwise
// i.e. doesn't consume any memory or take any time to execute (it's a reshape onto itself) other than the
// compute in the lambda function. It gets called after the forward step of the argument node.
class CallbackNodeOp : public ReshapeNodeOp {
private:
  typedef std::function<void(Expr)> LambdaNodeCallback;
  std::unique_ptr<LambdaNodeCallback> callback_;

public:
  CallbackNodeOp(Expr node, LambdaNodeCallback callback)
  : ReshapeNodeOp(node, node->shape()),
    callback_(new LambdaNodeCallback(callback)) {
  }

  void forward() override {
    (*callback_)(ReshapeNodeOp::reshapee_);
  }

  const std::string type() override { return "callback"; }

  virtual size_t hash() override {
    size_t seed = ReshapeNodeOp::hash();
    util::hash_combine(seed, callback_.get());
    return seed;
  }

  virtual bool equal(Expr node) override {
    if(!ReshapeNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<CallbackNodeOp>(node);
    if(!cnode)
      return false;
    if(callback_ != cnode->callback_)   // pointer compare on purpose
      return false;
    return true;
  }
};

// @TODO: review if still required as this is an ugly hack anyway.
// Memory less operator that clips gradients during backward step
// Executes this as an additional operation on the gradient.
class ClipGradientNodeOp : public UnaryNodeOp {
private:
  Expr clipee_;
  float clipValue_{0};

public:
  ClipGradientNodeOp(Expr a, float clipValue)
   : UnaryNodeOp(a), clipee_(a), clipValue_(clipValue) {
    Node::destroy_ = false;
  }

  ~ClipGradientNodeOp() {}

  void allocate() override {}
  void free() override {}

  void forward() override {}

  void backward() override {
    using namespace marian::functional;
    Element(_1 = clip(_1, clipValue_), adj_);
  }

  void init_dependent() override { clipee_->init_dependent(); }

  void set_zero_adjoint() override { clipee_->set_zero_adjoint(); }

  Tensor& val() override {
    auto childVal = clipee_->val();
    auto temp = TensorBase::New(childVal->memory(), shape(), childVal->type(), childVal->getBackend());
    val_.swap(temp);
    return val_;
  };

  Tensor& grad() override {
    auto childGrad = clipee_->grad();
    auto temp = TensorBase::New(childGrad->memory(), shape(), childGrad->type(), childGrad->getBackend());
    adj_.swap(temp);
    return adj_;
  };

  const std::string type() override { return "clipGradient"; }

  const std::string color() override { return "grey"; }

  virtual size_t hash() override {
    if(!hash_) {
      size_t seed = NaryNodeOp::hash();
      util::hash_combine(seed, clipValue_);
      hash_ = seed;
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<ClipGradientNodeOp>(node);
    if(!cnode)
      return false;
    if(clipValue_ != cnode->clipValue_)
      return false;
    return true;
  }
};

// narrow an axis to [begin, end)
// The resulting object must be consecutive in memory.
class SliceViewNodeOp : public UnaryNodeOp {
private:
  friend class SerializationHelpers;
  Expr viewedNode_; // viewed underlying node
  Slice slice_;     // index range
  int axis_;        // and axis along which it is viewed
  size_t byteOffset_, byteSize_; // viewed segment in bytes (memory-consecutive)

public:
  SliceViewNodeOp(Expr a, int axis, Slice slice)
      : UnaryNodeOp(a, newShape(a, axis, slice), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) {
    Node::destroy_ = false;
    auto byteStride = a->shape().stride(axis) * sizeOf(value_type());
    byteOffset_ = slice.begin * byteStride;
    byteSize_ = shape()[axis] * byteStride;
  }

  static Shape newShape(Expr a, int& axis, Slice& slice) { // note: normalizes slice and axis in-place
    const auto& shape = a->shape();
    axis  = shape.axis(axis);         // normalize negative axis
    slice = shape.slice(slice, axis); // normalize negative slice values
    // enforce consecutive memory
    if (slice.begin != 0 || slice.end != shape[axis] || slice.stride != 1) { // unless it's a no-op
      ABORT_IF(slice.stride != 1, "Strides other than 1 are presently not supported by sliceView()");
      for(int i = 0; i < axis; ++i)
        ABORT_IF(shape[i] != 1, "Non-consecutive slices are presently not supported by sliceView()");
    }
    Shape outShape = shape;
    outShape.set(axis, slice.end - slice.begin);
    return outShape;
  }

  void allocate() override {}
  void free() override {}

  void forward() override {}
  void backward() override {}

  void init_dependent() override { viewedNode_->init_dependent(); }

  void set_zero_adjoint() override { viewedNode_->set_zero_adjoint(); } // lazily allocate and zero out gradient (only runs once)

  Tensor& val() override {
    auto childVal = viewedNode_->val();
    auto mem = MemoryPiece::New(childVal->memory()->data() + byteOffset_, byteSize_);
    auto temp = TensorBase::New(mem, shape(), childVal->type(), childVal->getBackend());
    val_.swap(temp);
    return val_;
  };

  Tensor& grad() override {
    auto childGrad = viewedNode_->grad();
    auto mem = MemoryPiece::New(childGrad->memory()->data() + byteOffset_, byteSize_);
    auto temp = TensorBase::New(mem, shape(), childGrad->type(), childGrad->getBackend());
    adj_.swap(temp);
    return adj_;
  };

  const std::string type() override { return "sliceView"; }

  const std::string color() override { return "grey"; }

  virtual size_t hash() override {
    if(!hash_) {
      hash_ = NaryNodeOp::hash();
      util::hash_combine(hash_, slice_.begin);
      util::hash_combine(hash_, slice_.end);
      util::hash_combine(hash_, slice_.stride);
      util::hash_combine(hash_, axis_);
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<SliceViewNodeOp>(node);
    if(!cnode)
      return false;
    if(slice_ != cnode->slice_)
      return false;
    if(axis_ != cnode->axis_)
      return false;
    return true;
  }
};

struct ShiftNodeOp : public UnaryNodeOp {
  ShiftNodeOp(Expr a, Shape shift, float padValue)
      : UnaryNodeOp(a, a->shape()), shift_(shift), padValue_(padValue) {}

  NodeOps forwardOps() override {
    return {NodeOp(
        Shift(val_, child(0)->val(), shift_, padValue_, /*invert=*/false))};
  }

  NodeOps backwardOps() override {
    // last parameter beta=1 says to use += (out = in + beta * out)
    // @TODO: check need for padValue_
    return {NodeOp(ShiftGrad(child(0)->grad(), adj_, shift_, true))};
  }

  const std::string type() override { return "shift"; }

  virtual size_t hash() override {
    if(!hash_) {
      size_t seed = NaryNodeOp::hash();
      for(auto i : shift_)
        util::hash_combine(seed, i);
      util::hash_combine(seed, padValue_);
      hash_ = seed;
    }
    return hash_;
  }

  virtual bool equal(Expr node) override {
    if(!NaryNodeOp::equal(node))
      return false;
    auto cnode = std::dynamic_pointer_cast<ShiftNodeOp>(node);
    if(!cnode)
      return false;
    if(shift_ != cnode->shift_)
      return false;
    if(padValue_ != cnode->padValue_)
      return false;
    return true;
  }

private:
  friend class SerializationHelpers;
  Shape shift_;     // shift offsets in each dimension
  float padValue_;  // what value to shift in
};

struct AbsNodeOp : public UnaryNodeOp {
  AbsNodeOp(Expr a) : UnaryNodeOp(a) {}

  NodeOps forwardOps() override {
    using namespace functional;
    return {NodeOp(Element(_1 = abs(_2), val_, child(0)->val()))};
  }

  NodeOps backwardOps() override {
    using namespace functional;
    return {NodeOp(Add(sgn(_1) * _2, child(0)->grad(), child(0)->val(), adj_))};
  }

  const std::string type() override { return "abs"; }
};

#ifdef CUDNN
class PoolingOp : public UnaryNodeOp {
public:
  PoolingOp(Expr x,
            int height,
            int width,
            int padHeight,
            int padWidth,
            int strideHeight,
            int strideWidth,
            std::string mode)
      : UnaryNodeOp(x),
        pooling_(height,
                 width,
                 padHeight,
                 padWidth,
                 strideHeight,
                 strideWidth,
                 mode) {}

  NodeOps forwardOps() override {
    return {NodeOp(pooling_.forward(child(0)->val(), val_))};
  }

  NodeOps backwardOps() override {
    return {NodeOp(
        pooling_.backward(child(0)->val(), child(0)->grad(), val_, adj_))};
  }

  const std::string type() override { return "layer_pooling"; }

protected:
  PoolingWrapper pooling_;
};
#endif

class PoolingWithMaskingOp : public UnaryNodeOp {
public:
  PoolingWithMaskingOp(Expr x, Expr mask, int width, bool isEven = false)
      : UnaryNodeOp(x), mask_(mask), width_(width), isEven_(isEven) {
    auto xShape = x->shape();
    int dimBatch = xShape[0];
    int dimWord = xShape[1];
    int cols = (isEven_) ? xShape[2] - 1 : xShape[2];
    int dimSentence = (cols / width_) + (cols % width_ != 0);
    shape_ = {dimBatch, dimWord, dimSentence};
  }

  NodeOps forwardOps() override {
    return {NodeOp(PoolingWithMaskingForward(
        val_, child(0)->val(), mask_->val(), width_, isEven_))};
  }

  NodeOps backwardOps() override {
    return {NodeOp(PoolingWithMaskingBackward(adj_,
                                              child(0)->grad(),
                                              child(0)->val(),
                                              mask_->val(),
                                              width_,
                                              isEven_))};
  }

  const std::string type() override { return "layer_pooling"; }

protected:
  friend class SerializationHelpers;
  Expr mask_;
  int width_;
  bool isEven_;
};
}  // namespace marian