Program Listing for File shortlist.cpp

Return to documentation for file (src/data/shortlist.cpp)

#include "data/shortlist.h"
#include "microsoft/shortlist/utils/ParameterTree.h"
#include "marian.h"
#include "layers/lsh.h"

#include <queue>

namespace marian {
namespace data {

// cast current void pointer to T pointer and move forward by num elements
template <typename T>
const T* get(const void*& current, size_t num = 1) {
  const T* ptr = (const T*)current;
  current = (const T*)current + num;
  return ptr;
}

Shortlist::Shortlist(const std::vector<WordIndex>& indices)
  : indices_(indices),
    initialized_(false) {}

Shortlist::~Shortlist() {}

WordIndex Shortlist::reverseMap(int /*beamIdx*/, int /*batchIdx*/, int idx) const { return indices_[idx]; }

WordIndex Shortlist::tryForwardMap(WordIndex wIdx) const {
  auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
  if(first != indices_.end() && *first == wIdx)         // check if element not less than wIdx has been found and if equal to wIdx
    return (int)std::distance(indices_.begin(), first); // return coordinate if found
  else
    return npos;                                        // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
}

void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
  if (initialized_) {
    return;
  }

  auto forward = [this](Expr out, const std::vector<Expr>& ) {
    out->val()->set(indices_);
  };

  int k = (int) indices_.size();
  Shape kShape({k});
  indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);

  createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k);
  initialized_ = true;
}

Expr Shortlist::getIndicesExpr() const {
  int k = indicesExpr_->shape()[0];
  Expr out = reshape(indicesExpr_, {1, 1, k});
  return out;
}

void Shortlist::createCachedTensors(Expr weights,
                          bool isLegacyUntransposedW,
                          Expr b,
                          Expr lemmaEt,
                          int k) {
  ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");
  cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExpr_);
  cachedShortWt_ = reshape(cachedShortWt_, {1, 1, cachedShortWt_->shape()[0], cachedShortWt_->shape()[1]});

  if (b) {
    cachedShortb_ = index_select(b, -1, indicesExpr_);
  }

  if (lemmaEt) {
    cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_);
    cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k});
  }
}


LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize, bool abortIfDynamic)
: Shortlist(std::vector<WordIndex>()),
  k_(k), nbits_(nbits), lemmaSize_(lemmaSize), abortIfDynamic_(abortIfDynamic) {
}

WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
  //int currBeamSize = indicesExpr_->shape()[0];
  int currBatchSize = indicesExpr_->shape()[1];
  idx = (k_ * currBatchSize * beamIdx) + (k_ * batchIdx) + idx;
  assert(idx < indices_.size());
  return indices_[idx];
}

Expr LSHShortlist::getIndicesExpr() const {
  return indicesExpr_;
}

void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {

  ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
           "LSH index (--output-approx-knn) currently not implemented for GPU");

  indicesExpr_ = callback(lsh::search(input, weights, k_, nbits_, (int)lemmaSize_, abortIfDynamic_),
                          [this](Expr node) {
                            node->val()->get(indices_); // set the value of the field indices_ whenever the graph traverses this node
                          });

  createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k_);
}

void LSHShortlist::createCachedTensors(Expr weights,
                                       bool isLegacyUntransposedW,
                                       Expr b,
                                       Expr lemmaEt,
                                       int k) {
  int currBeamSize = indicesExpr_->shape()[0];
  int batchSize = indicesExpr_->shape()[1];
  ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");

  Expr indicesExprFlatten = reshape(indicesExpr_, {indicesExpr_->shape().elements()});

  cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprFlatten);
  cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]});

  if (b) {
    ABORT("Bias not supported with LSH");
    cachedShortb_ = index_select(b, -1, indicesExprFlatten);
    cachedShortb_ = reshape(cachedShortb_, {currBeamSize, batchSize, k, cachedShortb_->shape()[0]}); // not tested
  }

  if (lemmaEt) {
    int dim = lemmaEt->shape()[0];
    cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprFlatten);
    cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {dim, currBeamSize, batchSize, k});
    cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {1, 2, 0, 3});
  }
}

LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits, size_t lemmaSize, bool abortIfDynamic)
  : k_(k), nbits_(nbits), lemmaSize_(lemmaSize), abortIfDynamic_(abortIfDynamic) {
}

Ptr<Shortlist> LSHShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
  return New<LSHShortlist>(k_, nbits_, lemmaSize_, abortIfDynamic_);
}

QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
                                                         Ptr<const Vocab> srcVocab,
                                                         Ptr<const Vocab> trgVocab,
                                                         size_t srcIdx,
                                                         size_t /*trgIdx*/,
                                                         bool /*shared*/)
    : options_(options),
      srcVocab_(srcVocab),
      trgVocab_(trgVocab),
      srcIdx_(srcIdx) {
  std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist");

  ABORT_IF(vals.empty(), "No path to filter path given");
  std::string fname = vals[0];

  auto firstNum   = vals.size() > 1 ? std::stoi(vals[1]) : 0;
  auto bestNum    = vals.size() > 2 ? std::stoi(vals[2]) : 0;
  float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0;

  if(firstNum != 0 || bestNum != 0 || threshold != 0) {
    LOG(warn, "You have provided additional parameters for the Quicksand shortlist, but they are ignored.");
  }

  mmap_ = mio::mmap_source(fname); // memory-map the binary file once
  const void* current = mmap_.data(); // pointer iterator over binary file

  // compare magic number in binary file to make sure we are reading the right thing
  const int32_t MAGIC_NUMBER = 1234567890;
  int32_t header_magic_number = *get<int32_t>(current);
  ABORT_IF(header_magic_number != MAGIC_NUMBER, "Trying to mmap Quicksand shortlist but encountered wrong magic number");

  auto config = marian::quicksand::ParameterTree::FromBinaryReader(current);
  use16bit_ = config->GetBoolReq("use_16_bit");

  LOG(info, "[data] Mapping Quicksand shortlist from {}", fname);

  idSize_ = sizeof(int32_t);
  if (use16bit_) {
    idSize_ = sizeof(uint16_t);
  }

  // mmap the binary shortlist pieces
  numDefaultIds_        = *get<int32_t>(current);
  defaultIds_           =  get<int32_t>(current, numDefaultIds_);
  numSourceIds_         = *get<int32_t>(current);
  sourceLengths_        =  get<int32_t>(current, numSourceIds_);
  sourceOffsets_        =  get<int32_t>(current, numSourceIds_);
  numShortlistIds_      = *get<int32_t>(current);
  sourceToShortlistIds_ =  get<uint8_t>(current, idSize_ * numShortlistIds_);

  // display parameters
  LOG(info,
      "[data] Quicksand shortlist has {} source ids, {} default ids and {} shortlist ids",
      numSourceIds_,
      numDefaultIds_,
      numShortlistIds_);
}

Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
  auto srcBatch = (*batch)[srcIdx_];
  auto maxShortlistSize = trgVocab_->size();

  std::unordered_set<int32_t> indexSet;
  for(int32_t i = 0; i < numDefaultIds_ && i < maxShortlistSize; ++i) {
    int32_t id = defaultIds_[i];
    indexSet.insert(id);
  }

  // State
  std::vector<std::pair<const uint8_t*, int32_t>> curShortlists(maxShortlistSize);
  auto curShortlistIt = curShortlists.begin();

  // Because we might fill up our shortlist before reaching max_shortlist_size, we fill the shortlist in order of rank.
  // E.g., first rank of word 0, first rank of word 1, ... second rank of word 0, ...
  int32_t maxLength = 0;
  for (Word word : srcBatch->data()) {
    int32_t sourceId = (int32_t)word.toWordIndex();
    srcVocab_->transcodeToShortlistInPlace((WordIndex*)&sourceId, 1);

    if (sourceId < numSourceIds_) { // if it's a valid source id
      const uint8_t* curShortlistIds = sourceToShortlistIds_ + idSize_ * sourceOffsets_[sourceId]; // start position for mapping
      int32_t length = sourceLengths_[sourceId]; // how many mappings are there
      curShortlistIt->first  = curShortlistIds;
      curShortlistIt->second = length;
      curShortlistIt++;

      if (length > maxLength)
        maxLength = length;
    }
  }

  // collect the actual shortlist mappings
  for (int32_t i = 0; i < maxLength && indexSet.size() < maxShortlistSize; i++) {
    for (int32_t j = 0; j < curShortlists.size() && indexSet.size() < maxShortlistSize; j++) {
      int32_t length = curShortlists[j].second;
      if (i < length) {
        const uint8_t* source_shortlist_ids_bytes = curShortlists[j].first;
        int32_t id = 0;
        if (use16bit_) {
          const uint16_t* source_shortlist_ids = reinterpret_cast<const uint16_t*>(source_shortlist_ids_bytes);
          id = (int32_t)source_shortlist_ids[i];
        }
        else {
          const int32_t* source_shortlist_ids = reinterpret_cast<const int32_t*>(source_shortlist_ids_bytes);
          id = source_shortlist_ids[i];
        }
        indexSet.insert(id);
      }
    }
  }

  // turn into vector and sort (selected indices)
  std::vector<WordIndex> indices;
  indices.reserve(indexSet.size());
  for(auto i : indexSet)
    indices.push_back((WordIndex)i);

  std::sort(indices.begin(), indices.end());
  return New<Shortlist>(indices);
}

Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
                                                 Ptr<const Vocab> srcVocab,
                                                 Ptr<const Vocab> trgVocab,
                                                 const std::vector<int> &lshOpts,
                                                 size_t srcIdx,
                                                 size_t trgIdx,
                                                 bool shared) {
  if (lshOpts.size()) {
    assert(lshOpts.size() == 2);
    size_t lemmaSize = trgVocab->lemmaSize();
    return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1], lemmaSize, /*abortIfDynamic=*/false);
  }
  else {
    std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
    ABORT_IF(vals.empty(), "No path to shortlist given");
    std::string fname = vals[0];
    if(isBinaryShortlist(fname)){
        return New<BinaryShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
    } else if(filesystem::Path(fname).extension().string() == ".bin") {
      return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
    } else {
      return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
    }
  }
}

bool isBinaryShortlist(const std::string& fileName){
  uint64_t magic;
  io::InputFileStream in(fileName);
  in.read((char*)(&magic), sizeof(magic));
  return in && (magic == BINARY_SHORTLIST_MAGIC);
}

void BinaryShortlistGenerator::contentCheck() {
  bool failFlag = 0;
  // The offset table has to be within the size of shortlists.
  for(int i = 0; i < wordToOffsetSize_-1; i++)
    failFlag |= wordToOffset_[i] >= shortListsSize_;

  // The last element of wordToOffset_ must equal shortListsSize_
  failFlag |= wordToOffset_[wordToOffsetSize_-1] != shortListsSize_;

  // The vocabulary indices have to be within the vocabulary size.
  size_t vSize = trgVocab_->size();
  for(int j = 0; j < shortListsSize_; j++)
    failFlag |= shortLists_[j] >= vSize;
  ABORT_IF(failFlag, "Error: shortlist indices are out of bounds");
}

// load shortlist from buffer
void BinaryShortlistGenerator::load(const void* ptr_void, size_t blobSize, bool check /*= true*/) {
  /* File layout:
   * header
   * wordToOffset array
   * shortLists array
   */
  ABORT_IF(blobSize < sizeof(Header), "Shortlist length {} too short to have a header", blobSize);

  const char *ptr = static_cast<const char*>(ptr_void);
  const Header &header = *reinterpret_cast<const Header*>(ptr);
  ptr += sizeof(Header);
  ABORT_IF(header.magic != BINARY_SHORTLIST_MAGIC, "Incorrect magic in binary shortlist");

  uint64_t expectedSize = sizeof(Header) + header.wordToOffsetSize * sizeof(uint64_t) + header.shortListsSize * sizeof(WordIndex);
  ABORT_IF(expectedSize != blobSize, "Shortlist header claims file size should be {} but file is {}", expectedSize, blobSize);

  if (check) {
    uint64_t checksumActual = util::hashMem<uint64_t, uint64_t>(&header.firstNum, (blobSize - sizeof(header.magic) - sizeof(header.checksum)) / sizeof(uint64_t));
    ABORT_IF(checksumActual != header.checksum, "checksum check failed: this binary shortlist is corrupted");
  }

  firstNum_ = header.firstNum;
  bestNum_ = header.bestNum;
  LOG(info, "[data] Lexical short list firstNum {} and bestNum {}", firstNum_, bestNum_);

  wordToOffsetSize_ = header.wordToOffsetSize;
  shortListsSize_ = header.shortListsSize;

  // Offsets right after header.
  wordToOffset_ = reinterpret_cast<const uint64_t*>(ptr);
  ptr += wordToOffsetSize_ * sizeof(uint64_t);

  shortLists_ = reinterpret_cast<const WordIndex*>(ptr);

  // Verify offsets and vocab ids are within bounds if requested by user.
  if(check)
    contentCheck();
}

// load shortlist from file
void BinaryShortlistGenerator::load(const std::string& filename, bool check /*=true*/) {
  std::error_code error;
  mmapMem_.map(filename, error);
  ABORT_IF(error, "Error mapping file: {}", error.message());
  load(mmapMem_.data(), mmapMem_.mapped_length(), check);
}

BinaryShortlistGenerator::BinaryShortlistGenerator(Ptr<Options> options,
                                                   Ptr<const Vocab> srcVocab,
                                                   Ptr<const Vocab> trgVocab,
                                                   size_t srcIdx /*= 0*/,
                                                   size_t /*trgIdx = 1*/,
                                                   bool shared /*= false*/)
    : options_(options),
      srcVocab_(srcVocab),
      trgVocab_(trgVocab),
      srcIdx_(srcIdx),
      shared_(shared) {

  std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist");
  ABORT_IF(vals.empty(), "No path to shortlist file given");
  std::string fname = vals[0];

  if(isBinaryShortlist(fname)){
    bool check = vals.size() > 1 ? std::stoi(vals[1]) : 1;
    LOG(info, "[data] Loading binary shortlist as {} {}", fname, check);
    load(fname, check);
  }
  else{
    firstNum_ = vals.size() > 1 ? std::stoi(vals[1]) : 100;
    bestNum_ = vals.size() > 2 ? std::stoi(vals[2]) : 100;
    float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0;
    LOG(info, "[data] Importing text lexical shortlist as {} {} {} {}",
        fname, firstNum_, bestNum_, threshold);
    import(fname, threshold);
  }
}

BinaryShortlistGenerator::BinaryShortlistGenerator(const void *ptr_void,
                                                   const size_t blobSize,
                                                   Ptr<const Vocab> srcVocab,
                                                   Ptr<const Vocab> trgVocab,
                                                   size_t srcIdx /*= 0*/,
                                                   size_t /*trgIdx = 1*/,
                                                   bool shared /*= false*/,
                                                   bool check /*= true*/)
    : srcVocab_(srcVocab),
      trgVocab_(trgVocab),
      srcIdx_(srcIdx),
      shared_(shared) {
  load(ptr_void, blobSize, check);
}

Ptr<Shortlist> BinaryShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
  auto srcBatch = (*batch)[srcIdx_];
  size_t srcVocabSize = srcVocab_->size();
  size_t trgVocabSize = trgVocab_->size();

  // Since V=trgVocab_->size() is not large, anchor the time and space complexity to O(V).
  // Attempt to squeeze the truth tables into CPU cache
  std::vector<bool> srcTruthTable(srcVocabSize, 0);  // holds selected source words
  std::vector<bool> trgTruthTable(trgVocabSize, 0);  // holds selected target words

  // add firstNum most frequent words
  for(WordIndex i = 0; i < firstNum_ && i < trgVocabSize; ++i)
    trgTruthTable[i] = 1;

  // collect unique words from source
  // add aligned target words: mark trgTruthTable[word] to 1
  for(auto word : srcBatch->data()) {
    WordIndex srcIndex = word.toWordIndex();
    if(shared_)
      trgTruthTable[srcIndex] = 1;
    // If srcIndex has not been encountered, add the corresponding target words
    if (!srcTruthTable[srcIndex]) {
      for (uint64_t j = wordToOffset_[srcIndex]; j < wordToOffset_[srcIndex+1]; j++)
        trgTruthTable[shortLists_[j]] = 1;
      srcTruthTable[srcIndex] = 1;
    }
  }

  // Due to the 'multiple-of-eight' issue, the following O(N) patch is inserted
  size_t trgTruthTableOnes = 0;   // counter for no. of selected target words
  for (size_t i = 0; i < trgVocabSize; i++) {
    if(trgTruthTable[i])
      trgTruthTableOnes++;
  }

  // Ensure that the generated vocabulary items from a shortlist are a multiple-of-eight
  // This is necessary until intgemm supports non-multiple-of-eight matrices.
  for (size_t i = firstNum_; i < trgVocabSize && trgTruthTableOnes%8!=0; i++){
    if (!trgTruthTable[i]){
      trgTruthTable[i] = 1;
      trgTruthTableOnes++;
    }
  }

  // turn selected indices into vector and sort (Bucket sort: O(V))
  std::vector<WordIndex> indices;
  for (WordIndex i = 0; i < trgVocabSize; i++) {
    if(trgTruthTable[i])
      indices.push_back(i);
  }

  return New<Shortlist>(indices);
}

void BinaryShortlistGenerator::dump(const std::string& fileName) const {
  ABORT_IF(mmapMem_.is_open(),"No need to dump again");
  LOG(info, "[data] Saving binary shortlist dump to {}", fileName);
  saveBlobToFile(fileName);
}

void BinaryShortlistGenerator::import(const std::string& filename, double threshold) {
  io::InputFileStream in(filename);
  std::string src, trg;

  // Read text file
  std::vector<std::unordered_map<WordIndex, float>> srcTgtProbTable(srcVocab_->size());
  float prob;

  while(in >> trg >> src >> prob) {
    if(src == "NULL" || trg == "NULL")
      continue;

    auto sId = (*srcVocab_)[src].toWordIndex();
    auto tId = (*trgVocab_)[trg].toWordIndex();

    if(srcTgtProbTable[sId][tId] < prob)
      srcTgtProbTable[sId][tId] = prob;
  }

  // Create priority queue and count
  std::vector<std::priority_queue<std::pair<float, WordIndex>>> vpq;
  uint64_t shortListsSize = 0;

  vpq.resize(srcTgtProbTable.size());
  for(WordIndex sId = 0; sId < srcTgtProbTable.size(); sId++) {
    uint64_t shortListsSizeCurrent = 0;
    for(auto entry : srcTgtProbTable[sId]) {
      if (entry.first>=threshold) {
        vpq[sId].push(std::make_pair(entry.second, entry.first));
        if(shortListsSizeCurrent < bestNum_)
          shortListsSizeCurrent++;
      }
    }
    shortListsSize += shortListsSizeCurrent;
  }

  wordToOffsetSize_ = vpq.size() + 1;
  shortListsSize_ = shortListsSize;

  // Generate a binary blob
  blob_.resize(sizeof(Header) + wordToOffsetSize_ * sizeof(uint64_t) + shortListsSize_ * sizeof(WordIndex));
  struct Header* pHeader = (struct Header *)blob_.data();
  pHeader->magic = BINARY_SHORTLIST_MAGIC;
  pHeader->firstNum = firstNum_;
  pHeader->bestNum = bestNum_;
  pHeader->wordToOffsetSize = wordToOffsetSize_;
  pHeader->shortListsSize = shortListsSize_;
  uint64_t* wordToOffset = (uint64_t*)((char *)pHeader + sizeof(Header));
  WordIndex* shortLists = (WordIndex*)((char*)wordToOffset + wordToOffsetSize_*sizeof(uint64_t));

  uint64_t shortlistIdx = 0;
  for (size_t i = 0; i < wordToOffsetSize_ - 1; i++) {
    wordToOffset[i] = shortlistIdx;
    for(int popcnt = 0; popcnt < bestNum_ && !vpq[i].empty(); popcnt++) {
      shortLists[shortlistIdx] = vpq[i].top().second;
      shortlistIdx++;
      vpq[i].pop();
    }
  }
  wordToOffset[wordToOffsetSize_-1] = shortlistIdx;

  // Sort word indices for each shortlist
  for(int i = 1; i < wordToOffsetSize_; i++) {
    std::sort(&shortLists[wordToOffset[i-1]], &shortLists[wordToOffset[i]]);
  }
  pHeader->checksum = (uint64_t)util::hashMem<uint64_t>((uint64_t *)blob_.data()+2,
                                                        blob_.size()/sizeof(uint64_t)-2);

  wordToOffset_ = wordToOffset;
  shortLists_ = shortLists;
}

void BinaryShortlistGenerator::saveBlobToFile(const std::string& fileName) const {
  io::OutputFileStream outTop(fileName);
  outTop.write(blob_.data(), blob_.size());
}

}  // namespace data
}  // namespace marian