Program Listing for File lsh.h

Return to documentation for file (src/layers/lsh.h)

#pragma once

#include "graph/expression_operators.h"
#include "graph/node_initializers.h"

#include <vector>

namespace marian {
namespace lsh {
  // encodes an input as a bit vector, with optional rotation
  Expr encode(Expr input, Expr rotator = nullptr);

  // compute the rotation matrix (maps weights->shape()[-1] to nbits floats)
  Expr rotator(Expr weights, int inDim, int nbits);

  // perform the LSH search on fully encoded input and weights, return k results (indices) per input row
  // @TODO: add a top-k like operator that also returns the bitwise computed distances
  Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0, bool noSort = false);

  // same as above, but performs encoding on the fly
  Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0, bool abortIfDynamic = false);

  // struct for parameter conversion used in marian-conv
  struct ParamConvInfo {
    std::string name;
    std::string codesName;
    std::string rotationName;
    int nBits;
    bool transpose;

    ParamConvInfo(const std::string& name, const std::string& codesName, const std::string& rotationName, int nBits, bool transpose = false)
     : name(name), codesName(codesName), rotationName(rotationName), nBits(nBits), transpose(transpose) {}
  };

  // These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv
  void addDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo);
  void overwriteDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo);

  Ptr<inits::NodeInitializer> randomRotation();
}

}