.. _program_listing_file_src_models_amun.h: Program Listing for File amun.h =============================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/models/amun.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "marian.h" #include "models/s2s.h" namespace marian { class Amun : public EncoderDecoder { public: Amun(Ptr graph, Ptr options) : EncoderDecoder(graph, options) { ABORT_IF(opt("enc-depth") > 1, "--type amun does not support multiple encoder " "layers, use --type s2s"); ABORT_IF(opt("enc-cell-depth") > 1, "--type amun does not support stacked encoder " "cells, use --type s2s"); ABORT_IF(opt("skip"), "--type amun does not support skip connections, " "use --type s2s"); ABORT_IF(opt("dec-depth") > 1, "--type amun does not support multiple decoder " "layers, use --type s2s"); ABORT_IF(opt("dec-cell-base-depth") != 2, "--type amun does not support multiple decoder " "base cells, use --type s2s"); ABORT_IF(opt("dec-cell-high-depth") > 1, "--type amun does not support multiple decoder " "high cells, use --type s2s"); ABORT_IF(opt("enc-cell") != "gru", "--type amun does not support other rnn cells than gru, " "use --type s2s"); ABORT_IF(opt("dec-cell") != "gru", "--type amun does not support other rnn cells than gru, " "use --type s2s"); } void load(Ptr graph, const std::vector& items, bool /*markedReloaded*/ = true) override { std::map nameMap = {{"decoder_U", "decoder_cell1_U"}, {"decoder_Ux", "decoder_cell1_Ux"}, {"decoder_W", "decoder_cell1_W"}, {"decoder_Wx", "decoder_cell1_Wx"}, {"decoder_b", "decoder_cell1_b"}, {"decoder_bx", "decoder_cell1_bx"}, {"decoder_cell1_gamma1", "decoder_cell1_gamma1"}, {"decoder_cell1_gamma2", "decoder_cell1_gamma2"}, {"decoder_U_nl", "decoder_cell2_U"}, {"decoder_Ux_nl", "decoder_cell2_Ux"}, {"decoder_Wc", "decoder_cell2_W"}, {"decoder_Wcx", "decoder_cell2_Wx"}, {"decoder_b_nl", "decoder_cell2_b"}, {"decoder_bx_nl", "decoder_cell2_bx"}, {"ff_logit_prev_W", "decoder_ff_logit_l1_W0"}, {"ff_logit_lstm_W", "decoder_ff_logit_l1_W1"}, {"ff_logit_ctx_W", "decoder_ff_logit_l1_W2"}, {"ff_logit_prev_b", "decoder_ff_logit_l1_b0"}, {"ff_logit_lstm_b", "decoder_ff_logit_l1_b1"}, {"ff_logit_ctx_b", "decoder_ff_logit_l1_b2"}, {"ff_logit_l1_gamma0", "decoder_ff_logit_l1_gamma0"}, {"ff_logit_l1_gamma1", "decoder_ff_logit_l1_gamma1"}, {"ff_logit_l1_gamma2", "decoder_ff_logit_l1_gamma2"}, {"ff_logit_W", "decoder_ff_logit_l2_W"}, {"ff_logit_b", "decoder_ff_logit_l2_b"}, {"ff_state_W", "decoder_ff_state_W"}, {"ff_state_b", "decoder_ff_state_b"}, {"ff_state_gamma", "decoder_ff_state_gamma"}, {"Wemb_dec", "decoder_Wemb"}, {"Wemb", "encoder_Wemb"}, {"encoder_U", "encoder_bi_U"}, {"encoder_Ux", "encoder_bi_Ux"}, {"encoder_W", "encoder_bi_W"}, {"encoder_Wx", "encoder_bi_Wx"}, {"encoder_b", "encoder_bi_b"}, {"encoder_bx", "encoder_bi_bx"}, {"encoder_gamma1", "encoder_bi_gamma1"}, {"encoder_gamma2", "encoder_bi_gamma2"}, {"encoder_r_U", "encoder_bi_r_U"}, {"encoder_r_Ux", "encoder_bi_r_Ux"}, {"encoder_r_W", "encoder_bi_r_W"}, {"encoder_r_Wx", "encoder_bi_r_Wx"}, {"encoder_r_b", "encoder_bi_r_b"}, {"encoder_r_bx", "encoder_bi_r_bx"}, {"encoder_r_gamma1", "encoder_bi_r_gamma1"}, {"encoder_r_gamma2", "encoder_bi_r_gamma2"}}; if(opt("tied-embeddings-src") || opt("tied-embeddings-all")) nameMap["Wemb"] = "Wemb"; auto ioItems = items; // map names and remove a dummy matrices for(auto it = ioItems.begin(); it != ioItems.end();) { // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size // @TODO: consider dropping support for Nematus models if(it->shape.size() == 1) { int dim = it->shape[-1]; it->shape.resize(2); it->shape.set(0, 1); it->shape.set(1, dim); } if(it->name == "decoder_c_tt") { it = ioItems.erase(it); } else if(it->name == "uidx") { it = ioItems.erase(it); } else if(it->name == "history_errs") { it = ioItems.erase(it); } else { auto pair = nameMap.find(it->name); if(pair != nameMap.end()) it->name = pair->second; it++; } } // load items into the graph graph->load(ioItems); } void load(Ptr graph, const std::string& name, bool /*markReloaded*/ = true) override { LOG(info, "Loading model from {}", name); auto ioItems = io::loadItems(name); load(graph, ioItems); } void save(Ptr graph, const std::string& name, bool saveTranslatorConfig = false) override { LOG(info, "Saving model to {}", name); std::map nameMap = {{"decoder_cell1_U", "decoder_U"}, {"decoder_cell1_Ux", "decoder_Ux"}, {"decoder_cell1_W", "decoder_W"}, {"decoder_cell1_Wx", "decoder_Wx"}, {"decoder_cell1_b", "decoder_b"}, {"decoder_cell1_bx", "decoder_bx"}, {"decoder_cell2_U", "decoder_U_nl"}, {"decoder_cell2_Ux", "decoder_Ux_nl"}, {"decoder_cell2_W", "decoder_Wc"}, {"decoder_cell2_Wx", "decoder_Wcx"}, {"decoder_cell2_b", "decoder_b_nl"}, {"decoder_cell2_bx", "decoder_bx_nl"}, {"decoder_ff_logit_l1_W0", "ff_logit_prev_W"}, {"decoder_ff_logit_l1_W1", "ff_logit_lstm_W"}, {"decoder_ff_logit_l1_W2", "ff_logit_ctx_W"}, {"decoder_ff_logit_l1_b0", "ff_logit_prev_b"}, {"decoder_ff_logit_l1_b1", "ff_logit_lstm_b"}, {"decoder_ff_logit_l1_b2", "ff_logit_ctx_b"}, {"decoder_ff_logit_l1_gamma0", "ff_logit_l1_gamma0"}, {"decoder_ff_logit_l1_gamma1", "ff_logit_l1_gamma1"}, {"decoder_ff_logit_l1_gamma2", "ff_logit_l1_gamma2"}, {"decoder_ff_logit_l2_W", "ff_logit_W"}, {"decoder_ff_logit_l2_b", "ff_logit_b"}, {"decoder_ff_state_W", "ff_state_W"}, {"decoder_ff_state_b", "ff_state_b"}, {"decoder_ff_state_gamma", "ff_state_gamma"}, {"decoder_Wemb", "Wemb_dec"}, {"encoder_Wemb", "Wemb"}, {"encoder_bi_U", "encoder_U"}, {"encoder_bi_Ux", "encoder_Ux"}, {"encoder_bi_W", "encoder_W"}, {"encoder_bi_Wx", "encoder_Wx"}, {"encoder_bi_b", "encoder_b"}, {"encoder_bi_bx", "encoder_bx"}, {"encoder_bi_gamma1", "encoder_gamma1"}, {"encoder_bi_gamma2", "encoder_gamma2"}, {"encoder_bi_r_U", "encoder_r_U"}, {"encoder_bi_r_Ux", "encoder_r_Ux"}, {"encoder_bi_r_W", "encoder_r_W"}, {"encoder_bi_r_Wx", "encoder_r_Wx"}, {"encoder_bi_r_b", "encoder_r_b"}, {"encoder_bi_r_bx", "encoder_r_bx"}, {"encoder_bi_r_gamma1", "encoder_r_gamma1"}, {"encoder_bi_r_gamma2", "encoder_r_gamma2"}}; // get parameters from the graph to items std::vector ioItems; graph->save(ioItems); // replace names to be compatible with Nematus for(auto& item : ioItems) { auto newItemName = nameMap.find(item.name); if(newItemName != nameMap.end()) item.name = newItemName->second; } // add a dummy matrix 'decoder_c_tt' required for Amun and Nematus ioItems.emplace_back(); ioItems.back().name = "decoder_c_tt"; ioItems.back().shape = Shape({1, 0}); ioItems.back().bytes.emplace_back((char)0); io::addMetaToItems(getModelParametersAsString(), "special:model.yml", ioItems); io::saveItems(name, ioItems); if(saveTranslatorConfig) { createAmunConfig(name); createDecoderConfig(name); } } private: void createAmunConfig(const std::string& name) { Config::YamlNode amun; auto vocabs = options_->get>("vocabs"); if(options_->get("relative-paths")) { amun["relative-paths"] = true; auto dirPath = filesystem::Path{name}.parentPath(); amun["source-vocab"] = filesystem::relative(filesystem::Path{vocabs[0]}, dirPath).string(); amun["target-vocab"] = filesystem::relative(filesystem::Path{vocabs[1]}, dirPath).string(); amun["scorers"]["F0"]["path"] = filesystem::Path{name}.filename().string(); } else { amun["relative-paths"] = false; amun["source-vocab"] = vocabs[0]; amun["target-vocab"] = vocabs[1]; amun["scorers"]["F0"]["path"] = name; } amun["scorers"]["F0"]["type"] = "Nematus"; amun["weights"]["F0"] = 1.0f; amun["normalize"] = opt("normalize") > 0; amun["beam-size"] = opt("beam-size"); io::OutputFileStream out(name + ".amun.yml"); out << amun; } }; } // namespace marian