Program Listing for File shortlist.h¶
↰ Return to documentation for file (src/data/shortlist.h
)
#pragma once
#include "common/config.h"
#include "common/definitions.h"
#include "common/file_stream.h"
#include "data/corpus_base.h"
#include "data/types.h"
#include "mio/mio.hpp"
#include <random>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <iostream>
#include <algorithm>
#include <limits>
namespace faiss {
struct IndexLSH;
}
namespace marian {
namespace data {
class Shortlist {
protected:
std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings
Expr indicesExpr_; // cache an expression that contains the short list indices
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
bool initialized_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
void createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
int k);
public:
static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos
Shortlist(const std::vector<WordIndex>& indices);
virtual ~Shortlist();
virtual bool isDynamic() const { return false; }
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const;
virtual WordIndex tryForwardMap(WordIndex wIdx) const;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
virtual Expr getIndicesExpr() const;
virtual Expr getCachedShortWt() const { return cachedShortWt_; }
virtual Expr getCachedShortb() const { return cachedShortb_; }
virtual Expr getCachedShortLemmaEt() const { return cachedShortLemmaEt_; }
};
class ShortlistGenerator {
public:
virtual ~ShortlistGenerator() {}
virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const = 0;
// Writes text version of (possibly) pruned short list to file
// with given prefix and implementation-specific suffixes.
virtual void dump(const std::string& /*prefix*/) const {
ABORT("Not implemented");
}
};
// faster inference inspired by these 2 papers
// https://arxiv.org/pdf/1903.03129.pdf https://arxiv.org/pdf/1806.00588.pdf
class LSHShortlist: public Shortlist {
private:
int k_; // number of candidates returned from each input
int nbits_; // length of hash
size_t lemmaSize_; // vocab size
bool abortIfDynamic_; // if true disallow dynamic allocation for encoded weights and rotation matrix (only allow use of pre-allocated parameters)
static Ptr<faiss::IndexLSH> index_; // LSH index to store all possible candidates
static std::mutex mutex_;
void createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
int k);
public:
LSHShortlist(int k, int nbits, size_t lemmaSize, bool abortIfDynamic = false);
virtual bool isDynamic() const override { return true; }
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
virtual Expr getIndicesExpr() const override;
};
class LSHShortlistGenerator : public ShortlistGenerator {
private:
int k_;
int nbits_;
size_t lemmaSize_;
bool abortIfDynamic_;
public:
LSHShortlistGenerator(int k, int nbits, size_t lemmaSize, bool abortIfDynamic = false);
Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
};
// Intended for use during training in the future, currently disabled
#if 0
class SampledShortlistGenerator : public ShortlistGenerator {
private:
Ptr<Options> options_;
size_t maxVocab_{50000};
size_t total_{10000};
size_t firstNum_{1000};
size_t srcIdx_;
size_t trgIdx_;
bool shared_{false};
// static thread_local std::random_device rd_;
static thread_local std::unique_ptr<std::mt19937> gen_;
public:
SampledShortlistGenerator(Ptr<Options> options,
size_t srcIdx = 0,
size_t trgIdx = 1,
bool shared = false)
: options_(options),
srcIdx_(srcIdx),
trgIdx_(trgIdx),
shared_(shared)
{ }
virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override {
auto srcBatch = (*batch)[srcIdx_];
auto trgBatch = (*batch)[trgIdx_];
// add firstNum most frequent words
std::unordered_set<WordIndex> indexSet;
for(WordIndex i = 0; i < firstNum_ && i < maxVocab_; ++i)
indexSet.insert(i);
// add all words from ground truth
for(auto i : trgBatch->data())
indexSet.insert(i.toWordIndex());
// add all words from source
if(shared_)
for(auto i : srcBatch->data())
indexSet.insert(i.toWordIndex());
std::uniform_int_distribution<> dis((int)firstNum_, (int)maxVocab_);
if (gen_ == NULL)
gen_.reset(new std::mt19937(std::random_device{}()));
while(indexSet.size() < total_ && indexSet.size() < maxVocab_)
indexSet.insert(dis(*gen_));
// turn into vector and sort (selected indices)
std::vector<WordIndex> idx(indexSet.begin(), indexSet.end());
std::sort(idx.begin(), idx.end());
// assign new shifted position
std::unordered_map<WordIndex, WordIndex> pos;
std::vector<WordIndex> reverseMap;
for(WordIndex i = 0; i < idx.size(); ++i) {
pos[idx[i]] = i;
reverseMap.push_back(idx[i]);
}
Words mapped;
for(auto i : trgBatch->data()) {
// mapped postions for cross-entropy
mapped.push_back(Word::fromWordIndex(pos[i.toWordIndex()]));
}
return New<Shortlist>(idx, mapped, reverseMap);
}
};
#endif
class LexicalShortlistGenerator : public ShortlistGenerator {
private:
Ptr<Options> options_;
Ptr<const Vocab> srcVocab_;
Ptr<const Vocab> trgVocab_;
size_t srcIdx_;
bool shared_{false};
size_t firstNum_{100};
size_t bestNum_{100};
std::vector<std::unordered_map<WordIndex, float>> data_; // [WordIndex src] -> [WordIndex tgt] -> P_trans(tgt|src) --@TODO: rename data_ accordingly
void load(const std::string& fname) {
io::InputFileStream in(fname);
std::string src, trg;
float prob;
while(in >> trg >> src >> prob) {
// @TODO: change this to something safer other than NULL
if(src == "NULL" || trg == "NULL")
continue;
auto sId = (*srcVocab_)[src].toWordIndex();
auto tId = (*trgVocab_)[trg].toWordIndex();
if(data_.size() <= sId)
data_.resize(sId + 1);
data_[sId][tId] = prob;
}
}
void prune(float threshold = 0.f) {
size_t i = 0;
for(auto& probs : data_) {
std::vector<std::pair<float, WordIndex>> sorter;
for(auto& it : probs)
sorter.emplace_back(it.second, it.first);
std::sort(
sorter.begin(), sorter.end(), std::greater<std::pair<float, WordIndex>>()); // sort by prob
probs.clear();
for(auto& it : sorter) {
if(probs.size() < bestNum_ && it.first > threshold)
probs[it.second] = it.first;
else
break;
}
++i;
}
}
public:
LexicalShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
size_t srcIdx = 0,
size_t /*trgIdx*/ = 1,
bool shared = false)
: options_(options),
srcVocab_(srcVocab),
trgVocab_(trgVocab),
srcIdx_(srcIdx),
shared_(shared) {
std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist");
ABORT_IF(vals.empty(), "No path to filter path given");
std::string fname = vals[0];
firstNum_ = vals.size() > 1 ? std::stoi(vals[1]) : 100;
bestNum_ = vals.size() > 2 ? std::stoi(vals[2]) : 100;
float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0;
std::string dumpPath = vals.size() > 4 ? vals[4] : "";
LOG(info,
"[data] Loading lexical shortlist as {} {} {} {}",
fname,
firstNum_,
bestNum_,
threshold);
// @TODO: Load and prune in one go.
load(fname);
prune(threshold);
if(!dumpPath.empty())
dump(dumpPath);
}
virtual void dump(const std::string& prefix) const override {
// Dump top most frequent words from target vocabulary
LOG(info, "[data] Saving shortlist dump to {}", prefix + ".{top,dic}");
io::OutputFileStream outTop(prefix + ".top");
for(WordIndex i = 0; i < firstNum_ && i < trgVocab_->size(); ++i)
outTop << (*trgVocab_)[Word::fromWordIndex(i)] << std::endl;
// Dump translation pairs from dictionary
io::OutputFileStream outDic(prefix + ".dic");
for(WordIndex srcId = 0; srcId < data_.size(); srcId++) {
for(auto& it : data_[srcId]) {
auto trgId = it.first;
outDic << (*srcVocab_)[Word::fromWordIndex(srcId)] << "\t" << (*trgVocab_)[Word::fromWordIndex(trgId)] << std::endl;
}
}
}
virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override {
auto srcBatch = (*batch)[srcIdx_];
// add firstNum most frequent words
std::unordered_set<WordIndex> indexSet;
for(WordIndex i = 0; i < firstNum_ && i < trgVocab_->size(); ++i)
indexSet.insert(i);
// add all words from ground truth
// for(auto i : trgBatch->data())
// indexSet.insert(i.toWordIndex());
// collect unique words form source
std::unordered_set<WordIndex> srcSet;
for(auto i : srcBatch->data())
srcSet.insert(i.toWordIndex());
// add aligned target words
for(auto i : srcSet) {
if(shared_)
indexSet.insert(i);
for(auto& it : data_[i])
indexSet.insert(it.first);
}
// Ensure that the generated vocabulary items from a shortlist are a multiple-of-eight
// This is necessary until intgemm supports non-multiple-of-eight matrices.
// TODO better solution here? This could potentially be slow.
WordIndex i = static_cast<WordIndex>(firstNum_);
while (indexSet.size() % 8 != 0) {
indexSet.insert(i);
i++;
}
// turn into vector and sort (selected indices)
std::vector<WordIndex> indices(indexSet.begin(), indexSet.end());
std::sort(indices.begin(), indices.end());
return New<Shortlist>(indices);
}
};
class FakeShortlistGenerator : public ShortlistGenerator {
private:
std::vector<WordIndex> indices_;
public:
FakeShortlistGenerator(const std::unordered_set<WordIndex>& indexSet)
: indices_(indexSet.begin(), indexSet.end()) {
std::sort(indices_.begin(), indices_.end());
}
Ptr<Shortlist> generate(Ptr<data::CorpusBatch> /*batch*/) const override {
return New<Shortlist>(indices_);
}
};
/*
Legacy binary shortlist for Microsoft-internal use.
*/
class QuicksandShortlistGenerator : public ShortlistGenerator {
private:
Ptr<Options> options_;
Ptr<const Vocab> srcVocab_;
Ptr<const Vocab> trgVocab_;
size_t srcIdx_;
mio::mmap_source mmap_;
// all the quicksand bits go here
bool use16bit_{false};
int32_t numDefaultIds_;
int32_t idSize_;
const int32_t* defaultIds_{nullptr};
int32_t numSourceIds_{0};
const int32_t* sourceLengths_{nullptr};
const int32_t* sourceOffsets_{nullptr};
int32_t numShortlistIds_{0};
const uint8_t* sourceToShortlistIds_{nullptr};
public:
QuicksandShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
size_t srcIdx = 0,
size_t trgIdx = 1,
bool shared = false);
virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
};
/*
Shortlist factory to create correct type of shortlist. Currently assumes everything is a text shortlist
unless the extension is *.bin for which the Microsoft legacy binary shortlist is used.
*/
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
const std::vector<int> &lshOpts,
size_t srcIdx = 0,
size_t trgIdx = 1,
bool shared = false);
// Magic signature for binary shortlist:
// ASCII and Unicode text files never start with the following 64 bits
const uint64_t BINARY_SHORTLIST_MAGIC = 0xF11A48D5013417F5;
bool isBinaryShortlist(const std::string& fileName);
class BinaryShortlistGenerator : public ShortlistGenerator {
private:
Ptr<Options> options_;
Ptr<const Vocab> srcVocab_;
Ptr<const Vocab> trgVocab_;
size_t srcIdx_;
bool shared_{false};
uint64_t firstNum_{100}; // baked into binary header
uint64_t bestNum_{100}; // baked into binary header
// shortlist is stored in a skip list
// [&shortLists_[wordToOffset_[word]], &shortLists_[wordToOffset_[word+1]])
// is a sorted array of word indices in the shortlist for word
mio::mmap_source mmapMem_;
uint64_t wordToOffsetSize_;
uint64_t shortListsSize_;
const uint64_t *wordToOffset_;
const WordIndex *shortLists_;
std::vector<char> blob_; // binary blob
struct Header {
uint64_t magic; // BINARY_SHORTLIST_MAGIC
uint64_t checksum; // util::hashMem<uint64_t, uint64_t> from &firstNum to end of file.
uint64_t firstNum; // Limits used to create the shortlist.
uint64_t bestNum;
uint64_t wordToOffsetSize; // Length of wordToOffset_ array.
uint64_t shortListsSize; // Length of shortLists_ array.
};
void contentCheck();
// load shortlist from buffer
void load(const void* ptr_void, size_t blobSize, bool check = true);
// load shortlist from file
void load(const std::string& filename, bool check=true);
// import text shortlist from file
void import(const std::string& filename, double threshold);
// save blob to file (called by dump)
void saveBlobToFile(const std::string& filename) const;
public:
BinaryShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
size_t srcIdx = 0,
size_t /*trgIdx*/ = 1,
bool shared = false);
// construct directly from buffer
BinaryShortlistGenerator(const void* ptr_void,
const size_t blobSize,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
size_t srcIdx = 0,
size_t /*trgIdx*/ = 1,
bool shared = false,
bool check = true);
~BinaryShortlistGenerator(){
mmapMem_.unmap();
}
virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
virtual void dump(const std::string& fileName) const override;
};
} // namespace data
} // namespace marian