From 234073728c9f38c795f184c5d4a8dbd162e450a8 Mon Sep 17 00:00:00 2001 From: Loong <69568351+Looong01@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:51:27 +0200 Subject: [PATCH] Add ROCm 6.4.3+ support --- csrc/cuda/utils.cuh | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/csrc/cuda/utils.cuh b/csrc/cuda/utils.cuh index 747a8e2c..09901238 100644 --- a/csrc/cuda/utils.cuh +++ b/csrc/cuda/utils.cuh @@ -6,19 +6,26 @@ AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") -__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask, +// On ROCm, __shfl_*_sync requires a 64-bit mask; on CUDA it's 32-bit. +#ifdef USE_ROCM + using warp_mask_t = unsigned long long; +#else + using warp_mask_t = unsigned int; +#endif + +__device__ __inline__ at::Half __shfl_up_sync(const warp_mask_t mask, const at::Half var, const unsigned int delta) { return __shfl_up_sync(mask, var.operator __half(), delta); } -__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, +__device__ __inline__ at::Half __shfl_down_sync(const warp_mask_t mask, const at::Half var, const unsigned int delta) { return __shfl_down_sync(mask, var.operator __half(), delta); } -__device__ __inline__ at::Half __shfl_sync(const unsigned mask, +__device__ __inline__ at::Half __shfl_sync(const warp_mask_t mask, const at::Half var, const int delta) { return __shfl_sync(mask, var.operator __half(), delta);