diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 92e54f9991..cd0da1875b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -345,6 +345,13 @@ bool AddMM::is_equivalent(const Primitive& other) const { return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_); } +std::vector AddMM::output_shapes(const std::vector& inputs) { + // inputs are {a, b, c}, output shape is a's shape with last dim from b + auto out_shape = inputs[0].shape(); + out_shape.back() = inputs[1].shape(-1); + return {std::move(out_shape)}; +} + std::pair, std::vector> AddMM::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 4091aafcfb..2ce92b90f3 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -200,6 +200,7 @@ class AddMM : public UnaryPrimitive { DEFINE_NAME(AddMM) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; std::pair state() const { return {alpha_, beta_}; }; diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 7db471cc03..20f1145223 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -884,6 +884,34 @@ def test_shapeless_compile_matmul(self): fun = mx.compile(lambda a, b: a @ b, shapeless=True) self.assertTrue(mx.allclose(fun(a, b), a @ b)) + def test_shapeless_compile_addmm(self): + def fun(c, a, b): + return mx.addmm(c, a, b) + + cfun = mx.compile(fun, shapeless=True) + + # First shape + c = mx.ones((2, 4)) + a = mx.ones((2, 3)) + b = mx.ones((3, 4)) + self.assertTrue(mx.allclose(cfun(c, a, b), fun(c, a, b))) + + # Different shape, same ranks — should not recompile + c = mx.ones((3, 5)) + a = mx.ones((3, 6)) + b = mx.ones((6, 5)) + self.assertTrue(mx.allclose(cfun(c, a, b), fun(c, a, b))) + + # With alpha and beta + fun2 = mx.compile( + lambda c, a, b: mx.addmm(c, a, b, alpha=2.0, beta=3.0), shapeless=True + ) + c = mx.ones((2, 4)) + a = mx.ones((2, 3)) + b = mx.ones((3, 4)) + expected = 3.0 * c + 2.0 * (a @ b) + self.assertTrue(mx.allclose(fun2(c, a, b), expected)) + def test_shapeless_compile_slice_update(self): def fun(x): x[2] = mx.array([3.0])