Program Listing for File translator.h

Return to documentation for file (src/translator/translator.h)

#pragma once

#include <string>

#include "data/batch_generator.h"
#include "data/corpus.h"
#include "data/shortlist.h"
#include "data/text_input.h"

#include "common/scheduling_parameter.h"
#include "common/timer.h"

#include "3rd_party/threadpool.h"

#include "translator/history.h"
#include "translator/output_collector.h"
#include "translator/output_printer.h"

#include "models/model_task.h"
#include "translator/scorers.h"

// currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled.
#include "3rd_party/mio/mio.hpp"

namespace marian {

template <class Search>
class Translate : public ModelTask {
private:
  Ptr<Options> options_;
  std::vector<Ptr<ExpressionGraph>> graphs_;
  std::vector<std::vector<Ptr<Scorer>>> scorers_;

  Ptr<data::Corpus> corpus_;
  Ptr<Vocab> trgVocab_;
  Ptr<const data::ShortlistGenerator> shortlistGenerator_;

  size_t numDevices_;

  std::vector<mio::mmap_source> model_mmaps_; // map
  std::vector<std::vector<io::Item>> model_items_; // non-mmap

public:
  Translate(Ptr<Options> options)
    : options_(New<Options>(options->clone())) { // @TODO: clone should return Ptr<Options> same as "with"?
    // This is currently safe as the translator is either created stand-alone or
    // or config is created anew from Options in the validator

    options_->set("inference", true,
                  "shuffle", "none");

    corpus_ = New<data::Corpus>(options_, true);

    auto vocabs = options_->get<std::vector<std::string>>("vocabs");
    trgVocab_ = New<Vocab>(options_, vocabs.size() - 1);
    trgVocab_->load(vocabs.back());
    auto srcVocab = corpus_->getVocabs()[0];

    std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn", {});
    ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");

    if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) {
      shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabs.front() == vocabs.back());
    }

    auto devices = Config::getDevices(options_);
    numDevices_ = devices.size();

    ThreadPool threadPool(numDevices_, numDevices_);
    scorers_.resize(numDevices_);
    graphs_.resize(numDevices_);

    auto models = options->get<std::vector<std::string>>("models");
    if(options_->get<bool>("model-mmap", false)) {
      for(auto model : models) {
        ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped");
        LOG(info, "Loading model from {}", model);
        model_mmaps_.push_back(mio::mmap_source(model));
      }
    }
    else {
      for(auto model : models) {
        LOG(info, "Loading model from {}", model);
        auto items = io::loadItems(model);
        model_items_.push_back(std::move(items));
      }
    }

    size_t id = 0;
    for(auto device : devices) {
      auto task = [&](DeviceId device, size_t id) {
        auto graph = New<ExpressionGraph>(true);
        auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
        graph->setDefaultElementType(typeFromString(prec[0]));
        graph->setDevice(device);
        if (device.type == DeviceType::cpu) {
          graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
          graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
          graph->getBackend()->setQuantizeRange(options_->get<float>("quantize-range"));
        }
        graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
        graphs_[id] = graph;

        std::vector<Ptr<Scorer>> scorers;
        if(options_->get<bool>("model-mmap", false)) {
          scorers = createScorers(options_, model_mmaps_);
        }
        else {
          scorers = createScorers(options_, model_items_);
        }

        for(auto scorer : scorers) {
          scorer->init(graph);
          if(shortlistGenerator_)
            scorer->setShortlistGenerator(shortlistGenerator_);
        }

        scorers_[id] = scorers;
        graph->forward();
      };

      threadPool.enqueue(task, device, id++);
    }

    if(options_->hasAndNotEmpty("output-sampling")) {
      if(options_->get<size_t>("beam-size") > 1)
        LOG(warn,
            "[warning] Output sampling and beam search (beam-size > 1) are contradictory methods "
            "and using them together is not recommended. Set beam-size to 1");
      if(options_->get<std::vector<std::string>>("models").size() > 1)
        LOG(warn,
            "[warning] Output sampling and model ensembling are contradictory methods and using "
            "them together is not recommended. Use a single model");
    }
  }

  void run() override {
    data::BatchGenerator<data::Corpus> bg(corpus_, options_);

    ThreadPool threadPool(numDevices_, numDevices_);

    size_t batchId = 0;
    auto collector = New<OutputCollector>(options_->get<std::string>("output"));
    auto printer = New<OutputPrinter>(options_, trgVocab_);
    if(options_->get<bool>("quiet-translation"))
      collector->setPrintingStrategy(New<QuietPrinting>());

    // mutex for syncing counter and timer updates
    std::mutex syncCounts;

    // timer and counters for total elapsed time and statistics
    std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
    size_t totBatches      = 0;
    size_t totLines        = 0;
    size_t totSourceTokens = 0;

    // timer and counters for elapsed time and statistics between updates
    std::unique_ptr<timer::Timer> curTimer(new timer::Timer());
    size_t curBatches      = 0;
    size_t curLines        = 0;
    size_t curSourceTokens = 0;

    // determine if we want to display timer statistics, by default off
    auto statFreq = SchedulingParameter::parse(options_->get<std::string>("stat-freq", "0u"));
    // abort early to avoid potentially costly batching and translation before error message
    ABORT_IF(statFreq.unit != SchedulingUnit::updates, "Units other than 'u' are not supported for --stat-freq value {}", statFreq);

    // Override display for progress heartbeat for MS-internal Philly compute cluster
    // otherwise this job may be killed prematurely if no log for 4 hrs
    if(getenv("PHILLY_JOB_ID")) { // this environment variable exists when running on the cluster
      if(statFreq.n == 0) {
        statFreq.n = 10000;
        statFreq.unit = SchedulingUnit::updates;
      }
    }

    bool doNbest = options_->get<bool>("n-best");

    bg.prepare();
    for(auto batch : bg) {
      auto task = [=, &syncCounts,
                      &totBatches, &totLines, &totSourceTokens, &totTimer,
                      &curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) {
        thread_local Ptr<ExpressionGraph> graph;
        thread_local std::vector<Ptr<Scorer>> scorers;

        if(!graph) {
          graph = graphs_[id % numDevices_];
          scorers = scorers_[id % numDevices_];
        }

        auto search = New<Search>(options_, scorers, trgVocab_);
        auto histories = search->search(graph, batch);

        for(auto history : histories) {
          std::stringstream best1;
          std::stringstream bestn;
          printer->print(history, best1, bestn);
          collector->Write((long)history->getLineNum(),
                           best1.str(),
                           bestn.str(),
                           doNbest);
        }

        // if we asked for speed information display this
        if(statFreq.n > 0) {
          std::lock_guard<std::mutex> lock(syncCounts);
          totBatches++;
          totLines        += batch->size();
          totSourceTokens += batch->front()->batchWords();

          curBatches++;
          curLines        += batch->size();
          curSourceTokens += batch->front()->batchWords();

          if(totBatches % statFreq.n == 0) {
            double totTime = totTimer->elapsed();
            double curTime = curTimer->elapsed();

            LOG(info,
                "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
                totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime);

            // reset stats between updates
            curBatches = curLines = curSourceTokens = 0;
            curTimer.reset(new timer::Timer());
          }
        }
      };

      threadPool.enqueue(task, batchId++);
    }

    // make sure threads are joined before other local variables get de-allocated
    threadPool.join_all();

    // display final speed numbers over total translation if intermediate displays were requested
    if(statFreq.n > 0) {
      double totTime = totTimer->elapsed();
      LOG(info,
          "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
          totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime);
    }
  }
};

template <class Search>
class TranslateService : public ModelServiceTask {
private:
  Ptr<Options> options_;
  std::vector<Ptr<ExpressionGraph>> graphs_;
  std::vector<std::vector<Ptr<Scorer>>> scorers_;

  std::vector<Ptr<Vocab>> srcVocabs_;
  Ptr<Vocab> trgVocab_;
  Ptr<const data::ShortlistGenerator> shortlistGenerator_;

  size_t numDevices_;

public:
  virtual ~TranslateService() {}

  TranslateService(Ptr<Options> options)
    : options_(New<Options>(options->clone())) {
    // initialize vocabs
    options_->set("inference", true);
    options_->set("shuffle", "none");

    auto vocabPaths = options_->get<std::vector<std::string>>("vocabs");
    std::vector<int> maxVocabs = options_->get<std::vector<int>>("dim-vocabs");

    for(size_t i = 0; i < vocabPaths.size() - 1; ++i) {
      Ptr<Vocab> vocab = New<Vocab>(options_, i);
      vocab->load(vocabPaths[i], maxVocabs[i]);
      srcVocabs_.emplace_back(vocab);
    }

    trgVocab_ = New<Vocab>(options_, vocabPaths.size() - 1);
    trgVocab_->load(vocabPaths.back());
    auto srcVocab = srcVocabs_.front();

    std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
    ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");

    // load lexical shortlist
    if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) {
        shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabPaths.front() == vocabPaths.back());
    }

    // get device IDs
    auto devices = Config::getDevices(options_);
    numDevices_ = devices.size();

    // preload models
    std::vector<std::vector<io::Item>> model_items_;
    auto models = options->get<std::vector<std::string>>("models");
    for(auto model : models) {
      auto items = io::loadItems(model);
      model_items_.push_back(std::move(items));
    }

    // initialize scorers
    for(auto device : devices) {
      auto graph = New<ExpressionGraph>(true);

      auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
      graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
      graph->setDevice(device);
      if (device.type == DeviceType::cpu) {
        graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
        graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
        graph->getBackend()->setQuantizeRange(options_->get<float>("quantize-range"));
      }
      graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
      graphs_.push_back(graph);

      auto scorers = createScorers(options_, model_items_);
      for(auto scorer : scorers) {
        scorer->init(graph);
        if(shortlistGenerator_)
          scorer->setShortlistGenerator(shortlistGenerator_);
      }
      scorers_.push_back(scorers);
    }
  }

  std::string run(const std::string& input) override {
    // split tab-separated input into fields if necessary
    auto inputs = options_->get<bool>("tsv", false)
                      ? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1))
                      : std::vector<std::string>({input});
    auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_);
    data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_, nullptr, /*runAsync=*/false);

    auto collector = New<StringCollector>(options_->get<bool>("quiet-translation", false));
    auto printer = New<OutputPrinter>(options_, trgVocab_);
    size_t batchId = 0;

    batchGenerator.prepare();

    {
      ThreadPool threadPool_(numDevices_, numDevices_);

      for(auto batch : batchGenerator) {
        auto task = [=](size_t id) {
          thread_local Ptr<ExpressionGraph> graph;
          thread_local std::vector<Ptr<Scorer>> scorers;

          if(!graph) {
            graph = graphs_[id % numDevices_];
            scorers = scorers_[id % numDevices_];
          }

          auto search = New<Search>(options_, scorers, trgVocab_);
          auto histories = search->search(graph, batch);

          for(auto history : histories) {
            std::stringstream best1;
            std::stringstream bestn;
            printer->print(history, best1, bestn);
            collector->add((long)history->getLineNum(), best1.str(), bestn.str());
          }
        };

        threadPool_.enqueue(task, batchId);
        batchId++;
      }
    }

    auto translations = collector->collect(options_->get<bool>("n-best"));
    return utils::join(translations, "\n");
  }

private:
  // Converts a multi-line input with tab-separated source(s) and target sentences into separate lists
  // of sentences from source(s) and target sides, e.g.
  // "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"]
  std::vector<std::string> convertTsvToLists(const std::string& inputText, size_t numFields) {
    std::vector<std::string> outputFields(numFields);

    std::string line;
    std::vector<std::string> lineFields(numFields);
    std::istringstream inputStream(inputText);
    bool first = true;
    while(std::getline(inputStream, line)) {
      utils::splitTsv(line, lineFields, numFields);
      for(size_t i = 0; i < numFields; ++i) {
        if(!first)
          outputFields[i] += "\n";  // join sentences with a new line sign
        outputFields[i] += lineFields[i];
      }
      if(first)
        first = false;
    }

    return outputFields;
  }
};
}  // namespace marian