.. _program_listing_file_src_rnn_attention.h: Program Listing for File attention.h ==================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/rnn/attention.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "marian.h" #include "models/states.h" #include "rnn/types.h" namespace marian { namespace rnn { Expr attOps(Expr va, Expr context, Expr state); // Attitive attention used in RNN cells. // @TODO: come up with common framework for attention in RNNs and Transformer. class GlobalAttention : public CellInput { private: Expr Wa_, ba_, Ua_, va_; Expr gammaContext_; Expr gammaState_; Ptr encState_; Expr softmaxMask_; Expr mappedContext_; std::vector contexts_; std::vector alignments_; bool layerNorm_; float dropout_; Expr contextDropped_; Expr dropMaskContext_; Expr dropMaskState_; // for Nematus-style layer normalization Expr Wc_att_lns_, Wc_att_lnb_; Expr W_comb_att_lns_, W_comb_att_lnb_; bool nematusNorm_; public: GlobalAttention(Ptr graph, Ptr options, Ptr encState) : CellInput(options), encState_(encState), contextDropped_(encState->getContext()) { int dimDecState = options_->get("dimState"); dropout_ = options_->get("dropout", 0); layerNorm_ = options_->get("layer-normalization", false); nematusNorm_ = options_->get("nematus-normalization", false); std::string prefix = options_->get("prefix"); int dimEncState = encState_->getContext()->shape()[-1]; Wa_ = graph->param(prefix + "_W_comb_att", {dimDecState, dimEncState}, inits::glorotUniform()); Ua_ = graph->param( prefix + "_Wc_att", {dimEncState, dimEncState}, inits::glorotUniform()); va_ = graph->param( prefix + "_U_att", {dimEncState, 1}, inits::glorotUniform()); ba_ = graph->param(prefix + "_b_att", {1, dimEncState}, inits::zeros()); if(dropout_ > 0.0f) { dropMaskContext_ = graph->dropoutMask(dropout_, {1, dimEncState}); dropMaskState_ = graph->dropoutMask(dropout_, {1, dimDecState}); } contextDropped_ = dropout(contextDropped_, dropMaskContext_); if(layerNorm_) { if(nematusNorm_) { // instead of gammaContext_ Wc_att_lns_ = graph->param( prefix + "_Wc_att_lns", {1, dimEncState}, inits::fromValue(1.f)); Wc_att_lnb_ = graph->param( prefix + "_Wc_att_lnb", {1, dimEncState}, inits::zeros()); // instead of gammaState_ W_comb_att_lns_ = graph->param(prefix + "_W_comb_att_lns", {1, dimEncState}, inits::fromValue(1.f)); W_comb_att_lnb_ = graph->param( prefix + "_W_comb_att_lnb", {1, dimEncState}, inits::zeros()); mappedContext_ = layerNorm(affine(contextDropped_, Ua_, ba_), Wc_att_lns_, Wc_att_lnb_, NEMATUS_LN_EPS); } else { gammaContext_ = graph->param( prefix + "_att_gamma1", {1, dimEncState}, inits::fromValue(1.0)); gammaState_ = graph->param( prefix + "_att_gamma2", {1, dimEncState}, inits::fromValue(1.0)); mappedContext_ = layerNorm(dot(contextDropped_, Ua_), gammaContext_, ba_); } } else { mappedContext_ = affine(contextDropped_, Ua_, ba_); } auto softmaxMask = encState_->getMask(); if(softmaxMask) { Shape shape = {softmaxMask->shape()[-3], softmaxMask->shape()[-2]}; softmaxMask_ = transpose(reshape(softmaxMask, shape)); } } Expr apply(State state) override { auto recState = state.output; int dimBatch = contextDropped_->shape()[-2]; int srcWords = contextDropped_->shape()[-3]; int dimBeam = 1; if(recState->shape().size() > 3) dimBeam = recState->shape()[-4]; recState = dropout(recState, dropMaskState_); auto mappedState = dot(recState, Wa_); if(layerNorm_) { if(nematusNorm_) { mappedState = layerNorm( mappedState, W_comb_att_lns_, W_comb_att_lnb_, NEMATUS_LN_EPS); } else { mappedState = layerNorm(mappedState, gammaState_); } } auto attReduce = attOps(va_, mappedContext_, mappedState); // @TODO: horrible -> auto e = reshape(transpose(softmax(transpose(attReduce), softmaxMask_)), {dimBeam, srcWords, dimBatch, 1}); // <- horrible auto alignedSource = scalar_product(encState_->getAttended(), e, /*axis =*/ -3); contexts_.push_back(alignedSource); alignments_.push_back(e); return alignedSource; } std::vector& getContexts() { return contexts_; } Expr getContext() { return concatenate(contexts_, /*axis =*/ -3); } std::vector& getAlignments() { return alignments_; } virtual void clear() override { contexts_.clear(); alignments_.clear(); } int dimOutput() override { return encState_->getContext()->shape()[-1]; } }; using Attention = GlobalAttention; } // namespace rnn } // namespace marian