Program Listing for File loss.h¶
↰ Return to documentation for file (src/layers/loss.h
)
#pragma once
#include "data/types.h"
#include "graph/expression_operators.h"
#include "layers/logits.h" // for Logits (Frank's factor hack)
namespace marian {
class RationalLoss {
protected:
Expr loss_; // numerator
Expr count_; // denominator
RationalLoss() = default; // protected
public:
RationalLoss(Expr loss, Expr count) : loss_(loss), count_(count) {}
RationalLoss(Expr loss, float count)
: loss_(loss), count_(constant_like(loss, inits::fromValue(count))) {}
RationalLoss(const RationalLoss& other) : loss_(other.loss_), count_(other.count_) {}
virtual ~RationalLoss() = default;
Expr loss() const { return loss_; }
// @TODO: remove this function, as it does not add too much value over loss()->get(...)
template <typename T>
void loss(std::vector<T>& losses) const {
ABORT_IF(!loss_, "Loss has not been defined");
loss_->val()->get(losses);
}
template <typename T>
T loss() const { // this will fail if loss is not a single value
ABORT_IF(!loss_, "Loss has not been defined");
return loss_->val()->scalar<T>();
}
Expr count() const { return count_; }
// @TODO: remove this function, as it does not add too much value over count()->get(...)
template <typename T>
void count(std::vector<T>& labels) const {
ABORT_IF(!count_, "Labels have not been defined");
count_->val()->get(labels);
}
template <typename T>
T count() const { // this will fail if loss is not a single value
ABORT_IF(!count_, "Labels have not been defined");
return count_->val()->scalar<T>();
}
// @TODO: add a function for returning maybe ratio?
size_t size() const {
ABORT_IF(!count_, "Labels have not been defined");
return count_->shape().elements();
}
};
struct StaticLoss {
float loss; // numerator
float count; // denominator
StaticLoss() : loss(0.f), count(0.f) {}
StaticLoss(const RationalLoss& dynamic)
: loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
StaticLoss operator+(const StaticLoss& other) const {
StaticLoss res(*this);
res += other;
return res;
}
StaticLoss& operator+=(const StaticLoss& other) {
loss = loss + other.loss;
count = count + other.count;
return *this;
}
void reset() {
loss = 0.f;
count = 0.f;
}
};
class MultiRationalLoss : public RationalLoss {
protected:
std::vector<RationalLoss> partialLosses_;
virtual Expr accumulateLoss(const RationalLoss& current) = 0;
virtual Expr accumulateCount(const RationalLoss& current) = 0;
public:
MultiRationalLoss() : RationalLoss() {}
MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { push_back(rl); }
virtual void push_back(const RationalLoss& current) {
loss_ = accumulateLoss(current);
count_ = accumulateCount(current);
partialLosses_.push_back(current);
}
const RationalLoss& operator[](size_t i) { return partialLosses_[i]; }
auto begin() -> decltype(partialLosses_.begin()) const { return partialLosses_.begin(); }
auto end() -> decltype(partialLosses_.end()) const { return partialLosses_.end(); }
size_t size() const { return partialLosses_.size(); }
};
class SumMultiRationalLoss : public MultiRationalLoss {
private:
virtual Expr accumulateLoss(const RationalLoss& current) override {
if(loss_)
return loss_ + current.loss();
else
return current.loss();
}
virtual Expr accumulateCount(const RationalLoss& current) override {
if(count_)
return count_ + current.count();
else
return current.count();
}
public:
SumMultiRationalLoss() : MultiRationalLoss() {}
SumMultiRationalLoss(const RationalLoss& rl) : MultiRationalLoss(rl) {}
};
class ScaledMultiRationalLoss : public MultiRationalLoss {
private:
virtual Expr accumulateLoss(const RationalLoss& current) override {
if(loss_) {
const auto& first = partialLosses_.front();
return loss_
+ current.loss() * first.count()
/ current.count(); // scale up/down to match scale of first loss
} else {
return current.loss(); // first reference loss, keeps to scale with this one
}
}
virtual Expr accumulateCount(const RationalLoss& current) override {
if(count_) {
return count_; // Keep first label count // or: count_ + first.count() / current.count();
} else {
return current.count(); // This is the first loss
}
}
public:
ScaledMultiRationalLoss() : MultiRationalLoss() {}
ScaledMultiRationalLoss(const RationalLoss& rl) : MultiRationalLoss(rl) {}
};
class MeanMultiRationalLoss : public MultiRationalLoss {
private:
virtual Expr accumulateLoss(const RationalLoss& current) override {
if(loss_)
return loss_ + current.loss() / current.count();
else
return current.loss() / current.count();
}
virtual Expr accumulateCount(const RationalLoss& current) override {
if(count_)
return count_; // keep the existing '1'
else
return current.count()->graph()->ones(
{1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
}
public:
MeanMultiRationalLoss() : MultiRationalLoss() {}
MeanMultiRationalLoss(const RationalLoss& rl) : MultiRationalLoss(rl) {}
};
Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options);
//***********************************************************************************//
// This needs to be refactored. Currently easiest route for backwards compat, but
// still feels somewhat hacky.
class LabelwiseLoss {
protected:
std::vector<int> axes_;
virtual Expr compute(Logits logits,
const Words& labels,
Expr mask = nullptr,
Expr labelWeights = nullptr)
= 0;
// label counts are available, reduce together with loss to obtain counts
RationalLoss reduce(Expr loss, Expr labels) {
ABORT_IF(!loss, "Loss has not been computed");
ABORT_IF(!labels, "Labels have not been computed");
Expr lossSum = cast(loss, Type::float32); // accumulate in float32
Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
for(int i = 0; i < axes_.size(); ++i) {
lossSum = sum(lossSum, axes_[i]);
labelsSum = sum(labelsSum, axes_[i]);
}
return RationalLoss(lossSum, labelsSum);
}
// label counts are not available, assume every element of tensor corresponds to label count 1
RationalLoss reduce(Expr loss) {
ABORT_IF(!loss, "Loss has not been computed");
Expr lossSum = cast(loss, Type::float32);
for(int i = 0; i < axes_.size(); ++i)
lossSum = sum(lossSum, axes_[i]);
// reduction factor tells how over how many labels we reduced in total.
float reducedLabels = (float)loss->shape().elements() / (float)lossSum->shape().elements();
return RationalLoss(lossSum, reducedLabels);
}
public:
LabelwiseLoss(const std::vector<int>& axes) : axes_(axes) {}
virtual RationalLoss apply(Logits logits,
const Words& labels,
Expr mask = nullptr,
Expr labelWeights = nullptr) {
Expr loss = compute(logits, labels, mask, labelWeights);
if(mask)
return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
else
return reduce(loss); // we have no mask, assume all items are labels
}
};
class CrossEntropyLoss : public LabelwiseLoss {
public:
CrossEntropyLoss(float labelSmoothing, float factorWeight)
: CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {
} // cross-entropy already reduces over axis -1
CrossEntropyLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
: LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
labelSmoothing_(labelSmoothing),
factorWeight_(factorWeight) {}
virtual ~CrossEntropyLoss() {}
protected:
float labelSmoothing_; // interpolation factor for label smoothing, see below
float factorWeight_; // give extra weight to factors
virtual Expr compute(Logits logits,
const Words& labels,
Expr mask = nullptr,
Expr labelWeights = nullptr) override {
// logits may be factored; in that case, the getLoss() function computes one loss for each, and
// sums them up
int inFactor = false;
auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
logits = atleast_3d(logits); // we always assume a time and batch dimension exists.
// for bert training or classification the time dimension is lost.
// Here safeguard against 2d classifier output, adds 1 on the left, non-op.
Expr ce = cross_entropy(logits, indices, inFactor ? 0.f : labelSmoothing_, Type::float32);
if(inFactor && factorWeight_ != 1.0f) {
LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_);
ce = ce * factorWeight_;
}
inFactor = true;
return ce;
});
if(mask)
ce = ce * cast(mask, Type::float32);
if(labelWeights) {
// We currently do not know how to use target factors and word-level label weights together
bool wordlevel = labelWeights->shape()[-3]
> 1; // Time-dimension is not trivially 1, hence we have word-level weights.
ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1,
"CE loss with word-level label weights is not implemented for factors");
ce = ce * cast(labelWeights, Type::float32);
}
return ce;
}
};
class SequenceUnlikelihoodLoss : public CrossEntropyLoss {
public:
SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight)
: CrossEntropyLoss(labelSmoothing, factorWeight) {
} // cross-entropy already reduces over axis -1
SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
: CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
protected:
virtual Expr compute(Logits logits,
const Words& labels,
Expr mask = nullptr,
Expr labelWeights = nullptr) override {
auto ce = CrossEntropyLoss::compute(
logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
if(!labelWeights)
return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
// We currently do not know how to use target factors and word-level label weights together
ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by
// default 1, which would make this obsolete.
// use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask
// again to eliminate padding (might be obsolete)
auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
return cast(unlikelihood(logits, indices), Type::float32);
});
// compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only
// on_ the errors with UL. This is the "mixed" training schedule from
// https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily
// switch between CE and UL.
auto onlyCe = eq(sum(errorMask, /*axis=*/-3),
0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
ceUl = errorMask * ceUl; // don't use for correct label or padding
auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never
// simultanously used as cost per batch entry
return cost;
}
};
class RescorerLoss : public CrossEntropyLoss {
private:
bool wordScores_{false}; // compute word-level log probabilities
public:
// For sentence-wise CE reduce only over time axis.
// For word-level CE do not reduce over any axis.
RescorerLoss(bool wordScores)
: CrossEntropyLoss(/*axes=*/wordScores ? std::vector<int>({}) : std::vector<int>({-3}),
/*smoothing=*/0.f,
/*factorWeight=*/1.0f),
wordScores_(wordScores) {}
virtual RationalLoss apply(Logits logits,
const Words& labels,
Expr mask = nullptr,
Expr labelWeights = nullptr) override {
auto loss = CrossEntropyLoss::compute(logits, labels, mask, labelWeights);
if(!wordScores_) { // for sentence-level CE, reduce loss and labels as in cross-entropy
return reduce(loss, mask);
} else { // for word-level CE, reduce labels only to get sentence lengths
ABORT_IF(!loss, "Loss has not been computed");
ABORT_IF(!mask, "Word-level CE from rescorer must have mask");
Expr labelsSum = cast(mask, Type::float32); // accumulate in float32
labelsSum = sum(labelsSum, -3); // reduce over time axis to get sentence lengths
return RationalLoss(loss, labelsSum);
}
}
};
Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference);
} // namespace marian