Program Listing for File loss.h

Return to documentation for file (src/layers/loss.h)

#pragma once

#include "data/types.h"
#include "graph/expression_operators.h"
#include "layers/logits.h"  // for Logits (Frank's factor hack)

namespace marian {

class RationalLoss {
protected:
  Expr loss_;   // numerator
  Expr count_;  // denominator

  RationalLoss() = default;  // protected

public:
  RationalLoss(Expr loss, Expr count) : loss_(loss), count_(count) {}

  RationalLoss(Expr loss, float count)
      : loss_(loss), count_(constant_like(loss, inits::fromValue(count))) {}

  RationalLoss(const RationalLoss& other) : loss_(other.loss_), count_(other.count_) {}

  virtual ~RationalLoss() = default;

  Expr loss() const { return loss_; }

  // @TODO: remove this function, as it does not add too much value over loss()->get(...)
  template <typename T>
  void loss(std::vector<T>& losses) const {
    ABORT_IF(!loss_, "Loss has not been defined");
    loss_->val()->get(losses);
  }

  template <typename T>
  T loss() const {  // this will fail if loss is not a single value
    ABORT_IF(!loss_, "Loss has not been defined");
    return loss_->val()->scalar<T>();
  }

  Expr count() const { return count_; }

  // @TODO: remove this function, as it does not add too much value over count()->get(...)
  template <typename T>
  void count(std::vector<T>& labels) const {
    ABORT_IF(!count_, "Labels have not been defined");
    count_->val()->get(labels);
  }

  template <typename T>
  T count() const {  // this will fail if loss is not a single value
    ABORT_IF(!count_, "Labels have not been defined");
    return count_->val()->scalar<T>();
  }

  // @TODO: add a function for returning maybe ratio?

  size_t size() const {
    ABORT_IF(!count_, "Labels have not been defined");
    return count_->shape().elements();
  }
};

struct StaticLoss {
  float loss;   // numerator
  float count;  // denominator

  StaticLoss() : loss(0.f), count(0.f) {}

  StaticLoss(const RationalLoss& dynamic)
      : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}

  StaticLoss operator+(const StaticLoss& other) const {
    StaticLoss res(*this);
    res += other;
    return res;
  }

  StaticLoss& operator+=(const StaticLoss& other) {
    loss = loss + other.loss;
    count = count + other.count;
    return *this;
  }

  void reset() {
    loss = 0.f;
    count = 0.f;
  }
};

class MultiRationalLoss : public RationalLoss {
protected:
  std::vector<RationalLoss> partialLosses_;

  virtual Expr accumulateLoss(const RationalLoss& current) = 0;

  virtual Expr accumulateCount(const RationalLoss& current) = 0;

public:
  MultiRationalLoss() : RationalLoss() {}

  MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { push_back(rl); }

  virtual void push_back(const RationalLoss& current) {
    loss_ = accumulateLoss(current);
    count_ = accumulateCount(current);
    partialLosses_.push_back(current);
  }

  const RationalLoss& operator[](size_t i) { return partialLosses_[i]; }

  auto begin() -> decltype(partialLosses_.begin()) const { return partialLosses_.begin(); }

  auto end() -> decltype(partialLosses_.end()) const { return partialLosses_.end(); }

  size_t size() const { return partialLosses_.size(); }
};

class SumMultiRationalLoss : public MultiRationalLoss {
private:
  virtual Expr accumulateLoss(const RationalLoss& current) override {
    if(loss_)
      return loss_ + current.loss();
    else
      return current.loss();
  }

  virtual Expr accumulateCount(const RationalLoss& current) override {
    if(count_)
      return count_ + current.count();
    else
      return current.count();
  }

public:
  SumMultiRationalLoss() : MultiRationalLoss() {}
  SumMultiRationalLoss(const RationalLoss& rl) : MultiRationalLoss(rl) {}
};

class ScaledMultiRationalLoss : public MultiRationalLoss {
private:
  virtual Expr accumulateLoss(const RationalLoss& current) override {
    if(loss_) {
      const auto& first = partialLosses_.front();
      return loss_
             + current.loss() * first.count()
                   / current.count();  // scale up/down to match scale of first loss
    } else {
      return current.loss();  // first reference loss, keeps to scale with this one
    }
  }

  virtual Expr accumulateCount(const RationalLoss& current) override {
    if(count_) {
      return count_;  // Keep first label count // or: count_ + first.count() / current.count();
    } else {
      return current.count();  // This is the first loss
    }
  }

public:
  ScaledMultiRationalLoss() : MultiRationalLoss() {}
  ScaledMultiRationalLoss(const RationalLoss& rl) : MultiRationalLoss(rl) {}
};

class MeanMultiRationalLoss : public MultiRationalLoss {
private:
  virtual Expr accumulateLoss(const RationalLoss& current) override {
    if(loss_)
      return loss_ + current.loss() / current.count();
    else
      return current.loss() / current.count();
  }

  virtual Expr accumulateCount(const RationalLoss& current) override {
    if(count_)
      return count_;  // keep the existing '1'
    else
      return current.count()->graph()->ones(
          {1}, current.loss()->value_type());  // just '1' as labels are factored into loss_
  }

public:
  MeanMultiRationalLoss() : MultiRationalLoss() {}
  MeanMultiRationalLoss(const RationalLoss& rl) : MultiRationalLoss(rl) {}
};

Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options);

//***********************************************************************************//
// This needs to be refactored. Currently easiest route for backwards compat, but
// still feels somewhat hacky.

class LabelwiseLoss {
protected:
  std::vector<int> axes_;

  virtual Expr compute(Logits logits,
                       const Words& labels,
                       Expr mask = nullptr,
                       Expr labelWeights = nullptr)
      = 0;

  // label counts are available, reduce together with loss to obtain counts
  RationalLoss reduce(Expr loss, Expr labels) {
    ABORT_IF(!loss, "Loss has not been computed");
    ABORT_IF(!labels, "Labels have not been computed");

    Expr lossSum = cast(loss, Type::float32);      // accumulate in float32
    Expr labelsSum = cast(labels, Type::float32);  // accumulate in float32
    for(int i = 0; i < axes_.size(); ++i) {
      lossSum = sum(lossSum, axes_[i]);
      labelsSum = sum(labelsSum, axes_[i]);
    }

    return RationalLoss(lossSum, labelsSum);
  }

  // label counts are not available, assume every element of tensor corresponds to label count 1
  RationalLoss reduce(Expr loss) {
    ABORT_IF(!loss, "Loss has not been computed");

    Expr lossSum = cast(loss, Type::float32);
    for(int i = 0; i < axes_.size(); ++i)
      lossSum = sum(lossSum, axes_[i]);

    // reduction factor tells how over how many labels we reduced in total.
    float reducedLabels = (float)loss->shape().elements() / (float)lossSum->shape().elements();
    return RationalLoss(lossSum, reducedLabels);
  }

public:
  LabelwiseLoss(const std::vector<int>& axes) : axes_(axes) {}

  virtual RationalLoss apply(Logits logits,
                             const Words& labels,
                             Expr mask = nullptr,
                             Expr labelWeights = nullptr) {
    Expr loss = compute(logits, labels, mask, labelWeights);

    if(mask)
      return reduce(loss, mask);  // mask can be used as element-wise label count with broadcasting
    else
      return reduce(loss);  // we have no mask, assume all items are labels
  }
};

class CrossEntropyLoss : public LabelwiseLoss {
public:
  CrossEntropyLoss(float labelSmoothing, float factorWeight)
      : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {
  }  // cross-entropy already reduces over axis -1

  CrossEntropyLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
      : LabelwiseLoss(axes),  // cross-entropy already reduces over axis -1
        labelSmoothing_(labelSmoothing),
        factorWeight_(factorWeight) {}

  virtual ~CrossEntropyLoss() {}

protected:
  float labelSmoothing_;  // interpolation factor for label smoothing, see below
  float factorWeight_;    // give extra weight to factors

  virtual Expr compute(Logits logits,
                       const Words& labels,
                       Expr mask = nullptr,
                       Expr labelWeights = nullptr) override {
    // logits may be factored; in that case, the getLoss() function computes one loss for each, and
    // sums them up
    int inFactor = false;
    auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
      logits = atleast_3d(logits);  // we always assume a time and batch dimension exists.
      // for bert training or classification the time dimension is lost.
      // Here safeguard against 2d classifier output, adds 1 on the left, non-op.

      Expr ce = cross_entropy(logits, indices, inFactor ? 0.f : labelSmoothing_, Type::float32);
      if(inFactor && factorWeight_ != 1.0f) {
        LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_);
        ce = ce * factorWeight_;
      }
      inFactor = true;
      return ce;
    });

    if(mask)
      ce = ce * cast(mask, Type::float32);

    if(labelWeights) {
      // We currently do not know how to use target factors and word-level label weights together
      bool wordlevel = labelWeights->shape()[-3]
                       > 1;  // Time-dimension is not trivially 1, hence we have word-level weights.
      ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1,
               "CE loss with word-level label weights is not implemented for factors");
      ce = ce * cast(labelWeights, Type::float32);
    }

    return ce;
  }
};

class SequenceUnlikelihoodLoss : public CrossEntropyLoss {
public:
  SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight)
      : CrossEntropyLoss(labelSmoothing, factorWeight) {
  }  // cross-entropy already reduces over axis -1

  SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
      : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}

protected:
  virtual Expr compute(Logits logits,
                       const Words& labels,
                       Expr mask = nullptr,
                       Expr labelWeights = nullptr) override {
    auto ce = CrossEntropyLoss::compute(
        logits, labels, mask, /*labelWeights=*/nullptr);  // don't pass label-weights to CE
    if(!labelWeights)
      return ce;  // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?

    // We currently do not know how to use target factors and word-level label weights together
    ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");

    ABORT_IF(!mask, "mask is required");  // @TODO: check this, it seems weights for padding are by
                                          // default 1, which would make this obsolete.
    // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask
    // again to eliminate padding (might be obsolete)
    auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);

    auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
      return cast(unlikelihood(logits, indices), Type::float32);
    });

    // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only
    // on_ the errors with UL. This is the "mixed" training schedule from
    // https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily
    // switch between CE and UL.
    auto onlyCe = eq(sum(errorMask, /*axis=*/-3),
                     0.f);    // [1, 1, dimBatch, 1] - equal 1 if no errors are present
    ceUl = errorMask * ceUl;  // don't use for correct label or padding

    auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl;  // ce or unlikelihood part are never
                                                      // simultanously used as cost per batch entry

    return cost;
  }
};

class RescorerLoss : public CrossEntropyLoss {
private:
  bool wordScores_{false};  // compute word-level log probabilities

public:
  // For sentence-wise CE reduce only over time axis.
  // For word-level CE do not reduce over any axis.
  RescorerLoss(bool wordScores)
      : CrossEntropyLoss(/*axes=*/wordScores ? std::vector<int>({}) : std::vector<int>({-3}),
                         /*smoothing=*/0.f,
                         /*factorWeight=*/1.0f),
        wordScores_(wordScores) {}

  virtual RationalLoss apply(Logits logits,
                             const Words& labels,
                             Expr mask = nullptr,
                             Expr labelWeights = nullptr) override {
    auto loss = CrossEntropyLoss::compute(logits, labels, mask, labelWeights);

    if(!wordScores_) {  // for sentence-level CE, reduce loss and labels as in cross-entropy
      return reduce(loss, mask);
    } else {  // for word-level CE, reduce labels only to get sentence lengths
      ABORT_IF(!loss, "Loss has not been computed");
      ABORT_IF(!mask, "Word-level CE from rescorer must have mask");

      Expr labelsSum = cast(mask, Type::float32);  // accumulate in float32
      labelsSum = sum(labelsSum, -3);              // reduce over time axis to get sentence lengths
      return RationalLoss(loss, labelsSum);
    }
  }
};

Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference);

}  // namespace marian