.. _program_listing_file_src_microsoft_cosmos.cpp: Program Listing for File cosmos.cpp =================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/microsoft/cosmos.cpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "cosmos.h" #include "models/model_base.h" #include "models/model_factory.h" #include "data/text_input.h" #if MKL_FOUND #include "mkl.h" #endif namespace marian { // Thin wrapper around IModel that makes sure model can be cast to an EncoderPooler // These poolers know how to collect embeddings from a seq2seq encoder. class EmbedderModel { private: Ptr model_; public: EmbedderModel(Ptr options) : model_(createModelFromOptions(options, models::usage::embedding)) {} void load(Ptr graph, const std::string& modelFile) { model_->load(graph, modelFile); } Expr build(Ptr graph, Ptr batch) { auto embedder = std::dynamic_pointer_cast(model_); ABORT_IF(!embedder, "Could not cast to EncoderPooler"); return embedder->apply(graph, batch, /*clearGraph=*/true)[0]; } }; namespace cosmos { const size_t MAX_BATCH_SIZE = 32; const size_t MAX_LENGTH = 256; class Embedder { private: Ptr options_; Ptr graph_; Ptr vocab_; Ptr model_; public: Embedder(const std::string& modelPath, const std::string& vocabPath, bool computeSimilarity = false) { options_ = New("inference", true, "shuffle", "none", "mini-batch", MAX_BATCH_SIZE, "maxi-batch", 100, "maxi-batch-sort", "src", "max-length", MAX_LENGTH, "max-length-crop", true, "compute-similarity", computeSimilarity, "vocabs", std::vector(computeSimilarity ? 2 : 1, vocabPath)); vocab_ = New(options_, 0); vocab_->load(vocabPath, 0); graph_ = New(/*inference=*/true); graph_->setDevice(CPU0); graph_->reserveWorkspaceMB(512); YAML::Node config; io::getYamlFromModel(config, "special:model.yml", modelPath); Ptr modelOpts = New(); modelOpts->merge(options_); modelOpts->merge(config); model_ = New(modelOpts); model_->load(graph_, modelPath); } // Compute embedding vectors for a batch of sentences std::vector> embed(const std::string& input) { auto text = New(std::vector({input}), std::vector>({vocab_}), options_); // we set runAsync=false as we are throwing exceptions instead of aborts. Exceptions and threading do not mix well. data::BatchGenerator batchGenerator(text, options_, /*stats=*/nullptr, /*runAsync=*/false); batchGenerator.prepare(); std::vector> output; for(auto batch : batchGenerator) { auto embeddings = model_->build(graph_, batch); graph_->forward(); std::vector sentVectors; embeddings->val()->get(sentVectors); // collect embedding vector per sentence. // if we compute similarities this is only one similarity per sentence pair. for(size_t i = 0; i < batch->size(); ++i) { auto batchIdx = batch->getSentenceIds()[i]; if(output.size() <= batchIdx) output.resize(batchIdx + 1); int embSize = embeddings->shape()[-1]; size_t beg = i * embSize; size_t end = (i + 1) * embSize; std::vector sentVector(sentVectors.begin() + beg, sentVectors.begin() + end); output[batchIdx] = sentVector; } } return output; } // Compute cosine similarity scores for a two batches of corresponding sentences std::vector similarity(const std::string& input1, const std::string& input2) { auto text = New(std::vector({input1, input2}), std::vector>({vocab_, vocab_}), options_); // we set runAsync=false as we are throwing exceptions instead of aborts. Exceptions and threading do not mix well. data::BatchGenerator batchGenerator(text, options_, /*stats=*/nullptr, /*runAsync=*/false); batchGenerator.prepare(); std::vector output; for(auto batch : batchGenerator) { auto similarities = model_->build(graph_, batch); graph_->forward(); std::vector vSimilarities; similarities->val()->get(vSimilarities); // collect similarity score per sentence pair. for(size_t i = 0; i < batch->size(); ++i) { auto batchIdx = batch->getSentenceIds()[i]; if(output.size() <= batchIdx) output.resize(batchIdx + 1); output[batchIdx] = vSimilarities[i]; } } return output; }; }; /* Interface functions ***************************************************************************/ MarianEmbedder::MarianEmbedder() { #if MKL_FOUND mkl_set_num_threads(1); #endif marian::setThrowExceptionOnAbort(true); // globally defined to throw now } std::vector> MarianEmbedder::embed(const std::string& input) { ABORT_IF(!embedder_, "Embedder is not defined??"); return embedder_->embed(input); } bool MarianEmbedder::load(const std::string& modelPath, const std::string& vocabPath) { embedder_ = New(modelPath, vocabPath, /*computeSimilarity*/false); ABORT_IF(!embedder_, "Embedder is not defined??"); return true; } MarianCosineScorer::MarianCosineScorer() { #if MKL_FOUND mkl_set_num_threads(1); #endif marian::setThrowExceptionOnAbort(true); // globally defined to throw now } std::vector MarianCosineScorer::score(const std::string& input1, const std::string& input2) { ABORT_IF(!embedder_, "Embedder is not defined??"); return embedder_->similarity(input1, input2); }; bool MarianCosineScorer::load(const std::string& modelPath, const std::string& vocabPath) { embedder_ = New(modelPath, vocabPath, /*computeSimilarity*/true); ABORT_IF(!embedder_, "Embedder is not defined??"); return true; } } // namespace cosmos } // namespace marian