Fix int32 overflow in Metal conv_general output offset for large tensors#3294
Closed
agarwalprakhar2511 wants to merge 1 commit intoml-explore:mainfrom
Closed
Fix int32 overflow in Metal conv_general output offset for large tensors#3294agarwalprakhar2511 wants to merge 1 commit intoml-explore:mainfrom
agarwalprakhar2511 wants to merge 1 commit intoml-explore:mainfrom
Conversation
The implicit_gemm_conv_2d_general Metal kernel computed output write offsets using 32-bit int arithmetic. When batch * H_out * W_out * C_out exceeded 2^31, the offset silently wrapped and corrupted tail batch entries (zeroed out or wrote to wrong locations). Promote offset_cm and offset to size_t in the output store loop, matching the already-64-bit out_strides in MLXConvParams. Add a regression test that crosses the 2^31 boundary with varying per-batch input values so wrong-batch writes are also caught. Fixes ml-explore#3248
ea7ef82 to
3f651f3
Compare
Collaborator
|
The test makes no sense to me. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #3248.
Summary
implicit_gemm_conv_2d_generalkernel from 32-bitinttosize_t, preventing silent output corruption whenbatch * H_out * W_out * C_outexceeds 2^31out_stridesinMLXConvParamswere being multiplied by 32-bitintindices and stored back into anint, causing the address to wrap negative for large tensorsTest plan
python/tests/test_conv.py TestConv.test_conv_general_large_output_offset— passespython/tests/test_conv.pysuite — 19 tests, all pass (9 skipped = no Torch)(30842, 64, 64, 17)output —first_diff = 0.0,last_diff = 0.0Made with Cursor