Program Listing for File topk.cu

Return to documentation for file (src/tensors/gpu/topk.cu)

#include "tensors/tensor_operators.h"
#include "tensors/gpu/cuda_helpers.h"
#include "tensors/allocator.h"

#include <cuda.h>

// GPU implementation of proper Marian top-k operator for TopkNodeOp
// This file contains a lot of code-duplicaton with src/translator/nth_element.cu
// the goal is to replace the beam-search specific topk search with this code.
// Currently this is only used in the unit tests, but we will move forward and
// make the beam-search more graph and operator-based.

namespace marian {
namespace gpu {

const int MAX_BINS   = 500;
const int BLOCK_SIZE = 512;

#define UNROLL_MAXARG_LOOP(n, max)                    \
  if(tid < (n) && tid + (n) < (max)) {                \
    if(sharedValues[tid + (n)] > sharedValues[tid]) { \
      sharedIndices[tid] = sharedIndices[tid + (n)];  \
      sharedValues[tid]  = sharedValues[tid + (n)];   \
    }                                                 \
  }

// finds maximum element (first step)
template <typename T>
__global__ void gMaxElement(IndexType* binIndices, // out: top-k positions
                            T* binValues,          // out: top-k scores
                            const T* inValues,     // this is the probs array, only one with type float or half
                            int rows,              // we iterate over this many rows, row-major layout
                            int cols,              // a row has that many columns, row-major layout
                            float minimal,         // minimal is the smallest possible value. For simplicity we assume we look for the maxmimum.
                            bool descending)       // This will be the largest possible value if the order is reversed (i.e. we look for the minimum).
{
  extern __shared__ float sharedValues[];
  __shared__ IndexType sharedIndices[BLOCK_SIZE];

  // id of current thread within block
  int tid = threadIdx.x;

  float flip = descending ? 1.f : -1.f;

  // Roll over every row in row-major 2D representation of the data
  for(int rowIdx = 0; rowIdx < rows; ++rowIdx) {
    int begin = rowIdx * cols;        // start index of a row
    int end   = rowIdx * cols + cols; // end index of a row

    // We look at at most blockDim.x * 2 = 1024 values within a block, i.e. each thread reduces two values.
    // Here we set the position to begin + blockId * 1024 + threadId. If a row has more values we
    // partition the row according to blocks of 1024 values.
    int i = begin + blockIdx.x * (blockDim.x * 2) + tid;

    // Initialize shared values to minimal value.
    sharedValues[tid] = minimal;

    // Do first set of comparisons outside loop, saves one iteration.
    if(i + blockDim.x < end) { // Are we in a position for which we can access and compare two values in a row partition (shifted by block size)?
      // yes, hence compare:
      float a = flip * (float)inValues[i];              // value from first half of row parition for this block
      float b = flip * (float)inValues[i + blockDim.x]; // value from second half of row partition for this block
      if(a > b) { // just a max
        sharedIndices[tid] = i;
        sharedValues[tid]  = a;
      } else {
        sharedIndices[tid] = i + blockDim.x;
        sharedValues[tid]  = b;
      }
    } else if(i < end) { // Are we instead in a position that has access to one value in the row partition (shifting by block size would be out of bounds)?
      // Yes, hence save the current value and index as new max, no need to compare.
      sharedIndices[tid] = i;
      sharedValues[tid]  = flip * (float)inValues[i];
    } // nothing else to do here

    // We move to the next set of 1024 values shifted by block size times number of blocks
    // and look at two of them according to thread id.
    while(i + 2 * gridDim.x * blockDim.x < end) {
      i += 2 * gridDim.x * blockDim.x;

      // Check if first value is larger than what we have seen so far
      float a = flip * (float)inValues[i];
      if(a > sharedValues[tid]) {
        // Yes, hence save index and value
        sharedIndices[tid] = i;
        sharedValues[tid]  = a;
      }

      // Check if second value is larger than what we have seen so far
      if(i + blockDim.x < end) {
        float b = flip * (float)inValues[i + blockDim.x];
        if(b > sharedValues[tid]) {
          // Yes, hence save index and value
          sharedIndices[tid] = i + blockDim.x;
          sharedValues[tid]  = b;
        }
      }
    }

    // We are done with the first sweep and have populated shared memory, time to wait for the other threads and reduce it all
    __syncthreads();

    // Reduce over shared memory, here per loop until we hit the last 32 unreduced elements
    for(int s = (blockDim.x >> 1); s > 32; s >>= 1) {
      if(tid < s && tid + s < end) {
        if(sharedValues[tid + s] > sharedValues[tid]) {
          // keep the max
          sharedIndices[tid] = sharedIndices[tid + s];
          sharedValues[tid]  = sharedValues[tid + s];
        }
      }
      __syncthreads();
    }

    // Reduce over shared memory, here per unrolled code for powers of 2 lower equal 32.
    // Because we are at 32 (warp size) the threads run in lock-step and we can abandon syncing.
    UNROLL_MAXARG_LOOP(32, end);
    UNROLL_MAXARG_LOOP(16, end);
    UNROLL_MAXARG_LOOP(8, end);
    UNROLL_MAXARG_LOOP(4, end);
    UNROLL_MAXARG_LOOP(2, end);
    UNROLL_MAXARG_LOOP(1, end);

    // OK, we are done with the reduction and in the first thread
    if(tid == 0) {
      // assign the final maximal value to the bin, one bin per row and block
      binIndices[rowIdx * gridDim.x + blockIdx.x] = sharedIndices[0]; // [rows, num_blocks]
      binValues[rowIdx * gridDim.x + blockIdx.x]  = sharedValues[0];  // [rows, num_blocks]
    }
    __syncthreads();
  }
}

// This runs after the function above, we now have the maximum value per row and block and can look further
// for the top-k results. As above we pretend this does only maximum search.
// This runs restricted to one row (one row per block)
template <typename T>
__global__ void gMaxElementUpdate(IndexType* binIndices, // memory for bin indices
                                  T* binValues,          // memory for bin costs
                                  IndexType* outIndices, // result indices
                                  T* outValues,          // result costs
                                  T* inValues,           // should work well enough with half, uses float everywhere else
                                  const int cols,        // size of continous memory we search over
                                  const int K,           // how many top-K elements?
                                  int numBlocks,        // number of blocks/bins used in above function (per row)
                                  float minimal,         // value for minimal element
                                  bool descending)
{
  extern __shared__ float sharedValues[];
  __shared__ int    sharedIndices[BLOCK_SIZE];
  __shared__ float  bestBinCost;
  __shared__ int    bestBinCostIdx;

  const int tid    = threadIdx.x;

  float flip = descending ? 1.f : -1.f;

  // we only look at one row in this kernel
  const int rowIdx = blockIdx.x;           // index of the row corresponds to block index
  const int begin  = rowIdx * cols;        // start offset for this row relative to inValues tensor start
  const int end    = rowIdx * cols + cols; // end offset for this row relative to inValues tensor start

  int num_bins = numBlocks; // why not just use numBlocks?

  // iterate over top-k results
  for(int k = 0; k < K; ++k) {

    int kthOutIdx = rowIdx * K + k;  // offset into output tensor relative to outIndices/outValues tensor start
    int i = tid;

    sharedValues[tid] = minimal; // initialize to smallest value, everything else will be larger

    // as in the function above, the code here does a tree reduction over shared memory to find the single maximum element
    if(i + blockDim.x < num_bins) {
      float a = binValues[rowIdx * numBlocks + i];
      float b = binValues[rowIdx * numBlocks + i + blockDim.x];
      if(a > b) {
        sharedValues[tid] = a;
        sharedIndices[tid] = i;
      } else {
        sharedValues[tid] = b;
        sharedIndices[tid] = i + blockDim.x;
      }
    } else if(i < num_bins) {
      sharedValues[tid] = binValues[rowIdx * numBlocks + i];
      sharedIndices[tid] = i;
    }

    while(i + 2 * blockDim.x < num_bins) {
      i += 2 * blockDim.x;

      float a = binValues[rowIdx * numBlocks + i];
      if(a > sharedValues[tid]) {
        sharedValues[tid] = a;
        sharedIndices[tid] = i;
      }

      if(i + blockDim.x < num_bins) {
        float b = binValues[rowIdx * numBlocks + i + blockDim.x];
        if(b > sharedValues[tid]) {
          sharedValues[tid] = b;
          sharedIndices[tid] = i + blockDim.x;
        }
      }
    }

    __syncthreads();

    for(int s = (blockDim.x >> 1); s > 32; s >>= 1) {
      if(tid < s && tid + s < num_bins) {
        if(sharedValues[tid + s] > sharedValues[tid]) {
          sharedValues[tid] = sharedValues[tid + s];
          sharedIndices[tid] = sharedIndices[tid + s];
        }
      }
      __syncthreads();
    }

    UNROLL_MAXARG_LOOP(32, num_bins);
    UNROLL_MAXARG_LOOP(16, num_bins);
    UNROLL_MAXARG_LOOP(8, num_bins);
    UNROLL_MAXARG_LOOP(4, num_bins);
    UNROLL_MAXARG_LOOP(2, num_bins);
    UNROLL_MAXARG_LOOP(1, num_bins);

    if(tid == 0) {
      bestBinCost = sharedValues[0];
      bestBinCostIdx = rowIdx * numBlocks + sharedIndices[0];

      inValues[binIndices[bestBinCostIdx]] = flip * minimal; // this is restored in the last lines of this function

      outIndices[kthOutIdx] = binIndices[bestBinCostIdx] - begin; // relative to beginning of row hence substract `begin`
      outValues[kthOutIdx]  = flip * bestBinCost; // undo flip by flipping again
    }

    __syncthreads();

    // Second part of the algorithm, why it that not replacing the first function call??
    // Also shouldn't we skip here if k == K - 1?

    // After marking the previously largest element with "flip * minimal" we populate again
    // shared memory with the largest element as in gMaxElement(...)

    if(k < K - 1) {
      i = begin + (bestBinCostIdx - rowIdx * numBlocks) * (blockDim.x * 2) + tid;
      const int dist = num_bins * 2 * blockDim.x;

      sharedValues[tid] = minimal;

      if(i + blockDim.x < end) {
        float a = flip * (float)inValues[i];
        float b = flip * (float)inValues[i + blockDim.x];
        if(a > b) {
          sharedIndices[tid] = i;
          sharedValues[tid]  = a;
        } else {
          sharedIndices[tid] = i + blockDim.x;
          sharedValues[tid]  = b;
        }
      } else if(i < end) {
        sharedIndices[tid] = i;
        sharedValues[tid] = flip * (float)inValues[i];
      }

      while(i + dist < end) {
        i += dist;

        float a = flip * (float)inValues[i];
        if(a > sharedValues[tid]) {
          sharedIndices[tid] = i;
          sharedValues[tid]  = a;
        }

        if(i + blockDim.x < end) {
          float b = flip * (float)inValues[i + blockDim.x];
          if(b > sharedValues[tid]) {
            sharedIndices[tid] = i + blockDim.x;
            sharedValues[tid] =  b;
          }
        }
      }

      __syncthreads();

      for(int s = (blockDim.x >> 1); s > 32; s >>= 1) {
        if(tid < s && tid + s < end) {
          if(sharedValues[tid + s] > sharedValues[tid]) {
            sharedIndices[tid] = sharedIndices[tid + s];
            sharedValues[tid] = sharedValues[tid + s];
          }
        }
        __syncthreads();
      }

      UNROLL_MAXARG_LOOP(32, end);
      UNROLL_MAXARG_LOOP(16, end);
      UNROLL_MAXARG_LOOP(8, end);
      UNROLL_MAXARG_LOOP(4, end);
      UNROLL_MAXARG_LOOP(2, end);
      UNROLL_MAXARG_LOOP(1, end);

      if(tid == 0) {
        binIndices[bestBinCostIdx] = sharedIndices[0];
        binValues[bestBinCostIdx]  = sharedValues[0];
      }
      __syncthreads();
    }
  }

  // final operation to restore blanked-out input values. They were blanked out for marking
  // already found values. Since we want input values to be invariant we restore here.
  // @TODO: The lack of constness here might be a problem for concurrent processing (which we currently don't have)
  for(int k = tid; k < K; k += blockDim.x) {
    int kthOutIdx = rowIdx * K + k;
    inValues[begin + outIndices[kthOutIdx]] = outValues[kthOutIdx];
  }
}

void TopK(Tensor outVal, Tensor outInd, Ptr<Allocator> allocator, const Tensor in, int k, int axis, bool descending) {

  ABORT_IF(axis != in->shape().size() - 1, "Currently only works for last axis");
  ABORT_IF(!isFloat(in->type()), "Input should be float type and not {}", in->type());
  ABORT_IF(outInd->type() != Type::uint32, "Output should be have type {}", Type::uint32);
  ABORT_IF(outVal->type() != in->type(), "Output should be have type {}", in->type());

  cudaSetDevice(outInd->getDeviceId().no);

  int cols = in->shape()[-1];               // e.g. in beam search that would be [beam * dimVoc]
  int rows = in->shape().elements() / cols; // e.g. in beam search that would be [time * batch]

  ABORT_IF(k > cols, "Cannot select more than {} elements for axis {}", cols, axis);

  float minimal = NumericLimits<float>(in->type()).lowest; // lowest if looking for max

  const int numBlocks = std::min(MAX_BINS, int(cols / (2 * BLOCK_SIZE)) + int(cols % (2 * BLOCK_SIZE) != 0));
  auto tempMemInd = allocator->alloc<IndexType>(rows * numBlocks);

  MemoryPiece::PtrType tempMemVal;
  if(in->type() == Type::float32) {
    tempMemVal = allocator->alloc<float>(rows * numBlocks);
    // first find the maximum value per row and block and save indices and values to temporary memory
    gMaxElement<<<numBlocks, // blocks
                  BLOCK_SIZE, // threads
                  BLOCK_SIZE * sizeof(float), // shared memory size
                  /* stream_ */ 0>>>(
      tempMemInd->data<IndexType>(), tempMemVal->data<float>(),
      in->data<float>(), rows, cols, minimal, descending);
    gMaxElementUpdate<<<rows,       // blocks ... seems we can have up to 2^31-1 of these, so we are safe?
                        BLOCK_SIZE, // threads
                        BLOCK_SIZE * sizeof(float),  // shared memory size
                        /* stream_ */ 0>>>(
      tempMemInd->data<IndexType>(), tempMemVal->data<float>(),
      outInd->data<IndexType>(), outVal->data<float>(),
      in->data<float>(), cols, k, numBlocks, minimal, descending);
#if COMPILE_FP16
  } else if(in->type() == Type::float16) {
    tempMemVal = allocator->alloc<__half>(rows * numBlocks);
    // first find the maximum value per row and block and save indices and values to temporary memory
    gMaxElement<<<numBlocks, // blocks
                  BLOCK_SIZE, // threads
                  BLOCK_SIZE * sizeof(float), // shared memory size
                  /* stream_ */ 0>>>(
      tempMemInd->data<IndexType>(), tempMemVal->data<__half>(),
      in->data<__half>(), rows, cols, minimal, descending);
    gMaxElementUpdate<<<rows,       // blocks ... seems we can have up to 2^31-1 of these, so we are safe?
                        BLOCK_SIZE, // threads
                        BLOCK_SIZE * sizeof(float),  // shared memory size
                        /* stream_ */ 0>>>(
      tempMemInd->data<IndexType>(), tempMemVal->data<__half>(),
      outInd->data<IndexType>(), outVal->data<__half>(),
      in->data<__half>(), cols, k, numBlocks, minimal, descending);
#endif
  } else {
    ABORT("Topk not implemented for type {}", in->type());
  }

  allocator->free(tempMemInd);
  allocator->free(tempMemVal);
}

}
}