Program Listing for File default_vocab.cpp¶
↰ Return to documentation for file (src/data/default_vocab.cpp
)
#include "data/vocab_base.h"
#include "3rd_party/yaml-cpp/yaml.h"
#include "common/logging.h"
#include "common/regex.h"
#include "common/utils.h"
#include "common/filesystem.h"
#include <algorithm>
#include <fstream>
#include <iostream>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
namespace marian {
class DefaultVocab : public IVocab {
protected:
typedef std::map<std::string, Word> Str2Id;
Str2Id str2id_;
typedef std::vector<std::string> Id2Str;
Id2Str id2str_;
Word eosId_ = Word::NONE;
Word unkId_ = Word::NONE;
std::vector<std::string> suffixes_ = { ".yml", ".yaml", ".json" };
// Contains control characters added to vocab, possibly due to byte-fallback
std::vector<Word> controlChars_;
class VocabFreqOrderer {
private:
const std::unordered_map<std::string, size_t>& counter_;
public:
VocabFreqOrderer(const std::unordered_map<std::string, size_t>& counter)
: counter_(counter) {}
// order first by decreasing frequency,
// if frequencies are the same order lexicographically by vocabulary string
bool operator()(const std::string& a, const std::string& b) const {
return counter_.at(a) > counter_.at(b) || (counter_.at(a) == counter_.at(b) && a < b);
}
};
public:
// @TODO: choose between 'virtual' and 'final'. Can we derive from this class?
virtual ~DefaultVocab() {};
virtual const std::string& canonicalExtension() const override { return suffixes_[0]; }
virtual const std::vector<std::string>& suffixes() const override { return suffixes_; }
virtual Word operator[](const std::string& word) const override {
auto it = str2id_.find(word);
if(it != str2id_.end())
return it->second;
else
return unkId_;
}
Words encode(const std::string& line, bool addEOS, bool /*inference*/) const override {
auto lineTokens = utils::split(line, " ");
return (*this)(lineTokens, addEOS);
}
std::string decode(const Words& sentence, bool ignoreEOS) const override {
auto tokens = (*this)(sentence, ignoreEOS);
return utils::join(tokens, " ");
}
std::string surfaceForm(const Words& sentence) const override {
return decode(sentence, /*ignoreEOS=*/true);
}
// SentencePiece with byte-fallback may generate control symbols with output sampling.
// Let's mark them as special and suppress them later on output. This is generally safe
// for UTF-8 since control chars are not used as partial bytes in multi-byte sequences.
// They only appear in single-byte chars as themselves and this is what we suppress.
void addSpecialWords(std::vector<Word>& special) const override {
special.reserve(special.size() + controlChars_.size());
for(auto c : controlChars_)
special.push_back(c);
}
virtual std::string type() const override { return "DefaultVocab"; }
virtual Word getEosId() const override { return eosId_; }
virtual Word getUnkId() const override { return unkId_; }
const std::string& operator[](Word word) const override {
auto id = word.toWordIndex();
ABORT_IF(id >= id2str_.size(), "Unknown word id: {}", id);
return id2str_[id];
}
size_t size() const override {
return id2str_.size();
}
size_t load(const std::string& vocabPath, size_t maxSize) override {
bool isJson = regex::regex_search(vocabPath, regex::regex("\\.(json|yaml|yml)$"));
LOG(info,
"[data] Loading vocabulary from {} file {}",
isJson ? "JSON/Yaml" : "text",
vocabPath);
ABORT_IF(!filesystem::exists(vocabPath),
"DefaultVocabulary file {} does not exist",
vocabPath);
std::map<std::string, Word> vocab;
// read from JSON (or Yaml) file
if(isJson) {
io::InputFileStream strm(vocabPath);
YAML::Node vocabNode = YAML::Load(strm);
for(auto&& pair : vocabNode)
vocab.insert({pair.first.as<std::string>(), Word::fromWordIndex(pair.second.as<IndexType>())});
}
// read from flat text file
else {
io::InputFileStream in(vocabPath);
std::string line;
while(io::getline(in, line)) {
ABORT_IF(line.empty(),
"DefaultVocabulary file {} must not contain empty lines",
vocabPath);
auto wasInserted = vocab.insert({line, Word::fromWordIndex(vocab.size())}).second;
ABORT_IF(!wasInserted, "Duplicate vocabulary entry {}", line);
}
ABORT_IF(in.bad(), "DefaultVocabulary file {} could not be read", vocabPath);
}
id2str_.reserve(vocab.size());
for(auto&& pair : vocab) {
auto str = pair.first;
auto id = pair.second;
// note: this requires ids to be sorted by frequency
if(!maxSize || id.toWordIndex() < maxSize) {
insertWord(id, str);
}
}
ABORT_IF(id2str_.empty(), "Empty vocabulary: ", vocabPath);
populateControlChars();
addRequiredVocabulary(vocabPath, isJson);
return std::max(id2str_.size(), maxSize);
}
// for fakeBatch()
virtual void createFake() override {
eosId_ = insertWord(Word::DEFAULT_EOS_ID, DEFAULT_EOS_STR);
unkId_ = insertWord(Word::DEFAULT_UNK_ID, DEFAULT_UNK_STR);
}
virtual void create(const std::string& vocabPath,
const std::vector<std::string>& trainPaths,
size_t maxSize = 0) override {
LOG(info, "[data] Creating vocabulary {} from {}",
vocabPath,
utils::join(trainPaths, ", "));
if(vocabPath != "stdout") {
filesystem::Path path(vocabPath);
auto dir = path.parentPath();
if(dir.empty())
dir = filesystem::currentPath();
ABORT_IF(!dir.empty() && !filesystem::isDirectory(dir),
"Specified vocab directory {} does not exist",
dir.string());
ABORT_IF(filesystem::exists(vocabPath),
"Vocabulary file '{}' exists. Not overwriting",
path.string());
}
std::unordered_map<std::string, size_t> counter;
for(const auto& trainPath : trainPaths)
addCounts(counter, trainPath);
create(vocabPath, counter, maxSize);
}
private:
// Creates the first 32 control characters as done in byte-fallback and checks if they exist in the vocab.
// This makes sure that we do not waste computational effort on suppression if they don't actually appear.
void populateControlChars() {
for(int i = 0; i < 32; ++i) {
std::string bytePiece = fmt::format("<0x{:02X}>", i); // 0 becomes <0x00>, 10 becomes <0x0A>, note uppercase A and lowercase x
auto id = (*this)[bytePiece];
if(id != unkId_)
controlChars_.push_back(id);
}
}
virtual void addRequiredVocabulary(const std::string& vocabPath, bool isJson) {
// look up ids for </s> and <unk>, which are required
// The name backCompatStr is alternatively accepted for Yaml vocabs if id
// equals backCompatId.
auto getRequiredWordId = [&](const std::string& str,
const std::string& backCompatStr,
Word backCompatWord) -> Word {
// back compat with Nematus Yaml dicts
if(isJson) {
// if word id 0 or 1 is either empty or has the Nematus-convention string,
// then use it
auto backCompatId = backCompatWord.toWordIndex();
if(backCompatId < id2str_.size()
&& (id2str_[backCompatId].empty()
|| id2str_[backCompatId] == backCompatStr)) {
LOG(info,
"[data] Using unused word id {} for {}",
backCompatStr,
backCompatId,
str);
return backCompatWord;
}
}
auto iter = str2id_.find(str);
ABORT_IF(iter == str2id_.end(),
"DefaultVocabulary file {} is expected to contain an entry for {}",
vocabPath,
str);
return iter->second;
};
eosId_ = getRequiredWordId(DEFAULT_EOS_STR, NEMATUS_EOS_STR, Word::DEFAULT_EOS_ID);
unkId_ = getRequiredWordId(DEFAULT_UNK_STR, NEMATUS_UNK_STR, Word::DEFAULT_UNK_ID);
}
void addCounts(std::unordered_map<std::string, size_t>& counter,
const std::string& trainPath) {
std::unique_ptr<std::istream> trainStrm(
trainPath == "stdin" ? new std::istream(std::cin.rdbuf())
: new io::InputFileStream(trainPath)
);
std::string line;
while(getline(*trainStrm, line)) {
auto toks = utils::split(line, " ");
for(const std::string& tok : toks) {
auto iter = counter.find(tok);
if(iter == counter.end())
counter[tok] = 1;
else
iter->second++;
}
}
}
virtual void create(const std::string& vocabPath,
const std::unordered_map<std::string, size_t>& counter,
size_t maxSize = 0) {
std::vector<std::string> vocabVec;
for(auto& p : counter)
vocabVec.push_back(p.first);
std::sort(vocabVec.begin(), vocabVec.end(), VocabFreqOrderer(counter));
YAML::Node vocabYaml;
vocabYaml.force_insert(DEFAULT_EOS_STR, Word::DEFAULT_EOS_ID.toWordIndex());
vocabYaml.force_insert(DEFAULT_UNK_STR, Word::DEFAULT_UNK_ID.toWordIndex());
WordIndex maxSpec = 1;
auto vocabSize = vocabVec.size();
if(maxSize > maxSpec)
vocabSize = std::min(maxSize - maxSpec - 1, vocabVec.size());
for(size_t i = 0; i < vocabSize; ++i)
vocabYaml.force_insert(vocabVec[i], i + maxSpec + 1);
std::unique_ptr<std::ostream> vocabStrm(
vocabPath == "stdout" ? new std::ostream(std::cout.rdbuf())
: new io::OutputFileStream(vocabPath)
);
*vocabStrm << vocabYaml;
}
Words operator()(const std::vector<std::string>& lineTokens,
bool addEOS) const {
Words words(lineTokens.size());
std::transform(lineTokens.begin(),
lineTokens.end(),
words.begin(),
[&](const std::string& w) { return (*this)[w]; });
if(addEOS)
words.push_back(eosId_);
return words;
}
std::vector<std::string> operator()(const Words& sentence,
bool ignoreEOS) const {
std::vector<std::string> decoded;
for(size_t i = 0; i < sentence.size(); ++i) {
if((sentence[i] != eosId_ || !ignoreEOS)) {
decoded.push_back((*this)[sentence[i]]);
}
}
return decoded;
}
// helper to insert a word into str2id_[] and id2str_[]
Word insertWord(Word word, const std::string& str) {
str2id_[str] = word;
auto id = word.toWordIndex();
if(id >= id2str_.size())
id2str_.resize(id + 1);
id2str_[id] = str;
return word;
};
};
// This is a vocabulary class that does not enforce </s> or <unk>.
// This is used for class lists in a classifier.
class ClassVocab : public DefaultVocab {
private:
// Do nothing.
virtual void addRequiredVocabulary(const std::string& /*vocabPath*/, bool /*isJson*/) override {}
// Not adding special class labels, only seen classes.
virtual void create(const std::string& vocabPath,
const std::unordered_map<std::string, size_t>& counter,
size_t maxSize = 0) override {
std::vector<std::string> vocabVec;
for(auto& p : counter)
vocabVec.push_back(p.first);
std::sort(vocabVec.begin(), vocabVec.end(), VocabFreqOrderer(counter));
ABORT_IF(maxSize != 0 && vocabVec.size() != maxSize,
"Class vocab maxSize given ({}) has to match class vocab size ({})",
maxSize, vocabVec.size());
YAML::Node vocabYaml;
for(size_t i = 0; i < vocabVec.size(); ++i)
vocabYaml.force_insert(vocabVec[i], i);
std::unique_ptr<std::ostream> vocabStrm(
vocabPath == "stdout" ? new std::ostream(std::cout.rdbuf())
: new io::OutputFileStream(vocabPath)
);
*vocabStrm << vocabYaml;
}
};
Ptr<IVocab> createDefaultVocab() {
return New<DefaultVocab>();
}
Ptr<IVocab> createClassVocab() {
return New<ClassVocab>();
}
}