.. _program_listing_file_src_layers_lsh.cpp: Program Listing for File lsh.cpp ================================ |exhale_lsh| :ref:`Return to documentation for file ` (``src/layers/lsh.cpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "layers/lsh.h" #include "tensors/tensor_operators.h" #include "common/utils.h" #include "3rd_party/faiss/utils/hamming.h" #if BLAS_FOUND #include "3rd_party/faiss/VectorTransform.h" #endif #include "common/timer.h" #include "layers/lsh_impl.h" namespace marian { namespace lsh { int bytesPerVector(int nBits) { return (nBits + 7) / 8; } void fillRandomRotationMatrix(Tensor output, Ptr allocator) { #if BLAS_FOUND int nRows = output->shape()[-2]; int nBits = output->shape()[-1]; // @TODO re-implement using Marian code so it uses the correct random generator etc. faiss::RandomRotationMatrix rrot(nRows, nBits); // Then we do not need to use this seed at all rrot.init(5); // currently set to 5 following the default from FAISS, this could be any number really. // The faiss random rotation matrix is column major, hence we create a temporary tensor, // copy the rotation matrix into it and transpose to output. Shape tempShape = {nBits, nRows}; auto memory = allocator->alloc(requiredBytes(tempShape, output->type())); auto temp = TensorBase::New(memory, tempShape, output->type(), output->getBackend()); temp->set(rrot.A); TransposeND(output, temp, {0, 1, 3, 2}); allocator->free(memory); #else output; allocator; ABORT("LSH with rotation matrix requires Marian to be compiled with a BLAS library"); #endif } void encode(Tensor output, Tensor input) { int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix int nRows = input->shape().elements() / nBits; faiss::fvecs2bitvecs(input->data(), output->data(), (size_t)nBits, (size_t)nRows); } void encodeWithRotation(Tensor output, Tensor input, Tensor rotation, Ptr allocator) { int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix unless we rotate int nRows = input->shape().elements() / nBits; Tensor tempInput = input; MemoryPiece::PtrType memory; if(rotation) { int nBitsRot = rotation->shape()[-1]; Shape tempShape = {nRows, nBitsRot}; memory = allocator->alloc(requiredBytes(tempShape, rotation->type())); tempInput = TensorBase::New(memory, tempShape, rotation->type(), rotation->getBackend()); Prod(tempInput, input, rotation, false, false, 0.f, 1.f); } encode(output, tempInput); if(memory) allocator->free(memory); }; Expr encode(Expr input, Expr rotation) { auto encodeFwd = [](Expr out, const std::vector& inputs) { if(inputs.size() == 1) { encode(out->val(), inputs[0]->val()); } else if(inputs.size() == 2) { encodeWithRotation(out->val(), inputs[0]->val(), inputs[1]->val(), out->graph()->allocator()); } else { ABORT("Too many inputs to encode??"); } }; // Use the address of the first lambda function as an immutable hash. Making it static and const makes sure // that this hash value will not change. Next pass the hash into the lambda functor were it will be used // to identify this unique operation. Marian's ExpressionGraph can automatically memoize and identify nodes // that operate only on immutable nodes (parameters) and have the same hash. This way we make sure that the // codes node won't actually get recomputed throughout ExpressionGraph lifetime. `codes` will be reused // and the body of the lambda will not be called again. This does however build one index per graph. static const size_t encodeHash = (size_t)&encodeFwd; Shape encodedShape = input->shape(); int nBits = rotation ? rotation->shape()[-1] : input->shape()[-1]; encodedShape.set(-1, bytesPerVector(nBits)); std::vector inputs = {input}; if(rotation) inputs.push_back(rotation); return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash); } Expr rotator(Expr weights, int inDim, int nBits) { auto rotator = [](Expr out, const std::vector& inputs) { inputs; fillRandomRotationMatrix(out->val(), out->graph()->allocator()); }; static const size_t rotatorHash = (size_t)&rotator; return lambda({weights}, {inDim, nBits}, Type::float32, rotator, rotatorHash); } Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNRows, bool noSort/*= false*/) { ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1], "Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]); int currBeamSize = encodedQuery->shape()[0]; int batchSize = encodedQuery->shape()[2]; auto search = [=](Expr out, const std::vector& inputs) { Expr encodedQuery = inputs[0]; Expr encodedWeights = inputs[1]; int bytesPerVector = encodedWeights->shape()[-1]; int wRows = encodedWeights->shape().elements() / bytesPerVector; // we use this with Factored Segmenter to skip the factor embeddings at the end if(firstNRows != 0) wRows = firstNRows; ABORT_IF(dimK > wRows, "k is larger than number of candidate values?"); // @TODO: use min(k, wRows) silently? IndexType* outData = out->val()->data(); auto gather = [outData, dimK](IndexType rowId, IndexType k, IndexType kthColId, DistType /*dist*/) { outData[rowId * dimK + k] = kthColId; }; Parameters params; params.k = dimK; params.queryRows = encodedQuery->val()->data(); params.numQueryRows = encodedQuery->shape().elements() / bytesPerVector; params.codeRows = encodedWeights->val()->data(); params.numCodeRows = wRows; params.bytesPerVector = bytesPerVector; hammingTopK(params, gather); }; Shape kShape({currBeamSize, batchSize, dimK}); return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search); } Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows, bool abortIfDynamic) { int dim = weights->shape()[-1]; Expr rotMat = nullptr; if(dim != nBits) { rotMat = weights->graph()->get("lsh_output_rotation"); if(rotMat) { LOG_ONCE(info, "Reusing parameter LSH rotation matrix {} with shape {}", rotMat->name(), rotMat->shape()); } else { ABORT_IF(abortIfDynamic, "Dynamic creation of LSH rotation matrix prohibited"); LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits})); rotMat = rotator(weights, dim, nBits); } } Expr encodedWeights = weights->graph()->get("lsh_output_codes"); if(encodedWeights) { LOG_ONCE(info, "Reusing parameter LSH code matrix {} with shape {}", encodedWeights->name(), encodedWeights->shape()); } else { ABORT_IF(abortIfDynamic, "Dynamic creation of LSH code matrix prohibited"); LOG_ONCE(info, "Creating ad-hoc code matrix with shape {}", Shape({weights->shape()[-2], lsh::bytesPerVector(nBits)})); encodedWeights = encode(weights, rotMat); } return searchEncoded(encode(query, rotMat), encodedWeights, k, firstNRows); } class RandomRotation : public inits::NodeInitializer { public: void apply(Tensor tensor) override { auto sharedAllocator = allocator_.lock(); ABORT_IF(!sharedAllocator, "Allocator in RandomRotation has not been set or expired"); fillRandomRotationMatrix(tensor, sharedAllocator); } }; Ptr randomRotation() { return New(); } void addDummyParameters(Ptr graph, ParamConvInfo paramInfo) { auto weights = graph->get(paramInfo.name); int nBitsRot = paramInfo.nBits; ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name); int nBits = weights->shape()[-1]; if(paramInfo.transpose) nBits = weights->shape()[-2]; int nRows = weights->shape().elements() / nBits; Expr rotation; if(nBits != nBitsRot) { LOG(info, "Adding LSH rotation parameter {} with shape {}", paramInfo.rotationName, Shape({nBits, nBitsRot})); rotation = graph->param(paramInfo.rotationName, {nBits, nBitsRot}, inits::dummy(), Type::float32); nBits = nBitsRot; } int bytesPerVector = lsh::bytesPerVector(nBits); LOG(info, "Adding LSH encoded weights {} with shape {}", paramInfo.codesName, Shape({nRows, bytesPerVector})); auto codes = graph->param(paramInfo.codesName, {nRows, bytesPerVector}, inits::dummy(), Type::uint8); } void overwriteDummyParameters(Ptr graph, ParamConvInfo paramInfo) { Expr weights = graph->get(paramInfo.name); Expr codes = graph->get(paramInfo.codesName); Expr rotation = graph->get(paramInfo.rotationName); ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name); ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??"); if(paramInfo.transpose) { weights = transpose(weights); graph->forward(); } if(rotation) { fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator()); encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator()); } else { encode(codes->val(), weights->val()); } } } }