.. _program_listing_file_src_rnn_types.h: Program Listing for File types.h ================================ |exhale_lsh| :ref:`Return to documentation for file ` (``src/rnn/types.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "marian.h" #include #include namespace marian { namespace rnn { struct State { Expr output; Expr cell; State select(const std::vector& selIdx, // [beamIndex * activeBatchSize + batchIndex] int beamSize, bool isBatchMajor) const { return{ select(output, selIdx, beamSize, isBatchMajor), select(cell, selIdx, beamSize, isBatchMajor) }; } // this function is also called by Logits static Expr select(Expr sel, // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] (dimTime = 1 for RNN) const std::vector& selIdx, // [beamIndex * activeBatchSize + batchIndex] int beamSize, bool isBatchMajor) { if (!sel) return sel; // keep nullptr untouched sel = atleast_4d(sel); int dimBatch = (int)selIdx.size() / beamSize; int dimDepth = sel->shape()[-1]; int dimTime = isBatchMajor ? sel->shape()[-2] : sel->shape()[-3]; ABORT_IF(dimTime != 1 && !isBatchMajor, "unexpected time extent for RNN state"); // (the reshape()/rows() trick won't work in this case) int numCols = isBatchMajor ? dimDepth * dimTime : dimDepth; // @TODO: Can this complex operation be more easily written using index_select()? sel = reshape(sel, { sel->shape().elements() / numCols, numCols }); // [beamSize * dimBatch, dimDepth] or [beamSize * dimBatch, dimTime * dimDepth] sel = rows(sel, selIdx); sel = reshape(sel, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth }); return sel; } }; class States { private: std::vector states_; public: States() {} States(const std::vector& states) : states_(states) {} States(size_t num, State state) : states_(num, state) {} std::vector::iterator begin() { return states_.begin(); } std::vector::iterator end() { return states_.end(); } std::vector::const_iterator begin() const { return states_.begin(); } std::vector::const_iterator end() const { return states_.end(); } Expr outputs() { std::vector outputs; for(auto s : states_) outputs.push_back(atleast_3d(s.output)); if(outputs.size() > 1) return concatenate(outputs, /*axis =*/ -3); else return outputs[0]; } State& operator[](size_t i) { return states_[i]; }; const State& operator[](size_t i) const { return states_[i]; }; State& back() { return states_.back(); } const State& back() const { return states_.back(); } State& front() { return states_.front(); } const State& front() const { return states_.front(); } size_t size() const { return states_.size(); }; void push_back(const State& state) { states_.push_back(state); } // create updated set of states that reflect reordering and dropping of hypotheses States select(const std::vector& selIdx, // [beamIndex * activeBatchSize + batchIndex] int beamSize, bool isBatchMajor) const { States selected; for(auto& state : states_) selected.push_back(state.select(selIdx, beamSize, isBatchMajor)); return selected; } void reverse() { std::reverse(states_.begin(), states_.end()); } void clear() { states_.clear(); } }; class Cell; class CellInput; class Stackable : public std::enable_shared_from_this { protected: Ptr options_; public: Stackable(Ptr options) : options_(options) {} // required for dynamic_pointer_cast to detect polymorphism virtual ~Stackable() {} template inline Ptr as() { return std::dynamic_pointer_cast(shared_from_this()); } template inline bool is() { return as() != nullptr; } Ptr getOptions() { return options_; } template T opt(const std::string& key) { return options_->get(key); } template T opt(const std::string& key, T defaultValue) { return options_->get(key, defaultValue); } virtual void clear() = 0; }; class CellInput : public Stackable { public: CellInput(Ptr options) : Stackable(options) {} virtual Expr apply(State) = 0; virtual int dimOutput() = 0; }; class RNN; class Cell : public Stackable { protected: std::vector)>> lazyInputs_; public: Cell(Ptr options) : Stackable(options) {} State apply(std::vector inputs, State state, Expr mask = nullptr) { return applyState(applyInput(inputs), state, mask); } virtual std::vector getLazyInputs(Ptr parent) { std::vector inputs; for(auto lazy : lazyInputs_) inputs.push_back(lazy(parent)); return inputs; } virtual void setLazyInputs( std::vector)>> lazy) { lazyInputs_ = lazy; } virtual std::vector applyInput(std::vector inputs) = 0; virtual State applyState(std::vector, State, Expr = nullptr) = 0; virtual void clear() override {} }; class MultiCellInput : public CellInput { protected: std::vector> inputs_; public: MultiCellInput(const std::vector>& inputs, Ptr options) : CellInput(options), inputs_(inputs) {} void push_back(Ptr input) { inputs_.push_back(input); } virtual Expr apply(State state) override { std::vector outputs; for(auto input : inputs_) outputs.push_back(input->apply(state)); if(outputs.size() > 1) return concatenate(outputs, /*axis =*/ -1); else return outputs[0]; } virtual int dimOutput() override { int sum = 0; for(auto input : inputs_) sum += input->dimOutput(); return sum; } virtual void clear() override { for(auto i : inputs_) i->clear(); } }; class StackedCell : public Cell { protected: std::vector> stackables_; std::vector lastInputs_; public: StackedCell(Ptr, Ptr options) : Cell(options) {} StackedCell(Ptr, Ptr options, const std::vector>& stackables) : Cell(options), stackables_(stackables) {} void push_back(Ptr stackable) { stackables_.push_back(stackable); } virtual std::vector applyInput(std::vector inputs) override { // lastInputs_ = inputs; return stackables_[0]->as()->applyInput(inputs); } virtual State applyState(std::vector mappedInputs, State state, Expr mask = nullptr) override { State hidden = stackables_[0]->as()->applyState(mappedInputs, state, mask); ; for(size_t i = 1; i < stackables_.size(); ++i) { if(stackables_[i]->is()) { auto hiddenNext = stackables_[i]->as()->apply(lastInputs_, hidden, mask); lastInputs_.clear(); hidden = hiddenNext; } else { lastInputs_.push_back(stackables_[i]->as()->apply(hidden)); // lastInputs_ = { stackables_[i]->as()->apply(hidden) }; } } return hidden; }; Ptr operator[](int i) { return stackables_[i]; } Ptr at(int i) { return stackables_[i]; } virtual void clear() override { for(auto s : stackables_) s->clear(); } virtual std::vector getLazyInputs(Ptr parent) override { ABORT_IF(!stackables_[0]->is(), "First stackable should be of type Cell"); return stackables_[0]->as()->getLazyInputs(parent); } virtual void setLazyInputs( std::vector)>> lazy) override { ABORT_IF(!stackables_[0]->is(), "First stackable should be of type Cell"); stackables_[0]->as()->setLazyInputs(lazy); } }; } // namespace rnn } // namespace marian