Program Listing for File scheduling_parameter.h

Return to documentation for file (src/common/scheduling_parameter.h)

#pragma once

#include "common/logging.h"
#include "common/utils.h"

#include <string>

namespace marian {

// support for scheduling parameters that can be expressed with a unit, such as --lr-decay-inv-sqrt
enum class SchedulingUnit {
  trgLabels, // "t": number of target labels seen so far
  updates,   // "u": number of updates so far (batches)
  epochs     // "e": number of epochs begun so far (very first epoch is 1)
};

struct SchedulingParameter {
  size_t n{0};                                  // number of steps measured in 'unit'
  SchedulingUnit unit{SchedulingUnit::updates}; // unit of value

  // parses scheduling parameters of the form NU where N=unsigned int and U=unit
  // Examples of valid inputs: "16000u" (16000 updates), "32000000t" (32 million target labels),
  // "100e" (100 epochs).
  static SchedulingParameter parse(std::string param) {
    SchedulingParameter res;
    if(!param.empty() && param.back() >= 'a') {
      switch(param.back()) {
        case 't': res.unit = SchedulingUnit::trgLabels; break;
        case 'u': res.unit = SchedulingUnit::updates;   break;
        case 'e': res.unit = SchedulingUnit::epochs;    break;
        default: ABORT("invalid unit '{}' in {}", param.back(), param);
      }
      param.pop_back();
    }
    double number = utils::parseNumber(param);
    res.n = (size_t)number;
    ABORT_IF(number != (double)res.n, "Scheduling parameters must be whole numbers"); // @TODO: do they?
    return res;
  }

  operator bool() const { return n > 0; } // check whether it is specified

  operator std::string() const { // convert back for storing in config
    switch(unit) {
      case SchedulingUnit::trgLabels: return std::to_string(n) + "t";
      case SchedulingUnit::updates  : return std::to_string(n) + "u";
      case SchedulingUnit::epochs   : return std::to_string(n) + "e";
      default: ABORT("corrupt enum value for scheduling unit");
    }
  }
};

}