From 23e7187d7f9433ca91f13cbc86b92283f31586ed Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Tue, 11 Apr 2017 11:54:23 -0700 Subject: [PATCH 01/10] block wide reduction with multiple values to reduce at once --- lib/THC/THCReduceApplyUtils.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/THC/THCReduceApplyUtils.cuh b/lib/THC/THCReduceApplyUtils.cuh index 30325de5..05498899 100644 --- a/lib/THC/THCReduceApplyUtils.cuh +++ b/lib/THC/THCReduceApplyUtils.cuh @@ -105,7 +105,6 @@ __device__ T reduceBlock(T* smem, return threadVal; } - // Block-wide reduction where each thread locally reduces N // values before letting a single warp take over - assumes // threadVals is in registers, not shared memory From 2592252c85118c33a7a88bcc9a697fa5fdc62161 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Thu, 13 Apr 2017 10:01:04 -0700 Subject: [PATCH 02/10] add function for determining the amount of smem required for a reduction, in preparation for warp shuffle --- lib/THC/THCReduce.cuh | 2 +- lib/THC/THCReduceAll.cuh | 4 ++-- lib/THC/THCReduceApplyUtils.cuh | 11 +++++++++++ lib/THC/THCTensorMode.cuh | 9 +++++++++ lib/THC/generic/THCTensorMode.cu | 2 +- lib/THC/generic/THCTensorRandom.cu | 4 +++- 6 files changed, 27 insertions(+), 5 deletions(-) diff --git a/lib/THC/THCReduce.cuh b/lib/THC/THCReduce.cuh index 7f276a2a..5dcc60bd 100644 --- a/lib/THC/THCReduce.cuh +++ b/lib/THC/THCReduce.cuh @@ -198,7 +198,7 @@ bool THC_reduceDim(THCState* state, } block = getContigReduceBlock(outElements, reductionSize); - smemSize = sizeof(typename TensorUtils::DataType) * block.x; + smemSize = reduceSmemSize::DataType, 1>(block.x); } else { if (!getNoncontigReduceGrid(outElements, grid)) { return false; diff --git a/lib/THC/THCReduceAll.cuh b/lib/THC/THCReduceAll.cuh index 1d04e63b..6df649c9 100644 --- a/lib/THC/THCReduceAll.cuh +++ b/lib/THC/THCReduceAll.cuh @@ -200,7 +200,7 @@ void callReduceAll(THCState* state, } getPass1ReduceBlockGrid(state, totalElements, grid, block); - size_t smemSize = block.x * sizeof(AccT); + size_t smemSize = reduceSmemSize(block.x); kernelReduceAllPass1 <<>>( @@ -221,7 +221,7 @@ void callReduceAll(THCState* state, } } else { getSinglePassReduceBlockGrid(totalElements, grid, block); - size_t smemSize = block.x * sizeof(AccT); + size_t smemSize = reduceSmemSize(block.x); kernelReduceAll <<>>( diff --git a/lib/THC/THCReduceApplyUtils.cuh b/lib/THC/THCReduceApplyUtils.cuh index 05498899..28dad27e 100644 --- a/lib/THC/THCReduceApplyUtils.cuh +++ b/lib/THC/THCReduceApplyUtils.cuh @@ -1,6 +1,7 @@ #ifndef THC_REDUCE_APPLY_UTILS_INC #define THC_REDUCE_APPLY_UTILS_INC +#include #include #include #include "THCGeneral.h" @@ -19,6 +20,16 @@ __device__ __forceinline__ IndexType getLinearBlockId() { blockIdx.x; } +// Returns the minimum size of shared memory required of performing N reductions +// across a block, with each reduction having at most numVals individual elements +// to reduce. Note that internally, because reductions operate (at the shared memory +// level) with N elements per thread in the block, so we have to use min(numvals, +// max block size) to determine this count. +template +int reduceSmemSize(int numVals) { + return std::min(numVals, 1024) * N * sizeof(T); +} + // Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads: // (1, 2), (3, 4), (5, 6), (7, 8), then the return in threadVals for thread 0 // is (1 + 3 + 5 + 7, 2 + 4 + 6 + 8) = (16, 20) diff --git a/lib/THC/THCTensorMode.cuh b/lib/THC/THCTensorMode.cuh index b67ac2a4..efd99274 100644 --- a/lib/THC/THCTensorMode.cuh +++ b/lib/THC/THCTensorMode.cuh @@ -77,6 +77,15 @@ struct MatchReduceOp { } }; +template +int modeSmemSize() { + int sliceElementSize = sizeof(T) * Power2Size; + int sortScanSize = 2 * Power2Size * sizeof(unsigned int); + int reductionSize = reduceSmemSize(Power2Size); + + return sliceElementSize + (sortScanSize > reductionSize ? sortScanSize : reductionSize); +} + // The mode kernel has the following characteristics: It uses internal shared memory // buffers of Power2Size, which must be greater than the number of elements. Additionally, // there is one block for every slice to calculate the mode for, and in each block there diff --git a/lib/THC/generic/THCTensorMode.cu b/lib/THC/generic/THCTensorMode.cu index 031e0acc..7693df81 100644 --- a/lib/THC/generic/THCTensorMode.cu +++ b/lib/THC/generic/THCTensorMode.cu @@ -230,7 +230,7 @@ THC_API void THCTensor_(mode)(THCState *state, { \ dim3 blockSize(SIZE / 2); \ \ - int memsize = (sizeof(real) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \ + int memsize = modeSmemSize(); \ computeMode \ <<>>( \ THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \ diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu index 4c6d2fb2..0d214698 100644 --- a/lib/THC/generic/THCTensorRandom.cu +++ b/lib/THC/generic/THCTensorRandom.cu @@ -159,9 +159,11 @@ THC_API void THCTensor_(multinomial)(struct THCState *state, int maxThreads = props->maxThreadsPerBlock; dim3 block(numCategories < maxThreads ? numCategories : maxThreads); dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4); + // smem required for reduction + prefix sums + int smemSize = (block.x * sizeof(real)) + reduceSmemSize(block.x); sampleMultinomialOnce <<>>( THCudaLongTensor_data(state, self), numDist, From e09788f7d920bd196fbf90129ab40448358a1ead Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Tue, 18 Apr 2017 06:37:50 -0700 Subject: [PATCH 03/10] move random functions to use reduce smem, round up to smem multiple of warp --- lib/THC/THCReduceApplyUtils.cuh | 2 +- lib/THC/generic/THCTensorRandom.cu | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/THC/THCReduceApplyUtils.cuh b/lib/THC/THCReduceApplyUtils.cuh index 28dad27e..37b23f07 100644 --- a/lib/THC/THCReduceApplyUtils.cuh +++ b/lib/THC/THCReduceApplyUtils.cuh @@ -27,7 +27,7 @@ __device__ __forceinline__ IndexType getLinearBlockId() { // max block size) to determine this count. template int reduceSmemSize(int numVals) { - return std::min(numVals, 1024) * N * sizeof(T); + return THCRoundUp(std::min(numVals, 1024), 32) * N * sizeof(T); } // Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads: diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu index 0d214698..05e7da2c 100644 --- a/lib/THC/generic/THCTensorRandom.cu +++ b/lib/THC/generic/THCTensorRandom.cu @@ -93,10 +93,12 @@ void THCTensor_(renormRows)(struct THCState* state, int maxThreads = props->maxThreadsPerBlock; dim3 grid(rows < numSM * 4 ? rows : numSM * 4); - dim3 block(cols < maxThreads ? cols : maxThreads); + dim3 block(THCRoundUp(cols < maxThreads ? cols : maxThreads, (long) props->warpSize)); + + int smem = reduceSmemSize(cols); renormRowsL1 - <<>>(THCTensor_(data)(state, t), rows, cols); } @@ -157,10 +159,10 @@ THC_API void THCTensor_(multinomial)(struct THCState *state, THAssert(props != NULL); int numSM = props->multiProcessorCount; int maxThreads = props->maxThreadsPerBlock; - dim3 block(numCategories < maxThreads ? numCategories : maxThreads); + dim3 block(THCRoundUp(numCategories < maxThreads ? numCategories : maxThreads, props->warpSize)); dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4); // smem required for reduction + prefix sums - int smemSize = (block.x * sizeof(real)) + reduceSmemSize(block.x); + int smemSize = (block.x * sizeof(real)) + reduceSmemSize(numCategories); sampleMultinomialOnce << Date: Tue, 18 Apr 2017 07:19:35 -0700 Subject: [PATCH 04/10] use reduceSmemSize in reduceAll, round up to warpsize multiple threads in block --- lib/THC/THCReduceAll.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/THC/THCReduceAll.cuh b/lib/THC/THCReduceAll.cuh index 6df649c9..1eaa8abc 100644 --- a/lib/THC/THCReduceAll.cuh +++ b/lib/THC/THCReduceAll.cuh @@ -162,14 +162,14 @@ inline void getPass2ReduceBlockGrid(THCState* state, ptrdiff_t elements, dim3& grid, dim3& block) { grid = dim3(1); // We only need as many threads as there were blocks originally - block = dim3(getTwoPassBlocks(state, elements)); + block = dim3((int) THCRoundUp(getTwoPassBlocks(state, elements), (ptrdiff_t) 32)); } template inline void getSinglePassReduceBlockGrid(ptrdiff_t elements, dim3& grid, dim3& block) { grid = dim3(1); - block = dim3(THC_REDUCE_ALL_BLOCK_SIZE); + block = dim3(THCRoundUp((int) std::min(elements, (ptrdiff_t) THC_REDUCE_ALL_BLOCK_SIZE), 32)); } template (state, totalElements, grid, block); - smemSize = block.x * sizeof(AccT); + smemSize = reduceSmemSize(block.x); kernelReduceAllPass2 <<>>( From abbdfc55a3eff50e3dee5bc1e2a92164d790d683 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Tue, 18 Apr 2017 07:30:28 -0700 Subject: [PATCH 05/10] pass thcstate to reduceSmemSize --- lib/THC/THCReduce.cuh | 2 +- lib/THC/THCReduceAll.cuh | 6 +++--- lib/THC/THCReduceApplyUtils.cuh | 2 +- lib/THC/THCTensorMode.cuh | 4 ++-- lib/THC/generic/THCTensorMode.cu | 2 +- lib/THC/generic/THCTensorRandom.cu | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lib/THC/THCReduce.cuh b/lib/THC/THCReduce.cuh index 5dcc60bd..429966f6 100644 --- a/lib/THC/THCReduce.cuh +++ b/lib/THC/THCReduce.cuh @@ -198,7 +198,7 @@ bool THC_reduceDim(THCState* state, } block = getContigReduceBlock(outElements, reductionSize); - smemSize = reduceSmemSize::DataType, 1>(block.x); + smemSize = reduceSmemSize::DataType, 1>(state, block.x); } else { if (!getNoncontigReduceGrid(outElements, grid)) { return false; diff --git a/lib/THC/THCReduceAll.cuh b/lib/THC/THCReduceAll.cuh index 1eaa8abc..73802e60 100644 --- a/lib/THC/THCReduceAll.cuh +++ b/lib/THC/THCReduceAll.cuh @@ -200,7 +200,7 @@ void callReduceAll(THCState* state, } getPass1ReduceBlockGrid(state, totalElements, grid, block); - size_t smemSize = reduceSmemSize(block.x); + size_t smemSize = reduceSmemSize(state, block.x); kernelReduceAllPass1 <<>>( @@ -209,7 +209,7 @@ void callReduceAll(THCState* state, int numPass1Blocks = grid.x; getPass2ReduceBlockGrid(state, totalElements, grid, block); - smemSize = reduceSmemSize(block.x); + smemSize = reduceSmemSize(state, block.x); kernelReduceAllPass2 <<>>( @@ -221,7 +221,7 @@ void callReduceAll(THCState* state, } } else { getSinglePassReduceBlockGrid(totalElements, grid, block); - size_t smemSize = reduceSmemSize(block.x); + size_t smemSize = reduceSmemSize(state, block.x); kernelReduceAll <<>>( diff --git a/lib/THC/THCReduceApplyUtils.cuh b/lib/THC/THCReduceApplyUtils.cuh index 37b23f07..c19c9932 100644 --- a/lib/THC/THCReduceApplyUtils.cuh +++ b/lib/THC/THCReduceApplyUtils.cuh @@ -26,7 +26,7 @@ __device__ __forceinline__ IndexType getLinearBlockId() { // level) with N elements per thread in the block, so we have to use min(numvals, // max block size) to determine this count. template -int reduceSmemSize(int numVals) { +int reduceSmemSize(THCState *state, int numVals) { return THCRoundUp(std::min(numVals, 1024), 32) * N * sizeof(T); } diff --git a/lib/THC/THCTensorMode.cuh b/lib/THC/THCTensorMode.cuh index efd99274..75981599 100644 --- a/lib/THC/THCTensorMode.cuh +++ b/lib/THC/THCTensorMode.cuh @@ -78,10 +78,10 @@ struct MatchReduceOp { }; template -int modeSmemSize() { +int modeSmemSize(THCState *state) { int sliceElementSize = sizeof(T) * Power2Size; int sortScanSize = 2 * Power2Size * sizeof(unsigned int); - int reductionSize = reduceSmemSize(Power2Size); + int reductionSize = reduceSmemSize(state, Power2Size); return sliceElementSize + (sortScanSize > reductionSize ? sortScanSize : reductionSize); } diff --git a/lib/THC/generic/THCTensorMode.cu b/lib/THC/generic/THCTensorMode.cu index 7693df81..875caec2 100644 --- a/lib/THC/generic/THCTensorMode.cu +++ b/lib/THC/generic/THCTensorMode.cu @@ -230,7 +230,7 @@ THC_API void THCTensor_(mode)(THCState *state, { \ dim3 blockSize(SIZE / 2); \ \ - int memsize = modeSmemSize(); \ + int memsize = modeSmemSize(state); \ computeMode \ <<>>( \ THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \ diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu index 05e7da2c..759ee095 100644 --- a/lib/THC/generic/THCTensorRandom.cu +++ b/lib/THC/generic/THCTensorRandom.cu @@ -95,7 +95,7 @@ void THCTensor_(renormRows)(struct THCState* state, dim3 grid(rows < numSM * 4 ? rows : numSM * 4); dim3 block(THCRoundUp(cols < maxThreads ? cols : maxThreads, (long) props->warpSize)); - int smem = reduceSmemSize(cols); + int smem = reduceSmemSize(state, cols); renormRowsL1 <<warpSize)); dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4); // smem required for reduction + prefix sums - int smemSize = (block.x * sizeof(real)) + reduceSmemSize(numCategories); + int smemSize = (block.x * sizeof(real)) + reduceSmemSize(state, numCategories); sampleMultinomialOnce << Date: Tue, 18 Apr 2017 07:41:49 -0700 Subject: [PATCH 06/10] small changes to mode --> make sure we always have at least a warp # of threads --- lib/THC/THCTensorMode.cuh | 1 + lib/THC/generic/THCTensorMode.cu | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/THC/THCTensorMode.cuh b/lib/THC/THCTensorMode.cuh index 75981599..d63ed7a4 100644 --- a/lib/THC/THCTensorMode.cuh +++ b/lib/THC/THCTensorMode.cuh @@ -280,6 +280,7 @@ __global__ void computeMode( // Finally, we have the mode, and an index where it occurs. We use a single thread // to place this in the appropriate output position if (tidx == 0) { + assert(match.flag); long index = TH_INDEX_BASE + match.val; unsigned int outputOffset = IndexToOffset::get(blockId, values); diff --git a/lib/THC/generic/THCTensorMode.cu b/lib/THC/generic/THCTensorMode.cu index 875caec2..e8428e87 100644 --- a/lib/THC/generic/THCTensorMode.cu +++ b/lib/THC/generic/THCTensorMode.cu @@ -198,7 +198,7 @@ THC_API void THCTensor_(mode)(THCState *state, // 2. uses one block per slice, so number of slices must be less than the maximum number of blocks for // a kernel launch // 3. Can use 32-bit index math for indexing (mainly just for implementation conciseness, could be changed) - if (sliceSize <= MAX_BLOCK_SIZE && + if (sliceSize <= MAX_BLOCK_SIZE * 2 && slices <= MAX_GRID_SIZE && TensorUtils::canUse32BitIndexMath(state, input)) { // Beginning our optimized implementation. First thing we want to do is to transpose @@ -248,15 +248,15 @@ THC_API void THCTensor_(mode)(THCState *state, HANDLE_MODE(1024) break; case 128: - case 64: HANDLE_MODE(128) break; + case 64: case 32: case 16: case 8: case 4: case 2: - HANDLE_MODE(32) + HANDLE_MODE(64) // block size should be at least 1 full warp break; case 1: default: From 720d24a9279a81cc22bcc47eb6529b507582e222 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Tue, 18 Apr 2017 08:23:51 -0700 Subject: [PATCH 07/10] small fix in nthreadlocal; add doc --- lib/THC/THCReduceApplyUtils.cuh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/THC/THCReduceApplyUtils.cuh b/lib/THC/THCReduceApplyUtils.cuh index c19c9932..0e68f98c 100644 --- a/lib/THC/THCReduceApplyUtils.cuh +++ b/lib/THC/THCReduceApplyUtils.cuh @@ -118,7 +118,11 @@ __device__ T reduceBlock(T* smem, // Block-wide reduction where each thread locally reduces N // values before letting a single warp take over - assumes -// threadVals is in registers, not shared memory +// threadVals is in registers, not shared memory. Note that +// numVals in this case is the number of values in the overall +// reduction, i.e. if there are 512 threads with N=2, and say +// there are 768 elements in the input block, then numVals is 768, +// not, say, 384 (i.e. 768 / N=2) template __device__ T reduceBlockWithNThreadLocalReductions(T *smem, T threadVals[N], @@ -135,7 +139,7 @@ __device__ T reduceBlockWithNThreadLocalReductions(T *smem, local = reduceOp(local, next); } - return reduceBlock(smem, blockDim.x < numVals ? blockDim.x : numVals, local, reduceOp, init); + return reduceBlock(smem, THCCeilDiv(numVals, N), local, reduceOp, init); } // Make sure the given tensor doesn't have too many dimensions From 875601f2bc05466ed9c44b567c753b0b15c3d542 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Tue, 18 Apr 2017 08:47:26 -0700 Subject: [PATCH 08/10] implement warp shuffle based reduction; enable for arch >= 3.0 --- lib/THC/THCReduceApplyUtils.cuh | 102 +++++++++++++++++++++++++++++++- lib/THC/THCTensorMode.cuh | 18 ++++++ 2 files changed, 118 insertions(+), 2 deletions(-) diff --git a/lib/THC/THCReduceApplyUtils.cuh b/lib/THC/THCReduceApplyUtils.cuh index 0e68f98c..6143084f 100644 --- a/lib/THC/THCReduceApplyUtils.cuh +++ b/lib/THC/THCReduceApplyUtils.cuh @@ -8,6 +8,7 @@ #include "THCTensor.h" #include "THCDeviceUtils.cuh" #include "THCTensorInfo.cuh" +#include "THCAsmUtils.cuh" // Enum that indicates whether tensor arguments are read/write or // read-only @@ -26,8 +27,101 @@ __device__ __forceinline__ IndexType getLinearBlockId() { // level) with N elements per thread in the block, so we have to use min(numvals, // max block size) to determine this count. template -int reduceSmemSize(THCState *state, int numVals) { - return THCRoundUp(std::min(numVals, 1024), 32) * N * sizeof(T); +int reduceSmemSize(THCState *state, long numVals) { + // check if we can use a warp shuffle + cudaDeviceProp *props = THCState_getCurrentDeviceProperties(state); + if (props->major >= 3) { + return props->warpSize * N * sizeof(T); + } else { + return THCRoundUp(std::min(numVals, (long) props->maxThreadsPerBlock), (long) props->warpSize) * N * sizeof(T); + } +} + +template +struct THCWarpUtils { + static __device__ __forceinline__ T shflxor(T val, unsigned int mask) { + return __shfl_xor(val, mask); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ unsigned char shflxor(unsigned char val, unsigned int mask) { + return (unsigned char) __shfl_xor((int) val, mask); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ char shflxor(char val, unsigned int mask) { + return (char) __shfl_xor((int) val, mask); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ short shflxor(short val, unsigned int mask) { + return (short) __shfl_xor((int) val, mask); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ double shflxor(double val, unsigned int mask) { + int2 a = *reinterpret_cast(&val); + a.x = __shfl_xor(a.x, mask); + a.y = __shfl_xor(a.y, mask); + return *reinterpret_cast(&a); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ long shflxor(long val, unsigned int mask) { + int2 a = *reinterpret_cast(&val); + a.x = __shfl_xor(a.x, mask); + a.y = __shfl_xor(a.y, mask); + return *reinterpret_cast(&a); + } +}; + +template +__device__ void warpReduce(T threadVals[N], ReduceOp reduceOp) { +#pragma unroll + for (int mask = 1; mask < warpSize; mask *= 2) { +#pragma unroll + for (int i = 0; i < N; ++i) { + T neighbor = THCWarpUtils::shflxor(threadVals[i], mask); + threadVals[i] = reduceOp(threadVals[i], neighbor); + } + } +} + +template +__device__ void warpReduceBlock(T *smem, T threadVals[N], int numVals, ReduceOp reduceOp, T init) { + assert(blockDim.x % warpSize == 0); + // First, warps cooperate to reduce values within the warp + warpReduce(threadVals, reduceOp); + int lane = getLaneId(); + int warp = threadIdx.x / warpSize; + + if (lane == 0) { + +#pragma unroll + for (int i = 0; i < N; ++i) { + smem[warp + (i * warpSize)] = threadVals[i]; + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < N; ++i) { + threadVals[i] = (threadIdx.x < (blockDim.x / warpSize)) ? smem[lane + (i * warpSize)] : init; + } + + if (warp == 0) { + warpReduce(threadVals, reduceOp); + } } // Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads: @@ -47,6 +141,9 @@ __device__ void reduceNValuesInBlock(T *smem, return; } +#if __CUDA_ARCH__ >= 300 + warpReduceBlock(smem, threadVals, numVals, reduceOp, init); +#else // We store each of the N values contiguously, so if N = 2, all values for // the first threadVal for each thread in the block are stored followed by // all of the values for the second threadVal for each thread in the block @@ -102,6 +199,7 @@ __device__ void reduceNValuesInBlock(T *smem, } } } +#endif } // Block-wide reduction in shared memory helper; only threadIdx.x == 0 will diff --git a/lib/THC/THCTensorMode.cuh b/lib/THC/THCTensorMode.cuh index d63ed7a4..e2c30cf1 100644 --- a/lib/THC/THCTensorMode.cuh +++ b/lib/THC/THCTensorMode.cuh @@ -56,6 +56,15 @@ struct ModeUnsignedBoolPair { bool flag; }; +template <> +struct THCWarpUtils { + static __device__ __forceinline__ ModeUnsignedBoolPair shflxor(ModeUnsignedBoolPair val, unsigned int mask) { + val.val = __shfl_xor(val.val, mask); + val.flag = (bool) __shfl_xor((int) val.flag, mask); + return val; + } +}; + // In the kernel below, we have a common pattern of reducing (unsigned int, unsigned int) // pairs of data struct ModeUnsignedPair { @@ -63,6 +72,15 @@ struct ModeUnsignedPair { unsigned int index; }; +template <> +struct THCWarpUtils { + static __device__ __forceinline__ ModeUnsignedPair shflxor(ModeUnsignedPair val, unsigned int mask) { + val.val = __shfl_xor(val.val, mask); + val.index = __shfl_xor(val.index, mask); + return val; + } +}; + template struct MaxReduceOp { __host__ __device__ inline T operator()(const T& a, const T& b) { From 28cb9615290cc832bc7120a9e91343a1a0096b00 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Tue, 25 Apr 2017 08:13:54 -0700 Subject: [PATCH 09/10] rebase --- lib/THC/THCTensorTopK.cuh | 17 +++++++++++------ lib/THC/generic/THCTensorTopK.cu | 4 +++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/lib/THC/THCTensorTopK.cuh b/lib/THC/THCTensorTopK.cuh index 32041e3b..d1c848e8 100644 --- a/lib/THC/THCTensorTopK.cuh +++ b/lib/THC/THCTensorTopK.cuh @@ -166,18 +166,19 @@ __device__ void countRadixUsingMask(CountType counts[RadixSize], #pragma unroll for (unsigned int j = 0; j < RadixSize; ++j) { bool vote = hasVal && (digitInRadix == j); - counts[j] += __popc(__ballot(vote)); + counts[j] += vote; } } - // Now, for each warp, sum values - if (getLaneId() == 0) { + reduceNValuesInBlock, RadixSize>( + smem, counts, blockDim.x, AddOp(), 0); + + if (threadIdx.x == 0) { #pragma unroll for (unsigned int i = 0; i < RadixSize; ++i) { - atomicAdd(&smem[i], counts[i]); + smem[i] = counts[i]; } } - __syncthreads(); // For each thread, read in the total counts @@ -194,6 +195,10 @@ __device__ void countRadixUsingMask(CountType counts[RadixSize], #define RADIX_SIZE 4 // 2 ^ RADIX_BITS #define RADIX_MASK (RADIX_SIZE - 1) +int topkSmemSize(THCState *state, long sliceSize) { + return reduceSmemSize(state, sliceSize); +} + // This finds the unique value `v` that matches the pattern // ((v & desired) == desiredMask) in our sorted int format template @@ -356,7 +361,7 @@ __global__ void gatherTopK(TensorInfo input, IndexType indicesWithinSliceStride) { // Indices are limited to integer fp precision, so counts can fit in // int32, regardless of IndexType - __shared__ int smem[32]; // one per each warp, up to warp limit + extern __shared__ int smem[]; // one per each warp, up to warp limit IndexType slice = getLinearBlockId(); if (slice >= numInputSlices) { diff --git a/lib/THC/generic/THCTensorTopK.cu b/lib/THC/generic/THCTensorTopK.cu index 83ab1e10..ab91cb4a 100644 --- a/lib/THC/generic/THCTensorTopK.cu +++ b/lib/THC/generic/THCTensorTopK.cu @@ -28,9 +28,11 @@ THC_API void THCTensor_(topk)(THCState* state, THCudaLongTensor_resize(state, indices, topKSize, NULL); THLongStorage_free(topKSize); + int smem = topkSmemSize(state, sliceSize); + #define RUN_K(INDEX_T, DIM, DIR) \ gatherTopK \ - <<>>( \ + <<>>( \ inputInfo, \ sliceSize, \ k, \ From 576ae9f1018f4ca52900ebeba647207ee204db67 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Tue, 25 Apr 2017 08:31:49 -0700 Subject: [PATCH 10/10] only first warp executes last stage of warp reduction --- lib/THC/THCReduceApplyUtils.cuh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lib/THC/THCReduceApplyUtils.cuh b/lib/THC/THCReduceApplyUtils.cuh index 6143084f..632e12c2 100644 --- a/lib/THC/THCReduceApplyUtils.cuh +++ b/lib/THC/THCReduceApplyUtils.cuh @@ -114,12 +114,11 @@ __device__ void warpReduceBlock(T *smem, T threadVals[N], int numVals, ReduceOp } __syncthreads(); -#pragma unroll - for (int i = 0; i < N; ++i) { - threadVals[i] = (threadIdx.x < (blockDim.x / warpSize)) ? smem[lane + (i * warpSize)] : init; - } - if (warp == 0) { +#pragma unroll + for (int i = 0; i < N; ++i) { + threadVals[i] = (threadIdx.x < (blockDim.x / warpSize)) ? smem[lane + (i * warpSize)] : init; + } warpReduce(threadVals, reduceOp); } }