Class CrossEntropyLoss

Inheritance Relationships

Base Type

Derived Types

Class Documentation

class CrossEntropyLoss : public marian::LabelwiseLoss

Cross entropy loss across last axis, summed up over batch and time dimensions.

Subclassed by marian::RescorerLoss, marian::SequenceUnlikelihoodLoss

Public Functions

CrossEntropyLoss(float labelSmoothing, float factorWeight)
CrossEntropyLoss(const std::vector<int> &axes, float labelSmoothing, float factorWeight)
virtual ~CrossEntropyLoss()

Protected Functions

virtual Expr compute(Logits logits, const Words &labels, Expr mask = nullptr, Expr labelWeights = nullptr)

Protected Attributes

float labelSmoothing_
float factorWeight_