Class LabelwiseLoss

Inheritance Relationships

Derived Type

Class Documentation

class LabelwiseLoss

Computes loss per given groundtruth label and then reduces to RationalLoss.

Subclassed by marian::CrossEntropyLoss

Public Functions

LabelwiseLoss(const std::vector<int> &axes)
virtual RationalLoss apply(Logits logits, const Words &labels, Expr mask = nullptr, Expr labelWeights = nullptr)

Protected Functions

virtual Expr compute(Logits logits, const Words &labels, Expr mask = nullptr, Expr labelWeights = nullptr) = 0
RationalLoss reduce(Expr loss, Expr labels)
RationalLoss reduce(Expr loss)

Protected Attributes

std::vector<int> axes_