Program Listing for File attention.cpp¶
↰ Return to documentation for file (src/rnn/attention.cpp
)
#include "attention.h"
#include "graph/node_operators_binary.h"
#include "tensors/tensor_operators.h"
namespace marian {
namespace rnn {
struct AttentionNodeOp : public NaryNodeOp {
AttentionNodeOp(const std::vector<Expr>& nodes)
: NaryNodeOp(nodes, newShape(nodes)) {}
Shape newShape(const std::vector<Expr>& nodes) {
Shape shape = Shape::broadcast({nodes[1], nodes[2]});
Shape vaShape = nodes[0]->shape();
ABORT_IF(vaShape[-2] != shape[-1] || vaShape[-1] != 1, "Wrong size");
shape.set(-1, 1);
return shape;
}
NodeOps forwardOps() override {
return {
NodeOp(Att(val_, child(0)->val(), child(1)->val(), child(2)->val()))};
}
NodeOps backwardOps() override {
return {
NodeOp(AttBack(child(0)->grad(),
child(1)->grad(),
child(2)->grad(),
child(0)->val(),
child(1)->val(),
child(2)->val(),
adj_);)
};
}
// do not check if node is trainable
virtual void runBackward(const NodeOps& ops) override {
for(auto&& op : ops)
op();
}
const std::string type() override { return "Att-ops"; }
const std::string color() override { return "yellow"; }
};
Expr attOps(Expr va, Expr context, Expr state) {
std::vector<Expr> nodes{va, context, state};
int dimBatch = context->shape()[-2];
int dimWords = context->shape()[-3];
int dimBeam = 1;
if(state->shape().size() > 3)
dimBeam = state->shape()[-4];
return reshape(Expression<AttentionNodeOp>(nodes),
{dimBeam, 1, dimWords, dimBatch});
}
} // namespace rnn
} // namespace marian