Program Listing for File beam_search.cpp¶
↰ Return to documentation for file (src/translator/beam_search.cpp
)
#include "translator/beam_search.h"
#include "data/factored_vocab.h"
#include "translator/helpers.h"
#include "translator/nth_element.h"
#include "data/shortlist.h"
#include "common/utils.h"
namespace marian {
// combine new expandedPathScores and previous beams into new set of beams
Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [currentDimBatch, beamSize] flattened -> ((batchIdx, beamHypIdx) flattened, word idx) flattened
const std::vector<float>& nBestPathScores, // [currentDimBatch, beamSize] flattened
const size_t nBestBeamSize, // for interpretation of nBestKeys
const size_t vocabSize, // ditto.
const Beams& beams,
const std::vector<Ptr<ScorerState /*const*/>>& states,
Ptr<data::CorpusBatch /*const*/> batch, // for alignments only
Ptr<FactoredVocab/*const*/> factoredVocab, size_t factorGroup,
const std::vector<bool>& dropBatchEntries, // [origDimBatch] - empty source batch entries are marked with true, should be cleared after first use.
const std::vector<IndexType>& batchIdxMap) const { // [origBatchIdx -> currentBatchIdx]
std::vector<float> align; // collects alignment information from the last executed time step
if(options_->hasAndNotEmpty("alignment") && factorGroup == 0)
align = scorers_[0]->getAlignment(); // [beam depth * max src length * current batch size] -> P(s|t); use alignments from the first scorer, even if ensemble,
const auto origDimBatch = beams.size(); // see function search for definition of origDimBatch and currentDimBatch etc.
Beams newBeams(origDimBatch); // return value of this function goes here. There are always origDimBatch beams.
// create a reverse batchMap to obtain original batchIdx in the starting batch size
// and calculate the current batch size based on non-empty beams
std::vector<IndexType> reverseBatchIdxMap; // empty if not purging batch entries
size_t currentDimBatch = beams.size();
if(PURGE_BATCH) {
reverseBatchIdxMap.resize(batchIdxMap.size()); // adjust size if doing batch purging.
currentDimBatch = 0;
for(int i = 0; i < batchIdxMap.size(); ++i) {
reverseBatchIdxMap[batchIdxMap[i]] = i; // reverse batch index mapping, multiple occurences get overwritten with the last one,
// which is expected due to down-shifting
if(!beams[i].empty())
currentDimBatch++;
}
}
for(size_t i = 0; i < nBestKeys.size(); ++i) { // [currentDimBatch, beamSize] flattened
// Keys encode batchIdx, beamHypIdx, and word index in the entire beam.
// They can be between 0 and (vocabSize * nBestBeamSize * batchSize)-1.
// (beamHypIdx refers to the GPU tensors, *not* the beams[] array; they are not the same in case of purging)
const auto key = nBestKeys[i];
// decompose key into individual indices (batchIdx, beamHypIdx, wordIdx)
const auto beamHypIdx = (key / vocabSize) % nBestBeamSize;
const auto currentBatchIdx = (key / vocabSize) / nBestBeamSize;
const auto origBatchIdx = reverseBatchIdxMap.empty() ? currentBatchIdx : reverseBatchIdxMap[currentBatchIdx]; // map currentBatchIdx back into original position within starting maximal batch size, required to find correct beam
bool dropHyp = !dropBatchEntries.empty() && dropBatchEntries[origBatchIdx] && factorGroup == 0;
WordIndex wordIdx;
if(dropHyp) { // if we force=drop the hypothesis, assign EOS, otherwise the expected word id.
if(factoredVocab) { // when using factoredVocab, extract the EOS lemma index from the word id, we predicting factors one by one here, hence lemma only
std::vector<size_t> eosFactors;
factoredVocab->word2factors(factoredVocab->getEosId(), eosFactors);
wordIdx = (WordIndex)eosFactors[0];
} else { // without factoredVocab lemma index and word index are the same. Safe cruising.
wordIdx = trgVocab_->getEosId().toWordIndex();
}
} else { // we are not dropping anything, just assign the normal index
wordIdx = (WordIndex)(key % vocabSize);
}
// @TODO: We currently assign a log probability of 0 to all beam entries of the dropped batch entry, instead it might be a good idea to use
// the per Hyp pathScore without the current expansion (a bit hard to obtain).
// For the case where we drop empty inputs, 0 is fine. For other use cases like a forced stop, the penultimate pathScore might be better.
// For the empty hyp this would naturally result in 0, too.
const float pathScore = dropHyp ? 0.f : nBestPathScores[i]; // 0 (Prob = 1, maximum score) if dropped or expanded path score for (batchIdx, beamHypIdx, word)
const auto& beam = beams[origBatchIdx];
auto& newBeam = newBeams[origBatchIdx]; // extended hypotheses are going to be placed in this new beam
if(newBeam.size() >= beam.size()) // getNBestList() generates N for all batch entries incl. those that already have a narrower beam
continue;
if(pathScore == INVALID_PATH_SCORE) // (dummy slot or word that cannot be expanded by current factor)
continue;
ABORT_IF(pathScore < INVALID_PATH_SCORE, "Actual pathScore ({}) is lower than INVALID_PATH_SCORE ({})??", pathScore, INVALID_PATH_SCORE); // This should not happen in valid situations. Currently the only smaller value would be -inf (effect of overflow in summation?)
ABORT_IF(beamHypIdx >= beam.size(), "Out of bounds beamHypIdx??"); // effectively this is equivalent to ABORT_IF(beams[origBatchIdx].empty(), ...)
// map wordIdx to word
auto prevBeamHypIdx = beamHypIdx; // back pointer
auto prevHyp = beam[prevBeamHypIdx];
Word word;
// If short list has been set, then wordIdx is an index into the short-listed word set,
// rather than the true word index.
auto shortlist = scorers_[0]->getShortlist();
if (factoredVocab) {
// For factored decoding, the word is built over multiple decoding steps,
// starting with the lemma, then adding factors one by one.
if (factorGroup == 0) {
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx);
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
//LOG(info, "{} + {} ({}) -> {} -> {}",
// factoredVocab->decode(prevHyp->tracebackWords()),
// factoredVocab->word2string(word), factorIndices[0], prevHyp->getPathScore(), pathScore);
}
else {
//LOG(info, "{} |{} ({}) = {} ({}) -> {} -> {}",
// factoredVocab->decodeForDiagnostics(beam[beamHypIdx]->tracebackWords()),
// factoredVocab->getFactorGroupPrefix(factorGroup), factorGroup,
// factoredVocab->getFactorName(factorGroup, wordIdx), wordIdx,
// prevHyp->getPathScore(), pathScore);
word = beam[beamHypIdx]->getWord();
ABORT_IF(!factoredVocab->canExpandFactoredWord(word, factorGroup),
"A word without this factor snuck through to here??");
word = factoredVocab->expandFactoredWord(word, factorGroup, wordIdx);
prevBeamHypIdx = prevHyp->getPrevStateIndex();
prevHyp = prevHyp->getPrevHyp(); // short-circuit the backpointer, so that the traceback does not contain partially factored words
}
}
else if (shortlist)
word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx));
else
word = Word::fromWordIndex(wordIdx);
auto hyp = Hypothesis::New(prevHyp, word, prevBeamHypIdx, pathScore);
// Set score breakdown for n-best lists
if(options_->get<bool>("n-best")) {
auto breakDown = beam[beamHypIdx]->getScoreBreakdown();
ABORT_IF(factoredVocab && factorGroup > 0 && !factoredVocab->canExpandFactoredWord(word, factorGroup),
"A word without this factor snuck through to here??");
breakDown.resize(states.size(), 0); // at start, this is empty, so this will set the initial score to 0
for(size_t j = 0; j < states.size(); ++j) {
auto lval = states[j]->getLogProbs().getFactoredLogitsTensor(factorGroup); // [maxBeamSize, 1, currentDimBatch, dimFactorVocab]
// The flatting happens based on actual (current) batch size and batch index computed with batch-pruning as we are looking into the pruned tensor
size_t flattenedLogitIndex = (beamHypIdx * currentDimBatch + currentBatchIdx) * vocabSize + wordIdx; // (beam idx, batch idx, word idx); note: beam and batch are transposed, compared to 'key'
// @TODO: use a function on shape() to index, or new method val->at({i1, i2, i3, i4}) with broadcasting
ABORT_IF(lval->shape() != Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize}) &&
(beamHypIdx == 0 && lval->shape() != Shape({1, 1, (int)currentDimBatch, (int)vocabSize})),
"Unexpected shape of logits?? {} != {}", lval->shape(), Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize}));
breakDown[j] += lval->get(flattenedLogitIndex);
}
hyp->setScoreBreakdown(breakDown);
}
// Set alignments
if(!align.empty())
hyp->setAlignment(getAlignmentsForHypothesis(align, batch, (int)beamHypIdx, (int)currentBatchIdx, (int)origBatchIdx, (int)currentDimBatch));
else // not first factor: just copy
hyp->setAlignment(beam[beamHypIdx]->getAlignment());
newBeam.push_back(hyp);
}
// if factored vocab and this is not the first factor, we need to
// also propagate factored hypotheses that do not get expanded in this step because they don't have this factor
if (factorGroup > 0) {
for (size_t batchIdx = 0; batchIdx < beams.size(); batchIdx++) {
const auto& beam = beams[batchIdx];
auto& newBeam = newBeams[batchIdx];
for (const auto& beamHyp : beam) {
auto word = beamHyp->getWord();
//LOG(info, "Checking {}", factoredVocab->word2string(word));
if (factoredVocab->canExpandFactoredWord(word, factorGroup)) // handled above
continue;
//LOG(info, "Forwarded {}", factoredVocab->word2string(word));
newBeam.push_back(beamHyp);
}
if (newBeam.size() > beam.size()) {
//LOG(info, "Size {}, sorting...", newBeam.size());
std::nth_element(newBeam.begin(), newBeam.begin() + beam.size(), newBeam.end(), [](Hypothesis::PtrType a, Hypothesis::PtrType b) {
return a->getPathScore() > b->getPathScore(); // (sort highest score first)
});
//LOG(info, "Size {}, sorted...", newBeam.size());
newBeam.resize(beam.size());
}
}
}
return newBeams;
}
std::vector<float> BeamSearch::getAlignmentsForHypothesis( // -> P(s|t) for current t and given beam and batch dim
const std::vector<float> alignAll, // [beam depth, max src length, batch size, 1], flattened vector of all attention probablities
Ptr<data::CorpusBatch> batch,
int beamHypIdx,
int currentBatchIdx,
int origBatchIdx,
int currentDimBatch) const {
// Let's B be the beam size, N be the number of batched sentences,
// and L the number of words in the longest sentence in the batch.
// The alignment vector:
//
// if(first)
// * has length of N x L if it's the first beam
// * stores elements in the following order:
// beam1 = [word1-batch1, word1-batch2, ..., word2-batch1, ...]
// else
// * has length of N x L x B
// * stores elements in the following order:
// beams = [beam1, beam2, ..., beam_n]
//
// The mask vector is always of length N x L and has 1/0s stored like
// in a single beam, i.e.:
// * [word1-batch1, word1-batch2, ..., word2-batch1, ...]
//
size_t origDimBatch = batch->size(); // number of sentences in batch
size_t batchWidth = batch->width(); // max src length
// loop over words of batch entry 'currentBatchIdx' and beam entry 'beamHypIdx'
std::vector<float> align;
for(size_t srcPos = 0; srcPos < batchWidth; ++srcPos) { // loop over source positions
// We are looking into the probabilites from an actual tensor, hence we need to use currentDimBatch and currentBatchIdx.
size_t currentAttIdx = (batchWidth * beamHypIdx + srcPos) * currentDimBatch + currentBatchIdx; // = flatten [beam index, s, batch index, 0]
// We are looking into the mask from the orginal batch, hence we need to use origDmBatch and origBatchIdx.
size_t origAttIdx = (batchWidth * beamHypIdx + srcPos) * origDimBatch + origBatchIdx;; // = flatten [beam index, s, batch index, 0]
size_t origMaskIdx = origAttIdx % (batchWidth * origDimBatch); // == batchIdx + (batchSize * srcPos) = flatten [0, s, batch index, 0]
// If the original position is not masked out used the corresponding current attention score.
if(batch->front()->mask()[origMaskIdx] != 0)
align.emplace_back(alignAll[currentAttIdx]);
}
return align;
}
// remove all beam entries that have reached EOS
Beams BeamSearch::purgeBeams(const Beams& beams, /*in/out=*/std::vector<IndexType>& batchIdxMap) {
const auto trgEosId = trgVocab_->getEosId();
Beams newBeams;
size_t beamIdx = 0; // beam index
for(auto beam : beams) {
Beam newBeam; // a beam of surviving hyps
for(auto hyp : beam)
if(hyp->getWord() != trgEosId) // if this hyp is not finished,
newBeam.push_back(hyp); // move over to beam of surviving hyps
if(PURGE_BATCH)
if(newBeam.empty() && !beam.empty()) { // previous beam had hyps, but all were finished in this step, newBeam will now stay empty
for(size_t i = beamIdx + 1; i < beams.size(); ++i) // for all entries above this beam
batchIdxMap[i] = batchIdxMap[i] - 1; // make them look at one batch index below, as the current entry will be removed from the batch.
}
newBeams.push_back(newBeam);
beamIdx++; // move to next beam index
}
return newBeams;
}
//**********************************************************************
// main decoding function
Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
auto factoredVocab = trgVocab_->tryAs<FactoredVocab>();
size_t numFactorGroups = factoredVocab ? factoredVocab->getNumGroups() : 1;
if (numFactorGroups == 1) // if no factors then we didn't need this object in the first place
factoredVocab.reset();
// We will use the prefix "origBatch..." whenever we refer to batch dimensions of the original batch. These do not change during search.
// We will use the prefix "currentBatch.." whenever we refer to batch dimension that can change due to batch-pruning.
const int origDimBatch = (int)batch->size();
const auto trgEosId = trgVocab_->getEosId();
auto getNBestList = createGetNBestListFn(beamSize_, origDimBatch, graph->getDeviceId());
for(auto scorer : scorers_) {
scorer->clear(graph);
}
Histories histories(origDimBatch);
for(int i = 0; i < origDimBatch; ++i) {
size_t sentId = batch->getSentenceIds()[i];
histories[i] = New<History>(sentId,
options_->get<float>("normalize"),
options_->get<float>("word-penalty"));
}
// start states
std::vector<Ptr<ScorerState>> states;
for(auto scorer : scorers_) {
states.push_back(scorer->startState(graph, batch));
}
// create one beam per batch entry with sentence-start hypothesis
Beams beams(origDimBatch, Beam(beamSize_, Hypothesis::New())); // array [origDimBatch] of array [maxBeamSize] of Hypothesis, keeps full size through search.
// batch purging is determined from an empty sub-beam.
std::vector<IndexType> batchIdxMap(origDimBatch); // Record at which batch entry a beam is looking.
// By default that corresponds to position in array,
// but shifts in the course of removing batch entries when they are finished.
const std::vector<bool> emptyBatchEntries; // used for recording if there are empty input batch entries
for(int origBatchIdx = 0; origBatchIdx < origDimBatch; ++origBatchIdx) {
batchIdxMap[origBatchIdx] = origBatchIdx; // map to same position on initialization
auto& beam = beams[origBatchIdx];
histories[origBatchIdx]->add(beam, trgEosId); // add beams with start-hypotheses to traceback grid
// Mark batch entries that consist only of source <EOS> i.e. these are empty lines. They will be forced to EOS and purged from batch
const auto& srcEosId = batch->front()->vocab()->getEosId();
const_cast<std::vector<bool>&>(emptyBatchEntries).push_back(batch->front()->data()[origBatchIdx] == srcEosId); // const_cast during construction
}
Expr suppressedWordIndices;
bool suppressUnk = !options_->get<bool>("allow-unk", false);
bool suppressSpecial = !options_->get<bool>("allow-special", false);
if (suppressUnk || suppressSpecial) { // do we need to suppress unk or special?
std::vector<WordIndex> suppressed = trgVocab_->suppressedIndices(suppressUnk, suppressSpecial);
auto shortlist = scorers_[0]->getShortlist(); // first shortlist is generally ok, @TODO: make sure they are the same across scorers?
if(shortlist) // check if suppressed words are allowed by the shortlist, if not, remove
suppressed.erase(std::remove_if(suppressed.begin(),
suppressed.end(),
[&](WordIndex i) {
return shortlist->tryForwardMap(i) == data::Shortlist::npos;
}),
suppressed.end());
if(!suppressed.empty())
suppressedWordIndices = graph->indices(suppressed);
}
// the decoding process updates the following state information in each output time step:
// - beams: array [origDimBatch] of array [maxBeamSize] of Hypothesis
// - current output time step's set of active hypotheses, aka active search space
// - states[.]: ScorerState
// - NN state; one per scorer, e.g. 2 for ensemble of 2
// and it forms the following return value
// - histories: array [origDimBatch] of History
// with History: vector [t] of array [maxBeamSize] of Hypothesis
// with Hypothesis: (last word, aggregate score, prev Hypothesis)
IndexType currentDimBatch = origDimBatch;
auto prevBatchIdxMap = batchIdxMap; // [origBatchIdx -> currentBatchIdx] but shifted by one time step
// main loop over output time steps
for (size_t t = 0; ; t++) {
//std::cerr << "\nstep=" << t << std::endl;
ABORT_IF(origDimBatch != beams.size(), "Lost a batch entry??");
// determine beam size for next output time step, as max over still-active sentences
// E.g. if all batch entries are down from beam 5 to no more than 4 surviving hyps, then
// switch to beam of 4 for all. If all are done, then beam ends up being 0, and we are done.
size_t maxBeamSize = 0; // @TODO: is there some std::algorithm for this?
for(auto& beam : beams)
if(beam.size() > maxBeamSize)
maxBeamSize = beam.size();
// done if all batch entries have reached EOS on all beam entries
if (maxBeamSize == 0)
break;
for (size_t factorGroup = 0; factorGroup < numFactorGroups; factorGroup++) {
// for factored vocabs, we do one factor at a time, but without updating the scorer for secondary factors
//**********************************************************************
// create constant containing previous path scores for current beam
// Also create mapping of hyp indices, for reordering the decoder-state tensors.
std::vector<IndexType> batchIndices; // [1, 1, currentDimBatch, 1] indices of currently used batch indices with regard to current, actual tensors
std::vector<IndexType> hypIndices; // [maxBeamSize, 1, currentDimBatch, 1] (flattened) tensor index ((beamHypIdx, batchIdx), flattened) of prev hyp that a hyp originated from
std::vector<Word> prevWords; // [maxBeamSize, 1, currentDimBatch, 1] (flattened) word that a hyp ended in, for advancing the decoder-model's history
Expr prevPathScores; // [maxBeamSize, 1, currentDimBatch, 1], path score that a hyp ended in (last axis will broadcast into vocab size when adding expandedPathScores)
bool anyCanExpand = false; // stays false if all hyps are invalid factor expansions
if(t == 0 && factorGroup == 0) { // no scores yet
prevPathScores = graph->constant({1, 1, 1, 1}, inits::fromValue(0));
anyCanExpand = true;
// at the beginning all batch entries are used
batchIndices.resize(origDimBatch);
std::iota(batchIndices.begin(), batchIndices.end(), 0);
} else {
if(factorGroup == 0) // only factorGroup==0 can subselect neural state
for(int currentBatchIdx = 0; currentBatchIdx < beams.size(); ++currentBatchIdx) // loop over batch entries (active sentences)
if(!beams[currentBatchIdx].empty() || !PURGE_BATCH) // for each beam check
batchIndices.push_back(prevBatchIdxMap[currentBatchIdx]); // which batch entries were active in previous step
std::vector<float> prevScores;
for(size_t beamHypIdx = 0; beamHypIdx < maxBeamSize; ++beamHypIdx) { // loop over globally maximal beam-size (maxBeamSize)
for(int origBatchIdx = 0; origBatchIdx < origDimBatch; ++origBatchIdx) { // loop over all batch entries (active and inactive)
auto& beam = beams[origBatchIdx];
if(beamHypIdx < beam.size()) {
auto hyp = beam[beamHypIdx];
auto word = hyp->getWord();
auto canExpand = (!factoredVocab || factoredVocab->canExpandFactoredWord(hyp->getWord(), factorGroup));
//LOG(info, "[{}, {}] Can expand {} with {} -> {}", batchIdx, beamHypIdx, (*batch->back()->vocab())[hyp->getWord()], factorGroup, canExpand);
anyCanExpand |= canExpand;
auto currentBatchIdx = origBatchIdx;
if(PURGE_BATCH) {
if(factorGroup == 0)
currentBatchIdx = prevBatchIdxMap[origBatchIdx]; // subselection may happen for factorGroup == 0
else
currentBatchIdx = batchIdxMap[origBatchIdx]; // no subselection happens for factorGroup > 0,
// but we treat it like a next step, since a step
// happened for factorGroup == 0
}
auto hypIndex = (IndexType)(hyp->getPrevStateIndex() * currentDimBatch + currentBatchIdx); // (beamHypIdx, batchIdx), flattened, for index_select() operation
hypIndices.push_back(hypIndex); // (beamHypIdx, batchIdx), flattened as said above.
prevWords .push_back(word);
prevScores.push_back(canExpand ? hyp->getPathScore() : INVALID_PATH_SCORE);
} else { // pad to maxBeamSize (dummy hypothesis)
if(!PURGE_BATCH || !beam.empty()) { // but only if we are not pruning and the beam is not deactivated yet
hypIndices.push_back(0);
prevWords.push_back(trgEosId); // (unused, but must be valid)
prevScores.push_back((float)INVALID_PATH_SCORE);
}
}
}
}
if(factorGroup == 0)
currentDimBatch = (IndexType) batchIndices.size(); // keep batch size constant for all factor groups in a time step
prevPathScores = graph->constant({(int)maxBeamSize, 1, (int)currentDimBatch, 1}, inits::fromVector(prevScores));
}
if (!anyCanExpand) // all words cannot expand this factor: skip
continue;
//**********************************************************************
// compute expanded path scores with word prediction probs from all scorers
auto expandedPathScores = prevPathScores; // will become [maxBeamSize, 1, currDimBatch, dimVocab]
Expr logProbs;
for(size_t i = 0; i < scorers_.size(); ++i) {
if (factorGroup == 0) {
// compute output probabilities for current output time step
// - uses hypIndices[index in beam, 1, batch index, 1] to reorder scorer state to reflect the top-N in beams[][]
// - adds prevWords [index in beam, 1, batch index, 1] to the scorer's target history
// - performs one step of the scorer
// - returns new NN state for use in next output time step
// - returns vector of prediction probabilities over output vocab via newState
// update state in-place for next output time step
//if (t > 0) for (size_t kk = 0; kk < prevWords.size(); kk++)
// LOG(info, "prevWords[{},{}]={} -> {}", t/numFactorGroups, factorGroup,
// factoredVocab ? factoredVocab->word2string(prevWords[kk]) : (*batch->back()->vocab())[prevWords[kk]],
// prevScores[kk]);
states[i] = scorers_[i]->step(graph, states[i], hypIndices, prevWords, batchIndices, (int)maxBeamSize);
if (numFactorGroups == 1) { // @TODO: this branch can go away
logProbs = states[i]->getLogProbs().getLogits(); // [maxBeamSize, 1, currentDimBatch, dimVocab]
} else {
auto shortlist = scorers_[i]->getShortlist();
logProbs = states[i]->getLogProbs().getFactoredLogits(factorGroup, shortlist); // [maxBeamSize, 1, currentDimBatch, dimVocab]
}
}
else {
// add secondary factors
// For those, we don't update the decoder-model state in any way.
// Instead, we just keep expanding with the factors.
// We will have temporary Word entries in hyps with some factors set to FACTOR_NOT_SPECIFIED.
// For some lemmas, a factor is not applicable. For those, the factor score is the same (zero)
// for all factor values. This would thus unnecessarily pollute the beam with identical copies,
// and push out other hypotheses. Hence, we exclude those here by setting the path score to
// INVALID_PATH_SCORE. Instead, toHyps() explicitly propagates those hyps by simply copying the
// previous hypothesis.
logProbs = states[i]->getLogProbs().getFactoredLogits(factorGroup, /*shortlist=*/ nullptr, hypIndices, maxBeamSize); // [maxBeamSize, 1, currentDimBatch, dimVocab]
}
// expand all hypotheses, [maxBeamSize, 1, currentDimBatch, 1] -> [maxBeamSize, 1, currentDimBatch, dimVocab]
expandedPathScores = expandedPathScores + scorers_[i]->getWeight() * logProbs;
}
// make beams continuous
expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [currentDimBatch, 1, maxBeamSize, dimVocab]
// perform NN computation
if(t == 0 && factorGroup == 0)
graph->forward();
else
graph->forwardNext();
//**********************************************************************
// suppress specific symbols if not at right positions
if(suppressedWordIndices && factorGroup == 0)
suppressWords(expandedPathScores, suppressedWordIndices);
//**********************************************************************
// perform beam search
// find N best amongst the (maxBeamSize * dimVocab) hypotheses
std::vector<unsigned int> nBestKeys; // [currentDimBatch, maxBeamSize] flattened -> (batchIdx, beamHypIdx, word idx) flattened
std::vector<float> nBestPathScores; // [currentDimBatch, maxBeamSize] flattened
getNBestList(/*in*/ expandedPathScores->val(), // [currentDimBatch, 1, maxBeamSize, dimVocab or dimShortlist]
/*N=*/ maxBeamSize, // desired beam size
/*out*/ nBestPathScores,
/*out*/ nBestKeys,
/*first=*/t == 0 && factorGroup == 0); // @TODO: this is only used for checking presently, and should be removed altogether
// Now, nBestPathScores contain N-best expandedPathScores for each batch and beam,
// and nBestKeys for each their original location (batchIdx, beamHypIdx, word).
// combine N-best sets with existing search space (beams) to updated search space
beams = toHyps(nBestKeys, nBestPathScores,
/*nBestBeamSize*/expandedPathScores->shape()[-2], // used for interpretation of keys
/*vocabSize=*/expandedPathScores->shape()[-1], // used for interpretation of keys
beams,
states, // used for keeping track of per-ensemble-member path score
batch, // only used for propagating alignment info
factoredVocab, factorGroup,
emptyBatchEntries, // [origDimBatch] - empty source batch entries are marked with true
batchIdxMap); // used to create a reverse batch index map to recover original batch indices for this step
} // END FOR factorGroup = 0 .. numFactorGroups-1
prevBatchIdxMap = batchIdxMap; // save current batchIdx map to be used in next step; we are then going to look one step back
// remove all hyps that end in EOS
// The position of a hyp in the beam may change.
// in/out = shifts the batch index map if a beam gets fully purged
const auto purgedNewBeams = purgeBeams(beams, /*in/out=*/batchIdxMap);
// add updated search space (beams) to our return value
bool maxLengthReached = false;
for(int batchIdx = 0; batchIdx < origDimBatch; ++batchIdx) {
// if this batch entry has surviving hyps then add them to the traceback grid
if(!beams[batchIdx].empty()) { // if the beam is not empty expand the history object associated with the beam
if (histories[batchIdx]->size() >= options_->get<float>("max-length-factor") * batch->front()->batchWidth())
maxLengthReached = true;
histories[batchIdx]->add(beams[batchIdx], trgEosId, purgedNewBeams[batchIdx].empty() || maxLengthReached);
}
}
if (maxLengthReached) // early exit if max length limit was reached
break;
// this is the search space for the next output time step
beams = purgedNewBeams;
} // end of main loop over output time steps
return histories; // [origDimBatch][t][N best hyps]
}
} // namespace marian