.. _program_listing_file_src_models_states.h: Program Listing for File states.h ================================= |exhale_lsh| :ref:`Return to documentation for file ` (``src/models/states.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "layers/logits.h" // @HACK: for factored embeddings only so far #include "marian.h" #include "rnn/types.h" namespace marian { class EncoderState { private: Expr context_; Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask Ptr batch_; public: EncoderState(Expr context, Expr mask, Ptr batch) : context_(context), mask_(mask), batch_(batch) {} EncoderState() {} virtual ~EncoderState() {} virtual Expr getContext() const { return context_; } virtual Expr getAttended() const { return context_; } virtual Expr getMask() const { return mask_; } // source batch mask; may have additional positions suppressed virtual const Words& getSourceWords() { return batch_->front()->data(); } // Sub-select active batch entries from encoder context and context mask Ptr select( const std::vector& batchIndices) { // [batchIndex] indices of active batch entries // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer // gets transposed to the same dimension layout return New( index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_); } }; class DecoderState { protected: rnn::States states_; // states of individual decoder layers Logits logProbs_; std::vector> encStates_; Ptr batch_; Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded Expr targetMask_; Words targetWords_; // target labels // Keep track of current target token position during translation size_t position_{0}; public: DecoderState(const rnn::States& states, Logits logProbs, const std::vector>& encStates, Ptr batch) : states_(states), logProbs_(logProbs), encStates_(encStates), batch_(batch) {} virtual ~DecoderState() {} // @TODO: Do we need all these to be virtual? virtual const std::vector>& getEncoderStates() const { return encStates_; } virtual Logits getLogProbs() const { return logProbs_; } virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; } // @TODO: should this be a constructor? Then derived classes can call this without the New<> in // the loop virtual Ptr select( const std::vector& hypIndices, // [beamIndex * activeBatchSize + batchIndex] const std::vector& batchIndices, // [batchIndex] int beamSize) const { std::vector> newEncStates; for(auto& es : encStates_) // If the size of the batch dimension of the encoder state context changed, subselect the // correct batch entries newEncStates.push_back( es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices)); // hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices auto selectedState = New(states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_); // Set positon of new state based on the target token position of current state selectedState->setPosition(getPosition()); return selectedState; } virtual const rnn::States& getStates() const { return states_; } virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; }; virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; } virtual const Words& getTargetWords() const { return targetWords_; }; virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; } virtual Expr getTargetMask() const { return targetMask_; }; virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; } virtual const Words& getSourceWords() const { return getEncoderStates()[0]->getSourceWords(); } Ptr getBatch() const { return batch_; } // Set current target token position in state when decoding size_t getPosition() const { return position_; } // Set current target token position in state when decoding void setPosition(size_t position) { position_ = position; } virtual void blacklist(Expr /*totalCosts*/, Ptr /*batch*/) {} }; class ClassifierState { private: Expr logProbs_; std::vector> encStates_; Ptr batch_; Expr targetMask_; Words targetWords_; public: virtual ~ClassifierState() {} virtual Expr getLogProbs() const { return logProbs_; } virtual void setLogProbs(Expr logProbs) { logProbs_ = logProbs; } virtual const Words& getTargetWords() const { return targetWords_; }; virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; } virtual Expr getTargetMask() const { return targetMask_; }; virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; } }; } // namespace marian