-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Describe the bug
This library is packaged for nixpkgs (for both darwin and linux). Its test suite is executed as part of package build. When it's executed on Apple Silicon machines I have in my possession, the test suite passes fine. When it's executed on Nix build system (Hydra) machines, it often fails as follows:
pytest flags: -m pytest python/tests/ -k not\ \(test_siblings_without_eval\)
============================= test session starts ==============================
platform darwin -- Python 3.13.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /nix/var/nix/builds/nix-13401-1802369029/source
configfile: pyproject.toml
collected 672 items / 1 deselected / 671 selected
python/tests/test_array.py .......................................s..... [ 6%]
....................... [ 10%]
python/tests/test_autograd.py ............................... [ 14%]
python/tests/test_bf16.py ...s.. [ 15%]
python/tests/test_blas.py ........................ [ 19%]
python/tests/test_compile.py ........................................... [ 25%]
......... [ 26%]
python/tests/test_constants.py ... [ 27%]
python/tests/test_conv.py .s........ssssssss [ 30%]
python/tests/test_conv_transpose.py sssssssss [ 31%]
python/tests/test_device.py ..s......s [ 32%]
python/tests/test_double.py .......... [ 34%]
python/tests/test_einsum.py .......... [ 35%]
python/tests/test_eval.py ...........s. [ 37%]
python/tests/test_export_import.py ....s........... [ 40%]
python/tests/test_fast.py ssssss.............. [ 43%]
python/tests/test_fast_sdpa.py .....s.......... [ 45%]
python/tests/test_fft.py ...ss........ [ 47%]
python/tests/test_graph.py . [ 47%]
python/tests/test_init.py .......... [ 49%]
python/tests/test_linalg.py ................ [ 51%]
python/tests/test_load.py ............ [ 53%]
python/tests/test_losses.py .............. [ 55%]
python/tests/test_memory.py ..s [ 55%]
python/tests/test_nn.py ................................................ [ 63%]
............. [ 64%]
python/tests/test_ops.py ............................................... [ 71%]
........................................................................ [ 82%]
............. [ 84%]
python/tests/test_optimizers.py ....s................... [ 88%]
python/tests/test_quantized.py ......F................... [ 92%]
python/tests/test_random.py .............. [ 94%]
python/tests/test_reduce.py .......... [ 95%]
python/tests/test_tree.py .... [ 96%]
python/tests/test_upsample.py s [ 96%]
python/tests/test_vmap.py ........................ [100%]
=================================== FAILURES ===================================
_____________________ TestQuantized.test_gather_qmm_sorted _____________________
self = <test_quantized.TestQuantized testMethod=test_gather_qmm_sorted>
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=None, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, mode=mode)
else:
qw, s = mx.quantize(w, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
parameters = [
# L, K, D, E, I, transpose
(32, 512, 512, 4, 2, True, "affine"),
(32, 512, 544, 4, 2, True, "mxfp4"),
(32, 512, 544, 4, 2, True, "nvfp4"),
(32, 512, 544, 4, 2, True, "mxfp8"),
(133, 512, 512, 4, 2, True, "affine"),
(133, 512, 555, 4, 2, True, "affine"),
(133, 512, 512, 4, 2, True, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
(64, 512, 544, 4, 2, False, "mxfp4"),
(64, 512, 544, 4, 2, False, "nvfp4"),
(64, 512, 544, 4, 2, False, "mxfp8"),
(133, 512, 512, 4, 2, False, "affine"),
(133, 512, 544, 4, 2, False, "affine"),
(133, 512, 555, 4, 2, False, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
]
key = mx.random.key(0)
k1, k2, k3 = mx.random.split(key, 3)
dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
for L, K, D, E, I, transpose, mode in parameters:
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
if mode != "affine":
group_size = None
dtype = (
mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32
)
else:
group_size = 64
dtype = (
mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
)
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
indices = (mx.random.uniform(shape=ishape, key=k1) * E).astype(
mx.uint32
)
x = mx.random.normal(xshape, key=k2) / K**0.5
w = mx.random.normal(wshape, key=k3) / K**0.5
x = x.astype(dtype)
w = w.astype(dtype)
w, *wq = quantize(
w, group_size=group_size, mode=mode, transpose=transpose
)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(
x,
*wq,
group_size=group_size,
mode=mode,
transpose=transpose,
rhs_indices=indices
)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs,
*wq,
group_size=group_size,
mode=mode,
rhs_indices=idx,
transpose=transpose,
sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)
tol = 1.5e-5 if (dtype == mx.float32) else 2.5e-4
self.assertLess((y1 - y2).abs().max(), tol)
> self.assertLess((y1 - y3).abs().max(), tol)
E AssertionError: array(nan, dtype=float32) not less than 1.5e-05
python/tests/test_quantized.py:980: AssertionError
=========================== short test summary info ============================
FAILED python/tests/test_quantized.py::TestQuantized::test_gather_qmm_sorted - AssertionError: array(nan, dtype=float32) not less than 1.5e-05
===== 1 failed, 634 passed, 36 skipped, 1 deselected in 164.14s (0:02:44) ======
The test doesn't fail every time the suite runs, but most of the times.
The execution environment may be somewhat different, but regardless of the machine, the package is always built without Metal support.
All machines - both Hydra and my private machines - have Apple Silicon. Hydra machines may have older M-series CPUs, but I also tried to reproduce the failure on a M1 Air laptop I have, with no luck.
Sadly, we do not have direct access to these machines to run a debugging session, so it makes it hard to debug.
To Reproduce
As I said, sadly we are only able to reproduce it in build environment that is not conducive to debugging or reproduction attempts.
Expected behavior
Tests don't fail.
Desktop (please complete the following information):
- OS Version: I believe all MacOS Hydra machines were upgraded to 26.x.
- Version: 0.30.5, but it was happening since forever.
Additional context
Maybe worth noting is that other tests in the test class always pass. Also maybe worth noting is that we skipped another test on aarch64-darwin because of a memory allocator related failure, for which I have a patch that adjusts the test case. I wonder if this slight difference in allocator behavior could point to a memory leak and be relevant here.
We have means to disable tests in nixpkgs package builds, but I am hesitant to sweep it under the rug and wonder if mlx core team members could give an idea on how to fix or debug it. If not a fix, then maybe the test case or suite could be improved to allow to collect more information about this failure, so that the next time we hit it we could have a better idea of what happens.