.. _program_listing_file_src_data_batch_generator.h: Program Listing for File batch_generator.h ========================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/data/batch_generator.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "common/options.h" #include "common/signal_handling.h" #include "common/timer.h" #include "data/batch_stats.h" #include "data/rng_engine.h" #include "training/training_state.h" #include "data/iterator_facade.h" #include "3rd_party/threadpool.h" #include #include #include #include #include namespace marian { namespace data { // Iterator over batches generated by a BatchGenerator. Mean to be the only // interface to create batches. template class BatchIterator : public IteratorFacade, typename BatchGenerator::BatchPtr> { private: BatchGenerator* bg_; typename BatchGenerator::BatchPtr current_; friend BatchGenerator; // private, only BatchGenerator is allowed to create pointers // hence friend class above. BatchIterator(BatchGenerator* bg, typename BatchGenerator::BatchPtr current) : bg_(bg), current_(current) {} public: virtual bool equal(const BatchIterator& other) const override { // iterators are only equal if they point at the same batch or both have nullptr return current_ == other.current_; } // Just returns the batch pointer virtual const typename BatchGenerator::BatchPtr& dereference() const override { return current_; } // sets current pointer to the next batch pointer, will be nullptr when no // batches are available. This will evaluate to false in a for loop. virtual void increment() override { current_ = bg_->next(); }; }; template class BatchGenerator : public RNGEngine { public: typedef typename DataSet::batch_ptr BatchPtr; typedef typename DataSet::Sample Sample; typedef std::vector Samples; typedef BatchIterator iterator; friend iterator; protected: Ptr data_; Ptr options_; bool restored_{false}; // replacing old shuffle_ with two variants that determine more fine-grained shuffling behavior. // Both set to false is equivalent to old shuffle_ == false. // Now we can not shuffle the data, but shuffle batches. Useful for linear reading of very large data sets with pre-reading. // Parameters like maxi-batch determine how much data is pre-read and sorted by length or other criteria. bool shuffleData_{false}; // determine if full data should be shuffled before reading and batching. bool shuffleBatches_{false}; // determine if batches should be shuffled after batching. private: Ptr stats_; bool runAsync_{true}; // use asynchronous batch pre-fetching by default. We want to be able to disable this when running in library mode and for exception-safety. // state of fetching std::deque bufferedBatches_; // current swath of batches that next() reads from // state of reading typename DataSet::iterator current_; bool newlyPrepared_{ true }; // prepare() was just called: we need to reset current_ --@TODO: can we just reset it directly? // variables for multi-threaded pre-fetching mutable UPtr threadPool_; // (we only use one thread, but keep it around) std::future> futureBufferedBatches_; // next swath of batches is returned via this // this runs on a bg thread; sequencing is handled by caller, but locking is done in here std::deque fetchBatches() { timer::Timer total; typedef typename Sample::value_type Item; auto itemCmp = [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); }; // sort by element length, not content auto cmpSrc = [itemCmp](const Sample& a, const Sample& b) { return std::lexicographical_compare( a.begin(), a.end(), b.begin(), b.end(), itemCmp); }; auto cmpTrg = [itemCmp](const Sample& a, const Sample& b) { return std::lexicographical_compare( a.rbegin(), a.rend(), b.rbegin(), b.rend(), itemCmp); }; auto cmpNone = [](const Sample& a, const Sample& b) { return a.getId() < b.getId(); }; // sort in order of original ids = original data order unless shuffling typedef std::function cmp_type; typedef std::priority_queue sample_queue; std::unique_ptr maxiBatch; // priority queue, shortest first if(options_->has("maxi-batch-sort")) { if(options_->get("maxi-batch-sort") == "src") maxiBatch.reset(new sample_queue(cmpSrc)); else if(options_->get("maxi-batch-sort") == "none") maxiBatch.reset(new sample_queue(cmpNone)); else maxiBatch.reset(new sample_queue(cmpTrg)); } else { maxiBatch.reset(new sample_queue(cmpNone)); } size_t maxBatchSize = options_->get("mini-batch"); size_t maxSize = maxBatchSize * options_->get("maxi-batch"); // consume data from corpus into maxi-batch (single sentences) // sorted into specified order (due to queue) if(newlyPrepared_) { current_ = data_->begin(); newlyPrepared_ = false; } else { if(current_ != data_->end()) ++current_; } Samples maxiBatchTemp; while(current_ != data_->end() && maxiBatchTemp.size() < maxSize) { // loop over data if (saveAndExitRequested()) // stop generating batches return std::deque(); maxiBatchTemp.push_back(*current_); // do not consume more than required for the maxi batch as this causes // that line-by-line translation is delayed by one sentence bool last = maxiBatchTemp.size() == maxSize; if(!last) ++current_; // this actually reads the next line and pre-processes it } size_t numSentencesRead = maxiBatchTemp.size(); size_t sets = 0; for(auto&& s : maxiBatchTemp) { if(!s.empty()) { sets = s.size(); maxiBatch->push(s); } } // construct the actual batches and place them in the queue Samples batchVector; size_t currentWords = 0; std::vector lengths(sets, 0); // records maximum length observed within current batch std::deque tempBatches; // process all loaded sentences in order of increasing length // @TODO: we could just use a vector and do a sort() here; would make the cost more explicit const size_t mbWords = options_->get("mini-batch-words", 0); const bool useDynamicBatching = options_->has("mini-batch-fit"); BatchStats::const_iterator cachedStatsIter; if (stats_) cachedStatsIter = stats_->begin(); while(!maxiBatch->empty()) { // while there are sentences in the queue if (saveAndExitRequested()) // stop generating batches return std::deque(); // push item onto batch batchVector.push_back(maxiBatch->top()); maxiBatch->pop(); // fetch next-shortest // have we reached sufficient amount of data to form a batch? bool makeBatch; if(useDynamicBatching && stats_) { // batch size based on dynamic batching for(size_t i = 0; i < sets; ++i) if(batchVector.back()[i].size() > lengths[i]) lengths[i] = batchVector.back()[i].size(); // record max lengths so far maxBatchSize = stats_->findBatchSize(lengths, cachedStatsIter); makeBatch = batchVector.size() >= maxBatchSize; // if last added sentence caused a bump then we likely have bad padding, so rather move it into the next batch if(batchVector.size() > maxBatchSize) { maxiBatch->push(batchVector.back()); batchVector.pop_back(); } } else if(mbWords > 0) { currentWords += batchVector.back()[0].size(); // count words based on first stream =source --@TODO: shouldn't we count based on labels? makeBatch = currentWords > mbWords; // Batch size based on sentences } else makeBatch = batchVector.size() == maxBatchSize; // Batch size based on words // if we reached the desired batch size then create a real batch if(makeBatch) { tempBatches.push_back(data_->toBatch(batchVector)); // prepare for next batch batchVector.clear(); currentWords = 0; lengths.assign(sets, 0); if (stats_) cachedStatsIter = stats_->begin(); } } // turn rest into batch // @BUGBUG: This can create a very small batch, which with ce-mean-words can artificially // inflate the contribution of the sames in the batch, causing instability. // I think a good alternative would be to carry over the left-over sentences into the next round. if(!batchVector.empty()) tempBatches.push_back(data_->toBatch(batchVector)); // Shuffle the batches if(shuffleBatches_) { std::shuffle(tempBatches.begin(), tempBatches.end(), eng_); } double totalSent{}, totalLabels{}; for (auto& b : tempBatches) { totalSent += (double)b->size(); totalLabels += (double)b->words(-1); } auto totalDenom = tempBatches.empty() ? 1 : tempBatches.size(); // (make 0/0 = 0) LOG(debug, "[data] fetched {} batches with {} sentences. Per batch: {} sentences, {} labels.", tempBatches.size(), numSentencesRead, (double)totalSent / (double)totalDenom, (double)totalLabels / (double)totalDenom); LOG(debug, "[data] fetching batches took {:.2f} seconds, {:.2f} sents/s", total.elapsed(), (double)numSentencesRead / total.elapsed()); return tempBatches; } // this starts fillBatches() as a background operation void fetchBatchesAsync() { ABORT_IF(futureBufferedBatches_.valid(), "Attempted to restart futureBufferedBatches_ while still running"); ABORT_IF(!runAsync_, "Trying to run fetchBatchesAsync() but runAsync_ is false??"); ABORT_IF(!threadPool_, "Trying to run fetchBatchesAsync() without initialized threadPool_??"); futureBufferedBatches_ = threadPool_->enqueue([this]() { return fetchBatches(); }); } BatchPtr next() { if(bufferedBatches_.empty()) { if(runAsync_) { // by default we will run in asynchronous mode // out of data: need to get next batch from background thread // We only get here if the future has been scheduled to run; it must be valid. ABORT_IF(!futureBufferedBatches_.valid(), "Attempted to wait for futureBufferedBatches_ when none pending.\n" "This error often occurs when Marian tries to restore the training data iterator, but the corpus has been changed or replaced.\n" "If you have changed the training corpus, add --no-restore-corpus to the training command and run it again."); bufferedBatches_ = std::move(futureBufferedBatches_.get()); // if bg thread returns an empty swath, we hit the end of the epoch if (bufferedBatches_.empty() || saveAndExitRequested()) { return nullptr; } // and kick off the next bg operation fetchBatchesAsync(); } else { // don't spawn any threads, i.e. batch fetching is blocking. bufferedBatches_ = fetchBatches(); // if bufferedBatches is empty we hit the end of the epoch if (bufferedBatches_.empty() || saveAndExitRequested()) { return nullptr; } } } auto batch = bufferedBatches_.front(); bufferedBatches_.pop_front(); return batch; } public: BatchGenerator(Ptr data, Ptr options, Ptr stats = nullptr, bool runAsync = true) : data_(data), options_(options), stats_(stats), runAsync_(runAsync), threadPool_(runAsync ? new ThreadPool(1) : nullptr) { auto shuffle = options_->get("shuffle", "none"); shuffleData_ = shuffle == "data"; shuffleBatches_ = shuffleData_ || shuffle == "batches"; } ~BatchGenerator() { if (futureBufferedBatches_.valid()) // bg thread holds a reference to 'this', futureBufferedBatches_.get(); // so must wait for it to complete } iterator begin() { return iterator(this, next()); } iterator end() { return iterator(this, nullptr); } // @TODO: get rid of this function, begin() or constructor should figure this out void prepare() { if(shuffleData_) data_->shuffle(); else data_->reset(); newlyPrepared_ = true; // start the background pre-fetch operation when running in asynchronous mode, otherwise we will fetch on demand. if(runAsync_) fetchBatchesAsync(); } // Used to restore the state of a BatchGenerator after // an interrupted and resumed training. bool restore(Ptr state) { if(state->epochs == 1 && state->batchesEpoch == 0) return false; LOG(info, "[data] Restoring the corpus state to epoch {}, batch {}", state->epochs, state->batches); if(state->epochs > 1) { data_->restore(state); setRNGState(state->seedBatch); } prepare(); for(size_t i = 0; i < state->batchesEpoch; ++i) next(); return true; } // this is needed for dynamic MB scaling. Returns 0 if size is not known in words. size_t estimateTypicalTrgBatchWords() const { const size_t mbWords = options_->get("mini-batch-words", 0); const bool useDynamicBatching = options_->has("mini-batch-fit"); if (useDynamicBatching && stats_) return stats_->estimateTypicalTrgWords(); else if (mbWords) return mbWords; else return 0; } }; class CorpusBatchGenerator : public BatchGenerator, public TrainingObserver { public: CorpusBatchGenerator(Ptr data, Ptr options, Ptr stats = nullptr) : BatchGenerator(data, options, stats) {} void actAfterEpoch(TrainingState& state) override { state.seedBatch = getRNGState(); state.seedCorpus = data_->getRNGState(); } }; } // namespace data } // namespace marian