Program Listing for File word2vec_reader.h

Return to documentation for file (src/layers/word2vec_reader.h)

#pragma once

#include "marian.h"

#include "common/logging.h"

#include <fstream>
#include <string>
#include <vector>

namespace marian {

class Word2VecReader {
public:
  Word2VecReader() {}

  std::vector<float> read(const std::string& fileName, int dimVoc, int dimEmb) {
    LOG(info, "[data] Loading embedding vectors from {}", fileName);

    io::InputFileStream embFile(fileName);

    std::string line;
    std::vector<std::string> values;
    values.reserve(dimEmb);

    // The first line contains two values: the number of words in the
    // vocabulary and the length of embedding vectors
    io::getline(embFile, line);
    utils::split(line, values);
    ABORT_IF(values.size() != 2,
             "Unexpected format of the first line of the embedding file");
    ABORT_IF(stoi(values[1]) != dimEmb,
             "Unexpected length of embedding vectors");

    // Read embedding vectors into a map
    std::unordered_map<WordIndex, std::vector<float>> word2vec;
    while(io::getline(embFile, line)) {
      values.clear();
      utils::split(line, values);

      WordIndex word = std::stoi(values.front());
      if(word >= (size_t)dimVoc)
        continue;

      word2vec[word].reserve(dimEmb);
      std::transform(values.begin() + 1,
                     values.end(),
                     std::back_inserter(word2vec[word]),
                     [](const std::string& s) { return std::stof(s); });
    }

    // Initialize final flat vector for embeddings
    std::vector<float> embs;
    embs.reserve(dimVoc * dimEmb);

    // Populate output vector with embedding
    for(WordIndex word = 0; word < (WordIndex)dimVoc; ++word) {
      // For words not occuring in the file use uniform distribution
      if(word2vec.find(word) == word2vec.end()) {
        auto randVals = randomEmbeddings(dimVoc, dimEmb);
        embs.insert(embs.end(), randVals.begin(), randVals.end());
      } else {
        embs.insert(embs.end(), word2vec[word].begin(), word2vec[word].end());
      }
    }

    embs.resize(dimVoc * dimEmb, 0); // @TODO: is it correct to zero out the remaining embeddings?
    return embs;
  }

private:
  std::vector<float> randomEmbeddings(int dimVoc, int dimEmb) {
    std::vector<float> values;
    values.resize(dimEmb);
    // Glorot numal distribution
    float scale = sqrtf(2.0f / (dimVoc + dimEmb));

    // @TODO: switch to new random generator back-end.
    // This is rarely used however.
    std::random_device rd;
    std::mt19937 engine(rd());

    std::normal_distribution<float> d(0, scale);
    auto gen = [&d, &engine] () {
       return d(engine);
    };

    std::generate(values.begin(), values.end(), gen);

    return values;
  }
};

}  // namespace marian