.. _program_listing_file_src_rnn_rnn.h: Program Listing for File rnn.h ============================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/rnn/rnn.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "layers/generic.h" #include "marian.h" #include "rnn/cells.h" #include "rnn/types.h" #include #include #include #include #include namespace marian { namespace rnn { enum struct dir : int { forward, backward, alternating_forward, alternating_backward }; } } // namespace marian YAML_REGISTER_TYPE(marian::rnn::dir, int) namespace marian { namespace rnn { class BaseRNN { protected: Ptr graph_; Ptr options_; public: BaseRNN(Ptr graph, Ptr options) : graph_(graph), options_(options) {} virtual ~BaseRNN() {} virtual Expr transduce(Expr, Expr = nullptr) = 0; virtual Expr transduce(Expr, State, Expr = nullptr) = 0; virtual Expr transduce(Expr, States, Expr = nullptr) = 0; virtual States lastCellStates() = 0; virtual void push_back(Ptr) = 0; virtual Ptr at(int i) = 0; virtual Ptr getOptions() { return options_; } }; class RNN; class SingleLayerRNN : public BaseRNN { private: Ptr cell_; dir direction_; States last_; States apply(const Expr input, const States initialState, const Expr mask = nullptr) { last_.clear(); State state = initialState.front(); cell_->clear(); auto xWs = cell_->applyInput({input}); auto timeSteps = input->shape()[-3]; States outputs; for(int i = 0; i < timeSteps; ++i) { int j = i; if(direction_ == dir::backward) j = timeSteps - i - 1; std::vector steps(xWs.size()); std::transform(xWs.begin(), xWs.end(), steps.begin(), [j](Expr e) { return slice(e, -3, j); }); if(mask) state = cell_->applyState(steps, state, slice(mask, -3, j)); else state = cell_->applyState(steps, state); outputs.push_back(state); } if(direction_ == dir::backward) outputs.reverse(); last_.push_back(outputs.back()); return outputs; } States apply(const Expr input, const Expr mask = nullptr) { auto graph = input->graph(); int dimBatch = input->shape()[-2]; int dimState = cell_->getOptions()->get("dimState"); auto output = graph->zeros({1, dimBatch, dimState}); Expr cell = output; State startState{output, cell}; return apply(input, States({startState}), mask); } SingleLayerRNN(Ptr graph, Ptr options) : BaseRNN(graph, options), direction_((dir)options->get("direction", (int)dir::forward)) {} public: friend RNN; virtual ~SingleLayerRNN() {} // @TODO: benchmark whether this concatenation is a good idea virtual Expr transduce(Expr input, Expr mask = nullptr) override { return apply(input, mask).outputs(); } virtual Expr transduce(Expr input, States states, Expr mask = nullptr) override { return apply(input, states, mask).outputs(); } virtual Expr transduce(Expr input, State state, Expr mask = nullptr) override { return apply(input, States({state}), mask).outputs(); } States lastCellStates() override { return last_; } void push_back(Ptr cell) override { cell_ = cell; } virtual Ptr at(int i) override { ABORT_IF(i > 0, "SingleRNN only has one cell"); return cell_; } }; class RNN : public BaseRNN, public std::enable_shared_from_this { private: bool skip_; bool skipFirst_; std::vector> rnns_; public: RNN(Ptr graph, Ptr options) : BaseRNN(graph, options), skip_(options->get("skip", false)), skipFirst_(options->get("skipFirst", false)) {} void push_back(Ptr cell) override { auto rnn = Ptr(new SingleLayerRNN(graph_, cell->getOptions())); rnn->push_back(cell); rnns_.push_back(rnn); } Expr transduce(Expr input, Expr mask = nullptr) override { ABORT_IF(rnns_.empty(), "0 layers in RNN"); Expr output; Expr layerInput = input; for(size_t i = 0; i < rnns_.size(); ++i) { auto lazyInput = layerInput; auto cell = rnns_[i]->at(0); auto lazyInputs = cell->getLazyInputs(shared_from_this()); if(!lazyInputs.empty()) { lazyInputs.push_back(layerInput); lazyInput = concatenate(lazyInputs, /*axis =*/ -1); } auto layerOutput = rnns_[i]->transduce(lazyInput, mask); if(skip_ && (skipFirst_ || i > 0)) output = layerOutput + layerInput; else output = layerOutput; layerInput = output; } return output; } Expr transduce(Expr input, States states, Expr mask = nullptr) override { ABORT_IF(rnns_.empty(), "0 layers in RNN"); Expr output; Expr layerInput = input; for(size_t i = 0; i < rnns_.size(); ++i) { Expr lazyInput; auto cell = rnns_[i]->at(0); auto lazyInputs = cell->getLazyInputs(shared_from_this()); if(!lazyInputs.empty()) { lazyInputs.push_back(layerInput); lazyInput = concatenate(lazyInputs, /*axis =*/ -1); } else { lazyInput = layerInput; } auto layerOutput = rnns_[i]->transduce(lazyInput, States({states[i]}), mask); if(skip_ && (skipFirst_ || i > 0)) output = layerOutput + layerInput; else output = layerOutput; layerInput = output; } return output; } Expr transduce(Expr input, State state, Expr mask = nullptr) override { ABORT_IF(rnns_.empty(), "0 layers in RNN"); Expr output; Expr layerInput = input; for(size_t i = 0; i < rnns_.size(); ++i) { auto lazyInput = layerInput; auto cell = rnns_[i]->at(0); auto lazyInputs = cell->getLazyInputs(shared_from_this()); if(!lazyInputs.empty()) { lazyInputs.push_back(layerInput); lazyInput = concatenate(lazyInputs, /*axis =*/ -1); } auto layerOutput = rnns_[i]->transduce(lazyInput, States({state}), mask); if(skip_ && (skipFirst_ || i > 0)) output = layerOutput + layerInput; else output = layerOutput; layerInput = output; } return output; } States lastCellStates() override { States temp; for(auto rnn : rnns_) temp.push_back(rnn->lastCellStates().back()); return temp; } virtual Ptr at(int i) override { return rnns_[i]->at(0); } }; } // namespace rnn } // namespace marian