Program Listing for File corpus_base.h¶
↰ Return to documentation for file (src/data/corpus_base.h
)
#pragma once
#include "common/definitions.h"
#include "common/file_stream.h"
#include "common/options.h"
#include "common/utils.h"
#include "data/alignment.h"
#include "data/iterator_facade.h"
#include "data/batch.h"
#include "data/dataset.h"
#include "data/rng_engine.h"
#include "data/vocab.h"
#include <future>
namespace marian {
namespace data {
class SentenceTupleImpl {
private:
size_t id_;
std::vector<Words> tuple_; // [stream index][step index]
std::vector<float> weights_; // [stream index]
WordAlignment alignment_;
bool altered_ = false;
public:
typedef Words value_type;
SentenceTupleImpl() : id_(0) {}
SentenceTupleImpl(size_t id) : id_(id) {}
~SentenceTupleImpl() {}
size_t getId() const { return id_; }
bool isAltered() const { return altered_; }
void markAltered() { altered_ = true; }
void push_back(const Words& words) { tuple_.push_back(words); }
size_t size() const { return tuple_.size(); }
Words& operator[](size_t i) { return tuple_[i]; }
const Words& operator[](size_t i) const { return tuple_[i]; }
Words& back() { return tuple_.back(); }
const Words& back() const { return tuple_.back(); }
bool empty() const { return tuple_.empty(); }
auto begin() const -> decltype(tuple_.begin()) { return tuple_.begin(); }
auto end() const -> decltype(tuple_.end()) { return tuple_.end(); }
auto rbegin() const -> decltype(tuple_.rbegin()) { return tuple_.rbegin(); }
auto rend() const -> decltype(tuple_.rend()) { return tuple_.rend(); }
const std::vector<float>& getWeights() const { return weights_; }
void setWeights(const std::vector<float>& weights);
const WordAlignment& getAlignment() const { return alignment_; }
void setAlignment(const WordAlignment& alignment) { alignment_ = alignment; }
};
class SentenceTuple {
private:
std::shared_ptr<std::future<SentenceTupleImpl>> fImpl_;
mutable std::shared_ptr<SentenceTupleImpl> impl_;
public:
typedef Words value_type;
SentenceTuple() {}
SentenceTuple(const SentenceTupleImpl& tupImpl)
: impl_(std::make_shared<SentenceTupleImpl>(tupImpl)) {}
SentenceTuple(std::future<SentenceTupleImpl>&& fImpl)
: fImpl_(new std::future<SentenceTupleImpl>(std::move(fImpl))) {}
SentenceTupleImpl& get() const {
if(!impl_) {
ABORT_IF(!fImpl_ || !fImpl_->valid(), "No future tuple associated with SentenceTuple");
impl_ = std::make_shared<SentenceTupleImpl>(fImpl_->get());
}
return *impl_;
}
size_t getId() const { return get().getId(); }
bool isAltered() const { return get().isAltered(); }
size_t size() const { return get().size(); }
bool valid() const {
return fImpl_ || impl_;
}
Words& operator[](size_t i) { return get()[i]; }
const Words& operator[](size_t i) const { return get()[i]; }
Words& back() { return get().back(); }
const Words& back() const { return get().back(); }
bool empty() const { return get().empty(); }
auto begin() const -> decltype(get().begin()) { return get().begin(); }
auto end() const -> decltype(get().end()) { return get().end(); }
auto rbegin() const -> decltype(get().rbegin()) { return get().rbegin(); }
auto rend() const -> decltype(get().rend()) { return get().rend(); }
const std::vector<float>& getWeights() const { return get().getWeights(); }
const WordAlignment& getAlignment() const { return get().getAlignment(); }
};
class SubBatch {
private:
Words indices_;
std::vector<float> mask_;
size_t size_;
size_t width_;
size_t words_;
Ptr<const Vocab> vocab_;
// ... TODO: add the length information (remember it)
public:
SubBatch(size_t size, size_t width, const Ptr<const Vocab>& vocab)
: indices_(size * width, vocab ? vocab->getEosId() : Word::ZERO), // note: for gaps, we must use a valid index
mask_(size * width, 0),
size_(size),
width_(width),
words_(0),
vocab_(vocab) {}
Words& data() { return indices_; }
const Words& data() const { return indices_; }
size_t locate(size_t batchIdx, size_t wordPos) const { return locate(batchIdx, wordPos, size_); }
static size_t locate(size_t batchIdx, size_t wordPos, size_t batchSize) { return wordPos * batchSize + batchIdx; }
std::vector<float>& mask() { return mask_; }
const std::vector<float>& mask() const { return mask_; }
const Ptr<const Vocab>& vocab() const { return vocab_; }
size_t batchSize() const { return size_; }
size_t batchWidth() const { return width_; };
size_t batchWords() const { return words_; }
std::vector<Ptr<SubBatch>> split(size_t n, size_t sizeLimit /*or SIZE_MAX*/) const {
ABORT_IF(size_ == 0, "Encountered sub-batch size of 0");
auto size = std::min(size_, sizeLimit); // if limit is given then pretend the batch only has that many sentences
size_t targetSubSize = (size_t)(std::ceil(size / (float)n)); // aim at forming sub-batches of this #sentences
std::vector<Ptr<SubBatch>> splits;
for(size_t pos = 0; pos < size; pos += targetSubSize) { // loop over ranges of size targetSubSize to form sub-batches of this size
size_t subSize = std::min(targetSubSize, size - pos); // actual number of sentences can be smaller at the end
// determine actual width (=max length) of this sub-batch, which may be smaller than the overall max length
size_t subWidth = 0;
for(size_t s = 0; s < width_; ++s) {
for(size_t b = 0; b < subSize; ++b) {
if(mask_[locate(/*batchIdx=*/pos + b, /*wordPos=*/s)] != 0) // s * size_ + (pos + b)
if (subWidth < s + 1)
subWidth = s + 1;
}
}
// create sub-batch
auto sb = New<SubBatch>(subSize, subWidth, vocab_);
size_t words = 0;
for(size_t s = 0; s < subWidth; ++s) {
for(size_t b = 0; b < subSize; ++b) {
sb->data()[locate(/*batchIdx=*/b, /*wordPos=*/s, /*batchSize=*/subSize)/*s * subSize + b*/] = indices_[locate(/*batchIdx=*/pos + b, /*wordPos=*/s)]; // s * size_ + (pos + b)
sb->mask()[locate(/*batchIdx=*/b, /*wordPos=*/s, /*batchSize=*/subSize)/*s * subSize + b*/] = mask_[locate(/*batchIdx=*/pos + b, /*wordPos=*/s)]; // s * size_ + (pos + b)
if(mask_[locate(/*batchIdx=*/pos + b, /*wordPos=*/s)/*s * size_ + (pos + b)*/] != 0)
words++;
}
}
sb->setWords(words);
splits.push_back(sb);
}
return splits;
}
void setWords(size_t words) { words_ = words; }
};
class CorpusBatch : public Batch {
protected:
std::vector<Ptr<SubBatch>> subBatches_;
std::vector<WordAlignment> guidedAlignment_; // [max source len, batch size, max target len] flattened
std::vector<float> dataWeights_;
public:
CorpusBatch(const std::vector<Ptr<SubBatch>>& subBatches)
: subBatches_(subBatches) {}
Ptr<SubBatch> operator[](size_t i) const { return subBatches_[i]; }
Ptr<SubBatch> front() { return subBatches_.front(); }
Ptr<SubBatch> back() { return subBatches_.back(); }
size_t size() const override { return subBatches_[0]->batchSize(); }
size_t words(int which = 0) const override {
return subBatches_[which >= 0 ? which
: which + (ptrdiff_t)subBatches_.size()]
->batchWords();
}
size_t width() const override { return subBatches_[0]->batchWidth(); }
size_t sizeTrg() const override { return subBatches_.back()->batchSize(); }
size_t wordsTrg() const override { return subBatches_.back()->batchWords(); };
size_t widthTrg() const override { return subBatches_.back()->batchWidth(); };
size_t sets() const { return subBatches_.size(); }
static Ptr<CorpusBatch> fakeBatch(const std::vector<size_t>& lengths,
const std::vector<Ptr<Vocab>>& vocabs,
size_t batchSize,
Ptr<Options> options) {
std::vector<Ptr<SubBatch>> batches;
size_t batchIndex = 0;
for(auto len : lengths) {
auto sb = New<SubBatch>(batchSize, len, vocabs[batchIndex]);
// set word indices to random values (not actually needed with current version --@marcinjd: please confirm)
std::transform(sb->data().begin(), sb->data().end(), sb->data().begin(),
[&](Word) -> Word { return vocabs[batchIndex]->randWord(); });
// mask: no items ask being masked out
std::fill(sb->mask().begin(), sb->mask().end(), 1.f);
batchIndex++;
batches.push_back(sb);
}
auto batch = New<CorpusBatch>(batches);
if(!options)
return batch;
if(options->get("guided-alignment", std::string("none")) != "none") {
// @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths
std::vector<data::WordAlignment> alignment;
for(size_t k = 0; k < batchSize; ++k) {
data::WordAlignment perSentence;
// fill with random alignment points, add more twice the number of words to be safe.
for(size_t j = 0; j < lengths.back() * 2; ++j) {
size_t i = rand() % lengths.back();
perSentence.push_back(i, j, 1.0f);
}
alignment.push_back(std::move(perSentence));
}
batch->setGuidedAlignment(std::move(alignment));
}
if(options->hasAndNotEmpty("data-weighting")) {
auto weightsSize = batchSize;
if(options->get<std::string>("data-weighting-type") != "sentence")
weightsSize *= lengths.back();
std::vector<float> weights(weightsSize, 1.f);
batch->setDataWeights(weights);
}
return batch;
}
std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit /*=SIZE_MAX*/) override {
ABORT_IF(size() == 0, "Encountered batch size of 0");
std::vector<std::vector<Ptr<SubBatch>>> subs; // [subBatchIndex][streamIndex]
// split each stream separately
for(auto batchStream : subBatches_) {
size_t i = 0; // index into split batch
for(auto splitSubBatch : batchStream->split(n, sizeLimit)) { // splits a batch into pieces, can also change width
if(subs.size() <= i)
subs.resize(i + 1);
subs[i++].push_back(splitSubBatch); // this forms tuples across streams
}
}
// create batches from split subbatches
std::vector<Ptr<Batch>> splits;
for(auto subBatches : subs)
splits.push_back(New<CorpusBatch>(subBatches));
// set sentence indices in split batches
size_t pos = 0;
for(auto split : splits) {
std::vector<size_t> ids;
for(size_t i = pos; i < pos + split->size(); ++i)
ids.push_back(sentenceIds_[i]);
split->setSentenceIds(ids);
pos += split->size();
}
if(!guidedAlignment_.empty()) {
pos = 0;
for(auto split : splits) {
auto cb = std::static_pointer_cast<CorpusBatch>(split);
size_t dimBatch = cb->size();
std::vector<WordAlignment> batchAlignment;
for(size_t i = 0; i < dimBatch; ++i)
batchAlignment.push_back(std::move(guidedAlignment_[i + pos]));
cb->setGuidedAlignment(std::move(batchAlignment));
pos += dimBatch;
}
}
// restore data weights in split batches
pos = 0;
if(!dataWeights_.empty()) {
size_t oldSize = size();
for(auto split : splits) {
auto cb = std::static_pointer_cast<CorpusBatch>(split);
size_t width = 1; // One weight per sentence in case of sentence-level weights
if(dataWeights_.size() != oldSize) // if number of weights does not correspond to number of sentences we have word-level weights
width = cb->back()->batchWidth(); // splitting also affects width, hence we need to accomodate this here
std::vector<float> ws(width * split->size(), 1.0f);
// this needs to be split along the batch dimension
// which is here the innermost dimension.
// Should work for sentence-based weights, too.
for(size_t s = 0; s < width; ++s) {
for(size_t b = 0; b < split->size(); ++b) {
ws[s * split->size() + b] = dataWeights_[s * oldSize + b + pos]; // @TODO: use locate() as well
}
}
split->setDataWeights(ws);
pos += split->size();
}
}
return splits;
}
const std::vector<WordAlignment>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
void setGuidedAlignment(std::vector<WordAlignment>&& aln) override {
guidedAlignment_ = std::move(aln);
}
std::vector<float>& getDataWeights() { return dataWeights_; }
void setDataWeights(const std::vector<float>& weights) override {
dataWeights_ = weights;
}
void debug(bool printIndices = false) override { // prints word string if subbatch has vocab and
// printIndices == false otherwise only numeric indices
std::cerr << "batches: " << sets() << std::endl;
if(!sentenceIds_.empty()) {
std::cerr << "indices: ";
for(auto id : sentenceIds_)
std::cerr << id << " ";
std::cerr << std::endl;
}
size_t subBatchIndex = 0;
for(auto sb : subBatches_) {
std::cerr << "stream " << subBatchIndex++ << ": " << std::endl;
const auto& vocab = sb->vocab();
for(size_t s = 0; s < sb->batchWidth(); s++) {
std::cerr << "\t w: ";
for(size_t b = 0; b < sb->batchSize(); b++) {
Word w = sb->data()[sb->locate(/*batchIdx=*/b, /*wordPos=*/s)]; // s * sb->batchSize() + b;
if (vocab && !printIndices)
std::cerr << (*vocab)[w] << " ";
else
std::cerr << w.toString() << " "; // if not loaded then print numeric id instead
}
std::cerr << std::endl;
}
}
if(!dataWeights_.empty()) {
std::cerr << "weights: ";
for(auto w : dataWeights_)
std::cerr << w << " ";
std::cerr << std::endl;
}
}
};
class CorpusIterator;
class CorpusBase : public DatasetBase<SentenceTuple, CorpusIterator, CorpusBatch>, public RNGEngine {
public:
typedef SentenceTuple Sample;
CorpusBase(Ptr<Options> options,
bool translate = false,
size_t seed = Config::seed);
CorpusBase(const std::vector<std::string>& paths,
const std::vector<Ptr<Vocab>>& vocabs,
Ptr<Options> options,
size_t seed = Config::seed);
virtual ~CorpusBase() {}
virtual std::vector<Ptr<Vocab>>& getVocabs() = 0;
protected:
std::vector<UPtr<std::istream>> files_;
std::vector<Ptr<Vocab>> vocabs_;
std::vector<bool> addEOS_;
size_t pos_{0};
size_t maxLength_{0};
bool maxLengthCrop_{false};
bool rightLeft_{false};
bool tsv_{false}; // true if the input is a single file with tab-separated values
size_t tsvNumInputFields_{0}; // number of fields from the TSV input that are associated
// with vocabs, i.e. excluding fields with alignment or
// weights, only if --tsv
static size_t getNumberOfTSVInputFields(Ptr<Options> options);
int weightFileIdx_{-1};
int alignFileIdx_{-1};
void initEOS(bool training);
void addWordsToSentenceTuple(const std::string& line, size_t batchIndex, SentenceTupleImpl& tup) const;
void addAlignmentToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const;
void addWeightsToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const;
void addAlignmentsToBatch(Ptr<CorpusBatch> batch, const std::vector<Sample>& batchVector);
void addWeightsToBatch(Ptr<CorpusBatch> batch, const std::vector<Sample>& batchVector);
};
class CorpusIterator : public IteratorFacade<CorpusIterator, SentenceTuple> {
public:
CorpusIterator();
explicit CorpusIterator(CorpusBase* corpus);
private:
void increment() override;
bool equal(CorpusIterator const& other) const override;
const SentenceTuple& dereference() const override;
CorpusBase* corpus_;
int64_t pos_; // we use int64_t here because the initial value can be -1
SentenceTuple tup_;
};
} // namespace data
} // namespace marian