Program Listing for File encoder_decoder.cpp

Return to documentation for file (src/models/encoder_decoder.cpp)

#include "models/encoder_decoder.h"
#include "common/cli_helper.h"
#include "common/filesystem.h"
#include "common/version.h"

namespace marian {

EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)
    : LayerBase(graph, options),
      prefix_(options->get<std::string>("prefix", "")),
      inference_(options->get<bool>("inference", false)) {

  std::vector<std::string> encoderDecoderModelFeatures =
    {"type",
     "dim-vocabs",
     "dim-emb",
     "dim-rnn",
     "enc-cell",
     "enc-type",
     "enc-cell-depth",
     "enc-depth",
     "dec-depth",
     "dec-cell",
     "dec-cell-base-depth",
     "dec-cell-high-depth",
     "skip",
     "layer-normalization",
     "right-left",
     "input-types",
     "special-vocab",
     "tied-embeddings",
     "tied-embeddings-src",
     "tied-embeddings-all"};

  for(auto feature : encoderDecoderModelFeatures)
    modelFeatures_.insert(feature);

  modelFeatures_.insert("transformer-heads");
  modelFeatures_.insert("transformer-no-projection");
  modelFeatures_.insert("transformer-dim-ffn");
  modelFeatures_.insert("transformer-decoder-dim-ffn");
  modelFeatures_.insert("transformer-ffn-depth");
  modelFeatures_.insert("transformer-decoder-ffn-depth");
  modelFeatures_.insert("transformer-ffn-activation");
  modelFeatures_.insert("transformer-dim-aan");
  modelFeatures_.insert("transformer-aan-depth");
  modelFeatures_.insert("transformer-aan-activation");
  modelFeatures_.insert("transformer-aan-nogate");
  modelFeatures_.insert("transformer-preprocess");
  modelFeatures_.insert("transformer-postprocess");
  modelFeatures_.insert("transformer-postprocess-emb");
  modelFeatures_.insert("transformer-postprocess-top");
  modelFeatures_.insert("transformer-decoder-autoreg");
  modelFeatures_.insert("transformer-tied-layers");
  modelFeatures_.insert("transformer-guided-alignment-layer");
  modelFeatures_.insert("transformer-train-position-embeddings");
  modelFeatures_.insert("transformer-pool");

  modelFeatures_.insert("bert-train-type-embeddings");
  modelFeatures_.insert("bert-type-vocab-size");

  modelFeatures_.insert("ulr");
  modelFeatures_.insert("ulr-trainable-transformation");
  modelFeatures_.insert("ulr-dim-emb");
  modelFeatures_.insert("lemma-dim-emb");
  modelFeatures_.insert("output-omit-bias");
  modelFeatures_.insert("lemma-dependency");
  modelFeatures_.insert("factors-combine");
  modelFeatures_.insert("factors-dim-emb");
}

std::vector<Ptr<EncoderBase>>& EncoderDecoder::getEncoders() {
  return encoders_;
}

void EncoderDecoder::push_back(Ptr<EncoderBase> encoder) {
  encoders_.push_back(encoder);
}

std::vector<Ptr<DecoderBase>>& EncoderDecoder::getDecoders() {
  return decoders_;
}

void EncoderDecoder::push_back(Ptr<DecoderBase> decoder) {
  decoders_.push_back(decoder);
}

void EncoderDecoder::createDecoderConfig(const std::string& name) {
  Config::YamlNode decoder;

  if(options_->get<bool>("relative-paths")) {
    decoder["relative-paths"] = true;
    // we can safely use a bare model file name here, because the config file is created in the same
    // directory as the model file
    auto modelFileName = filesystem::Path{name}.filename().string();
    decoder["models"] = std::vector<std::string>({modelFileName});

    // create relative paths to vocabs with regard to saved model checkpoint
    auto dirPath = filesystem::Path{name}.parentPath();
    std::vector<std::string> relativeVocabs;
    const auto& vocabs = options_->get<std::vector<std::string>>("vocabs");
    std::transform(
        vocabs.begin(),
        vocabs.end(),
        std::back_inserter(relativeVocabs),
        [&](const std::string& p) -> std::string {
          return filesystem::relative(filesystem::Path{p}, dirPath).string();
        });

    decoder["vocabs"] = relativeVocabs;
  } else {
    decoder["relative-paths"] = false;
    decoder["models"] = std::vector<std::string>({name});
    decoder["vocabs"] = options_->get<std::vector<std::string>>("vocabs");
  }

  decoder["beam-size"] = opt<size_t>("beam-size");
  decoder["normalize"] = opt<float>("normalize");
  decoder["word-penalty"] = opt<float>("word-penalty");

  decoder["mini-batch"] = opt<size_t>("valid-mini-batch");
  decoder["maxi-batch"] = opt<size_t>("valid-mini-batch") > 1 ? 100 : 1;
  decoder["maxi-batch-sort"] = opt<size_t>("valid-mini-batch") > 1 ? "src" : "none";

  io::OutputFileStream out(name + ".decoder.yml");
  out << decoder;
}

Config::YamlNode EncoderDecoder::getModelParameters() {
  Config::YamlNode modelParams;
  auto clone = options_->cloneToYamlNode();
  for(auto& key : modelFeatures_)
    modelParams[key] = clone[key];

  if(options_->has("original-type"))
    modelParams["type"] = clone["original-type"];

  modelParams["version"] = buildVersion();
  return modelParams;
}

std::string EncoderDecoder::getModelParametersAsString() {
  auto yaml = getModelParameters();
  YAML::Emitter out;
  cli::OutputYaml(yaml, out);
  return std::string(out.c_str());
}

void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
                          const std::vector<io::Item>& items,
                          bool markedReloaded) {
  graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
}

void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
                          const std::string& name,
                          bool markedReloaded) {
  graph->load(name, markedReloaded && !opt<bool>("ignore-model-config", false));
}

void EncoderDecoder::mmap(Ptr<ExpressionGraph> graph,
                          const void* ptr,
                          bool markedReloaded) {
  graph->mmap(ptr, markedReloaded && !opt<bool>("ignore-model-config", false));
}

void EncoderDecoder::save(Ptr<ExpressionGraph> graph,
                          const std::string& name,
                          bool saveTranslatorConfig) {
  // ignore config for now
  LOG(info, "Saving model weights and runtime parameters to {}", name);

  graph->save(name, getModelParametersAsString());

  if(saveTranslatorConfig)
    createDecoderConfig(name);
}

void EncoderDecoder::clear(Ptr<ExpressionGraph> graph) {
  graph->clear();

  for(auto& enc : encoders_)
    enc->clear();
  for(auto& dec : decoders_)
    dec->clear();
}

Ptr<DecoderState> EncoderDecoder::startState(Ptr<ExpressionGraph> graph,
                                             Ptr<data::CorpusBatch> batch) {
  std::vector<Ptr<EncoderState>> encoderStates;
  for(auto& encoder : encoders_)
    encoderStates.push_back(encoder->build(graph, batch));

  // initialize shortlist here
  if(shortlistGenerator_) {
    auto shortlist = shortlistGenerator_->generate(batch);
    decoders_[0]->setShortlist(shortlist);
  }

  return decoders_[0]->startState(graph, batch, encoderStates);
}

Ptr<DecoderState> EncoderDecoder::step(Ptr<ExpressionGraph> graph,
                                       Ptr<DecoderState> state,
                                       const std::vector<IndexType>& hypIndices,   // [beamIndex * activeBatchSize + batchIndex]
                                       const Words& words,                         // [beamIndex * activeBatchSize + batchIndex]
                                       const std::vector<IndexType>& batchIndices, // [batchIndex]
                                       int beamSize) {
  // create updated state that reflects reordering and dropping of hypotheses
  state = hypIndices.empty() ? state : state->select(hypIndices, batchIndices, beamSize);

  // Fill state with embeddings based on last prediction
  decoders_[0]->embeddingsFromPrediction(graph, state, words, (int) batchIndices.size(), beamSize);
  auto nextState = decoders_[0]->step(graph, state);

  return nextState;
}

Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
                                          Ptr<data::CorpusBatch> batch,
                                          bool clearGraph) {
  if(clearGraph)
    clear(graph);

  // Required first step, also initializes shortlist
  auto state = startState(graph, batch);

  // Fill state with embeddings from batch (ground truth)
  decoders_[0]->embeddingsFromBatch(graph, state, batch);
  auto nextState = decoders_[0]->step(graph, state);
  nextState->setTargetMask(state->getTargetMask());
  nextState->setTargetWords(state->getTargetWords());

  return nextState;
}

Logits EncoderDecoder::build(Ptr<ExpressionGraph> graph,
                           Ptr<data::CorpusBatch> batch,
                           bool clearGraph) {
  auto state = stepAll(graph, batch, clearGraph);

  // returns raw logits
  return state->getLogProbs();
}

Logits EncoderDecoder::build(Ptr<ExpressionGraph> graph,
                           Ptr<data::Batch> batch,
                           bool clearGraph) {
  auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
  return build(graph, corpusBatch, clearGraph);
}

}  // namespace marian