diff --git a/python/src/array.cpp b/python/src/array.cpp index 231474c2d9..838a33a47d 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -296,7 +296,7 @@ void init_array(nb::module_& m) { nb::is_weak_referenceable()) .def( "__init__", - [](mx::array* aptr, ArrayInitType v, std::optional t) { + [](mx::array* aptr, nb::object v, std::optional t) { new (aptr) mx::array(create_array(v, t)); }, "val"_a, @@ -492,7 +492,7 @@ void init_array(nb::module_& m) { reinterpret_cast(nd.shape_ptr()), owner, nullptr, - nb::bfloat16), + nb::dtype()), mx::bfloat16)); } else { new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt)); diff --git a/python/src/convert.cpp b/python/src/convert.cpp index e004ac9cfd..a6dd7fd10a 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -14,17 +14,6 @@ enum PyScalarT { pycomplex = 3, }; -namespace nanobind { -template <> -struct ndarray_traits { - static constexpr bool is_complex = false; - static constexpr bool is_float = true; - static constexpr bool is_bool = false; - static constexpr bool is_int = false; - static constexpr bool is_signed = true; -}; -}; // namespace nanobind - int check_shape_dim(int64_t dim) { if (dim > std::numeric_limits::max()) { throw std::invalid_argument( @@ -46,14 +35,15 @@ mx::array nd_array_to_mlx_contiguous( mx::array nd_array_to_mlx( nb::ndarray nd_array, - std::optional dtype) { + std::optional dtype, + std::optional nb_dtype) { // Compute the shape and size mx::Shape shape; shape.reserve(nd_array.ndim()); for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } - auto type = nd_array.dtype(); + auto type = nb_dtype.value_or(nd_array.dtype()); // Copy data and make array if (type == nb::dtype()) { @@ -86,7 +76,7 @@ mx::array nd_array_to_mlx( } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::float16)); - } else if (type == nb::bfloat16) { + } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::bfloat16)); } else if (type == nb::dtype()) { @@ -454,7 +444,7 @@ mx::array array_from_list_impl(T pl, std::optional dtype) { // `pl` contains mlx arrays std::vector arrays; for (auto l : pl) { - arrays.push_back(create_array(nb::cast(l), dtype)); + arrays.push_back(create_array(nb::cast(l), dtype)); } return mx::stack(arrays); } @@ -467,38 +457,49 @@ mx::array array_from_list(nb::tuple pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } -mx::array create_array(ArrayInitType v, std::optional t) { - if (auto pv = std::get_if(&v); pv) { - return mx::array(nb::cast(*pv), t.value_or(mx::bool_)); - } else if (auto pv = std::get_if(&v); pv) { - auto val = nb::cast(*pv); +mx::array create_array(nb::object v, std::optional t) { + if (nb::isinstance(v)) { + return mx::array(nb::cast(v), t.value_or(mx::bool_)); + } else if (nb::isinstance(v)) { + auto val = nb::cast(v); auto default_type = (val > std::numeric_limits::max() || val < std::numeric_limits::min()) ? mx::int64 : mx::int32; return mx::array(val, t.value_or(default_type)); - } else if (auto pv = std::get_if(&v); pv) { + } else if (nb::isinstance(v)) { auto out_type = t.value_or(mx::float32); if (out_type == mx::float64) { - return mx::array(nb::cast(*pv), out_type); + return mx::array(nb::cast(v), out_type); } else { - return mx::array(nb::cast(*pv), out_type); + return mx::array(nb::cast(v), out_type); } - } else if (auto pv = std::get_if>(&v); pv) { + } else if (PyComplex_Check(v.ptr())) { return mx::array( - static_cast(*pv), t.value_or(mx::complex64)); - } else if (auto pv = std::get_if(&v); pv) { - return array_from_list(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { - return array_from_list(*pv, t); - } else if (auto pv = std::get_if< - nb::ndarray>(&v); - pv) { - return nd_array_to_mlx(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { - return mx::astype(*pv, t.value_or((*pv).dtype())); + static_cast(nb::cast>(v)), + t.value_or(mx::complex64)); + } else if (nb::isinstance(v)) { + return array_from_list(nb::cast(v), t); + } else if (nb::isinstance(v)) { + return array_from_list(nb::cast(v), t); + } else if (nb::isinstance(v)) { + auto arr = nb::cast(v); + return mx::astype(arr, t.value_or(arr.dtype())); + } else if (nb::ndarray_check(v)) { + using ContigArray = nb::ndarray; + ContigArray nd; + std::optional nb_dtype; + // Nanobind does not recognize bfloat16 numpy array: + // https://github.com/wjakob/nanobind/discussions/560 + if (v.attr("dtype").equal(nb::str("bfloat16"))) { + nd = nb::cast(v.attr("view")("uint16")); + nb_dtype = nb::dtype(); + } else { + nd = nb::cast(v); + } + return nd_array_to_mlx(nd, t, nb_dtype); } else { - auto arr = to_array_with_accessor(std::get(v).obj); + auto arr = to_array_with_accessor(v); return mx::astype(arr, t.value_or(arr.dtype())); } } diff --git a/python/src/convert.h b/python/src/convert.h index 4f8a77abe5..9da7a0f715 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -13,30 +13,58 @@ namespace mx = mlx::core; namespace nb = nanobind; namespace nanobind { -static constexpr dlpack::dtype bfloat16{4, 16, 1}; -}; // namespace nanobind + +template <> +struct ndarray_traits { + static constexpr bool is_complex = false; + static constexpr bool is_float = true; + static constexpr bool is_bool = false; + static constexpr bool is_int = false; + static constexpr bool is_signed = true; +}; + +template <> +struct ndarray_traits { + static constexpr bool is_complex = false; + static constexpr bool is_float = true; + static constexpr bool is_bool = false; + static constexpr bool is_int = false; + static constexpr bool is_signed = true; +}; + +namespace detail { + +template <> +struct dtype_traits { + static constexpr dlpack::dtype value{ + /* code */ uint8_t(nb::dlpack::dtype_code::Float), + /* bits */ 16, + /* lanes */ 1}; + static constexpr const char* name = "float16"; +}; + +template <> +struct dtype_traits { + static constexpr dlpack::dtype value{ + /* code */ uint8_t(nb::dlpack::dtype_code::Bfloat), + /* bits */ 16, + /* lanes */ 1}; + static constexpr const char* name = "bfloat16"; +}; + +} // namespace detail + +} // namespace nanobind struct ArrayLike { ArrayLike(nb::object obj) : obj(obj) {}; nb::object obj; }; -using ArrayInitType = std::variant< - nb::bool_, - nb::int_, - nb::float_, - // Must be above ndarray - mx::array, - // Must be above complex - nb::ndarray, - std::complex, - nb::list, - nb::tuple, - ArrayLike>; - mx::array nd_array_to_mlx( nb::ndarray nd_array, - std::optional dtype); + std::optional mx_dtype, + std::optional nb_dtype = std::nullopt); nb::ndarray mlx_to_np_array(const mx::array& a); nb::ndarray<> mlx_to_dlpack(const mx::array& a); @@ -45,6 +73,6 @@ nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); -mx::array create_array(ArrayInitType v, std::optional t); +mx::array create_array(nb::object v, std::optional t); mx::array array_from_list(nb::list pl, std::optional dtype); mx::array array_from_list(nb::tuple pl, std::optional dtype); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index dac4b8f7f7..b1362f6dea 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1719,7 +1719,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "asarray", - [](const ArrayInitType& a, std::optional dtype) { + [](const nb::object& a, std::optional dtype) { return create_array(a, dtype); }, nb::arg(), diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py index 2e4e2e0c33..d25364c14f 100644 --- a/python/tests/test_bf16.py +++ b/python/tests/test_bf16.py @@ -15,6 +15,13 @@ except ImportError as e: has_torch = False +try: + import ml_dtypes + + has_ml_dtypes = True +except ImportError: + has_ml_dtypes = False + class TestBF16(mlx_tests.MLXTestCase): def __test_ops( @@ -191,6 +198,31 @@ def test_conversion(self): self.assertEqual(a_mx.dtype, mx.bfloat16) self.assertTrue(mx.array_equal(a_mx, expected)) + @unittest.skipIf(not has_ml_dtypes, "requires ml_dtypes") + def test_conversion_ml_dtypes(self): + x_scalar = np.array(1.5, dtype=ml_dtypes.bfloat16) + a_scalar = mx.array(x_scalar) + self.assertEqual(a_scalar.dtype, mx.bfloat16) + self.assertEqual(a_scalar.shape, ()) + self.assertEqual(a_scalar.item(), 1.5) + + data = [1.5, 2.5, 3.5] + x_vector = np.array(data, dtype=ml_dtypes.bfloat16) + a_vector = mx.array(x_vector) + expected = mx.array(data, dtype=mx.bfloat16) + self.assertEqual(a_vector.dtype, mx.bfloat16) + self.assertEqual(a_vector.shape, (3,)) + self.assertTrue(mx.array_equal(a_vector, expected)) + + a_cast = mx.array(x_scalar, dtype=mx.float32) + self.assertEqual(a_cast.dtype, mx.float32) + self.assertEqual(a_cast.item(), 1.5) + + a_asarray = mx.asarray(x_vector) + self.assertEqual(a_asarray.dtype, mx.bfloat16) + self.assertEqual(a_asarray.shape, (3,)) + self.assertTrue(mx.array_equal(a_asarray, expected)) + if __name__ == "__main__": mlx_tests.MLXTestRunner() diff --git a/setup.py b/setup.py index 12505bd1d8..04a905e178 100644 --- a/setup.py +++ b/setup.py @@ -229,6 +229,7 @@ def get_tag(self) -> tuple[str, str, str]: extras = { "dev": [ + "ml_dtypes", "numpy>=2", "pre-commit", "psutil",