Class MultiRationalLoss

Inheritance Relationships

Base Type

Derived Types

Class Documentation

class MultiRationalLoss : public marian::RationalLoss

Base class for multi-objective losses Base class for multi-objective losses which is a list of RationalLoss but also defines how to accumulate that list into a single RationalLoss.

Subclassed by marian::MeanMultiRationalLoss, marian::ScaledMultiRationalLoss, marian::SumMultiRationalLoss

Public Functions

MultiRationalLoss(const RationalLoss &rl)
virtual void push_back(const RationalLoss &current)
const RationalLoss &operator[](size_t i)
auto begin()
auto end()
size_t size() const

Protected Functions

virtual Expr accumulateLoss(const RationalLoss &current) = 0

Accumulation rule for losses In the default case this would just be a sum, see SumMultiRationalLoss, but there are special cases like ScaledMultiRationalLoss (scale other loses according to first label count) or MeanMultiRationalLoss (sum of means) where the accumulation is more complex.

virtual Expr accumulateCount(const RationalLoss &current) = 0

Accumulation rule for labels Similar as above, the naive case is summation, but for instance MeanMultiRationalLoss is including all label counts in the loss hence label counts are always just 1 which is passed through without summation or other modifications.

Protected Attributes

std::vector<RationalLoss> partialLosses_