Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 168 additions & 162 deletions mlx/backend/cuda/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ __device__ __forceinline__ __nv_bfloat16 nan_value<__nv_bfloat16>() {
return __float2bfloat16(cuda::std::numeric_limits<float>::quiet_NaN());
}

template <>
__device__ __forceinline__ complex64_t nan_value<complex64_t>() {
float qnan = cuda::std::numeric_limits<float>::quiet_NaN();
return complex64_t{qnan, qnan};
}

template <typename T, typename = void>
struct InitValue {
__device__ __forceinline__ static T value() {
Expand All @@ -52,7 +58,9 @@ struct InitValue {
};

template <typename T>
struct InitValue<T, cuda::std::enable_if_t<is_floating_v<T>>> {
struct InitValue<
T,
cuda::std::enable_if_t<is_floating_v<T> || cu::is_complex_v<T>>> {
__device__ __forceinline__ static T value() {
return nan_value<T>();
}
Expand All @@ -65,16 +73,25 @@ __device__ __forceinline__ void thread_swap(T& a, T& b) {
b = w;
}

template <typename T>
__device__ __forceinline__ bool check_nan(T a) {
if constexpr (cu::is_complex_v<T>) {
return cuda::std::isnan(a.real()) || cuda::std::isnan(a.imag());
} else {
return cuda::std::isnan(a);
}
}

template <typename T>
struct LessThan {
__device__ __forceinline__ static T init() {
return InitValue<T>::value();
}

__device__ __forceinline__ bool operator()(T a, T b) const {
if constexpr (is_floating_v<T>) {
bool an = cuda::std::isnan(a);
bool bn = cuda::std::isnan(b);
if constexpr (is_floating_v<T> || cu::is_complex_v<T>) {
bool an = check_nan(a);
bool bn = check_nan(b);
if (an | bn) {
return (!an) & bn;
}
Expand Down Expand Up @@ -745,82 +762,76 @@ void single_block_sort(

dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using ValT = cuda_type_t<CTYPE>;
dispatch_block_dim(bn, [&](auto block_dim) {
constexpr int BLOCK_THREADS = block_dim();
if constexpr (BLOCK_THREADS < 1024) {
dim3 grid(1, n_rows, 1);
dim3 block(BLOCK_THREADS, 1, 1);

dispatch_bool(argsort, [&](auto arg_tag) {
constexpr bool ARG_SORT = decltype(arg_tag)::value;
using OutT = std::conditional_t<ARG_SORT, uint32_t, ValT>;

if (contiguous) {
auto kernel = cu::block_sort_kernel<
ValT,
OutT,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
int64_t in_stride_segment_axis = INT64_MAX;
int64_t out_stride_segment_axis = INT64_MAX;
for (int i = 0; i < nc_shape.size(); i++) {
if (nc_shape[i] == 1) {
continue;
}
if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) {
throw std::runtime_error(
"[Sort::eval_gpu] Stride too large.");
}
in_stride_segment_axis =
std::min(in_stride_segment_axis, in_nc_str[i]);
out_stride_segment_axis =
std::min(out_stride_segment_axis, out_nc_str[i]);
using ValT = cuda_type_t<CTYPE>;
dispatch_block_dim(bn, [&](auto block_dim) {
constexpr int BLOCK_THREADS = block_dim();
if constexpr (BLOCK_THREADS < 1024) {
dim3 grid(1, n_rows, 1);
dim3 block(BLOCK_THREADS, 1, 1);

dispatch_bool(argsort, [&](auto arg_tag) {
constexpr bool ARG_SORT = decltype(arg_tag)::value;
using OutT = std::conditional_t<ARG_SORT, uint32_t, ValT>;

if (contiguous) {
auto kernel = cu::block_sort_kernel<
ValT,
OutT,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
int64_t in_stride_segment_axis = INT64_MAX;
int64_t out_stride_segment_axis = INT64_MAX;
for (int i = 0; i < nc_shape.size(); i++) {
if (nc_shape[i] == 1) {
continue;
}
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<ValT>(in),
gpu_ptr<OutT>(out),
size_sorted_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis);
} else {
auto kernel = cu::block_sort_nc_kernel<
ValT,
OutT,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
auto nc_shape_param = const_param(nc_shape);
auto in_nc_strides_param = const_param(in_nc_str);
auto out_nc_strides_param = const_param(out_nc_str);
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<ValT>(in),
gpu_ptr<OutT>(out),
size_sorted_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
nc_shape_param,
in_nc_strides_param,
out_nc_strides_param,
nc_dim);
if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) {
throw std::runtime_error("[Sort::eval_gpu] Stride too large.");
}
in_stride_segment_axis =
std::min(in_stride_segment_axis, in_nc_str[i]);
out_stride_segment_axis =
std::min(out_stride_segment_axis, out_nc_str[i]);
}
});
}
});
} else {
throw std::runtime_error(
"CUDA backend does not support sorting complex numbers");
}
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<ValT>(in),
gpu_ptr<OutT>(out),
size_sorted_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis);
} else {
auto kernel = cu::block_sort_nc_kernel<
ValT,
OutT,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
auto nc_shape_param = const_param(nc_shape);
auto in_nc_strides_param = const_param(in_nc_str);
auto out_nc_strides_param = const_param(out_nc_str);
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<ValT>(in),
gpu_ptr<OutT>(out),
size_sorted_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
nc_shape_param,
in_nc_strides_param,
out_nc_strides_param,
nc_dim);
}
});
}
});
});
}

Expand Down Expand Up @@ -870,99 +881,94 @@ void multi_block_sort(

dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using ValT = cuda_type_t<CTYPE>;
using IdxT = uint32_t;
constexpr int BLOCK_THREADS = sizeof(ValT) == 8 ? 256 : 512;
dim3 grid(n_blocks, n_rows, 1);
dim3 block(BLOCK_THREADS, 1, 1);

dispatch_bool(argsort, [&](auto arg_tag) {
constexpr bool ARG_SORT = decltype(arg_tag)::value;
auto nc_shape_param = const_param(nc_shape);
auto nc_strides_param = const_param(nc_str);

auto block_sort_kernel = cu::mb_block_sort_kernel<
using ValT = cuda_type_t<CTYPE>;
using IdxT = uint32_t;
constexpr int BLOCK_THREADS = sizeof(ValT) == 8 ? 256 : 512;
dim3 grid(n_blocks, n_rows, 1);
dim3 block(BLOCK_THREADS, 1, 1);

dispatch_bool(argsort, [&](auto arg_tag) {
constexpr bool ARG_SORT = decltype(arg_tag)::value;
auto nc_shape_param = const_param(nc_shape);
auto nc_strides_param = const_param(nc_str);

auto block_sort_kernel = cu::mb_block_sort_kernel<
ValT,
IdxT,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
encoder.set_input_array(in);
encoder.set_output_array(dev_vals_in);
encoder.set_output_array(dev_idxs_in);
encoder.add_kernel_node(
block_sort_kernel,
grid,
block,
gpu_ptr<ValT>(in),
gpu_ptr<ValT>(dev_vals_in),
gpu_ptr<IdxT>(dev_idxs_in),
size_sorted_axis,
stride_sorted_axis,
nc_shape_param,
nc_strides_param,
nc_dim);

int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;

for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks;
merge_tiles *= 2) {
auto partition_kernel = cu::mb_block_partition_kernel<
ValT,
IdxT,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;

encoder.set_input_array(dev_vals_in);
encoder.set_input_array(dev_idxs_in);
encoder.set_output_array(block_partitions);

encoder.add_kernel_node(
partition_kernel,
dim3(1, n_rows, 1),
dim3(n_thr_per_group, 1, 1),
gpu_ptr<IdxT>(block_partitions),
gpu_ptr<ValT>(dev_vals_in),
gpu_ptr<IdxT>(dev_idxs_in),
size_sorted_axis,
merge_tiles,
n_blocks);

auto merge_kernel = cu::mb_block_merge_kernel<
ValT,
IdxT,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
encoder.set_input_array(in);
encoder.set_output_array(dev_vals_in);
encoder.set_output_array(dev_idxs_in);

encoder.set_input_array(dev_vals_in);
encoder.set_input_array(dev_idxs_in);
encoder.set_input_array(block_partitions);
encoder.set_output_array(dev_vals_out);
encoder.set_output_array(dev_idxs_out);

encoder.add_kernel_node(
block_sort_kernel,
grid,
block,
gpu_ptr<ValT>(in),
merge_kernel,
dim3(n_blocks, n_rows, 1),
dim3(BLOCK_THREADS, 1, 1),
gpu_ptr<IdxT>(block_partitions),
gpu_ptr<ValT>(dev_vals_in),
gpu_ptr<IdxT>(dev_idxs_in),
gpu_ptr<ValT>(dev_vals_out),
gpu_ptr<IdxT>(dev_idxs_out),
size_sorted_axis,
stride_sorted_axis,
nc_shape_param,
nc_strides_param,
nc_dim);

int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;

for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks;
merge_tiles *= 2) {
auto partition_kernel = cu::mb_block_partition_kernel<
ValT,
IdxT,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;

encoder.set_input_array(dev_vals_in);
encoder.set_input_array(dev_idxs_in);
encoder.set_output_array(block_partitions);

encoder.add_kernel_node(
partition_kernel,
dim3(1, n_rows, 1),
dim3(n_thr_per_group, 1, 1),
gpu_ptr<IdxT>(block_partitions),
gpu_ptr<ValT>(dev_vals_in),
gpu_ptr<IdxT>(dev_idxs_in),
size_sorted_axis,
merge_tiles,
n_blocks);

auto merge_kernel = cu::mb_block_merge_kernel<
ValT,
IdxT,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;

encoder.set_input_array(dev_vals_in);
encoder.set_input_array(dev_idxs_in);
encoder.set_input_array(block_partitions);
encoder.set_output_array(dev_vals_out);
encoder.set_output_array(dev_idxs_out);

encoder.add_kernel_node(
merge_kernel,
dim3(n_blocks, n_rows, 1),
dim3(BLOCK_THREADS, 1, 1),
gpu_ptr<IdxT>(block_partitions),
gpu_ptr<ValT>(dev_vals_in),
gpu_ptr<IdxT>(dev_idxs_in),
gpu_ptr<ValT>(dev_vals_out),
gpu_ptr<IdxT>(dev_idxs_out),
size_sorted_axis,
merge_tiles,
n_blocks);
std::swap(dev_vals_in, dev_vals_out);
std::swap(dev_idxs_in, dev_idxs_out);
}
});
} else {
throw std::runtime_error(
"CUDA backend does not support sorting complex numbers");
}
merge_tiles,
n_blocks);
std::swap(dev_vals_in, dev_vals_out);
std::swap(dev_idxs_in, dev_idxs_out);
}
});
});

encoder.add_temporary(dev_vals_out);
Expand Down
Loading
Loading