Program Listing for File embedder.h¶
↰ Return to documentation for file (src/embedder/embedder.h
)
#pragma once
#include "marian.h"
#include "common/config.h"
#include "common/options.h"
#include "data/batch_generator.h"
#include "data/corpus.h"
#include "data/corpus_nbest.h"
#include "models/costs.h"
#include "models/model_task.h"
#include "embedder/vector_collector.h"
#include "training/scheduler.h"
#include "training/validator.h"
namespace marian {
using namespace data;
/*
* The tool is used to create output sentence embeddings from available
* Marian encoders. With --compute-similiarity and can return the cosine
* similarity between two sentences provided from two sources.
*/
class Embedder {
private:
Ptr<models::IModel> model_;
public:
Embedder(Ptr<Options> options)
: model_(createModelFromOptions(options, models::usage::embedding)) {}
void load(Ptr<ExpressionGraph> graph, const std::string& modelFile) {
model_->load(graph, modelFile);
}
Expr build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
auto embedder = std::dynamic_pointer_cast<EncoderPooler>(model_);
ABORT_IF(!embedder, "Could not cast to EncoderPooler");
return embedder->apply(graph, batch, /*clearGraph=*/true)[0];
}
};
/*
* Actual Embed task. @TODO: this should be simplified in the future.
*/
template <class Model>
class Embed : public ModelTask {
private:
Ptr<Options> options_;
Ptr<CorpusBase> corpus_;
std::vector<Ptr<ExpressionGraph>> graphs_;
std::vector<Ptr<Model>> models_;
public:
Embed(Ptr<Options> options) : options_(options) {
options_ = options_->with("inference", true,
"shuffle", "none",
"input-types", std::vector<std::string>({"sequence"}));
// if a similarity is computed then double the input types and vocabs for
// the two encoders that are used in the model.
if(options->get<bool>("compute-similarity")) {
auto vVocabs = options_->get<std::vector<std::string>>("vocabs");
auto vDimVocabs = options_->get<std::vector<size_t>>("dim-vocabs");
vVocabs.push_back(vVocabs.back());
vDimVocabs.push_back(vDimVocabs.back());
options_ = options_->with("vocabs", vVocabs,
"dim-vocabs", vDimVocabs,
"input-types", std::vector<std::string>(vVocabs.size(), "sequence"));
}
corpus_ = New<Corpus>(options_);
corpus_->prepare();
auto devices = Config::getDevices(options_);
for(auto device : devices) {
auto graph = New<ExpressionGraph>(true);
auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
graph->setDevice(device);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
}
auto modelFile = options_->get<std::string>("model");
models_.resize(graphs_.size());
ThreadPool pool(graphs_.size(), graphs_.size());
for(size_t i = 0; i < graphs_.size(); ++i) {
pool.enqueue(
[=](size_t j) {
models_[j] = New<Model>(options_);
models_[j]->load(graphs_[j], modelFile);
},
i);
}
}
void run() override {
LOG(info, "Embedding");
timer::Timer timer;
auto batchGenerator = New<BatchGenerator<CorpusBase>>(corpus_, options_);
batchGenerator->prepare();
auto output = New<VectorCollector>(options_);
size_t batchId = 0;
{
ThreadPool pool(graphs_.size(), graphs_.size());
for(auto batch : *batchGenerator) {
auto task = [=](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<Model> builder;
if(!graph) {
graph = graphs_[id % graphs_.size()];
builder = models_[id % graphs_.size()];
}
auto embeddings = builder->build(graph, batch);
graph->forward();
std::vector<float> sentVectors;
embeddings->val()->get(sentVectors);
// collect embedding vector per sentence.
// if we compute similarities this is only one similarity per sentence pair.
for(size_t i = 0; i < batch->size(); ++i) {
auto embSize = embeddings->shape()[-1];
auto beg = i * embSize;
auto end = (i + 1) * embSize;
std::vector<float> sentVector(sentVectors.begin() + beg, sentVectors.begin() + end);
output->Write((long)batch->getSentenceIds()[i],
sentVector);
}
// progress heartbeat for MS-internal Philly compute cluster
// otherwise this job may be killed prematurely if no log for 4 hrs
if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster
&& id % 1000 == 0) // hard beat once every 1000 batches
{
auto progress = id / 10000.f; //fake progress for now, becomes >100 after 1M batches
fprintf(stderr, "PROGRESS: %.2f%%\n", progress);
fflush(stderr);
}
};
pool.enqueue(task, batchId++);
}
}
LOG(info, "Total time: {:.5f}s wall", timer.elapsed());
}
};
} // namespace marian