.. _program_listing_file_src_data_corpus_nbest.h: Program Listing for File corpus_nbest.h ======================================= |exhale_lsh| :ref:`Return to documentation for file ` (``src/data/corpus_nbest.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include "common/definitions.h" #include "common/file_stream.h" #include "common/options.h" #include "data/alignment.h" #include "data/batch.h" #include "data/corpus_base.h" #include "data/dataset.h" #include "data/vocab.h" namespace marian { namespace data { class CorpusNBest : public CorpusBase { private: std::vector ids_; int lastNum_{-1}; std::vector lastLines_; public: // @TODO: check if translate can be replaced by an option in options CorpusNBest(Ptr options, bool translate = false); CorpusNBest(std::vector paths, std::vector> vocabs, Ptr options); Sample next() override; void shuffle() override {} void reset() override; void restore(Ptr) override {} iterator begin() override { return iterator(this); } iterator end() override { return iterator(); } std::vector>& getVocabs() override { return vocabs_; } batch_ptr toBatch(const std::vector& batchVector) override { size_t batchSize = batchVector.size(); std::vector sentenceIds; std::vector maxDims; for(auto& ex : batchVector) { if(maxDims.size() < ex.size()) maxDims.resize(ex.size(), 0); for(size_t i = 0; i < ex.size(); ++i) { if(ex[i].size() > (size_t)maxDims[i]) maxDims[i] = (int)ex[i].size(); } sentenceIds.push_back(ex.getId()); } std::vector> subBatches; for(size_t j = 0; j < maxDims.size(); ++j) { subBatches.emplace_back(New(batchSize, maxDims[j], vocabs_[j])); } std::vector words(maxDims.size(), 0); for(size_t i = 0; i < batchSize; ++i) { for(size_t j = 0; j < maxDims.size(); ++j) { for(size_t k = 0; k < batchVector[i][j].size(); ++k) { subBatches[j]->data()[k * batchSize + i] = batchVector[i][j][k]; subBatches[j]->mask()[k * batchSize + i] = 1.f; words[j]++; } } } for(size_t j = 0; j < maxDims.size(); ++j) subBatches[j]->setWords(words[j]); auto batch = batch_ptr(new batch_type(subBatches)); batch->setSentenceIds(sentenceIds); return batch; } }; } // namespace data } // namespace marian