.. _program_listing_file_src_translator_nth_element.cu: Program Listing for File nth_element.cu ======================================= |exhale_lsh| :ref:`Return to documentation for file ` (``src/translator/nth_element.cu``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp /* All or part of this file was contributed by Intel under license: * Copyright (C) 2017-2018 Intel Corporation * SPDX-License-Identifier: MIT */ #include #include "translator/nth_element.h" #include #include "tensors/gpu/cuda_helpers.h" namespace marian { #define UNROLL_MAXARG_LOOP(n, max) \ if(tid < (n) && tid + (n) < (max)) { \ if(sdata[tid + (n)] > sdata[tid]) { \ sdata[tid] = sdata[tid + (n)]; \ indices[tid] = indices[tid + (n)]; \ } \ } template __global__ void gMaxElement(float* d_out, int* d_ind, T* d_in, // this is the probs array, only one with type float or half int numBatches, int* batchFirstElementIdxs, float disabledPathScore) // disabledPathScore is used to blank out found values, type-dependent { extern __shared__ float sdata[]; __shared__ int indices[512]; int tid = threadIdx.x; for(int batchIdx = 0; batchIdx < numBatches; ++batchIdx) { int begin = batchFirstElementIdxs[batchIdx]; int end = batchFirstElementIdxs[batchIdx + 1]; int i = begin + blockIdx.x * (blockDim.x * 2) + tid; sdata[tid] = disabledPathScore; if(i < end) { sdata[tid] = (float)d_in[i]; indices[tid] = i; } if(i + blockDim.x < end) { float a = (float)d_in[i]; float b = (float)d_in[i + blockDim.x]; if(a > b) { sdata[tid] = a; indices[tid] = i; } else { sdata[tid] = b; indices[tid] = i + blockDim.x; } } while(i + 2 * gridDim.x * blockDim.x < end) { i += 2 * gridDim.x * blockDim.x; float a = (float)d_in[i]; if(a > sdata[tid]) { sdata[tid] = a; indices[tid] = i; } if(i + blockDim.x < end) { float b = (float)d_in[i + blockDim.x]; if(b > sdata[tid]) { sdata[tid] = b; indices[tid] = i + blockDim.x; } } } __syncthreads(); for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { if(tid < s && tid + s < end) { if(sdata[tid + s] > sdata[tid]) { sdata[tid] = sdata[tid + s]; indices[tid] = indices[tid + s]; } } __syncthreads(); } UNROLL_MAXARG_LOOP(32, end); UNROLL_MAXARG_LOOP(16, end); UNROLL_MAXARG_LOOP(8, end); UNROLL_MAXARG_LOOP(4, end); UNROLL_MAXARG_LOOP(2, end); UNROLL_MAXARG_LOOP(1, end); if(tid == 0) { d_out[blockIdx.x + batchIdx * gridDim.x] = sdata[0]; d_ind[blockIdx.x + batchIdx * gridDim.x] = indices[0]; } __syncthreads(); } } template __global__ void gMaxElementUpdate(float* binCosts, int* binIdxs, T* probs, // should work well enough with half, uses float everywhere else int* batchFirstElements, float* outCosts, int* outIdxs, int* cumulativeBeamSizes, int NUM_BLOCKS, float disabledPathScore) { extern __shared__ float sdata[]; __shared__ int indices[512]; __shared__ float bestBinCost; __shared__ int bestBinCostIdx; const int tid = threadIdx.x; const int batchIdx = blockIdx.x; const int N = batchFirstElements[batchIdx + 1] - batchFirstElements[batchIdx]; int num_bins = int(N / (2 * 512)) + int(N % (2 * 512) != 0); if(num_bins > 500) { num_bins = 500; } for(int pos = cumulativeBeamSizes[batchIdx]; pos < cumulativeBeamSizes[batchIdx + 1]; ++pos) { int i = tid; sdata[tid] = disabledPathScore; if(i < num_bins) { sdata[tid] = binCosts[batchIdx * NUM_BLOCKS + i]; indices[tid] = i; } if(i + blockDim.x < num_bins) { float a = binCosts[batchIdx * NUM_BLOCKS + i]; float b = binCosts[batchIdx * NUM_BLOCKS + i + blockDim.x]; if(a > b) { sdata[tid] = a; indices[tid] = i; } else { sdata[tid] = b; indices[tid] = i + blockDim.x; } } while(i + 2 * blockDim.x < num_bins) { i += 2 * blockDim.x; float a = binCosts[batchIdx * NUM_BLOCKS + i]; if(a > sdata[tid]) { sdata[tid] = a; indices[tid] = i; } if(i + blockDim.x < num_bins) { float b = binCosts[batchIdx * NUM_BLOCKS + i + blockDim.x]; if(b > sdata[tid]) { sdata[tid] = b; indices[tid] = i + blockDim.x; } } } __syncthreads(); for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { if(tid < s && tid + s < num_bins) { if(sdata[tid + s] > sdata[tid]) { sdata[tid] = sdata[tid + s]; indices[tid] = indices[tid + s]; } } __syncthreads(); } UNROLL_MAXARG_LOOP(32, num_bins); UNROLL_MAXARG_LOOP(16, num_bins); UNROLL_MAXARG_LOOP(8, num_bins); UNROLL_MAXARG_LOOP(4, num_bins); UNROLL_MAXARG_LOOP(2, num_bins); UNROLL_MAXARG_LOOP(1, num_bins); if(tid == 0) { bestBinCost = sdata[0]; bestBinCostIdx = batchIdx * NUM_BLOCKS + indices[0]; probs[binIdxs[bestBinCostIdx]] = disabledPathScore; outIdxs[pos] = binIdxs[bestBinCostIdx]; outCosts[pos] = bestBinCost; } __syncthreads(); i = batchFirstElements[batchIdx] + (bestBinCostIdx - batchIdx * NUM_BLOCKS) * (blockDim.x * 2) + tid; const int dist = num_bins * 2 * blockDim.x; sdata[tid] = disabledPathScore; if(i < batchFirstElements[batchIdx + 1]) { sdata[tid] = (float)probs[i]; indices[tid] = i; } if(i + blockDim.x < batchFirstElements[batchIdx + 1]) { float a = (float)probs[i]; float b = (float)probs[i + blockDim.x]; if(a > b) { sdata[tid] = a; indices[tid] = i; } else { sdata[tid] = b; indices[tid] = i + blockDim.x; } } while(i + dist < batchFirstElements[batchIdx + 1]) { i += dist; float a = (float)probs[i]; if(a > sdata[tid]) { sdata[tid] = a; indices[tid] = i; } if(i + blockDim.x < batchFirstElements[batchIdx + 1]) { float b = (float)probs[i + blockDim.x]; if(b > sdata[tid]) { sdata[tid] = b; indices[tid] = i + blockDim.x; } } } __syncthreads(); for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { if(tid < s && tid + s < batchFirstElements[batchIdx + 1]) { if(sdata[tid + s] > sdata[tid]) { sdata[tid] = sdata[tid + s]; indices[tid] = indices[tid + s]; } } __syncthreads(); } UNROLL_MAXARG_LOOP(32, batchFirstElements[batchIdx + 1]); UNROLL_MAXARG_LOOP(16, batchFirstElements[batchIdx + 1]); UNROLL_MAXARG_LOOP(8, batchFirstElements[batchIdx + 1]); UNROLL_MAXARG_LOOP(4, batchFirstElements[batchIdx + 1]); UNROLL_MAXARG_LOOP(2, batchFirstElements[batchIdx + 1]); UNROLL_MAXARG_LOOP(1, batchFirstElements[batchIdx + 1]); if(tid == 0) { binCosts[bestBinCostIdx] = sdata[0]; binIdxs[bestBinCostIdx] = indices[0]; } __syncthreads(); } } __global__ void gGetValueByKey(float* d_in, float* d_out, int* indeces, int n) { int tid = threadIdx.x + blockDim.x * blockIdx.x; if(tid < n) { int index = indeces[tid]; d_out[tid] = d_in[index]; } } class NthElementGPU { public: NthElementGPU() = delete; NthElementGPU(const NthElementGPU& copy) = delete; NthElementGPU(size_t maxBeamSize, size_t maxBatchSize, DeviceId deviceId) : deviceId_(deviceId), maxBeamSize_(maxBeamSize), maxBatchSize_(maxBatchSize), NUM_BLOCKS(std::min( 500, int(maxBeamSize* MAX_VOCAB_SIZE / (2 * BLOCK_SIZE)) + int(maxBeamSize* MAX_VOCAB_SIZE % (2 * BLOCK_SIZE) != 0))) { // std::cerr << "NthElement::NthElement" << std::endl; cudaSetDevice(deviceId_.no); CUDA_CHECK(cudaMalloc((void**)&d_ind, maxBatchSize * NUM_BLOCKS * sizeof(int))); CUDA_CHECK(cudaMalloc((void**)&d_out, maxBatchSize * NUM_BLOCKS * sizeof(float))); CUDA_CHECK(cudaMalloc((void**)&d_res_idx, maxBatchSize * maxBeamSize * sizeof(int))); CUDA_CHECK(cudaMalloc((void**)&d_res, maxBatchSize * maxBeamSize * sizeof(float))); CUDA_CHECK(cudaHostAlloc((void**)&h_res, maxBeamSize * maxBatchSize * sizeof(float), cudaHostAllocDefault)); CUDA_CHECK(cudaHostAlloc((void**)&h_res_idx, maxBeamSize * maxBatchSize * sizeof(int), cudaHostAllocDefault)); CUDA_CHECK(cudaMalloc((void**)&d_breakdown, maxBeamSize * sizeof(float))); CUDA_CHECK(cudaMalloc((void**)&d_batchPosition, (maxBatchSize + 1) * sizeof(int))); CUDA_CHECK(cudaMalloc((void**)&d_cumBeamSizes, (maxBatchSize + 1) * sizeof(int))); } ~NthElementGPU() { // No CUDA error checking as this is a destructor and we cannot do anything about errors anyway. cudaSetDevice(deviceId_.no); cudaFree(d_cumBeamSizes); cudaFree(d_batchPosition); cudaFree(d_breakdown); cudaFreeHost(h_res_idx); cudaFreeHost(h_res); cudaFree(d_res); cudaFree(d_res_idx); cudaFree(d_out); cudaFree(d_ind); } private: template void selectNBest(T* probs, const std::vector& batchFirstElementIdxs, const std::vector& cumulativeBeamSizes, float disabledPathScore) { cudaSetDevice(deviceId_.no); CUDA_CHECK(cudaMemcpyAsync(d_batchPosition, batchFirstElementIdxs.data(), batchFirstElementIdxs.size() * sizeof(int), cudaMemcpyHostToDevice, /* stream_ */ 0)); CUDA_CHECK(cudaMemcpyAsync(d_cumBeamSizes, cumulativeBeamSizes.data(), cumulativeBeamSizes.size() * sizeof(int), cudaMemcpyHostToDevice, /* stream_ */ 0)); const int numBatches = batchFirstElementIdxs.size() - 1; gMaxElement<<>>( d_out, d_ind, probs, numBatches, d_batchPosition, disabledPathScore); gMaxElementUpdate<<>>(d_out, d_ind, probs, d_batchPosition, d_res, d_res_idx, d_cumBeamSizes, NUM_BLOCKS, disabledPathScore); } public: void getNBestList(Tensor scores, size_t N, std::vector& outCosts, std::vector& outKeys, const bool isFirst) { cudaSetDevice(deviceId_.no); const auto vocabSize = scores->shape()[-1]; const auto inputN = scores->shape()[-2]; const auto dimBatch = scores->shape()[-4]; ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??"); // @TODO: Remove isFirst argument altogether ABORT_IF(vocabSize > MAX_VOCAB_SIZE, "GetNBestList(): actual vocab size {} exceeds MAX_VOCAB_SIZE of {}", vocabSize, MAX_VOCAB_SIZE); ABORT_IF(dimBatch > maxBatchSize_, "GetNBestList(): actual batch size {} exceeds initialization parameter {}", dimBatch, maxBatchSize_); ABORT_IF(std::max(N, (size_t)inputN) > maxBeamSize_, "GetNBestList(): actual beam size {} exceeds initialization parameter {}", N, maxBeamSize_); const std::vector beamSizes(dimBatch, N); std::vector cumulativeBeamSizes(beamSizes.size() + 1, 0); std::vector batchFirstElementIdxs(beamSizes.size() + 1, 0); for(size_t batchIdx = 0; batchIdx < beamSizes.size(); ++batchIdx) { #if 1 cumulativeBeamSizes[batchIdx + 1] = (batchIdx + 1) * (int)N; batchFirstElementIdxs[batchIdx + 1] += (batchIdx + 1) * inputN * vocabSize; ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != cumulativeBeamSizes[batchIdx] + (int)N, "cumulativeBeamSizes wrong??"); ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??"); #else cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + beamSizes[batchIdx]; ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != (batchIdx + 1) * N, "cumulativeBeamSizes wrong??"); batchFirstElementIdxs[batchIdx + 1] += ((isFirst) ? (batchIdx + 1) : cumulativeBeamSizes[batchIdx + 1]) * vocabSize; ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??"); #endif } if(scores->type() == Type::float32) { float disabledPathScore = NumericLimits(scores->type()).lowest; selectNBest(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes, disabledPathScore); #if COMPILE_FP16 } else if(scores->type() == Type::float16) { float disabledPathScore = NumericLimits(scores->type()).lowest; selectNBest(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes, disabledPathScore); #endif } else { ABORT("getNBestList not implemented for type {}", scores->type()); } getPairs(dimBatch * N, outKeys, outCosts); ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??"); } private: void getPairs(size_t number, std::vector& outKeys, std::vector& outValues) { cudaSetDevice(deviceId_.no); CUDA_CHECK(cudaMemcpyAsync(h_res, d_res, number * sizeof(float), cudaMemcpyDeviceToHost, /* stream_ */ 0)); CUDA_CHECK(cudaMemcpyAsync(h_res_idx, d_res_idx, number * sizeof(int), cudaMemcpyDeviceToHost, /* stream_ */ 0)); cudaStreamSynchronize(/* stream_ */ 0); for(size_t i = 0; i < number; ++i) { outKeys.push_back(h_res_idx[i]); outValues.push_back(h_res[i]); } //lastN = number; } DeviceId deviceId_; const int MAX_VOCAB_SIZE = 500000; size_t maxBeamSize_; size_t maxBatchSize_; const int BLOCK_SIZE = 512; const int NUM_BLOCKS; int* d_ind; // [maxBatchSize * NUM_BLOCKS] float* d_out; // [maxBatchSize * NUM_BLOCKS] int* d_res_idx; // [maxBatchSize * maxBeamSize] float* d_res; // [maxBatchSize * maxBeamSize] int* h_res_idx; // [maxBeamSize * maxBatchSize] float* h_res; // [maxBeamSize * maxBatchSize] float* d_breakdown; // [maxBeamSize] int* d_batchPosition; // [maxBatchSize + 1] int* d_cumBeamSizes; // [maxBatchSize + 1] //size_t lastN; }; // factory function // Returns a lambda with the same signature as the getNBestList() function. GetNBestListFn createGetNBestListGPUFn(size_t beamSize, size_t dimBatch, DeviceId deviceId) { auto nth = New(beamSize, dimBatch, deviceId); return [nth](Tensor logProbs, size_t N, std::vector& outCosts, std::vector& outKeys, const bool isFirst) { return nth->getNBestList(logProbs, N, outCosts, outKeys, isFirst); }; } } // namespace marian