.. _program_listing_file_src_data_corpus_sqlite.h: Program Listing for File corpus_sqlite.h ======================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/data/corpus_sqlite.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include "common/definitions.h" #include "common/file_stream.h" #include "common/options.h" #include "data/alignment.h" #include "data/batch.h" #include "data/corpus_base.h" #include "data/dataset.h" #include "data/vocab.h" #include #include static void SQLiteRandomSeed(sqlite3_context* context, int argc, sqlite3_value** argv) { if(argc == 1 && sqlite3_value_type(argv[0]) == SQLITE_INTEGER) { const int seed = sqlite3_value_int(argv[0]); static std::default_random_engine eng(seed); std::uniform_int_distribution<> unif; const int result = unif(eng); sqlite3_result_int(context, result); } else { sqlite3_result_error(context, "Invalid", 0); } } namespace marian { namespace data { class CorpusSQLite : public CorpusBase { private: UPtr db_; UPtr select_; void fillSQLite(); size_t seed_; public: // @TODO: check if translate can be replaced by an option in options CorpusSQLite(Ptr options, bool translate = false, size_t seed = Config::seed); CorpusSQLite(const std::vector& paths, const std::vector>& vocabs, Ptr options, size_t seed = Config::seed); Sample next() override; void shuffle() override; void reset() override; void restore(Ptr) override; iterator begin() override { return iterator(this); } iterator end() override { return iterator(); } std::vector>& getVocabs() override { return vocabs_; } batch_ptr toBatch(const std::vector& batchVector) override { size_t batchSize = batchVector.size(); std::vector sentenceIds; std::vector maxDims; for(auto& ex : batchVector) { if(maxDims.size() < ex.size()) maxDims.resize(ex.size(), 0); for(size_t i = 0; i < ex.size(); ++i) { if(ex[i].size() > (size_t)maxDims[i]) maxDims[i] = (int)ex[i].size(); } sentenceIds.push_back(ex.getId()); } std::vector> subBatches; for(size_t j = 0; j < maxDims.size(); ++j) { subBatches.emplace_back(New(batchSize, maxDims[j], vocabs_[j])); } std::vector words(maxDims.size(), 0); for(size_t i = 0; i < batchSize; ++i) { for(size_t j = 0; j < maxDims.size(); ++j) { for(size_t k = 0; k < batchVector[i][j].size(); ++k) { subBatches[j]->data()[k * batchSize + i] = batchVector[i][j][k]; subBatches[j]->mask()[k * batchSize + i] = 1.f; words[j]++; } } } for(size_t j = 0; j < maxDims.size(); ++j) subBatches[j]->setWords(words[j]); auto batch = batch_ptr(new batch_type(subBatches)); batch->setSentenceIds(sentenceIds); if(options_->has("guided-alignment") && alignFileIdx_) addAlignmentsToBatch(batch, batchVector); if(options_->hasAndNotEmpty("data-weighting") && weightFileIdx_) addWeightsToBatch(batch, batchVector); return batch; } private: void createRandomFunction() { sqlite3_create_function(db_->getHandle(), "random_seed", 1, SQLITE_UTF8, NULL, &SQLiteRandomSeed, NULL, NULL); } }; } // namespace data } // namespace marian