Program Listing for File attention.h

Return to documentation for file (src/rnn/attention.h)

#pragma once

#include "marian.h"
#include "models/states.h"
#include "rnn/types.h"

namespace marian {
namespace rnn {

Expr attOps(Expr va, Expr context, Expr state);

// Attitive attention used in RNN cells.
// @TODO: come up with common framework for attention in RNNs and Transformer.
class GlobalAttention : public CellInput {
private:
  Expr Wa_, ba_, Ua_, va_;

  Expr gammaContext_;
  Expr gammaState_;

  Ptr<EncoderState> encState_;
  Expr softmaxMask_;
  Expr mappedContext_;
  std::vector<Expr> contexts_;
  std::vector<Expr> alignments_;
  bool layerNorm_;
  float dropout_;

  Expr contextDropped_;
  Expr dropMaskContext_;
  Expr dropMaskState_;

  // for Nematus-style layer normalization
  Expr Wc_att_lns_, Wc_att_lnb_;
  Expr W_comb_att_lns_, W_comb_att_lnb_;
  bool nematusNorm_;

public:
  GlobalAttention(Ptr<ExpressionGraph> graph,
                  Ptr<Options> options,
                  Ptr<EncoderState> encState)
      : CellInput(options),
        encState_(encState),
        contextDropped_(encState->getContext()) {
    int dimDecState = options_->get<int>("dimState");
    dropout_ = options_->get<float>("dropout", 0);
    layerNorm_ = options_->get<bool>("layer-normalization", false);
    nematusNorm_ = options_->get<bool>("nematus-normalization", false);
    std::string prefix = options_->get<std::string>("prefix");

    int dimEncState = encState_->getContext()->shape()[-1];

    Wa_ = graph->param(prefix + "_W_comb_att",
                       {dimDecState, dimEncState},
                       inits::glorotUniform());
    Ua_ = graph->param(
        prefix + "_Wc_att", {dimEncState, dimEncState}, inits::glorotUniform());
    va_ = graph->param(
        prefix + "_U_att", {dimEncState, 1}, inits::glorotUniform());
    ba_ = graph->param(prefix + "_b_att", {1, dimEncState}, inits::zeros());

    if(dropout_ > 0.0f) {
      dropMaskContext_ = graph->dropoutMask(dropout_, {1, dimEncState});
      dropMaskState_   = graph->dropoutMask(dropout_, {1, dimDecState});
    }

    contextDropped_ = dropout(contextDropped_, dropMaskContext_);

    if(layerNorm_) {
      if(nematusNorm_) {
        // instead of gammaContext_
        Wc_att_lns_ = graph->param(
            prefix + "_Wc_att_lns", {1, dimEncState}, inits::fromValue(1.f));
        Wc_att_lnb_ = graph->param(
            prefix + "_Wc_att_lnb", {1, dimEncState}, inits::zeros());
        // instead of gammaState_
        W_comb_att_lns_ = graph->param(prefix + "_W_comb_att_lns",
                                       {1, dimEncState},
                                       inits::fromValue(1.f));
        W_comb_att_lnb_ = graph->param(
            prefix + "_W_comb_att_lnb", {1, dimEncState}, inits::zeros());

        mappedContext_ = layerNorm(affine(contextDropped_, Ua_, ba_),
                                   Wc_att_lns_,
                                   Wc_att_lnb_,
                                   NEMATUS_LN_EPS);
      } else {
        gammaContext_ = graph->param(
            prefix + "_att_gamma1", {1, dimEncState}, inits::fromValue(1.0));
        gammaState_ = graph->param(
            prefix + "_att_gamma2", {1, dimEncState}, inits::fromValue(1.0));

        mappedContext_
            = layerNorm(dot(contextDropped_, Ua_), gammaContext_, ba_);
      }

    } else {
      mappedContext_ = affine(contextDropped_, Ua_, ba_);
    }

    auto softmaxMask = encState_->getMask();
    if(softmaxMask) {
      Shape shape = {softmaxMask->shape()[-3], softmaxMask->shape()[-2]};
      softmaxMask_ = transpose(reshape(softmaxMask, shape));
    }
  }

  Expr apply(State state) override {
    auto recState = state.output;

    int dimBatch = contextDropped_->shape()[-2];
    int srcWords = contextDropped_->shape()[-3];
    int dimBeam = 1;
    if(recState->shape().size() > 3)
      dimBeam = recState->shape()[-4];

    recState = dropout(recState, dropMaskState_);

    auto mappedState = dot(recState, Wa_);
    if(layerNorm_) {
      if(nematusNorm_) {
        mappedState = layerNorm(
            mappedState, W_comb_att_lns_, W_comb_att_lnb_, NEMATUS_LN_EPS);
      } else {
        mappedState = layerNorm(mappedState, gammaState_);
      }
    }

    auto attReduce = attOps(va_, mappedContext_, mappedState);

    // @TODO: horrible ->
    auto e = reshape(transpose(softmax(transpose(attReduce), softmaxMask_)),
                     {dimBeam, srcWords, dimBatch, 1});
    // <- horrible

    auto alignedSource = scalar_product(encState_->getAttended(), e, /*axis =*/ -3);

    contexts_.push_back(alignedSource);
    alignments_.push_back(e);
    return alignedSource;
  }

  std::vector<Expr>& getContexts() { return contexts_; }

  Expr getContext() { return concatenate(contexts_, /*axis =*/ -3); }

  std::vector<Expr>& getAlignments() { return alignments_; }

  virtual void clear() override {
    contexts_.clear();
    alignments_.clear();
  }

  int dimOutput() override { return encState_->getContext()->shape()[-1]; }
};

using Attention = GlobalAttention;
}  // namespace rnn
}  // namespace marian