.. _program_listing_file_src_data_sentencepiece_vocab.cpp: Program Listing for File sentencepiece_vocab.cpp ================================================ |exhale_lsh| :ref:`Return to documentation for file ` (``src/data/sentencepiece_vocab.cpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "data/vocab_base.h" #ifdef USE_SENTENCEPIECE #include "sentencepiece/src/sentencepiece_processor.h" #include "sentencepiece/src/sentencepiece_trainer.h" #endif #include "common/config.h" #include "common/options.h" #include "common/logging.h" #include "common/filesystem.h" #include "common/regex.h" #include #include namespace marian { #ifdef USE_SENTENCEPIECE // Wrapper around https://github.com/google/sentencepiece class SentencePieceVocab : public IVocab { private: // Actual SentencePiece processor object UPtr spm_; // Sampling factor for subword regularization, disabled when 0 float alpha_{0}; // Allowed suffixes for SentencePiece model std::vector suffixes_ = {".spm"}; Ptr options_; size_t batchIndex_{0}; std::mt19937 generator_; std::uniform_int_distribution randInt_; // from 0 to INT_MAX // Keeps sentences segmented into subword units bool keepEncoded_{false}; // Contains control characters added to vocab due to byte-fallback std::vector controlChars_; // Creates the first 32 control characters as done in byte-fallback and checks if they exist in the vocab. // This makes sure that we do not waste computational effort on suppression if they don't actually appear. void populateControlChars() { for(int i = 0; i < 32; ++i) { std::string bytePiece = fmt::format("<0x{:02X}>", i); // 0 becomes <0x00>, 10 becomes <0x0A>, note uppercase A and lowercase x auto id = spm_->PieceToId(bytePiece); if(id != spm_->unk_id()) controlChars_.push_back(Word::fromWordIndex(id)); } } // Sample from one file, based on first algorithm from: // https://en.wikipedia.org/wiki/Reservoir_sampling void reservoirSampling(std::vector& sample, size_t& seenLines, const std::string& trainPath, size_t maxLines, size_t maxBytes) { ABORT_IF(maxLines == 0, "Sample needs to be larger 0"); std::unique_ptr trainStrm(trainPath == "stdin" ? new std::istream(std::cin.rdbuf()) : new io::InputFileStream(trainPath)); std::string line; while(getline(*trainStrm, line)) { if(line.size() > 0 && line.size() < maxBytes) { if(sample.size() < maxLines) { sample.push_back(line); } else { size_t i = randInt_(generator_) % (seenLines + 1); if(i < maxLines) sample[i] = line; } seenLines++; } } } // Iterate over all input files and collect a representative sample via reservoir sampling. // The sample will first grow to the desired size and next keep sampling with decreasing // probability in the hope to get a uniform sample from the union of all files. size_t reservoirSamplingAll(io::TemporaryFile& temp, const std::vector& trainPaths, size_t maxLines, size_t maxBytes) { LOG(info, "[SentencePiece] Sampling at most {} lines from {}", maxLines, utils::join(trainPaths, ", ")); std::vector sample; size_t seenLines = 0; for(const auto& trainPath : trainPaths) reservoirSampling(sample, seenLines, trainPath, maxLines, maxBytes); std::shuffle(sample.begin(), sample.end(), generator_); for(const auto& line : sample) temp << line << std::endl; LOG(info, "[SentencePiece] Selected {} lines", sample.size()); return sample.size(); } // Just concatenate all files to a temporary file so SentencePiece can consume it. size_t dumpAll(io::TemporaryFile& temp, const std::vector& trainPaths, size_t maxBytes) { LOG(info, "[SentencePiece] Selecting all lines from {}", utils::join(trainPaths, ", ")); size_t seenLines = 0; std::string line; for(const auto& trainPath : trainPaths) { io::InputFileStream in(trainPath); while(getline(in, line)) { if(line.size() > 0 && line.size() < maxBytes) { temp << line << std::endl; seenLines++; } } } LOG(info, "[SentencePiece] Selected {} lines", seenLines); return seenLines; } public: SentencePieceVocab(Ptr options, size_t batchIndex) : options_(options), batchIndex_(batchIndex), generator_((uint32_t)Config::seed), keepEncoded_(options->get("no-spm-decode", false)) { if(options_->has("sentencepiece-alphas")) { auto alphas = options_->get>("sentencepiece-alphas"); if(alphas.size() <= batchIndex) alpha_ = 0.f; else alpha_ = alphas[batchIndex_]; if(alpha_ > 0) LOG(debug, "Setting SentencePiece vocabulary sampling factor to {} for input {}", alpha_, batchIndex_); } } virtual const std::string& canonicalExtension() const override { return suffixes_[0]; } virtual const std::vector& suffixes() const override { return suffixes_; } virtual std::string suffix() { return suffixes_[0]; }; virtual std::string type() const override { return "SentencePieceVocab"; } virtual Word getEosId() const override { return Word::fromWordIndex(spm_->eos_id()); } virtual Word getUnkId() const override { return Word::fromWordIndex(spm_->unk_id()); } void create(const std::string& vocabPath, const std::vector& trainPaths, size_t maxSize) override { size_t defaultMaxSize = 32000; size_t maxLines = options_->get("sentencepiece-max-lines"); size_t maxBytes = 2048; LOG(info, "[SentencePiece] Training SentencePiece vocabulary {}", vocabPath); if(maxSize == 0) { LOG(info, "[SentencePiece] Vocabulary size is undefined (set with --dim-vocabs ...) - setting to {}", defaultMaxSize); maxSize = defaultMaxSize; } // Create temporary file to hold the sample for the SentencePiece trainer io::TemporaryFile temp(options_->get("tempdir"), false); std::string tempFileName = temp.getFileName(); LOG(info, "[SentencePiece] Creating temporary file {}", tempFileName); size_t seenLines = 0; if(maxLines == 0) seenLines = dumpAll(temp, trainPaths, maxBytes); else seenLines = reservoirSamplingAll(temp, trainPaths, maxLines, maxBytes); // Compose the SentencePiece training command from filenames and parameters0 std::stringstream command; command << " --bos_id=-1 --eos_id=0 --unk_id=1" // these should not be changed as they match Marian defaults << " --input=" << tempFileName << " --model_prefix=" << vocabPath << " --vocab_size=" << maxSize << " --max_sentence_length=" << maxBytes << " --input_sentence_size=" << seenLines << " " << options_->get("sentencepiece-options"); // these are SentencePiece command line options // Train the SentencePiece model const auto status = sentencepiece::SentencePieceTrainer::Train(command.str()); ABORT_IF(!status.ok(), "SentencePiece vocabulary error: {}", status.ToString()); LOG(info, "[SentencePiece] Removing {}", vocabPath + ".vocab"); ABORT_IF(remove((vocabPath + ".vocab").c_str()) != 0, "Could not remove {}", vocabPath + ".vocab"); LOG(info, "[SentencePiece] Renaming {} to {}", vocabPath + ".model", vocabPath); ABORT_IF(rename((vocabPath + ".model").c_str(), vocabPath.c_str()) != 0, "Could not rename {} to {}", vocabPath + ".model", vocabPath); } void createFake() override { ABORT("[SentencePiece] Fake SentencePiece vocabulary not supported"); } Word operator[](const std::string& token) const override { return Word::fromWordIndex(spm_->PieceToId(token)); } const std::string& operator[](Word id) const override { ABORT_IF(id.toWordIndex() >= size(), "Unknown word id: ", id.toWordIndex()); return spm_->IdToPiece(id.toWordIndex()); } Words encode(const std::string& line, bool addEOS, bool inference) const override { std::vector spmIds; if(inference || alpha_ == 0) spm_->Encode(line, &spmIds); else spm_->SampleEncode(line, -1, alpha_, &spmIds); Words words; words.reserve(spmIds.size() + addEOS); for (auto&& spmId : spmIds) words.push_back(Word::fromWordIndex(spmId)); if(addEOS) words.push_back(getEosId()); return words; } std::string decode(const Words& sentence, bool ignoreEOS) const override { std::string line; if(keepEncoded_) { // i.e. keep the sentence segmented into subword units for(const Word& id : sentence) if(!ignoreEOS || id != getEosId()) line += (*this)[id] + " "; line.pop_back(); // trim the trailing whitespace } else { // convert vector of Word to vector of int std::vector spmSentence; spmSentence.reserve(sentence.size()); for(auto&& word : sentence) if(!ignoreEOS || word != getEosId()) spmSentence.push_back(word.toWordIndex()); spm_->Decode(spmSentence, &line); } return line; } std::string surfaceForm(const Words& sentence) const override { // with SentencePiece, decoded form and surface form are identical return decode(sentence, /*ignoreEOS=*/true); } size_t size() const override { return spm_->GetPieceSize(); } size_t load(const std::string& vocabPath, size_t /*maxSize*/) override { LOG(info, "[data] Loading SentencePiece vocabulary from file {}", vocabPath); ABORT_IF(!filesystem::exists(vocabPath), "SentencePiece vocabulary file {} does not exist", vocabPath); spm_.reset(new sentencepiece::SentencePieceProcessor()); const auto status = spm_->Load(vocabPath); ABORT_IF(!status.ok(), "SentencePiece vocabulary error: {}", status.ToString()); populateControlChars(); return spm_->GetPieceSize(); } std::string toUpper(const std::string& line) const override { return utils::utf8ToUpper(line); } std::string toEnglishTitleCase(const std::string& line) const override { return utils::toEnglishTitleCase(line); } // SentencePiece with byte-fallback may generate control symbols with output sampling. // Let's mark them as special and suppress them later on output. This is generally safe // for UTF-8 since control chars are not used as partial bytes in multi-byte sequences. // They only appear in single-byte chars as themselves and this is what we suppress. void addSpecialWords(std::vector& special) const override { special.reserve(special.size() + controlChars_.size()); for(auto c : controlChars_) special.push_back(c); } }; #endif // USE_SENTENCEPIECE Ptr createSentencePieceVocab(const std::string& vocabPath, Ptr options, size_t batchIndex) { bool isSentencePiece = regex::regex_search(vocabPath, regex::regex("\\.(spm)$")); if(isSentencePiece) { #ifdef USE_SENTENCEPIECE return New(options, batchIndex); #else batchIndex; options; ABORT("*.spm suffix in path {} reserved for SentencePiece models, " "but support for SentencePiece is not compiled into Marian. " "Try to recompile after `cmake .. -DUSE_SENTENCEPIECE=on [...]`", vocabPath); #endif } // Not a SentencePiece model based on suffix; return nullptr; } }