.. _program_listing_file_src_graph_node_operators_unary.h: Program Listing for File node_operators_unary.h =============================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/graph/node_operators_unary.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "tensors/backend.h" #include "tensors/tensor.h" #include "functional/functional.h" #include "graph/node.h" #include "tensors/tensor_operators.h" #ifdef CUDNN #include "tensors/gpu/cudnn_wrappers.h" #endif namespace marian { struct UnaryNodeOp : public NaryNodeOp { UnaryNodeOp(Expr a, Shape shape, Type value_type) : NaryNodeOp({a}, shape, value_type) {} UnaryNodeOp(Expr a, Type value_type) : NaryNodeOp({a}, a->shape(), value_type) {} UnaryNodeOp(Expr a, Shape shape) : NaryNodeOp({a}, shape, a->value_type()) {} UnaryNodeOp(Expr a) : NaryNodeOp({a}, a->shape(), a->value_type()) {} const std::string color() override { return "yellow"; } }; struct ScalarAddNodeOp : public UnaryNodeOp { private: friend class SerializationHelpers; float scalar_{0}; public: ScalarAddNodeOp(Expr a, float scalar) : UnaryNodeOp(a), scalar_{scalar} {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } const std::string type() override { return "scalar_add"; } virtual size_t hash() override { if(!hash_) { hash_ = NaryNodeOp::hash(); util::hash_combine(hash_, scalar_); } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(scalar_ != cnode->scalar_) return false; return true; } }; // Cast a tensor to a different type struct CastNodeOp : public UnaryNodeOp { public: CastNodeOp(Expr a, Type type) : UnaryNodeOp(a, type) {} NodeOps forwardOps() override { using namespace functional; return { NodeOp(CopyCast(val_, child(0)->val())) }; } NodeOps backwardOps() override { using namespace functional; return { NodeOp(AddCast(child(0)->grad(), adj_)) }; } const std::string type() override { return "cast"; } }; struct ScalarMultNodeOp : public UnaryNodeOp { private: friend class SerializationHelpers; float scalar_{0}; public: ScalarMultNodeOp(Expr a, float scalar) : UnaryNodeOp(a), scalar_{scalar} {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = scalar_ * _2, val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))}; } const std::string type() override { return "scalar_mult"; } virtual size_t hash() override { if(!hash_) { hash_ = NaryNodeOp::hash(); util::hash_combine(hash_, scalar_); } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(scalar_ != cnode->scalar_) return false; return true; } }; struct ClipNodeOp : public UnaryNodeOp { private: float clip_{0}; public: ClipNodeOp(Expr a, float clip) : UnaryNodeOp(a), clip_{clip} {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = clip(_2, clip_), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp( Add(bump(_1, clip_) * _2, child(0)->grad(), child(0)->val(), adj_))}; } const std::string type() override { return "clip"; } virtual size_t hash() override { if(!hash_) { hash_ = NaryNodeOp::hash(); util::hash_combine(hash_, clip_); } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(clip_ != cnode->clip_) return false; return true; } }; struct SigmoidNodeOp : public UnaryNodeOp { SigmoidNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = sigmoid(_2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))}; } const std::string type() override { return "sigmoid"; } }; // struct Scalar2PowNodeOp : public UnaryNodeOp { // private: // float scalar_{0}; // // public: // template // Scalar2PowNodeOp(Expr a, float scalar, Args... args) // : UnaryNodeOp(a, args...), scalar_{scalar} {} // // NodeOps forwardOps() { // return {NodeOp(Element(_1 = Pow(_2, scalar_), val_, child(0)->val()))}; // } // // NodeOps backwardOps() { // return {NodeOp(Add(scalar_ * Pow(_1, scalar_ - 1.f) * _2, // child(0)->grad(), child(0)->val(), adj_))}; // } // // const std::string type() { return "scalar_pow2"; } //}; // // struct Scalar1PowNodeOp : public UnaryNodeOp { // private: // float scalar_{0}; // // public: // template // Scalar1PowNodeOp(float scalar, Expr a, Args... args) // : UnaryNodeOp(a, args...), scalar_{scalar} {} // // NodeOps forwardOps() { // return {NodeOp(Element(_1 = Pow(scalar_, _2), val_, child(0)->val()))}; // } // // NodeOps backwardOps() { // return {NodeOp(Add(Pow(scalar_, _1) * log(scalar_) * _2, child(0)->grad(), // child(0)->val(), adj_))}; // } // // const std::string type() { return "scalar_pow1"; } //}; struct TanhNodeOp : public NaryNodeOp { TanhNodeOp(const std::vector& nodes) : NaryNodeOp(nodes, newShape(nodes)) {} Shape newShape(const std::vector& nodes) { return Shape::broadcast(nodes); } NodeOps forwardOps() override { using namespace functional; switch(children_.size()) { case 1: return {NodeOp(Element(_1 = tanh(_2), val_, child(0)->val()))}; case 2: return {NodeOp(Element( _1 = tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))}; case 3: return {NodeOp(Element(_1 = tanh(_2 + _3 + _4), val_, child(0)->val(), child(1)->val(), child(2)->val()))}; default: return { NodeOp(Element(_1 = _2 + _3 + _4, val_, child(0)->val(), child(1)->val(), child(2)->val()); for(size_t i = 3; i < children_.size(); ++i) Element(_1 = _1 + _2, val_, child(i)->val()); Element(_1 = tanh(_1), val_);) }; } } NodeOps backwardOps() override { using namespace functional; NodeOps ops; for(size_t i = 0; i < children_.size(); i++) { ops.push_back( NodeOp(Add(_1 * (1.0f - (_2 * _2)), child(i)->grad(), adj_, val_))); } return ops; } const std::string color() override { return "yellow"; } const std::string type() override { return "tanh"; } }; struct ReLUNodeOp : public UnaryNodeOp { ReLUNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { // f(x) = max(0, x) using namespace functional; return {NodeOp(Element(_1 = ReLU(_2), val_, // _1 := f(x) to be calculated child(0)->val() // _2 := x ))}; } NodeOps backwardOps() override { using namespace functional; // dJ/dx += dJ/df * binarystep(x) return {NodeOp(Add(_1 * ReLUback(_2), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val() // _2 := f(x) = max(0, x) ))}; } const std::string type() override { return "ReLU"; } }; struct PReLUNodeOp : public UnaryNodeOp { PReLUNodeOp(float alpha, Expr a) : UnaryNodeOp(a), alpha_(alpha) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add( _1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))}; } const std::string type() override { return "PReLU"; } virtual size_t hash() override { if(!hash_) { hash_ = NaryNodeOp::hash(); util::hash_combine(hash_, alpha_); } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(alpha_ != cnode->alpha_) return false; return true; } private: float alpha_{0.01f}; }; struct SwishNodeOp : public UnaryNodeOp { SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x))) return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x val_ // _3 := f(x) = x*sigmoid(b*x) ))}; } const std::string type() override { return "swish"; } virtual size_t hash() override { if(!hash_) { hash_ = NaryNodeOp::hash(); util::hash_combine(hash_, b_); } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(b_ != cnode->b_) return false; return true; } float b_; }; struct SoftmaxNodeOp : public UnaryNodeOp { SoftmaxNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { return { NodeOp(Softmax(val_, child(0)->val()))}; } NodeOps backwardOps() override { // For each row, the Jacobian times vector is given by: // J * dy = p .* (dy - avg*1) // where avg = p'*dy and p is the softmax output (probabilities). // // For more information, see sec. 2.5 of the following reference: // André F. T. Martins and Ramon Astudillo. // "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label // Classification." ICML 2016. // http://jmlr.org/proceedings/papers/v48/martins16.pdf // val_ is already masked if there is a mask, so no need to apply here. return {NodeOp(SoftmaxGrad(child(0)->grad(), adj_, val_))}; } const std::string type() override { return "softmax"; } }; struct LogSoftmaxNodeOp : public UnaryNodeOp { LogSoftmaxNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { return {NodeOp(LogSoftmax(val_, child(0)->val()))}; } NodeOps backwardOps() override { // Based on the description for softmax, we have logsoftmax: // J * dy = dy - avg*1 // where avg = exp(p)'*dy and p is the softmax output (probabilities). return {NodeOp(LogSoftmaxGrad(child(0)->grad(), adj_, val_))}; } const std::string type() override { return "logsoftmax"; } }; enum class ReduceNodeOpCode { sum, mean, rms, meanSqr, min, max, prod, logSumExp }; struct ReduceNodeOp : public UnaryNodeOp { friend class SerializationHelpers; int axis_; ReduceNodeOpCode opCode_; int reducedDim_; // dimension of axis being reduced, e.g. used in mean() ReduceNodeOp(Expr a, int axis, ReduceNodeOpCode opCode) : UnaryNodeOp(a, newShape(a, axis)), opCode_(opCode) { reducedDim_ = a->shape()[axis]; // e.g. used in mean() ABORT_IF(reducedDim_ != a->shape().elements() / shape().elements(), "Bug in determining reducedDim {} != {}", reducedDim_, a->shape().elements() / shape().elements()); } NodeOps forwardOps() override { using namespace functional; switch (opCode_) { case ReduceNodeOpCode::sum: return {NodeOp(Reduce(_1, val_, child(0)->val()))}; case ReduceNodeOpCode::mean: return {NodeOp(Reduce(_1, 1.0f / (float)reducedDim_, val_, child(0)->val()))}; case ReduceNodeOpCode::rms: return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val()); Element(_1 = sqrt(_1), val_))}; case ReduceNodeOpCode::meanSqr: return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val()))}; case ReduceNodeOpCode::min: return {NodeOp(Reduce(_1, min(_1,_2), std::numeric_limits::max(), val_, child(0)->val()))}; case ReduceNodeOpCode::max: return {NodeOp(Reduce(_1, max(_1,_2), std::numeric_limits::lowest(), val_, child(0)->val()))}; case ReduceNodeOpCode::prod: return {NodeOp(Reduce(_1, _1 * _2, 1.0f, val_, child(0)->val()))}; case ReduceNodeOpCode::logSumExp: return {NodeOp(Reduce(_1, logaddexp(_1,_2), std::numeric_limits::lowest(), val_, child(0)->val()))}; default: ABORT("Unexpected reduction op-code {}", (int)opCode_); } } NodeOps backwardOps() override { using namespace functional; #if 1 // @BUGBUG: This is a workaround for not correctly propagating non-trainable information. @TODO: Do this the right and general way. if (adj_ == nullptr) return {}; #endif switch (opCode_) { case ReduceNodeOpCode::sum: return {NodeOp(Add(_1, child(0)->grad(), adj_))}; case ReduceNodeOpCode::mean: return {NodeOp(Add(_1, 1.0f / (float)reducedDim_, child(0)->grad(), adj_))}; case ReduceNodeOpCode::rms: // WARNING: UNTESTED!! // y = (sum_j x_j^2)^0.5 // dJ/dx_i = dJ/dy * 0.5 (sum_j x_j^2)^-0.5 * 2 x_i = dJ/dy * x_i / y --@REVIEW: is this correct? // @TODO: do we need protection against div by 0? L'hospital rule? return {NodeOp(Add(_1 * _2 / _3, child(0)->grad(), adj_, child(0)->val(), val_))}; case ReduceNodeOpCode::meanSqr: // WARNING: UNTESTED!! // y = sum_j x_j^2 // dJ/dx_i = dJ/dy * sum_j dx_j^2/dx_i = dJ/dy * 2 dx_i --@REVIEW: is this correct? return {NodeOp(Add(_1 * 2.0f * _2, child(0)->grad(), adj_, child(0)->val()))}; case ReduceNodeOpCode::min: // WARNING: UNTESTED!! case ReduceNodeOpCode::max: // WARNING: UNTESTED!! // adj_ gets routed into the min/max value --@REVIEW: is this correct? return {NodeOp(Add((_1 == _2) * _3, child(0)->grad(), child(0)->val(), val_, adj_))}; case ReduceNodeOpCode::logSumExp: // y = log(sum_j exp(x_j)) // dJ/dx_i = dJ/dy * 1/(sum_j exp(x_j)) exp(x_i) = dJ/dy * exp(x_i - y)) --@REVIEW: is this correct? return {NodeOp(Add(_1 * exp(_2 - _3), child(0)->grad(), adj_, child(0)->val(), val_))}; default: ABORT("Unexpected reduction op-code {}", (int)opCode_); } } Shape newShape(Expr a, int axis) { Shape shape = a->shape(); axis_ = shape.axis(axis); shape.set(axis_, 1); return shape; } const std::string type() override { switch (opCode_) { case ReduceNodeOpCode::sum: return "sum"; case ReduceNodeOpCode::mean: return "mean"; case ReduceNodeOpCode::rms: return "rms"; case ReduceNodeOpCode::meanSqr: return "meanSqr"; case ReduceNodeOpCode::min: return "min"; case ReduceNodeOpCode::max: return "max"; case ReduceNodeOpCode::prod: return "prod"; case ReduceNodeOpCode::logSumExp: return "logSumExp"; default: ABORT("Unexpected reduction op-code {}", (int)opCode_); } } const std::string color() override { return "orange"; } virtual size_t hash() override { if(!hash_) { hash_ = NaryNodeOp::hash(); util::hash_combine(hash_, axis_); util::hash_combine(hash_, (int)opCode_); } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(axis_ != cnode->axis_ || opCode_ != cnode->opCode_) return false; return true; } }; struct LogNodeOp : public UnaryNodeOp { LogNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = log(_2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {// NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, // child(0)->val()))}; NodeOp(Add(_1 / _2, child(0)->grad(), adj_, child(0)->val()))}; } const std::string type() override { return "log"; } }; struct ExpNodeOp : public UnaryNodeOp { ExpNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = exp(_2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(_1 * exp(_2), child(0)->grad(), adj_, child(0)->val()))}; } const std::string type() override { return "exp"; } }; struct SinNodeOp : public UnaryNodeOp { SinNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = sin(_2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(_1 * cos(_2), child(0)->grad(), adj_, child(0)->val()))}; } const std::string type() override { return "sin"; } }; struct CosNodeOp : public UnaryNodeOp { CosNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = cos(_2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(_1 * -sin(_2), child(0)->grad(), adj_, child(0)->val()))}; } const std::string type() override { return "cos"; } }; struct TanNodeOp : public UnaryNodeOp { TanNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = tan(_2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(_1 / sqr(cos(_2)), child(0)->grad(), adj_, child(0)->val()))}; } const std::string type() override { return "tan"; } }; struct SqrtNodeOp : public UnaryNodeOp { float epsilon_; SqrtNodeOp(Expr a, float epsilon) : UnaryNodeOp(a), epsilon_(epsilon) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = sqrt(_2 + epsilon_), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(0.5f * (1.f / _1) * _2, child(0)->grad(), val_, adj_))}; } const std::string type() override { return "sqrt"; } virtual size_t hash() override { if(!hash_) { size_t seed = NaryNodeOp::hash(); util::hash_combine(seed, epsilon_); hash_ = seed; } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(epsilon_ != cnode->epsilon_) return false; return true; } }; struct SquareNodeOp : public UnaryNodeOp { SquareNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = _2 * _2, val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return { NodeOp(Add(2.f * _1 * _2, child(0)->grad(), child(0)->val(), adj_))}; } const std::string type() override { return "square"; } }; struct NegNodeOp : public UnaryNodeOp { NegNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = -_2, val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(-_1, child(0)->grad(), adj_))}; } const std::string type() override { return "negate"; } }; struct TransposeNodeOp : public UnaryNodeOp { TransposeNodeOp(Expr a, const std::vector& axes) : UnaryNodeOp(a, newShape(a, axes)), axes_{axes}, axesBw_(axes.size()) { for(int i = 0; i < axes_.size(); ++i) axesBw_[axes_[i]] = i; } NodeOps forwardOps() override { return {NodeOp(TransposeND(val_, child(0)->val(), axes_))}; } NodeOps backwardOps() override { return {NodeOp(TransposeNDGrad(child(0)->grad(), adj_, axesBw_))}; } Shape newShape(Expr a, const std::vector& axes) { Shape shape = a->shape(); ABORT_IF(shape.size() != axes.size(), "Shape and transpose axes have different number of dimensions"); for(size_t i = 0; i < shape.size(); ++i) shape.set(i, a->shape()[axes[i]]); return shape; } virtual size_t hash() override { if(!hash_) { size_t seed = NaryNodeOp::hash(); for(auto ax : axes_) util::hash_combine(seed, ax); hash_ = seed; } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(axes_ != cnode->axes_) return false; return true; } const std::string type() override { return "transpose"; } const std::string color() override { return "orange"; } private: friend class SerializationHelpers; std::vector axes_; std::vector axesBw_; }; class ReshapeNodeOp : public UnaryNodeOp { protected: friend class SerializationHelpers; Expr reshapee_; public: ReshapeNodeOp(Expr a, Shape shape) : UnaryNodeOp(a, shape), reshapee_(a) { ABORT_IF(a->shape().elements() != shape.elements(), "Reshape must not change the number of elements (from {} to {})", a->shape().toString(), shape.toString()); Node::destroy_ = false; } ~ReshapeNodeOp() {} void allocate() override {} void free() override {} void forward() override {} void backward() override {} void init_dependent() override { reshapee_->init_dependent(); } void set_zero_adjoint() override { reshapee_->set_zero_adjoint(); } Tensor& val() override { auto childVal = reshapee_->val(); auto temp = TensorBase::New(childVal->memory(), shape(), childVal->type(), childVal->getBackend()); val_.swap(temp); return val_; }; Tensor& grad() override { auto childGrad = reshapee_->grad(); auto temp = TensorBase::New(childGrad->memory(), shape(), childGrad->type(), childGrad->getBackend()); adj_.swap(temp); return adj_; }; const std::string type() override { return "reshape"; } const std::string color() override { return "grey"; } virtual size_t hash() override { if(!hash_) { size_t seed = NaryNodeOp::hash(); for(auto s : shape()) util::hash_combine(seed, s); hash_ = seed; } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(shape() != cnode->shape()) return false; return true; } }; // @TODO: add version with access to backward step // This allows to attach a lambda function to any node during the execution. It is a non-operation otherwise // i.e. doesn't consume any memory or take any time to execute (it's a reshape onto itself) other than the // compute in the lambda function. It gets called after the forward step of the argument node. class CallbackNodeOp : public ReshapeNodeOp { private: typedef std::function LambdaNodeCallback; std::unique_ptr callback_; public: CallbackNodeOp(Expr node, LambdaNodeCallback callback) : ReshapeNodeOp(node, node->shape()), callback_(new LambdaNodeCallback(callback)) { } void forward() override { (*callback_)(ReshapeNodeOp::reshapee_); } const std::string type() override { return "callback"; } virtual size_t hash() override { size_t seed = ReshapeNodeOp::hash(); util::hash_combine(seed, callback_.get()); return seed; } virtual bool equal(Expr node) override { if(!ReshapeNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(callback_ != cnode->callback_) // pointer compare on purpose return false; return true; } }; // @TODO: review if still required as this is an ugly hack anyway. // Memory less operator that clips gradients during backward step // Executes this as an additional operation on the gradient. class ClipGradientNodeOp : public UnaryNodeOp { private: Expr clipee_; float clipValue_{0}; public: ClipGradientNodeOp(Expr a, float clipValue) : UnaryNodeOp(a), clipee_(a), clipValue_(clipValue) { Node::destroy_ = false; } ~ClipGradientNodeOp() {} void allocate() override {} void free() override {} void forward() override {} void backward() override { using namespace marian::functional; Element(_1 = clip(_1, clipValue_), adj_); } void init_dependent() override { clipee_->init_dependent(); } void set_zero_adjoint() override { clipee_->set_zero_adjoint(); } Tensor& val() override { auto childVal = clipee_->val(); auto temp = TensorBase::New(childVal->memory(), shape(), childVal->type(), childVal->getBackend()); val_.swap(temp); return val_; }; Tensor& grad() override { auto childGrad = clipee_->grad(); auto temp = TensorBase::New(childGrad->memory(), shape(), childGrad->type(), childGrad->getBackend()); adj_.swap(temp); return adj_; }; const std::string type() override { return "clipGradient"; } const std::string color() override { return "grey"; } virtual size_t hash() override { if(!hash_) { size_t seed = NaryNodeOp::hash(); util::hash_combine(seed, clipValue_); hash_ = seed; } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(clipValue_ != cnode->clipValue_) return false; return true; } }; // narrow an axis to [begin, end) // The resulting object must be consecutive in memory. class SliceViewNodeOp : public UnaryNodeOp { private: friend class SerializationHelpers; Expr viewedNode_; // viewed underlying node Slice slice_; // index range int axis_; // and axis along which it is viewed size_t byteOffset_, byteSize_; // viewed segment in bytes (memory-consecutive) public: SliceViewNodeOp(Expr a, int axis, Slice slice) : UnaryNodeOp(a, newShape(a, axis, slice), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) { Node::destroy_ = false; auto byteStride = a->shape().stride(axis) * sizeOf(value_type()); byteOffset_ = slice.begin * byteStride; byteSize_ = shape()[axis] * byteStride; } static Shape newShape(Expr a, int& axis, Slice& slice) { // note: normalizes slice and axis in-place const auto& shape = a->shape(); axis = shape.axis(axis); // normalize negative axis slice = shape.slice(slice, axis); // normalize negative slice values // enforce consecutive memory if (slice.begin != 0 || slice.end != shape[axis] || slice.stride != 1) { // unless it's a no-op ABORT_IF(slice.stride != 1, "Strides other than 1 are presently not supported by sliceView()"); for(int i = 0; i < axis; ++i) ABORT_IF(shape[i] != 1, "Non-consecutive slices are presently not supported by sliceView()"); } Shape outShape = shape; outShape.set(axis, slice.end - slice.begin); return outShape; } void allocate() override {} void free() override {} void forward() override {} void backward() override {} void init_dependent() override { viewedNode_->init_dependent(); } void set_zero_adjoint() override { viewedNode_->set_zero_adjoint(); } // lazily allocate and zero out gradient (only runs once) Tensor& val() override { auto childVal = viewedNode_->val(); auto mem = MemoryPiece::New(childVal->memory()->data() + byteOffset_, byteSize_); auto temp = TensorBase::New(mem, shape(), childVal->type(), childVal->getBackend()); val_.swap(temp); return val_; }; Tensor& grad() override { auto childGrad = viewedNode_->grad(); auto mem = MemoryPiece::New(childGrad->memory()->data() + byteOffset_, byteSize_); auto temp = TensorBase::New(mem, shape(), childGrad->type(), childGrad->getBackend()); adj_.swap(temp); return adj_; }; const std::string type() override { return "sliceView"; } const std::string color() override { return "grey"; } virtual size_t hash() override { if(!hash_) { hash_ = NaryNodeOp::hash(); util::hash_combine(hash_, slice_.begin); util::hash_combine(hash_, slice_.end); util::hash_combine(hash_, slice_.stride); util::hash_combine(hash_, axis_); } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(slice_ != cnode->slice_) return false; if(axis_ != cnode->axis_) return false; return true; } }; struct ShiftNodeOp : public UnaryNodeOp { ShiftNodeOp(Expr a, Shape shift, float padValue) : UnaryNodeOp(a, a->shape()), shift_(shift), padValue_(padValue) {} NodeOps forwardOps() override { return {NodeOp( Shift(val_, child(0)->val(), shift_, padValue_, /*invert=*/false))}; } NodeOps backwardOps() override { // last parameter beta=1 says to use += (out = in + beta * out) // @TODO: check need for padValue_ return {NodeOp(ShiftGrad(child(0)->grad(), adj_, shift_, true))}; } const std::string type() override { return "shift"; } virtual size_t hash() override { if(!hash_) { size_t seed = NaryNodeOp::hash(); for(auto i : shift_) util::hash_combine(seed, i); util::hash_combine(seed, padValue_); hash_ = seed; } return hash_; } virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(shift_ != cnode->shift_) return false; if(padValue_ != cnode->padValue_) return false; return true; } private: friend class SerializationHelpers; Shape shift_; // shift offsets in each dimension float padValue_; // what value to shift in }; struct AbsNodeOp : public UnaryNodeOp { AbsNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() override { using namespace functional; return {NodeOp(Element(_1 = abs(_2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; return {NodeOp(Add(sgn(_1) * _2, child(0)->grad(), child(0)->val(), adj_))}; } const std::string type() override { return "abs"; } }; #ifdef CUDNN class PoolingOp : public UnaryNodeOp { public: PoolingOp(Expr x, int height, int width, int padHeight, int padWidth, int strideHeight, int strideWidth, std::string mode) : UnaryNodeOp(x), pooling_(height, width, padHeight, padWidth, strideHeight, strideWidth, mode) {} NodeOps forwardOps() override { return {NodeOp(pooling_.forward(child(0)->val(), val_))}; } NodeOps backwardOps() override { return {NodeOp( pooling_.backward(child(0)->val(), child(0)->grad(), val_, adj_))}; } const std::string type() override { return "layer_pooling"; } protected: PoolingWrapper pooling_; }; #endif class PoolingWithMaskingOp : public UnaryNodeOp { public: PoolingWithMaskingOp(Expr x, Expr mask, int width, bool isEven = false) : UnaryNodeOp(x), mask_(mask), width_(width), isEven_(isEven) { auto xShape = x->shape(); int dimBatch = xShape[0]; int dimWord = xShape[1]; int cols = (isEven_) ? xShape[2] - 1 : xShape[2]; int dimSentence = (cols / width_) + (cols % width_ != 0); shape_ = {dimBatch, dimWord, dimSentence}; } NodeOps forwardOps() override { return {NodeOp(PoolingWithMaskingForward( val_, child(0)->val(), mask_->val(), width_, isEven_))}; } NodeOps backwardOps() override { return {NodeOp(PoolingWithMaskingBackward(adj_, child(0)->grad(), child(0)->val(), mask_->val(), width_, isEven_))}; } const std::string type() override { return "layer_pooling"; } protected: friend class SerializationHelpers; Expr mask_; int width_; bool isEven_; }; } // namespace marian