.. _program_listing_file_src_models_encoder_classifier.h: Program Listing for File encoder_classifier.h ============================================= |exhale_lsh| :ref:`Return to documentation for file ` (``src/models/encoder_classifier.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "marian.h" #include "models/encoder.h" #include "models/classifier.h" #include "models/model_base.h" #include "models/states.h" namespace marian { class EncoderClassifierBase : public models::IModel { public: virtual ~EncoderClassifierBase() {} virtual void load(Ptr graph, const std::string& name, bool markedReloaded = true) override = 0; virtual void mmap(Ptr graph, const void* ptr, bool markedReloaded = true) = 0; virtual void save(Ptr graph, const std::string& name, bool saveTranslatorConfig = false) override = 0; virtual void clear(Ptr graph) override = 0; virtual std::vector> apply(Ptr, Ptr, bool) = 0; virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) override = 0; virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) = 0; virtual Ptr getOptions() = 0; }; class EncoderClassifier : public EncoderClassifierBase { protected: Ptr options_; std::string prefix_; std::vector> encoders_; std::vector> classifiers_; bool inference_{false}; std::set modelFeatures_; Config::YamlNode getModelParameters() { Config::YamlNode modelParams; auto clone = options_->cloneToYamlNode(); for(auto& key : modelFeatures_) modelParams[key] = clone[key]; if(options_->has("original-type")) modelParams["type"] = clone["original-type"]; modelParams["version"] = buildVersion(); return modelParams; } std::string getModelParametersAsString() { auto yaml = getModelParameters(); YAML::Emitter out; cli::OutputYaml(yaml, out); return std::string(out.c_str()); } public: typedef data::Corpus dataset_type; // @TODO: lots of code-duplication with EncoderDecoder EncoderClassifier(Ptr options) : options_(options), prefix_(options->get("prefix", "")), inference_(options->get("inference", false)) { modelFeatures_ = {"type", "dim-vocabs", "dim-emb", "dim-rnn", "enc-cell", "enc-type", "enc-cell-depth", "enc-depth", "dec-depth", "dec-cell", "dec-cell-base-depth", "dec-cell-high-depth", "skip", "layer-normalization", "right-left", "input-types", "special-vocab", "tied-embeddings", "tied-embeddings-src", "tied-embeddings-all"}; modelFeatures_.insert("transformer-heads"); modelFeatures_.insert("transformer-no-projection"); modelFeatures_.insert("transformer-dim-ffn"); modelFeatures_.insert("transformer-ffn-depth"); modelFeatures_.insert("transformer-ffn-activation"); modelFeatures_.insert("transformer-dim-aan"); modelFeatures_.insert("transformer-aan-depth"); modelFeatures_.insert("transformer-aan-activation"); modelFeatures_.insert("transformer-aan-nogate"); modelFeatures_.insert("transformer-preprocess"); modelFeatures_.insert("transformer-postprocess"); modelFeatures_.insert("transformer-postprocess-emb"); modelFeatures_.insert("transformer-postprocess-top"); modelFeatures_.insert("transformer-decoder-autoreg"); modelFeatures_.insert("transformer-tied-layers"); modelFeatures_.insert("transformer-guided-alignment-layer"); modelFeatures_.insert("transformer-train-position-embeddings"); modelFeatures_.insert("bert-train-type-embeddings"); modelFeatures_.insert("bert-type-vocab-size"); modelFeatures_.insert("ulr"); modelFeatures_.insert("ulr-trainable-transformation"); modelFeatures_.insert("ulr-dim-emb"); modelFeatures_.insert("lemma-dim-emb"); modelFeatures_.insert("lemma-dependency"); modelFeatures_.insert("factors-combine"); modelFeatures_.insert("factors-dim-emb"); } virtual Ptr getOptions() override { return options_; } std::vector>& getEncoders() { return encoders_; } std::vector>& getClassifiers() { return classifiers_; } void push_back(Ptr encoder) { encoders_.push_back(encoder); } void push_back(Ptr classifier) { classifiers_.push_back(classifier); } void load(Ptr graph, const std::string& name, bool markedReloaded) override { graph->load(name, markedReloaded && !opt("ignore-model-config", false)); } void mmap(Ptr graph, const void* ptr, bool markedReloaded) override { graph->mmap(ptr, markedReloaded && !opt("ignore-model-config", false)); } void save(Ptr graph, const std::string& name, bool /*saveModelConfig*/) override { LOG(info, "Saving model weights and runtime parameters to {}", name); graph->save(name , getModelParametersAsString()); } void clear(Ptr graph) override { graph->clear(); for(auto& enc : encoders_) enc->clear(); for(auto& cls : classifiers_) cls->clear(); } template T opt(const std::string& key) { return options_->get(key); } template T opt(const std::string& key, const T& def) { return options_->get(key, def); } template void set(std::string key, T value) { options_->set(key, value); } /*********************************************************************/ virtual std::vector> apply(Ptr graph, Ptr batch, bool clearGraph) override { if(clearGraph) clear(graph); std::vector> encoderStates; for(auto& encoder : encoders_) encoderStates.push_back(encoder->build(graph, batch)); std::vector> classifierStates; for(auto& classifier : classifiers_) classifierStates.push_back(classifier->apply(graph, batch, encoderStates)); return classifierStates; } virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) override { auto states = apply(graph, batch, clearGraph); // returns raw logits return Logits(states[0]->getLogProbs()); } virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) override { auto corpusBatch = std::static_pointer_cast(batch); return build(graph, corpusBatch, clearGraph); } }; } // namespace marian