From 29769816f64d2cf4671b3f0181e72a3906728902 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 26 Mar 2026 00:49:18 -0700 Subject: [PATCH] Remove no longer needed const_cast --- mlx/backend/cuda/reduce/all_reduce.cu | 14 +++----------- mlx/backend/cuda/reduce/col_reduce.cu | 8 ++------ mlx/backend/cuda/reduce/row_reduce.cu | 9 +++++++-- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index dd4980a971..0659504c5b 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -98,9 +98,7 @@ void all_reduce( size_t block_step; size_t insize = in.size(); Dtype dt = in.dtype(); - - // Cub doesn't like const pointers for load (sigh). - void* indata = const_cast(gpu_ptr(in)); + void* indata = gpu_ptr(in); // Large array so allocate an intermediate and accumulate there std::tie(blocks, threads, block_step) = get_args(insize, N_READS); @@ -120,7 +118,7 @@ void all_reduce( kernel, blocks, threads, - static_cast(indata), + indata, gpu_ptr(intermediate), block_step, insize); @@ -143,13 +141,7 @@ void all_reduce( using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; encoder.add_kernel_node( - kernel, - blocks, - threads, - static_cast(indata), - gpu_ptr(out), - block_step, - insize); + kernel, blocks, threads, indata, gpu_ptr(out), block_step, insize); }); }); } diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index b3b5a93fd7..8a6327bf38 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -282,8 +282,6 @@ void col_reduce_looped( using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(gpu_ptr(in)); constexpr int N_READS = 4; constexpr int BM = 32; @@ -296,7 +294,7 @@ void col_reduce_looped( kernel, grid, blocks, - indata, + gpu_ptr(in), gpu_ptr(out), static_cast(args), out.size() / args.reduction_stride); @@ -389,8 +387,6 @@ void col_reduce_two_pass( using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(gpu_ptr(in)); constexpr int N_READS = 4; constexpr int BM = 32; @@ -403,7 +399,7 @@ void col_reduce_two_pass( kernel, grid, blocks, - indata, + gpu_ptr(in), gpu_ptr(intermediate), static_cast(args), out.size() / args.reduction_stride); diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index d09071669c..1dd40d14bd 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -268,10 +268,15 @@ void row_reduce_simple( kernel = cu::row_reduce_simple; } - T* indata = const_cast(gpu_ptr(in)); int size = plan.shape.back(); encoder.add_kernel_node( - kernel, grid, block, indata, gpu_ptr(out), out.size(), size); + kernel, + grid, + block, + gpu_ptr(in), + gpu_ptr(out), + out.size(), + size); }); }); }