Program Listing for File loss.cpp

Return to documentation for file (src/layers/loss.cpp)

#include "layers/loss.h"

namespace marian {

// @TODO, simplify this. Currently here for back-compat
Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) {
  float smoothing = inference ? 0.f : options->get<float>("label-smoothing");
  float factorWeight = options->get<float>("factor-weight", 1.0f);
  std::string costType = options->get<std::string>("cost-type", "ce-mean");
  bool unlikelihood = options->get<bool>("unlikelihood-loss", false);

  if(costType == "ce-rescore") {  // per-batch-item scores (while ce-mean reduces over batch)
    bool wordScores = options->get<bool>("word-scores", false);
    return New<RescorerLoss>(wordScores);
  } else if(unlikelihood) {
    ABORT_IF(
        !options->hasAndNotEmpty("data-weighting")
            && options->get<std::string>("data-weighting-type") != "word",
        "Unlikelihood loss training requires error annotation in form of per-target-label scores");
    return New<SequenceUnlikelihoodLoss>(
        smoothing, factorWeight);  // this is a mix of CE-loss and unlikelihood less depending on
                                   // values given for data-weighting
  } else {  // same as ce-mean  --@TODO: better check all allowed values, and fail for invalid ones.
            // E.g. what about ce-sum?
    return New<CrossEntropyLoss>(smoothing, factorWeight);
  }
}

// see loss.h for detailed explanations of each class
Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
  std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
  if(multiLossType == "sum")  // sum of sums
    return New<SumMultiRationalLoss>();
  else if(multiLossType == "scaled")  // sum of scaled sums, first element is reference scale
    return New<ScaledMultiRationalLoss>();
  else if(multiLossType == "mean")  // sum of means
    return New<MeanMultiRationalLoss>();
  else
    ABORT("Unknown multi-loss-type {}", multiLossType);
}

}  // namespace marian