Program Listing for File weight.h

Return to documentation for file (src/layers/weight.h)

#pragma once

#include "common/options.h"
#include "data/corpus.h"
#include "graph/expression_graph.h"
#include "graph/expression_operators.h"
#include "graph/node_initializers.h"

namespace marian {

class WeightingBase {
public:
  WeightingBase(){};
  virtual Expr getWeights(Ptr<ExpressionGraph> graph,
                          Ptr<data::CorpusBatch> batch)
      = 0;
  virtual void debugWeighting(std::vector<float> /*weightedMask*/,
                              std::vector<float> /*freqMask*/,
                              Ptr<data::CorpusBatch> /*batch*/){};
  virtual ~WeightingBase() {}
};

class DataWeighting : public WeightingBase {
protected:
  std::string weightingType_;

public:
  DataWeighting(std::string weightingType)
      : WeightingBase(), weightingType_(weightingType){};
  Expr getWeights(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) override;
};

Ptr<WeightingBase> WeightingFactory(Ptr<Options> options);
}  // namespace marian