Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ void init_array(nb::module_& m) {
nb::is_weak_referenceable())
.def(
"__init__",
[](mx::array* aptr, ArrayInitType v, std::optional<mx::Dtype> t) {
[](mx::array* aptr, nb::object v, std::optional<mx::Dtype> t) {
new (aptr) mx::array(create_array(v, t));
},
"val"_a,
Expand Down Expand Up @@ -492,7 +492,7 @@ void init_array(nb::module_& m) {
reinterpret_cast<const size_t*>(nd.shape_ptr()),
owner,
nullptr,
nb::bfloat16),
nb::dtype<mx::bfloat16_t>()),
mx::bfloat16));
} else {
new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt));
Expand Down
73 changes: 37 additions & 36 deletions python/src/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@ enum PyScalarT {
pycomplex = 3,
};

namespace nanobind {
template <>
struct ndarray_traits<mx::float16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
}; // namespace nanobind

int check_shape_dim(int64_t dim) {
if (dim > std::numeric_limits<int>::max()) {
throw std::invalid_argument(
Expand All @@ -46,14 +35,15 @@ mx::array nd_array_to_mlx_contiguous(

mx::array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<mx::Dtype> dtype) {
std::optional<mx::Dtype> dtype,
std::optional<nb::dlpack::dtype> nb_dtype) {
// Compute the shape and size
mx::Shape shape;
shape.reserve(nd_array.ndim());
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i)));
}
auto type = nd_array.dtype();
auto type = nb_dtype.value_or(nd_array.dtype());

// Copy data and make array
if (type == nb::dtype<bool>()) {
Expand Down Expand Up @@ -86,7 +76,7 @@ mx::array nd_array_to_mlx(
} else if (type == nb::dtype<mx::float16_t>()) {
return nd_array_to_mlx_contiguous<mx::float16_t>(
nd_array, shape, dtype.value_or(mx::float16));
} else if (type == nb::bfloat16) {
} else if (type == nb::dtype<mx::bfloat16_t>()) {
return nd_array_to_mlx_contiguous<mx::bfloat16_t>(
nd_array, shape, dtype.value_or(mx::bfloat16));
} else if (type == nb::dtype<float>()) {
Expand Down Expand Up @@ -454,7 +444,7 @@ mx::array array_from_list_impl(T pl, std::optional<mx::Dtype> dtype) {
// `pl` contains mlx arrays
std::vector<mx::array> arrays;
for (auto l : pl) {
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
arrays.push_back(create_array(nb::cast<nb::object>(l), dtype));
}
return mx::stack(arrays);
}
Expand All @@ -467,38 +457,49 @@ mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype) {
return array_from_list_impl(pl, dtype);
}

mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto val = nb::cast<int64_t>(*pv);
mx::array create_array(nb::object v, std::optional<mx::Dtype> t) {
if (nb::isinstance<nb::bool_>(v)) {
return mx::array(nb::cast<bool>(v), t.value_or(mx::bool_));
} else if (nb::isinstance<nb::int_>(v)) {
auto val = nb::cast<int64_t>(v);
auto default_type = (val > std::numeric_limits<int>::max() ||
val < std::numeric_limits<int>::min())
? mx::int64
: mx::int32;
return mx::array(val, t.value_or(default_type));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
} else if (nb::isinstance<nb::float_>(v)) {
auto out_type = t.value_or(mx::float32);
if (out_type == mx::float64) {
return mx::array(nb::cast<double>(*pv), out_type);
return mx::array(nb::cast<double>(v), out_type);
} else {
return mx::array(nb::cast<float>(*pv), out_type);
return mx::array(nb::cast<float>(v), out_type);
}
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
} else if (PyComplex_Check(v.ptr())) {
return mx::array(
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
pv) {
return nd_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
return mx::astype(*pv, t.value_or((*pv).dtype()));
static_cast<mx::complex64_t>(nb::cast<std::complex<float>>(v)),
t.value_or(mx::complex64));
} else if (nb::isinstance<nb::list>(v)) {
return array_from_list(nb::cast<nb::list>(v), t);
} else if (nb::isinstance<nb::tuple>(v)) {
return array_from_list(nb::cast<nb::tuple>(v), t);
} else if (nb::isinstance<mx::array>(v)) {
auto arr = nb::cast<mx::array>(v);
return mx::astype(arr, t.value_or(arr.dtype()));
} else if (nb::ndarray_check(v)) {
using ContigArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
ContigArray nd;
std::optional<nb::dlpack::dtype> nb_dtype;
// Nanobind does not recognize bfloat16 numpy array:
// https://github.com/wjakob/nanobind/discussions/560
if (v.attr("dtype").equal(nb::str("bfloat16"))) {
nd = nb::cast<ContigArray>(v.attr("view")("uint16"));
nb_dtype = nb::dtype<mx::bfloat16_t>();
} else {
nd = nb::cast<ContigArray>(v);
}
return nd_array_to_mlx(nd, t, nb_dtype);
} else {
auto arr = to_array_with_accessor(std::get<ArrayLike>(v).obj);
auto arr = to_array_with_accessor(v);
return mx::astype(arr, t.value_or(arr.dtype()));
}
}
62 changes: 45 additions & 17 deletions python/src/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,58 @@ namespace mx = mlx::core;
namespace nb = nanobind;

namespace nanobind {
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind

template <>
struct ndarray_traits<mx::float16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};

template <>
struct ndarray_traits<mx::bfloat16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};

namespace detail {

template <>
struct dtype_traits<mx::float16_t> {
static constexpr dlpack::dtype value{
/* code */ uint8_t(nb::dlpack::dtype_code::Float),
/* bits */ 16,
/* lanes */ 1};
static constexpr const char* name = "float16";
};

template <>
struct dtype_traits<mx::bfloat16_t> {
static constexpr dlpack::dtype value{
/* code */ uint8_t(nb::dlpack::dtype_code::Bfloat),
/* bits */ 16,
/* lanes */ 1};
static constexpr const char* name = "bfloat16";
};

} // namespace detail

} // namespace nanobind

struct ArrayLike {
ArrayLike(nb::object obj) : obj(obj) {};
nb::object obj;
};

using ArrayInitType = std::variant<
nb::bool_,
nb::int_,
nb::float_,
// Must be above ndarray
mx::array,
// Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>,
nb::list,
nb::tuple,
ArrayLike>;

mx::array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<mx::Dtype> dtype);
std::optional<mx::Dtype> mx_dtype,
std::optional<nb::dlpack::dtype> nb_dtype = std::nullopt);

nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a);
nb::ndarray<> mlx_to_dlpack(const mx::array& a);
Expand All @@ -45,6 +73,6 @@ nb::object to_scalar(mx::array& a);

nb::object tolist(mx::array& a);

mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t);
mx::array create_array(nb::object v, std::optional<mx::Dtype> t);
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype);
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype);
2 changes: 1 addition & 1 deletion python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1719,7 +1719,7 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"asarray",
[](const ArrayInitType& a, std::optional<mx::Dtype> dtype) {
[](const nb::object& a, std::optional<mx::Dtype> dtype) {
return create_array(a, dtype);
},
nb::arg(),
Expand Down
32 changes: 32 additions & 0 deletions python/tests/test_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
except ImportError as e:
has_torch = False

try:
import ml_dtypes

has_ml_dtypes = True
except ImportError:
has_ml_dtypes = False


class TestBF16(mlx_tests.MLXTestCase):
def __test_ops(
Expand Down Expand Up @@ -191,6 +198,31 @@ def test_conversion(self):
self.assertEqual(a_mx.dtype, mx.bfloat16)
self.assertTrue(mx.array_equal(a_mx, expected))

@unittest.skipIf(not has_ml_dtypes, "requires ml_dtypes")
def test_conversion_ml_dtypes(self):
x_scalar = np.array(1.5, dtype=ml_dtypes.bfloat16)
a_scalar = mx.array(x_scalar)
self.assertEqual(a_scalar.dtype, mx.bfloat16)
self.assertEqual(a_scalar.shape, ())
self.assertEqual(a_scalar.item(), 1.5)

data = [1.5, 2.5, 3.5]
x_vector = np.array(data, dtype=ml_dtypes.bfloat16)
a_vector = mx.array(x_vector)
expected = mx.array(data, dtype=mx.bfloat16)
self.assertEqual(a_vector.dtype, mx.bfloat16)
self.assertEqual(a_vector.shape, (3,))
self.assertTrue(mx.array_equal(a_vector, expected))

a_cast = mx.array(x_scalar, dtype=mx.float32)
self.assertEqual(a_cast.dtype, mx.float32)
self.assertEqual(a_cast.item(), 1.5)

a_asarray = mx.asarray(x_vector)
self.assertEqual(a_asarray.dtype, mx.bfloat16)
self.assertEqual(a_asarray.shape, (3,))
self.assertTrue(mx.array_equal(a_asarray, expected))


if __name__ == "__main__":
mlx_tests.MLXTestRunner()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def get_tag(self) -> tuple[str, str, str]:

extras = {
"dev": [
"ml_dtypes",
"numpy>=2",
"pre-commit",
"psutil",
Expand Down
Loading