Skip to content
Open
2 changes: 1 addition & 1 deletion lib/THC/THCReduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ bool THC_reduceDim(THCState* state,
}

block = getContigReduceBlock(outElements, reductionSize);
smemSize = sizeof(typename TensorUtils<TensorType>::DataType) * block.x;
smemSize = reduceSmemSize<typename TensorUtils<TensorType>::DataType, 1>(state, block.x);
} else {
if (!getNoncontigReduceGrid(outElements, grid)) {
return false;
Expand Down
10 changes: 5 additions & 5 deletions lib/THC/THCReduceAll.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ inline void getPass2ReduceBlockGrid(THCState* state, ptrdiff_t elements,
dim3& grid, dim3& block) {
grid = dim3(1);
// We only need as many threads as there were blocks originally
block = dim3(getTwoPassBlocks<InT, AccT>(state, elements));
block = dim3((int) THCRoundUp(getTwoPassBlocks<InT, AccT>(state, elements), (ptrdiff_t) 32));
}

template <typename InT, typename AccT>
inline void getSinglePassReduceBlockGrid(ptrdiff_t elements,
dim3& grid, dim3& block) {
grid = dim3(1);
block = dim3(THC_REDUCE_ALL_BLOCK_SIZE);
block = dim3(THCRoundUp((int) std::min(elements, (ptrdiff_t) THC_REDUCE_ALL_BLOCK_SIZE), 32));
}

template <typename ModifyOp,
Expand Down Expand Up @@ -200,7 +200,7 @@ void callReduceAll(THCState* state,
}

getPass1ReduceBlockGrid<InT, AccT>(state, totalElements, grid, block);
size_t smemSize = block.x * sizeof(AccT);
size_t smemSize = reduceSmemSize<AccT, 1>(state, block.x);

kernelReduceAllPass1<ModifyOp, ReduceOp, ReduceAccOp, InT, AccT, IndexType, ADims>
<<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(
Expand All @@ -209,7 +209,7 @@ void callReduceAll(THCState* state,

int numPass1Blocks = grid.x;
getPass2ReduceBlockGrid<InT, AccT>(state, totalElements, grid, block);
smemSize = block.x * sizeof(AccT);
smemSize = reduceSmemSize<AccT, 1>(state, block.x);

kernelReduceAllPass2<ReduceAccOp, AccT, IndexType>
<<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(
Expand All @@ -221,7 +221,7 @@ void callReduceAll(THCState* state,
}
} else {
getSinglePassReduceBlockGrid<InT, AccT>(totalElements, grid, block);
size_t smemSize = block.x * sizeof(AccT);
size_t smemSize = reduceSmemSize<AccT, 1>(state, block.x);

kernelReduceAll<ModifyOp, ReduceOp, ReduceAccOp, InT, AccT, IndexType, ADims>
<<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(
Expand Down
117 changes: 114 additions & 3 deletions lib/THC/THCReduceApplyUtils.cuh
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#ifndef THC_REDUCE_APPLY_UTILS_INC
#define THC_REDUCE_APPLY_UTILS_INC

#include <algorithm>
#include <cuda.h>
#include <assert.h>
#include "THCGeneral.h"
#include "THCTensor.h"
#include "THCDeviceUtils.cuh"
#include "THCTensorInfo.cuh"
#include "THCAsmUtils.cuh"

// Enum that indicates whether tensor arguments are read/write or
// read-only
Expand All @@ -19,6 +21,108 @@ __device__ __forceinline__ IndexType getLinearBlockId() {
blockIdx.x;
}

// Returns the minimum size of shared memory required of performing N reductions
// across a block, with each reduction having at most numVals individual elements
// to reduce. Note that internally, because reductions operate (at the shared memory
// level) with N elements per thread in the block, so we have to use min(numvals,
// max block size) to determine this count.
template <typename T, int N>
int reduceSmemSize(THCState *state, long numVals) {
// check if we can use a warp shuffle
cudaDeviceProp *props = THCState_getCurrentDeviceProperties(state);
if (props->major >= 3) {
return props->warpSize * N * sizeof(T);
} else {
return THCRoundUp(std::min(numVals, (long) props->maxThreadsPerBlock), (long) props->warpSize) * N * sizeof(T);
}
}

template <typename T>
struct THCWarpUtils {
static __device__ __forceinline__ T shflxor(T val, unsigned int mask) {
return __shfl_xor(val, mask);
}
};

template <>
struct THCWarpUtils<unsigned char> {
static __device__ __forceinline__ unsigned char shflxor(unsigned char val, unsigned int mask) {
return (unsigned char) __shfl_xor((int) val, mask);
}
};

template <>
struct THCWarpUtils<char> {
static __device__ __forceinline__ char shflxor(char val, unsigned int mask) {
return (char) __shfl_xor((int) val, mask);
}
};

template <>
struct THCWarpUtils<short> {
static __device__ __forceinline__ short shflxor(short val, unsigned int mask) {
return (short) __shfl_xor((int) val, mask);
}
};

template <>
struct THCWarpUtils<double> {
static __device__ __forceinline__ double shflxor(double val, unsigned int mask) {
int2 a = *reinterpret_cast<int2*>(&val);
a.x = __shfl_xor(a.x, mask);
a.y = __shfl_xor(a.y, mask);
return *reinterpret_cast<double*>(&a);
}
};

template <>
struct THCWarpUtils<long> {
static __device__ __forceinline__ long shflxor(long val, unsigned int mask) {
int2 a = *reinterpret_cast<int2*>(&val);
a.x = __shfl_xor(a.x, mask);
a.y = __shfl_xor(a.y, mask);
return *reinterpret_cast<long*>(&a);
}
};

template <typename T, typename ReduceOp, int N>
__device__ void warpReduce(T threadVals[N], ReduceOp reduceOp) {
#pragma unroll
for (int mask = 1; mask < warpSize; mask *= 2) {
#pragma unroll
for (int i = 0; i < N; ++i) {
T neighbor = THCWarpUtils<T>::shflxor(threadVals[i], mask);
threadVals[i] = reduceOp(threadVals[i], neighbor);
}
}
}

template <typename T, typename ReduceOp, int N>
__device__ void warpReduceBlock(T *smem, T threadVals[N], int numVals, ReduceOp reduceOp, T init) {
assert(blockDim.x % warpSize == 0);
// First, warps cooperate to reduce values within the warp
warpReduce<T, ReduceOp, N>(threadVals, reduceOp);
int lane = getLaneId();
int warp = threadIdx.x / warpSize;

if (lane == 0) {

#pragma unroll
for (int i = 0; i < N; ++i) {
smem[warp + (i * warpSize)] = threadVals[i];
}
}
__syncthreads();

if (warp == 0) {
#pragma unroll
for (int i = 0; i < N; ++i) {
threadVals[i] = (threadIdx.x < (blockDim.x / warpSize)) ? smem[lane + (i * warpSize)] : init;
}
warpReduce<T, ReduceOp, N>(threadVals, reduceOp);
}
}

// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads:
// (1, 2), (3, 4), (5, 6), (7, 8), then the return in threadVals for thread 0
// is (1 + 3 + 5 + 7, 2 + 4 + 6 + 8) = (16, 20)
Expand All @@ -36,6 +140,9 @@ __device__ void reduceNValuesInBlock(T *smem,
return;
}

#if __CUDA_ARCH__ >= 300
warpReduceBlock<T, ReduceOp, N>(smem, threadVals, numVals, reduceOp, init);
#else
// We store each of the N values contiguously, so if N = 2, all values for
// the first threadVal for each thread in the block are stored followed by
// all of the values for the second threadVal for each thread in the block
Expand Down Expand Up @@ -91,6 +198,7 @@ __device__ void reduceNValuesInBlock(T *smem,
}
}
}
#endif
}

// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will
Expand All @@ -105,10 +213,13 @@ __device__ T reduceBlock(T* smem,
return threadVal;
}


// Block-wide reduction where each thread locally reduces N
// values before letting a single warp take over - assumes
// threadVals is in registers, not shared memory
// threadVals is in registers, not shared memory. Note that
// numVals in this case is the number of values in the overall
// reduction, i.e. if there are 512 threads with N=2, and say
// there are 768 elements in the input block, then numVals is 768,
// not, say, 384 (i.e. 768 / N=2)
template <typename T, typename ReduceOp, int N>
__device__ T reduceBlockWithNThreadLocalReductions(T *smem,
T threadVals[N],
Expand All @@ -125,7 +236,7 @@ __device__ T reduceBlockWithNThreadLocalReductions(T *smem,
local = reduceOp(local, next);
}

return reduceBlock<T, ReduceOp>(smem, blockDim.x < numVals ? blockDim.x : numVals, local, reduceOp, init);
return reduceBlock<T, ReduceOp>(smem, THCCeilDiv(numVals, N), local, reduceOp, init);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the old or new code here.

Isn't numVals always less than blockDim.x, because it is the number of threads with N active values? In other words, it has nothing to do with N, because all numVals threads have N values?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the new code: numVals is the overall slice size for the reduction. So If numVals is 512 and N = 2 then the first 256 threads in the block have valid input values. For reduceBlock, numVals represents the number of threads whose values should be considered valid. So in the above example, the first 256 threads have values that should be reduced, hence the division.

The old code was doing something incorrect. However, because the local reduction uses init, it still succeeded.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but then what about the case where numVals is not a multiple of N?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case the THCCeilDiv should round it up - if we had 513 values in the above case then the first 257 threads should have valid input values.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but then only one of the 2 input values for the 257th thread would then be valid? where's the guarantee that all provided input values are either valid reduction values, or identity values? if they're identity values for the reduction, then we don't really need numVals at all, except possibly to handle tiny reduction cases (smaller than the block size) for slightly greater efficiency.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just means its part of the input that we need to consider. The resulting value from the 257th thread is the combination of the input value and a identity value, but the presence of the input value means it must take part in the reduction.

Yes, in theory we could imagine a case where block size > numVals -> this is common in the code for mode, example, where we round up.

}

// Make sure the given tensor doesn't have too many dimensions
Expand Down
28 changes: 28 additions & 0 deletions lib/THC/THCTensorMode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,31 @@ struct ModeUnsignedBoolPair {
bool flag;
};

template <>
struct THCWarpUtils<ModeUnsignedBoolPair> {
static __device__ __forceinline__ ModeUnsignedBoolPair shflxor(ModeUnsignedBoolPair val, unsigned int mask) {
val.val = __shfl_xor(val.val, mask);
val.flag = (bool) __shfl_xor((int) val.flag, mask);
return val;
}
};

// In the kernel below, we have a common pattern of reducing (unsigned int, unsigned int)
// pairs of data
struct ModeUnsignedPair {
unsigned int val;
unsigned int index;
};

template <>
struct THCWarpUtils<ModeUnsignedPair> {
static __device__ __forceinline__ ModeUnsignedPair shflxor(ModeUnsignedPair val, unsigned int mask) {
val.val = __shfl_xor(val.val, mask);
val.index = __shfl_xor(val.index, mask);
return val;
}
};

template <typename T>
struct MaxReduceOp {
__host__ __device__ inline T operator()(const T& a, const T& b) {
Expand All @@ -77,6 +95,15 @@ struct MatchReduceOp {
}
};

template <typename T, unsigned int Power2Size>
int modeSmemSize(THCState *state) {
int sliceElementSize = sizeof(T) * Power2Size;
int sortScanSize = 2 * Power2Size * sizeof(unsigned int);
int reductionSize = reduceSmemSize<ModeUnsignedPair, 1>(state, Power2Size);

return sliceElementSize + (sortScanSize > reductionSize ? sortScanSize : reductionSize);
}

// The mode kernel has the following characteristics: It uses internal shared memory
// buffers of Power2Size, which must be greater than the number of elements. Additionally,
// there is one block for every slice to calculate the mode for, and in each block there
Expand Down Expand Up @@ -271,6 +298,7 @@ __global__ void computeMode(
// Finally, we have the mode, and an index where it occurs. We use a single thread
// to place this in the appropriate output position
if (tidx == 0) {
assert(match.flag);
long index = TH_INDEX_BASE + match.val;

unsigned int outputOffset = IndexToOffset<T, unsigned int, -1>::get(blockId, values);
Expand Down
17 changes: 11 additions & 6 deletions lib/THC/THCTensorTopK.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,19 @@ __device__ void countRadixUsingMask(CountType counts[RadixSize],
#pragma unroll
for (unsigned int j = 0; j < RadixSize; ++j) {
bool vote = hasVal && (digitInRadix == j);
counts[j] += __popc(__ballot(vote));
counts[j] += vote;
}
}

// Now, for each warp, sum values
if (getLaneId() == 0) {
reduceNValuesInBlock<CountType, AddOp<CountType>, RadixSize>(
smem, counts, blockDim.x, AddOp<CountType>(), 0);

if (threadIdx.x == 0) {
#pragma unroll
for (unsigned int i = 0; i < RadixSize; ++i) {
atomicAdd(&smem[i], counts[i]);
smem[i] = counts[i];
}
}

__syncthreads();

// For each thread, read in the total counts
Expand All @@ -194,6 +195,10 @@ __device__ void countRadixUsingMask(CountType counts[RadixSize],
#define RADIX_SIZE 4 // 2 ^ RADIX_BITS
#define RADIX_MASK (RADIX_SIZE - 1)

int topkSmemSize(THCState *state, long sliceSize) {
return reduceSmemSize<unsigned int, RADIX_SIZE>(state, sliceSize);
}

// This finds the unique value `v` that matches the pattern
// ((v & desired) == desiredMask) in our sorted int format
template <typename DataType, typename BitDataType, typename IndexType>
Expand Down Expand Up @@ -356,7 +361,7 @@ __global__ void gatherTopK(TensorInfo<T, IndexType> input,
IndexType indicesWithinSliceStride) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of IndexType
__shared__ int smem[32]; // one per each warp, up to warp limit
extern __shared__ int smem[]; // one per each warp, up to warp limit

IndexType slice = getLinearBlockId<IndexType>();
if (slice >= numInputSlices) {
Expand Down
8 changes: 4 additions & 4 deletions lib/THC/generic/THCTensorMode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ THC_API void THCTensor_(mode)(THCState *state,
// 2. uses one block per slice, so number of slices must be less than the maximum number of blocks for
// a kernel launch
// 3. Can use 32-bit index math for indexing (mainly just for implementation conciseness, could be changed)
if (sliceSize <= MAX_BLOCK_SIZE &&
if (sliceSize <= MAX_BLOCK_SIZE * 2 &&
slices <= MAX_GRID_SIZE &&
TensorUtils<THCTensor>::canUse32BitIndexMath(state, input)) {
// Beginning our optimized implementation. First thing we want to do is to transpose
Expand Down Expand Up @@ -230,7 +230,7 @@ THC_API void THCTensor_(mode)(THCState *state,
{ \
dim3 blockSize(SIZE / 2); \
\
int memsize = (sizeof(real) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \
int memsize = modeSmemSize<real, SIZE>(state); \
computeMode<real, SIZE> \
<<<grid, blockSize, memsize, THCState_getCurrentStream(state)>>>( \
THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \
Expand All @@ -248,15 +248,15 @@ THC_API void THCTensor_(mode)(THCState *state,
HANDLE_MODE(1024)
break;
case 128:
case 64:
HANDLE_MODE(128)
break;
case 64:
case 32:
case 16:
case 8:
case 4:
case 2:
HANDLE_MODE(32)
HANDLE_MODE(64) // block size should be at least 1 full warp
break;
case 1:
default:
Expand Down
12 changes: 8 additions & 4 deletions lib/THC/generic/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ void THCTensor_(renormRows)(struct THCState* state,
int maxThreads = props->maxThreadsPerBlock;

dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
dim3 block(cols < maxThreads ? cols : maxThreads);
dim3 block(THCRoundUp(cols < maxThreads ? cols : maxThreads, (long) props->warpSize));

int smem = reduceSmemSize<real, 1>(state, cols);

renormRowsL1<real>
<<<grid, block, block.x * sizeof(real),
<<<grid, block, smem,
THCState_getCurrentStream(state)>>>(THCTensor_(data)(state, t),
rows, cols);
}
Expand Down Expand Up @@ -157,11 +159,13 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
THAssert(props != NULL);
int numSM = props->multiProcessorCount;
int maxThreads = props->maxThreadsPerBlock;
dim3 block(numCategories < maxThreads ? numCategories : maxThreads);
dim3 block(THCRoundUp(numCategories < maxThreads ? numCategories : maxThreads, props->warpSize));
dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4);
// smem required for reduction + prefix sums
int smemSize = (block.x * sizeof(real)) + reduceSmemSize<accreal, 1>(state, numCategories);
sampleMultinomialOnce<real, accreal>
<<<grid, block,
block.x * (sizeof(real) * sizeof(accreal)),
smemSize,
THCState_getCurrentStream(state)>>>(
THCudaLongTensor_data(state, self),
numDist,
Expand Down
Loading