Class SequenceUnlikelihoodLoss

Inheritance Relationships

Base Type

Class Documentation

class SequenceUnlikelihoodLoss : public marian::CrossEntropyLoss

Unlikelihood loss across last axis, summed up over batch and time dimensions.

This is an implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319. We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going to flip over to use SUL for that sentence to penalize the selected word.

SUL is implemented as: -log(gather(1 - softmax(logits), -1, indices))

Factors are currently not supported.

Public Functions

SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight)
SequenceUnlikelihoodLoss(const std::vector<int> &axes, float labelSmoothing, float factorWeight)

Protected Functions

virtual Expr compute(Logits logits, const Words &labels, Expr mask = nullptr, Expr labelWeights = nullptr)