Class BertBatch

Inheritance Relationships

Base Type

Class Documentation

class BertBatch : public marian::data::CorpusBatch

BERT-specific mini-batch that computes masking for Masked LM training.

Expects symbols [MASK], [SEP], [CLS] to be present in vocabularies unless other symbols are specified in the config.

This takes a normal CorpusBatch and extends it with additional data. Luckily all the BERT-functionality can be inferred from a CorpusBatch alone.

Public Functions

BertBatch(Ptr<CorpusBatch> batch, std::mt19937 &engine, float maskFraction, const std::string &maskSymbol, const std::string &sepSymbol, const std::string &clsSymbol, int dimTypeVocab)
BertBatch(Ptr<CorpusBatch> batch, const std::string &sepSymbol, const std::string &clsSymbol, int dimTypeVocab)
void annotateSentenceIndices(int dimTypeVocab)
const std::vector<IndexType> &bertMaskedPositions()
const Words &bertMaskedWords()
const std::vector<IndexType> &bertSentenceIndices()