Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,19 @@ implicit_gemm_conv_2d_general(
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;

if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
oh * params->out_strides[1] + ow * params->out_strides[2];
size_t offset_cm =
static_cast<size_t>(n) *
static_cast<size_t>(params->out_strides[0]) +
static_cast<size_t>(oh) *
static_cast<size_t>(params->out_strides[1]) +
static_cast<size_t>(ow) *
static_cast<size_t>(params->out_strides[2]);

STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = mma_op.Ctile.frag_at(i, j);
int offset = offset_cm + (j * mma_t::TN_stride);
size_t offset = offset_cm + (j * mma_t::TN_stride);

constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;

Expand Down
38 changes: 38 additions & 0 deletions python/tests/test_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.

import math
import os
import unittest
from itertools import permutations

Expand Down Expand Up @@ -1192,6 +1193,43 @@ def test_conv2d_large_filter_small_channels(self):
y_hat = mx.conv2d(x, w, (1, 1), (1, 1))
self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3))

@unittest.skipIf(
os.getenv("LOW_MEMORY", None) is not None,
"This test requires a lot of memory",
)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_conv_general_large_output_offset(self):
H = W = 64
O = 17

per_batch_output = H * W * O
# +2 makes the last batch start beyond the signed 32-bit offset limit.
batch_size = (2**31) // per_batch_output + 2

# Vary the per-batch input so mis-addressed writes cannot still pass by
# aliasing another batch with identical values.
batch_values = (mx.arange(batch_size, dtype=mx.int32) % 251).astype(mx.float16)
batch_values = batch_values.reshape((batch_size, 1, 1, 1))
x = mx.ones((batch_size, H, W, 1), dtype=mx.float16) * batch_values

channel_values = mx.arange(1, O + 1, dtype=mx.float16) / 8
w = channel_values.reshape((O, 1, 1, 1))
channel_values = channel_values.reshape((1, 1, O))

try:
y = mx.conv_general(x, w, stream=mx.gpu)
self.assertTrue(
mx.allclose(y[0], x[0] * channel_values, atol=1e-3, rtol=1e-3)
)
self.assertTrue(
mx.allclose(y[-1], x[-1] * channel_values, atol=1e-3, rtol=1e-3)
)
finally:
del batch_values, x, w, channel_values
if "y" in locals():
del y
mx.clear_cache()


if __name__ == "__main__":
mlx_tests.MLXTestRunner()