.. _program_listing_file_src_translator_beam_search.h: Program Listing for File beam_search.h ====================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/translator/beam_search.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "marian.h" #include "translator/history.h" #include "translator/scorers.h" namespace marian { class BeamSearch { private: Ptr options_; std::vector> scorers_; size_t beamSize_; Ptr trgVocab_; const float INVALID_PATH_SCORE; const bool PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues. static float chooseInvalidPathScore(Ptr options) { auto prec = options->get>("precision", {"float32"}); auto computeType = typeFromString(prec[0]); return NumericLimits(computeType).lowest; } public: BeamSearch(Ptr options, const std::vector>& scorers, const Ptr trgVocab) : options_(options), scorers_(scorers), beamSize_(options_->get("beam-size")), trgVocab_(trgVocab), INVALID_PATH_SCORE{chooseInvalidPathScore(options)} {} // combine new expandedPathScores and previous beams into new set of beams Beams toHyps(const std::vector& nBestKeys, // [currentDimBatch, beamSize] flattened -> ((batchIdx, beamHypIdx) flattened, word idx) flattened const std::vector& nBestPathScores, // [currentDimBatch, beamSize] flattened const size_t nBestBeamSize, // for interpretation of nBestKeys const size_t vocabSize, // ditto. const Beams& beams, const std::vector>& states, Ptr batch, // for alignments only Ptr factoredVocab, size_t factorGroup, const std::vector& dropBatchEntries, // [origDimBatch] - empty source batch entries are marked with true, should be cleared after first use. const std::vector& batchIdxMap) const; std::vector getAlignmentsForHypothesis( // -> P(s|t) for current t and given beam and batch dim const std::vector alignAll, // [beam depth, max src length, batch size, 1], flattened vector of all attention probablities Ptr batch, int beamHypIdx, int currentBatchIdx, int origBatchIdx, int currentDimBatch) const; // remove all beam entries that have reached EOS Beams purgeBeams(const Beams& beams, /*in/out=*/std::vector& batchIdxMap); // main decoding function Histories search(Ptr graph, Ptr batch); }; } // namespace marian