Skip to content

Commit 3da4862

Browse files
support fp8 scaled_embedding_bag pattern match (#3406)
* support fp8 scaled_embedding_bag pattern match * Update torchao/quantization/pt2e/inductor_passes/x86.py Co-authored-by: Xia Weiwen <xia.weiwen@hotmail.com> * refine code --------- Co-authored-by: Xia Weiwen <xia.weiwen@hotmail.com>
1 parent 387b92a commit 3da4862

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3275,6 +3275,78 @@ def test_fp8_q_attention_block(self):
32753275
annotate_matmul=annotate_matmul, is_fp8=True
32763276
)
32773277

3278+
@skipIfNoDynamoSupport
3279+
@skipIfNoONEDNN
3280+
@skipIfNoFloat8Support
3281+
@unittest.skipIf(
3282+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
3283+
reason="cpp kernels not built",
3284+
)
3285+
def test_fp8_scaled_embedding_bag(self):
3286+
dtype = torch.float8_e4m3fn
3287+
3288+
class FP8QDQEmbeddingBag(torch.nn.Module):
3289+
def __init__(self):
3290+
super().__init__()
3291+
self.weight_scale = 2.0
3292+
3293+
def forward(
3294+
self,
3295+
weight,
3296+
input,
3297+
offsets=None,
3298+
):
3299+
weight = (
3300+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
3301+
tensor=weight.data,
3302+
scale=torch.tensor([self.weight_scale]),
3303+
output_dtype=torch.float,
3304+
)
3305+
)
3306+
3307+
return torch.nn.functional.embedding_bag(
3308+
input,
3309+
weight,
3310+
offsets,
3311+
mode="sum",
3312+
include_last_offset=True,
3313+
)
3314+
3315+
EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10]
3316+
EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024]
3317+
EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512]
3318+
EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32]
3319+
3320+
EMBEDINGBAG_TEST_PARAMS = list(
3321+
itertools.product(
3322+
EMBEDINGBAG_MULTIHOT_SIZES,
3323+
EMBEDINGBAG_BAG_SIZES,
3324+
EMBEDINGBAG_VECTOR_SIZES,
3325+
EMBEDINGBAG_INDEX_DTYPES,
3326+
)
3327+
)
3328+
3329+
for multi_hot, batch_size, vector_size, index_type in EMBEDINGBAG_TEST_PARAMS:
3330+
with torch.no_grad():
3331+
mod = FP8QDQEmbeddingBag()
3332+
3333+
weight = torch.randn((1000, vector_size)).to(dtype)
3334+
indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type)
3335+
offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to(
3336+
index_type
3337+
)
3338+
3339+
def matcher_check_fn():
3340+
self.assertEqual(
3341+
counters["inductor"]["scaled_embedding_bag_matcher_count"], 1
3342+
)
3343+
3344+
self._test_common(
3345+
mod,
3346+
(weight, indices, offsets),
3347+
matcher_check_fn,
3348+
)
3349+
32783350

32793351
instantiate_parametrized_tests(TestPatternMatcher)
32803352
if __name__ == "__main__":

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import functools
55
import itertools
6+
import operator
67
from typing import Any
78

89
import torch
@@ -2902,6 +2903,113 @@ def _register_qlinear_binary_fusion():
29022903
)
29032904

29042905

2906+
def _register_scaled_embedding_bag_pass(pattern, pass_number, dtype=torch.float32):
2907+
@register_freezing_graph_pattern(
2908+
pattern,
2909+
pass_number=pass_number,
2910+
)
2911+
def scaled_embedding_bag(match: Match, *args, **kwargs):
2912+
assert dtype in [torch.float32, torch.bfloat16]
2913+
2914+
getitem_node = match.output_node()
2915+
embedding_bag_node = getitem_node.args[0]
2916+
assert embedding_bag_node.target is aten._embedding_bag_forward_only.default
2917+
2918+
embedding_bag_weight_index = 0
2919+
if dtype == torch.float32:
2920+
# pattern: embedding_bag -> dequant
2921+
dequant_node = embedding_bag_node.args[embedding_bag_weight_index]
2922+
else:
2923+
# pattern: embedding_bag -> to_bf16 -> dequant
2924+
weight_to_bf16_node = embedding_bag_node.args[embedding_bag_weight_index]
2925+
dequant_node = weight_to_bf16_node.args[0]
2926+
2927+
assert dequant_node.target in [
2928+
quantized_decomposed.dequantize_per_tensor.default,
2929+
quantized_decomposed.dequantize_per_tensor.tensor,
2930+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
2931+
]
2932+
2933+
# Weight QParams
2934+
qw, w_scale = kwargs["x"], kwargs["x_scale"]
2935+
2936+
# Input Params
2937+
indices, offsets, mode, include_last_offset = (
2938+
kwargs["indices"],
2939+
kwargs["offsets"],
2940+
kwargs["mode"],
2941+
kwargs["include_last_offset"],
2942+
)
2943+
# only support fp32 output, next step to support more dtype
2944+
o_scale = 1.0
2945+
2946+
graph = match.graph
2947+
with graph.inserting_before(getitem_node):
2948+
new_args: tuple[Any, ...] = (
2949+
qw,
2950+
indices,
2951+
offsets,
2952+
w_scale,
2953+
o_scale,
2954+
mode,
2955+
include_last_offset,
2956+
torch.float,
2957+
)
2958+
2959+
new_embedding_bag_node = graph.call_function(
2960+
torch.ops.torchao._scaled_embedding_bag.default, args=new_args
2961+
)
2962+
2963+
getitem_node.replace_all_uses_with(new_embedding_bag_node)
2964+
new_embedding_bag_node.meta.update(embedding_bag_node.meta)
2965+
2966+
graph.erase_node(getitem_node)
2967+
graph.erase_node(embedding_bag_node)
2968+
if dtype == torch.bfloat16:
2969+
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
2970+
# Erase the dequant pattern
2971+
graph.erase_node(dequant_node)
2972+
2973+
counters["inductor"]["scaled_embedding_bag_matcher_count"] += 1
2974+
counters["inductor"]["scaled_embedding_bag_matcher_nodes"] += len(match.nodes)
2975+
2976+
2977+
def _generate_scaled_embedding_bag_patterns(dq_pattern):
2978+
embedding_bag_pattern = CallFunction(
2979+
torch.ops.aten._embedding_bag_forward_only.default,
2980+
dq_pattern,
2981+
KeywordArg("indices"),
2982+
KeywordArg("offsets"),
2983+
Arg(),
2984+
KeywordArg("mode"),
2985+
KeywordArg("sparse"),
2986+
Arg(),
2987+
KeywordArg("include_last_offset"),
2988+
)
2989+
return CallFunction(
2990+
operator.getitem,
2991+
embedding_bag_pattern,
2992+
KeywordArg("item"),
2993+
)
2994+
2995+
2996+
def _register_quantization_embeddingbag_pass():
2997+
for dtype in [torch.float32, torch.bfloat16]:
2998+
_register_scaled_embedding_bag_pass(
2999+
_generate_scaled_embedding_bag_patterns(
3000+
_may_generate_pattern_with_dtype_convert(
3001+
get_dequantize_per_tensor_activation_pattern(
3002+
is_tensor_overload=False, is_fp8=True
3003+
),
3004+
KeywordArg("autocast_act_dtype"),
3005+
dtype == torch.bfloat16,
3006+
),
3007+
),
3008+
pass_number=1,
3009+
dtype=dtype,
3010+
) # pass_number=0 to run before weight prepack
3011+
3012+
29053013
@functools.lru_cache(None)
29063014
def _register_quantization_weight_pack_pass():
29073015
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
@@ -2924,6 +3032,7 @@ def _register_quantization_weight_pack_pass():
29243032
_register_qconv_binary_fusion()
29253033
_register_qlinear_unary_fusion()
29263034
_register_qlinear_binary_fusion()
3035+
_register_quantization_embeddingbag_pass()
29273036

29283037

29293038
def quant_lift_up(module_graph: torch.fx.graph.Graph):

0 commit comments

Comments
 (0)