Program Listing for File corpus_nbest.h¶
↰ Return to documentation for file (src/data/corpus_nbest.h
)
#pragma once
#include <fstream>
#include <iostream>
#include <random>
#include "common/definitions.h"
#include "common/file_stream.h"
#include "common/options.h"
#include "data/alignment.h"
#include "data/batch.h"
#include "data/corpus_base.h"
#include "data/dataset.h"
#include "data/vocab.h"
namespace marian {
namespace data {
class CorpusNBest : public CorpusBase {
private:
std::vector<size_t> ids_;
int lastNum_{-1};
std::vector<std::string> lastLines_;
public:
// @TODO: check if translate can be replaced by an option in options
CorpusNBest(Ptr<Options> options, bool translate = false);
CorpusNBest(std::vector<std::string> paths,
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> options);
Sample next() override;
void shuffle() override {}
void reset() override;
void restore(Ptr<TrainingState>) override {}
iterator begin() override { return iterator(this); }
iterator end() override { return iterator(); }
std::vector<Ptr<Vocab>>& getVocabs() override { return vocabs_; }
batch_ptr toBatch(const std::vector<Sample>& batchVector) override {
size_t batchSize = batchVector.size();
std::vector<size_t> sentenceIds;
std::vector<int> maxDims;
for(auto& ex : batchVector) {
if(maxDims.size() < ex.size())
maxDims.resize(ex.size(), 0);
for(size_t i = 0; i < ex.size(); ++i) {
if(ex[i].size() > (size_t)maxDims[i])
maxDims[i] = (int)ex[i].size();
}
sentenceIds.push_back(ex.getId());
}
std::vector<Ptr<SubBatch>> subBatches;
for(size_t j = 0; j < maxDims.size(); ++j) {
subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_[j]));
}
std::vector<size_t> words(maxDims.size(), 0);
for(size_t i = 0; i < batchSize; ++i) {
for(size_t j = 0; j < maxDims.size(); ++j) {
for(size_t k = 0; k < batchVector[i][j].size(); ++k) {
subBatches[j]->data()[k * batchSize + i] = batchVector[i][j][k];
subBatches[j]->mask()[k * batchSize + i] = 1.f;
words[j]++;
}
}
}
for(size_t j = 0; j < maxDims.size(); ++j)
subBatches[j]->setWords(words[j]);
auto batch = batch_ptr(new batch_type(subBatches));
batch->setSentenceIds(sentenceIds);
return batch;
}
};
} // namespace data
} // namespace marian