From 38248ab9d4e994a3ec63a828ffabe10092030035 Mon Sep 17 00:00:00 2001 From: Harsh Sutaria Date: Wed, 29 Oct 2025 15:24:40 -0400 Subject: [PATCH 1/3] Fix addmm with empty matrices and beta != 1.0 - Fix CPU backend to apply beta scaling for empty matrices (K=0) - Fix Metal backend to apply beta scaling for empty matrices (K=0) - Add comprehensive tests for empty matrix addmm with various beta values - Tests cover different shapes and batched empty matrices The bug occurred when addmm was called with empty matrices (one dimension is 0) and beta != 1.0. The function would return C instead of beta * C because the early return for empty matrices didn't apply beta scaling. Fixes #2698 --- mlx/backend/cpu/matmul.cpp | 55 ++++++++++++++++++++++++++++++++++-- mlx/backend/metal/matmul.cpp | 26 +++++++++++------ python/tests/test_blas.py | 45 +++++++++++++++++++++++++---- 3 files changed, 110 insertions(+), 16 deletions(-) diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index b346e84db8..a31473d4f1 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -135,15 +135,64 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { return; } + // Handle empty matrix case (K=0) + if (inputs[0].shape(-1) == 0) { + auto& c = inputs[2]; + CopyType ctype = c.data_size() == 1 + ? CopyType::Scalar + : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy_cpu(c, out, ctype, stream()); + if (beta_ != 1.0f) { + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_output_array(out); + encoder.dispatch([out_ptr = out.data(), + size = out.size(), + beta = beta_, + dtype = out.dtype()]() mutable { + switch (dtype) { + case float32: + for (size_t i = 0; i < size; ++i) { + static_cast(out_ptr)[i] *= beta; + } + break; + case float16: + for (size_t i = 0; i < size; ++i) { + auto& val = static_cast(out_ptr)[i]; + val = static_cast(static_cast(val) * beta); + } + break; + case bfloat16: + for (size_t i = 0; i < size; ++i) { + auto& val = static_cast(out_ptr)[i]; + val = static_cast(static_cast(val) * beta); + } + break; + case float64: + for (size_t i = 0; i < size; ++i) { + static_cast(out_ptr)[i] *= beta; + } + break; + case complex64: + for (size_t i = 0; i < size; ++i) { + auto& val = static_cast(out_ptr)[i]; + val = complex64_t(val.real() * beta, val.imag() * beta); + } + break; + default: + throw std::runtime_error( + "[AddMM::eval_cpu] Unsupported dtype for beta scaling"); + } + }); + } + return; + } + // Fill output with C auto& c = inputs[2]; CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy_cpu(c, out, ctype, stream()); - if (inputs[0].shape(-1) == 0) { - return; - } matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index d6bee651d2..49a5b84cf7 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -8,6 +8,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/matmul.h" #include "mlx/backend/gpu/copy.h" +#include "mlx/backend/metal/binary.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" @@ -925,19 +926,28 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { return; } - // Copy c into out and return + auto& s = stream(); + auto& d = metal::device(s.device); + + // Handle empty matrix case (K=0) if (inputs[0].shape(-1) == 0) { - copy_gpu( - inputs[2], - out, - inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General, - stream()); + auto& c = inputs[2]; + if (beta_ == 1.0f) { + copy_gpu( + c, + out, + c.flags().row_contiguous ? CopyType::Vector : CopyType::General, + s); + } else { + out.set_data(allocator::malloc(out.nbytes())); + array beta_scalar = array(beta_, c.dtype()); + binary_op_gpu({c, beta_scalar}, out, "Multiply", s); + d.add_temporary(std::move(beta_scalar), s.index); + } return; } out.set_data(allocator::malloc(out.nbytes())); - auto& s = stream(); - auto& d = metal::device(s.device); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 0cc6e621a1..8e97e1f49d 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -785,11 +785,46 @@ def test_empty_matmul(self): self.assertEqual(out.item(), 1.0) self.assertEqual(out.shape, ()) - a = mx.zeros(shape=(5, 0)) - b = mx.zeros(shape=(0, 5)) - c = mx.random.uniform(shape=(5, 5)) - out = mx.addmm(c, a, b) - self.assertTrue(mx.allclose(out, c)) + a = mx.ones((2, 0)) + b = mx.ones((0, 2)) + c = mx.ones((2, 2)) + + test_cases = [ + (0.0, 1.0), + (0.0, 2.0), + (0.0, 0.5), + (0.0, 0.0), + (1.0, 2.0), + ] + + for alpha, beta in test_cases: + with self.subTest(alpha=alpha, beta=beta): + result = mx.addmm(c, a, b, alpha=alpha, beta=beta) + expected = c * beta # a @ b = 0 for empty matrices + self.assertTrue(mx.allclose(result, expected)) + + shapes_tests = [ + ((3, 0), (0, 3), (3, 3)), + ((5, 0), (0, 5), (5, 5)), + ((1, 0), (0, 10), (1, 10)), + ((10, 0), (0, 1), (10, 1)), + ] + + for shape_a, shape_b, shape_c in shapes_tests: + with self.subTest(shape_a=shape_a, shape_b=shape_b, shape_c=shape_c): + a = mx.ones(shape_a) + b = mx.ones(shape_b) + c = mx.ones(shape_c) + result = mx.addmm(c, a, b, alpha=0.5, beta=2.0) + expected = c * 2.0 + self.assertTrue(mx.allclose(result, expected)) + + a = mx.ones((2, 5, 0)) + b = mx.ones((2, 0, 5)) + c = mx.ones((2, 5, 5)) + result = mx.addmm(c, a, b, alpha=0.0, beta=3.0) + expected = c * 3.0 + self.assertTrue(mx.allclose(result, expected)) def test_block_masked_matmul(self): def ref_block_masked_mm( From 3facaa9bd223fe1bc5c94f56993c586ea005539d Mon Sep 17 00:00:00 2001 From: Harsh Sutaria Date: Thu, 30 Oct 2025 18:35:00 -0400 Subject: [PATCH 2/3] Addressed PR review: use binary operation and remove unnecessary memory allocation --- mlx/backend/cpu/matmul.cpp | 57 ++++++++++++++++-------------------- mlx/backend/metal/matmul.cpp | 1 - 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index a31473d4f1..c3d25994da 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -2,6 +2,8 @@ #include #include "mlx/array.h" +#include "mlx/backend/cpu/binary.h" +#include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/gemm.h" @@ -138,45 +140,38 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { // Handle empty matrix case (K=0) if (inputs[0].shape(-1) == 0) { auto& c = inputs[2]; - CopyType ctype = c.data_size() == 1 - ? CopyType::Scalar - : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); - copy_cpu(c, out, ctype, stream()); - if (beta_ != 1.0f) { + if (beta_ == 1.0f) { + CopyType ctype = c.data_size() == 1 + ? CopyType::Scalar + : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy_cpu(c, out, ctype, stream()); + } else { + array beta_scalar = array(beta_, c.dtype()); + auto bopt = get_binary_op_type(c, beta_scalar); + set_binary_op_output_data(c, beta_scalar, out, bopt); auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(c); + encoder.set_input_array(beta_scalar); encoder.set_output_array(out); - encoder.dispatch([out_ptr = out.data(), - size = out.size(), - beta = beta_, - dtype = out.dtype()]() mutable { - switch (dtype) { - case float32: - for (size_t i = 0; i < size; ++i) { - static_cast(out_ptr)[i] *= beta; - } - break; + encoder.dispatch([c = array::unsafe_weak_copy(c), + beta_scalar = array::unsafe_weak_copy(beta_scalar), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { case float16: - for (size_t i = 0; i < size; ++i) { - auto& val = static_cast(out_ptr)[i]; - val = static_cast(static_cast(val) * beta); - } + binary_op(c, beta_scalar, out, bopt); break; - case bfloat16: - for (size_t i = 0; i < size; ++i) { - auto& val = static_cast(out_ptr)[i]; - val = static_cast(static_cast(val) * beta); - } + case float32: + binary_op(c, beta_scalar, out, bopt); break; case float64: - for (size_t i = 0; i < size; ++i) { - static_cast(out_ptr)[i] *= beta; - } + binary_op(c, beta_scalar, out, bopt); + break; + case bfloat16: + binary_op(c, beta_scalar, out, bopt); break; case complex64: - for (size_t i = 0; i < size; ++i) { - auto& val = static_cast(out_ptr)[i]; - val = complex64_t(val.real() * beta, val.imag() * beta); - } + binary_op(c, beta_scalar, out, bopt); break; default: throw std::runtime_error( diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 49a5b84cf7..abc45575a4 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -939,7 +939,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { c.flags().row_contiguous ? CopyType::Vector : CopyType::General, s); } else { - out.set_data(allocator::malloc(out.nbytes())); array beta_scalar = array(beta_, c.dtype()); binary_op_gpu({c, beta_scalar}, out, "Multiply", s); d.add_temporary(std::move(beta_scalar), s.index); From e4140bd9f101c5d809411aa7c3d3beec90870cab Mon Sep 17 00:00:00 2001 From: Harsh Sutaria Date: Mon, 3 Nov 2025 15:08:40 -0500 Subject: [PATCH 3/3] Fix CPU backend: add temporary to keep beta_scalar alive --- mlx/backend/cpu/matmul.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index c3d25994da..0998c527cb 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -178,6 +178,7 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { "[AddMM::eval_cpu] Unsupported dtype for beta scaling"); } }); + encoder.add_temporary(std::move(beta_scalar)); } return; }