Program Listing for File amun.h

Return to documentation for file (src/models/amun.h)

#pragma once

#include "marian.h"

#include "models/s2s.h"

namespace marian {

class Amun : public EncoderDecoder {
public:
  Amun(Ptr<ExpressionGraph> graph, Ptr<Options> options) : EncoderDecoder(graph, options) {
    ABORT_IF(opt<int>("enc-depth") > 1,
             "--type amun does not support multiple encoder "
             "layers, use --type s2s");
    ABORT_IF(opt<int>("enc-cell-depth") > 1,
             "--type amun does not support stacked encoder "
             "cells, use --type s2s");
    ABORT_IF(opt<bool>("skip"),
             "--type amun does not support skip connections, "
             "use --type s2s");
    ABORT_IF(opt<int>("dec-depth") > 1,
             "--type amun does not support multiple decoder "
             "layers, use --type s2s");
    ABORT_IF(opt<int>("dec-cell-base-depth") != 2,
             "--type amun does not support multiple decoder "
             "base cells, use --type s2s");
    ABORT_IF(opt<int>("dec-cell-high-depth") > 1,
             "--type amun does not support multiple decoder "
             "high cells, use --type s2s");
    ABORT_IF(opt<std::string>("enc-cell") != "gru",
             "--type amun does not support other rnn cells than gru, "
             "use --type s2s");
    ABORT_IF(opt<std::string>("dec-cell") != "gru",
             "--type amun does not support other rnn cells than gru, "
             "use --type s2s");
  }

  void load(Ptr<ExpressionGraph> graph,
            const std::vector<io::Item>& items,
            bool /*markedReloaded*/ = true) override {
    std::map<std::string, std::string> nameMap
        = {{"decoder_U", "decoder_cell1_U"},
           {"decoder_Ux", "decoder_cell1_Ux"},
           {"decoder_W", "decoder_cell1_W"},
           {"decoder_Wx", "decoder_cell1_Wx"},
           {"decoder_b", "decoder_cell1_b"},
           {"decoder_bx", "decoder_cell1_bx"},
           {"decoder_cell1_gamma1", "decoder_cell1_gamma1"},
           {"decoder_cell1_gamma2", "decoder_cell1_gamma2"},
           {"decoder_U_nl", "decoder_cell2_U"},
           {"decoder_Ux_nl", "decoder_cell2_Ux"},
           {"decoder_Wc", "decoder_cell2_W"},
           {"decoder_Wcx", "decoder_cell2_Wx"},
           {"decoder_b_nl", "decoder_cell2_b"},
           {"decoder_bx_nl", "decoder_cell2_bx"},
           {"ff_logit_prev_W", "decoder_ff_logit_l1_W0"},
           {"ff_logit_lstm_W", "decoder_ff_logit_l1_W1"},
           {"ff_logit_ctx_W", "decoder_ff_logit_l1_W2"},
           {"ff_logit_prev_b", "decoder_ff_logit_l1_b0"},
           {"ff_logit_lstm_b", "decoder_ff_logit_l1_b1"},
           {"ff_logit_ctx_b", "decoder_ff_logit_l1_b2"},
           {"ff_logit_l1_gamma0", "decoder_ff_logit_l1_gamma0"},
           {"ff_logit_l1_gamma1", "decoder_ff_logit_l1_gamma1"},
           {"ff_logit_l1_gamma2", "decoder_ff_logit_l1_gamma2"},
           {"ff_logit_W", "decoder_ff_logit_l2_W"},
           {"ff_logit_b", "decoder_ff_logit_l2_b"},
           {"ff_state_W", "decoder_ff_state_W"},
           {"ff_state_b", "decoder_ff_state_b"},
           {"ff_state_gamma", "decoder_ff_state_gamma"},
           {"Wemb_dec", "decoder_Wemb"},
           {"Wemb", "encoder_Wemb"},
           {"encoder_U", "encoder_bi_U"},
           {"encoder_Ux", "encoder_bi_Ux"},
           {"encoder_W", "encoder_bi_W"},
           {"encoder_Wx", "encoder_bi_Wx"},
           {"encoder_b", "encoder_bi_b"},
           {"encoder_bx", "encoder_bi_bx"},
           {"encoder_gamma1", "encoder_bi_gamma1"},
           {"encoder_gamma2", "encoder_bi_gamma2"},
           {"encoder_r_U", "encoder_bi_r_U"},
           {"encoder_r_Ux", "encoder_bi_r_Ux"},
           {"encoder_r_W", "encoder_bi_r_W"},
           {"encoder_r_Wx", "encoder_bi_r_Wx"},
           {"encoder_r_b", "encoder_bi_r_b"},
           {"encoder_r_bx", "encoder_bi_r_bx"},
           {"encoder_r_gamma1", "encoder_bi_r_gamma1"},
           {"encoder_r_gamma2", "encoder_bi_r_gamma2"}};

    if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
      nameMap["Wemb"] = "Wemb";

    auto ioItems = items;
    // map names and remove a dummy matrices
    for(auto it = ioItems.begin(); it != ioItems.end();) {
      // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
      // @TODO: consider dropping support for Nematus models
      if(it->shape.size() == 1) {
        int dim = it->shape[-1];
        it->shape.resize(2);
        it->shape.set(0, 1);
        it->shape.set(1, dim);
      }

      if(it->name == "decoder_c_tt") {
        it = ioItems.erase(it);
      } else if(it->name == "uidx") {
        it = ioItems.erase(it);
      } else if(it->name == "history_errs") {
        it = ioItems.erase(it);
      } else {
        auto pair = nameMap.find(it->name);
        if(pair != nameMap.end())
          it->name = pair->second;
        it++;
      }
    }
    // load items into the graph
    graph->load(ioItems);
  }

  void load(Ptr<ExpressionGraph> graph,
            const std::string& name,
            bool /*markReloaded*/ = true) override {
    LOG(info, "Loading model from {}", name);
    auto ioItems = io::loadItems(name);
    load(graph, ioItems);
  }

  void save(Ptr<ExpressionGraph> graph,
            const std::string& name,
            bool saveTranslatorConfig = false) override {
    LOG(info, "Saving model to {}", name);

    std::map<std::string, std::string> nameMap
        = {{"decoder_cell1_U", "decoder_U"},
           {"decoder_cell1_Ux", "decoder_Ux"},
           {"decoder_cell1_W", "decoder_W"},
           {"decoder_cell1_Wx", "decoder_Wx"},
           {"decoder_cell1_b", "decoder_b"},
           {"decoder_cell1_bx", "decoder_bx"},
           {"decoder_cell2_U", "decoder_U_nl"},
           {"decoder_cell2_Ux", "decoder_Ux_nl"},
           {"decoder_cell2_W", "decoder_Wc"},
           {"decoder_cell2_Wx", "decoder_Wcx"},
           {"decoder_cell2_b", "decoder_b_nl"},
           {"decoder_cell2_bx", "decoder_bx_nl"},
           {"decoder_ff_logit_l1_W0", "ff_logit_prev_W"},
           {"decoder_ff_logit_l1_W1", "ff_logit_lstm_W"},
           {"decoder_ff_logit_l1_W2", "ff_logit_ctx_W"},
           {"decoder_ff_logit_l1_b0", "ff_logit_prev_b"},
           {"decoder_ff_logit_l1_b1", "ff_logit_lstm_b"},
           {"decoder_ff_logit_l1_b2", "ff_logit_ctx_b"},
           {"decoder_ff_logit_l1_gamma0", "ff_logit_l1_gamma0"},
           {"decoder_ff_logit_l1_gamma1", "ff_logit_l1_gamma1"},
           {"decoder_ff_logit_l1_gamma2", "ff_logit_l1_gamma2"},
           {"decoder_ff_logit_l2_W", "ff_logit_W"},
           {"decoder_ff_logit_l2_b", "ff_logit_b"},
           {"decoder_ff_state_W", "ff_state_W"},
           {"decoder_ff_state_b", "ff_state_b"},
           {"decoder_ff_state_gamma", "ff_state_gamma"},
           {"decoder_Wemb", "Wemb_dec"},
           {"encoder_Wemb", "Wemb"},
           {"encoder_bi_U", "encoder_U"},
           {"encoder_bi_Ux", "encoder_Ux"},
           {"encoder_bi_W", "encoder_W"},
           {"encoder_bi_Wx", "encoder_Wx"},
           {"encoder_bi_b", "encoder_b"},
           {"encoder_bi_bx", "encoder_bx"},
           {"encoder_bi_gamma1", "encoder_gamma1"},
           {"encoder_bi_gamma2", "encoder_gamma2"},
           {"encoder_bi_r_U", "encoder_r_U"},
           {"encoder_bi_r_Ux", "encoder_r_Ux"},
           {"encoder_bi_r_W", "encoder_r_W"},
           {"encoder_bi_r_Wx", "encoder_r_Wx"},
           {"encoder_bi_r_b", "encoder_r_b"},
           {"encoder_bi_r_bx", "encoder_r_bx"},
           {"encoder_bi_r_gamma1", "encoder_r_gamma1"},
           {"encoder_bi_r_gamma2", "encoder_r_gamma2"}};

    // get parameters from the graph to items
    std::vector<io::Item> ioItems;
    graph->save(ioItems);
    // replace names to be compatible with Nematus
    for(auto& item : ioItems) {
      auto newItemName = nameMap.find(item.name);
      if(newItemName != nameMap.end())
        item.name = newItemName->second;
    }
    // add a dummy matrix 'decoder_c_tt' required for Amun and Nematus
    ioItems.emplace_back();
    ioItems.back().name = "decoder_c_tt";
    ioItems.back().shape = Shape({1, 0});
    ioItems.back().bytes.emplace_back((char)0);

    io::addMetaToItems(getModelParametersAsString(), "special:model.yml", ioItems);
    io::saveItems(name, ioItems);

    if(saveTranslatorConfig) {
      createAmunConfig(name);
      createDecoderConfig(name);
    }
  }

private:
  void createAmunConfig(const std::string& name) {
    Config::YamlNode amun;
    auto vocabs = options_->get<std::vector<std::string>>("vocabs");

    if(options_->get<bool>("relative-paths")) {
      amun["relative-paths"] = true;
      auto dirPath = filesystem::Path{name}.parentPath();
      amun["source-vocab"] = filesystem::relative(filesystem::Path{vocabs[0]}, dirPath).string();
      amun["target-vocab"] = filesystem::relative(filesystem::Path{vocabs[1]}, dirPath).string();
      amun["scorers"]["F0"]["path"] = filesystem::Path{name}.filename().string();
    } else {
      amun["relative-paths"] = false;
      amun["source-vocab"] = vocabs[0];
      amun["target-vocab"] = vocabs[1];
      amun["scorers"]["F0"]["path"] = name;
    }

    amun["scorers"]["F0"]["type"] = "Nematus";
    amun["weights"]["F0"] = 1.0f;
    amun["normalize"] = opt<float>("normalize") > 0;
    amun["beam-size"] = opt<size_t>("beam-size");

    io::OutputFileStream out(name + ".amun.yml");
    out << amun;
  }
};
}  // namespace marian