Skip to content

[BUG] mx.conv_general produces wrong results when total output elements exceed 2^31 (~2.15 billion). #3248

@dmunch

Description

@dmunch

Describe the bug

mx.conv_general produces wrong results when total
output elements exceed 2^31 (~2.15 billion).

Likely cause: 32-bit integer index arithmetic in the Metal conv kernel
overflows, causing tail batch entries to be silently corrupted (zeroed
out, or filled with unprocessed input data).

The bug affects any conv_general call where
batch * height_out * width_out * channels_out > 2^31
regardless of kernel size. The 1x1 kernel below is the simplest
trigger, but 3x3 and other kernels are equally affected once the
output tensor is large enough.

Workaround: for 1x1 (pointwise) convolutions, replace conv_general
with an equivalent matmul: y = x @ w.squeeze().T + bias

To Reproduce

Include code snippet

#!/usr/bin/env python3
"""
Minimal reproduction: mx.conv_general produces wrong results when total
output elements exceed 2^31 (~2.15 billion).

Likely cause: 32-bit integer index arithmetic in the Metal conv kernel
overflows, causing tail batch entries to be silently corrupted (zeroed
out, or filled with unprocessed input data).

Environment:
  - MLX 0.30.1
  - macOS 26.2, Apple M4 Max 128 GB
  - Also observed on M-series with earlier MLX versions

The bug affects any conv_general call where
  batch * height_out * width_out * channels_out > 2^31
regardless of kernel size.  The 1x1 kernel below is the simplest
trigger, but 3x3 and other kernels are equally affected once the
output tensor is large enough.

Workaround: for 1x1 (pointwise) convolutions, replace conv_general
with an equivalent matmul:  y = x @ w.squeeze().T + bias
"""

import mlx.core as mx
import numpy as np

INT32_MAX = 2**31  # 2,147,483,648


def run_test(N: int, H: int, W: int, C_in: int, C_out: int):
    """Run conv_general and matmul, compare last-batch-entry results."""
    total_out = N * H * W * C_out

    mx.random.seed(42)
    x = mx.random.normal((N, H, W, C_in))
    w = mx.random.normal((C_out, 1, 1, C_in)) * 0.01
    bias = mx.zeros((C_out,))
    mx.eval(x, w, bias)

    # ---- conv_general (buggy when total_out > 2^31) ----
    y_conv = mx.conv_general(x, w) + bias
    mx.eval(y_conv)

    # ---- matmul equivalent (always correct) ----
    w_mat = w[:, 0, 0, :]          # [C_out, C_in]
    y_matmul = x @ w_mat.T + bias  # [N, H, W, C_out]
    mx.eval(y_matmul)

    # Compare per-batch-entry statistics
    conv_arr = np.array(y_conv)
    mat_arr = np.array(y_matmul)

    first_std_conv = conv_arr[0].std()
    last_std_conv = conv_arr[-1].std()
    last_std_matmul = mat_arr[-1].std()

    # Max absolute difference in last batch entry
    max_diff = np.abs(conv_arr[-1] - mat_arr[-1]).max()

    over = total_out > INT32_MAX  # exactly 2^31 is fine; 2^31+1 fails
    ok = max_diff < 1e-5

    tag = "✅ PASS" if ok else "❌ FAIL"
    print(
        f"  N={N:4d}  output_elems={total_out:>14,d}  "
        f"{'> 2^31' if over else '< 2^31'}  "
        f"conv_last_std={last_std_conv:.4f}  "
        f"matmul_last_std={last_std_matmul:.4f}  "
        f"max_diff={max_diff:.2e}  {tag}"
    )
    return ok


def main():
    print("=" * 90)
    print("mx.conv_general 2^31 output-element overflow reproduction")
    print("=" * 90)
    print(f"MLX version: {mx.__version__}")
    print(f"2^31 = {INT32_MAX:,}\n")

    # --- Test 1: sweep batch size across threshold ---
    # H=320, W=176, C_out=512 → per-batch output = 28,835,840
    # Threshold N = ceil(2^31 / 28,835,840) = 75
    print("Test 1: Sweep batch size (H=320, W=176, C_in=1024, C_out=512)")
    print("-" * 90)
    for N in [50, 74, 75, 76, 77]:
        run_test(N, 320, 176, 1024, 512)

    # --- Test 2: simpler shape, same overflow ---
    # H=2048, W=1, C_out=1024 → per-batch output = 2,097,152
    # Threshold N = ceil(2^31 / 2,097,152) = 1024
    print("\nTest 2: Simpler shape (H=2048, W=1, C_in=512, C_out=1024)")
    print("-" * 90)
    for N in [1000, 1023, 1024, 1025]:
        run_test(N, 2048, 1, 512, 1024)

    # --- Test 3: large spatial, small batch ---
    # N=2, H=32768, W=1, C_out=32768 → output = 2 * 32768 * 32768 = 2,147,483,648 = 2^31
    # Just above threshold at N=2
    print("\nTest 3: Large spatial, small batch (N=2, H=32768, W=1, C_in=64)")
    print("-" * 90)
    for C_out in [32000, 32768, 33000]:
        run_test(2, 32768, 1, 64, C_out)

    print("\n" + "=" * 90)
    print("Conclusion: conv_general silently corrupts output when")
    print("  batch * H_out * W_out * C_out > 2^31")
    print("Workaround: use matmul (x @ w.T + bias) for pointwise convolutions.")
    print("=" * 90)


if __name__ == "__main__":
    main()

Expected behavior

mx.conv_general produces correct results when total output elements exceed 2^31 (~2.15 billion).

Environment:

  • MLX 0.30.1
  • macOS 26.2, Apple M4 Max 128 GB
  • Also observed on M-series with earlier MLX versions

Additional context

Console output of above script:

==========================================================================================
mx.conv_general 2^31 output-element overflow reproduction
==========================================================================================
MLX version: 0.30.1
2^31 = 2,147,483,648

Test 1: Sweep batch size (H=320, W=176, C_in=1024, C_out=512)
------------------------------------------------------------------------------------------
^Bp  N=  50  output_elems= 1,441,792,000  < 2^31  conv_last_std=0.3206  matmul_last_std=0.3206  max_diff=0.00e+00  ✅ PASS
  N=  74  output_elems= 2,133,852,160  < 2^31  conv_last_std=0.3205  matmul_last_std=0.3205  max_diff=0.00e+00  ✅ PASS
  N=  75  output_elems= 2,162,688,000  > 2^31  conv_last_std=0.2203  matmul_last_std=0.3205  max_diff=1.76e+00  ❌ FAIL
  N=  76  output_elems= 2,191,523,840  > 2^31  conv_last_std=0.0000  matmul_last_std=0.3206  max_diff=1.95e+00  ❌ FAIL
  N=  77  output_elems= 2,220,359,680  > 2^31  conv_last_std=0.0000  matmul_last_std=0.3204  max_diff=1.87e+00  ❌ FAIL

Test 2: Simpler shape (H=2048, W=1, C_in=512, C_out=1024)
------------------------------------------------------------------------------------------
  N=1000  output_elems= 2,097,152,000  < 2^31  conv_last_std=0.2267  matmul_last_std=0.2267  max_diff=0.00e+00  ✅ PASS
  N=1023  output_elems= 2,145,386,496  < 2^31  conv_last_std=0.2265  matmul_last_std=0.2265  max_diff=0.00e+00  ✅ PASS
  N=1024  output_elems= 2,147,483,648  < 2^31  conv_last_std=0.2265  matmul_last_std=0.2265  max_diff=0.00e+00  ✅ PASS
  N=1025  output_elems= 2,149,580,800  > 2^31  conv_last_std=0.0000  matmul_last_std=0.2266  max_diff=1.22e+00  ❌ FAIL

Test 3: Large spatial, small batch (N=2, H=32768, W=1, C_in=64)
------------------------------------------------------------------------------------------
  N=   2  output_elems= 2,097,152,000  < 2^31  conv_last_std=0.0801  matmul_last_std=0.0801  max_diff=0.00e+00  ✅ PASS
  N=   2  output_elems= 2,147,483,648  < 2^31  conv_last_std=0.0801  matmul_last_std=0.0801  max_diff=0.00e+00  ✅ PASS
  N=   2  output_elems= 2,162,688,000  > 2^31  conv_last_std=0.0791  matmul_last_std=0.0800  max_diff=4.78e-01  ❌ FAIL

==========================================================================================
Conclusion: conv_general silently corrupts output when
  batch * H_out * W_out * C_out > 2^31
Workaround: use matmul (x @ w.T + bias) for pointwise convolutions.
==========================================================================================

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions