Support pickling array for bfloat16#2586
Conversation
python/src/convert.cpp
Outdated
| } | ||
| } | ||
|
|
||
| mx::array nd_array_to_mlx_as_dtype( |
There was a problem hiding this comment.
nd_array_to_mlx already takes a dtype parameter. Can we use that?
There was a problem hiding this comment.
I first considered forcing nd_array_to_mlx to use the optional dtype whenever it was set. But to_array and create_array also call this function, and some callers may expect dtype inference or want to control casting themselves. To avoid breaking that, I now create a bfloat16 view of the ndarray in __setstate__ and keep all pickling logic on the pickling side.
There was a problem hiding this comment.
Improved the view logic, ready for re-review : )
python/src/array.cpp
Outdated
| static_cast<uint8_t>(a.dtype().val()), | ||
| a.dtype().size()); |
There was a problem hiding this comment.
I don't think we need to save the val and size. Just save the val.
python/src/array.cpp
Outdated
| using ND = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>; | ||
| auto nd = nb::cast<ND>(state[0]); | ||
| auto val = nb::cast<uint8_t>(state[1]); | ||
| auto size = nb::cast<uint8_t>(state[2]); | ||
| mx::Dtype dtype(static_cast<mx::Dtype::Val>(val), size); | ||
| new (&arr) mx::array(nd_array_to_mlx_as_dtype(nd, dtype)); |
There was a problem hiding this comment.
Then here, check the val. If it's bfloat pass in bfloat16 as the type. Otherwise nullopt.
python/src/convert.cpp
Outdated
| case mx::bfloat16: | ||
| throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); | ||
| return mlx_to_nd_array_impl<uint16_t, NDParams...>(a); |
There was a problem hiding this comment.
This is slightly problematic because now MLX bfloat16 arrays will convert to numpy with e.g. .numpy. I don't think we want that behavior, it should still raise an exception for that.
There was a problem hiding this comment.
Maybe instead of doing this we can view() the array as uint16 before calling this function when its type is bfloat16.
There was a problem hiding this comment.
Done, thanks for the suggestion.
|
@awni Could you take another look? Thanks! |
17adcd0 to
a2c52b3
Compare
awni
left a comment
There was a problem hiding this comment.
Looks great! Let's merge when the tests clear.
* add bfloat16 pickling * Improvements * improve --------- Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
Proposed changes
bfloat16data as auint16NumPy array.dtypeto preserve thebfloat16type.closes #795
Note that ml_dtype defines bfloat16 as a NumPy dtype,
but integrating it natively into MLX seems like a big change.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes