.. _program_listing_file_src_translator_translator.h: Program Listing for File translator.h ===================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/translator/translator.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include "data/batch_generator.h" #include "data/corpus.h" #include "data/shortlist.h" #include "data/text_input.h" #include "common/scheduling_parameter.h" #include "common/timer.h" #include "3rd_party/threadpool.h" #include "translator/history.h" #include "translator/output_collector.h" #include "translator/output_printer.h" #include "models/model_task.h" #include "translator/scorers.h" // currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled. #include "3rd_party/mio/mio.hpp" namespace marian { template class Translate : public ModelTask { private: Ptr options_; std::vector> graphs_; std::vector>> scorers_; Ptr corpus_; Ptr trgVocab_; Ptr shortlistGenerator_; size_t numDevices_; std::vector model_mmaps_; // map std::vector> model_items_; // non-mmap public: Translate(Ptr options) : options_(New(options->clone())) { // @TODO: clone should return Ptr same as "with"? // This is currently safe as the translator is either created stand-alone or // or config is created anew from Options in the validator options_->set("inference", true, "shuffle", "none"); corpus_ = New(options_, true); auto vocabs = options_->get>("vocabs"); trgVocab_ = New(options_, vocabs.size() - 1); trgVocab_->load(vocabs.back()); auto srcVocab = corpus_->getVocabs()[0]; std::vector lshOpts = options_->get>("output-approx-knn", {}); ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters"); if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) { shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabs.front() == vocabs.back()); } auto devices = Config::getDevices(options_); numDevices_ = devices.size(); ThreadPool threadPool(numDevices_, numDevices_); scorers_.resize(numDevices_); graphs_.resize(numDevices_); auto models = options->get>("models"); if(options_->get("model-mmap", false)) { for(auto model : models) { ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped"); LOG(info, "Loading model from {}", model); model_mmaps_.push_back(mio::mmap_source(model)); } } else { for(auto model : models) { LOG(info, "Loading model from {}", model); auto items = io::loadItems(model); model_items_.push_back(std::move(items)); } } size_t id = 0; for(auto device : devices) { auto task = [&](DeviceId device, size_t id) { auto graph = New(true); auto prec = options_->get>("precision", {"float32"}); graph->setDefaultElementType(typeFromString(prec[0])); graph->setDevice(device); if (device.type == DeviceType::cpu) { graph->getBackend()->setOptimized(options_->get("optimize")); graph->getBackend()->setGemmType(options_->get("gemm-type")); graph->getBackend()->setQuantizeRange(options_->get("quantize-range")); } graph->reserveWorkspaceMB(options_->get("workspace")); graphs_[id] = graph; std::vector> scorers; if(options_->get("model-mmap", false)) { scorers = createScorers(options_, model_mmaps_); } else { scorers = createScorers(options_, model_items_); } for(auto scorer : scorers) { scorer->init(graph); if(shortlistGenerator_) scorer->setShortlistGenerator(shortlistGenerator_); } scorers_[id] = scorers; graph->forward(); }; threadPool.enqueue(task, device, id++); } if(options_->hasAndNotEmpty("output-sampling")) { if(options_->get("beam-size") > 1) LOG(warn, "[warning] Output sampling and beam search (beam-size > 1) are contradictory methods " "and using them together is not recommended. Set beam-size to 1"); if(options_->get>("models").size() > 1) LOG(warn, "[warning] Output sampling and model ensembling are contradictory methods and using " "them together is not recommended. Use a single model"); } } void run() override { data::BatchGenerator bg(corpus_, options_); ThreadPool threadPool(numDevices_, numDevices_); size_t batchId = 0; auto collector = New(options_->get("output")); auto printer = New(options_, trgVocab_); if(options_->get("quiet-translation")) collector->setPrintingStrategy(New()); // mutex for syncing counter and timer updates std::mutex syncCounts; // timer and counters for total elapsed time and statistics std::unique_ptr totTimer(new timer::Timer()); size_t totBatches = 0; size_t totLines = 0; size_t totSourceTokens = 0; // timer and counters for elapsed time and statistics between updates std::unique_ptr curTimer(new timer::Timer()); size_t curBatches = 0; size_t curLines = 0; size_t curSourceTokens = 0; // determine if we want to display timer statistics, by default off auto statFreq = SchedulingParameter::parse(options_->get("stat-freq", "0u")); // abort early to avoid potentially costly batching and translation before error message ABORT_IF(statFreq.unit != SchedulingUnit::updates, "Units other than 'u' are not supported for --stat-freq value {}", statFreq); // Override display for progress heartbeat for MS-internal Philly compute cluster // otherwise this job may be killed prematurely if no log for 4 hrs if(getenv("PHILLY_JOB_ID")) { // this environment variable exists when running on the cluster if(statFreq.n == 0) { statFreq.n = 10000; statFreq.unit = SchedulingUnit::updates; } } bool doNbest = options_->get("n-best"); bg.prepare(); for(auto batch : bg) { auto task = [=, &syncCounts, &totBatches, &totLines, &totSourceTokens, &totTimer, &curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) { thread_local Ptr graph; thread_local std::vector> scorers; if(!graph) { graph = graphs_[id % numDevices_]; scorers = scorers_[id % numDevices_]; } auto search = New(options_, scorers, trgVocab_); auto histories = search->search(graph, batch); for(auto history : histories) { std::stringstream best1; std::stringstream bestn; printer->print(history, best1, bestn); collector->Write((long)history->getLineNum(), best1.str(), bestn.str(), doNbest); } // if we asked for speed information display this if(statFreq.n > 0) { std::lock_guard lock(syncCounts); totBatches++; totLines += batch->size(); totSourceTokens += batch->front()->batchWords(); curBatches++; curLines += batch->size(); curSourceTokens += batch->front()->batchWords(); if(totBatches % statFreq.n == 0) { double totTime = totTimer->elapsed(); double curTime = curTimer->elapsed(); LOG(info, "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime); // reset stats between updates curBatches = curLines = curSourceTokens = 0; curTimer.reset(new timer::Timer()); } } }; threadPool.enqueue(task, batchId++); } // make sure threads are joined before other local variables get de-allocated threadPool.join_all(); // display final speed numbers over total translation if intermediate displays were requested if(statFreq.n > 0) { double totTime = totTimer->elapsed(); LOG(info, "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime); } } }; template class TranslateService : public ModelServiceTask { private: Ptr options_; std::vector> graphs_; std::vector>> scorers_; std::vector> srcVocabs_; Ptr trgVocab_; Ptr shortlistGenerator_; size_t numDevices_; public: virtual ~TranslateService() {} TranslateService(Ptr options) : options_(New(options->clone())) { // initialize vocabs options_->set("inference", true); options_->set("shuffle", "none"); auto vocabPaths = options_->get>("vocabs"); std::vector maxVocabs = options_->get>("dim-vocabs"); for(size_t i = 0; i < vocabPaths.size() - 1; ++i) { Ptr vocab = New(options_, i); vocab->load(vocabPaths[i], maxVocabs[i]); srcVocabs_.emplace_back(vocab); } trgVocab_ = New(options_, vocabPaths.size() - 1); trgVocab_->load(vocabPaths.back()); auto srcVocab = srcVocabs_.front(); std::vector lshOpts = options_->get>("output-approx-knn"); ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters"); // load lexical shortlist if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) { shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabPaths.front() == vocabPaths.back()); } // get device IDs auto devices = Config::getDevices(options_); numDevices_ = devices.size(); // preload models std::vector> model_items_; auto models = options->get>("models"); for(auto model : models) { auto items = io::loadItems(model); model_items_.push_back(std::move(items)); } // initialize scorers for(auto device : devices) { auto graph = New(true); auto precison = options_->get>("precision", {"float32"}); graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph graph->setDevice(device); if (device.type == DeviceType::cpu) { graph->getBackend()->setOptimized(options_->get("optimize")); graph->getBackend()->setGemmType(options_->get("gemm-type")); graph->getBackend()->setQuantizeRange(options_->get("quantize-range")); } graph->reserveWorkspaceMB(options_->get("workspace")); graphs_.push_back(graph); auto scorers = createScorers(options_, model_items_); for(auto scorer : scorers) { scorer->init(graph); if(shortlistGenerator_) scorer->setShortlistGenerator(shortlistGenerator_); } scorers_.push_back(scorers); } } std::string run(const std::string& input) override { // split tab-separated input into fields if necessary auto inputs = options_->get("tsv", false) ? convertTsvToLists(input, options_->get("tsv-fields", 1)) : std::vector({input}); auto corpus_ = New(inputs, srcVocabs_, options_); data::BatchGenerator batchGenerator(corpus_, options_, nullptr, /*runAsync=*/false); auto collector = New(options_->get("quiet-translation", false)); auto printer = New(options_, trgVocab_); size_t batchId = 0; batchGenerator.prepare(); { ThreadPool threadPool_(numDevices_, numDevices_); for(auto batch : batchGenerator) { auto task = [=](size_t id) { thread_local Ptr graph; thread_local std::vector> scorers; if(!graph) { graph = graphs_[id % numDevices_]; scorers = scorers_[id % numDevices_]; } auto search = New(options_, scorers, trgVocab_); auto histories = search->search(graph, batch); for(auto history : histories) { std::stringstream best1; std::stringstream bestn; printer->print(history, best1, bestn); collector->add((long)history->getLineNum(), best1.str(), bestn.str()); } }; threadPool_.enqueue(task, batchId); batchId++; } } auto translations = collector->collect(options_->get("n-best")); return utils::join(translations, "\n"); } private: // Converts a multi-line input with tab-separated source(s) and target sentences into separate lists // of sentences from source(s) and target sides, e.g. // "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"] std::vector convertTsvToLists(const std::string& inputText, size_t numFields) { std::vector outputFields(numFields); std::string line; std::vector lineFields(numFields); std::istringstream inputStream(inputText); bool first = true; while(std::getline(inputStream, line)) { utils::splitTsv(line, lineFields, numFields); for(size_t i = 0; i < numFields; ++i) { if(!first) outputFields[i] += "\n"; // join sentences with a new line sign outputFields[i] += lineFields[i]; } if(first) first = false; } return outputFields; } }; } // namespace marian