Program Listing for File sentencepiece.cpp¶
↰ Return to documentation for file (src/microsoft/sentencepiece.cpp
)
#include <sstream>
#include <memory>
#include <string>
#include <vector>
#ifdef USE_SENTENCEPIECE
#include "sentencepiece.h"
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsuggest-override"
#endif
#include "sentencepiece/src/builtin_pb/sentencepiece.pb.h"
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
#include "sentencepiece/src/sentencepiece_processor.h"
#include "sentencepiece/src/sentencepiece_trainer.h"
#include "unicode_conversions.h"
namespace marian {
namespace spm {
class SentencePieceInternal {
std::unique_ptr<sentencepiece::SentencePieceProcessor> m_processor;
void checkStatus(sentencepiece::util::Status status, const char* what) {
if(status.ok())
return;
std::string err = status.ToString();
std::cerr << err << std::endl;
throw std::runtime_error(std::string("SentencePiece error ") + what + ": " + err);
}
int createNativeSentencePieceText(sentencepiece::SentencePieceText& spt, Native_SentencePieceText** outSpt) {
Native_SentencePieceText* spt_ret = new Native_SentencePieceText();
spt_ret->text = new char[spt.text().size() + 1];
::strcpy(spt_ret->text, spt.text().c_str());
spt_ret->num_pieces = spt.pieces().size();
spt_ret->pieces = new Native_SentencePiecePiece*[spt_ret->num_pieces];
int counter = 0;
for(auto& piece : spt.pieces()) {
spt_ret->pieces[counter] = new Native_SentencePiecePiece();
spt_ret->pieces[counter]->id = piece.id();
spt_ret->pieces[counter]->begin = piece.begin();
spt_ret->pieces[counter]->end = piece.end();
spt_ret->pieces[counter]->surface = new char[piece.surface().size() + 1];
::strcpy((spt_ret->pieces)[counter]->surface, (char*)piece.surface().c_str());
spt_ret->pieces[counter]->piece = new char[piece.piece().size() + 1];
::strcpy((spt_ret->pieces)[counter]->piece, (char*)piece.piece().c_str());
counter++;
}
*outSpt = spt_ret;
return 0;
}
public:
SentencePieceInternal(const uint16_t* modelPath, const uint16_t** vocab, size_t vocabSize) {
m_processor.reset(new sentencepiece::SentencePieceProcessor());
// load the model file
const auto status = m_processor->Load(utf16_to_utf8(utf16string(modelPath)));
// implant the restricted vocabulary, if given
if(vocab && vocabSize > 0) {
std::vector<std::string> vocab_str;
for(size_t i = 0; i < vocabSize; i++)
vocab_str.push_back(utf16_to_utf8(utf16string(vocab[i])));
m_processor->SetVocabulary(vocab_str);
}
checkStatus(status, "loading");
}
int getPieceID(char* sentence) {
std::string sentInUtf8(sentence);
return m_processor->PieceToId(absl::string_view(sentInUtf8));
}
int encodeAligned(char* sentence, Native_SentencePieceText** nSpt) {
sentencepiece::SentencePieceText spt;
std::string sentInUtf8(sentence);
m_processor->Encode(absl::string_view(sentInUtf8), &spt);
return createNativeSentencePieceText(spt, nSpt);
}
int decodeAligned(int num_tokens, char** inp_tokens, Native_SentencePieceText** nSpt) {
sentencepiece::SentencePieceText spt;
std::vector<std::string> tokens;
for(int i = 0; i < num_tokens; i++) {
std::string tok((char*)inp_tokens[i]);
tokens.push_back(tok);
}
m_processor->Decode(tokens, &spt);
return createNativeSentencePieceText(spt, nSpt);
}
};
int SentencePieceInteropFreeNativeSentencePieceText(Native_SentencePieceText* spt) {
auto num_pieces = (*spt).num_pieces;
for(int i = 0; i < num_pieces; i++) {
Native_SentencePiecePiece* piece = (*spt).pieces[i];
delete(piece->surface);
delete(piece->piece);
delete(piece);
}
delete[]((*spt).pieces);
delete[]((*spt).text);
delete(spt);
spt = NULL;
return 0;
}
intptr_t SentencePieceInteropLoadModel(const uint16_t* modelPath,
const uint16_t** vocab,
size_t vocabSize) {
try {
return (intptr_t) new SentencePieceInternal(modelPath, vocab, vocabSize);
}
catch(...) { return (intptr_t) nullptr; }
}
int SentencePieceInteropDecodeAligned(intptr_t object,
int num_tokens,
char** tokens,
Native_SentencePieceText** nSpt) {
try {
return ((SentencePieceInternal*)object)->decodeAligned(num_tokens, tokens, nSpt);
}
catch(...) { return -1; }
}
int SentencePieceInteropEncodeAligned(intptr_t object,
char* word,
Native_SentencePieceText** nSpt) {
try {
return ((SentencePieceInternal*)object)->encodeAligned(word, nSpt);
}
catch(...) { return -1; }
}
int SentencePieceInteropGetPieceID(intptr_t object, char* word) {
try {
return ((SentencePieceInternal*)object)->getPieceID(word);
}
catch(...) { return -1; }
}
int SentencePieceInteropUnloadModel(intptr_t object) {
delete(SentencePieceInternal*)object;
return 0;
}
int SentencepieceInteropTrainModel(char* args) {
std::stringstream command;
command << std::string(args);
auto status = sentencepiece::SentencePieceTrainer::Train(command.str());
return (int)status.code();
}
} // namespace spm
} // namespace marian
#endif