Program Listing for File char_s2s.h

Return to documentation for file (src/models/char_s2s.h)

#pragma once

#include "marian.h"

#include "layers/convolution.h"
#include "models/s2s.h"

namespace marian {

class CharS2SEncoder : public EncoderS2S {
  using EncoderS2S::EncoderS2S;

public:
  virtual Ptr<EncoderState> build(Ptr<ExpressionGraph> graph,
                                  Ptr<data::CorpusBatch> batch) override {
    graph_ = graph;
    // select embeddings that occur in the batch
    Expr batchEmbeddings, batchMask; std::tie
    (batchEmbeddings, batchMask) = getEmbeddingLayer()->apply(batch->front());

    int dimEmb = opt<int>("dim-emb");
    auto convSizes = options_->get<std::vector<int>>("char-conv-filters-num");
    auto convWidths
        = options_->get<std::vector<int>>("char-conv-filters-widths");
    int stride = opt<int>("char-stride");
    int highwayNum = opt<int>("char-highway");

    auto conved = CharConvPooling(
        prefix_ + "conv_pooling", dimEmb, convWidths, convSizes, stride)(
        batchEmbeddings, batchMask);

    auto inHighway = conved;
    for(int i = 0; i < highwayNum; ++i) {
      inHighway = highway(prefix_ + "_" + std::to_string(i), inHighway);
    }

    Expr stridedMask = getStridedMask(graph, batch, stride);
    Expr context = applyEncoderRNN(
        graph, inHighway, stridedMask, opt<std::string>("enc-type"));

    return New<EncoderState>(context, stridedMask, batch);
  }

protected:
  Expr getStridedMask(Ptr<ExpressionGraph> graph,
                      Ptr<data::CorpusBatch> batch,
                      int stride) {
    auto subBatch = (*batch)[batchIndex_];

    size_t dimBatch = subBatch->batchSize();

    std::vector<float> strided;
    for(size_t wordIdx = 0; wordIdx < subBatch->mask().size();
        wordIdx += stride * dimBatch) {
      for(size_t j = wordIdx; j < wordIdx + dimBatch; ++j) {
        strided.push_back(subBatch->mask()[j]);
      }
    }
    size_t dimWords = strided.size() / dimBatch;
    auto stridedMask
        = graph->constant({(int)dimWords, (int)dimBatch, 1}, inits::fromVector(strided));
    return stridedMask;
  }
};
}  // namespace marian