Skip to content

Commit cb67b03

Browse files
authored
Add test for constant folding in pt2e quant (#3420)
Summary: We want to remove pt2e quant code from pytorch and the constant folding test actually uses the pytorch pt2e quant APIs, so we will move the test to torchao instead, it's also going to just test the torchao version of constant fold: https://github.com/pytorch/ao/blob/main/torchao/quantization/pt2e/constant_fold.py Test Plan: python test/quantization/pt2e/test_quantize_pt2e.py -k test_constant_folding_pass Reviewers: Subscribers: Tasks: Tags:
1 parent 16aad7c commit cb67b03

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# ruff: noqa: F841
1010

1111

12+
import copy
1213
import unittest
1314

1415
import torch
@@ -21,6 +22,7 @@
2122
weight_observer_range_neg_127_to_127,
2223
)
2324
from torch.fx import Node
25+
from torch.testing import FileCheck
2426
from torch.testing._internal.common_quantization import (
2527
NodeSpec as ns,
2628
)
@@ -1630,6 +1632,101 @@ def forward(self, x):
16301632
if key != FROM_NODE_KEY:
16311633
self.assertEqual(n.meta[key], weight_meta[key])
16321634

1635+
def test_constant_folding_pass(self):
1636+
from torchao.quantization import (
1637+
MappingType,
1638+
PerGroup,
1639+
PerToken,
1640+
)
1641+
from torchao.quantization.pt2e._affine_quantization import (
1642+
AffineQuantizedMinMaxObserver,
1643+
)
1644+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
1645+
from torchao.quantization.pt2e.quantizer import (
1646+
QuantizationAnnotation,
1647+
QuantizationSpec,
1648+
Quantizer,
1649+
)
1650+
1651+
class BackendAQuantizer(Quantizer):
1652+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1653+
for node in model.graph.nodes:
1654+
if (
1655+
node.op == "call_function"
1656+
and node.target == torch.ops.aten.linear.default
1657+
):
1658+
input_act = node.args[0]
1659+
assert isinstance(input_act, torch.fx.Node)
1660+
weight = node.args[1]
1661+
assert isinstance(weight, torch.fx.Node)
1662+
1663+
act_qspec = QuantizationSpec(
1664+
dtype=torch.uint8,
1665+
quant_min=0,
1666+
quant_max=255,
1667+
qscheme=None,
1668+
is_dynamic=False,
1669+
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
1670+
# TODO: maybe align the arg name here
1671+
target_dtype=torch.uint8,
1672+
mapping_type=MappingType.SYMMETRIC,
1673+
granularity=PerToken(),
1674+
),
1675+
)
1676+
1677+
weight_qspec = QuantizationSpec(
1678+
dtype=torch.uint8,
1679+
quant_min=0,
1680+
quant_max=255,
1681+
qscheme=None,
1682+
is_dynamic=False,
1683+
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
1684+
target_dtype=torch.uint8,
1685+
mapping_type=MappingType.SYMMETRIC,
1686+
granularity=PerGroup(group_size=128),
1687+
),
1688+
)
1689+
node.meta["quantization_annotation"] = QuantizationAnnotation(
1690+
input_qspec_map={
1691+
input_act: act_qspec,
1692+
weight: weight_qspec,
1693+
},
1694+
_annotated=True,
1695+
)
1696+
1697+
def validate(self, model: torch.fx.GraphModule) -> None:
1698+
pass
1699+
1700+
class M(torch.nn.Module):
1701+
def __init__(self):
1702+
super().__init__()
1703+
self.linear = torch.nn.Linear(128, 20)
1704+
1705+
def forward(self, x):
1706+
return self.linear(x)
1707+
1708+
example_inputs = (torch.randn(5, 128),)
1709+
model = M()
1710+
quantizer = BackendAQuantizer()
1711+
m = torch.export.export(model.eval(), example_inputs, strict=True).module()
1712+
m = prepare_pt2e(m, quantizer)
1713+
# Calibration
1714+
m(*example_inputs)
1715+
# Get the quantized model
1716+
m_fold = copy.deepcopy(m)
1717+
m_fold = convert_pt2e(m_fold, fold_quantize=True)
1718+
1719+
# If fold, check the graph only contains frozed params and no linear_weight
1720+
FileCheck().check("_frozen_param0").check_not("linear_weight").run(m_fold.code)
1721+
1722+
m_not_fold = copy.deepcopy(m)
1723+
m_not_fold = convert_pt2e(m_not_fold, fold_quantize=False)
1724+
1725+
# If not fold, check the graph doesn't contain frozed params and contain linear_weight
1726+
FileCheck().check_not("_frozen_param0").check("linear_weight").run(
1727+
m_not_fold.code
1728+
)
1729+
16331730
def test_save_load(self):
16341731
"""Test save/load a quantized model"""
16351732
m = self._get_pt2e_quantized_linear()

0 commit comments

Comments
 (0)