Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,37 @@ void init_array(nb::module_& m) {
})
.def(
"__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); })
.def("__getstate__", &mlx_to_np_array)
.def(
"__getstate__",
[](const mx::array& a) {
auto nd = (a.dtype() == mx::bfloat16)
? mlx_to_np_array(mx::view(a, mx::uint16))
: mlx_to_np_array(a);
return nb::make_tuple(nd, static_cast<uint8_t>(a.dtype().val()));
})
.def(
"__setstate__",
[](mx::array& arr,
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
new (&arr) mx::array(nd_array_to_mlx(state, std::nullopt));
[](mx::array& arr, const nb::tuple& state) {
if (nb::len(state) != 2) {
throw std::invalid_argument(
"Invalid pickle state: expected (ndarray, Dtype::Val)");
}
using ND = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
ND nd = nb::cast<ND>(state[0]);
auto val = static_cast<mx::Dtype::Val>(nb::cast<uint8_t>(state[1]));
if (val == mx::Dtype::Val::bfloat16) {
auto owner = nb::handle(state[0].ptr());
new (&arr) mx::array(nd_array_to_mlx(
ND(nd.data(),
nd.ndim(),
reinterpret_cast<const size_t*>(nd.shape_ptr()),
owner,
nullptr,
nb::bfloat16),
mx::bfloat16));
} else {
new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt));
}
})
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
.def(
Expand Down
3 changes: 1 addition & 2 deletions python/src/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ struct ndarray_traits<mx::float16_t> {
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};

static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind

int check_shape_dim(int64_t dim) {
Expand All @@ -51,6 +49,7 @@ mx::array nd_array_to_mlx(
std::optional<mx::Dtype> dtype) {
// Compute the shape and size
mx::Shape shape;
shape.reserve(nd_array.ndim());
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i)));
}
Expand Down
4 changes: 4 additions & 0 deletions python/src/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
namespace mx = mlx::core;
namespace nb = nanobind;

namespace nanobind {
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind

struct ArrayLike {
ArrayLike(nb::object obj) : obj(obj) {};
nb::object obj;
Expand Down
8 changes: 2 additions & 6 deletions python/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def test_array_repr(self):
self.assertEqual(str(x), expected)

x = mx.array([[1, 2], [1, 2], [1, 2]])
expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)"
expected = "array([[1, 2],\n [1, 2],\n [1, 2]], dtype=int32)"
self.assertEqual(str(x), expected)

x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])
Expand Down Expand Up @@ -886,6 +886,7 @@ def test_array_pickle(self):
mx.uint64,
mx.float16,
mx.float32,
mx.bfloat16,
mx.complex64,
]

Expand All @@ -895,11 +896,6 @@ def test_array_pickle(self):
y = pickle.loads(state)
self.assertEqualArray(y, x)

# check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(TypeError):
pickle.dumps(x)

def test_array_copy(self):
dtypes = [
mx.int8,
Expand Down
Loading