Program Listing for File lsh.cpp

Return to documentation for file (src/layers/lsh.cpp)

#include "layers/lsh.h"
#include "tensors/tensor_operators.h"
#include "common/utils.h"

#include "3rd_party/faiss/utils/hamming.h"

#if BLAS_FOUND
#include "3rd_party/faiss/VectorTransform.h"
#endif

#include "common/timer.h"

#include "layers/lsh_impl.h"

namespace marian {
namespace lsh {

int bytesPerVector(int nBits) {
  return (nBits + 7) / 8;
}

void fillRandomRotationMatrix(Tensor output, Ptr<Allocator> allocator) {
#if BLAS_FOUND
  int nRows = output->shape()[-2];
  int nBits = output->shape()[-1];

  // @TODO re-implement using Marian code so it uses the correct random generator etc.
  faiss::RandomRotationMatrix rrot(nRows, nBits);
  // Then we do not need to use this seed at all
  rrot.init(5); // currently set to 5 following the default from FAISS, this could be any number really.

  // The faiss random rotation matrix is column major, hence we create a temporary tensor,
  // copy the rotation matrix into it and transpose to output.
  Shape tempShape = {nBits, nRows};
  auto memory = allocator->alloc(requiredBytes(tempShape, output->type()));
  auto temp = TensorBase::New(memory,
                              tempShape,
                              output->type(),
                              output->getBackend());
  temp->set(rrot.A);
  TransposeND(output, temp, {0, 1, 3, 2});
  allocator->free(memory);
#else
  output; allocator;
  ABORT("LSH with rotation matrix requires Marian to be compiled with a BLAS library");
#endif
}

void encode(Tensor output, Tensor input) {
  int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix
  int nRows = input->shape().elements() / nBits;
  faiss::fvecs2bitvecs(input->data<float>(), output->data<uint8_t>(), (size_t)nBits, (size_t)nRows);
}

void encodeWithRotation(Tensor output, Tensor input, Tensor rotation, Ptr<Allocator> allocator) {
  int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix unless we rotate
  int nRows = input->shape().elements() / nBits;

  Tensor tempInput = input;
  MemoryPiece::PtrType memory;
  if(rotation) {
    int nBitsRot = rotation->shape()[-1];
    Shape tempShape = {nRows, nBitsRot};
    memory = allocator->alloc(requiredBytes(tempShape, rotation->type()));
    tempInput = TensorBase::New(memory, tempShape, rotation->type(), rotation->getBackend());
    Prod(tempInput, input, rotation, false, false, 0.f, 1.f);
  }
  encode(output, tempInput);

  if(memory)
    allocator->free(memory);
};

Expr encode(Expr input, Expr rotation) {
  auto encodeFwd = [](Expr out, const std::vector<Expr>& inputs) {
    if(inputs.size() == 1) {
      encode(out->val(), inputs[0]->val());
    } else if(inputs.size() == 2) {
      encodeWithRotation(out->val(), inputs[0]->val(), inputs[1]->val(), out->graph()->allocator());
    } else {
      ABORT("Too many inputs to encode??");
    }
  };

  // Use the address of the first lambda function as an immutable hash. Making it static and const makes sure
  // that this hash value will not change. Next pass the hash into the lambda functor were it will be used
  // to identify this unique operation. Marian's ExpressionGraph can automatically memoize and identify nodes
  // that operate only on immutable nodes (parameters) and have the same hash. This way we make sure that the
  // codes node won't actually get recomputed throughout ExpressionGraph lifetime. `codes` will be reused
  // and the body of the lambda will not be called again. This does however build one index per graph.
  static const size_t encodeHash = (size_t)&encodeFwd;

  Shape encodedShape = input->shape();

  int nBits = rotation ? rotation->shape()[-1] : input->shape()[-1];
  encodedShape.set(-1, bytesPerVector(nBits));
  std::vector<Expr> inputs = {input};
  if(rotation)
    inputs.push_back(rotation);
  return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash);
}

Expr rotator(Expr weights, int inDim, int nBits) {
  auto rotator = [](Expr out, const std::vector<Expr>& inputs) {
    inputs;
    fillRandomRotationMatrix(out->val(), out->graph()->allocator());
  };

  static const size_t rotatorHash = (size_t)&rotator;
  return lambda({weights}, {inDim, nBits}, Type::float32, rotator, rotatorHash);
}

Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNRows, bool noSort/*= false*/) {
  ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1],
           "Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]);

  int currBeamSize = encodedQuery->shape()[0];
  int batchSize    = encodedQuery->shape()[2];

  auto search = [=](Expr out, const std::vector<Expr>& inputs) {
    Expr encodedQuery   = inputs[0];
    Expr encodedWeights = inputs[1];

    int bytesPerVector = encodedWeights->shape()[-1];
    int wRows = encodedWeights->shape().elements() / bytesPerVector;

    // we use this with Factored Segmenter to skip the factor embeddings at the end
    if(firstNRows != 0)
      wRows = firstNRows;

    ABORT_IF(dimK > wRows, "k is larger than number of candidate values?"); // @TODO: use min(k, wRows) silently?

    IndexType* outData = out->val()->data<IndexType>();
    auto gather = [outData, dimK](IndexType rowId, IndexType k, IndexType kthColId, DistType /*dist*/) {
      outData[rowId * dimK + k] = kthColId;
    };

    Parameters params;
    params.k              = dimK;
    params.queryRows      = encodedQuery->val()->data<uint8_t>();
    params.numQueryRows   = encodedQuery->shape().elements() / bytesPerVector;
    params.codeRows       = encodedWeights->val()->data<uint8_t>();
    params.numCodeRows    = wRows;
    params.bytesPerVector = bytesPerVector;

    hammingTopK(params, gather);
  };

  Shape kShape({currBeamSize, batchSize, dimK});
  return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search);
}

Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows, bool abortIfDynamic) {
  int dim = weights->shape()[-1];

  Expr rotMat = nullptr;
  if(dim != nBits) {
    rotMat = weights->graph()->get("lsh_output_rotation");
    if(rotMat) {
      LOG_ONCE(info, "Reusing parameter LSH rotation matrix {} with shape {}", rotMat->name(), rotMat->shape());
    } else {
      ABORT_IF(abortIfDynamic, "Dynamic creation of LSH rotation matrix prohibited");
      LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits}));
      rotMat = rotator(weights, dim, nBits);
    }
  }

  Expr encodedWeights = weights->graph()->get("lsh_output_codes");
  if(encodedWeights) {
    LOG_ONCE(info, "Reusing parameter LSH code matrix {} with shape {}", encodedWeights->name(), encodedWeights->shape());
  } else {
    ABORT_IF(abortIfDynamic, "Dynamic creation of LSH code matrix prohibited");
    LOG_ONCE(info, "Creating ad-hoc code matrix with shape {}", Shape({weights->shape()[-2], lsh::bytesPerVector(nBits)}));
    encodedWeights = encode(weights, rotMat);
  }

  return searchEncoded(encode(query, rotMat), encodedWeights, k, firstNRows);
}

class RandomRotation : public inits::NodeInitializer {
public:
  void apply(Tensor tensor) override {
    auto sharedAllocator = allocator_.lock();
    ABORT_IF(!sharedAllocator, "Allocator in RandomRotation has not been set or expired");
    fillRandomRotationMatrix(tensor, sharedAllocator);
  }
};

Ptr<inits::NodeInitializer> randomRotation() {
  return New<RandomRotation>();
}

void addDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo) {
  auto weights = graph->get(paramInfo.name);
  int nBitsRot = paramInfo.nBits;

  ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name);

  int nBits = weights->shape()[-1];
  if(paramInfo.transpose)
    nBits = weights->shape()[-2];

  int nRows = weights->shape().elements() / nBits;

  Expr rotation;
  if(nBits != nBitsRot) {
    LOG(info, "Adding LSH rotation parameter {} with shape {}", paramInfo.rotationName, Shape({nBits, nBitsRot}));
    rotation = graph->param(paramInfo.rotationName, {nBits, nBitsRot}, inits::dummy(), Type::float32);
    nBits = nBitsRot;
  }

  int bytesPerVector = lsh::bytesPerVector(nBits);
  LOG(info, "Adding LSH encoded weights {} with shape {}", paramInfo.codesName, Shape({nRows, bytesPerVector}));
  auto codes = graph->param(paramInfo.codesName, {nRows, bytesPerVector}, inits::dummy(), Type::uint8);
}

void overwriteDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo) {
  Expr weights  = graph->get(paramInfo.name);
  Expr codes    = graph->get(paramInfo.codesName);
  Expr rotation = graph->get(paramInfo.rotationName);

  ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name);
  ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??");

  if(paramInfo.transpose) {
    weights = transpose(weights);
    graph->forward();
  }

  if(rotation) {
    fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator());
    encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator());
  } else {
    encode(codes->val(), weights->val());
  }
}

}
}