Program Listing for File weight.cpp¶
↰ Return to documentation for file (src/layers/weight.cpp
)
#include "layers/weight.h"
namespace marian {
Ptr<WeightingBase> WeightingFactory(Ptr<Options> options) {
ABORT_IF(!options->hasAndNotEmpty("data-weighting"),
"No data-weighting specified in options");
return New<DataWeighting>(options->get<std::string>("data-weighting-type"));
}
Expr DataWeighting::getWeights(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch) {
ABORT_IF(batch->getDataWeights().empty(),
"Vector of weights is unexpectedly empty!");
bool sentenceWeighting = weightingType_ == "sentence";
int dimBatch = (int)batch->size();
int dimWords = sentenceWeighting ? 1 : (int)batch->back()->batchWidth();
// This would abort anyway in fromVector(...), but has clearer error message
// here for this particular case
ABORT_IF(batch->getDataWeights().size() != dimWords * dimBatch,
"Number of sentence/word-level weights ({}) does not match tensor size ({})",
batch->getDataWeights().size(), dimWords * dimBatch);
auto weights = graph->constant({1, dimWords, dimBatch, 1},
inits::fromVector(batch->getDataWeights()));
return weights; // [1, dimWords, dimBatch, 1] in case of word-level weights or
// [1, 1, dimBatch, 1] in case of sentence-level weights
}
} // namespace marian