Program Listing for File encoder_classifier.h

Return to documentation for file (src/models/encoder_classifier.h)

#pragma once

#include "marian.h"

#include "models/encoder.h"
#include "models/classifier.h"
#include "models/model_base.h"
#include "models/states.h"

namespace marian {

class EncoderClassifierBase : public models::IModel {
public:
  virtual ~EncoderClassifierBase() {}

  virtual void load(Ptr<ExpressionGraph> graph,
                    const std::string& name,
                    bool markedReloaded = true) override
      = 0;

  virtual void mmap(Ptr<ExpressionGraph> graph,
                    const void* ptr,
                    bool markedReloaded = true)
      = 0;

  virtual void save(Ptr<ExpressionGraph> graph,
                    const std::string& name,
                    bool saveTranslatorConfig = false) override
      = 0;

  virtual void clear(Ptr<ExpressionGraph> graph) override = 0;

  virtual std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;


  virtual Logits build(Ptr<ExpressionGraph> graph,
                       Ptr<data::Batch> batch,
                       bool clearGraph = true) override = 0;

  virtual Logits build(Ptr<ExpressionGraph> graph,
                       Ptr<data::CorpusBatch> batch,
                       bool clearGraph = true) = 0;

  virtual Ptr<Options> getOptions() = 0;
};

class EncoderClassifier : public EncoderClassifierBase {
protected:
  Ptr<Options> options_;

  std::string prefix_;

  std::vector<Ptr<EncoderBase>> encoders_;
  std::vector<Ptr<ClassifierBase>> classifiers_;

  bool inference_{false};

  std::set<std::string> modelFeatures_;

  Config::YamlNode getModelParameters() {
    Config::YamlNode modelParams;
    auto clone = options_->cloneToYamlNode();
    for(auto& key : modelFeatures_)
      modelParams[key] = clone[key];

    if(options_->has("original-type"))
      modelParams["type"] = clone["original-type"];

    modelParams["version"] = buildVersion();
    return modelParams;
  }

  std::string getModelParametersAsString() {
    auto yaml = getModelParameters();
    YAML::Emitter out;
    cli::OutputYaml(yaml, out);
    return std::string(out.c_str());
  }

public:
  typedef data::Corpus dataset_type;

  // @TODO: lots of code-duplication with EncoderDecoder
  EncoderClassifier(Ptr<Options> options)
    : options_(options),
      prefix_(options->get<std::string>("prefix", "")),
      inference_(options->get<bool>("inference", false)) {
  modelFeatures_ = {"type",
                    "dim-vocabs",
                    "dim-emb",
                    "dim-rnn",
                    "enc-cell",
                    "enc-type",
                    "enc-cell-depth",
                    "enc-depth",
                    "dec-depth",
                    "dec-cell",
                    "dec-cell-base-depth",
                    "dec-cell-high-depth",
                    "skip",
                    "layer-normalization",
                    "right-left",
                    "input-types",
                    "special-vocab",
                    "tied-embeddings",
                    "tied-embeddings-src",
                    "tied-embeddings-all"};

    modelFeatures_.insert("transformer-heads");
    modelFeatures_.insert("transformer-no-projection");
    modelFeatures_.insert("transformer-dim-ffn");
    modelFeatures_.insert("transformer-ffn-depth");
    modelFeatures_.insert("transformer-ffn-activation");
    modelFeatures_.insert("transformer-dim-aan");
    modelFeatures_.insert("transformer-aan-depth");
    modelFeatures_.insert("transformer-aan-activation");
    modelFeatures_.insert("transformer-aan-nogate");
    modelFeatures_.insert("transformer-preprocess");
    modelFeatures_.insert("transformer-postprocess");
    modelFeatures_.insert("transformer-postprocess-emb");
    modelFeatures_.insert("transformer-postprocess-top");
    modelFeatures_.insert("transformer-decoder-autoreg");
    modelFeatures_.insert("transformer-tied-layers");
    modelFeatures_.insert("transformer-guided-alignment-layer");
    modelFeatures_.insert("transformer-train-position-embeddings");

    modelFeatures_.insert("bert-train-type-embeddings");
    modelFeatures_.insert("bert-type-vocab-size");

    modelFeatures_.insert("ulr");
    modelFeatures_.insert("ulr-trainable-transformation");
    modelFeatures_.insert("ulr-dim-emb");
    modelFeatures_.insert("lemma-dim-emb");
    modelFeatures_.insert("lemma-dependency");
    modelFeatures_.insert("factors-combine");
    modelFeatures_.insert("factors-dim-emb");
  }

  virtual Ptr<Options> getOptions() override { return options_; }

  std::vector<Ptr<EncoderBase>>& getEncoders() { return encoders_; }
  std::vector<Ptr<ClassifierBase>>& getClassifiers() { return classifiers_; }

  void push_back(Ptr<EncoderBase> encoder) { encoders_.push_back(encoder); }
  void push_back(Ptr<ClassifierBase> classifier) { classifiers_.push_back(classifier); }

  void load(Ptr<ExpressionGraph> graph,
            const std::string& name,
            bool markedReloaded) override {
    graph->load(name, markedReloaded && !opt<bool>("ignore-model-config", false));
  }

  void mmap(Ptr<ExpressionGraph> graph,
            const void* ptr,
            bool markedReloaded) override {
    graph->mmap(ptr, markedReloaded && !opt<bool>("ignore-model-config", false));
  }

  void save(Ptr<ExpressionGraph> graph,
            const std::string& name,
            bool /*saveModelConfig*/) override {
    LOG(info, "Saving model weights and runtime parameters to {}", name);
    graph->save(name , getModelParametersAsString());
  }

  void clear(Ptr<ExpressionGraph> graph) override {
    graph->clear();

    for(auto& enc : encoders_)
      enc->clear();
    for(auto& cls : classifiers_)
      cls->clear();
  }

  template <typename T>
  T opt(const std::string& key) {
    return options_->get<T>(key);
  }

  template <typename T>
  T opt(const std::string& key, const T& def) {
    return options_->get<T>(key, def);
  }

  template <typename T>
  void set(std::string key, T value) {
    options_->set(key, value);
  }

  /*********************************************************************/

  virtual std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
    if(clearGraph)
      clear(graph);

    std::vector<Ptr<EncoderState>> encoderStates;
    for(auto& encoder : encoders_)
        encoderStates.push_back(encoder->build(graph, batch));

    std::vector<Ptr<ClassifierState>> classifierStates;
    for(auto& classifier : classifiers_)
      classifierStates.push_back(classifier->apply(graph, batch, encoderStates));

    return classifierStates;
  }

  virtual Logits build(Ptr<ExpressionGraph> graph,
                       Ptr<data::CorpusBatch> batch,
                       bool clearGraph = true) override {
    auto states = apply(graph, batch, clearGraph);
    // returns raw logits
    return Logits(states[0]->getLogProbs());
  }

  virtual Logits build(Ptr<ExpressionGraph> graph,
                       Ptr<data::Batch> batch,
                       bool clearGraph = true) override {
    auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
    return build(graph, corpusBatch, clearGraph);
  }
};

}  // namespace marian