-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Describe the bug
mx.conv_general produces wrong results when total
output elements exceed 2^31 (~2.15 billion).
Likely cause: 32-bit integer index arithmetic in the Metal conv kernel
overflows, causing tail batch entries to be silently corrupted (zeroed
out, or filled with unprocessed input data).
The bug affects any conv_general call where
batch * height_out * width_out * channels_out > 2^31
regardless of kernel size. The 1x1 kernel below is the simplest
trigger, but 3x3 and other kernels are equally affected once the
output tensor is large enough.
Workaround: for 1x1 (pointwise) convolutions, replace conv_general
with an equivalent matmul: y = x @ w.squeeze().T + bias
To Reproduce
Include code snippet
#!/usr/bin/env python3
"""
Minimal reproduction: mx.conv_general produces wrong results when total
output elements exceed 2^31 (~2.15 billion).
Likely cause: 32-bit integer index arithmetic in the Metal conv kernel
overflows, causing tail batch entries to be silently corrupted (zeroed
out, or filled with unprocessed input data).
Environment:
- MLX 0.30.1
- macOS 26.2, Apple M4 Max 128 GB
- Also observed on M-series with earlier MLX versions
The bug affects any conv_general call where
batch * height_out * width_out * channels_out > 2^31
regardless of kernel size. The 1x1 kernel below is the simplest
trigger, but 3x3 and other kernels are equally affected once the
output tensor is large enough.
Workaround: for 1x1 (pointwise) convolutions, replace conv_general
with an equivalent matmul: y = x @ w.squeeze().T + bias
"""
import mlx.core as mx
import numpy as np
INT32_MAX = 2**31 # 2,147,483,648
def run_test(N: int, H: int, W: int, C_in: int, C_out: int):
"""Run conv_general and matmul, compare last-batch-entry results."""
total_out = N * H * W * C_out
mx.random.seed(42)
x = mx.random.normal((N, H, W, C_in))
w = mx.random.normal((C_out, 1, 1, C_in)) * 0.01
bias = mx.zeros((C_out,))
mx.eval(x, w, bias)
# ---- conv_general (buggy when total_out > 2^31) ----
y_conv = mx.conv_general(x, w) + bias
mx.eval(y_conv)
# ---- matmul equivalent (always correct) ----
w_mat = w[:, 0, 0, :] # [C_out, C_in]
y_matmul = x @ w_mat.T + bias # [N, H, W, C_out]
mx.eval(y_matmul)
# Compare per-batch-entry statistics
conv_arr = np.array(y_conv)
mat_arr = np.array(y_matmul)
first_std_conv = conv_arr[0].std()
last_std_conv = conv_arr[-1].std()
last_std_matmul = mat_arr[-1].std()
# Max absolute difference in last batch entry
max_diff = np.abs(conv_arr[-1] - mat_arr[-1]).max()
over = total_out > INT32_MAX # exactly 2^31 is fine; 2^31+1 fails
ok = max_diff < 1e-5
tag = "✅ PASS" if ok else "❌ FAIL"
print(
f" N={N:4d} output_elems={total_out:>14,d} "
f"{'> 2^31' if over else '< 2^31'} "
f"conv_last_std={last_std_conv:.4f} "
f"matmul_last_std={last_std_matmul:.4f} "
f"max_diff={max_diff:.2e} {tag}"
)
return ok
def main():
print("=" * 90)
print("mx.conv_general 2^31 output-element overflow reproduction")
print("=" * 90)
print(f"MLX version: {mx.__version__}")
print(f"2^31 = {INT32_MAX:,}\n")
# --- Test 1: sweep batch size across threshold ---
# H=320, W=176, C_out=512 → per-batch output = 28,835,840
# Threshold N = ceil(2^31 / 28,835,840) = 75
print("Test 1: Sweep batch size (H=320, W=176, C_in=1024, C_out=512)")
print("-" * 90)
for N in [50, 74, 75, 76, 77]:
run_test(N, 320, 176, 1024, 512)
# --- Test 2: simpler shape, same overflow ---
# H=2048, W=1, C_out=1024 → per-batch output = 2,097,152
# Threshold N = ceil(2^31 / 2,097,152) = 1024
print("\nTest 2: Simpler shape (H=2048, W=1, C_in=512, C_out=1024)")
print("-" * 90)
for N in [1000, 1023, 1024, 1025]:
run_test(N, 2048, 1, 512, 1024)
# --- Test 3: large spatial, small batch ---
# N=2, H=32768, W=1, C_out=32768 → output = 2 * 32768 * 32768 = 2,147,483,648 = 2^31
# Just above threshold at N=2
print("\nTest 3: Large spatial, small batch (N=2, H=32768, W=1, C_in=64)")
print("-" * 90)
for C_out in [32000, 32768, 33000]:
run_test(2, 32768, 1, 64, C_out)
print("\n" + "=" * 90)
print("Conclusion: conv_general silently corrupts output when")
print(" batch * H_out * W_out * C_out > 2^31")
print("Workaround: use matmul (x @ w.T + bias) for pointwise convolutions.")
print("=" * 90)
if __name__ == "__main__":
main()Expected behavior
mx.conv_general produces correct results when total output elements exceed 2^31 (~2.15 billion).
Environment:
- MLX 0.30.1
- macOS 26.2, Apple M4 Max 128 GB
- Also observed on M-series with earlier MLX versions
Additional context
Console output of above script:
==========================================================================================
mx.conv_general 2^31 output-element overflow reproduction
==========================================================================================
MLX version: 0.30.1
2^31 = 2,147,483,648
Test 1: Sweep batch size (H=320, W=176, C_in=1024, C_out=512)
------------------------------------------------------------------------------------------
^Bp N= 50 output_elems= 1,441,792,000 < 2^31 conv_last_std=0.3206 matmul_last_std=0.3206 max_diff=0.00e+00 ✅ PASS
N= 74 output_elems= 2,133,852,160 < 2^31 conv_last_std=0.3205 matmul_last_std=0.3205 max_diff=0.00e+00 ✅ PASS
N= 75 output_elems= 2,162,688,000 > 2^31 conv_last_std=0.2203 matmul_last_std=0.3205 max_diff=1.76e+00 ❌ FAIL
N= 76 output_elems= 2,191,523,840 > 2^31 conv_last_std=0.0000 matmul_last_std=0.3206 max_diff=1.95e+00 ❌ FAIL
N= 77 output_elems= 2,220,359,680 > 2^31 conv_last_std=0.0000 matmul_last_std=0.3204 max_diff=1.87e+00 ❌ FAIL
Test 2: Simpler shape (H=2048, W=1, C_in=512, C_out=1024)
------------------------------------------------------------------------------------------
N=1000 output_elems= 2,097,152,000 < 2^31 conv_last_std=0.2267 matmul_last_std=0.2267 max_diff=0.00e+00 ✅ PASS
N=1023 output_elems= 2,145,386,496 < 2^31 conv_last_std=0.2265 matmul_last_std=0.2265 max_diff=0.00e+00 ✅ PASS
N=1024 output_elems= 2,147,483,648 < 2^31 conv_last_std=0.2265 matmul_last_std=0.2265 max_diff=0.00e+00 ✅ PASS
N=1025 output_elems= 2,149,580,800 > 2^31 conv_last_std=0.0000 matmul_last_std=0.2266 max_diff=1.22e+00 ❌ FAIL
Test 3: Large spatial, small batch (N=2, H=32768, W=1, C_in=64)
------------------------------------------------------------------------------------------
N= 2 output_elems= 2,097,152,000 < 2^31 conv_last_std=0.0801 matmul_last_std=0.0801 max_diff=0.00e+00 ✅ PASS
N= 2 output_elems= 2,147,483,648 < 2^31 conv_last_std=0.0801 matmul_last_std=0.0801 max_diff=0.00e+00 ✅ PASS
N= 2 output_elems= 2,162,688,000 > 2^31 conv_last_std=0.0791 matmul_last_std=0.0800 max_diff=4.78e-01 ❌ FAIL
==========================================================================================
Conclusion: conv_general silently corrupts output when
batch * H_out * W_out * C_out > 2^31
Workaround: use matmul (x @ w.T + bias) for pointwise convolutions.
==========================================================================================