Skip to content

Commit f8a64c3

Browse files
kundaMwizapytorchmergebot
authored andcommitted
Broadcast constants on vectorised stores in CppTile2DKernel (pytorch#140262)
Currently constants are not broadcasted on vectorised stores in `CppTile2DKernel`. This leads to errors like the following: ```shell error:: request for member 'store' in 'tmp1', which is of non-class type 'signed char' 61 | tmp1.store(tmp2 + static_cast<int64_t>(8L*x0_inner), static_cast<int64_t>(8)); | ^~~~~ ``` This PR adds the required broadcasting. Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#140262 Approved by: https://github.com/jgong5
1 parent e1e3bbc commit f8a64c3

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3394,6 +3394,59 @@ def f(a):
33943394
x = torch.rand(4, 5)
33953395
self.common(f, (x,))
33963396

3397+
def test_broadcast_scalar_cpp_tile_2d_kernel(self):
3398+
# Based on detectron2_maskrcnn backbone (conv2d -> max_pool2d)
3399+
s0 = 12
3400+
s1 = 21
3401+
3402+
data = torch.randn(
3403+
[1, 256, 8 * s0, 8 * s1],
3404+
)
3405+
weight_one = torch.randn([256, 256, 1, 1], requires_grad=True)
3406+
weight_two = torch.randn((256, 256, 3, 3), requires_grad=True)
3407+
bias_one = torch.randn([256], requires_grad=True)
3408+
bias_two = torch.randn([256], requires_grad=True)
3409+
3410+
@torch.compile
3411+
def fn(data, weight_one, weight_two, bias_one, bias_two):
3412+
conv_result_one = torch.ops.aten.convolution.default(
3413+
data,
3414+
weight_one,
3415+
bias_one,
3416+
[1, 1],
3417+
[1, 1],
3418+
[1, 1],
3419+
False,
3420+
[0, 0],
3421+
1,
3422+
)
3423+
3424+
conv_result_two = torch.ops.aten.convolution.default(
3425+
data,
3426+
weight_two,
3427+
bias_two,
3428+
[1, 1],
3429+
[1, 1],
3430+
[1, 1],
3431+
False,
3432+
[0, 0],
3433+
1,
3434+
)
3435+
3436+
max_pool_result = torch.nn.functional.max_pool2d(
3437+
conv_result_one,
3438+
[1, 1],
3439+
[2, 2],
3440+
[0, 0],
3441+
[1, 1],
3442+
False,
3443+
)
3444+
return conv_result_one, conv_result_two, max_pool_result
3445+
3446+
torch._dynamo.mark_dynamic(data, 2)
3447+
torch._dynamo.mark_dynamic(data, 3)
3448+
self.common(fn, (data, weight_one, weight_two, bias_one, bias_two))
3449+
33973450
def test_to_channels_last_lowp_fp(self):
33983451
def f(a):
33993452
return a.to(memory_format=torch.channels_last)

torch/_inductor/codegen/cpp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3284,6 +3284,11 @@ def load(self, name: str, index: sympy.Expr):
32843284

32853285
def store(self, name, index, value, mode=None):
32863286
assert "buf" in name
3287+
assert isinstance(value, CppCSEVariable), value
3288+
if not value.is_vec:
3289+
# this happens when we store a scalar into a vectorized buffer like "fill"
3290+
value = self.broadcast(value)
3291+
32873292
var = self.args.output(name)
32883293

32893294
inner = self.inner_itervar()

0 commit comments

Comments
 (0)