diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h index 850ec15be6..ce8e5590c9 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h @@ -135,7 +135,8 @@ implicit_gemm_conv_2d( C += tid.z * N; B += c_col * K; - C += c_row * (N * params->groups) + c_col; + C += static_cast(c_row) * static_cast(N * params->groups) + + static_cast(c_col); const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index 1241f77357..2af042ddd8 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -200,14 +200,18 @@ implicit_gemm_conv_2d_general( (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { - int offset_cm = n * params->out_strides[0] + - oh * params->out_strides[1] + ow * params->out_strides[2]; + size_t offset_cm = static_cast(n) * + static_cast(params->out_strides[0]) + + static_cast(oh) * + static_cast(params->out_strides[1]) + + static_cast(ow) * + static_cast(params->out_strides[2]); STEEL_PRAGMA_UNROLL for (int j = 0; j < mma_t::TN; j++) { // Get accumulated result and associated offset in C thread const auto& accum = mma_op.Ctile.frag_at(i, j); - int offset = offset_cm + (j * mma_t::TN_stride); + size_t offset = offset_cm + (j * mma_t::TN_stride); constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 1d05cb0c15..485aab95e7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3,6 +3,8 @@ // Required for using M_PI_2 in MSVC. #define _USE_MATH_DEFINES #include +#include +#include #include #include "doctest/doctest.h" @@ -4029,6 +4031,35 @@ TEST_CASE("test conv2d") { } } +TEST_CASE("test conv2d output offset uses 64-bit arithmetic") { + // Regression: conv_general output offset must stay correct past int32 range. + int64_t out_stride_n = 64 * 64 * 17; + int64_t out_stride_h = 64 * 17; + int64_t out_stride_w = 17; + + auto per_batch = out_stride_n; + int n = (std::numeric_limits::max() / per_batch) + 2; + int oh = 63; + int ow = 63; + + auto expected = + static_cast(n - 1) * static_cast(out_stride_n) + + static_cast(oh) * static_cast(out_stride_h) + + static_cast(ow) * static_cast(out_stride_w); + int32_t old_behavior = static_cast( + static_cast(n - 1) * out_stride_n + + static_cast(oh) * out_stride_h + + static_cast(ow) * out_stride_w); + + auto offset = expected; + + CHECK_EQ(offset, expected); + CHECK_GT(offset, static_cast(std::numeric_limits::max())); + + // Simulates prior int32 wraparound behavior. + CHECK_NE(static_cast(old_behavior), expected); +} + TEST_CASE("test trace") { auto in = eye(3); auto out = trace(in).item();