diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index 1241f77357..83805d4d23 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -200,14 +200,19 @@ implicit_gemm_conv_2d_general( (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { - int offset_cm = n * params->out_strides[0] + - oh * params->out_strides[1] + ow * params->out_strides[2]; + size_t offset_cm = + static_cast(n) * + static_cast(params->out_strides[0]) + + static_cast(oh) * + static_cast(params->out_strides[1]) + + static_cast(ow) * + static_cast(params->out_strides[2]); STEEL_PRAGMA_UNROLL for (int j = 0; j < mma_t::TN; j++) { // Get accumulated result and associated offset in C thread const auto& accum = mma_op.Ctile.frag_at(i, j); - int offset = offset_cm + (j * mma_t::TN_stride); + size_t offset = offset_cm + (j * mma_t::TN_stride); constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 80b7d6b613..311f7c5c1d 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import math +import os import unittest from itertools import permutations @@ -1192,6 +1193,43 @@ def test_conv2d_large_filter_small_channels(self): y_hat = mx.conv2d(x, w, (1, 1), (1, 1)) self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3)) + @unittest.skipIf( + os.getenv("LOW_MEMORY", None) is not None, + "This test requires a lot of memory", + ) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_conv_general_large_output_offset(self): + H = W = 64 + O = 17 + + per_batch_output = H * W * O + # +2 makes the last batch start beyond the signed 32-bit offset limit. + batch_size = (2**31) // per_batch_output + 2 + + # Vary the per-batch input so mis-addressed writes cannot still pass by + # aliasing another batch with identical values. + batch_values = (mx.arange(batch_size, dtype=mx.int32) % 251).astype(mx.float16) + batch_values = batch_values.reshape((batch_size, 1, 1, 1)) + x = mx.ones((batch_size, H, W, 1), dtype=mx.float16) * batch_values + + channel_values = mx.arange(1, O + 1, dtype=mx.float16) / 8 + w = channel_values.reshape((O, 1, 1, 1)) + channel_values = channel_values.reshape((1, 1, O)) + + try: + y = mx.conv_general(x, w, stream=mx.gpu) + self.assertTrue( + mx.allclose(y[0], x[0] * channel_values, atol=1e-3, rtol=1e-3) + ) + self.assertTrue( + mx.allclose(y[-1], x[-1] * channel_values, atol=1e-3, rtol=1e-3) + ) + finally: + del batch_values, x, w, channel_values + if "y" in locals(): + del y + mx.clear_cache() + if __name__ == "__main__": mlx_tests.MLXTestRunner()