Program Listing for File constructors.h¶
↰ Return to documentation for file (src/rnn/constructors.h
)
#pragma once
#include "layers/factory.h"
#include "marian.h"
#include "rnn/rnn.h"
namespace marian {
namespace rnn {
typedef Factory StackableFactory;
//struct StackableFactory : public Factory {StackableFactory
// using Factory::Factory;
//
// virtual ~StackableFactory() {}
//
// template <typename Cast>
// inline Ptr<Cast> as() {
// return std::dynamic_pointer_cast<Cast>(shared_from_this());
// }
//
// template <typename Cast>
// inline bool is() {
// return as<Cast>() != nullptr;
// }
//};
struct InputFactory : public StackableFactory {
virtual Ptr<CellInput> construct(Ptr<ExpressionGraph> graph) = 0;
};
class CellFactory : public StackableFactory {
protected:
std::vector<std::function<Expr(Ptr<rnn::RNN>)>> inputs_;
public:
virtual Ptr<Cell> construct(Ptr<ExpressionGraph> graph) {
std::string type = options_->get<std::string>("type");
if(type == "gru") {
auto cell = New<GRU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "gru-nematus") {
auto cell = New<GRUNematus>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "lstm") {
auto cell = New<LSTM>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "mlstm") {
auto cell = New<MLSTM>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "mgru") {
auto cell = New<MGRU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "tanh") {
auto cell = New<Tanh>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "relu") {
auto cell = New<ReLU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "sru") {
auto cell = New<SRU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "ssru") {
auto cell = New<SSRU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else {
ABORT("Unknown RNN cell type");
}
}
CellFactory clone() {
CellFactory aClone;
aClone.options_->merge(options_);
aClone.inputs_ = inputs_;
return aClone;
}
virtual void add_input(std::function<Expr(Ptr<rnn::RNN>)> func) {
inputs_.push_back(func);
}
virtual void add_input(Expr input) {
inputs_.push_back([input](Ptr<rnn::RNN> /*rnn*/) { return input; });
}
};
typedef Accumulator<CellFactory> cell;
class StackedCellFactory : public CellFactory {
protected:
std::vector<Ptr<StackableFactory>> stackableFactories_;
public:
Ptr<Cell> construct(Ptr<ExpressionGraph> graph) override {
auto stacked = New<StackedCell>(graph, options_);
int lastDimInput = options_->get<int>("dimInput");
for(size_t i = 0; i < stackableFactories_.size(); ++i) {
auto sf = stackableFactories_[i];
if(sf->is<CellFactory>()) {
auto cellFactory = sf->as<CellFactory>();
cellFactory->mergeOpts(options_);
sf->setOpt("dimInput", lastDimInput);
lastDimInput = 0;
if(i == 0)
for(auto f : inputs_)
cellFactory->add_input(f);
stacked->push_back(cellFactory->construct(graph));
} else {
auto inputFactory = sf->as<InputFactory>();
inputFactory->mergeOpts(options_);
auto input = inputFactory->construct(graph);
stacked->push_back(input);
lastDimInput += input->dimOutput();
}
}
return stacked;
}
template <class F>
Accumulator<StackedCellFactory> push_back(const F& f) {
stackableFactories_.push_back(New<F>(f));
return Accumulator<StackedCellFactory>(*this);
}
};
typedef Accumulator<StackedCellFactory> stacked_cell;
class RNNFactory : public Factory {
using Factory::Factory;
protected:
std::vector<Ptr<CellFactory>> layerFactories_;
public:
Ptr<RNN> construct(Ptr<ExpressionGraph> graph) {
auto rnn = New<RNN>(graph, options_);
for(size_t i = 0; i < layerFactories_.size(); ++i) {
auto lf = layerFactories_[i];
lf->mergeOpts(options_);
if(i > 0) {
int dimInput
= layerFactories_[i - 1]->opt<int>("dimState")
+ lf->opt<int>("dimInputExtra", 0);
lf->setOpt("dimInput", dimInput);
}
if((rnn::dir)opt<int>("direction", (int)rnn::dir::forward)
== rnn::dir::alternating_forward) {
if(i % 2 == 0)
lf->setOpt("direction", (int)rnn::dir::forward);
else
lf->setOpt("direction", (int)rnn::dir::backward);
}
if((rnn::dir)opt<int>("direction", (int)rnn::dir::forward)
== rnn::dir::alternating_backward) {
if(i % 2 == 1)
lf->setOpt("direction", (int)rnn::dir::forward);
else
lf->setOpt("direction", (int)rnn::dir::backward);
}
rnn->push_back(lf->construct(graph));
}
return rnn;
}
template <class F>
Accumulator<RNNFactory> push_back(const F& f) {
layerFactories_.push_back(New<F>(f));
return Accumulator<RNNFactory>(*this);
}
RNNFactory clone() {
RNNFactory aClone;
aClone.options_->merge(options_);
for(auto lf : layerFactories_)
aClone.push_back(lf->clone());
return aClone;
}
};
typedef Accumulator<RNNFactory> rnn;
} // namespace rnn
} // namespace marian