Program Listing for File batch_stats.h¶
↰ Return to documentation for file (src/data/batch_stats.h
)
#pragma once
#include <deque>
#include <queue>
#include "data/corpus.h"
#include "data/vocab.h"
namespace marian {
namespace data {
class BatchStats {
private:
std::map<std::vector<size_t>, size_t> map_; // [(src len, tgt len)] -> batch size
public:
BatchStats() { }
typedef std::map<std::vector<size_t>, size_t>::const_iterator const_iterator;
const_iterator begin() const { return map_.begin(); }
const_iterator lower_bound(const std::vector<size_t>& lengths) const { return map_.lower_bound(lengths); }
size_t findBatchSize(const std::vector<size_t>& lengths, const_iterator& it) const {
// find the first item where all item.first[i] >= lengths[i], i.e. that can fit sentence tuples of lengths[]
// This is expected to be called multiple times with increasing sentence lengths.
// To get an initial value for 'it', call lower_bound() or begin().
bool done = false;
while (!done && it != map_.end()) {
done = true;
for(size_t i = 0; i < lengths.size(); ++i)
while(it != map_.end() && it->first[i] < lengths[i]) {
it++;
done = false; // it++ might have decreased a key[<i], so we must check once again
}
}
ABORT_IF(it == map_.end(), "Missing batch statistics");
return it->second;
}
void add(Ptr<data::CorpusBatch> batch, double multiplier = 1.) {
std::vector<size_t> lengths;
for(size_t i = 0; i < batch->sets(); ++i)
lengths.push_back((*batch)[i]->batchWidth());
size_t batchSize = (size_t)ceil((double)batch->size() * multiplier);
if(map_[lengths] < batchSize)
map_[lengths] = batchSize;
}
// return a rough minibatch size in labels
// We average over all (batch sizes * max trg length).
size_t estimateTypicalTrgWords() const {
size_t sum = 0;
for (const auto& entry : map_) {
auto maxTrgLength = entry.first.back();
auto numSentences = entry.second;
sum += numSentences * maxTrgLength;
}
return sum / map_.size();
}
// helpers for multi-node --note: presently unused, but keeping them around for later use
// serialize into a flat vector, for MPI data exchange
std::vector<size_t> flatten() const {
std::vector<size_t> res;
if(map_.empty())
return res;
auto numStreams = map_.begin()->first.size();
// format:
// - num streams
// - tuples ((stream sizes), )
res.push_back(numStreams);
for (const auto& entry : map_) {
ABORT_IF(entry.first.size() != numStreams, "inconsistent number of streams??");
for (auto streamLen : entry.first)
res.push_back(streamLen);
res.push_back(entry.second);
}
return res;
}
// deserialize a flattened batchStats
// used as part of MPI data exchange
BatchStats(const std::vector<size_t>& flattenedStats) {
if (flattenedStats.empty())
return;
size_t i = 0;
auto numStreams = flattenedStats[i++];
std::vector<size_t> lengths(numStreams);
while (i < flattenedStats.size()) {
for(auto& length : lengths)
length = flattenedStats[i++];
auto batchSize = flattenedStats[i++];
map_[lengths] = batchSize;
}
ABORT_IF(i != flattenedStats.size(), "invalid flattenedVector??");
//dump();
}
void dump() { // (for debugging)
for (const auto& entry : map_) {
for (auto streamLen : entry.first)
std::cerr << streamLen << " ";
std::cerr << ": " << entry.second << std::endl;
}
}
};
} // namespace data
} // namespace marian