Program Listing for File cosmos.cpp

Return to documentation for file (src/microsoft/cosmos.cpp)

#include "cosmos.h"

#include "models/model_base.h"
#include "models/model_factory.h"
#include "data/text_input.h"

#if MKL_FOUND
#include "mkl.h"
#endif

namespace marian {

// Thin wrapper around IModel that makes sure model can be cast to an EncoderPooler
// These poolers know how to collect embeddings from a seq2seq encoder.
class EmbedderModel {
private:
  Ptr<models::IModel> model_;

public:
  EmbedderModel(Ptr<Options> options)
    : model_(createModelFromOptions(options, models::usage::embedding)) {}

  void load(Ptr<ExpressionGraph> graph, const std::string& modelFile) {
    model_->load(graph, modelFile);
  }

  Expr build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
    auto embedder = std::dynamic_pointer_cast<EncoderPooler>(model_);
    ABORT_IF(!embedder, "Could not cast to EncoderPooler");
    return embedder->apply(graph, batch, /*clearGraph=*/true)[0];
  }
};

namespace cosmos {

const size_t MAX_BATCH_SIZE =  32;
const size_t MAX_LENGTH     = 256;

class Embedder {
private:
  Ptr<Options> options_;
  Ptr<ExpressionGraph> graph_;
  Ptr<Vocab> vocab_;

  Ptr<EmbedderModel> model_;

public:
  Embedder(const std::string& modelPath, const std::string& vocabPath, bool computeSimilarity = false) {
    options_ = New<Options>("inference", true,
                            "shuffle", "none",
                            "mini-batch", MAX_BATCH_SIZE,
                            "maxi-batch", 100,
                            "maxi-batch-sort", "src",
                            "max-length", MAX_LENGTH,
                            "max-length-crop", true,
                            "compute-similarity", computeSimilarity,
                            "vocabs", std::vector<std::string>(computeSimilarity ? 2 : 1, vocabPath));

    vocab_ = New<Vocab>(options_, 0);
    vocab_->load(vocabPath, 0);

    graph_ = New<ExpressionGraph>(/*inference=*/true);
    graph_->setDevice(CPU0);
    graph_->reserveWorkspaceMB(512);

    YAML::Node config;
    io::getYamlFromModel(config, "special:model.yml", modelPath);

    Ptr<Options> modelOpts = New<Options>();
    modelOpts->merge(options_);
    modelOpts->merge(config);

    model_ = New<EmbedderModel>(modelOpts);
    model_->load(graph_, modelPath);
  }

  // Compute embedding vectors for a batch of sentences
  std::vector<std::vector<float>> embed(const std::string& input) {
    auto text = New<data::TextInput>(std::vector<std::string>({input}),
                                     std::vector<Ptr<Vocab>>({vocab_}),
                                     options_);
    // we set runAsync=false as we are throwing exceptions instead of aborts. Exceptions and threading do not mix well.
    data::BatchGenerator<data::TextInput> batchGenerator(text, options_, /*stats=*/nullptr, /*runAsync=*/false);
    batchGenerator.prepare();

    std::vector<std::vector<float>> output;

    for(auto batch : batchGenerator) {
      auto embeddings = model_->build(graph_, batch);
      graph_->forward();

      std::vector<float> sentVectors;
      embeddings->val()->get(sentVectors);

      // collect embedding vector per sentence.
      // if we compute similarities this is only one similarity per sentence pair.
      for(size_t i = 0; i < batch->size(); ++i) {
        auto batchIdx = batch->getSentenceIds()[i];
        if(output.size() <= batchIdx)
          output.resize(batchIdx + 1);

        int embSize = embeddings->shape()[-1];
        size_t beg = i * embSize;
        size_t end = (i + 1) * embSize;
        std::vector<float> sentVector(sentVectors.begin() + beg, sentVectors.begin() + end);
        output[batchIdx] = sentVector;
      }
    }

    return output;
  }

  // Compute cosine similarity scores for a two batches of corresponding sentences
  std::vector<float> similarity(const std::string& input1, const std::string& input2) {
    auto text = New<data::TextInput>(std::vector<std::string>({input1, input2}),
                                     std::vector<Ptr<Vocab>>({vocab_, vocab_}),
                                     options_);
    // we set runAsync=false as we are throwing exceptions instead of aborts. Exceptions and threading do not mix well.
    data::BatchGenerator<data::TextInput> batchGenerator(text, options_, /*stats=*/nullptr, /*runAsync=*/false);
    batchGenerator.prepare();

    std::vector<float> output;

    for(auto batch : batchGenerator) {
      auto similarities = model_->build(graph_, batch);
      graph_->forward();

      std::vector<float> vSimilarities;
      similarities->val()->get(vSimilarities);

      // collect similarity score per sentence pair.
      for(size_t i = 0; i < batch->size(); ++i) {
        auto batchIdx = batch->getSentenceIds()[i];
        if(output.size() <= batchIdx)
          output.resize(batchIdx + 1);
        output[batchIdx] = vSimilarities[i];
      }
    }

    return output;
  };
};

/* Interface functions ***************************************************************************/

MarianEmbedder::MarianEmbedder() {
#if MKL_FOUND
  mkl_set_num_threads(1);
#endif
  marian::setThrowExceptionOnAbort(true); // globally defined to throw now
}

std::vector<std::vector<float>> MarianEmbedder::embed(const std::string& input) {
  ABORT_IF(!embedder_, "Embedder is not defined??");
  return embedder_->embed(input);
}

bool MarianEmbedder::load(const std::string& modelPath, const std::string& vocabPath) {
  embedder_ = New<Embedder>(modelPath, vocabPath, /*computeSimilarity*/false);
  ABORT_IF(!embedder_, "Embedder is not defined??");
  return true;
}

MarianCosineScorer::MarianCosineScorer() {
#if MKL_FOUND
  mkl_set_num_threads(1);
#endif
  marian::setThrowExceptionOnAbort(true); // globally defined to throw now
}

std::vector<float> MarianCosineScorer::score(const std::string& input1, const std::string& input2) {
  ABORT_IF(!embedder_, "Embedder is not defined??");
  return embedder_->similarity(input1, input2);
};

bool MarianCosineScorer::load(const std::string& modelPath, const std::string& vocabPath) {
  embedder_ = New<Embedder>(modelPath, vocabPath, /*computeSimilarity*/true);
  ABORT_IF(!embedder_, "Embedder is not defined??");
  return true;
}

} // namespace cosmos
} // namespace marian