Program Listing for File bert.h¶
↰ Return to documentation for file (src/models/bert.h
)
#pragma once
#include "data/corpus_base.h"
#include "models/encoder_classifier.h"
#include "models/transformer.h" // @BUGBUG: transformer.h is large and was meant to be compiled separately
#include "data/rng_engine.h"
namespace marian {
namespace data {
class BertBatch : public CorpusBatch {
private:
std::vector<IndexType> maskedPositions_;
Words maskedWords_;
std::vector<IndexType> sentenceIndices_;
std::string maskSymbol_;
std::string sepSymbol_;
std::string clsSymbol_;
// Selects a random word from the vocabulary
std::unique_ptr<std::uniform_int_distribution<WordIndex>> randomWord_;
// Selects a random integer between 0 and 99
std::unique_ptr<std::uniform_real_distribution<float>> randomPercent_;
// Word ids of words that should not be masked, e.g. separators, padding
std::unordered_set<Word> dontMask_;
// Masking function, i.e. replaces a chosen word with either
// a [MASK] symbol, itself or a random word
Word maskOut(Word word, Word mask, std::mt19937& engine) {
auto subBatch = subBatches_.front();
// @TODO: turn those threshold into parameters, adjustable from command line
float r = (*randomPercent_)(engine);
if (r < 0.1f) { // for 10% of cases return same word
return word;
} else if (r < 0.2f) { // for 10% return random word
Word randWord = Word::fromWordIndex((*randomWord_)(engine));
if(dontMask_.count(randWord) > 0) // some words, e.g. [CLS] or </s>, may not be used as random words
return mask; // for those, return the mask symbol instead
else
return randWord; // else return the random word
} else { // for 80% of words apply mask symbol
return mask;
}
}
public:
// Takes a corpus batch, random engine (for deterministic behavior) and the masking percentage.
// Also sets special vocabulary items given on command line.
BertBatch(Ptr<CorpusBatch> batch,
std::mt19937& engine,
float maskFraction,
const std::string& maskSymbol,
const std::string& sepSymbol,
const std::string& clsSymbol,
int dimTypeVocab)
: CorpusBatch(*batch),
maskSymbol_(maskSymbol), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {
// BERT expects a textual first stream and a second stream with class labels
auto subBatch = subBatches_.front();
const auto& vocab = *subBatch->vocab();
// Initialize to sample random vocab id
randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size()));
// Initialize to sample random percentage
randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f));
auto& words = subBatch->data();
// Get word id of special symbols
Word maskId = vocab[maskSymbol_];
Word clsId = vocab[clsSymbol_];
Word sepId = vocab[sepSymbol_];
ABORT_IF(maskId == vocab.getUnkId(),
"BERT masking symbol {} not found in vocabulary", maskSymbol_);
ABORT_IF(sepId == vocab.getUnkId(),
"BERT separator symbol {} not found in vocabulary", sepSymbol_);
ABORT_IF(clsId == vocab.getUnkId(),
"BERT class symbol {} not found in vocabulary", clsSymbol_);
dontMask_.insert(clsId); // don't mask class token
dontMask_.insert(sepId); // don't mask separator token
dontMask_.insert(vocab.getEosId()); // don't mask </s>
// it's ok to mask <unk>
std::vector<int> selected;
selected.reserve(words.size());
for(int i = 0; i < words.size(); ++i) // collect words among which we will mask
if(dontMask_.count(words[i]) == 0) // do not add indices of special words
selected.push_back(i);
std::shuffle(selected.begin(), selected.end(), engine); // randomize positions
selected.resize((size_t)std::ceil(selected.size() * maskFraction)); // select first x percent from shuffled indices
for(int i : selected) {
maskedPositions_.push_back(i); // where is the original word?
maskedWords_.push_back(words[i]); // what is the original word?
words[i] = maskOut(words[i], maskId, engine); // mask that position
}
annotateSentenceIndices(dimTypeVocab);
}
BertBatch(Ptr<CorpusBatch> batch,
const std::string& sepSymbol,
const std::string& clsSymbol,
int dimTypeVocab)
: CorpusBatch(*batch),
maskSymbol_("dummy"), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {
annotateSentenceIndices(dimTypeVocab);
}
void annotateSentenceIndices(int dimTypeVocab) {
// BERT expects a textual first stream and a second stream with class labels
auto subBatch = subBatches_.front();
const auto& vocab = *subBatch->vocab();
auto& words = subBatch->data();
// Get word id of special symbols
Word sepId = vocab[sepSymbol_];
ABORT_IF(sepId == vocab.getUnkId(),
"BERT separator symbol {} not found in vocabulary", sepSymbol_);
int dimBatch = (int)subBatch->batchSize();
int dimWords = (int)subBatch->batchWidth();
const size_t maxSentPos = dimTypeVocab;
// create indices for BERT sentence embeddings A and B
sentenceIndices_.resize(words.size()); // each word is either in sentence A or B
std::vector<IndexType> sentPos(dimBatch, 0); // initialize each batch entry with being A [0]
for(int i = 0; i < dimWords; ++i) { // advance word-wise
for(int j = 0; j < dimBatch; ++j) { // scan batch-wise
int k = i * dimBatch + j;
sentenceIndices_[k] = sentPos[j]; // set to current sentence position for batch entry, max position 1.
if(words[k] == sepId && sentPos[j] < maxSentPos) { // if current word is a separator and not beyond range
sentPos[j]++; // then increase sentence position for batch entry (to B [1])
}
}
}
}
const std::vector<IndexType>& bertMaskedPositions() { return maskedPositions_; }
const Words& bertMaskedWords() { return maskedWords_; }
const std::vector<IndexType>& bertSentenceIndices() { return sentenceIndices_; }
};
}
class BertEncoderClassifier : public EncoderClassifier, public data::RNGEngine { // @TODO: this random engine is not being serialized right now
public:
BertEncoderClassifier(Ptr<Options> options)
: EncoderClassifier(options) {}
std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
std::string modelType = opt<std::string>("type");
int dimTypeVocab = opt<int>("bert-type-vocab-size");
// intercept batch and annotate with BERT-specific concepts
Ptr<data::BertBatch> bertBatch;
if(modelType == "bert") { // full BERT pre-training
bertBatch = New<data::BertBatch>(batch,
eng_,
opt<float>("bert-masking-fraction", 0.15f), // 15% by default according to paper
opt<std::string>("bert-mask-symbol"),
opt<std::string>("bert-sep-symbol"),
opt<std::string>("bert-class-symbol"),
dimTypeVocab);
} else if(modelType == "bert-classifier") { // we are probably fine-tuning a BERT model for a classification task
bertBatch = New<data::BertBatch>(batch,
opt<std::string>("bert-sep-symbol"),
opt<std::string>("bert-class-symbol"),
dimTypeVocab); // only annotate sentence separators
} else {
ABORT("Unknown BERT-style model: {}", modelType);
}
return EncoderClassifier::apply(graph, bertBatch, clearGraph);
}
// for externally created BertBatch for instance in BertValidator
std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::BertBatch> bertBatch, bool clearGraph) {
return EncoderClassifier::apply(graph, bertBatch, clearGraph);
}
};
class BertEncoder : public EncoderTransformer {
using EncoderTransformer::EncoderTransformer;
public:
Expr addSentenceEmbeddings(Expr embeddings,
Ptr<data::CorpusBatch> batch,
bool learnedPosEmbeddings) const {
Ptr<data::BertBatch> bertBatch = std::dynamic_pointer_cast<data::BertBatch>(batch);
ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training or fine-tuning");
int dimEmb = embeddings->shape()[-1];
int dimBatch = embeddings->shape()[-2];
int dimWords = embeddings->shape()[-3];
int dimTypeVocab = opt<int>("bert-type-vocab-size", 2);
Expr signal;
if(learnedPosEmbeddings) {
auto sentenceEmbeddings = embedding()
("prefix", "Wtype")
("dimVocab", dimTypeVocab) // sentence A or sentence B
("dimEmb", dimEmb)
.construct(graph_);
signal = sentenceEmbeddings->applyIndices(bertBatch->bertSentenceIndices(), {dimWords, dimBatch, dimEmb});
} else {
// @TODO: factory for positional embeddings?
// constant sinusoidal position embeddings, no backprob
auto sentenceEmbeddingsExpr = graph_->constant({2, dimEmb}, inits::sinusoidalPositionEmbeddings(0));
signal = rows(sentenceEmbeddingsExpr, bertBatch->bertSentenceIndices());
signal = reshape(signal, {dimWords, dimBatch, dimEmb});
}
return embeddings + signal;
}
virtual Expr addSpecialEmbeddings(Expr input, int start = 0, Ptr<data::CorpusBatch> batch = nullptr) const override {
bool trainPosEmbeddings = opt<bool>("transformer-train-position-embeddings", true);
bool trainTypeEmbeddings = opt<bool>("bert-train-type-embeddings", true);
input = addPositionalEmbeddings(input, start, trainPosEmbeddings);
input = addSentenceEmbeddings(input, batch, trainTypeEmbeddings);
return input;
}
};
class BertClassifier : public ClassifierBase {
using ClassifierBase::ClassifierBase;
public:
Ptr<ClassifierState> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() != 1, "Currently we only support a single encoder BERT model");
auto context = encoderStates[0]->getContext();
auto classEmbeddings = slice(context, /*axis=*/-3, /*i=*/0); // [CLS] symbol is first symbol in each sequence
int dimModel = classEmbeddings->shape()[-1];
int dimTrgCls = opt<std::vector<int>>("dim-vocabs")[batchIndex_]; // Target vocab is used as class labels
auto output = mlp::mlp() //
.push_back(mlp::dense() //
("prefix", prefix_ + "_ff_logit_l1") //
("dim", dimModel) //
("activation", (int)mlp::act::tanh)) // @TODO: do we actually need this?
.push_back(mlp::output() //
("dim", dimTrgCls)) //
("prefix", prefix_ + "_ff_logit_l2") //
.construct(graph);
auto logits = output->apply(classEmbeddings); // class logits for each batch entry
auto state = New<ClassifierState>();
state->setLogProbs(logits);
// Filled externally, for BERT these are NextSentence prediction labels
const auto& classLabels = (*batch)[batchIndex_]->data();
state->setTargetWords(classLabels);
return state;
}
virtual void clear() override {}
};
class BertMaskedLM : public ClassifierBase {
using ClassifierBase::ClassifierBase;
public:
Ptr<ClassifierState> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
Ptr<data::BertBatch> bertBatch = std::dynamic_pointer_cast<data::BertBatch>(batch);
ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training");
ABORT_IF(encoderStates.size() != 1, "Currently we only support a single encoder BERT model");
auto context = encoderStates[0]->getContext();
auto bertMaskedPositions = graph->indices(bertBatch->bertMaskedPositions()); // positions in batch of masked entries
const auto& bertMaskedWords = bertBatch->bertMaskedWords(); // vocab ids of entries that have been masked
int dimModel = context->shape()[-1];
int dimBatch = context->shape()[-2];
int dimTime = context->shape()[-3];
auto maskedContext = rows(reshape(context, {dimBatch * dimTime, dimModel}), bertMaskedPositions); // subselect stuff that has actually been masked out
int dimVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];
auto layer1 = mlp::mlp()
.push_back(mlp::dense()
("prefix", prefix_ + "_ff_logit_l1")
("dim", dimModel))
.construct(graph);
auto intermediate = layer1->apply(maskedContext);
std::string activationType = opt<std::string>("transformer-ffn-activation");
if(activationType == "relu")
intermediate = relu(intermediate);
else if(activationType == "swish")
intermediate = swish(intermediate);
else if(activationType == "gelu")
intermediate = gelu(intermediate);
else
ABORT("Activation function {} not supported in BERT masked LM", activationType);
auto gamma = graph->param(prefix_ + "_ff_ln_scale", {1, dimModel}, inits::ones());
auto beta = graph->param(prefix_ + "_ff_ln_bias", {1, dimModel}, inits::zeros());
intermediate = layerNorm(intermediate, gamma, beta);
auto layer2 = mlp::mlp()
.push_back(mlp::output(
"prefix", prefix_ + "_ff_logit_l2",
"dim", dimVoc)
.tieTransposed("Wemb"))
.construct(graph);
auto logits = layer2->apply(intermediate); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab dim]
auto state = New<ClassifierState>();
state->setLogProbs(logits);
state->setTargetWords(bertMaskedWords);
return state;
}
virtual void clear() override {}
};
}