Skip to content

Commit 0ffbac1

Browse files
authored
Revert "Fix style after #3261" (#3412)
Revert "Fix style after #3261 (#3397)" This reverts commit 316ef03.
1 parent 5977905 commit 0ffbac1

File tree

2 files changed

+30
-92
lines changed

2 files changed

+30
-92
lines changed

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 14 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -139,22 +139,10 @@ def forward(self, input):
139139

140140

141141
class FP8QDQConv2d(torch.nn.Module):
142-
def __init__(
143-
self,
144-
in_channels,
145-
out_channels,
146-
kernel_size,
147-
stride=1,
148-
padding=0,
149-
dilation=1,
150-
groups=1,
151-
bias=True,
152-
):
142+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
153143
super().__init__()
154144
self.qtype = torch.float8_e4m3fn
155-
self.weight = torch.randn(
156-
(out_channels, in_channels // groups, *kernel_size)
157-
).to(self.qtype)
145+
self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype)
158146
self.weight_scale = 2.0
159147
self.scale = 2.0
160148
self.bias = None
@@ -182,16 +170,7 @@ def forward(self, input):
182170
output_dtype=torch.float,
183171
)
184172

185-
return torch.nn.functional.conv2d(
186-
dq_input,
187-
weight,
188-
self.bias,
189-
self.stride,
190-
self.padding,
191-
self.dilation,
192-
self.groups,
193-
)
194-
173+
return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
195174

196175
def qdq(input, scale):
197176
dtype = input.dtype
@@ -226,7 +205,9 @@ def create_mod_info_recursion(parent):
226205
parent_child_mod_dict = generate_model_info(model)
227206
for name, mod in model.named_modules():
228207
mod_type_str = mod.__class__.__name__
229-
if mod_type_str not in ["Linear", "Conv2d"]:
208+
if mod_type_str not in [
209+
"Linear", "Conv2d"
210+
]:
230211
continue
231212
param = mod.weight
232213
xmax = torch.max(param)
@@ -244,16 +225,7 @@ def create_mod_info_recursion(parent):
244225
patched_mod.weight_scale = weight_scale.item()
245226
patched_mod.weight.data = q_param
246227
elif mod_type_str in ["Conv2d"]:
247-
patched_mod = FP8QDQConv2d(
248-
mod.in_channels,
249-
mod.out_channels,
250-
mod.kernel_size,
251-
mod.stride,
252-
mod.padding,
253-
mod.dilation,
254-
mod.groups,
255-
False,
256-
)
228+
patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False)
257229
patched_mod.bias = mod.bias
258230
patched_mod.weight_scale = weight_scale.item()
259231
patched_mod.weight.data = q_param
@@ -638,9 +610,7 @@ def test_qconv2d_relu6_fp8_cpu(self):
638610
r"""
639611
This testcase will quantize Conv2d->ReLU6 pattern.
640612
"""
641-
self._qconv2d_unary_test_helper(
642-
device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True
643-
)
613+
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True)
644614

645615
@skipIfNoDynamoSupport
646616
@skipIfNoONEDNN
@@ -657,9 +627,7 @@ def test_qconv2d_hardtanh_fp8_cpu(self):
657627
r"""
658628
This testcase will quantize Conv2d->Hardtanh pattern.
659629
"""
660-
self._qconv2d_unary_test_helper(
661-
device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True
662-
)
630+
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True)
663631

664632
@skipIfNoDynamoSupport
665633
@skipIfNoONEDNNBF16
@@ -710,9 +678,7 @@ def test_qconv2d_hardswish_fp8_cpu(self):
710678
r"""
711679
This testcase will quantize Conv2d->Hardswish pattern.
712680
"""
713-
self._qconv2d_unary_test_helper(
714-
device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True
715-
)
681+
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True)
716682

717683
@skipIfNoDynamoSupport
718684
@skipIfNoONEDNNBF16
@@ -765,9 +731,7 @@ def test_qconv2d_silu_fp8_cpu(self):
765731
r"""
766732
This testcase will quantize Conv2d->SiLU pattern.
767733
"""
768-
self._qconv2d_unary_test_helper(
769-
device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True
770-
)
734+
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True)
771735

772736
@skipIfNoDynamoSupport
773737
@skipIfNoONEDNNBF16
@@ -947,7 +911,9 @@ def forward(self, x, x2, x3):
947911
add_fn_list = quantization_add_fn_list
948912
if not is_fp8:
949913
add_fn_list = add_fn_list + quantization_inplace_add_fn_list
950-
for add_fn, swap_inputs in itertools.product(add_fn_list, [False, True]):
914+
for add_fn, swap_inputs in itertools.product(
915+
add_fn_list, [False, True]
916+
):
951917
mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device)
952918
x = torch.randn(
953919
(1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -174,26 +174,21 @@ def get_dequantize_per_tensor_activation_pattern(
174174
output_dtype=KeywordArg("w_dtype"),
175175
)
176176

177-
178177
def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern):
179178
return _may_generate_pattern_with_dtype_convert(
180179
dequant_wgt_pattern,
181180
KeywordArg("autocast_wgt_dtype"),
182181
)
183182

184-
185183
def get_dequantize_clone_weight_pattern(dequant_wgt_pattern):
186184
return CallFunction(
187185
aten.clone.default,
188186
dequant_wgt_pattern,
189187
memory_format=KeywordArg("memory_format"),
190188
)
191189

192-
193190
def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern):
194-
return get_dequantize_clone_weight_pattern(
195-
get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern)
196-
)
191+
return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern))
197192

198193

199194
def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1):
@@ -455,18 +450,14 @@ def fn(match):
455450
break
456451
assert extra_input_of_binary_node is not None
457452
# Extra input of binary node comes from dequant pattern
458-
if (
459-
not is_fp8
460-
and extra_input_from_dequant
461-
and (
462-
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
463-
or (
464-
extra_input_of_binary_node.target
465-
not in [
466-
quantized_decomposed.dequantize_per_tensor.default,
467-
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
468-
]
469-
)
453+
if not is_fp8 and extra_input_from_dequant and (
454+
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
455+
or (
456+
extra_input_of_binary_node.target
457+
not in [
458+
quantized_decomposed.dequantize_per_tensor.default,
459+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
460+
]
470461
)
471462
):
472463
return False
@@ -701,9 +692,7 @@ def _inner(match):
701692
return _inner
702693

703694

704-
def _register_qconv_weight_prepack_pass(
705-
pattern, pass_number, dtype=torch.float32, is_fp8=False
706-
):
695+
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False):
707696
@register_freezing_graph_pattern(
708697
pattern,
709698
extra_check=_is_valid_dequant_conv_pattern(dtype),
@@ -787,10 +776,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
787776
if is_fp8:
788777
# For float8, we assume the scales are from aten.full.default instead of
789778
# a constant buffer to avoid constant folding of q/dq before fusion passes.
790-
assert (
791-
w_scale.target is torch.ops.aten.full.default
792-
and x_scale.target is torch.ops.aten.full.default
793-
)
779+
assert w_scale.target is torch.ops.aten.full.default and x_scale.target is torch.ops.aten.full.default
794780
with torch.utils._python_dispatch._disable_current_modes():
795781
w_scale_tensor = torch.tensor([w_scale.args[1]])
796782
match.graph.owning_module.register_buffer("w_scale", w_scale_tensor)
@@ -1460,12 +1446,8 @@ def _register_dequant_promotion():
14601446

14611447

14621448
def _register_qconv_weight_prepack():
1463-
for dtype, is_fp8 in itertools.product(
1464-
[torch.float32, torch.bfloat16], [True, False]
1465-
):
1466-
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(
1467-
dtype, is_fp8=is_fp8
1468-
)
1449+
for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]):
1450+
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8)
14691451
for weight_prepack_pattern in weight_prepack_patterns:
14701452
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
14711453
_register_qconv_weight_prepack_pass(
@@ -2068,13 +2050,7 @@ def qconv(match: Match, *args, **kwargs):
20682050
kwargs["groups"],
20692051
)
20702052
output_dtype = _get_pattern_output_dtype(match)
2071-
assert output_dtype in [
2072-
torch.int8,
2073-
torch.uint8,
2074-
torch.float8_e4m3fn,
2075-
torch.float32,
2076-
torch.bfloat16,
2077-
]
2053+
assert output_dtype in [torch.int8, torch.uint8, torch.float8_e4m3fn, torch.float32, torch.bfloat16]
20782054
# Output QParams
20792055
if output_dtype == torch.float8_e4m3fn:
20802056
# For float8, we assume the scale is from aten.full.default instead of
@@ -2321,9 +2297,7 @@ def _register_qconv_unary_fusion():
23212297

23222298

23232299
def _register_qconv_binary_fusion():
2324-
for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product(
2325-
[False, True], [False, True]
2326-
):
2300+
for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]):
23272301
qconv_binary_op = (
23282302
torch.ops.onednn.qconv2d_pointwise.binary_tensor
23292303
if x_scale_zp_are_tensors
@@ -2332,9 +2306,7 @@ def _register_qconv_binary_fusion():
23322306
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
23332307
swap_binary_inputs_list = [False, True]
23342308
binary_replace_patterns = {}
2335-
for swap_inputs, is_fp8 in itertools.product(
2336-
swap_binary_inputs_list, [False, True]
2337-
):
2309+
for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]):
23382310
binary_replace_patterns.update(
23392311
{
23402312
PostOpAttr(

0 commit comments

Comments
 (0)