Class DecoderState

Inheritance Relationships

Derived Type

Class Documentation

class DecoderState

Subclassed by marian::TransformerState

Public Functions

DecoderState(const rnn::States &states, Logits logProbs, const std::vector<Ptr<EncoderState>> &encStates, Ptr<data::CorpusBatch> batch)
virtual ~DecoderState()
virtual const std::vector<Ptr<EncoderState>> &getEncoderStates() const
virtual Logits getLogProbs() const
virtual void setLogProbs(Logits logProbs)
virtual Ptr<DecoderState> select(const std::vector<IndexType> &hypIndices, const std::vector<IndexType> &batchIndices, int beamSize) const
virtual const rnn::States &getStates() const
virtual Expr getTargetHistoryEmbeddings() const
virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings)
virtual const Words &getTargetWords() const
virtual void setTargetWords(const Words &targetWords)
virtual Expr getTargetMask() const
virtual void setTargetMask(Expr targetMask)
virtual const Words &getSourceWords() const
Ptr<data::CorpusBatch> getBatch() const
size_t getPosition() const
void setPosition(size_t position)
virtual void blacklist(Expr, Ptr<data::CorpusBatch>)

Protected Attributes

rnn::States states_
Logits logProbs_
std::vector<Ptr<EncoderState>> encStates_
Ptr<data::CorpusBatch> batch_
Expr targetHistoryEmbeddings_
Expr targetMask_
Words targetWords_
size_t position_ = {0}