.. _program_listing_file_src_rnn_constructors.h: Program Listing for File constructors.h ======================================= |exhale_lsh| :ref:`Return to documentation for file ` (``src/rnn/constructors.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "layers/factory.h" #include "marian.h" #include "rnn/rnn.h" namespace marian { namespace rnn { typedef Factory StackableFactory; //struct StackableFactory : public Factory {StackableFactory // using Factory::Factory; // // virtual ~StackableFactory() {} // // template // inline Ptr as() { // return std::dynamic_pointer_cast(shared_from_this()); // } // // template // inline bool is() { // return as() != nullptr; // } //}; struct InputFactory : public StackableFactory { virtual Ptr construct(Ptr graph) = 0; }; class CellFactory : public StackableFactory { protected: std::vector)>> inputs_; public: virtual Ptr construct(Ptr graph) { std::string type = options_->get("type"); if(type == "gru") { auto cell = New(graph, options_); cell->setLazyInputs(inputs_); return cell; } else if(type == "gru-nematus") { auto cell = New(graph, options_); cell->setLazyInputs(inputs_); return cell; } else if(type == "lstm") { auto cell = New(graph, options_); cell->setLazyInputs(inputs_); return cell; } else if(type == "mlstm") { auto cell = New(graph, options_); cell->setLazyInputs(inputs_); return cell; } else if(type == "mgru") { auto cell = New(graph, options_); cell->setLazyInputs(inputs_); return cell; } else if(type == "tanh") { auto cell = New(graph, options_); cell->setLazyInputs(inputs_); return cell; } else if(type == "relu") { auto cell = New(graph, options_); cell->setLazyInputs(inputs_); return cell; } else if(type == "sru") { auto cell = New(graph, options_); cell->setLazyInputs(inputs_); return cell; } else if(type == "ssru") { auto cell = New(graph, options_); cell->setLazyInputs(inputs_); return cell; } else { ABORT("Unknown RNN cell type"); } } CellFactory clone() { CellFactory aClone; aClone.options_->merge(options_); aClone.inputs_ = inputs_; return aClone; } virtual void add_input(std::function)> func) { inputs_.push_back(func); } virtual void add_input(Expr input) { inputs_.push_back([input](Ptr /*rnn*/) { return input; }); } }; typedef Accumulator cell; class StackedCellFactory : public CellFactory { protected: std::vector> stackableFactories_; public: Ptr construct(Ptr graph) override { auto stacked = New(graph, options_); int lastDimInput = options_->get("dimInput"); for(size_t i = 0; i < stackableFactories_.size(); ++i) { auto sf = stackableFactories_[i]; if(sf->is()) { auto cellFactory = sf->as(); cellFactory->mergeOpts(options_); sf->setOpt("dimInput", lastDimInput); lastDimInput = 0; if(i == 0) for(auto f : inputs_) cellFactory->add_input(f); stacked->push_back(cellFactory->construct(graph)); } else { auto inputFactory = sf->as(); inputFactory->mergeOpts(options_); auto input = inputFactory->construct(graph); stacked->push_back(input); lastDimInput += input->dimOutput(); } } return stacked; } template Accumulator push_back(const F& f) { stackableFactories_.push_back(New(f)); return Accumulator(*this); } }; typedef Accumulator stacked_cell; class RNNFactory : public Factory { using Factory::Factory; protected: std::vector> layerFactories_; public: Ptr construct(Ptr graph) { auto rnn = New(graph, options_); for(size_t i = 0; i < layerFactories_.size(); ++i) { auto lf = layerFactories_[i]; lf->mergeOpts(options_); if(i > 0) { int dimInput = layerFactories_[i - 1]->opt("dimState") + lf->opt("dimInputExtra", 0); lf->setOpt("dimInput", dimInput); } if((rnn::dir)opt("direction", (int)rnn::dir::forward) == rnn::dir::alternating_forward) { if(i % 2 == 0) lf->setOpt("direction", (int)rnn::dir::forward); else lf->setOpt("direction", (int)rnn::dir::backward); } if((rnn::dir)opt("direction", (int)rnn::dir::forward) == rnn::dir::alternating_backward) { if(i % 2 == 1) lf->setOpt("direction", (int)rnn::dir::forward); else lf->setOpt("direction", (int)rnn::dir::backward); } rnn->push_back(lf->construct(graph)); } return rnn; } template Accumulator push_back(const F& f) { layerFactories_.push_back(New(f)); return Accumulator(*this); } RNNFactory clone() { RNNFactory aClone; aClone.options_->merge(options_); for(auto lf : layerFactories_) aClone.push_back(lf->clone()); return aClone; } }; typedef Accumulator rnn; } // namespace rnn } // namespace marian