Program Listing for File node_operators_binary.h¶
↰ Return to documentation for file (src/graph/node_operators_binary.h
)
#pragma once
#include <thread>
#include "common/hash.h"
#include "functional/functional.h"
#include "graph/node.h"
#include "tensors/tensor_operators.h"
#ifdef CUDNN
#include "tensors/gpu/cudnn_wrappers.h"
#endif
namespace marian {
class LambdaNodeOp : public NaryNodeOp {
private:
typedef const std::vector<Expr>& Inputs;
typedef std::function<void(Expr, Inputs)> LambdaNodeFunctor;
std::unique_ptr<LambdaNodeFunctor> forward_;
std::unique_ptr<LambdaNodeFunctor> backward_;
size_t externalHash_;
public:
LambdaNodeOp(Inputs inputs, Shape shape, Type type,
LambdaNodeFunctor forward,
size_t externalHash = 0)
: NaryNodeOp(inputs, shape, type),
forward_(new LambdaNodeFunctor(forward)),
externalHash_(externalHash) {
Node::trainable_ = !!backward_;
}
LambdaNodeOp(Inputs inputs, Shape shape, Type type,
LambdaNodeFunctor forward,
LambdaNodeFunctor backward,
size_t externalHash = 0)
: NaryNodeOp(inputs, shape, type),
forward_(new LambdaNodeFunctor(forward)),
backward_(new LambdaNodeFunctor(backward)),
externalHash_(externalHash) {
}
void forward() override {
(*forward_)(this, children_);
}
void backward() override {
ABORT_IF(!backward_, "No backward lambda given?");
(*backward_)(this, children_);
}
const std::string type() override { return "lambda"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
if(externalHash_ != 0) {
util::hash_combine(seed, externalHash_);
} else {
util::hash_combine(seed, forward_.get());
util::hash_combine(seed, backward_.get());
}
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<LambdaNodeOp>(node);
if(!cnode)
return false;
if(forward_ != cnode->forward_) // pointer compare on purpose
return false;
if(backward_ != cnode->backward_) // pointer compare on purpose
return false;
return true;
}
};
class DotNodeOp : public NaryNodeOp {
private:
friend class SerializationHelpers;
bool transA_;
bool transB_;
float scalar_;
public:
DotNodeOp(Expr a, Expr b, bool transA, bool transB, float scalar)
: NaryNodeOp({a, b}, newShape(a, b, transA, transB)),
transA_(transA),
transB_(transB),
scalar_(scalar) {}
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
auto shapeA = a->shape();
if(transA) {
shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
}
auto shapeB = b->shape();
if(transB) {
shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]); // @TODO: why not use negative indices?
shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]);
}
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
return outShape;
}
NodeOps forwardOps() override {
// C = alpha * dot(op(A), op(B))
return {NodeOp(Prod(val_,
child(0)->val(),
child(1)->val(),
transA_,
transB_,
0.f,
scalar_))};
}
NodeOps backwardOps() override {
// D is the adjoint, the matrix of derivatives
// df/dA += alpha * dot(D, op(B).T)
// df/dB += alpha * dot(op(A).T, D)
// beta set to 1.0 in gemm, C = alpha * dot(op(A), op(B)) + beta * C
// to sum gradients from different graph parts
auto isParameter = [](Expr p) {
return std::dynamic_pointer_cast<ParamNode>(p) != nullptr;
};
// if child A is not a parameter (i.e. activations) use computeType float32 for accumulation
Type computeTypeA = child(0)->trainable() ? child(0)->grad()->type() : Type::float32;
if(!isParameter(child(0)) && computeTypeA == Type::float16)
computeTypeA = Type::float32;
// if child B is not a parameter (i.e. activations) use computeType float32 for accumulation
Type computeTypeB = child(1)->trainable() ? child(1)->grad()->type() : Type::float32;
if(!isParameter(child(1)) && computeTypeB == Type::float16)
computeTypeB = Type::float32;
if(!transA_ && transB_)
return {NodeOp(Prod(child(0)->grad(),
adj_,
child(1)->val(),
false,
false,
1.0,
scalar_, computeTypeA)),
NodeOp(Prod(child(1)->grad(),
adj_,
child(0)->val(),
true,
false,
1.0,
scalar_, computeTypeB))};
if(transA_ && !transB_)
return {NodeOp(Prod(child(0)->grad(),
child(1)->val(),
adj_,
false,
true,
1.0,
scalar_, computeTypeA)),
NodeOp(Prod(child(1)->grad(),
child(0)->val(),
adj_,
false,
false,
1.0,
scalar_, computeTypeB))};
if(transA_ && transB_)
return {NodeOp(Prod(child(0)->grad(),
child(1)->val(),
adj_,
true,
true,
1.0,
scalar_, computeTypeA)),
NodeOp(Prod(child(1)->grad(),
adj_,
child(0)->val(),
true,
true,
1.0,
scalar_, computeTypeB))};
return {NodeOp(Prod(child(0)->grad(),
adj_,
child(1)->val(),
false,
true,
1.0,
scalar_, computeTypeA)),
NodeOp(Prod(child(1)->grad(),
child(0)->val(),
adj_,
true,
false,
1.0,
scalar_, computeTypeB))};
}
const std::string type() override { return "dot"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, transA_);
util::hash_combine(seed, transB_);
util::hash_combine(seed, scalar_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<DotNodeOp>(node);
if(!cnode)
return false;
if(transA_ != cnode->transA_)
return false;
if(transB_ != cnode->transB_)
return false;
if(scalar_ != cnode->scalar_)
return false;
return true;
}
const std::string color() override { return "orange"; }
};
class AffineNodeOp : public NaryNodeOp {
private:
friend class SerializationHelpers;
bool transA_;
bool transB_;
float scalar_;
public:
AffineNodeOp(const std::vector<Expr>& nodes,
bool transA,
bool transB,
float scalar)
: NaryNodeOp(nodes, newShape(nodes[0], nodes[1], transA, transB)),
transA_(transA),
transB_(transB),
scalar_(scalar) {}
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
auto shapeA = a->shape();
if(transA) {
shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
}
auto shapeB = b->shape();
if(transB) {
shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]);
shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]);
}
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
return outShape;
}
NodeOps forwardOps() override {
using namespace functional;
return {
NodeOp(Affine(val_,
graph()->allocator(),
child(0)->val(),
child(1)->val(),
child(2)->val(),
transA_,
transB_,
0.f,
scalar_,
/*doRelu=*/false))
};
}
NodeOps backwardOps() override {
// D is the adjoint, the matrix of derivatives
// df/dA += alpha * dot(D, op(B).T)
// df/dB += alpha * dot(op(A).T, D)
// beta set to 1.0 in gemm, C = alpha * dot(op(A), op(B)) + beta * C
// to sum gradients from different graph parts
auto isParameter = [](Expr p) {
return std::dynamic_pointer_cast<ParamNode>(p) != nullptr;
};
// if child A is not a parameter (i.e. activations) use computeType float32 for accumulation
Type computeTypeA = child(0)->trainable() ? child(0)->grad()->type() : Type::float32;
if(!isParameter(child(0)) && computeTypeA == Type::float16)
computeTypeA = Type::float32;
// if child B is not a parameter (i.e. activations) use computeType float32 for accumulation
Type computeTypeB = child(1)->trainable() ? child(1)->grad()->type() : Type::float32;
if(!isParameter(child(1)) && computeTypeB == Type::float16)
computeTypeB = Type::float32;
// if child C (bias) is not a parameter (i.e. activations) use computeType float32 for accumulation
Type computeTypeC = child(2)->trainable() ? child(2)->grad()->type() : Type::float32;
if(!isParameter(child(2)) && computeTypeC == Type::float16)
computeTypeC = Type::float32;
// We reduce bias gradients with a matrix multiply
if(!transA_ && transB_)
return {
NodeOp(Prod(child(0)->grad(),
adj_,
child(1)->val(),
false,
false,
1.0,
scalar_, computeTypeA)),
NodeOp(Prod(child(1)->grad(),
adj_,
child(0)->val(),
true,
false,
1.0,
scalar_, computeTypeB)),
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
if(transA_ && !transB_)
return {
NodeOp(Prod(child(0)->grad(),
child(1)->val(),
adj_,
false,
true,
1.0,
scalar_, computeTypeA)),
NodeOp(Prod(child(1)->grad(),
child(0)->val(),
adj_,
false,
false,
1.0,
scalar_, computeTypeB)),
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
if(transA_ && transB_)
return {
NodeOp(Prod(child(0)->grad(),
child(1)->val(),
adj_,
true,
true,
1.0,
scalar_, computeTypeA)),
NodeOp(Prod(child(1)->grad(),
adj_,
child(0)->val(),
true,
true,
1.0,
scalar_, computeTypeB)),
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
return {
NodeOp(Prod(child(0)->grad(),
adj_,
child(1)->val(),
false,
true,
1.0,
scalar_, computeTypeA)),
NodeOp(Prod(child(1)->grad(),
child(0)->val(),
adj_,
true,
false,
1.0,
scalar_, computeTypeB)),
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
}
const std::string type() override { return "affine"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, transA_);
util::hash_combine(seed, transB_);
util::hash_combine(seed, scalar_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<AffineNodeOp>(node);
if(!cnode)
return false;
if(transA_ != cnode->transA_)
return false;
if(transB_ != cnode->transB_)
return false;
if(scalar_ != cnode->scalar_)
return false;
return true;
}
};
class AffineWithReluNodeOp : public NaryNodeOp {
private:
friend class SerializationHelpers;
bool transA_;
bool transB_;
float scalar_;
public:
AffineWithReluNodeOp(Expr a,
Expr b,
Expr bias,
bool transA,
bool transB,
float scalar)
: NaryNodeOp({a, b, bias}, newShape(a, b, transA, transB)),
transA_(transA),
transB_(transB),
scalar_(scalar) {
ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
"AffineWithReluNodeOp currently only supported for inference on GPU");
}
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
auto shapeA = a->shape();
if(transA) {
shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
}
auto shapeB = b->shape();
if(transB) {
shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]);
shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]);
}
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
return outShape;
}
NodeOps forwardOps() override {
ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
"AffineWithReluNodeOp currently only supported for inference on GPU");
return {
NodeOp(Affine(val_,
graph()->allocator(),
child(0)->val(),
child(1)->val(),
child(2)->val(),
transA_,
transB_,
0.f,
scalar_,
/*doRelu=*/true))
};
}
NodeOps backwardOps() override {
ABORT("AffineWithReluNodeOp cannot be used for training??");
return {};
}
const std::string type() override { return "affineWithRelu"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, transA_);
util::hash_combine(seed, transB_);
util::hash_combine(seed, scalar_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<AffineWithReluNodeOp>(node);
if(!cnode)
return false;
if(transA_ != cnode->transA_)
return false;
if(transB_ != cnode->transB_)
return false;
if(scalar_ != cnode->scalar_)
return false;
return true;
}
};
class DotBatchedNodeOp : public NaryNodeOp {
private:
friend class SerializationHelpers;
bool transA_;
bool transB_;
float scalar_;
public:
DotBatchedNodeOp(Expr a, Expr b, bool transA, bool transB, float scalar)
: NaryNodeOp({a, b}, newShape(a, b, transA, transB)),
transA_(transA),
transB_(transB),
scalar_(scalar) {}
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
auto shapeA = a->shape();
if(transA) {
shapeA.set(-2, a->shape()[-1]);
shapeA.set(-1, a->shape()[-2]);
}
auto shapeB = b->shape();
if(transB) {
shapeB.set(-2, b->shape()[-1]);
shapeB.set(-1, b->shape()[-2]);
}
ABORT_IF(shapeA[-1] != shapeB[-2],
"Batched matrix product requires inner dimensions to match in {}{} * {}{}",
std::string(shapeA), transA, std::string(shapeB), transB);
// create shapes for batch dimensions only
auto shapeBatchA = shapeA;
shapeBatchA.set(-1, 1);
shapeBatchA.set(-2, 1);
auto shapeBatchB = shapeB;
shapeBatchB.set(-1, 1);
shapeBatchB.set(-2, 1);
// broadcast batch dimensions
auto shapeOut = Shape::broadcast({shapeBatchA, shapeBatchB});
// set non-batch dimensions in output
shapeOut.set(-2, shapeA[-2]);
shapeOut.set(-1, shapeB[-1]);
return shapeOut;
}
NodeOps forwardOps() override {
// C = alpha * dot(op(A), op(B))
return {NodeOp(ProdBatched(val_,
graph()->allocator(),
child(0)->val(),
child(1)->val(),
transA_,
transB_,
0.f,
scalar_))};
}
NodeOps backwardOps() override {
// D is the adjoint, the matrix of derivatives
// df/dA += alpha * dot(D, op(B).T)
// df/dB += alpha * dot(op(A).T, D)
// beta set to 1.0 in gemm, C = alpha * dot(op(A), op(B)) + beta * C
// to sum gradients from different graph parts
if(!transA_ && transB_)
return {NodeOp(ProdBatched(child(0)->grad(),
graph()->allocator(),
adj_,
child(1)->val(),
false,
false,
1.0,
scalar_)),
NodeOp(ProdBatched(child(1)->grad(),
graph()->allocator(),
adj_,
child(0)->val(),
true,
false,
1.0,
scalar_))};
if(transA_ && !transB_)
return {NodeOp(ProdBatched(child(0)->grad(),
graph()->allocator(),
child(1)->val(),
adj_,
false,
true,
1.0,
scalar_)),
NodeOp(ProdBatched(child(1)->grad(),
graph()->allocator(),
child(0)->val(),
adj_,
false,
false,
1.0,
scalar_))};
if(transA_ && transB_)
return {NodeOp(ProdBatched(child(0)->grad(),
graph()->allocator(),
child(1)->val(),
adj_,
true,
true,
1.0,
scalar_)),
NodeOp(ProdBatched(child(1)->grad(),
graph()->allocator(),
adj_,
child(0)->val(),
true,
true,
1.0,
scalar_))};
return {NodeOp(ProdBatched(child(0)->grad(),
graph()->allocator(),
adj_,
child(1)->val(),
false,
true,
1.0,
scalar_)),
NodeOp(ProdBatched(child(1)->grad(),
graph()->allocator(),
child(0)->val(),
adj_,
true,
false,
1.0,
scalar_))};
}
const std::string type() override { return "bdot"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, transA_);
util::hash_combine(seed, transB_);
util::hash_combine(seed, scalar_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<DotBatchedNodeOp>(node);
if(!cnode)
return false;
if(transA_ != cnode->transA_)
return false;
if(transB_ != cnode->transB_)
return false;
if(scalar_ != cnode->scalar_)
return false;
return true;
}
const std::string color() override { return "orange"; }
};
class DotBatchedLegacyNodeOp : public NaryNodeOp {
private:
friend class SerializationHelpers;
bool transA_;
bool transB_;
float scalar_;
public:
DotBatchedLegacyNodeOp(Expr a, Expr b, bool transA, bool transB, float scalar)
: NaryNodeOp({a, b}, newShape(a, b, transA, transB)),
transA_(transA),
transB_(transB),
scalar_(scalar) {}
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
auto shapeA = a->shape();
if(transA) {
shapeA.set(-2, a->shape()[-1]);
shapeA.set(-1, a->shape()[-2]);
}
auto shapeB = b->shape();
if(transB) {
shapeB.set(-2, b->shape()[-1]);
shapeB.set(-1, b->shape()[-2]);
}
Shape outShape = shapeA;
outShape.set(-1, shapeB[-1]);
ABORT_IF(shapeA[-1] != shapeB[-2],
"Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
return outShape;
}
NodeOps forwardOps() override {
// C = alpha * dot(op(A), op(B))
return {NodeOp(ProdBatchedLegacy(val_,
graph()->allocator(),
child(0)->val(),
child(1)->val(),
transA_,
transB_,
0.f,
scalar_))};
}
NodeOps backwardOps() override {
// D is the adjoint, the matrix of derivatives
// df/dA += alpha * dot(D, op(B).T)
// df/dB += alpha * dot(op(A).T, D)
// beta set to 1.0 in gemm, C = alpha * dot(op(A), op(B)) + beta * C
// to sum gradients from different graph parts
if(!transA_ && transB_)
return {NodeOp(ProdBatchedLegacy(child(0)->grad(),
graph()->allocator(),
adj_,
child(1)->val(),
false,
false,
1.0,
scalar_)),
NodeOp(ProdBatchedLegacy(child(1)->grad(),
graph()->allocator(),
adj_,
child(0)->val(),
true,
false,
1.0,
scalar_))};
if(transA_ && !transB_)
return {NodeOp(ProdBatchedLegacy(child(0)->grad(),
graph()->allocator(),
child(1)->val(),
adj_,
false,
true,
1.0,
scalar_)),
NodeOp(ProdBatchedLegacy(child(1)->grad(),
graph()->allocator(),
child(0)->val(),
adj_,
false,
false,
1.0,
scalar_))};
if(transA_ && transB_)
return {NodeOp(ProdBatchedLegacy(child(0)->grad(),
graph()->allocator(),
child(1)->val(),
adj_,
true,
true,
1.0,
scalar_)),
NodeOp(ProdBatchedLegacy(child(1)->grad(),
graph()->allocator(),
adj_,
child(0)->val(),
true,
true,
1.0,
scalar_))};
return {NodeOp(ProdBatchedLegacy(child(0)->grad(),
graph()->allocator(),
adj_,
child(1)->val(),
false,
true,
1.0,
scalar_)),
NodeOp(ProdBatchedLegacy(child(1)->grad(),
graph()->allocator(),
child(0)->val(),
adj_,
true,
false,
1.0,
scalar_))};
}
const std::string type() override { return "bdot_legacy"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, transA_);
util::hash_combine(seed, transB_);
util::hash_combine(seed, scalar_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<DotBatchedLegacyNodeOp>(node);
if(!cnode)
return false;
if(transA_ != cnode->transA_)
return false;
if(transB_ != cnode->transB_)
return false;
if(scalar_ != cnode->scalar_)
return false;
return true;
}
const std::string color() override { return "orange"; }
};
// Note: To reduce code duplication, we use the same NodeOp for C = op(S) x D and C = D x op(S).
// Set swapOperands to select the latter.
class CSRDotNodeOp : public NaryNodeOp {
bool transS_;
bool swapOperands_;
public:
CSRDotNodeOp(const Shape& S_shape, Expr S_values, Expr S_indices,
Expr S_offsets, Expr D, bool transS, bool swapOperands)
: NaryNodeOp({ S_values, S_indices, S_offsets, D },
newShape(S_shape, S_values, S_indices, S_offsets, D, transS, swapOperands),
NaryNodeOp::commonType({S_values, D})),
transS_(transS), swapOperands_(swapOperands) {
matchOrAbort<IndexType>(S_indices->value_type());
matchOrAbort<IndexType>(S_offsets->value_type());
ABORT_IF(swapOperands_, "Implementation for this is wonky, if you use this tell us.");
}
Shape newShape(const Shape& S_shape, Expr S_values, Expr S_indices, Expr S_offsets, Expr D, bool transS, bool swapOperands) {
ABORT_IF(S_values->shape().size() != 1 || S_indices->shape().size() != 1 || S_offsets->shape().size() != 1,
"Sparse matrix components must all be vectors");
ABORT_IF(S_values->shape() != S_indices->shape(),
"Sparse matrix values and indices must have the same shape");
ABORT_IF(S_shape.size() != 2,
"Sparse matrix must have rank 2");
ABORT_IF(S_offsets->shape()[0] - 1 != S_shape[0],
"Sparse matrix offset vector has incorrect size");
auto outShape = D->shape();
ABORT_IF(S_shape[transS == swapOperands ? 1 : 0] != outShape[-(int)swapOperands],
"Matrix product requires inner dimensions to match");
outShape.set(-(int)swapOperands, S_shape[transS != swapOperands]);
return outShape;
}
NodeOps forwardOps() override {
return {NodeOp(CSRProd(val_,
graph()->allocator(),
child(0)->val(), child(1)->val(), child(2)->val(),
child(3)->val(),
/*transS=*/transS_, /*swapOperands=*/swapOperands_, /*beta=*/0))};
}
NodeOps backwardOps() override {
return { nullptr, // can't backprop into the sparse matrix (the gradient is dense)
nullptr,
nullptr,
NodeOp(CSRProd(child(3)->grad(), // child(3) = D
graph()->allocator(),
child(0)->val(), child(1)->val(), child(2)->val(), // children(0..2) = A
adj_,
/*transS=*/!transS_, /*swapOperands=*/swapOperands_, /*beta=*/1))};
}
const std::string type() override { return "csr_dot"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
for(auto s : shape())
util::hash_combine(seed, s);
util::hash_combine(seed, transS_);
util::hash_combine(seed, swapOperands_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<CSRDotNodeOp>(node);
if(!cnode)
return false;
if(transS_ != cnode->transS_)
return false;
if(shape() != cnode->shape())
return false;
if(swapOperands_ != cnode->swapOperands_)
return false;
return true;
}
const std::string color() override { return "orange"; }
};
struct ScalarProductNodeOp : public NaryNodeOp {
ScalarProductNodeOp(Expr a, Expr b, int axis)
: NaryNodeOp({a, b}, newShape(a, b, axis)) {}
Shape newShape(Expr a, Expr b, int axis) {
Shape full = Shape::broadcast({a, b});
axis_ = full.axis(axis);
full.set(axis_, 1);
return full;
}
NodeOps forwardOps() override {
using namespace functional;
return {NodeOp(Reduce(_1 * _2, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
return {NodeOp(Add(_1 * _2, child(0)->grad(), child(1)->val(), adj_)),
NodeOp(Add(_1 * _2, child(1)->grad(), child(0)->val(), adj_))};
}
const std::string type() override { return "scalar-product"; }
const std::string color() override { return "orange"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, axis_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<ScalarProductNodeOp>(node);
if(!cnode)
return false;
if(axis_ != cnode->axis_)
return false;
return true;
}
int axis_;
};
struct RowsNodeOp : public NaryNodeOp {
RowsNodeOp(Expr a, Expr indices)
: NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) {
matchOrAbort<IndexType>(indices->value_type());
}
NodeOps forwardOps() override {
return {NodeOp(
CopyRows(val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
return {NodeOp(PasteRows(child(0)->grad(), adj_, child(1)->val()))};
}
Shape newShape(Expr a, Expr indices) {
Shape shape = a->shape();
ABORT_IF(shape.size() != 2,
"rows operator can only be used with 2-dimensional tensors");
shape.set(0, (int)indices->shape().elements());
return shape;
}
const std::string type() override { return "rows"; }
const std::string color() override { return "orange"; }
};
// This operation gathers elements of a tensor along an axis.
// This is like PyTorch gather().
// For example, this can be used for:
// - Same index applied to all batch items:
// 'index' has 1 in the axes that match batch axes in the input, and axis set to the one axis that gets selected over.
// Example: Selecting Transformer head 0, i.e. return a[:,1,:,:]
// axis = -3
// a : (B, H , S, T) B=batch dim, H=#heads, S=src length, T=trg length
// idx: ( #1#, 1, 1) #1# denotes 'axis'. All values are zero.
// out: (B, 1 , S, T) out[b, 0, s, t] == a[b, idx[/*0,*/ 0, s, t], s, t]
// - Same data with batched indices (today's rows()):
// 'data' has 1 in the batch axes.
// Example: Embedding lookup as done today using rows():
// axis = -2
// e : ( V , E) V=vocab size, E=embedding dimension
// idx: (#(B*S)#, 1) B=batch size, S=source length, idx values are in range 0..V-1
// out: ( (B*S) , E) out[b, s, e] == e[/*0,*/ idx[b, s, 0], e]
// - Batched selection (x-ent scenario): Both 'index' and 'data' have matching batch axes.
// Example: Cross-entropy loss as -gather(logSoftmax(logits), groundTruth, axis=-1):
// axis = -1
// lp : (B, T, V ) B=batch size, T=trg length, V=vocab size
// idx: (B, T, #1#) idx values are in range 0..V-1
// out: (B, T, 1 ) out[b,t,0] == lp[b, t, idx[b, t, 0]]
// Example for 2D tensor with axis=0:
// | t[index[0, 0] 0] t[index[0, 1] 1] |
// | t[index[1, 0] 0] t[index[1, 1] 1] |
// And for axis 1:
// | t[0 index[0, 0]] t[0 index[0, 1]] |
// | t[1 index[1, 0]] t[1 index[1, 1]] |
// For a 3-D tensor the output is specified by:
// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
// 'a' and 'indices' must have the same rank.
struct GatherNodeOp : public NaryNodeOp {
GatherNodeOp(Expr a, int axis, Expr indices)
: NaryNodeOp({a, indices}, newShape(a, axis, indices), a->value_type()),
axis_(a->shape().axis(axis)) {
matchOrAbort<IndexType>(indices->value_type());
}
NodeOps forwardOps() override {
return {NodeOp(
// @TODO: rename to gather
Select(val_, child(0)->val(), child(1)->val(), axis_))};
}
NodeOps backwardOps() override {
return {NodeOp(
// @TODO: rename to scatter
Insert</*add=*/true>(child(0)->grad(), adj_, child(1)->val(), axis_))};
}
Shape newShape(Expr a, int axis, Expr indices) {
Shape shape = a->shape();
axis = shape.axis(axis);
auto rank = shape.size();
ABORT_IF(rank != indices->shape().size(), "Mismatching ranks for input ({}) and indices ({})", std::string(shape), std::string(indices->shape()));
shape.set(axis, indices->shape()[axis]);
for (size_t i = 0; i < rank; ++i) {
if (i != axis) {
ABORT_IF(indices->shape()[i] != shape[i] && indices->shape()[i] != 1,
"Dimensions must match or broadcast for input ({}) and indices ({})", std::string(shape), std::string(indices->shape()));
}
}
return shape;
}
const std::string type() override { return "gather"; }
const std::string color() override { return "orange"; }
virtual size_t hash() override {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, axis_);
hash_ = seed;
}
return hash_;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<GatherNodeOp>(node);
if(!cnode)
return false;
if(axis_ != cnode->axis_)
return false;
return true;
}
private:
friend class SerializationHelpers;
int axis_;
};
struct ScatterNodeOp : public NaryNodeOp {
ScatterNodeOp(Expr a, int axis, Expr indices, Expr source)
: NaryNodeOp({a, indices, source}, newShape(a, axis, indices, source), a->value_type()),
axis_(a->shape().axis(axis)) {
matchOrAbort<IndexType>(indices->value_type());
}
NodeOps forwardOps() override {
return {NodeOp(
CopyCast(val_, child(0)->val()); // @TODO: use normal copy
Insert</*add=*/false>(val_, child(2)->val(), child(1)->val(), axis_)
)};
}
NodeOps backwardOps() override {
ABORT("backward for ScatterNodeOp not yet implemented");
}
Shape newShape(Expr a, int axis, Expr indices, Expr source) {
ABORT_IF(axis != -1, "only last dimensions");
ABORT_IF(indices->shape() != source->shape(), "Shapes must match");
Shape shape = a->shape();
// @TODO: do proper checking
return shape;
}
const std::string type() override { return "scatter"; }
const std::string color() override { return "orange"; }
virtual size_t hash() override {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, axis_);
hash_ = seed;
}
return hash_;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<ScatterNodeOp>(node);
if(!cnode)
return false;
if(axis_ != cnode->axis_)
return false;
return true;
}
private:
friend class SerializationHelpers;
int axis_;
};
struct ColsNodeOp : public NaryNodeOp {
ColsNodeOp(Expr a, Expr indices)
: NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) {
matchOrAbort<IndexType>(indices->value_type());
}
NodeOps forwardOps() override {
return {NodeOp(CopyCols(val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
return {NodeOp(PasteCols(child(0)->grad(), adj_, child(1)->val()))};
}
Shape newShape(Expr a, Expr indices) {
Shape shape = a->shape();
shape.set(1, (int)indices->shape().elements());
return shape;
}
const std::string type() override { return "cols"; }
const std::string color() override { return "orange"; }
};
struct ElementBinaryNodeOp : public NaryNodeOp {
ElementBinaryNodeOp(Expr a, Expr b)
: NaryNodeOp({a, b}, newShape(a, b)) {}
Shape newShape(Expr a, Expr b) { return Shape::broadcast({a, b}); }
const std::string color() override { return "yellow"; }
};
struct PlusNodeOp : public ElementBinaryNodeOp {
PlusNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
NodeOps forwardOps() override {
using namespace functional;
return {
NodeOp(Element(_1 = _2 + _3, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
return {NodeOp(Add(_1, child(0)->grad(), adj_)),
NodeOp(Add(_1, child(1)->grad(), adj_))};
}
const std::string type() override { return "+"; }
};
struct MinusNodeOp : public ElementBinaryNodeOp {
MinusNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
NodeOps forwardOps() override {
using namespace functional;
return {
NodeOp(Element(_1 = _2 - _3, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
return {NodeOp(Add(_1, child(0)->grad(), adj_)),
NodeOp(Add(-_1, child(1)->grad(), adj_))};
}
const std::string type() override { return "-"; }
};
struct MultNodeOp : public ElementBinaryNodeOp {
MultNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
NodeOps forwardOps() override {
using namespace functional;
return {
NodeOp(Element(_1 = _2 * _3, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
return {NodeOp(Add(_1 * _2, child(0)->grad(), adj_, child(1)->val())),
NodeOp(Add(_1 * _2, child(1)->grad(), adj_, child(0)->val()))};
}
const std::string type() override { return "*"; }
};
struct DivNodeOp : public ElementBinaryNodeOp {
DivNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
NodeOps forwardOps() override {
using namespace functional;
return {
NodeOp(Element(_1 = _2 / _3, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
return {
NodeOp(Add(_1 * 1.0f / _2, child(0)->grad(), adj_, child(1)->val())),
NodeOp(Add(-_1 * _2 / (_3 * _3),
child(1)->grad(),
adj_,
child(0)->val(),
child(1)->val()))};
}
const std::string type() override { return "/"; }
};
// struct PowNodeOp : public ElementBinaryNodeOp {
// public:
// template <typename... Args>
// PowNodeOp(Args... args) : ElementBinaryNodeOp(args...) {}
//
// NodeOps forwardOps() {
// return {NodeOp(Element(_1 = Pow(_2, _3), val_,
// child(0)->val(), child(1)->val()))};
// }
//
// NodeOps backwardOps() {
// return {
// NodeOp(Add(_2 * Pow(_1, _2 - 1.f) * _3,
// child(0)->grad(), child(0)->val(), child(1)->val(), adj_)),
// NodeOp(Add(Pow(_1, _2) * Log(_1) * _3,
// child(1)->grad(), child(0)->val(), child(1)->val(), adj_))
//
// };
// }
//
// const std::string type() { return "pow"; }
//};
struct LogAddExpNodeOp : public ElementBinaryNodeOp {
LogAddExpNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
NodeOps forwardOps() override {
using namespace functional;
return {NodeOp(Element(
_1 = logaddexp(_2, _3), val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
// d/dx (ln( exp(x) + (exp(y)) = exp(x) / (exp(x) + exp(y)) = 1 / (1 +
// exp(y-x)) = sigmoid(x-y)
return {NodeOp(Add(_1 * sigmoid(_2 - _3),
child(0)->grad(),
adj_,
child(0)->val(),
child(1)->val())),
NodeOp(Add(_1 * sigmoid(_3 - _2),
child(1)->grad(),
adj_,
child(0)->val(),
child(1)->val()))};
}
// TODO: this is not a "type" (as in data type). It's an operator name.
const std::string type() override { return "logaddexp"; }
};
struct MaximumNodeOp : public ElementBinaryNodeOp {
MaximumNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
NodeOps forwardOps() override {
using namespace functional;
return {NodeOp(
Element(_1 = max(_2, _3), val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
return {NodeOp(Add((_2 > _3) * _1,
child(0)->grad(),
adj_,
child(0)->val(),
child(1)->val())),
NodeOp(Add((_2 <= _3) * _1,
child(1)->grad(),
adj_,
child(0)->val(),
child(1)->val()))};
}
const std::string type() override { return "max"; }
};
// TODO: lotsa code dup here!
struct MinimumNodeOp : public ElementBinaryNodeOp {
MinimumNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
NodeOps forwardOps() override {
using namespace functional;
return {NodeOp(
Element(_1 = min(_2, _3), val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
return {NodeOp(Add((_2 < _3) * _1,
child(0)->grad(),
adj_,
child(0)->val(),
child(1)->val())),
NodeOp(Add((_2 >= _3) * _1,
child(1)->grad(),
adj_,
child(0)->val(),
child(1)->val()))};
}
const std::string type() override { return "min"; }
};
struct CmpNodeOp : public ElementBinaryNodeOp {
CmpNodeOp(Expr a, Expr b, int cmp_, bool not_) : ElementBinaryNodeOp(a, b), cmp_(cmp_), not_(not_) {
//setTrainable(false); // has no gradient
// Note: ^^ Disabled because it currently causing Marian to choke, for unknown reasons.
// Not setting this will not change the result since the vector of gradient functions is empty.
}
NodeOps forwardOps() override {
using namespace functional;
return {
NodeOp(Element(_1 = ((((_2 > _3) - (_2 < _3)) == (float)cmp_) != not_),
val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() override { return {}; }
const std::string type() override {
switch (cmp_) {
case -1: return not_ ? "ge" : "lt";
case 0: return not_ ? "ne" : "eq";
case 1: return not_ ? "le" : "gt";
}
ABORT("Should not get here??");
}
virtual size_t hash() override {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, cmp_);
util::hash_combine(seed, not_);
hash_ = seed;
}
return hash_;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<CmpNodeOp>(node);
if(!cnode)
return false;
if(cmp_ != cnode->cmp_)
return false;
if(not_ != cnode->not_)
return false;
return true;
}
private:
int cmp_; // -1: less; 0: equal; 1: greater
bool not_; // invert result if true
};
// In each j-th row, take the corresponding j-th label index i from indices and compute:
// For each vocabulary item v, the only non-zero element in a row in the sum is the item
// that matches the label indexed by i (the picked element).
// C = sum_{v in V}(-logsoftmax(A) * delta(v, i) = -logsoftmax(A)[i]
class CrossEntropyNodeOp : public NaryNodeOp {
private:
float labelSmoothingAlpha_;
public:
CrossEntropyNodeOp(Expr a, Expr indices, float labelSmoothingAlpha, Type outputType = Type::float32)
: NaryNodeOp({a, indices}, newShape(a), outputType),
labelSmoothingAlpha_(labelSmoothingAlpha) {
matchOrAbort<IndexType>(indices->value_type());
int rows = a->shape().elements() / a->shape()[-1];
int labels = indices->shape().elements();
ABORT_IF(rows != labels, "Number of examples and labels does not match: {} != {}", rows, labels);
}
Shape newShape(Expr a) {
Shape shape1 = a->shape();
shape1.set(a->shape().size() - 1, 1);
return shape1;
}
NodeOps forwardOps() override {
return {NodeOp(CrossEntropyPick(val_, child(0)->val(), child(1)->val(), labelSmoothingAlpha_))};
}
NodeOps backwardOps() override {
return {NodeOp(CrossEntropyPickBackward(
child(0)->grad(), adj_, child(0)->val(), child(1)->val(), labelSmoothingAlpha_))};
}
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, labelSmoothingAlpha_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<CrossEntropyNodeOp>(node);
if(!cnode)
return false;
if(labelSmoothingAlpha_ != cnode->labelSmoothingAlpha_)
return false;
return true;
}
const std::string type() override { return "x-ent"; }
};
struct ConcatenateNodeOp : public NaryNodeOp {
ConcatenateNodeOp(const std::vector<Expr>& nodes, int axis)
: NaryNodeOp(nodes, newShape(nodes, axis)) {
}
Shape newShape(const std::vector<Expr>& nodes, int ax) {
ABORT_IF(nodes.empty(), "No child nodes given");
Shape shape = nodes[0]->shape();
axis_ = shape.axis(ax);
int sum = 0;
auto checkShape = shape;
for(auto child : nodes) {
checkShape.set(axis_, child->shape()[axis_]); // don't abort on different sizes on axis dim.
ABORT_IF(checkShape != child->shape(),
"Child shapes {} and {} cannot be concatenated along axis {}",
shape, child->shape(), ax);
sum += child->shape()[axis_];
}
shape.set(axis_, sum);
return shape;
}
void forward() override {
std::vector<Tensor> concatenees;
for(size_t i = 0; i < children_.size(); ++i)
concatenees.push_back(child(i)->val());
Concatenate(val_, concatenees, axis_);
}
void backward() override {
std::vector<Tensor> deconcatenees;
for(size_t i = 0; i < children_.size(); ++i) {
auto childPtr = child(i);
childPtr->set_zero_adjoint(); // @TODO: this is a hotfix, do this properly
deconcatenees.push_back(childPtr->grad());
}
Deconcatenate(deconcatenees, adj_, axis_);
}
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, axis_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<ConcatenateNodeOp>(node);
if(!cnode)
return false;
if(axis_ != cnode->axis_)
return false;
return true;
}
const std::string type() override { return "concat"; }
private:
friend class SerializationHelpers;
int axis_;
};
// layer norm along last axis
struct LayerNormalizationOp : public NaryNodeOp {
public:
LayerNormalizationOp(const std::vector<Expr>& nodes, float eps = 1e-9)
: NaryNodeOp(nodes), eps_(eps) {
// @TODO: dimension check
}
NodeOps forwardOps() override {
return {NodeOp(
LayerNormalization(val_,
child(0)->val(),
child(1)->val(),
(children_.size() == 3) ? child(2)->val() : nullptr,
eps_))};
}
// @BUGBUG: backward has not been tested for broadcasting gamma/beta
NodeOps backwardOps() override {
return {NodeOp(
LayerNormalizationGrad(
graph()->allocator(),
child(0)->grad(),
child(1)->grad(),
(children_.size() == 3) ? child(2)->grad() : nullptr,
adj_,
val_,
child(0)->val(),
child(1)->val(),
(children_.size() == 3) ? child(2)->val() : nullptr,
eps_))};
}
const std::string type() override { return "layer_normalization"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, eps_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<LayerNormalizationOp>(node);
if(!cnode)
return false;
if(eps_ != cnode->eps_)
return false;
return true;
}
private:
friend class SerializationHelpers; // @TODO: use the same name for this as SqrtNodeOp
float eps_;
};
// RMS norm along last axis
struct RMSNormalizationOp : public NaryNodeOp {
public:
RMSNormalizationOp(const std::vector<Expr>& nodes, float eps = 1e-9)
: NaryNodeOp(nodes), eps_(eps) {
// @TODO: dimension check
}
NodeOps forwardOps() override {
return {NodeOp(
RMSNormalization(val_,
child(0)->val(),
child(1)->val(),
(children_.size() == 3) ? child(2)->val() : nullptr,
eps_))};
}
// @BUGBUG: backward has not been tested for broadcasting gamma/beta
NodeOps backwardOps() override {
return {NodeOp(
RMSNormalizationGrad(
graph()->allocator(),
child(0)->grad(),
child(1)->grad(),
(children_.size() == 3) ? child(2)->grad() : nullptr,
adj_,
val_,
child(0)->val(),
child(1)->val(),
(children_.size() == 3) ? child(2)->val() : nullptr,
eps_))};
}
const std::string type() override { return "rms_normalization"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, eps_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<RMSNormalizationOp>(node);
if(!cnode)
return false;
if(eps_ != cnode->eps_)
return false;
return true;
}
private:
friend class SerializationHelpers; // @TODO: use the same name for this as SqrtNodeOp
float eps_;
};
struct HighwayNodeOp : public NaryNodeOp {
HighwayNodeOp(const std::vector<Expr>& nodes) : NaryNodeOp(nodes) {}
NodeOps forwardOps() override {
return {NodeOp(HighwayForward(
val_, child(0)->val(), child(1)->val(), child(2)->val()))};
}
NodeOps backwardOps() override {
return {NodeOp(HighwayBackward(child(0)->grad(),
child(1)->grad(),
child(2)->grad(),
child(0)->val(),
child(1)->val(),
child(2)->val(),
adj_))};
}
const std::string type() override { return "highway"; }
};
#ifdef CUDNN
class ConvolutionOp : public NaryNodeOp {
public:
ConvolutionOp(const std::vector<Expr>& nodes,
int hPad = 0,
int wPad = 0,
int hStride = 1,
int wStride = 1)
: NaryNodeOp(nodes),
conv_(nodes[1]->shape(),
nodes[2]->shape(),
hPad,
wPad,
hStride,
wStride) {
conv_.getOutputShape(nodes[0]->shape(), shape_);
}
NodeOps forwardOps() override {
return {NodeOp(conv_.forward(
child(0)->val(), child(1)->val(), child(2)->val(), val_))};
}
NodeOps backwardOps() override {
return {NodeOp(conv_.backward(child(0)->val(),
child(0)->grad(),
child(1)->val(),
child(1)->grad(),
child(2)->grad(),
adj_))};
}
const std::string type() override { return "layer_convolution"; }
protected:
ConvolutionWrapper conv_;
};
#endif
} // namespace marian