.. _program_listing_file_src_models_model_factory.cpp: Program Listing for File model_factory.cpp ========================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/models/model_factory.cpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "marian.h" #include "models/model_factory.h" #include "models/encoder_decoder.h" #include "models/encoder_classifier.h" #include "models/bert.h" #include "models/costs.h" #include "models/amun.h" #include "models/nematus.h" #include "models/s2s.h" #include "models/laser.h" #include "models/transformer_factory.h" #ifdef CUDNN #include "models/char_s2s.h" #endif #ifdef COMPILE_EXAMPLES #include "examples/mnist/model.h" #ifdef CUDNN #include "examples/mnist/model_lenet.h" #endif #endif namespace marian { namespace models { Ptr EncoderFactory::construct(Ptr graph) { if(options_->get("type") == "s2s") return New(graph, options_); if(options_->get("type") == "laser" || options_->get("type") == "laser-sim") return New(graph, options_); #ifdef CUDNN if(options_->get("type") == "char-s2s") return New(graph, options_); #endif if(options_->get("type") == "transformer") return NewEncoderTransformer(graph, options_); if(options_->get("type") == "bert-encoder") return New(graph, options_); ABORT("Unknown encoder type"); } Ptr DecoderFactory::construct(Ptr graph) { if(options_->get("type") == "s2s") return New(graph, options_); if(options_->get("type") == "transformer") return NewDecoderTransformer(graph, options_); ABORT("Unknown decoder type"); } Ptr ClassifierFactory::construct(Ptr graph) { if(options_->get("type") == "bert-masked-lm") return New(graph, options_); else if(options_->get("type") == "bert-classifier") return New(graph, options_); else ABORT("Unknown classifier type"); } Ptr PoolerFactory::construct(Ptr graph) { if(options_->get("type") == "max-pooler") return New(graph, options_); if(options_->get("type") == "slice-pooler") return New(graph, options_); else if(options_->get("type") == "sim-pooler") return New(graph, options_); else ABORT("Unknown pooler type"); } Ptr EncoderDecoderFactory::construct(Ptr graph) { Ptr encdec; if(options_->get("type") == "amun") encdec = New(graph, options_); else if(options_->get("type") == "nematus") encdec = New(graph, options_); else encdec = New(graph, options_); for(auto& ef : encoders_) encdec->push_back(ef(options_).construct(graph)); for(auto& df : decoders_) encdec->push_back(df(options_).construct(graph)); return encdec; } Ptr EncoderClassifierFactory::construct(Ptr graph) { Ptr enccls; if(options_->get("type") == "bert") enccls = New(options_); else if(options_->get("type") == "bert-classifier") enccls = New(options_); else enccls = New(options_); for(auto& ef : encoders_) enccls->push_back(ef(options_).construct(graph)); for(auto& cf : classifiers_) enccls->push_back(cf(options_).construct(graph)); return enccls; } Ptr EncoderPoolerFactory::construct(Ptr graph) { Ptr encpool = New(options_); for(auto& ef : encoders_) encpool->push_back(ef(options_).construct(graph)); for(auto& pl : poolers_) encpool->push_back(pl(options_).construct(graph)); return encpool; } Ptr createBaseModelByType(std::string type, usage use, Ptr options) { Ptr graph = nullptr; // graph unknown at this stage // clang-format off bool trainEmbedderRank = options->hasAndNotEmpty("train-embedder-rank"); if(use == usage::embedding || trainEmbedderRank) { // hijacking an EncoderDecoder model for embedding only auto dimVocabs = options->get>("dim-vocabs"); size_t fields = trainEmbedderRank ? dimVocabs.size() : 0; int dimVocab = dimVocabs[0]; Ptr newOptions; if(options->get("compute-similarity", false)) { newOptions = options->with("usage", use, "original-type", type, "input-types", std::vector({"sequence", "sequence"}), "dim-vocabs", std::vector(2, dimVocab)); } else if(trainEmbedderRank) { newOptions = options->with("usage", use, "original-type", type, "input-types", std::vector(fields, "sequence"), "dim-vocabs", std::vector(fields, dimVocab)); } else { newOptions = options->with("usage", use, "original-type", type, "input-types", std::vector({"sequence"}), "dim-vocabs", std::vector(1, dimVocab)); } auto res = New(newOptions); if(options->get("compute-similarity", false)) { res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph)); res->push_back(models::encoder(newOptions->with("index", 1)).construct(graph)); res->push_back(New(graph, newOptions->with("type", "sim-pooler"))); } else if(trainEmbedderRank) { LOG(info, "Using {} input fields for embedder ranking training", fields); for(int i = 0; i < fields; ++i) res->push_back(models::encoder(newOptions->with("index", i)).construct(graph)); res->push_back(New(graph, newOptions->with("type", "sim-pooler"))); } else { res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph)); if(type == "laser") res->push_back(New(graph, newOptions->with("type", "max-pooler"))); else res->push_back(New(graph, newOptions->with("type", "slice-pooler"))); } return res; } if(type == "s2s" || type == "amun" || type == "nematus") { return models::encoder_decoder(options->with( "usage", use, "original-type", type)) .push_back(models::encoder()("type", "s2s")) .push_back(models::decoder()("type", "s2s")) .construct(graph); } else if(type == "transformer") { #if 1 auto newOptions = options->with("usage", use); auto res = New(graph, newOptions); res->push_back(New(graph, newOptions->with("type", "transformer"))); res->push_back(New(graph, newOptions->with("type", "transformer"))); return res; #else return models::encoder_decoder(options->with( "usage", use)) .push_back(models::encoder()("type", "transformer")) .push_back(models::decoder()("type", "transformer")) .construct(graph); #endif } else if(type == "transformer_s2s") { return models::encoder_decoder()(options) ("usage", use) ("original-type", type) .push_back(models::encoder()("type", "transformer")) .push_back(models::decoder()("type", "s2s")) .construct(graph); } else if(type == "lm") { auto idx = options->has("index") ? options->get("index") : 0; std::vector dimVocabs = options->get>("dim-vocabs"); int vocab = dimVocabs[0]; dimVocabs.resize(idx + 1); std::fill(dimVocabs.begin(), dimVocabs.end(), vocab); return models::encoder_decoder(options->with( "usage", use, "type", "s2s", "original-type", type)) .push_back(models::decoder() ("index", idx) ("dim-vocabs", dimVocabs)) .construct(graph); } else if(type == "multi-s2s") { size_t numEncoders = 2; auto ms2sFactory = models::encoder_decoder()(options) ("usage", use) ("type", "s2s") ("original-type", type); for(size_t i = 0; i < numEncoders; ++i) { auto prefix = "encoder" + std::to_string(i + 1); ms2sFactory.push_back(models::encoder()("prefix", prefix)("index", i)); } ms2sFactory.push_back(models::decoder()("index", numEncoders)); return ms2sFactory.construct(graph); } else if(type == "shared-multi-s2s") { size_t numEncoders = 2; auto ms2sFactory = models::encoder_decoder()(options) ("usage", use) ("type", "s2s") ("original-type", type); for(size_t i = 0; i < numEncoders; ++i) { auto prefix = "encoder"; ms2sFactory.push_back(models::encoder()("prefix", prefix)("index", i)); } ms2sFactory.push_back(models::decoder()("index", numEncoders)); return ms2sFactory.construct(graph); } else if(type == "multi-transformer") { size_t numEncoders = 2; auto mtransFactory = models::encoder_decoder()(options) ("usage", use) ("type", "transformer") ("original-type", type); for(size_t i = 0; i < numEncoders; ++i) { auto prefix = "encoder" + std::to_string(i + 1); mtransFactory.push_back(models::encoder()("prefix", prefix)("index", i)); } mtransFactory.push_back(models::decoder()("index", numEncoders)); return mtransFactory.construct(graph); } else if(type == "shared-multi-transformer") { size_t numEncoders = 2; auto mtransFactory = models::encoder_decoder()(options) ("usage", use) ("type", "transformer") ("original-type", type); for(size_t i = 0; i < numEncoders; ++i) { auto prefix = "encoder"; mtransFactory.push_back(models::encoder()("prefix", prefix)("index", i)); } mtransFactory.push_back(models::decoder()("index", numEncoders)); return mtransFactory.construct(graph); } else if(type == "lm-transformer") { auto idx = options->has("index") ? options->get("index") : 0; std::vector dimVocabs = options->get>("dim-vocabs"); int vocab = dimVocabs[0]; dimVocabs.resize(idx + 1); std::fill(dimVocabs.begin(), dimVocabs.end(), vocab); return models::encoder_decoder()(options) ("usage", use) ("type", "transformer") ("original-type", type) .push_back(models::decoder() ("index", idx) ("dim-vocabs", dimVocabs)) .construct(graph); } else if(type == "bert") { // for full BERT training return models::encoder_classifier()(options) // ("original-type", "bert") // so we can query this ("usage", use) // .push_back(models::encoder() // ("type", "bert-encoder") // close to original transformer encoder ("index", 0)) // .push_back(models::classifier() // ("prefix", "masked-lm") // prefix for parameter names ("type", "bert-masked-lm") // ("index", 0)) // multi-task learning with MaskedLM .push_back(models::classifier() // ("prefix", "next-sentence") // prefix for parameter names ("type", "bert-classifier") // ("index", 1)) // next sentence prediction .construct(graph); } else if(type == "bert-classifier") { // for BERT fine-tuning on non-BERT classification task return models::encoder_classifier()(options) // ("original-type", "bert-classifier") // so we can query this if needed ("usage", use) // .push_back(models::encoder() // ("type", "bert-encoder") // ("index", 0)) // close to original transformer encoder .push_back(models::classifier() // ("type", "bert-classifier") // ("index", 1)) // next sentence prediction .construct(graph); } #ifdef COMPILE_EXAMPLES else if(type == "mnist-ffnn") return New(options); #endif #ifdef CUDNN #ifdef COMPILE_EXAMPLES else if(type == "mnist-lenet") return New(options); #endif else if(type == "char-s2s") { return models::encoder_decoder()(options) ("usage", use) ("original-type", type) .push_back(models::encoder()("type", "char-s2s")) .push_back(models::decoder()("type", "s2s")) .construct(graph); } #endif // clang-format on else ABORT("Unknown model type: {}", type); } Ptr createModelFromOptions(Ptr options, usage use) { std::string type = options->get("type"); auto baseModel = createBaseModelByType(type, use, options); // add (log)softmax if requested if (use == usage::translation) { if(std::dynamic_pointer_cast(baseModel)) { if(options->hasAndNotEmpty("output-sampling")) { auto sampling = options->get>("output-sampling", {}); std::string method = sampling.size() > 0 ? sampling[0] : "full"; if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) { LOG(info, "Output sampling from the full softmax distribution"); return New(std::dynamic_pointer_cast(baseModel), New()); } else if(method == "topk") { int k = sampling.size() > 1 ? std::stoi(sampling[1]) : 10; if(k == 1) LOG(info, "Output sampling with k=1 is equivalent to beam search with beam size 1"); LOG(info, "Output sampling via top-{} sampling", k); return New(std::dynamic_pointer_cast(baseModel), New(k)); } else { ABORT("Unknown sampling method: {}", method); } } else { return New(std::dynamic_pointer_cast(baseModel), New()); } } #ifdef COMPILE_EXAMPLES // note: 'usage::translation' here means 'inference' else if (std::dynamic_pointer_cast(baseModel)) return New(baseModel, New()); #ifdef CUDNN else if (std::dynamic_pointer_cast(baseModel)) return New(baseModel, New()); #endif #endif else ABORT("'usage' parameter 'translation' cannot be applied to model type: {}", type); } else if (use == usage::raw || use == usage::embedding) return baseModel; else ABORT("'Usage' parameter must be 'translation' or 'raw'"); } Ptr createCriterionFunctionFromOptions(Ptr options, usage use) { std::string type = options->get("type"); auto baseModel = createBaseModelByType(type, use, options); // add cost function ABORT_IF(use != usage::training && use != usage::scoring, "'Usage' parameter must be 'training' or 'scoring'"); // note: usage::scoring means "score the loss function", hence it uses a Trainer (not Scorer, which is for decoding) // @TODO: Should we define a new class that does not compute gradients? if (std::dynamic_pointer_cast(baseModel)) return New(baseModel, New(options)); else if (std::dynamic_pointer_cast(baseModel)) return New(baseModel, New(options)); #ifdef COMPILE_EXAMPLES // @TODO: examples should be compiled optionally else if (std::dynamic_pointer_cast(baseModel)) return New(baseModel, New()); #ifdef CUDNN else if (std::dynamic_pointer_cast(baseModel)) return New(baseModel, New()); #endif #endif else if (std::dynamic_pointer_cast(baseModel)) return New(baseModel, New(options)); else ABORT("Criterion function unknown for model type: {}", type); } } // namespace models } // namespace marian