.. _program_listing_file_src_models_encoder_pooler.h: Program Listing for File encoder_pooler.h ========================================= |exhale_lsh| :ref:`Return to documentation for file ` (``src/models/encoder_pooler.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "marian.h" #include "models/encoder.h" #include "models/pooler.h" #include "models/model_base.h" #include "models/states.h" // @TODO: this introduces functionality to use LASER in Marian for the filtering workflow or for use in MS-internal // COSMOS server-farm. There is a lot of code duplication with Classifier and EncoderDecoder and this needs to be fixed. // This will be done after the new layer system has been finished. namespace marian { class EncoderPoolerBase : public models::IModel { public: virtual ~EncoderPoolerBase() {} virtual void load(Ptr graph, const std::string& name, bool markedReloaded = true) override = 0; virtual void mmap(Ptr graph, const void* ptr, bool markedReloaded = true) = 0; virtual void save(Ptr graph, const std::string& name, bool saveTranslatorConfig = false) override = 0; virtual void clear(Ptr graph) override = 0; virtual std::vector apply(Ptr, Ptr, bool) = 0; virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) override { clearGraph; ABORT("Poolers cannot produce Logits"); }; virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) { clearGraph; ABORT("Poolers cannot produce Logits"); } virtual Ptr getOptions() = 0; }; class EncoderPooler : public EncoderPoolerBase { protected: Ptr options_; std::string prefix_; std::vector> encoders_; std::vector> poolers_; bool inference_{true}; std::set modelFeatures_; Config::YamlNode getModelParameters() { Config::YamlNode modelParams; auto clone = options_->cloneToYamlNode(); for(auto& key : modelFeatures_) modelParams[key] = clone[key]; if(options_->has("original-type")) modelParams["type"] = clone["original-type"]; modelParams["version"] = buildVersion(); return modelParams; } std::string getModelParametersAsString() { auto yaml = getModelParameters(); YAML::Emitter out; cli::OutputYaml(yaml, out); return std::string(out.c_str()); } public: typedef data::Corpus dataset_type; // @TODO: lots of code-duplication with EncoderDecoder EncoderPooler(Ptr options) : options_(options), prefix_(options->get("prefix", "")), inference_(options->get("inference", false)) { modelFeatures_ = {"type", "dim-vocabs", "dim-emb", "dim-rnn", "enc-cell", "enc-type", "enc-cell-depth", "enc-depth", "dec-depth", "dec-cell", "dec-cell-base-depth", "dec-cell-high-depth", "skip", "layer-normalization", "right-left", "input-types", "special-vocab", "tied-embeddings", "tied-embeddings-src", "tied-embeddings-all"}; modelFeatures_.insert("transformer-heads"); modelFeatures_.insert("transformer-no-projection"); modelFeatures_.insert("transformer-dim-ffn"); modelFeatures_.insert("transformer-ffn-depth"); modelFeatures_.insert("transformer-ffn-activation"); modelFeatures_.insert("transformer-dim-aan"); modelFeatures_.insert("transformer-aan-depth"); modelFeatures_.insert("transformer-aan-activation"); modelFeatures_.insert("transformer-aan-nogate"); modelFeatures_.insert("transformer-preprocess"); modelFeatures_.insert("transformer-postprocess"); modelFeatures_.insert("transformer-postprocess-emb"); modelFeatures_.insert("transformer-postprocess-top"); modelFeatures_.insert("transformer-decoder-autoreg"); modelFeatures_.insert("transformer-tied-layers"); modelFeatures_.insert("transformer-guided-alignment-layer"); modelFeatures_.insert("transformer-train-position-embeddings"); modelFeatures_.insert("transformer-pool"); modelFeatures_.insert("bert-train-type-embeddings"); modelFeatures_.insert("bert-type-vocab-size"); modelFeatures_.insert("ulr"); modelFeatures_.insert("ulr-trainable-transformation"); modelFeatures_.insert("ulr-dim-emb"); modelFeatures_.insert("lemma-dim-emb"); modelFeatures_.insert("lemma-dependency"); modelFeatures_.insert("factors-combine"); modelFeatures_.insert("factors-dim-emb"); } virtual Ptr getOptions() override { return options_; } std::vector>& getEncoders() { return encoders_; } std::vector>& getPoolers() { return poolers_; } void push_back(Ptr encoder) { encoders_.push_back(encoder); } void push_back(Ptr pooler) { poolers_.push_back(pooler); } void load(Ptr graph, const std::string& name, bool markedReloaded) override { graph->load(name, markedReloaded && !opt("ignore-model-config", false)); } void mmap(Ptr graph, const void* ptr, bool markedReloaded) override { graph->mmap(ptr, markedReloaded && !opt("ignore-model-config", false)); } void save(Ptr graph, const std::string& name, bool /*saveModelConfig*/) override { LOG(info, "Saving model weights and runtime parameters to {}", name); graph->save(name , getModelParametersAsString()); } void clear(Ptr graph) override { graph->clear(); for(auto& enc : encoders_) enc->clear(); for(auto& pooler : poolers_) pooler->clear(); } template T opt(const std::string& key) { return options_->get(key); } template T opt(const std::string& key, const T& def) { return options_->get(key, def); } template void set(std::string key, T value) { options_->set(key, value); } /*********************************************************************/ virtual std::vector apply(Ptr graph, Ptr batch, bool clearGraph) override { if(clearGraph) clear(graph); std::vector> encoderStates; for(auto& encoder : encoders_) encoderStates.push_back(encoder->build(graph, batch)); ABORT_IF(poolers_.size() != 1, "Expected exactly one pooler"); return poolers_[0]->apply(graph, batch, encoderStates); } }; } // namespace marian