Skip to content

Fix sort NaN handling for float16 and bfloat16#3269

Merged
angeloskath merged 1 commit intoml-explore:mainfrom
Lyxot:fix/sort-nan-half-types
Mar 19, 2026
Merged

Fix sort NaN handling for float16 and bfloat16#3269
angeloskath merged 1 commit intoml-explore:mainfrom
Lyxot:fix/sort-nan-half-types

Conversation

@Lyxot
Copy link
Copy Markdown
Contributor

@Lyxot Lyxot commented Mar 17, 2026

Summary

mx.sort and mx.argsort produce incorrect results for float16 and bfloat16 arrays containing NaN values on both CUDA and CPU (x86) backends.

Root cause

The NaN-aware comparator and init value in sort are guarded by std::is_floating_point_v<T>, which returns false for __half/__nv_bfloat16 (CUDA) and _MLX_Float16/_MLX_BFloat16 (x86 CPU), so NaN handling is skipped.

Fix

  • CUDA: Replace std::is_floating_point_v<T> with is_floating_v<T>, and extend the trait to cover __half/__nv_bfloat16.
  • CPU: Add a local is_floating_v<T> trait covering float16_t/bfloat16_t in sort.cpp.

Tests

Add test_sort_nan and test_argsort_nan coverage for float16/bfloat16.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works

Copilot AI review requested due to automatic review settings March 17, 2026 12:33
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Improves NaN-aware sorting/argsorting across CPU and CUDA backends, particularly for fp16/bf16 types, and extends the Python test suite to cover these dtypes.

Changes:

  • Extend NaN-aware sort testing to float32, float16, and bfloat16, and add a new argsort NaN test for the same dtypes.
  • Update CUDA sort comparator initialization and comparisons to treat CUDA half/bfloat16 types as floating for NaN handling.
  • Update CPU sort/argsort/argpartition NaN handling to apply to float16_t and bfloat16_t via a shared floating-type trait.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
python/tests/test_ops.py Adds fp16/bf16 coverage for NaN behavior in sort and introduces a new argsort NaN test.
mlx/backend/cuda/sort.cu Switches NaN handling gates from std::is_floating_point to project is_floating_v to include half/bfloat16.
mlx/backend/cuda/kernel_utils.cuh Expands is_floating_v to include CUDA __half and __nv_bfloat16.
mlx/backend/cpu/sort.cpp Introduces is_floating_v to include fp16/bf16 and applies it to NaN-aware sort/argsort/argpartition logic.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

@angeloskath angeloskath merged commit 82809eb into ml-explore:main Mar 19, 2026
19 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants