Program Listing for File lsh.h¶
↰ Return to documentation for file (src/layers/lsh.h
)
#pragma once
#include "graph/expression_operators.h"
#include "graph/node_initializers.h"
#include <vector>
namespace marian {
namespace lsh {
// encodes an input as a bit vector, with optional rotation
Expr encode(Expr input, Expr rotator = nullptr);
// compute the rotation matrix (maps weights->shape()[-1] to nbits floats)
Expr rotator(Expr weights, int inDim, int nbits);
// perform the LSH search on fully encoded input and weights, return k results (indices) per input row
// @TODO: add a top-k like operator that also returns the bitwise computed distances
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0, bool noSort = false);
// same as above, but performs encoding on the fly
Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0, bool abortIfDynamic = false);
// struct for parameter conversion used in marian-conv
struct ParamConvInfo {
std::string name;
std::string codesName;
std::string rotationName;
int nBits;
bool transpose;
ParamConvInfo(const std::string& name, const std::string& codesName, const std::string& rotationName, int nBits, bool transpose = false)
: name(name), codesName(codesName), rotationName(rotationName), nBits(nBits), transpose(transpose) {}
};
// These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv
void addDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo);
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo);
Ptr<inits::NodeInitializer> randomRotation();
}
}