Skip to content

Fix int32 overflow in Metal conv_general output offset for large tensors#3294

Closed
agarwalprakhar2511 wants to merge 1 commit intoml-explore:mainfrom
agarwalprakhar2511:fix/conv-general-int32-overflow
Closed

Fix int32 overflow in Metal conv_general output offset for large tensors#3294
agarwalprakhar2511 wants to merge 1 commit intoml-explore:mainfrom
agarwalprakhar2511:fix/conv-general-int32-overflow

Conversation

@agarwalprakhar2511
Copy link
Copy Markdown

Fixes #3248.

Summary

  • Promote the output write offset calculation in the Metal implicit_gemm_conv_2d_general kernel from 32-bit int to size_t, preventing silent output corruption when batch * H_out * W_out * C_out exceeds 2^31
  • The already-64-bit out_strides in MLXConvParams were being multiplied by 32-bit int indices and stored back into an int, causing the address to wrap negative for large tensors
  • Add a Metal regression test that crosses the 2^31 output-offset boundary with varying per-batch input values so both zeroed-output and wrong-batch-write failures are caught

Test plan

  • python/tests/test_conv.py TestConv.test_conv_general_large_output_offset — passes
  • Full python/tests/test_conv.py suite — 19 tests, all pass (9 skipped = no Torch)
  • Ad hoc large-shape validation on (30842, 64, 64, 17) output — first_diff = 0.0, last_diff = 0.0

Made with Cursor

The implicit_gemm_conv_2d_general Metal kernel computed output write
offsets using 32-bit int arithmetic. When batch * H_out * W_out * C_out
exceeded 2^31, the offset silently wrapped and corrupted tail batch
entries (zeroed out or wrote to wrong locations).

Promote offset_cm and offset to size_t in the output store loop,
matching the already-64-bit out_strides in MLXConvParams.

Add a regression test that crosses the 2^31 boundary with varying
per-batch input values so wrong-batch writes are also caught.

Fixes ml-explore#3248
@agarwalprakhar2511 agarwalprakhar2511 force-pushed the fix/conv-general-int32-overflow branch from ea7ef82 to 3f651f3 Compare March 22, 2026 10:02
@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Mar 22, 2026

The test makes no sense to me.

@zcbenz zcbenz closed this Mar 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] mx.conv_general produces wrong results when total output elements exceed 2^31 (~2.15 billion).

2 participants