Program Listing for File output.h¶
↰ Return to documentation for file (src/layers/output.h
)
#pragma once
#include "data/shortlist.h"
#include "generic.h"
#include "layers/factory.h"
#include "logits.h"
#include "marian.h"
namespace marian {
namespace mlp {
class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList {
private:
// parameters held by this layer
Expr Wt_; // weight matrix is stored transposed for efficiency
Expr b_;
Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
bool hasBias_{true};
Ptr<FactoredVocab> factoredVocab_;
// optional parameters set/updated after construction
Expr tiedParam_;
Ptr<data::Shortlist> shortlist_;
void lazyConstruct(int inputDim);
public:
Output(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options), hasBias_{!options->get<bool>("output-omit-bias", false)} {
clear();
}
void tieTransposed(Expr tied) {
if(Wt_)
ABORT_IF(tiedParam_.get() != tied.get(),
"Tied output projection cannot be changed once weights have been created");
else
tiedParam_ = tied;
}
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
if(shortlist_)
ABORT_IF(shortlist.get() != shortlist_.get(),
"Output shortlist cannot be changed except after clear()");
else {
shortlist_ = shortlist;
}
// cachedShortWt_ and cachedShortb_ will be created lazily inside apply()
}
// this is expected to be called in sync with graph->clear(), which invalidates
// cachedShortWt_ etc. in the graph's short-term cache
void clear() override final {
shortlist_ = nullptr;
}
Logits applyAsLogits(Expr input) override final;
};
} // namespace mlp
} // namespace marian