@@ -174,26 +174,21 @@ def get_dequantize_per_tensor_activation_pattern(
174174 output_dtype = KeywordArg ("w_dtype" ),
175175)
176176
177-
178177def get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern ):
179178 return _may_generate_pattern_with_dtype_convert (
180179 dequant_wgt_pattern ,
181180 KeywordArg ("autocast_wgt_dtype" ),
182181 )
183182
184-
185183def get_dequantize_clone_weight_pattern (dequant_wgt_pattern ):
186184 return CallFunction (
187185 aten .clone .default ,
188186 dequant_wgt_pattern ,
189187 memory_format = KeywordArg ("memory_format" ),
190188 )
191189
192-
193190def get_dequantize_to_bf16_clone_weight_pattern (dequant_wgt_pattern ):
194- return get_dequantize_clone_weight_pattern (
195- get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern )
196- )
191+ return get_dequantize_clone_weight_pattern (get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern ))
197192
198193
199194def get_qconv_pt2e_pattern (x_scale_zp_are_tensors = False , users = 1 ):
@@ -455,18 +450,14 @@ def fn(match):
455450 break
456451 assert extra_input_of_binary_node is not None
457452 # Extra input of binary node comes from dequant pattern
458- if (
459- not is_fp8
460- and extra_input_from_dequant
461- and (
462- (not isinstance (extra_input_of_binary_node , torch .fx .Node ))
463- or (
464- extra_input_of_binary_node .target
465- not in [
466- quantized_decomposed .dequantize_per_tensor .default ,
467- torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
468- ]
469- )
453+ if not is_fp8 and extra_input_from_dequant and (
454+ (not isinstance (extra_input_of_binary_node , torch .fx .Node ))
455+ or (
456+ extra_input_of_binary_node .target
457+ not in [
458+ quantized_decomposed .dequantize_per_tensor .default ,
459+ torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
460+ ]
470461 )
471462 ):
472463 return False
@@ -701,9 +692,7 @@ def _inner(match):
701692 return _inner
702693
703694
704- def _register_qconv_weight_prepack_pass (
705- pattern , pass_number , dtype = torch .float32 , is_fp8 = False
706- ):
695+ def _register_qconv_weight_prepack_pass (pattern , pass_number , dtype = torch .float32 , is_fp8 = False ):
707696 @register_freezing_graph_pattern (
708697 pattern ,
709698 extra_check = _is_valid_dequant_conv_pattern (dtype ),
@@ -787,10 +776,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
787776 if is_fp8 :
788777 # For float8, we assume the scales are from aten.full.default instead of
789778 # a constant buffer to avoid constant folding of q/dq before fusion passes.
790- assert (
791- w_scale .target is torch .ops .aten .full .default
792- and x_scale .target is torch .ops .aten .full .default
793- )
779+ assert w_scale .target is torch .ops .aten .full .default and x_scale .target is torch .ops .aten .full .default
794780 with torch .utils ._python_dispatch ._disable_current_modes ():
795781 w_scale_tensor = torch .tensor ([w_scale .args [1 ]])
796782 match .graph .owning_module .register_buffer ("w_scale" , w_scale_tensor )
@@ -1460,12 +1446,8 @@ def _register_dequant_promotion():
14601446
14611447
14621448def _register_qconv_weight_prepack ():
1463- for dtype , is_fp8 in itertools .product (
1464- [torch .float32 , torch .bfloat16 ], [True , False ]
1465- ):
1466- weight_prepack_patterns = _generate_qconv_weight_prepack_patterns (
1467- dtype , is_fp8 = is_fp8
1468- )
1449+ for dtype , is_fp8 in itertools .product ([torch .float32 , torch .bfloat16 ], [True , False ]):
1450+ weight_prepack_patterns = _generate_qconv_weight_prepack_patterns (dtype , is_fp8 = is_fp8 )
14691451 for weight_prepack_pattern in weight_prepack_patterns :
14701452 # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
14711453 _register_qconv_weight_prepack_pass (
@@ -2068,13 +2050,7 @@ def qconv(match: Match, *args, **kwargs):
20682050 kwargs ["groups" ],
20692051 )
20702052 output_dtype = _get_pattern_output_dtype (match )
2071- assert output_dtype in [
2072- torch .int8 ,
2073- torch .uint8 ,
2074- torch .float8_e4m3fn ,
2075- torch .float32 ,
2076- torch .bfloat16 ,
2077- ]
2053+ assert output_dtype in [torch .int8 , torch .uint8 , torch .float8_e4m3fn , torch .float32 , torch .bfloat16 ]
20782054 # Output QParams
20792055 if output_dtype == torch .float8_e4m3fn :
20802056 # For float8, we assume the scale is from aten.full.default instead of
@@ -2321,9 +2297,7 @@ def _register_qconv_unary_fusion():
23212297
23222298
23232299def _register_qconv_binary_fusion ():
2324- for int8_mixed_bf16_with_inplace_add , x_scale_zp_are_tensors in itertools .product (
2325- [False , True ], [False , True ]
2326- ):
2300+ for int8_mixed_bf16_with_inplace_add , x_scale_zp_are_tensors in itertools .product ([False , True ], [False , True ]):
23272301 qconv_binary_op = (
23282302 torch .ops .onednn .qconv2d_pointwise .binary_tensor
23292303 if x_scale_zp_are_tensors
@@ -2332,9 +2306,7 @@ def _register_qconv_binary_fusion():
23322306 # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
23332307 swap_binary_inputs_list = [False , True ]
23342308 binary_replace_patterns = {}
2335- for swap_inputs , is_fp8 in itertools .product (
2336- swap_binary_inputs_list , [False , True ]
2337- ):
2309+ for swap_inputs , is_fp8 in itertools .product (swap_binary_inputs_list , [False , True ]):
23382310 binary_replace_patterns .update (
23392311 {
23402312 PostOpAttr (
0 commit comments