Program Listing for File output_collector.cpp

Return to documentation for file (src/translator/output_collector.cpp)

#include "output_collector.h"
#include "common/file_stream.h"
#include "common/logging.h"

#include <cassert>

namespace marian {

OutputCollector::OutputCollector()
  : nextId_(0),
    printing_(new DefaultPrinting()) {}

OutputCollector::OutputCollector(std::string outFile)
  : nextId_(0),
    outStrm_(new std::ostream(std::cout.rdbuf())),
    printing_(new DefaultPrinting()) {
  if (outFile != "stdout")
    outStrm_.reset(new io::OutputFileStream(outFile));
}

void OutputCollector::Write(long sourceId,
                            const std::string& best1,
                            const std::string& bestn,
                            bool nbest) {
  std::lock_guard<std::mutex> lock(mutex_);
  if(sourceId == nextId_) {
    if(printing_->shouldBePrinted(sourceId))
      LOG(info, "Best translation {} : {}", sourceId, best1);

    if(outStrm_) {
      if(nbest)
        *outStrm_ << bestn << std::endl;
      else
        *outStrm_ << best1 << std::endl;
    }

    ++nextId_;

    Outputs::const_iterator iter, iterNext;
    iter = outputs_.begin();
    while(iter != outputs_.end()) {
      long currId = iter->first;

      if(currId == nextId_) {
        // 1st element in the map is the next
        const auto& currOutput = iter->second;
        if(printing_->shouldBePrinted(currId))
          LOG(info, "Best translation {} : {}", currId, currOutput.first);

        if(outStrm_) {
          if(nbest)
            *outStrm_ << currOutput.second << std::endl;
          else
            *outStrm_ << currOutput.first << std::endl;
        }

        ++nextId_;

        // delete current record, move iter on 1
        iterNext = iter;
        ++iterNext;
        outputs_.erase(iter);
        iter = iterNext;
      } else {
        // not the next. stop iterating
        assert(nextId_ < currId);
        break;
      }
    }

    // for 1-best, flush stdout so that we can consume this immediately from an
    // external process
    if(outStrm_ && !nbest)
      *outStrm_ << std::flush;

  } else {
    // save for later
    outputs_[sourceId] = std::make_pair(best1, bestn);
  }
}

StringCollector::StringCollector(bool quiet /*=false*/) : maxId_(-1), quiet_(quiet) {}

void StringCollector::add(long sourceId,
                          const std::string& best1,
                          const std::string& bestn) {
  std::lock_guard<std::mutex> lock(mutex_);
  if(!quiet_)
    LOG(info, "Best translation {} : {}", sourceId, best1);
  outputs_[sourceId] = std::make_pair(best1, bestn);
  if(maxId_ <= sourceId)
    maxId_ = sourceId;
}

std::vector<std::string> StringCollector::collect(bool nbest) {
  std::vector<std::string> outputs;
  for(int id = 0; id <= maxId_; ++id)
    outputs.emplace_back(nbest ? outputs_[id].second : outputs_[id].first);
  return outputs;
}
}  // namespace marian