Class ClassifierBase

Inheritance Relationships

Base Type

Derived Types

Class Documentation

class ClassifierBase : public marian::LayerBase

Simple base class for Classifiers to be used in EncoderClassifier framework Currently only implementations are in bert.h.

Subclassed by marian::BertClassifier, marian::BertMaskedLM

Public Functions

ClassifierBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
virtual ~ClassifierBase()
virtual Ptr<ClassifierState> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0
template<typename T>
T opt(const std::string &key) const
virtual void clear() = 0

Protected Attributes

Ptr<Options> options_
const std::string prefix_ = {"classifier"}
const bool inference_ = {false}
const size_t batchIndex_ = {0}