Skip to content

Support pickling array for bfloat16#2586

Merged
awni merged 3 commits intoml-explore:mainfrom
CC-Yeh:pickle_bfloat16
Sep 23, 2025
Merged

Support pickling array for bfloat16#2586
awni merged 3 commits intoml-explore:mainfrom
CC-Yeh:pickle_bfloat16

Conversation

@CC-Yeh
Copy link
Copy Markdown
Contributor

@CC-Yeh CC-Yeh commented Sep 11, 2025

Proposed changes

  • Store bfloat16 data as a uint16 NumPy array.
  • Pickle the dtype to preserve the bfloat16 type.

closes #795

Note that ml_dtype defines bfloat16 as a NumPy dtype,
but integrating it natively into MLX seems like a big change.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

}
}

mx::array nd_array_to_mlx_as_dtype(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nd_array_to_mlx already takes a dtype parameter. Can we use that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I first considered forcing nd_array_to_mlx to use the optional dtype whenever it was set. But to_array and create_array also call this function, and some callers may expect dtype inference or want to control casting themselves. To avoid breaking that, I now create a bfloat16 view of the ndarray in __setstate__ and keep all pickling logic on the pickling side.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved the view logic, ready for re-review : )

Comment on lines +474 to +475
static_cast<uint8_t>(a.dtype().val()),
a.dtype().size());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to save the val and size. Just save the val.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx

Comment on lines +484 to +489
using ND = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
auto nd = nb::cast<ND>(state[0]);
auto val = nb::cast<uint8_t>(state[1]);
auto size = nb::cast<uint8_t>(state[2]);
mx::Dtype dtype(static_cast<mx::Dtype::Val>(val), size);
new (&arr) mx::array(nd_array_to_mlx_as_dtype(nd, dtype));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then here, check the val. If it's bfloat pass in bfloat16 as the type. Otherwise nullopt.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +151 to +152
case mx::bfloat16:
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slightly problematic because now MLX bfloat16 arrays will convert to numpy with e.g. .numpy. I don't think we want that behavior, it should still raise an exception for that.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe instead of doing this we can view() the array as uint16 before calling this function when its type is bfloat16.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for the suggestion.

@CC-Yeh CC-Yeh requested a review from awni September 13, 2025 07:23
@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Sep 17, 2025

@awni Could you take another look? Thanks!

Copy link
Copy Markdown
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Let's merge when the tests clear.

@awni awni merged commit fbbf3b9 into ml-explore:main Sep 23, 2025
7 of 8 checks passed
faisalmemon pushed a commit to faisalmemon/mlx that referenced this pull request Oct 30, 2025
* add bfloat16 pickling

* Improvements

* improve

---------

Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Support pickling array for bfloat16

2 participants