Skip to content

Commit 2813d07

Browse files
authored
fix mismatch sub dtype (#13447)
Differential Revision: D80312352
1 parent a5f0cf5 commit 2813d07

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

exir/passes/remove_mixed_type_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def call_operator(self, op, args, kwargs, meta: NodeMetadata): # noqa: C901
2323
promotion_type_allow_list = {
2424
torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2525
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
26+
torch.ops.aten.sub.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2627
# The correct promotion for div depends on the mode! If there is no mode,
2728
# it's INT_TO_FLOAT, otherwise it's default.
2829
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,

exir/tests/test_passes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
159159
return Module
160160

161161
Add = make_module(lambda x, y: (x + y) + x)
162+
Sub = make_module(lambda x, y: (x - y) - x)
162163
Mult = make_module(lambda x, y: x * y)
163164
Minimum = make_module(torch.minimum)
164165
DivWithoutMode = make_module(torch.div)
@@ -177,6 +178,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
177178
2,
178179
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
179180
),
181+
(
182+
Sub,
183+
exir_ops.edge.aten.sub.Tensor,
184+
2,
185+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
186+
),
180187
(
181188
Mult,
182189
exir_ops.edge.aten.mul.Tensor,

0 commit comments

Comments
 (0)