Program Listing for File topk.cpp

Return to documentation for file (src/tensors/cpu/topk.cpp)

#include "tensors/tensor_operators.h"
#include "tensors/allocator.h"
#include <numeric>

// CPU implementation of proper Marian top-k operator for TopkNodeOp
// This file contains a lot of code-duplicaton with src/translator/nth_element.cpp
// the goal is to replace the beam-search specific topk search with this code.
// Currently this is only used in the unit tests, but we will move forward and
// make the beam-search more graph and operator-based.

namespace marian {
namespace cpu {

void TopK(Tensor outVal, Tensor outInd, Ptr<Allocator> /*allocator*/, const Tensor in, int k, int axis, bool descending) {

  ABORT_IF(axis != in->shape().size() - 1, "Currently only works for last axis");
  ABORT_IF(in->type() != Type::float32, "Input should have type {}", Type::float32);
  ABORT_IF(outInd->type() != Type::uint32, "Output should be have type {}", Type::uint32);

  int cols = in->shape()[axis];
  int rows = in->shape().elements() / cols;

  ABORT_IF(k > cols, "Cannot select more than {} elements for axis {}", cols, axis);

  std::vector<IndexType> idxs(cols);
  std::iota(idxs.begin(), idxs.end(), 0);

  const float* inDataPtr = in->data<float>();
  IndexType* outIndPtr   = outInd->data<IndexType>();
  float* outValPtr       = outVal->data<float>();
  for(int i = 0; i < rows; ++i) {
    std::partial_sort(
      // sorts the top N (beam size) idxs by score to the front
      idxs.begin(),
      idxs.begin() + k,
      idxs.end(),
      [&](int a, int b) {
        return descending ? inDataPtr[a] > inDataPtr[b] : inDataPtr[a] < inDataPtr[b];
      }
    );

    for(int j = 0; j < k; j++) {
      outIndPtr[j] = idxs[j];
      outValPtr[j] = inDataPtr[idxs[j]];
    }

    outIndPtr += k;
    outValPtr += k;
    inDataPtr += cols;
  }
}

}
}