From 4f2bfc5d4c58f2d7e8121b544f5a848c379fe925 Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 09:27:37 -0500 Subject: [PATCH 01/16] np bfloat16 conversion fix --- python/src/array.cpp | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 231474c2d9..61a158a9ab 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -296,8 +296,40 @@ void init_array(nb::module_& m) { nb::is_weak_referenceable()) .def( "__init__", - [](mx::array* aptr, ArrayInitType v, std::optional t) { - new (aptr) mx::array(create_array(v, t)); + [](mx::array* aptr, nb::object v, std::optional t) { + if (nb::hasattr(v, "dtype")) { + if (nb::str(v.attr("dtype")).equal(nb::str("bfloat16"))) { + auto type_mod = nb::str(v.attr("__class__").attr("__module__")); + if (type_mod.equal(nb::str("numpy")) || + type_mod.equal(nb::str("ml_dtypes"))) { + auto np = nb::module_::import_("numpy"); + auto contig_obj = np.attr("ascontiguousarray")(v); + mx::Shape shape; + nb::tuple shape_tuple = nb::cast(v.attr("shape")); + size_t ndim = shape_tuple.size(); + for (size_t i = 0; i < ndim; ++i) { + shape.push_back(nb::cast(shape_tuple[i])); + } + uint64_t ptr_int = nb::cast( + contig_obj.attr("ctypes").attr("data")); + const mx::bfloat16_t* typed_ptr = + reinterpret_cast(ptr_int); + auto res = (ndim == 0) + ? mx::array(*typed_ptr, mx::bfloat16) + : mx::array(typed_ptr, shape, mx::bfloat16); + if (t.has_value()) + res = mx::astype(res, *t); + new (aptr) mx::array(res); + return; + } + } + } + try { + auto v_cast = nb::cast(v); + new (aptr) mx::array(create_array(v_cast, t)); + } catch (const nb::cast_error& e) { + throw std::invalid_argument("Cannot convert to mlx array."); + } }, "val"_a, "dtype"_a = nb::none(), From 52dd487a375303c6931989ce23e4dcfe5801cb61 Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 09:44:44 -0500 Subject: [PATCH 02/16] tests --- python/tests/test_bf16.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py index 2e4e2e0c33..92e8eb94fb 100644 --- a/python/tests/test_bf16.py +++ b/python/tests/test_bf16.py @@ -15,6 +15,12 @@ except ImportError as e: has_torch = False +try: + import ml_dtypes + + has_ml_dtypes = True +except ImportError: + has_ml_dtypes = False class TestBF16(mlx_tests.MLXTestCase): def __test_ops( @@ -190,6 +196,26 @@ def test_conversion(self): expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16) self.assertEqual(a_mx.dtype, mx.bfloat16) self.assertTrue(mx.array_equal(a_mx, expected)) + + @unittest.skipIf(not has_ml_dtypes, "requires ml_dtypes") + def test_conversion_ml_dtypes(self): + x_scalar = np.array(1.5, dtype=ml_dtypes.bfloat16) + a_scalar = mx.array(x_scalar) + self.assertEqual(a_scalar.dtype, mx.bfloat16) + self.assertEqual(a_scalar.shape, ()) + self.assertEqual(a_scalar.item(), 1.5) + + data = [1.5, 2.5, 3.5] + x_vector = np.array(data, dtype=ml_dtypes.bfloat16) + a_vector = mx.array(x_vector) + expected = mx.array(data, dtype=mx.bfloat16) + self.assertEqual(a_vector.dtype, mx.bfloat16) + self.assertEqual(a_vector.shape, (3,)) + self.assertTrue(mx.array_equal(a_vector, expected)) + + a_cast = mx.array(x_scalar, dtype=mx.float32) + self.assertEqual(a_cast.dtype, mx.float32) + self.assertEqual(a_cast.item(), 1.5) if __name__ == "__main__": From 2f6a0f64525a161f78cc790363a8238fc5c94ab1 Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 09:47:04 -0500 Subject: [PATCH 03/16] formatter --- python/tests/test_bf16.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py index 92e8eb94fb..8b5043681c 100644 --- a/python/tests/test_bf16.py +++ b/python/tests/test_bf16.py @@ -22,6 +22,7 @@ except ImportError: has_ml_dtypes = False + class TestBF16(mlx_tests.MLXTestCase): def __test_ops( self, @@ -196,7 +197,7 @@ def test_conversion(self): expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16) self.assertEqual(a_mx.dtype, mx.bfloat16) self.assertTrue(mx.array_equal(a_mx, expected)) - + @unittest.skipIf(not has_ml_dtypes, "requires ml_dtypes") def test_conversion_ml_dtypes(self): x_scalar = np.array(1.5, dtype=ml_dtypes.bfloat16) @@ -215,7 +216,7 @@ def test_conversion_ml_dtypes(self): a_cast = mx.array(x_scalar, dtype=mx.float32) self.assertEqual(a_cast.dtype, mx.float32) - self.assertEqual(a_cast.item(), 1.5) + self.assertEqual(a_cast.item(), 1.5) if __name__ == "__main__": From 1b27fb3e7949a7697204e07c92b795880b96c6f1 Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 13:44:53 -0500 Subject: [PATCH 04/16] move bfloat16 handle to create_array --- python/src/array.cpp | 34 +------------- python/src/convert.cpp | 94 ++++++++++++++++++++++++++------------- python/src/convert.h | 2 +- python/src/ops.cpp | 2 +- python/tests/test_bf16.py | 5 +++ 5 files changed, 70 insertions(+), 67 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 61a158a9ab..75697ca39f 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -297,39 +297,7 @@ void init_array(nb::module_& m) { .def( "__init__", [](mx::array* aptr, nb::object v, std::optional t) { - if (nb::hasattr(v, "dtype")) { - if (nb::str(v.attr("dtype")).equal(nb::str("bfloat16"))) { - auto type_mod = nb::str(v.attr("__class__").attr("__module__")); - if (type_mod.equal(nb::str("numpy")) || - type_mod.equal(nb::str("ml_dtypes"))) { - auto np = nb::module_::import_("numpy"); - auto contig_obj = np.attr("ascontiguousarray")(v); - mx::Shape shape; - nb::tuple shape_tuple = nb::cast(v.attr("shape")); - size_t ndim = shape_tuple.size(); - for (size_t i = 0; i < ndim; ++i) { - shape.push_back(nb::cast(shape_tuple[i])); - } - uint64_t ptr_int = nb::cast( - contig_obj.attr("ctypes").attr("data")); - const mx::bfloat16_t* typed_ptr = - reinterpret_cast(ptr_int); - auto res = (ndim == 0) - ? mx::array(*typed_ptr, mx::bfloat16) - : mx::array(typed_ptr, shape, mx::bfloat16); - if (t.has_value()) - res = mx::astype(res, *t); - new (aptr) mx::array(res); - return; - } - } - } - try { - auto v_cast = nb::cast(v); - new (aptr) mx::array(create_array(v_cast, t)); - } catch (const nb::cast_error& e) { - throw std::invalid_argument("Cannot convert to mlx array."); - } + new (aptr) mx::array(create_array(v, t)); }, "val"_a, "dtype"_a = nb::none(), diff --git a/python/src/convert.cpp b/python/src/convert.cpp index e004ac9cfd..f5e68dcc25 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -454,7 +454,7 @@ mx::array array_from_list_impl(T pl, std::optional dtype) { // `pl` contains mlx arrays std::vector arrays; for (auto l : pl) { - arrays.push_back(create_array(nb::cast(l), dtype)); + arrays.push_back(create_array(nb::cast(l), dtype)); } return mx::stack(arrays); } @@ -467,38 +467,68 @@ mx::array array_from_list(nb::tuple pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } -mx::array create_array(ArrayInitType v, std::optional t) { - if (auto pv = std::get_if(&v); pv) { - return mx::array(nb::cast(*pv), t.value_or(mx::bool_)); - } else if (auto pv = std::get_if(&v); pv) { - auto val = nb::cast(*pv); - auto default_type = (val > std::numeric_limits::max() || - val < std::numeric_limits::min()) - ? mx::int64 - : mx::int32; - return mx::array(val, t.value_or(default_type)); - } else if (auto pv = std::get_if(&v); pv) { - auto out_type = t.value_or(mx::float32); - if (out_type == mx::float64) { - return mx::array(nb::cast(*pv), out_type); +mx::array create_array(nb::object v, std::optional t) { + if (nb::hasattr(v, "dtype")) { + if (nb::str(v.attr("dtype")).equal(nb::str("bfloat16"))) { + auto type_mod = nb::str(v.attr("__class__").attr("__module__")); + if (type_mod.equal(nb::str("numpy")) || + type_mod.equal(nb::str("ml_dtypes"))) { + auto np = nb::module_::import_("numpy"); + auto contig_obj = np.attr("ascontiguousarray")(v); + mx::Shape shape; + nb::tuple shape_tuple = nb::cast(v.attr("shape")); + size_t ndim = shape_tuple.size(); + for (size_t i = 0; i < ndim; ++i) { + shape.push_back(nb::cast(shape_tuple[i])); + } + uint64_t ptr_int = + nb::cast(contig_obj.attr("ctypes").attr("data")); + const mx::bfloat16_t* typed_ptr = + reinterpret_cast(ptr_int); + auto res = (ndim == 0) ? mx::array(*typed_ptr, mx::bfloat16) + : mx::array(typed_ptr, shape, mx::bfloat16); + if (t.has_value()) + res = mx::astype(res, *t); + return res; + } + } + } + try { + auto v_cast = nb::cast(v); + if (auto pv = std::get_if(&v_cast); pv) { + return mx::array(nb::cast(*pv), t.value_or(mx::bool_)); + } else if (auto pv = std::get_if(&v_cast); pv) { + auto val = nb::cast(*pv); + auto default_type = (val > std::numeric_limits::max() || + val < std::numeric_limits::min()) + ? mx::int64 + : mx::int32; + return mx::array(val, t.value_or(default_type)); + } else if (auto pv = std::get_if(&v_cast); pv) { + auto out_type = t.value_or(mx::float32); + if (out_type == mx::float64) { + return mx::array(nb::cast(*pv), out_type); + } else { + return mx::array(nb::cast(*pv), out_type); + } + } else if (auto pv = std::get_if>(&v_cast); pv) { + return mx::array( + static_cast(*pv), t.value_or(mx::complex64)); + } else if (auto pv = std::get_if(&v_cast); pv) { + return array_from_list(*pv, t); + } else if (auto pv = std::get_if(&v_cast); pv) { + return array_from_list(*pv, t); + } else if (auto pv = std::get_if< + nb::ndarray>(&v_cast); + pv) { + return nd_array_to_mlx(*pv, t); + } else if (auto pv = std::get_if(&v_cast); pv) { + return mx::astype(*pv, t.value_or((*pv).dtype())); } else { - return mx::array(nb::cast(*pv), out_type); + auto arr = to_array_with_accessor(std::get(v_cast).obj); + return mx::astype(arr, t.value_or(arr.dtype())); } - } else if (auto pv = std::get_if>(&v); pv) { - return mx::array( - static_cast(*pv), t.value_or(mx::complex64)); - } else if (auto pv = std::get_if(&v); pv) { - return array_from_list(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { - return array_from_list(*pv, t); - } else if (auto pv = std::get_if< - nb::ndarray>(&v); - pv) { - return nd_array_to_mlx(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { - return mx::astype(*pv, t.value_or((*pv).dtype())); - } else { - auto arr = to_array_with_accessor(std::get(v).obj); - return mx::astype(arr, t.value_or(arr.dtype())); + } catch (const nb::cast_error& e) { + throw std::invalid_argument("Cannot convert to mlx array."); } } diff --git a/python/src/convert.h b/python/src/convert.h index 4f8a77abe5..757695dfa9 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -45,6 +45,6 @@ nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); -mx::array create_array(ArrayInitType v, std::optional t); +mx::array create_array(nb::object v, std::optional t); mx::array array_from_list(nb::list pl, std::optional dtype); mx::array array_from_list(nb::tuple pl, std::optional dtype); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index dac4b8f7f7..b1362f6dea 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1719,7 +1719,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "asarray", - [](const ArrayInitType& a, std::optional dtype) { + [](const nb::object& a, std::optional dtype) { return create_array(a, dtype); }, nb::arg(), diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py index 8b5043681c..d25364c14f 100644 --- a/python/tests/test_bf16.py +++ b/python/tests/test_bf16.py @@ -218,6 +218,11 @@ def test_conversion_ml_dtypes(self): self.assertEqual(a_cast.dtype, mx.float32) self.assertEqual(a_cast.item(), 1.5) + a_asarray = mx.asarray(x_vector) + self.assertEqual(a_asarray.dtype, mx.bfloat16) + self.assertEqual(a_asarray.shape, (3,)) + self.assertTrue(mx.array_equal(a_asarray, expected)) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From d0d4a236f81841b0c13a89ef3600c241a71de2ec Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 14:39:53 -0500 Subject: [PATCH 05/16] ci fix for windows --- python/src/convert.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index f5e68dcc25..188912d9d8 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -470,7 +470,8 @@ mx::array array_from_list(nb::tuple pl, std::optional dtype) { mx::array create_array(nb::object v, std::optional t) { if (nb::hasattr(v, "dtype")) { if (nb::str(v.attr("dtype")).equal(nb::str("bfloat16"))) { - auto type_mod = nb::str(v.attr("__class__").attr("__module__")); + auto type_mod_attr = v.attr("__class__").attr("__module__"); + nb::str type_mod(type_mod_attr); if (type_mod.equal(nb::str("numpy")) || type_mod.equal(nb::str("ml_dtypes"))) { auto np = nb::module_::import_("numpy"); From 15ac8edc6445a60eb91a8caf32629ce0ab39c717 Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 15:07:12 -0500 Subject: [PATCH 06/16] another attempt --- python/src/convert.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 188912d9d8..ecac2348bb 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -470,8 +470,7 @@ mx::array array_from_list(nb::tuple pl, std::optional dtype) { mx::array create_array(nb::object v, std::optional t) { if (nb::hasattr(v, "dtype")) { if (nb::str(v.attr("dtype")).equal(nb::str("bfloat16"))) { - auto type_mod_attr = v.attr("__class__").attr("__module__"); - nb::str type_mod(type_mod_attr); + auto type_mod = nb::cast(v.attr("__class__").attr("__module__")); if (type_mod.equal(nb::str("numpy")) || type_mod.equal(nb::str("ml_dtypes"))) { auto np = nb::module_::import_("numpy"); From 95b87651ef9986d6485792f2828959f3b4fff94b Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 15:38:10 -0500 Subject: [PATCH 07/16] another one --- python/src/convert.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index ecac2348bb..037e400b63 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -469,7 +469,7 @@ mx::array array_from_list(nb::tuple pl, std::optional dtype) { mx::array create_array(nb::object v, std::optional t) { if (nb::hasattr(v, "dtype")) { - if (nb::str(v.attr("dtype")).equal(nb::str("bfloat16"))) { + if (nb::cast(v.attr("dtype")).equal(nb::str("bfloat16"))) { auto type_mod = nb::cast(v.attr("__class__").attr("__module__")); if (type_mod.equal(nb::str("numpy")) || type_mod.equal(nb::str("ml_dtypes"))) { From 1ed37d80e76667b0e77c7eba155900daf3127da1 Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 16:35:30 -0500 Subject: [PATCH 08/16] another one --- python/src/convert.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 037e400b63..5acf938ebd 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -528,6 +528,8 @@ mx::array create_array(nb::object v, std::optional t) { auto arr = to_array_with_accessor(std::get(v_cast).obj); return mx::astype(arr, t.value_or(arr.dtype())); } + } catch (const std::bad_cast& e) { + throw std::invalid_argument("Cannot convert to mlx array."); } catch (const nb::cast_error& e) { throw std::invalid_argument("Cannot convert to mlx array."); } From 07223a6eb635f7c75e2caf6444ffed76d3325cff Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 16:39:09 -0500 Subject: [PATCH 09/16] another one --- python/src/convert.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 5acf938ebd..f21f9628af 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -530,7 +530,5 @@ mx::array create_array(nb::object v, std::optional t) { } } catch (const std::bad_cast& e) { throw std::invalid_argument("Cannot convert to mlx array."); - } catch (const nb::cast_error& e) { - throw std::invalid_argument("Cannot convert to mlx array."); } } From 64bafa40de684a7c0585181e67c1468574adda2c Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Thu, 19 Feb 2026 16:59:37 -0500 Subject: [PATCH 10/16] trying to fix Windows complaints --- python/src/convert.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index f21f9628af..ff9a5a113c 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -469,8 +469,10 @@ mx::array array_from_list(nb::tuple pl, std::optional dtype) { mx::array create_array(nb::object v, std::optional t) { if (nb::hasattr(v, "dtype")) { - if (nb::cast(v.attr("dtype")).equal(nb::str("bfloat16"))) { - auto type_mod = nb::cast(v.attr("__class__").attr("__module__")); + nb::object dtype_obj = v.attr("dtype"); + if (nb::str(dtype_obj).equal(nb::str("bfloat16"))) { + nb::object module_obj = v.attr("__class__").attr("__module__"); + auto type_mod = nb::str(module_obj); if (type_mod.equal(nb::str("numpy")) || type_mod.equal(nb::str("ml_dtypes"))) { auto np = nb::module_::import_("numpy"); From 87c27bdf61e07a2eda2e8011ed073dbbb0f96e55 Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Sun, 29 Mar 2026 15:13:48 -0400 Subject: [PATCH 11/16] removes arrayinittype, uses nb::isinstance --- python/src/convert.cpp | 88 ++++++++++++++++++++---------------------- python/src/convert.h | 13 ------- setup.py | 1 + 3 files changed, 42 insertions(+), 60 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index ff9a5a113c..e3aeddea57 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -475,62 +475,56 @@ mx::array create_array(nb::object v, std::optional t) { auto type_mod = nb::str(module_obj); if (type_mod.equal(nb::str("numpy")) || type_mod.equal(nb::str("ml_dtypes"))) { - auto np = nb::module_::import_("numpy"); - auto contig_obj = np.attr("ascontiguousarray")(v); - mx::Shape shape; - nb::tuple shape_tuple = nb::cast(v.attr("shape")); - size_t ndim = shape_tuple.size(); - for (size_t i = 0; i < ndim; ++i) { - shape.push_back(nb::cast(shape_tuple[i])); - } - uint64_t ptr_int = - nb::cast(contig_obj.attr("ctypes").attr("data")); + auto uint16_view = v.attr("view")("uint16"); + using ContigArray = + nb::ndarray; + auto nd_arr = nb::cast(uint16_view); + auto shape = nb::cast(v.attr("shape")); const mx::bfloat16_t* typed_ptr = - reinterpret_cast(ptr_int); - auto res = (ndim == 0) ? mx::array(*typed_ptr, mx::bfloat16) - : mx::array(typed_ptr, shape, mx::bfloat16); + reinterpret_cast(nd_arr.data()); + auto res = (shape.empty()) ? mx::array(*typed_ptr, mx::bfloat16) + : mx::array(typed_ptr, shape, mx::bfloat16); if (t.has_value()) res = mx::astype(res, *t); return res; } } } - try { - auto v_cast = nb::cast(v); - if (auto pv = std::get_if(&v_cast); pv) { - return mx::array(nb::cast(*pv), t.value_or(mx::bool_)); - } else if (auto pv = std::get_if(&v_cast); pv) { - auto val = nb::cast(*pv); - auto default_type = (val > std::numeric_limits::max() || - val < std::numeric_limits::min()) - ? mx::int64 - : mx::int32; - return mx::array(val, t.value_or(default_type)); - } else if (auto pv = std::get_if(&v_cast); pv) { - auto out_type = t.value_or(mx::float32); - if (out_type == mx::float64) { - return mx::array(nb::cast(*pv), out_type); - } else { - return mx::array(nb::cast(*pv), out_type); - } - } else if (auto pv = std::get_if>(&v_cast); pv) { - return mx::array( - static_cast(*pv), t.value_or(mx::complex64)); - } else if (auto pv = std::get_if(&v_cast); pv) { - return array_from_list(*pv, t); - } else if (auto pv = std::get_if(&v_cast); pv) { - return array_from_list(*pv, t); - } else if (auto pv = std::get_if< - nb::ndarray>(&v_cast); - pv) { - return nd_array_to_mlx(*pv, t); - } else if (auto pv = std::get_if(&v_cast); pv) { - return mx::astype(*pv, t.value_or((*pv).dtype())); + if (nb::isinstance(v)) { + return mx::array(nb::cast(v), t.value_or(mx::bool_)); + } else if (nb::isinstance(v)) { + auto val = nb::cast(v); + auto default_type = (val > std::numeric_limits::max() || + val < std::numeric_limits::min()) + ? mx::int64 + : mx::int32; + return mx::array(val, t.value_or(default_type)); + } else if (nb::isinstance(v)) { + auto out_type = t.value_or(mx::float32); + if (out_type == mx::float64) { + return mx::array(nb::cast(v), out_type); } else { - auto arr = to_array_with_accessor(std::get(v_cast).obj); + return mx::array(nb::cast(v), out_type); + } + } else if (PyComplex_Check(v.ptr())) { + return mx::array( + static_cast(nb::cast>(v)), + t.value_or(mx::complex64)); + } else if (nb::isinstance(v)) { + return array_from_list(nb::cast(v), t); + } else if (nb::isinstance(v)) { + return array_from_list(nb::cast(v), t); + } else if (nb::isinstance(v)) { + auto arr = nb::cast(v); + return mx::astype(arr, t.value_or(arr.dtype())); + } else { + try { + using ContigArray = nb::ndarray; + auto nd = nb::cast(v); + return nd_array_to_mlx(nd, t); + } catch (const nb::cast_error&) { + auto arr = to_array_with_accessor(v); return mx::astype(arr, t.value_or(arr.dtype())); } - } catch (const std::bad_cast& e) { - throw std::invalid_argument("Cannot convert to mlx array."); } } diff --git a/python/src/convert.h b/python/src/convert.h index 757695dfa9..ebcf598e8e 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -21,19 +21,6 @@ struct ArrayLike { nb::object obj; }; -using ArrayInitType = std::variant< - nb::bool_, - nb::int_, - nb::float_, - // Must be above ndarray - mx::array, - // Must be above complex - nb::ndarray, - std::complex, - nb::list, - nb::tuple, - ArrayLike>; - mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype); diff --git a/setup.py b/setup.py index 12505bd1d8..e9ef273d9d 100644 --- a/setup.py +++ b/setup.py @@ -234,6 +234,7 @@ def get_tag(self) -> tuple[str, str, str]: "psutil", "torch>=2.9", "typing_extensions", + "ml_dtypes", ], } entry_points = { From 676b9cee589cc890bb88c66b2387643e1f32a254 Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Sun, 29 Mar 2026 21:03:53 -0400 Subject: [PATCH 12/16] refactor --- python/src/convert.cpp | 54 +++++++++++++++++++----------------------- setup.py | 2 +- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index e3aeddea57..4382031969 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -468,28 +468,6 @@ mx::array array_from_list(nb::tuple pl, std::optional dtype) { } mx::array create_array(nb::object v, std::optional t) { - if (nb::hasattr(v, "dtype")) { - nb::object dtype_obj = v.attr("dtype"); - if (nb::str(dtype_obj).equal(nb::str("bfloat16"))) { - nb::object module_obj = v.attr("__class__").attr("__module__"); - auto type_mod = nb::str(module_obj); - if (type_mod.equal(nb::str("numpy")) || - type_mod.equal(nb::str("ml_dtypes"))) { - auto uint16_view = v.attr("view")("uint16"); - using ContigArray = - nb::ndarray; - auto nd_arr = nb::cast(uint16_view); - auto shape = nb::cast(v.attr("shape")); - const mx::bfloat16_t* typed_ptr = - reinterpret_cast(nd_arr.data()); - auto res = (shape.empty()) ? mx::array(*typed_ptr, mx::bfloat16) - : mx::array(typed_ptr, shape, mx::bfloat16); - if (t.has_value()) - res = mx::astype(res, *t); - return res; - } - } - } if (nb::isinstance(v)) { return mx::array(nb::cast(v), t.value_or(mx::bool_)); } else if (nb::isinstance(v)) { @@ -517,14 +495,30 @@ mx::array create_array(nb::object v, std::optional t) { } else if (nb::isinstance(v)) { auto arr = nb::cast(v); return mx::astype(arr, t.value_or(arr.dtype())); + } else if ( + nb::hasattr(v, "dtype") && + nb::str(v.attr("dtype")).equal(nb::str("bfloat16")) && + nb::hasattr(v, "__class__") && + (nb::str(v.attr("__class__").attr("__module__")) + .equal(nb::str("numpy")) || + nb::str(v.attr("__class__").attr("__module__")) + .equal(nb::str("ml_dtypes")))) { + auto uint16_view = v.attr("view")("uint16"); + using ContigArray = + nb::ndarray; + auto nd_arr = nb::cast(uint16_view); + auto shape = nb::cast(v.attr("shape")); + const mx::bfloat16_t* typed_ptr = + reinterpret_cast(nd_arr.data()); + auto res = (shape.empty()) ? mx::array(*typed_ptr, mx::bfloat16) + : mx::array(typed_ptr, shape, mx::bfloat16); + return t.has_value() ? mx::astype(res, *t) : res; + } else if (nb::ndarray_check(v)) { + using ContigArray = nb::ndarray; + auto nd = nb::cast(v); + return nd_array_to_mlx(nd, t); } else { - try { - using ContigArray = nb::ndarray; - auto nd = nb::cast(v); - return nd_array_to_mlx(nd, t); - } catch (const nb::cast_error&) { - auto arr = to_array_with_accessor(v); - return mx::astype(arr, t.value_or(arr.dtype())); - } + auto arr = to_array_with_accessor(v); + return mx::astype(arr, t.value_or(arr.dtype())); } } diff --git a/setup.py b/setup.py index e9ef273d9d..04a905e178 100644 --- a/setup.py +++ b/setup.py @@ -229,12 +229,12 @@ def get_tag(self) -> tuple[str, str, str]: extras = { "dev": [ + "ml_dtypes", "numpy>=2", "pre-commit", "psutil", "torch>=2.9", "typing_extensions", - "ml_dtypes", ], } entry_points = { From 32135b7ba4b520dc8c1ce4869a4901926d82c7aa Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Sun, 29 Mar 2026 21:17:21 -0400 Subject: [PATCH 13/16] windows ci fix --- python/src/convert.cpp | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 4382031969..0e658aad6f 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -495,24 +495,25 @@ mx::array create_array(nb::object v, std::optional t) { } else if (nb::isinstance(v)) { auto arr = nb::cast(v); return mx::astype(arr, t.value_or(arr.dtype())); - } else if ( - nb::hasattr(v, "dtype") && - nb::str(v.attr("dtype")).equal(nb::str("bfloat16")) && - nb::hasattr(v, "__class__") && - (nb::str(v.attr("__class__").attr("__module__")) - .equal(nb::str("numpy")) || - nb::str(v.attr("__class__").attr("__module__")) - .equal(nb::str("ml_dtypes")))) { - auto uint16_view = v.attr("view")("uint16"); - using ContigArray = - nb::ndarray; - auto nd_arr = nb::cast(uint16_view); - auto shape = nb::cast(v.attr("shape")); - const mx::bfloat16_t* typed_ptr = - reinterpret_cast(nd_arr.data()); - auto res = (shape.empty()) ? mx::array(*typed_ptr, mx::bfloat16) - : mx::array(typed_ptr, shape, mx::bfloat16); - return t.has_value() ? mx::astype(res, *t) : res; + } else if (nb::hasattr(v, "dtype")) { + nb::object dtype_obj = v.attr("dtype"); + if (nb::str(dtype_obj).equal(nb::str("bfloat16"))) { + nb::object module_obj = v.attr("__class__").attr("__module__"); + auto type_mod = nb::str(module_obj); + if (type_mod.equal(nb::str("numpy")) || + type_mod.equal(nb::str("ml_dtypes"))) { + auto uint16_view = v.attr("view")("uint16"); + using ContigArray = + nb::ndarray; + auto nd_arr = nb::cast(uint16_view); + auto shape = nb::cast(v.attr("shape")); + const mx::bfloat16_t* typed_ptr = + reinterpret_cast(nd_arr.data()); + auto res = (shape.empty()) ? mx::array(*typed_ptr, mx::bfloat16) + : mx::array(typed_ptr, shape, mx::bfloat16); + return t.has_value() ? mx::astype(res, *t) : res; + } + } } else if (nb::ndarray_check(v)) { using ContigArray = nb::ndarray; auto nd = nb::cast(v); From e4fb3815250386c93216db32b019797516c39226 Mon Sep 17 00:00:00 2001 From: Kellen Sun Date: Sun, 29 Mar 2026 21:24:37 -0400 Subject: [PATCH 14/16] making the compiler happy --- python/src/convert.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 0e658aad6f..3eb2a932a3 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -514,12 +514,12 @@ mx::array create_array(nb::object v, std::optional t) { return t.has_value() ? mx::astype(res, *t) : res; } } - } else if (nb::ndarray_check(v)) { + } + if (nb::ndarray_check(v)) { using ContigArray = nb::ndarray; auto nd = nb::cast(v); return nd_array_to_mlx(nd, t); - } else { - auto arr = to_array_with_accessor(v); - return mx::astype(arr, t.value_or(arr.dtype())); } + auto arr = to_array_with_accessor(v); + return mx::astype(arr, t.value_or(arr.dtype())); } From c78fe114b1bf0fdbba0de7937768a72684925802 Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 30 Mar 2026 14:27:45 +0900 Subject: [PATCH 15/16] Use nd_array_to_mlx for conversion --- python/src/array.cpp | 2 +- python/src/convert.cpp | 58 ++++++++++++++---------------------------- python/src/convert.h | 47 +++++++++++++++++++++++++++++++--- 3 files changed, 64 insertions(+), 43 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 75697ca39f..838a33a47d 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -492,7 +492,7 @@ void init_array(nb::module_& m) { reinterpret_cast(nd.shape_ptr()), owner, nullptr, - nb::bfloat16), + nb::dtype()), mx::bfloat16)); } else { new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt)); diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 3eb2a932a3..fcc88a90df 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -14,17 +14,6 @@ enum PyScalarT { pycomplex = 3, }; -namespace nanobind { -template <> -struct ndarray_traits { - static constexpr bool is_complex = false; - static constexpr bool is_float = true; - static constexpr bool is_bool = false; - static constexpr bool is_int = false; - static constexpr bool is_signed = true; -}; -}; // namespace nanobind - int check_shape_dim(int64_t dim) { if (dim > std::numeric_limits::max()) { throw std::invalid_argument( @@ -46,14 +35,15 @@ mx::array nd_array_to_mlx_contiguous( mx::array nd_array_to_mlx( nb::ndarray nd_array, - std::optional dtype) { + std::optional dtype, + std::optional nb_dtype) { // Compute the shape and size mx::Shape shape; shape.reserve(nd_array.ndim()); for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } - auto type = nd_array.dtype(); + auto type = nb_dtype.value_or(nd_array.dtype()); // Copy data and make array if (type == nb::dtype()) { @@ -86,7 +76,7 @@ mx::array nd_array_to_mlx( } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::float16)); - } else if (type == nb::bfloat16) { + } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::bfloat16)); } else if (type == nb::dtype()) { @@ -495,31 +485,21 @@ mx::array create_array(nb::object v, std::optional t) { } else if (nb::isinstance(v)) { auto arr = nb::cast(v); return mx::astype(arr, t.value_or(arr.dtype())); - } else if (nb::hasattr(v, "dtype")) { - nb::object dtype_obj = v.attr("dtype"); - if (nb::str(dtype_obj).equal(nb::str("bfloat16"))) { - nb::object module_obj = v.attr("__class__").attr("__module__"); - auto type_mod = nb::str(module_obj); - if (type_mod.equal(nb::str("numpy")) || - type_mod.equal(nb::str("ml_dtypes"))) { - auto uint16_view = v.attr("view")("uint16"); - using ContigArray = - nb::ndarray; - auto nd_arr = nb::cast(uint16_view); - auto shape = nb::cast(v.attr("shape")); - const mx::bfloat16_t* typed_ptr = - reinterpret_cast(nd_arr.data()); - auto res = (shape.empty()) ? mx::array(*typed_ptr, mx::bfloat16) - : mx::array(typed_ptr, shape, mx::bfloat16); - return t.has_value() ? mx::astype(res, *t) : res; - } - } - } - if (nb::ndarray_check(v)) { + } else if (nb::ndarray_check(v)) { using ContigArray = nb::ndarray; - auto nd = nb::cast(v); - return nd_array_to_mlx(nd, t); + ContigArray nd; + std::optional nb_dtype; + // Nanobind does not recognize bfloat16 numpy array: + // https://github.com/wjakob/nanobind/discussions/560 + if (nb::str(v.attr("dtype")).equal(nb::str("bfloat16"))) { + nd = nb::cast(v.attr("view")("uint16")); + nb_dtype = nb::dtype(); + } else { + nd = nb::cast(v); + } + return nd_array_to_mlx(nd, t, nb_dtype); + } else { + auto arr = to_array_with_accessor(v); + return mx::astype(arr, t.value_or(arr.dtype())); } - auto arr = to_array_with_accessor(v); - return mx::astype(arr, t.value_or(arr.dtype())); } diff --git a/python/src/convert.h b/python/src/convert.h index ebcf598e8e..9da7a0f715 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -13,8 +13,48 @@ namespace mx = mlx::core; namespace nb = nanobind; namespace nanobind { -static constexpr dlpack::dtype bfloat16{4, 16, 1}; -}; // namespace nanobind + +template <> +struct ndarray_traits { + static constexpr bool is_complex = false; + static constexpr bool is_float = true; + static constexpr bool is_bool = false; + static constexpr bool is_int = false; + static constexpr bool is_signed = true; +}; + +template <> +struct ndarray_traits { + static constexpr bool is_complex = false; + static constexpr bool is_float = true; + static constexpr bool is_bool = false; + static constexpr bool is_int = false; + static constexpr bool is_signed = true; +}; + +namespace detail { + +template <> +struct dtype_traits { + static constexpr dlpack::dtype value{ + /* code */ uint8_t(nb::dlpack::dtype_code::Float), + /* bits */ 16, + /* lanes */ 1}; + static constexpr const char* name = "float16"; +}; + +template <> +struct dtype_traits { + static constexpr dlpack::dtype value{ + /* code */ uint8_t(nb::dlpack::dtype_code::Bfloat), + /* bits */ 16, + /* lanes */ 1}; + static constexpr const char* name = "bfloat16"; +}; + +} // namespace detail + +} // namespace nanobind struct ArrayLike { ArrayLike(nb::object obj) : obj(obj) {}; @@ -23,7 +63,8 @@ struct ArrayLike { mx::array nd_array_to_mlx( nb::ndarray nd_array, - std::optional dtype); + std::optional mx_dtype, + std::optional nb_dtype = std::nullopt); nb::ndarray mlx_to_np_array(const mx::array& a); nb::ndarray<> mlx_to_dlpack(const mx::array& a); From bb8befc669047869d729cabba876d9ca02efbe8d Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 30 Mar 2026 08:35:15 +0100 Subject: [PATCH 16/16] Fix Windows CI --- python/src/convert.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index fcc88a90df..a6dd7fd10a 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -491,7 +491,7 @@ mx::array create_array(nb::object v, std::optional t) { std::optional nb_dtype; // Nanobind does not recognize bfloat16 numpy array: // https://github.com/wjakob/nanobind/discussions/560 - if (nb::str(v.attr("dtype")).equal(nb::str("bfloat16"))) { + if (v.attr("dtype").equal(nb::str("bfloat16"))) { nd = nb::cast(v.attr("view")("uint16")); nb_dtype = nb::dtype(); } else {