diff --git a/src/collectives/collectives.h b/src/collectives/collectives.h index 73fe7d5c8..6d45fe7ea 100644 --- a/src/collectives/collectives.h +++ b/src/collectives/collectives.h @@ -38,17 +38,22 @@ DECL_COLL3(coll, op, f16) \ DECL_COLL3(coll, op, f32) \ DECL_COLL3(coll, op, f64) +#define DECL_COLL2A(coll, op) \ + DECL_COLL3(coll, op, i8) #define DECL_COLL(coll) \ DECL_COLL2(coll, sum) \ DECL_COLL2(coll, prod) \ + DECL_COLL2(coll, max) \ DECL_COLL2(coll, min) \ - DECL_COLL2(coll, max) + DECL_COLL2A(coll, band) \ + DECL_COLL2A(coll, bor) \ + DECL_COLL2A(coll, bxor) #define DECL_ALL_COLLS \ - DECL_COLL2(ncclBroadcast, copy) \ + DECL_COLL2A(ncclBroadcast, copy) \ DECL_COLL(ncclReduce) \ - DECL_COLL2(ncclAllGather, copy) \ + DECL_COLL2A(ncclAllGather, copy) \ DECL_COLL(ncclReduceScatter) \ DECL_COLL(ncclAllReduce) \ diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 8c336bf94..f1733ff31 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -139,10 +139,22 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ IMPL_COLL2(collf, prod, FuncProd, colln, ncclProd); #elif NCCL_OP == 2 #define IMPL_COLL_R(collf, colln) \ - IMPL_COLL2(collf, min, FuncMin, colln, ncclMin); + IMPL_COLL2(collf, max, FuncMax, colln, ncclMax); #elif NCCL_OP == 3 #define IMPL_COLL_R(collf, colln) \ - IMPL_COLL2(collf, max, FuncMax, colln, ncclMax); + IMPL_COLL2(collf, min, FuncMin, colln, ncclMin); +// Bit RedOp only use i8 +#elif NCCL_OP == 4 && NCCL_TYPE == 0 +#define IMPL_COLL_R(collf, colln) \ + IMPL_COLL3(collf, band, FuncBitAnd, i8, int8_t, colln, ncclBitAnd, ncclInt8); +#elif NCCL_OP == 5 && NCCL_TYPE == 0 +#define IMPL_COLL_R(collf, colln) \ + IMPL_COLL3(collf, bor, FuncBitOr, i8, int8_t, colln, ncclBitOr, ncclInt8); +#elif NCCL_OP == 6 && NCCL_TYPE == 0 +#define IMPL_COLL_R(collf, colln) \ + IMPL_COLL3(collf, bxor, FuncBitXor, i8, int8_t, colln, ncclBitXor, ncclInt8); +#else +#define IMPL_COLL_R(collf, colln) #endif // Copy primitives only define one diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu index 010c4548c..407ab4d9c 100644 --- a/src/collectives/device/functions.cu +++ b/src/collectives/device/functions.cu @@ -39,12 +39,19 @@ NCCL_FUNC4(coll, op, i8) // Must be consistent with ncclRedOp_t +// Bit RedOp can only use i8 #define NCCL_FUNCS2A(coll) \ NCCL_FUNCS3A(coll, sum ), \ NCCL_FUNCS3A(coll, prod), \ NCCL_FUNCS3A(coll, max ), \ - NCCL_FUNCS3A(coll, min ) + NCCL_FUNCS3A(coll, min ), \ + NCCL_FUNCS3B(coll, band), \ + NCCL_FUNCS3B(coll, bor ), \ + NCCL_FUNCS3B(coll, bxor) #define NCCL_FUNCS2B(coll) \ + NCCL_FUNCS3B(coll, copy), \ + NCCL_FUNCS3B(coll, copy), \ + NCCL_FUNCS3B(coll, copy), \ NCCL_FUNCS3B(coll, copy), \ NCCL_FUNCS3B(coll, copy), \ NCCL_FUNCS3B(coll, copy), \ diff --git a/src/collectives/device/gen_rules.sh b/src/collectives/device/gen_rules.sh index 4413213e1..9ac8ba920 100755 --- a/src/collectives/device/gen_rules.sh +++ b/src/collectives/device/gen_rules.sh @@ -11,7 +11,7 @@ targets="GENOBJS := \\\\\n" for base in all_reduce all_gather broadcast reduce reduce_scatter; do opn=0 - for op in sum prod min max; do + for op in sum prod max min band bor bxor; do dtn=0 for dt in i8 u8 i32 u32 i64 u64 f16 f32 f64; do echo "${dir}/${base}_${op}_${dt}.o : ${base}.cu ${dir}/${base}.dep" diff --git a/src/collectives/device/reduce_kernel.h b/src/collectives/device/reduce_kernel.h index 0e907939f..98b658272 100644 --- a/src/collectives/device/reduce_kernel.h +++ b/src/collectives/device/reduce_kernel.h @@ -46,6 +46,30 @@ struct FuncMin { } }; +template +struct FuncBitAnd { + template + __device__ U operator()(const U x, const U y) const { + return x & y; + } +}; + +template +struct FuncBitOr { + template + __device__ U operator()(const U x, const U y) const { + return x | y; + } +}; + +template +struct FuncBitXor { + template + __device__ U operator()(const U x, const U y) const { + return x ^ y; + } +}; + #define MASK0 0x00ff00ff #define MASK1 0xff00ff00 static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) { @@ -299,4 +323,5 @@ struct FuncMin { return __float2half(fm); } }; + #endif // REDUCE_KERNEL_H_ diff --git a/src/enqueue.cc b/src/enqueue.cc index b48563456..3743bb677 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -42,12 +42,19 @@ (void*)NCCL_FUNC4(coll, op, i8) // Must be consistent with ncclRedOp_t -- but we only generate kernel for sums. +// Bit RedOp can only use i8 #define NCCL_FUNCS2A(coll) \ NCCL_FUNCS3A(coll, sum), \ NCCL_FUNCS3A(coll, sum), \ NCCL_FUNCS3A(coll, sum), \ - NCCL_FUNCS3A(coll, sum) + NCCL_FUNCS3A(coll, sum), \ + NCCL_FUNCS3B(coll, sum), \ + NCCL_FUNCS3B(coll, sum), \ + NCCL_FUNCS3B(coll, sum) #define NCCL_FUNCS2B(coll) \ + NCCL_FUNCS3B(coll, copy), \ + NCCL_FUNCS3B(coll, copy), \ + NCCL_FUNCS3B(coll, copy), \ NCCL_FUNCS3B(coll, copy), \ NCCL_FUNCS3B(coll, copy), \ NCCL_FUNCS3B(coll, copy), \ diff --git a/src/misc/argcheck.cc b/src/misc/argcheck.cc index 364f04152..80f698c70 100644 --- a/src/misc/argcheck.cc +++ b/src/misc/argcheck.cc @@ -43,9 +43,10 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) { WARN("%s : invalid type %d", info->opName, info->datatype); return ncclInvalidArgument; } - // Type is OK, compute nbytes. Convert Allgather/Broadcast calls to chars. + // Type is OK, compute nbytes. Convert Allgather/Broadcast/BitRedOp calls to chars. info->nBytes = info->count * ncclTypeSize(info->datatype); - if (info->coll == ncclCollAllGather || info->coll == ncclCollBroadcast) { + if (info->coll == ncclCollAllGather || info->coll == ncclCollBroadcast || + info->op == ncclBitAnd || info->op == ncclBitOr || info->op == ncclBitXor) { info->count = info->nBytes; info->datatype = ncclInt8; } diff --git a/src/nccl.h.in b/src/nccl.h.in index 985274eae..2680033f9 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -103,7 +103,10 @@ typedef enum { ncclSum = 0, ncclProd = 1, ncclMax = 2, ncclMin = 3, - ncclNumOps = 4 } ncclRedOp_t; + ncclBitAnd = 4, + ncclBitOr = 5, + ncclBitXor = 6, + ncclNumOps = 7 } ncclRedOp_t; /* Data types */ typedef enum { ncclInt8 = 0, ncclChar = 0,