.. _program_listing_file_src_layers_embedding.h: Program Listing for File embedding.h ==================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/layers/embedding.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "generic.h" #include "marian.h" namespace marian { class FactoredVocab; class Embedding : public LayerBase, public IEmbeddingLayer { Expr E_; Expr FactorEmbMatrix_; // Factors embedding matrix if combining lemma and factors embeddings with concatenation Ptr factoredVocab_; Expr multiRows(const Words& data, float dropProb) const; Expr embedWithConcat(const Words& data) const; bool inference_{false}; public: Embedding(Ptr graph, Ptr options); std::tuple apply( Ptr subBatch) const override final; Expr apply(const Words& words, const Shape& shape) const override final; Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final; }; class ULREmbedding : public LayerBase, public IEmbeddingLayer { std::vector ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members bool inference_{false}; public: ULREmbedding(Ptr graph, Ptr options) : LayerBase(graph, options), inference_(opt("inference")) { std::string name = "url_embed"; // opt("prefix"); int dimKeys = opt("dimTgtVoc"); int dimQueries = opt("dimSrcVoc"); int dimEmb = opt("dimEmb"); int dimUlrEmb = opt("dimUlrEmb"); // ULR mono embed size bool fixed = opt("fixed", false); // Embedding layer initialization should depend only on embedding size, hence fanIn=false auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); std::string queryFile = opt("ulrQueryFile"); std::string keyFile = opt("ulrKeysFile"); bool trainTrans = opt("ulrTrainTransform", false); if(!queryFile.empty() && !keyFile.empty()) { initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); name = "ulr_query"; fixed = true; auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(query_embed); // keys embeds initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); name = "ulr_keys"; fixed = true; auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(key_embed); // actual trainable embedding initFunc = inits::glorotUniform(); name = "ulr_embed"; fixed = false; auto ulr_embed = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim ulrEmbeddings_.push_back(ulr_embed); // init trainable src embedding name = "ulr_src_embed"; auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulr_src_embed); // ulr transformation matrix // initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall // we make this to the fixed case only if(trainTrans) { initFunc = inits::glorotUniform(); fixed = false; } else { initFunc = inits::eye(); // identity matrix fixed = true; } name = "ulr_transform"; auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulrTransform); initFunc = inits::fromValue( 1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no // universal embeddings - should be zero for top freq only fixed = true; name = "ulr_shared"; auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed); ulrEmbeddings_.push_back(share_embed); } } std::tuple apply( Ptr subBatch) const override final { auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 int dimBatch = (int)subBatch->batchSize(); int dimEmb = uniEmbed->shape()[-1]; int dimWords = (int)subBatch->batchWidth(); // D = K.A.QT // dimm(K) = univ_tok_vocab*uni_embed_size // dim A = uni_embed_size*uni_embed_size // dim Q: uni_embed_size * total_merged_vocab_size // dim D = univ_tok_vocab * total_merged_vocab_size // note all above can be precombuted and serialized if A is not trainiable and during decoding // (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this // minibatch from Q auto embIdx = toWordIndexVector(subBatch->data()); auto queryEmbeddings = rows(queryEmbed, embIdx); auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings auto alpha = rows(ulrSharable, embIdx); // extract sharable flags auto qt = dot(queryEmbeddings, ulrTransform, false, false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]); qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in // magnitude with larger embeds sizes auto z = dot(qt, keyEmbed, false, true); // query-key similarity float dropProb = this->options_->get("ulr-dropout", 0.0f); // default no dropout if(!inference_) z = dropout(z, dropProb); float tau = this->options_->get("ulr-softmax-temperature", 1.0f); // default no temperature // temperature in softmax is to control randomness of predictions // high temperature Softmax outputs are more close to each other // low temperatures the softmax become more similar to "hardmax" auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb}); auto graph = ulrEmbeddings_.front()->graph(); auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask())); if(!inference_) batchEmbeddings = dropout(batchEmbeddings, options_->get("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1}); return std::make_tuple(batchEmbeddings, batchMask); } Expr apply(const Words& words, const Shape& shape) const override final { return applyIndices(toWordIndexVector(words), shape); } Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final { embIdx; shape; ABORT("not implemented"); // @TODO: implement me } }; } // namespace marian