Program Listing for File decoder.h¶
↰ Return to documentation for file (src/models/decoder.h
)
#pragma once
#include "marian.h"
#include "states.h"
#include "data/shortlist.h"
#include "layers/constructors.h"
#include "layers/generic.h"
namespace marian {
class DecoderBase : public EncoderDecoderLayerBase {
protected:
Ptr<data::Shortlist> shortlist_;
public:
DecoderBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) :
EncoderDecoderLayerBase(graph, options, "decoder", /*batchIndex=*/1,
options->get<float>("dropout-trg", 0.0f),
options->get<bool>("embedding-fix-trg", false)) {}
virtual Ptr<DecoderState> startState(Ptr<ExpressionGraph>,
Ptr<data::CorpusBatch> batch,
std::vector<Ptr<EncoderState>>&)
= 0;
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph>, Ptr<DecoderState>) = 0;
virtual void embeddingsFromBatch(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
Ptr<data::CorpusBatch> batch) {
graph_ = graph;
auto subBatch = (*batch)[batchIndex_];
Expr y, yMask; std::tie
(y, yMask) = getEmbeddingLayer()->apply(subBatch);
// @TODO: during training there is currently no code path that leads to using a shortlist
#if 0
const Words& data =
/*if*/ (shortlist_) ?
shortlist_->mappedIndices()
/*else*/ :
subBatch->data();
#endif
ABORT_IF(shortlist_, "How did a shortlist make it into training?");
auto yDelayed = shift(y, {1, 0, 0}); // insert zero at front; first word gets predicted from a target embedding of 0
state->setTargetHistoryEmbeddings(yDelayed);
state->setTargetMask(yMask);
const Words& data = subBatch->data();
state->setTargetWords(data);
}
virtual void embeddingsFromPrediction(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
const Words& words,
int dimBatch,
int dimBeam) {
graph_ = graph;
auto embeddingLayer = getEmbeddingLayer();
Expr selectedEmbs;
int dimEmb = opt<int>("dim-emb");
if(words.empty())
selectedEmbs = graph_->constant({1, 1, dimBatch, dimEmb}, inits::zeros());
else
selectedEmbs = embeddingLayer->apply(words, {dimBeam, 1, dimBatch, dimEmb});
state->setTargetHistoryEmbeddings(selectedEmbs);
}
virtual const std::vector<Expr> getAlignments(int /*i*/ = 0) { return {}; }; // [tgt index][beam depth, max src length, batch size, 1]
virtual Ptr<data::Shortlist> getShortlist() { return shortlist_; }
virtual void setShortlist(Ptr<data::Shortlist> shortlist) {
shortlist_ = shortlist;
}
virtual void clear() = 0;
};
} // namespace marian