Program Listing for File pooler.h¶
↰ Return to documentation for file (src/models/pooler.h
)
#pragma once
#include "marian.h"
#include "models/states.h"
#include "layers/constructors.h"
#include "layers/factory.h"
namespace marian {
class PoolerBase : public LayerBase {
using LayerBase::LayerBase;
protected:
const std::string prefix_{"pooler"};
const bool inference_{false};
const size_t batchIndex_{0};
public:
PoolerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options),
prefix_(options->get<std::string>("prefix", "pooler")),
inference_(options->get<bool>("inference", true)),
batchIndex_(options->get<size_t>("index", 1)) {} // assume that training input has batch index 0 and labels has 1
virtual ~PoolerBase() {}
virtual std::vector<Expr> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0;
template <typename T>
T opt(const std::string& key) const {
return options_->get<T>(key);
}
// Should be used to clear any batch-wise temporary objects if present
virtual void clear() = 0;
};
class MaxPooler : public PoolerBase {
public:
MaxPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: PoolerBase(graph, options) {}
std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
auto context = encoderStates[0]->getContext();
auto batchMask = encoderStates[0]->getMask();
// do a max pool here
Expr logMask = (1.f - batchMask) * -9999.f;
Expr maxPool = max(context * batchMask + logMask, /*axis=*/-3);
return {maxPool};
}
void clear() override {}
};
class SlicePooler : public PoolerBase {
public:
SlicePooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: PoolerBase(graph, options) {}
std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
auto context = encoderStates[0]->getContext();
auto batchMask = encoderStates[0]->getMask();
// Corresponds to the way we do this in transformer.h
// @TODO: unify this better, this is currently hacky
Expr slicePool = slice(context * batchMask, /*axis=*/-3, 0);
return {slicePool};
}
void clear() override {}
};
class SimPooler : public PoolerBase {
public:
SimPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: PoolerBase(graph, options) {}
std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() < 2, "SimPooler expects at least two encoder states not {}", encoderStates.size());
std::vector<Expr> vecs;
for(auto encoderState : encoderStates) {
auto context = encoderState->getContext();
auto batchMask = encoderState->getMask();
Expr pool;
auto type = options_->get<std::string>("original-type");
if(type == "laser" || type == "laser-sim") {
// LASER models do a max pool here
Expr logMask = (1.f - batchMask) * -9999.f;
pool = max(context * batchMask + logMask, /*axis=*/-3);
} else if(type == "transformer") {
// Our own implementation in transformer.h uses a slice of the first element
pool = slice(context, -3, 0);
} else {
// @TODO: make SimPooler take Pooler objects as arguments then it won't need to know this.
ABORT("Don't know what type of pooler to use for model type {}", type);
}
vecs.push_back(pool);
}
std::vector<Expr> outputs;
bool trainRank = options_->hasAndNotEmpty("train-embedder-rank");
if(!trainRank) { // inference, compute one cosine similarity only
ABORT_IF(vecs.size() != 2, "We are expecting two inputs for similarity computation");
// efficiently compute vector length with bdot
auto vnorm = [](Expr e) {
int dimModel = e->shape()[-1];
int dimBatch = e->shape()[-2];
e = reshape(e, {dimBatch, 1, dimModel});
return reshape(sqrt(bdot(e, e, false, true)), {dimBatch, 1});
};
auto dotProduct = scalar_product(vecs[0], vecs[1], /*axis*/-1);
auto length0 = vnorm(vecs[0]); // will be hashed and reused in the graph
auto length1 = vnorm(vecs[1]);
auto cosine = dotProduct / ( length0 * length1 );
cosine = maximum(0, cosine); // clip to [0, 1] - should we actually do that?
outputs.push_back(cosine);
} else { // compute outputs for embedding similarity ranking
if(vecs.size() == 2) { // implies we are sampling negative examples from the batch, since otherwise there is nothing to train
LOG_ONCE(info, "Sampling negative examples from batch");
auto src = vecs[0];
auto trg = vecs[1];
int dimModel = src->shape()[-1];
int dimBatch = src->shape()[-2];
src = reshape(src, {dimBatch, dimModel});
trg = reshape(trg, {dimBatch, dimModel});
// compute cosines between every batch entry, this produces the whole dimBatch x dimBatch matrix
auto dotProduct = dot(src, trg, false, true); // [dimBatch, dimBatch] - computes dot product matrix
auto positiveMask = dotProduct->graph()->constant({dimBatch, dimBatch}, inits::eye()); // a mask for the diagonal (positive examples are on the diagonal)
auto negativeMask = 1.f - positiveMask; // mask for all negative examples;
auto positive = sum(dotProduct * positiveMask, /*axis=*/-1); // we sum across last dim in order to get a column vector of positve examples (everything else is zero)
outputs.push_back(positive);
auto negative1 = dotProduct * negativeMask; // get negative examples for src -> trg (in a row)
outputs.push_back(negative1);
auto negative2 = transpose(negative1); // get negative examples for trg -> src via transpose so they are located in a row
outputs.push_back(transpose(negative2));
} else {
LOG_ONCE(info, "Using provided {} negative examples", vecs.size() - 2);
// For inference and training with given set of negative examples provided in additional streams.
// Assuming that enc0 is query, enc1 is positive example and remaining encoders are optional negative examples. Here we only use column vectors [dimBatch, 1]
auto positive = scalar_product(vecs[0], vecs[1], /*axis*/-1);
outputs.push_back(positive); // first tensor contains similarity between anchor and positive example
std::vector<Expr> dotProductsNegative1, dotProductsNegative2;
for(int i = 2; i < vecs.size(); ++i) {
// compute similarity with anchor
auto negative1 = scalar_product(vecs[0], vecs[i], /*axis*/-1);
dotProductsNegative1.push_back(negative1);
// for negative examples also add symmetric dot product with the positive example
auto negative2 = scalar_product(vecs[1], vecs[i], /*axis*/-1);
dotProductsNegative2.push_back(negative2);
}
auto negative1 = concatenate(dotProductsNegative1, /*axis=*/-1);
outputs.push_back(negative1); // second tensor contains similarities between anchor and all negative example
auto negative2 = concatenate(dotProductsNegative2, /*axis=*/-1);
outputs.push_back(negative2); // third tensor contains similarities between positive and all negative example (symmetric)
}
}
return outputs;
}
void clear() override {}
};
}