.. _program_listing_file_src_models_nematus.h: Program Listing for File nematus.h ================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/models/nematus.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "marian.h" #include "models/s2s.h" namespace marian { class Nematus : public EncoderDecoder { public: Nematus(Ptr graph, Ptr options) : EncoderDecoder(graph, options), nameMap_(createNameMap()) { ABORT_IF(options_->get("enc-type") != "bidirectional", "--type nematus does not support other encoder " "type than bidirectional, use --type s2s"); ABORT_IF(options_->get("enc-cell") != "gru-nematus", "--type nematus does not support other rnn cells " "than gru-nematus, use --type s2s"); ABORT_IF(options_->get("dec-cell") != "gru-nematus", "--type nematus does not support other rnn cells " "than gru-nematus, use --type s2s"); ABORT_IF(options_->get("dec-cell-high-depth") > 1, "--type nematus does not currently support " "--dec-cell-high-depth > 1, use --type s2s"); } void load(Ptr graph, const std::vector& items, bool /*markReloaded*/ = true) override { auto ioItems = items; // map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node for(auto it = ioItems.begin(); it != ioItems.end();) { // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size // @TODO: consider dropping support for Nematus models if(it->shape.size() == 1) { int dim = it->shape[-1]; it->shape.resize(2); it->shape.set(0, 1); it->shape.set(1, dim); } if(it->name == "decoder_c_tt") { it = ioItems.erase(it); } else if(it->name == "uidx") { it = ioItems.erase(it); } else if(it->name == "history_errs") { it = ioItems.erase(it); } else { auto pair = nameMap_.find(it->name); if(pair != nameMap_.end()) it->name = pair->second; it++; } } // load items into the graph graph->load(ioItems); } void load(Ptr graph, const std::string& name, bool /*markReloaded*/ = true) override { LOG(info, "Loading model from {}", name); auto ioItems = io::loadItems(name); load(graph, ioItems); } void save(Ptr graph, const std::string& name, bool saveTranslatorConfig = false) override { LOG(info, "Saving model to {}", name); // prepare reversed map if(nameMapRev_.empty()) for(const auto& kv : nameMap_) nameMapRev_.insert({kv.second, kv.first}); // get parameters from the graph to items std::vector ioItems; graph->save(ioItems); // replace names to be compatible with Nematus for(auto& item : ioItems) { auto newItemName = nameMapRev_.find(item.name); if(newItemName != nameMapRev_.end()) item.name = newItemName->second; } // add a dummy matrix 'decoder_c_tt' required for Amun and Nematus ioItems.emplace_back(); ioItems.back().name = "decoder_c_tt"; ioItems.back().shape = Shape({1, 0}); ioItems.back().bytes.emplace_back((char)0); io::addMetaToItems(getModelParametersAsString(), "special:model.yml", ioItems); io::saveItems(name, ioItems); if(saveTranslatorConfig) { createAmunConfig(name); createDecoderConfig(name); } } private: std::map nameMap_; std::map nameMapRev_; std::map createNameMap() { std::map nameMap = {{"decoder_U", "decoder_cell1_U"}, {"decoder_Ux", "decoder_cell1_Ux"}, {"decoder_W", "decoder_cell1_W"}, {"decoder_Wx", "decoder_cell1_Wx"}, {"decoder_b", "decoder_cell1_b"}, {"decoder_bx", "decoder_cell1_bx"}, {"decoder_U_nl", "decoder_cell2_U"}, {"decoder_Ux_nl", "decoder_cell2_Ux"}, {"decoder_Wc", "decoder_cell2_W"}, {"decoder_Wcx", "decoder_cell2_Wx"}, {"decoder_b_nl", "decoder_cell2_b"}, {"decoder_bx_nl", "decoder_cell2_bx"}, {"ff_logit_prev_W", "decoder_ff_logit_l1_W0"}, {"ff_logit_lstm_W", "decoder_ff_logit_l1_W1"}, {"ff_logit_ctx_W", "decoder_ff_logit_l1_W2"}, {"ff_logit_prev_b", "decoder_ff_logit_l1_b0"}, {"ff_logit_lstm_b", "decoder_ff_logit_l1_b1"}, {"ff_logit_ctx_b", "decoder_ff_logit_l1_b2"}, {"ff_logit_W", "decoder_ff_logit_l2_W"}, {"ff_logit_b", "decoder_ff_logit_l2_b"}, {"ff_state_W", "decoder_ff_state_W"}, {"ff_state_b", "decoder_ff_state_b"}, {"Wemb_dec", "decoder_Wemb"}, {"Wemb", "encoder_Wemb"}, {"encoder_U", "encoder_bi_U"}, {"encoder_Ux", "encoder_bi_Ux"}, {"encoder_W", "encoder_bi_W"}, {"encoder_Wx", "encoder_bi_Wx"}, {"encoder_b", "encoder_bi_b"}, {"encoder_bx", "encoder_bi_bx"}, {"encoder_r_U", "encoder_bi_r_U"}, {"encoder_r_Ux", "encoder_bi_r_Ux"}, {"encoder_r_W", "encoder_bi_r_W"}, {"encoder_r_Wx", "encoder_bi_r_Wx"}, {"encoder_r_b", "encoder_bi_r_b"}, {"encoder_r_bx", "encoder_bi_r_bx"}, {"ff_state_ln_s", "decoder_ff_state_ln_s"}, {"ff_state_ln_b", "decoder_ff_state_ln_b"}, {"ff_logit_prev_ln_s", "decoder_ff_logit_l1_ln_s0"}, {"ff_logit_lstm_ln_s", "decoder_ff_logit_l1_ln_s1"}, {"ff_logit_ctx_ln_s", "decoder_ff_logit_l1_ln_s2"}, {"ff_logit_prev_ln_b", "decoder_ff_logit_l1_ln_b0"}, {"ff_logit_lstm_ln_b", "decoder_ff_logit_l1_ln_b1"}, {"ff_logit_ctx_ln_b", "decoder_ff_logit_l1_ln_b2"}}; // add mapping for deep encoder cells std::vector suffixes = {"_U", "_Ux", "_b", "_bx"}; for(int i = 1; i < options_->get("enc-cell-depth"); ++i) { std::string num1 = std::to_string(i); std::string num2 = std::to_string(i + 1); for(auto suf : suffixes) { nameMap.insert({"encoder" + suf + "_drt_" + num1, "encoder_bi_cell" + num2 + suf}); nameMap.insert({"encoder_r" + suf + "_drt_" + num1, "encoder_bi_r_cell" + num2 + suf}); } } // add mapping for deep decoder cells for(int i = 3; i <= options_->get("dec-cell-base-depth"); ++i) { std::string num1 = std::to_string(i - 2); std::string num2 = std::to_string(i); for(auto suf : suffixes) nameMap.insert({"decoder" + suf + "_nl_drt_" + num1, "decoder_cell" + num2 + suf}); } // add mapping for normalization layers std::map nameMapCopy(nameMap); for(auto& kv : nameMapCopy) { std::string prefix = kv.first.substr(0, 7); if(prefix == "encoder" || prefix == "decoder") { nameMap.insert({kv.first + "_lns", kv.second + "_lns"}); nameMap.insert({kv.first + "_lnb", kv.second + "_lnb"}); } } return nameMap; } void createAmunConfig(const std::string& name) { Config::YamlNode amun; // Amun has only CPU decoder for deep Nematus models amun["cpu-threads"] = 16; amun["gpu-threads"] = 0; amun["maxi-batch"] = 1; amun["mini-batch"] = 1; auto vocabs = options_->get>("vocabs"); amun["source-vocab"] = vocabs[0]; amun["target-vocab"] = vocabs[1]; amun["devices"] = options_->get>("devices"); amun["normalize"] = true; amun["beam-size"] = 5; amun["relative-paths"] = false; amun["scorers"]["F0"]["path"] = name; amun["scorers"]["F0"]["type"] = "nematus2"; amun["weights"]["F0"] = 1.0f; io::OutputFileStream out(name + ".amun.yml"); out << amun; } }; } // namespace marian