Program Listing for File corpus_sqlite.h

Return to documentation for file (src/data/corpus_sqlite.h)

#pragma once

#include <fstream>
#include <iostream>
#include <random>

#include "common/definitions.h"
#include "common/file_stream.h"
#include "common/options.h"
#include "data/alignment.h"
#include "data/batch.h"
#include "data/corpus_base.h"
#include "data/dataset.h"
#include "data/vocab.h"

#include <SQLiteCpp/SQLiteCpp.h>
#include <SQLiteCpp/sqlite3/sqlite3.h>

static void SQLiteRandomSeed(sqlite3_context* context,
                             int argc,
                             sqlite3_value** argv) {
  if(argc == 1 && sqlite3_value_type(argv[0]) == SQLITE_INTEGER) {
    const int seed = sqlite3_value_int(argv[0]);
    static std::default_random_engine eng(seed);
    std::uniform_int_distribution<> unif;
    const int result = unif(eng);
    sqlite3_result_int(context, result);
  } else {
    sqlite3_result_error(context, "Invalid", 0);
  }
}

namespace marian {
namespace data {

class CorpusSQLite : public CorpusBase {
private:
  UPtr<SQLite::Database> db_;
  UPtr<SQLite::Statement> select_;

  void fillSQLite();

  size_t seed_;

public:
  // @TODO: check if translate can be replaced by an option in options
  CorpusSQLite(Ptr<Options> options, bool translate = false, size_t seed = Config::seed);

  CorpusSQLite(const std::vector<std::string>& paths,
               const std::vector<Ptr<Vocab>>& vocabs,
               Ptr<Options> options,
               size_t seed = Config::seed);

  Sample next() override;

  void shuffle() override;

  void reset() override;

  void restore(Ptr<TrainingState>) override;

  iterator begin() override { return iterator(this); }

  iterator end() override { return iterator(); }

  std::vector<Ptr<Vocab>>& getVocabs() override { return vocabs_; }

  batch_ptr toBatch(const std::vector<Sample>& batchVector) override {
    size_t batchSize = batchVector.size();

    std::vector<size_t> sentenceIds;

    std::vector<int> maxDims;
    for(auto& ex : batchVector) {
      if(maxDims.size() < ex.size())
        maxDims.resize(ex.size(), 0);
      for(size_t i = 0; i < ex.size(); ++i) {
        if(ex[i].size() > (size_t)maxDims[i])
          maxDims[i] = (int)ex[i].size();
      }
      sentenceIds.push_back(ex.getId());
    }

    std::vector<Ptr<SubBatch>> subBatches;
    for(size_t j = 0; j < maxDims.size(); ++j) {
      subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_[j]));
    }

    std::vector<size_t> words(maxDims.size(), 0);
    for(size_t i = 0; i < batchSize; ++i) {
      for(size_t j = 0; j < maxDims.size(); ++j) {
        for(size_t k = 0; k < batchVector[i][j].size(); ++k) {
          subBatches[j]->data()[k * batchSize + i] = batchVector[i][j][k];
          subBatches[j]->mask()[k * batchSize + i] = 1.f;
          words[j]++;
        }
      }
    }

    for(size_t j = 0; j < maxDims.size(); ++j)
      subBatches[j]->setWords(words[j]);

    auto batch = batch_ptr(new batch_type(subBatches));
    batch->setSentenceIds(sentenceIds);

    if(options_->has("guided-alignment") && alignFileIdx_)
      addAlignmentsToBatch(batch, batchVector);
    if(options_->hasAndNotEmpty("data-weighting") && weightFileIdx_)
      addWeightsToBatch(batch, batchVector);

    return batch;
  }

private:
  void createRandomFunction() {
    sqlite3_create_function(db_->getHandle(),
                            "random_seed",
                            1,
                            SQLITE_UTF8,
                            NULL,
                            &SQLiteRandomSeed,
                            NULL,
                            NULL);
  }
};
}  // namespace data
}  // namespace marian