Program Listing for File nematus.h

Return to documentation for file (src/models/nematus.h)

#pragma once

#include "marian.h"

#include "models/s2s.h"

namespace marian {

class Nematus : public EncoderDecoder {
public:
  Nematus(Ptr<ExpressionGraph> graph, Ptr<Options> options) : EncoderDecoder(graph, options), nameMap_(createNameMap()) {
    ABORT_IF(options_->get<std::string>("enc-type") != "bidirectional",
             "--type nematus does not support other encoder "
             "type than bidirectional, use --type s2s");

    ABORT_IF(options_->get<std::string>("enc-cell") != "gru-nematus",
             "--type nematus does not support other rnn cells "
             "than gru-nematus, use --type s2s");
    ABORT_IF(options_->get<std::string>("dec-cell") != "gru-nematus",
             "--type nematus does not support other rnn cells "
             "than gru-nematus, use --type s2s");

    ABORT_IF(options_->get<int>("dec-cell-high-depth") > 1,
             "--type nematus does not currently support "
             "--dec-cell-high-depth > 1, use --type s2s");
  }

  void load(Ptr<ExpressionGraph> graph,
            const std::vector<io::Item>& items,
            bool /*markReloaded*/ = true) override {
    auto ioItems = items;
    // map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node
    for(auto it = ioItems.begin(); it != ioItems.end();) {
      // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
      // @TODO: consider dropping support for Nematus models
      if(it->shape.size() == 1) {
        int dim = it->shape[-1];
        it->shape.resize(2);
        it->shape.set(0, 1);
        it->shape.set(1, dim);
      }

      if(it->name == "decoder_c_tt") {
        it = ioItems.erase(it);
      } else if(it->name == "uidx") {
        it = ioItems.erase(it);
      } else if(it->name == "history_errs") {
        it = ioItems.erase(it);
      } else {
        auto pair = nameMap_.find(it->name);
        if(pair != nameMap_.end())
          it->name = pair->second;
        it++;
      }
    }
    // load items into the graph
    graph->load(ioItems);
  }

  void load(Ptr<ExpressionGraph> graph,
            const std::string& name,
            bool /*markReloaded*/ = true) override {
    LOG(info, "Loading model from {}", name);
    auto ioItems = io::loadItems(name);
    load(graph, ioItems);
  }

  void save(Ptr<ExpressionGraph> graph,
            const std::string& name,
            bool saveTranslatorConfig = false) override {
    LOG(info, "Saving model to {}", name);

    // prepare reversed map
    if(nameMapRev_.empty())
      for(const auto& kv : nameMap_)
        nameMapRev_.insert({kv.second, kv.first});

    // get parameters from the graph to items
    std::vector<io::Item> ioItems;
    graph->save(ioItems);
    // replace names to be compatible with Nematus
    for(auto& item : ioItems) {
      auto newItemName = nameMapRev_.find(item.name);
      if(newItemName != nameMapRev_.end())
        item.name = newItemName->second;
    }
    // add a dummy matrix 'decoder_c_tt' required for Amun and Nematus
    ioItems.emplace_back();
    ioItems.back().name = "decoder_c_tt";
    ioItems.back().shape = Shape({1, 0});
    ioItems.back().bytes.emplace_back((char)0);

    io::addMetaToItems(getModelParametersAsString(), "special:model.yml", ioItems);
    io::saveItems(name, ioItems);

    if(saveTranslatorConfig) {
      createAmunConfig(name);
      createDecoderConfig(name);
    }
  }

private:
  std::map<std::string, std::string> nameMap_;
  std::map<std::string, std::string> nameMapRev_;

  std::map<std::string, std::string> createNameMap() {
    std::map<std::string, std::string> nameMap
        = {{"decoder_U", "decoder_cell1_U"},
           {"decoder_Ux", "decoder_cell1_Ux"},
           {"decoder_W", "decoder_cell1_W"},
           {"decoder_Wx", "decoder_cell1_Wx"},
           {"decoder_b", "decoder_cell1_b"},
           {"decoder_bx", "decoder_cell1_bx"},
           {"decoder_U_nl", "decoder_cell2_U"},
           {"decoder_Ux_nl", "decoder_cell2_Ux"},
           {"decoder_Wc", "decoder_cell2_W"},
           {"decoder_Wcx", "decoder_cell2_Wx"},
           {"decoder_b_nl", "decoder_cell2_b"},
           {"decoder_bx_nl", "decoder_cell2_bx"},
           {"ff_logit_prev_W", "decoder_ff_logit_l1_W0"},
           {"ff_logit_lstm_W", "decoder_ff_logit_l1_W1"},
           {"ff_logit_ctx_W", "decoder_ff_logit_l1_W2"},
           {"ff_logit_prev_b", "decoder_ff_logit_l1_b0"},
           {"ff_logit_lstm_b", "decoder_ff_logit_l1_b1"},
           {"ff_logit_ctx_b", "decoder_ff_logit_l1_b2"},
           {"ff_logit_W", "decoder_ff_logit_l2_W"},
           {"ff_logit_b", "decoder_ff_logit_l2_b"},
           {"ff_state_W", "decoder_ff_state_W"},
           {"ff_state_b", "decoder_ff_state_b"},
           {"Wemb_dec", "decoder_Wemb"},
           {"Wemb", "encoder_Wemb"},
           {"encoder_U", "encoder_bi_U"},
           {"encoder_Ux", "encoder_bi_Ux"},
           {"encoder_W", "encoder_bi_W"},
           {"encoder_Wx", "encoder_bi_Wx"},
           {"encoder_b", "encoder_bi_b"},
           {"encoder_bx", "encoder_bi_bx"},
           {"encoder_r_U", "encoder_bi_r_U"},
           {"encoder_r_Ux", "encoder_bi_r_Ux"},
           {"encoder_r_W", "encoder_bi_r_W"},
           {"encoder_r_Wx", "encoder_bi_r_Wx"},
           {"encoder_r_b", "encoder_bi_r_b"},
           {"encoder_r_bx", "encoder_bi_r_bx"},
           {"ff_state_ln_s", "decoder_ff_state_ln_s"},
           {"ff_state_ln_b", "decoder_ff_state_ln_b"},
           {"ff_logit_prev_ln_s", "decoder_ff_logit_l1_ln_s0"},
           {"ff_logit_lstm_ln_s", "decoder_ff_logit_l1_ln_s1"},
           {"ff_logit_ctx_ln_s", "decoder_ff_logit_l1_ln_s2"},
           {"ff_logit_prev_ln_b", "decoder_ff_logit_l1_ln_b0"},
           {"ff_logit_lstm_ln_b", "decoder_ff_logit_l1_ln_b1"},
           {"ff_logit_ctx_ln_b", "decoder_ff_logit_l1_ln_b2"}};

    // add mapping for deep encoder cells
    std::vector<std::string> suffixes = {"_U", "_Ux", "_b", "_bx"};
    for(int i = 1; i < options_->get<int>("enc-cell-depth"); ++i) {
      std::string num1 = std::to_string(i);
      std::string num2 = std::to_string(i + 1);
      for(auto suf : suffixes) {
        nameMap.insert({"encoder" + suf + "_drt_" + num1, "encoder_bi_cell" + num2 + suf});
        nameMap.insert({"encoder_r" + suf + "_drt_" + num1, "encoder_bi_r_cell" + num2 + suf});
      }
    }
    // add mapping for deep decoder cells
    for(int i = 3; i <= options_->get<int>("dec-cell-base-depth"); ++i) {
      std::string num1 = std::to_string(i - 2);
      std::string num2 = std::to_string(i);
      for(auto suf : suffixes)
        nameMap.insert({"decoder" + suf + "_nl_drt_" + num1, "decoder_cell" + num2 + suf});
    }
    // add mapping for normalization layers
    std::map<std::string, std::string> nameMapCopy(nameMap);
    for(auto& kv : nameMapCopy) {
      std::string prefix = kv.first.substr(0, 7);

      if(prefix == "encoder" || prefix == "decoder") {
        nameMap.insert({kv.first + "_lns", kv.second + "_lns"});
        nameMap.insert({kv.first + "_lnb", kv.second + "_lnb"});
      }
    }

    return nameMap;
  }

  void createAmunConfig(const std::string& name) {
    Config::YamlNode amun;
    // Amun has only CPU decoder for deep Nematus models
    amun["cpu-threads"] = 16;
    amun["gpu-threads"] = 0;
    amun["maxi-batch"] = 1;
    amun["mini-batch"] = 1;

    auto vocabs = options_->get<std::vector<std::string>>("vocabs");
    amun["source-vocab"] = vocabs[0];
    amun["target-vocab"] = vocabs[1];
    amun["devices"] = options_->get<std::vector<size_t>>("devices");
    amun["normalize"] = true;
    amun["beam-size"] = 5;
    amun["relative-paths"] = false;

    amun["scorers"]["F0"]["path"] = name;
    amun["scorers"]["F0"]["type"] = "nematus2";
    amun["weights"]["F0"] = 1.0f;

    io::OutputFileStream out(name + ".amun.yml");
    out << amun;
  }
};
}  // namespace marian