.. _program_listing_file_src_models_bert.h: Program Listing for File bert.h =============================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/models/bert.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "data/corpus_base.h" #include "models/encoder_classifier.h" #include "models/transformer.h" // @BUGBUG: transformer.h is large and was meant to be compiled separately #include "data/rng_engine.h" namespace marian { namespace data { class BertBatch : public CorpusBatch { private: std::vector maskedPositions_; Words maskedWords_; std::vector sentenceIndices_; std::string maskSymbol_; std::string sepSymbol_; std::string clsSymbol_; // Selects a random word from the vocabulary std::unique_ptr> randomWord_; // Selects a random integer between 0 and 99 std::unique_ptr> randomPercent_; // Word ids of words that should not be masked, e.g. separators, padding std::unordered_set dontMask_; // Masking function, i.e. replaces a chosen word with either // a [MASK] symbol, itself or a random word Word maskOut(Word word, Word mask, std::mt19937& engine) { auto subBatch = subBatches_.front(); // @TODO: turn those threshold into parameters, adjustable from command line float r = (*randomPercent_)(engine); if (r < 0.1f) { // for 10% of cases return same word return word; } else if (r < 0.2f) { // for 10% return random word Word randWord = Word::fromWordIndex((*randomWord_)(engine)); if(dontMask_.count(randWord) > 0) // some words, e.g. [CLS] or , may not be used as random words return mask; // for those, return the mask symbol instead else return randWord; // else return the random word } else { // for 80% of words apply mask symbol return mask; } } public: // Takes a corpus batch, random engine (for deterministic behavior) and the masking percentage. // Also sets special vocabulary items given on command line. BertBatch(Ptr batch, std::mt19937& engine, float maskFraction, const std::string& maskSymbol, const std::string& sepSymbol, const std::string& clsSymbol, int dimTypeVocab) : CorpusBatch(*batch), maskSymbol_(maskSymbol), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) { // BERT expects a textual first stream and a second stream with class labels auto subBatch = subBatches_.front(); const auto& vocab = *subBatch->vocab(); // Initialize to sample random vocab id randomWord_.reset(new std::uniform_int_distribution(0, (WordIndex)vocab.size())); // Initialize to sample random percentage randomPercent_.reset(new std::uniform_real_distribution(0.f, 1.f)); auto& words = subBatch->data(); // Get word id of special symbols Word maskId = vocab[maskSymbol_]; Word clsId = vocab[clsSymbol_]; Word sepId = vocab[sepSymbol_]; ABORT_IF(maskId == vocab.getUnkId(), "BERT masking symbol {} not found in vocabulary", maskSymbol_); ABORT_IF(sepId == vocab.getUnkId(), "BERT separator symbol {} not found in vocabulary", sepSymbol_); ABORT_IF(clsId == vocab.getUnkId(), "BERT class symbol {} not found in vocabulary", clsSymbol_); dontMask_.insert(clsId); // don't mask class token dontMask_.insert(sepId); // don't mask separator token dontMask_.insert(vocab.getEosId()); // don't mask // it's ok to mask std::vector selected; selected.reserve(words.size()); for(int i = 0; i < words.size(); ++i) // collect words among which we will mask if(dontMask_.count(words[i]) == 0) // do not add indices of special words selected.push_back(i); std::shuffle(selected.begin(), selected.end(), engine); // randomize positions selected.resize((size_t)std::ceil(selected.size() * maskFraction)); // select first x percent from shuffled indices for(int i : selected) { maskedPositions_.push_back(i); // where is the original word? maskedWords_.push_back(words[i]); // what is the original word? words[i] = maskOut(words[i], maskId, engine); // mask that position } annotateSentenceIndices(dimTypeVocab); } BertBatch(Ptr batch, const std::string& sepSymbol, const std::string& clsSymbol, int dimTypeVocab) : CorpusBatch(*batch), maskSymbol_("dummy"), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) { annotateSentenceIndices(dimTypeVocab); } void annotateSentenceIndices(int dimTypeVocab) { // BERT expects a textual first stream and a second stream with class labels auto subBatch = subBatches_.front(); const auto& vocab = *subBatch->vocab(); auto& words = subBatch->data(); // Get word id of special symbols Word sepId = vocab[sepSymbol_]; ABORT_IF(sepId == vocab.getUnkId(), "BERT separator symbol {} not found in vocabulary", sepSymbol_); int dimBatch = (int)subBatch->batchSize(); int dimWords = (int)subBatch->batchWidth(); const size_t maxSentPos = dimTypeVocab; // create indices for BERT sentence embeddings A and B sentenceIndices_.resize(words.size()); // each word is either in sentence A or B std::vector sentPos(dimBatch, 0); // initialize each batch entry with being A [0] for(int i = 0; i < dimWords; ++i) { // advance word-wise for(int j = 0; j < dimBatch; ++j) { // scan batch-wise int k = i * dimBatch + j; sentenceIndices_[k] = sentPos[j]; // set to current sentence position for batch entry, max position 1. if(words[k] == sepId && sentPos[j] < maxSentPos) { // if current word is a separator and not beyond range sentPos[j]++; // then increase sentence position for batch entry (to B [1]) } } } } const std::vector& bertMaskedPositions() { return maskedPositions_; } const Words& bertMaskedWords() { return maskedWords_; } const std::vector& bertSentenceIndices() { return sentenceIndices_; } }; } class BertEncoderClassifier : public EncoderClassifier, public data::RNGEngine { // @TODO: this random engine is not being serialized right now public: BertEncoderClassifier(Ptr options) : EncoderClassifier(options) {} std::vector> apply(Ptr graph, Ptr batch, bool clearGraph) override { std::string modelType = opt("type"); int dimTypeVocab = opt("bert-type-vocab-size"); // intercept batch and annotate with BERT-specific concepts Ptr bertBatch; if(modelType == "bert") { // full BERT pre-training bertBatch = New(batch, eng_, opt("bert-masking-fraction", 0.15f), // 15% by default according to paper opt("bert-mask-symbol"), opt("bert-sep-symbol"), opt("bert-class-symbol"), dimTypeVocab); } else if(modelType == "bert-classifier") { // we are probably fine-tuning a BERT model for a classification task bertBatch = New(batch, opt("bert-sep-symbol"), opt("bert-class-symbol"), dimTypeVocab); // only annotate sentence separators } else { ABORT("Unknown BERT-style model: {}", modelType); } return EncoderClassifier::apply(graph, bertBatch, clearGraph); } // for externally created BertBatch for instance in BertValidator std::vector> apply(Ptr graph, Ptr bertBatch, bool clearGraph) { return EncoderClassifier::apply(graph, bertBatch, clearGraph); } }; class BertEncoder : public EncoderTransformer { using EncoderTransformer::EncoderTransformer; public: Expr addSentenceEmbeddings(Expr embeddings, Ptr batch, bool learnedPosEmbeddings) const { Ptr bertBatch = std::dynamic_pointer_cast(batch); ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training or fine-tuning"); int dimEmb = embeddings->shape()[-1]; int dimBatch = embeddings->shape()[-2]; int dimWords = embeddings->shape()[-3]; int dimTypeVocab = opt("bert-type-vocab-size", 2); Expr signal; if(learnedPosEmbeddings) { auto sentenceEmbeddings = embedding() ("prefix", "Wtype") ("dimVocab", dimTypeVocab) // sentence A or sentence B ("dimEmb", dimEmb) .construct(graph_); signal = sentenceEmbeddings->applyIndices(bertBatch->bertSentenceIndices(), {dimWords, dimBatch, dimEmb}); } else { // @TODO: factory for positional embeddings? // constant sinusoidal position embeddings, no backprob auto sentenceEmbeddingsExpr = graph_->constant({2, dimEmb}, inits::sinusoidalPositionEmbeddings(0)); signal = rows(sentenceEmbeddingsExpr, bertBatch->bertSentenceIndices()); signal = reshape(signal, {dimWords, dimBatch, dimEmb}); } return embeddings + signal; } virtual Expr addSpecialEmbeddings(Expr input, int start = 0, Ptr batch = nullptr) const override { bool trainPosEmbeddings = opt("transformer-train-position-embeddings", true); bool trainTypeEmbeddings = opt("bert-train-type-embeddings", true); input = addPositionalEmbeddings(input, start, trainPosEmbeddings); input = addSentenceEmbeddings(input, batch, trainTypeEmbeddings); return input; } }; class BertClassifier : public ClassifierBase { using ClassifierBase::ClassifierBase; public: Ptr apply(Ptr graph, Ptr batch, const std::vector>& encoderStates) override { ABORT_IF(encoderStates.size() != 1, "Currently we only support a single encoder BERT model"); auto context = encoderStates[0]->getContext(); auto classEmbeddings = slice(context, /*axis=*/-3, /*i=*/0); // [CLS] symbol is first symbol in each sequence int dimModel = classEmbeddings->shape()[-1]; int dimTrgCls = opt>("dim-vocabs")[batchIndex_]; // Target vocab is used as class labels auto output = mlp::mlp() // .push_back(mlp::dense() // ("prefix", prefix_ + "_ff_logit_l1") // ("dim", dimModel) // ("activation", (int)mlp::act::tanh)) // @TODO: do we actually need this? .push_back(mlp::output() // ("dim", dimTrgCls)) // ("prefix", prefix_ + "_ff_logit_l2") // .construct(graph); auto logits = output->apply(classEmbeddings); // class logits for each batch entry auto state = New(); state->setLogProbs(logits); // Filled externally, for BERT these are NextSentence prediction labels const auto& classLabels = (*batch)[batchIndex_]->data(); state->setTargetWords(classLabels); return state; } virtual void clear() override {} }; class BertMaskedLM : public ClassifierBase { using ClassifierBase::ClassifierBase; public: Ptr apply(Ptr graph, Ptr batch, const std::vector>& encoderStates) override { Ptr bertBatch = std::dynamic_pointer_cast(batch); ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training"); ABORT_IF(encoderStates.size() != 1, "Currently we only support a single encoder BERT model"); auto context = encoderStates[0]->getContext(); auto bertMaskedPositions = graph->indices(bertBatch->bertMaskedPositions()); // positions in batch of masked entries const auto& bertMaskedWords = bertBatch->bertMaskedWords(); // vocab ids of entries that have been masked int dimModel = context->shape()[-1]; int dimBatch = context->shape()[-2]; int dimTime = context->shape()[-3]; auto maskedContext = rows(reshape(context, {dimBatch * dimTime, dimModel}), bertMaskedPositions); // subselect stuff that has actually been masked out int dimVoc = opt>("dim-vocabs")[batchIndex_]; auto layer1 = mlp::mlp() .push_back(mlp::dense() ("prefix", prefix_ + "_ff_logit_l1") ("dim", dimModel)) .construct(graph); auto intermediate = layer1->apply(maskedContext); std::string activationType = opt("transformer-ffn-activation"); if(activationType == "relu") intermediate = relu(intermediate); else if(activationType == "swish") intermediate = swish(intermediate); else if(activationType == "gelu") intermediate = gelu(intermediate); else ABORT("Activation function {} not supported in BERT masked LM", activationType); auto gamma = graph->param(prefix_ + "_ff_ln_scale", {1, dimModel}, inits::ones()); auto beta = graph->param(prefix_ + "_ff_ln_bias", {1, dimModel}, inits::zeros()); intermediate = layerNorm(intermediate, gamma, beta); auto layer2 = mlp::mlp() .push_back(mlp::output( "prefix", prefix_ + "_ff_logit_l2", "dim", dimVoc) .tieTransposed("Wemb")) .construct(graph); auto logits = layer2->apply(intermediate); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab dim] auto state = New(); state->setLogProbs(logits); state->setTargetWords(bertMaskedWords); return state; } virtual void clear() override {} }; }