Skip to content

[BUG] test_gather_qmm_sorted consistently fails in some aarch64-darwin environments #3200

@booxter

Description

@booxter

Describe the bug

This library is packaged for nixpkgs (for both darwin and linux). Its test suite is executed as part of package build. When it's executed on Apple Silicon machines I have in my possession, the test suite passes fine. When it's executed on Nix build system (Hydra) machines, it often fails as follows:

pytest flags: -m pytest python/tests/ -k not\ \(test_siblings_without_eval\)
============================= test session starts ==============================
platform darwin -- Python 3.13.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /nix/var/nix/builds/nix-13401-1802369029/source
configfile: pyproject.toml
collected 672 items / 1 deselected / 671 selected                              

python/tests/test_array.py .......................................s..... [  6%]
.......................                                                  [ 10%]
python/tests/test_autograd.py ...............................            [ 14%]
python/tests/test_bf16.py ...s..                                         [ 15%]
python/tests/test_blas.py ........................                       [ 19%]
python/tests/test_compile.py ........................................... [ 25%]
.........                                                                [ 26%]
python/tests/test_constants.py ...                                       [ 27%]
python/tests/test_conv.py .s........ssssssss                             [ 30%]
python/tests/test_conv_transpose.py sssssssss                            [ 31%]
python/tests/test_device.py ..s......s                                   [ 32%]
python/tests/test_double.py ..........                                   [ 34%]
python/tests/test_einsum.py ..........                                   [ 35%]
python/tests/test_eval.py ...........s.                                  [ 37%]
python/tests/test_export_import.py ....s...........                      [ 40%]
python/tests/test_fast.py ssssss..............                           [ 43%]
python/tests/test_fast_sdpa.py .....s..........                          [ 45%]
python/tests/test_fft.py ...ss........                                   [ 47%]
python/tests/test_graph.py .                                             [ 47%]
python/tests/test_init.py ..........                                     [ 49%]
python/tests/test_linalg.py ................                             [ 51%]
python/tests/test_load.py ............                                   [ 53%]
python/tests/test_losses.py ..............                               [ 55%]
python/tests/test_memory.py ..s                                          [ 55%]
python/tests/test_nn.py ................................................ [ 63%]
.............                                                            [ 64%]
python/tests/test_ops.py ............................................... [ 71%]
........................................................................ [ 82%]
.............                                                            [ 84%]
python/tests/test_optimizers.py ....s...................                 [ 88%]
python/tests/test_quantized.py ......F...................                [ 92%]
python/tests/test_random.py ..............                               [ 94%]
python/tests/test_reduce.py ..........                                   [ 95%]
python/tests/test_tree.py ....                                           [ 96%]
python/tests/test_upsample.py s                                          [ 96%]
python/tests/test_vmap.py ........................                       [100%]

=================================== FAILURES ===================================
_____________________ TestQuantized.test_gather_qmm_sorted _____________________

self = <test_quantized.TestQuantized testMethod=test_gather_qmm_sorted>

    def test_gather_qmm_sorted(self):
        def quantize(w, transpose=True, group_size=None, mode="affine"):
            if mode == "affine":
                qw, s, b = mx.quantize(w, group_size=group_size, mode=mode)
            else:
                qw, s = mx.quantize(w, mode=mode)
                b = None
    
            w_hat = mx.dequantize(qw, s, b, group_size=group_size, mode=mode)
            if transpose:
                w_hat = w_hat.swapaxes(-1, -2)
            return w_hat, qw, s, b
    
        def gather_sort(x, indices):
            N, M = indices.shape
            indices = indices.flatten()
            order = mx.argsort(indices)
            inv_order = mx.argsort(order)
            return x.flatten(0, -3)[order // M], indices[order], inv_order
    
        def scatter_unsort(x, inv_order, shape=None):
            x = x[inv_order]
            if shape is not None:
                x = mx.unflatten(x, 0, shape)
            return x
    
        parameters = [
            # L, K, D, E, I, transpose
            (32, 512, 512, 4, 2, True, "affine"),
            (32, 512, 544, 4, 2, True, "mxfp4"),
            (32, 512, 544, 4, 2, True, "nvfp4"),
            (32, 512, 544, 4, 2, True, "mxfp8"),
            (133, 512, 512, 4, 2, True, "affine"),
            (133, 512, 555, 4, 2, True, "affine"),
            (133, 512, 512, 4, 2, True, "affine"),
            (64, 512, 512, 4, 2, False, "affine"),
            (64, 512, 544, 4, 2, False, "mxfp4"),
            (64, 512, 544, 4, 2, False, "nvfp4"),
            (64, 512, 544, 4, 2, False, "mxfp8"),
            (133, 512, 512, 4, 2, False, "affine"),
            (133, 512, 544, 4, 2, False, "affine"),
            (133, 512, 555, 4, 2, False, "affine"),
            (64, 512, 512, 4, 2, False, "affine"),
        ]
    
        key = mx.random.key(0)
        k1, k2, k3 = mx.random.split(key, 3)
        dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
    
        for L, K, D, E, I, transpose, mode in parameters:
            with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
                if mode != "affine":
                    group_size = None
                    dtype = (
                        mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32
                    )
                else:
                    group_size = 64
                    dtype = (
                        mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
                    )
    
                K, D = (K, D) if transpose else (D, K)
                ishape = (L, I)
                xshape = (L, 1, 1, K)
                wshape = (E, D, K) if transpose else (E, K, D)
    
                indices = (mx.random.uniform(shape=ishape, key=k1) * E).astype(
                    mx.uint32
                )
                x = mx.random.normal(xshape, key=k2) / K**0.5
                w = mx.random.normal(wshape, key=k3) / K**0.5
    
                x = x.astype(dtype)
                w = w.astype(dtype)
    
                w, *wq = quantize(
                    w, group_size=group_size, mode=mode, transpose=transpose
                )
    
                y1 = mx.gather_mm(x, w, rhs_indices=indices)
                y2 = mx.gather_qmm(
                    x,
                    *wq,
                    group_size=group_size,
                    mode=mode,
                    transpose=transpose,
                    rhs_indices=indices
                )
                xs, idx, inv_order = gather_sort(x, indices)
                y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
    
                y4 = mx.gather_qmm(
                    xs,
                    *wq,
                    group_size=group_size,
                    mode=mode,
                    rhs_indices=idx,
                    transpose=transpose,
                    sorted_indices=True
                )
                y3 = scatter_unsort(y3, inv_order, indices.shape)
                y4 = scatter_unsort(y4, inv_order, indices.shape)
    
                tol = 1.5e-5 if (dtype == mx.float32) else 2.5e-4
    
                self.assertLess((y1 - y2).abs().max(), tol)
>               self.assertLess((y1 - y3).abs().max(), tol)
E               AssertionError: array(nan, dtype=float32) not less than 1.5e-05

python/tests/test_quantized.py:980: AssertionError
=========================== short test summary info ============================
FAILED python/tests/test_quantized.py::TestQuantized::test_gather_qmm_sorted - AssertionError: array(nan, dtype=float32) not less than 1.5e-05
===== 1 failed, 634 passed, 36 skipped, 1 deselected in 164.14s (0:02:44) ======

The test doesn't fail every time the suite runs, but most of the times.

The execution environment may be somewhat different, but regardless of the machine, the package is always built without Metal support.

All machines - both Hydra and my private machines - have Apple Silicon. Hydra machines may have older M-series CPUs, but I also tried to reproduce the failure on a M1 Air laptop I have, with no luck.

Sadly, we do not have direct access to these machines to run a debugging session, so it makes it hard to debug.

To Reproduce

As I said, sadly we are only able to reproduce it in build environment that is not conducive to debugging or reproduction attempts.

Expected behavior
Tests don't fail.

Desktop (please complete the following information):

  • OS Version: I believe all MacOS Hydra machines were upgraded to 26.x.
  • Version: 0.30.5, but it was happening since forever.

Additional context

Maybe worth noting is that other tests in the test class always pass. Also maybe worth noting is that we skipped another test on aarch64-darwin because of a memory allocator related failure, for which I have a patch that adjusts the test case. I wonder if this slight difference in allocator behavior could point to a memory leak and be relevant here.

We have means to disable tests in nixpkgs package builds, but I am hesitant to sweep it under the rug and wonder if mlx core team members could give an idea on how to fix or debug it. If not a fix, then maybe the test case or suite could be improved to allow to collect more information about this failure, so that the next time we hit it we could have a better idea of what happens.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions