Program Listing for File vocab.cpp¶
↰ Return to documentation for file (src/data/vocab.cpp
)
#include "common/utils.h"
#include "data/vocab.h"
#include "data/vocab_base.h"
namespace marian {
Word Word::NONE = Word();
Word Word::ZERO = Word(0);
Word Word::DEFAULT_EOS_ID = Word(0);
Word Word::DEFAULT_UNK_ID = Word(1);
// @TODO: make each vocab peek on type
Ptr<IVocab> createVocab(const std::string& vocabPath, Ptr<Options> options, size_t batchIndex) {
// try SentencePiece
auto vocab = createSentencePieceVocab(vocabPath, options, batchIndex);
if(vocab)
return vocab; // this is defined which means that a sentencepiece vocabulary could be created, so return it
// try factored
vocab = createFactoredVocab(vocabPath);
if (vocab)
return vocab;
// regular vocab
// check type of input, if not given, assume "sequence"
auto inputTypes = options->get<std::vector<std::string>>("input-types", {});
std::string inputType = inputTypes.size() > batchIndex ? inputTypes[batchIndex] : "sequence";
return inputType == "class" ? createClassVocab() : createDefaultVocab();
}
size_t Vocab::loadOrCreate(const std::string& vocabPath,
const std::vector<std::string>& trainPaths,
size_t maxSize) {
size_t size = 0;
if(vocabPath.empty()) {
// No vocabulary path was given, attempt to first find a vocabulary
// for trainPaths[0] + possible suffixes. If not found attempt to create
// as trainPaths[0] + canonical suffix.
// Only search based on first path, maybe disable this at all?
LOG(info,
"No vocabulary path given; "
"trying to find default vocabulary based on data path {}",
trainPaths[0]);
vImpl_ = createDefaultVocab();
size = vImpl_->findAndLoad(trainPaths[0], maxSize);
if(size == 0) {
auto newVocabPath = trainPaths[0] + vImpl_->canonicalExtension();
LOG(info,
"No vocabulary path given; "
"trying to create vocabulary based on data paths {}",
utils::join(trainPaths, ", "));
create(newVocabPath, trainPaths, maxSize);
size = load(newVocabPath, maxSize);
}
} else {
if(!filesystem::exists(vocabPath)) {
// Vocabulary path was given, but no vocabulary present,
// attempt to create in specified location.
create(vocabPath, trainPaths, maxSize);
}
// Vocabulary path exists, attempting to load
size = load(vocabPath, maxSize);
}
LOG(info, "[data] Setting vocabulary size for input {} to {}", batchIndex_, utils::withCommas(size));
return size;
}
size_t Vocab::load(const std::string& vocabPath, size_t maxSize) {
if(!vImpl_)
vImpl_ = createVocab(vocabPath, options_, batchIndex_);
return vImpl_->load(vocabPath, (int)maxSize);
}
void Vocab::create(const std::string& vocabPath,
const std::vector<std::string>& trainPaths,
size_t maxSize) {
if(!vImpl_)
vImpl_ = createVocab(vocabPath, options_, batchIndex_);
vImpl_->create(vocabPath, trainPaths, maxSize);
}
void Vocab::create(const std::string& vocabPath,
const std::string& trainPath,
size_t maxSize) {
create(vocabPath, std::vector<std::string>({trainPath}), maxSize);
}
void Vocab::createFake() {
if(!vImpl_)
vImpl_ = createDefaultVocab(); // DefaultVocab is OK here
vImpl_->createFake();
}
Word Vocab::randWord() {
return vImpl_->randWord();
}
// string token to token id
Word Vocab::operator[](const std::string& word) const {
return vImpl_->operator[](word);
}
// token id to string token
const std::string& Vocab::operator[](Word id) const {
return vImpl_->operator[](id);
}
// line of text to list of token ids, can perform tokenization
Words Vocab::encode(const std::string& line,
bool addEOS,
bool inference) const {
return vImpl_->encode(line, addEOS, inference);
}
// convert sequence of token ids to single line, can perform detokenization
std::string Vocab::decode(const Words& sentence,
bool ignoreEOS) const {
return vImpl_->decode(sentence, ignoreEOS);
}
// convert sequence of token its to surface form (incl. removng spaces, applying factors)
// for in-process BLEU validation
std::string Vocab::surfaceForm(const Words& sentence) const {
return vImpl_->surfaceForm(sentence);
}
// number of vocabulary items
size_t Vocab::size() const { return vImpl_->size(); }
size_t Vocab::lemmaSize() const {
return vImpl_->lemmaSize();
}
// type of vocabulary items
std::string Vocab::type() const { return vImpl_->type(); }
// return EOS symbol id
Word Vocab::getEosId() const { return vImpl_->getEosId(); }
// return UNK symbol id
Word Vocab::getUnkId() const { return vImpl_->getUnkId(); }
std::vector<Word> Vocab::suppressedIds(bool suppressUnk, bool suppressSpecial) const {
std::vector<Word> ids;
if(suppressUnk) {
auto unkId = getUnkId();
if(unkId != Word::NONE)
ids.push_back(unkId);
}
if(suppressSpecial)
vImpl_->addSpecialWords(/*in/out=*/ids);
return ids;
}
std::vector<WordIndex> Vocab::suppressedIndices(bool suppressUnk, bool suppressSpecial) const {
std::vector<WordIndex> indices;
for(Word word : suppressedIds(suppressUnk, suppressSpecial))
indices.push_back(word.toWordIndex());
vImpl_->transcodeToShortlistInPlace(indices.data(), indices.size());
return indices;
}
// for corpus augmentation: convert string to all-caps
std::string Vocab::toUpper(const std::string& line) const { return vImpl_->toUpper(line); }
// for corpus augmentation: convert string to title case
std::string Vocab::toEnglishTitleCase(const std::string& line) const { return vImpl_->toEnglishTitleCase(line); }
// for short-list generation
void Vocab::transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const { vImpl_->transcodeToShortlistInPlace(ptr, num); }
} // namespace marian