Program Listing for File corpus_sqlite.cpp¶
↰ Return to documentation for file (src/data/corpus_sqlite.cpp
)
#include <random>
#include "data/corpus_sqlite.h"
namespace marian {
namespace data {
CorpusSQLite::CorpusSQLite(Ptr<Options> options, bool translate /*= false*/, size_t seed /*= Config:seed*/)
: CorpusBase(options, translate, seed), seed_(seed) {
fillSQLite();
}
CorpusSQLite::CorpusSQLite(const std::vector<std::string>& paths,
const std::vector<Ptr<Vocab>>& vocabs,
Ptr<Options> options, size_t seed)
: CorpusBase(paths, vocabs, options, seed), seed_(seed) {
fillSQLite();
}
void CorpusSQLite::fillSQLite() {
auto tempDir = options_->get<std::string>("tempdir");
bool fill = false;
// create a temporary or persistent SQLite database
if(options_->get<std::string>("sqlite") == "temporary") {
LOG(info, "[sqlite] Creating temporary database in {}", tempDir);
db_.reset(new SQLite::Database("", SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE));
db_->exec("PRAGMA temp_store_directory = '" + tempDir + "';");
fill = true;
} else {
auto path = options_->get<std::string>("sqlite");
if(filesystem::exists(path)) {
LOG(info, "[sqlite] Reusing persistent database {}", path);
db_.reset(new SQLite::Database(path, SQLite::OPEN_READWRITE));
db_->exec("PRAGMA temp_store_directory = '" + tempDir + "';");
if(options_->get<bool>("sqlite-drop")) {
LOG(info, "[sqlite] Dropping previous data");
db_->exec("drop table if exists lines");
fill = true;
}
} else {
LOG(info, "[sqlite] Creating persistent database {}", path);
db_.reset(new SQLite::Database(path, SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE));
db_->exec("PRAGMA temp_store_directory = '" + tempDir + "';");
fill = true;
}
}
// populate tables with lines from text files
if(fill) {
std::string createStr = "create table lines (_id integer";
std::string insertStr = "insert into lines values (?";
for(size_t i = 0; i < files_.size(); ++i) {
createStr += ", line" + std::to_string(i) + " text";
insertStr += ", ?";
}
createStr += ");";
insertStr += ");";
db_->exec(createStr);
SQLite::Statement ps(*db_, insertStr);
int lines = 0;
int report = 1000000;
bool cont = true;
db_->exec("begin;");
while(cont) {
ps.bind(1, (int)lines);
std::string line;
for(size_t i = 0; i < files_.size(); ++i) {
cont = cont && io::getline(*files_[i], line);
if(cont)
ps.bind((int)(i + 2), line);
}
if(cont) {
ps.exec();
ps.reset();
}
lines++;
if(lines % report == 0) {
LOG(info, "[sqlite] Inserted {} lines", lines);
db_->exec("commit;");
db_->exec("begin;");
report *= 2;
}
}
db_->exec("commit;");
LOG(info, "[sqlite] Inserted {} lines", lines - 1);
LOG(info, "[sqlite] Creating primary index");
db_->exec("create unique index idx_line on lines (_id);");
}
createRandomFunction();
}
SentenceTuple CorpusSQLite::next() {
while(select_->executeStep()) {
// fill up the sentence tuple with sentences from all input files
size_t curId = select_->getColumn(0).getInt();
SentenceTupleImpl tup(curId);
for(size_t i = 0; i < files_.size(); ++i) {
auto line = select_->getColumn((int)(i + 1));
if(i > 0 && i == alignFileIdx_) {
addAlignmentToSentenceTuple(line, tup);
} else if(i > 0 && i == weightFileIdx_) {
addWeightsToSentenceTuple(line, tup);
} else {
addWordsToSentenceTuple(line, i, tup);
}
}
if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
return words.size() > 0 && words.size() <= maxLength_;
}))
return SentenceTuple(tup);
}
return SentenceTuple();
}
void CorpusSQLite::shuffle() {
LOG(info, "[sqlite] Selecting shuffled data");
select_.reset(new SQLite::Statement(
*db_, "select * from lines order by random_seed(" + std::to_string(seed_) + ");"));
}
void CorpusSQLite::reset() {
select_.reset(
new SQLite::Statement(*db_, "select * from lines order by _id;"));
}
void CorpusSQLite::restore(Ptr<TrainingState> ts) {
for(size_t i = 0; i < ts->epochs - 1; ++i) {
select_.reset(new SQLite::Statement(
*db_, "select _id from lines order by random_seed(" + std::to_string(seed_) + ");"));
select_->executeStep();
reset();
}
}
} // namespace data
} // namespace marian