Program Listing for File beam_search.h¶
↰ Return to documentation for file (src/translator/beam_search.h
)
#pragma once
#include "marian.h"
#include "translator/history.h"
#include "translator/scorers.h"
namespace marian {
class BeamSearch {
private:
Ptr<Options> options_;
std::vector<Ptr<Scorer>> scorers_;
size_t beamSize_;
Ptr<const Vocab> trgVocab_;
const float INVALID_PATH_SCORE;
const bool PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues.
static float chooseInvalidPathScore(Ptr<Options> options) {
auto prec = options->get<std::vector<std::string>>("precision", {"float32"});
auto computeType = typeFromString(prec[0]);
return NumericLimits<float>(computeType).lowest;
}
public:
BeamSearch(Ptr<Options> options, const std::vector<Ptr<Scorer>>& scorers, const Ptr<const Vocab> trgVocab)
: options_(options), scorers_(scorers), beamSize_(options_->get<size_t>("beam-size")), trgVocab_(trgVocab),
INVALID_PATH_SCORE{chooseInvalidPathScore(options)}
{}
// combine new expandedPathScores and previous beams into new set of beams
Beams toHyps(const std::vector<unsigned int>& nBestKeys, // [currentDimBatch, beamSize] flattened -> ((batchIdx, beamHypIdx) flattened, word idx) flattened
const std::vector<float>& nBestPathScores, // [currentDimBatch, beamSize] flattened
const size_t nBestBeamSize, // for interpretation of nBestKeys
const size_t vocabSize, // ditto.
const Beams& beams,
const std::vector<Ptr<ScorerState /*const*/>>& states,
Ptr<data::CorpusBatch /*const*/> batch, // for alignments only
Ptr<class FactoredVocab/*const*/> factoredVocab, size_t factorGroup,
const std::vector<bool>& dropBatchEntries, // [origDimBatch] - empty source batch entries are marked with true, should be cleared after first use.
const std::vector<IndexType>& batchIdxMap) const;
std::vector<float> getAlignmentsForHypothesis( // -> P(s|t) for current t and given beam and batch dim
const std::vector<float> alignAll, // [beam depth, max src length, batch size, 1], flattened vector of all attention probablities
Ptr<data::CorpusBatch> batch,
int beamHypIdx,
int currentBatchIdx,
int origBatchIdx,
int currentDimBatch) const;
// remove all beam entries that have reached EOS
Beams purgeBeams(const Beams& beams, /*in/out=*/std::vector<IndexType>& batchIdxMap);
// main decoding function
Histories search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch);
};
} // namespace marian