Program Listing for File laser.h

Return to documentation for file (src/models/laser.h)

#pragma once

#include "marian.h"

#include "layers/constructors.h"
#include "rnn/constructors.h"

namespace marian {

// Re-implements the LASER BiLSTM encoder from:
// Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond
// Mikel Artetxe, Holger Schwenk
// https://arxiv.org/abs/1812.10464

class EncoderLaser : public EncoderBase {
  using EncoderBase::EncoderBase;

public:
  Expr applyEncoderRNN(Ptr<ExpressionGraph> graph,
                       Expr embeddings,
                       Expr mask) {
    int depth = opt<int>("enc-depth");
    float dropoutRnn = inference_ ? 0 : opt<float>("dropout-rnn");

    Expr output = embeddings;

    auto applyRnn = [&](int layer, rnn::dir direction, Expr input, Expr mask) {

        std::string paramPrefix = prefix_ + "_" + opt<std::string>("enc-cell");
        paramPrefix += "_l" + std::to_string(layer);
        if(direction == rnn::dir::backward)
            paramPrefix += "_reverse";

        auto rnnFactory = rnn::rnn()
            ("type", opt<std::string>("enc-cell"))
            ("direction", (int)direction)
            ("dimInput", input->shape()[-1])
            ("dimState", opt<int>("dim-rnn"))
            ("dropout", dropoutRnn)
            ("layer-normalization", opt<bool>("layer-normalization"))
            ("skip", opt<bool>("skip"))
            .push_back(rnn::cell()("prefix", paramPrefix));

        return rnnFactory.construct(graph)->transduce(input, mask);
    };

    for(int i = 0; i < depth; ++i) {
        output = concatenate({applyRnn(i, rnn::dir::forward, output, mask),
                              applyRnn(i, rnn::dir::backward, output, mask)},
                              /*axis =*/ -1);
    }

    return output;
  }

  virtual Ptr<EncoderState> build(Ptr<ExpressionGraph> graph,
                                  Ptr<data::CorpusBatch> batch) override {
    graph_ = graph;
    // select embeddings that occur in the batch
    Expr batchEmbeddings, batchMask; std::tie
    (batchEmbeddings, batchMask) = getEmbeddingLayer()->apply((*batch)[batchIndex_]);

    Expr context = applyEncoderRNN(graph_, batchEmbeddings, batchMask);

    return New<EncoderState>(context, batchMask, batch);
  }

  void clear() override {}
};

}