From 839585f701022dce3bb8af6b541c815655431ab6 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sun, 2 Nov 2025 17:53:24 +0100 Subject: [PATCH 1/3] shapeless support for zeros/ones_like --- mlx/ops.cpp | 34 ++++++++++++++++++++++++++-------- mlx/ops.h | 11 +++++++++++ mlx/primitives.h | 1 + python/tests/test_compile.py | 13 ++++++++++++- 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 271462d56e..0de9f840ed 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -280,16 +280,20 @@ array copy(array a, StreamOrDevice s /* = {} */) { {std::move(a)}); } +array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) { + return array( + vals.shape(), + dtype, + std::make_shared(to_stream(s)), + {astype(std::move(vals), dtype, s)}); +} + array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) { if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) { throw std::invalid_argument("[full] Negative dimensions not allowed."); } - auto copied_shape = shape; // |shape| will be moved - return array( - std::move(copied_shape), - dtype, - std::make_shared(to_stream(s)), - {broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)}); + return full_impl( + broadcast_to(std::move(vals), std::move(shape), s), dtype, s); } array full(Shape shape, array vals, StreamOrDevice s /* = {} */) { @@ -297,12 +301,26 @@ array full(Shape shape, array vals, StreamOrDevice s /* = {} */) { return full(std::move(shape), std::move(vals), dtype, to_stream(s)); } +array full_like( + const array& a, + array vals, + Dtype dtype, + StreamOrDevice s /* = {} */) { + auto inputs = broadcast_arrays({a, std::move(vals)}, s); + return full_impl(std::move(inputs[1]), dtype, s); +} + +array full_like(const array& a, array vals, StreamOrDevice s /* = {} */) { + auto dtype = vals.dtype(); + return full_like(a, std::move(vals), dtype, to_stream(s)); +} + array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) { return full(shape, array(0, dtype), to_stream(s)); } array zeros_like(const array& a, StreamOrDevice s /* = {} */) { - return zeros(a.shape(), a.dtype(), to_stream(s)); + return full_like(a, 0, a.dtype(), to_stream(s)); } array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) { @@ -310,7 +328,7 @@ array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) { } array ones_like(const array& a, StreamOrDevice s /* = {} */) { - return ones(a.shape(), a.dtype(), to_stream(s)); + return full_like(a, 1, a.dtype(), to_stream(s)); } array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) { diff --git a/mlx/ops.h b/mlx/ops.h index 49c64e74f0..0ae94bc102 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -69,6 +69,17 @@ array full(Shape shape, T val, StreamOrDevice s = {}) { return full(std::move(shape), array(val), to_stream(s)); } +array full_like(const array& a, array vals, Dtype dtype, StreamOrDevice s = {}); +array full_like(const array& a, array vals, StreamOrDevice s = {}); +template +array full_like(const array& a, T val, Dtype dtype, StreamOrDevice s = {}) { + return full_like(a, array(val, dtype), dtype, to_stream(s)); +} +template +array full_like(const array& a, T val, StreamOrDevice s = {}) { + return full_like(a, array(val), to_stream(s)); +} + /** Fill an array of the given shape with zeros. */ array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {}); inline array zeros(const Shape& shape, StreamOrDevice s = {}) { diff --git a/mlx/primitives.h b/mlx/primitives.h index a1ad2425c0..d3260f0f4c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1146,6 +1146,7 @@ class Full : public UnaryPrimitive { DEFINE_GRADS() DEFINE_NAME(Full) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() }; class Gather : public UnaryPrimitive { diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 26132b6285..11256d1307 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -482,6 +482,18 @@ def fun(x): self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32)) + def test_shapeless_compile_full_like(self): + x = mx.zeros((1, 1, 32)) + + def zeros_fun(x): + return mx.zeros_like(x) + + def ones_fun(x): + return mx.ones_like(x) + + self.assertEqual(mx.compile(zeros_fun, shapeless=True)(x).shape, (1, 1, 32)) + self.assertEqual(mx.compile(ones_fun, shapeless=True)(x).shape, (1, 1, 32)) + def test_compile_with_constant(self): # Test float @partial(mx.compile) @@ -842,7 +854,6 @@ def fun(inputs): self.assertTrue(mx.allclose(out, expected)) def test_compile_many_outputs(self): - @mx.compile def fun(arr): arrs = [arr] * 64 From ebd669b79f0cf6f5211aa797ea1ee36e0611d2fe Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Tue, 4 Nov 2025 20:43:00 +0100 Subject: [PATCH 2/3] Improvements --- mlx/ops.cpp | 3 +-- mlx/ops.h | 2 +- python/tests/test_compile.py | 16 +++++++++++++--- tests/ops_tests.cpp | 26 ++++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 6 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0de9f840ed..c8c6635403 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -311,8 +311,7 @@ array full_like( } array full_like(const array& a, array vals, StreamOrDevice s /* = {} */) { - auto dtype = vals.dtype(); - return full_like(a, std::move(vals), dtype, to_stream(s)); + return full_like(a, std::move(vals), a.dtype(), to_stream(s)); } array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) { diff --git a/mlx/ops.h b/mlx/ops.h index 0ae94bc102..6c44c032c0 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -77,7 +77,7 @@ array full_like(const array& a, T val, Dtype dtype, StreamOrDevice s = {}) { } template array full_like(const array& a, T val, StreamOrDevice s = {}) { - return full_like(a, array(val), to_stream(s)); + return full_like(a, array(val, a.dtype()), to_stream(s)); } /** Fill an array of the given shape with zeros. */ diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 11256d1307..1ed5b78194 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -483,7 +483,8 @@ def fun(x): self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32)) def test_shapeless_compile_full_like(self): - x = mx.zeros((1, 1, 32)) + x_shape = (1, 1, 32) + x = mx.zeros((x_shape)) def zeros_fun(x): return mx.zeros_like(x) @@ -491,8 +492,17 @@ def zeros_fun(x): def ones_fun(x): return mx.ones_like(x) - self.assertEqual(mx.compile(zeros_fun, shapeless=True)(x).shape, (1, 1, 32)) - self.assertEqual(mx.compile(ones_fun, shapeless=True)(x).shape, (1, 1, 32)) + compiled_zero_like = mx.compile(zeros_fun, shapeless=True) + compiled_ones_like = mx.compile(ones_fun, shapeless=True) + + self.assertEqual(compiled_zero_like(x).shape, x_shape) + self.assertEqual(compiled_ones_like(x).shape, x_shape) + + y_shape = (2, 2, 16) + y = mx.zeros(y_shape) + + self.assertEqual(compiled_zero_like(y).shape, y_shape) + self.assertEqual(compiled_ones_like(y).shape, y_shape) def test_compile_with_constant(self): # Test float diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 1b95066228..38e64559b3 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2826,6 +2826,32 @@ TEST_CASE("test stack") { stack({x, y}, 0), "All arrays must have the same shape and dtype"); } +TEST_CASE("test full_like") { + auto base_int = array({1, 2, 3}, {3}, int16); + + auto from_array_with_dtype = full_like(base_int, array(7.5f), float16); + auto expected_float16 = array({7.5, 7.5, 7.5}, {3}, float16); + CHECK_EQ(from_array_with_dtype.dtype(), float16); + CHECK(array_equal(from_array_with_dtype, expected_float16).item()); + + auto from_array_default_dtype = full_like(base_int, array(4.0f)); + auto expected_int16 = array({4, 4, 4}, {3}, int16); + CHECK_EQ(from_array_default_dtype.dtype(), int16); + CHECK(array_equal(from_array_default_dtype, expected_int16).item()); + + auto from_scalar_with_dtype = full_like(base_int, 3.25f, float32); + auto expected_float32 = array({3.25f, 3.25f, 3.25f}, {3}, float32); + CHECK_EQ(from_scalar_with_dtype.dtype(), float32); + CHECK(array_equal(from_scalar_with_dtype, expected_float32).item()); + + auto base_float = array({1.0f, 2.0f}, {2}, float32); + auto from_scalar_default_dtype = full_like(base_float, 2); + auto expected_base_float = array({2.0f, 2.0f}, {2}, float32); + CHECK_EQ(from_scalar_default_dtype.dtype(), float32); + CHECK( + array_equal(from_scalar_default_dtype, expected_base_float).item()); +} + TEST_CASE("test eye") { auto eye_3 = eye(3); CHECK_EQ(eye_3.shape(), Shape{3, 3}); From f4e34019d4ffc44ea02dbe2dda5232176b778d46 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Thu, 6 Nov 2025 23:41:40 +0100 Subject: [PATCH 3/3] fix access after moved --- mlx/ops.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c8c6635403..f58c73a6ce 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -285,15 +285,14 @@ array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) { vals.shape(), dtype, std::make_shared(to_stream(s)), - {astype(std::move(vals), dtype, s)}); + {astype(vals, dtype, s)}); } array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) { if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) { throw std::invalid_argument("[full] Negative dimensions not allowed."); } - return full_impl( - broadcast_to(std::move(vals), std::move(shape), s), dtype, s); + return full_impl(broadcast_to(vals, std::move(shape), s), dtype, s); } array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {