Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ implicit_gemm_conv_2d(
C += tid.z * N;

B += c_col * K;
C += c_row * (N * params->groups) + c_col;
C += static_cast<size_t>(c_row) * static_cast<size_t>(N * params->groups) +
static_cast<size_t>(c_col);

const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,18 @@ implicit_gemm_conv_2d_general(
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;

if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
oh * params->out_strides[1] + ow * params->out_strides[2];
size_t offset_cm = static_cast<size_t>(n) *
static_cast<size_t>(params->out_strides[0]) +
static_cast<size_t>(oh) *
static_cast<size_t>(params->out_strides[1]) +
static_cast<size_t>(ow) *
static_cast<size_t>(params->out_strides[2]);

STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = mma_op.Ctile.frag_at(i, j);
int offset = offset_cm + (j * mma_t::TN_stride);
size_t offset = offset_cm + (j * mma_t::TN_stride);

constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;

Expand Down
31 changes: 31 additions & 0 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Required for using M_PI_2 in MSVC.
#define _USE_MATH_DEFINES
#include <cmath>
#include <cstdint>
#include <limits>
#include <numeric>

#include "doctest/doctest.h"
Expand Down Expand Up @@ -4029,6 +4031,35 @@ TEST_CASE("test conv2d") {
}
}

TEST_CASE("test conv2d output offset uses 64-bit arithmetic") {
// Regression: conv_general output offset must stay correct past int32 range.
int64_t out_stride_n = 64 * 64 * 17;
int64_t out_stride_h = 64 * 17;
int64_t out_stride_w = 17;

auto per_batch = out_stride_n;
int n = (std::numeric_limits<int32_t>::max() / per_batch) + 2;
int oh = 63;
int ow = 63;

auto expected =
static_cast<size_t>(n - 1) * static_cast<size_t>(out_stride_n) +
static_cast<size_t>(oh) * static_cast<size_t>(out_stride_h) +
static_cast<size_t>(ow) * static_cast<size_t>(out_stride_w);
int32_t old_behavior = static_cast<int32_t>(
static_cast<int64_t>(n - 1) * out_stride_n +
static_cast<int64_t>(oh) * out_stride_h +
static_cast<int64_t>(ow) * out_stride_w);

auto offset = expected;

CHECK_EQ(offset, expected);
CHECK_GT(offset, static_cast<size_t>(std::numeric_limits<int32_t>::max()));

// Simulates prior int32 wraparound behavior.
CHECK_NE(static_cast<size_t>(old_behavior), expected);
}

TEST_CASE("test trace") {
auto in = eye(3);
auto out = trace(in).item<float>();
Expand Down