From 44c5bede923ba9c8d3407ffdd07bc8b3dd3a5006 Mon Sep 17 00:00:00 2001 From: Chen-Chen Yeh Date: Thu, 11 Sep 2025 13:02:14 +0200 Subject: [PATCH 1/3] add bfloat16 pickling --- python/src/array.cpp | 23 +++++++++++++++---- python/src/convert.cpp | 45 +++++++++++++++++++++++++++++++++++++- python/src/convert.h | 4 ++++ python/tests/test_array.py | 6 +---- 4 files changed, 68 insertions(+), 10 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index ae38fa2110..bdc1fdb4eb 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -466,12 +466,27 @@ void init_array(nb::module_& m) { }) .def( "__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); }) - .def("__getstate__", &mlx_to_np_array) + .def( + "__getstate__", + [](const mx::array& a) { + return nb::make_tuple( + mlx_to_np_array(a), + static_cast(a.dtype().val()), + a.dtype().size()); + }) .def( "__setstate__", - [](mx::array& arr, - const nb::ndarray& state) { - new (&arr) mx::array(nd_array_to_mlx(state, std::nullopt)); + [](mx::array& arr, const nb::tuple& state) { + if (nb::len(state) != 3) { + throw std::invalid_argument( + "Invalid pickle state: expected (ndarray, Dtype::Val, size)"); + } + using ND = nb::ndarray; + auto nd = nb::cast(state[0]); + auto val = nb::cast(state[1]); + auto size = nb::cast(state[2]); + mx::Dtype dtype(static_cast(val), size); + new (&arr) mx::array(nd_array_to_mlx_as_dtype(nd, dtype)); }) .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) .def( diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 1340b663a0..dd569e2c18 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -149,7 +149,7 @@ nb::ndarray mlx_to_nd_array(const mx::array& a) { case mx::float16: return mlx_to_nd_array_impl(a); case mx::bfloat16: - throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); + return mlx_to_nd_array_impl(a); case mx::float32: return mlx_to_nd_array_impl(a); case mx::float64: @@ -161,6 +161,49 @@ nb::ndarray mlx_to_nd_array(const mx::array& a) { } } +mx::array nd_array_to_mlx_as_dtype( + nb::ndarray nd_array, + mx::Dtype dtype) { + // Compute the shape + mx::Shape shape; + for (int i = 0; i < nd_array.ndim(); i++) { + shape.push_back(check_shape_dim(nd_array.shape(i))); + } + switch (dtype) { + case mx::bool_: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::uint8: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::uint16: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::uint32: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::uint64: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::int8: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::int16: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::int32: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::int64: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::float16: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::bfloat16: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::float32: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::float64: + return nd_array_to_mlx_contiguous(nd_array, shape, dtype); + case mx::complex64: + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype); + default: + throw nb::type_error("type cannot be converted to MLX with reinterpret."); + } +} + nb::ndarray mlx_to_np_array(const mx::array& a) { return mlx_to_nd_array(a); } diff --git a/python/src/convert.h b/python/src/convert.h index f5016c8af8..4c458fcf63 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -34,6 +34,10 @@ mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype); +mx::array nd_array_to_mlx_as_dtype( + nb::ndarray nd_array, + mx::Dtype dtype); + nb::ndarray mlx_to_np_array(const mx::array& a); nb::ndarray<> mlx_to_dlpack(const mx::array& a); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index ae1cb784ff..02c515c2ad 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -886,6 +886,7 @@ def test_array_pickle(self): mx.uint64, mx.float16, mx.float32, + mx.bfloat16, mx.complex64, ] @@ -895,11 +896,6 @@ def test_array_pickle(self): y = pickle.loads(state) self.assertEqualArray(y, x) - # check if it throws an error when dtype is not supported (bfloat16) - x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16) - with self.assertRaises(TypeError): - pickle.dumps(x) - def test_array_copy(self): dtypes = [ mx.int8, From 71b237906e9b9ce3983920713afd48d6d4b2ce88 Mon Sep 17 00:00:00 2001 From: Chen-Chen Yeh Date: Fri, 12 Sep 2025 22:22:22 +0200 Subject: [PATCH 2/3] Improvements --- python/src/array.cpp | 34 ++++++++++++++++++--------- python/src/convert.cpp | 48 ++------------------------------------ python/src/convert.h | 10 ++++---- python/tests/test_array.py | 2 +- 4 files changed, 31 insertions(+), 63 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index bdc1fdb4eb..061db63f98 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -469,24 +469,36 @@ void init_array(nb::module_& m) { .def( "__getstate__", [](const mx::array& a) { - return nb::make_tuple( - mlx_to_np_array(a), - static_cast(a.dtype().val()), - a.dtype().size()); + auto nd = (a.dtype() == mx::bfloat16) + ? mlx_to_np_array(mx::view(a, mx::uint16)) + : mlx_to_np_array(a); + return nb::make_tuple(nd, static_cast(a.dtype().val())); }) .def( "__setstate__", [](mx::array& arr, const nb::tuple& state) { - if (nb::len(state) != 3) { + if (nb::len(state) != 2) { throw std::invalid_argument( - "Invalid pickle state: expected (ndarray, Dtype::Val, size)"); + "Invalid pickle state: expected (ndarray, Dtype::Val)"); } using ND = nb::ndarray; - auto nd = nb::cast(state[0]); - auto val = nb::cast(state[1]); - auto size = nb::cast(state[2]); - mx::Dtype dtype(static_cast(val), size); - new (&arr) mx::array(nd_array_to_mlx_as_dtype(nd, dtype)); + ND nd = nb::cast(state[0]); + auto val = static_cast(nb::cast(state[1])); + if (val == mx::Dtype::Val::bfloat16) { + std::vector shape; + for (size_t i = 0; i < nd.ndim(); ++i) + shape.push_back((size_t)nd.shape(i)); + new (&arr) mx::array(nd_array_to_mlx( + ND(nd.data(), + nd.ndim(), + shape.data(), + {}, + nullptr, + nb::bfloat16), + mx::bfloat16)); + } else { + new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt)); + } }) .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) .def( diff --git a/python/src/convert.cpp b/python/src/convert.cpp index dd569e2c18..88da37103e 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -23,8 +23,6 @@ struct ndarray_traits { static constexpr bool is_int = false; static constexpr bool is_signed = true; }; - -static constexpr dlpack::dtype bfloat16{4, 16, 1}; }; // namespace nanobind int check_shape_dim(int64_t dim) { @@ -51,6 +49,7 @@ mx::array nd_array_to_mlx( std::optional dtype) { // Compute the shape and size mx::Shape shape; + shape.reserve(nd_array.ndim()); for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } @@ -149,7 +148,7 @@ nb::ndarray mlx_to_nd_array(const mx::array& a) { case mx::float16: return mlx_to_nd_array_impl(a); case mx::bfloat16: - return mlx_to_nd_array_impl(a); + throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); case mx::float32: return mlx_to_nd_array_impl(a); case mx::float64: @@ -161,49 +160,6 @@ nb::ndarray mlx_to_nd_array(const mx::array& a) { } } -mx::array nd_array_to_mlx_as_dtype( - nb::ndarray nd_array, - mx::Dtype dtype) { - // Compute the shape - mx::Shape shape; - for (int i = 0; i < nd_array.ndim(); i++) { - shape.push_back(check_shape_dim(nd_array.shape(i))); - } - switch (dtype) { - case mx::bool_: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::uint8: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::uint16: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::uint32: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::uint64: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::int8: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::int16: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::int32: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::int64: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::float16: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::bfloat16: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::float32: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::float64: - return nd_array_to_mlx_contiguous(nd_array, shape, dtype); - case mx::complex64: - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype); - default: - throw nb::type_error("type cannot be converted to MLX with reinterpret."); - } -} - nb::ndarray mlx_to_np_array(const mx::array& a) { return mlx_to_nd_array(a); } diff --git a/python/src/convert.h b/python/src/convert.h index 4c458fcf63..4e952ee48f 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -12,8 +12,12 @@ namespace mx = mlx::core; namespace nb = nanobind; +namespace nanobind { +static constexpr dlpack::dtype bfloat16{4, 16, 1}; +}; // namespace nanobind + struct ArrayLike { - ArrayLike(nb::object obj) : obj(obj) {}; + ArrayLike(nb::object obj) : obj(obj){}; nb::object obj; }; @@ -34,10 +38,6 @@ mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype); -mx::array nd_array_to_mlx_as_dtype( - nb::ndarray nd_array, - mx::Dtype dtype); - nb::ndarray mlx_to_np_array(const mx::array& a); nb::ndarray<> mlx_to_dlpack(const mx::array& a); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 02c515c2ad..e932382b1b 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -532,7 +532,7 @@ def test_array_repr(self): self.assertEqual(str(x), expected) x = mx.array([[1, 2], [1, 2], [1, 2]]) - expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)" + expected = "array([[1, 2],\n [1, 2],\n [1, 2]], dtype=int32)" self.assertEqual(str(x), expected) x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]]) From a2c52b3bdac893ea7adc844a63c78c2350025801 Mon Sep 17 00:00:00 2001 From: Chen-Chen Yeh Date: Thu, 18 Sep 2025 20:02:24 +0200 Subject: [PATCH 3/3] improve --- python/src/array.cpp | 8 +++----- python/src/convert.h | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 061db63f98..9367d4e091 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -485,14 +485,12 @@ void init_array(nb::module_& m) { ND nd = nb::cast(state[0]); auto val = static_cast(nb::cast(state[1])); if (val == mx::Dtype::Val::bfloat16) { - std::vector shape; - for (size_t i = 0; i < nd.ndim(); ++i) - shape.push_back((size_t)nd.shape(i)); + auto owner = nb::handle(state[0].ptr()); new (&arr) mx::array(nd_array_to_mlx( ND(nd.data(), nd.ndim(), - shape.data(), - {}, + reinterpret_cast(nd.shape_ptr()), + owner, nullptr, nb::bfloat16), mx::bfloat16)); diff --git a/python/src/convert.h b/python/src/convert.h index 4e952ee48f..4f8a77abe5 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -17,7 +17,7 @@ static constexpr dlpack::dtype bfloat16{4, 16, 1}; }; // namespace nanobind struct ArrayLike { - ArrayLike(nb::object obj) : obj(obj){}; + ArrayLike(nb::object obj) : obj(obj) {}; nb::object obj; };