Program Listing for File model_factory.h

Return to documentation for file (src/models/model_factory.h)

#pragma once

#include "marian.h"

#include "layers/factory.h"
#include "models/encoder_decoder.h"
#include "models/encoder_classifier.h"
#include "models/encoder_pooler.h"

namespace marian {
namespace models {

class EncoderFactory : public Factory {
  using Factory::Factory;
public:
  virtual Ptr<EncoderBase> construct(Ptr<ExpressionGraph> graph);
};

typedef Accumulator<EncoderFactory> encoder;

class DecoderFactory : public Factory {
  using Factory::Factory;
public:
  virtual Ptr<DecoderBase> construct(Ptr<ExpressionGraph> graph);
};

typedef Accumulator<DecoderFactory> decoder;

class ClassifierFactory : public Factory {
  using Factory::Factory;
public:
  virtual Ptr<ClassifierBase> construct(Ptr<ExpressionGraph> graph);
};

typedef Accumulator<ClassifierFactory> classifier;

class PoolerFactory : public Factory {
  using Factory::Factory;
public:
  virtual Ptr<PoolerBase> construct(Ptr<ExpressionGraph> graph);
};

typedef Accumulator<PoolerFactory> pooler;

class EncoderDecoderFactory : public Factory {
  using Factory::Factory;
private:
  std::vector<encoder> encoders_;
  std::vector<decoder> decoders_;

public:
  Accumulator<EncoderDecoderFactory> push_back(encoder enc) {
    encoders_.push_back(enc);
    return Accumulator<EncoderDecoderFactory>(*this);
  }

  Accumulator<EncoderDecoderFactory> push_back(decoder dec) {
    decoders_.push_back(dec);
    return Accumulator<EncoderDecoderFactory>(*this);
  }

  virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
};

typedef Accumulator<EncoderDecoderFactory> encoder_decoder;

class EncoderClassifierFactory : public Factory {
  using Factory::Factory;
private:
  std::vector<encoder> encoders_;
  std::vector<classifier> classifiers_;

public:
  Accumulator<EncoderClassifierFactory> push_back(encoder enc) {
    encoders_.push_back(enc);
    return Accumulator<EncoderClassifierFactory>(*this);
  }

  Accumulator<EncoderClassifierFactory> push_back(classifier cls) {
    classifiers_.push_back(cls);
    return Accumulator<EncoderClassifierFactory>(*this);
  }

  virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
};

typedef Accumulator<EncoderClassifierFactory> encoder_classifier;

class EncoderPoolerFactory : public Factory {
  using Factory::Factory;
private:
  std::vector<encoder> encoders_;
  std::vector<pooler> poolers_;

public:
  Accumulator<EncoderPoolerFactory> push_back(encoder enc) {
    encoders_.push_back(enc);
    return Accumulator<EncoderPoolerFactory>(*this);
  }

  Accumulator<EncoderPoolerFactory> push_back(pooler cls) {
    poolers_.push_back(cls);
    return Accumulator<EncoderPoolerFactory>(*this);
  }

  virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
};

typedef Accumulator<EncoderPoolerFactory> encoder_pooler;

Ptr<IModel> createBaseModelByType(std::string type, usage, Ptr<Options> options);

Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage);

Ptr<ICriterionFunction> createCriterionFunctionFromOptions(Ptr<Options> options, usage);
}  // namespace models
}  // namespace marian