Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/collectives/collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,22 @@
DECL_COLL3(coll, op, f16) \
DECL_COLL3(coll, op, f32) \
DECL_COLL3(coll, op, f64)
#define DECL_COLL2A(coll, op) \
DECL_COLL3(coll, op, i8)

#define DECL_COLL(coll) \
DECL_COLL2(coll, sum) \
DECL_COLL2(coll, prod) \
DECL_COLL2(coll, max) \
DECL_COLL2(coll, min) \
DECL_COLL2(coll, max)
DECL_COLL2A(coll, band) \
DECL_COLL2A(coll, bor) \
DECL_COLL2A(coll, bxor)

#define DECL_ALL_COLLS \
DECL_COLL2(ncclBroadcast, copy) \
DECL_COLL2A(ncclBroadcast, copy) \
DECL_COLL(ncclReduce) \
DECL_COLL2(ncclAllGather, copy) \
DECL_COLL2A(ncclAllGather, copy) \
DECL_COLL(ncclReduceScatter) \
DECL_COLL(ncclAllReduce) \

Expand Down
16 changes: 14 additions & 2 deletions src/collectives/device/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,22 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
IMPL_COLL2(collf, prod, FuncProd, colln, ncclProd);
#elif NCCL_OP == 2
#define IMPL_COLL_R(collf, colln) \
IMPL_COLL2(collf, min, FuncMin, colln, ncclMin);
IMPL_COLL2(collf, max, FuncMax, colln, ncclMax);
#elif NCCL_OP == 3
#define IMPL_COLL_R(collf, colln) \
IMPL_COLL2(collf, max, FuncMax, colln, ncclMax);
IMPL_COLL2(collf, min, FuncMin, colln, ncclMin);
// Bit RedOp only use i8
#elif NCCL_OP == 4 && NCCL_TYPE == 0
#define IMPL_COLL_R(collf, colln) \
IMPL_COLL3(collf, band, FuncBitAnd, i8, int8_t, colln, ncclBitAnd, ncclInt8);
#elif NCCL_OP == 5 && NCCL_TYPE == 0
#define IMPL_COLL_R(collf, colln) \
IMPL_COLL3(collf, bor, FuncBitOr, i8, int8_t, colln, ncclBitOr, ncclInt8);
#elif NCCL_OP == 6 && NCCL_TYPE == 0
#define IMPL_COLL_R(collf, colln) \
IMPL_COLL3(collf, bxor, FuncBitXor, i8, int8_t, colln, ncclBitXor, ncclInt8);
#else
#define IMPL_COLL_R(collf, colln)
#endif

// Copy primitives only define one
Expand Down
9 changes: 8 additions & 1 deletion src/collectives/device/functions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,19 @@
NCCL_FUNC4(coll, op, i8)

// Must be consistent with ncclRedOp_t
// Bit RedOp can only use i8
#define NCCL_FUNCS2A(coll) \
NCCL_FUNCS3A(coll, sum ), \
NCCL_FUNCS3A(coll, prod), \
NCCL_FUNCS3A(coll, max ), \
NCCL_FUNCS3A(coll, min )
NCCL_FUNCS3A(coll, min ), \
NCCL_FUNCS3B(coll, band), \
NCCL_FUNCS3B(coll, bor ), \
NCCL_FUNCS3B(coll, bxor)
#define NCCL_FUNCS2B(coll) \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
Expand Down
2 changes: 1 addition & 1 deletion src/collectives/device/gen_rules.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ targets="GENOBJS := \\\\\n"

for base in all_reduce all_gather broadcast reduce reduce_scatter; do
opn=0
for op in sum prod min max; do
for op in sum prod max min band bor bxor; do
dtn=0
for dt in i8 u8 i32 u32 i64 u64 f16 f32 f64; do
echo "${dir}/${base}_${op}_${dt}.o : ${base}.cu ${dir}/${base}.dep"
Expand Down
25 changes: 25 additions & 0 deletions src/collectives/device/reduce_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,30 @@ struct FuncMin {
}
};

template<typename T>
struct FuncBitAnd {
template<typename U>
__device__ U operator()(const U x, const U y) const {
return x & y;
}
};

template<typename T>
struct FuncBitOr {
template<typename U>
__device__ U operator()(const U x, const U y) const {
return x | y;
}
};

template<typename T>
struct FuncBitXor {
template<typename U>
__device__ U operator()(const U x, const U y) const {
return x ^ y;
}
};

#define MASK0 0x00ff00ff
#define MASK1 0xff00ff00
static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) {
Expand Down Expand Up @@ -299,4 +323,5 @@ struct FuncMin<half> {
return __float2half(fm);
}
};

#endif // REDUCE_KERNEL_H_
9 changes: 8 additions & 1 deletion src/enqueue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,19 @@
(void*)NCCL_FUNC4(coll, op, i8)

// Must be consistent with ncclRedOp_t -- but we only generate kernel for sums.
// Bit RedOp can only use i8
#define NCCL_FUNCS2A(coll) \
NCCL_FUNCS3A(coll, sum), \
NCCL_FUNCS3A(coll, sum), \
NCCL_FUNCS3A(coll, sum), \
NCCL_FUNCS3A(coll, sum)
NCCL_FUNCS3A(coll, sum), \
NCCL_FUNCS3B(coll, sum), \
NCCL_FUNCS3B(coll, sum), \
NCCL_FUNCS3B(coll, sum)
#define NCCL_FUNCS2B(coll) \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy), \
Expand Down
5 changes: 3 additions & 2 deletions src/misc/argcheck.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) {
WARN("%s : invalid type %d", info->opName, info->datatype);
return ncclInvalidArgument;
}
// Type is OK, compute nbytes. Convert Allgather/Broadcast calls to chars.
// Type is OK, compute nbytes. Convert Allgather/Broadcast/BitRedOp calls to chars.
info->nBytes = info->count * ncclTypeSize(info->datatype);
if (info->coll == ncclCollAllGather || info->coll == ncclCollBroadcast) {
if (info->coll == ncclCollAllGather || info->coll == ncclCollBroadcast ||
info->op == ncclBitAnd || info->op == ncclBitOr || info->op == ncclBitXor) {
info->count = info->nBytes;
info->datatype = ncclInt8;
}
Expand Down
5 changes: 4 additions & 1 deletion src/nccl.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ typedef enum { ncclSum = 0,
ncclProd = 1,
ncclMax = 2,
ncclMin = 3,
ncclNumOps = 4 } ncclRedOp_t;
ncclBitAnd = 4,
ncclBitOr = 5,
ncclBitXor = 6,
ncclNumOps = 7 } ncclRedOp_t;

/* Data types */
typedef enum { ncclInt8 = 0, ncclChar = 0,
Expand Down