|
9 | 9 | # ruff: noqa: F841 |
10 | 10 |
|
11 | 11 |
|
| 12 | +import copy |
12 | 13 | import unittest |
13 | 14 |
|
14 | 15 | import torch |
|
21 | 22 | weight_observer_range_neg_127_to_127, |
22 | 23 | ) |
23 | 24 | from torch.fx import Node |
| 25 | +from torch.testing import FileCheck |
24 | 26 | from torch.testing._internal.common_quantization import ( |
25 | 27 | NodeSpec as ns, |
26 | 28 | ) |
@@ -1630,6 +1632,101 @@ def forward(self, x): |
1630 | 1632 | if key != FROM_NODE_KEY: |
1631 | 1633 | self.assertEqual(n.meta[key], weight_meta[key]) |
1632 | 1634 |
|
| 1635 | + def test_constant_folding_pass(self): |
| 1636 | + from torchao.quantization import ( |
| 1637 | + MappingType, |
| 1638 | + PerGroup, |
| 1639 | + PerToken, |
| 1640 | + ) |
| 1641 | + from torchao.quantization.pt2e._affine_quantization import ( |
| 1642 | + AffineQuantizedMinMaxObserver, |
| 1643 | + ) |
| 1644 | + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 1645 | + from torchao.quantization.pt2e.quantizer import ( |
| 1646 | + QuantizationAnnotation, |
| 1647 | + QuantizationSpec, |
| 1648 | + Quantizer, |
| 1649 | + ) |
| 1650 | + |
| 1651 | + class BackendAQuantizer(Quantizer): |
| 1652 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 1653 | + for node in model.graph.nodes: |
| 1654 | + if ( |
| 1655 | + node.op == "call_function" |
| 1656 | + and node.target == torch.ops.aten.linear.default |
| 1657 | + ): |
| 1658 | + input_act = node.args[0] |
| 1659 | + assert isinstance(input_act, torch.fx.Node) |
| 1660 | + weight = node.args[1] |
| 1661 | + assert isinstance(weight, torch.fx.Node) |
| 1662 | + |
| 1663 | + act_qspec = QuantizationSpec( |
| 1664 | + dtype=torch.uint8, |
| 1665 | + quant_min=0, |
| 1666 | + quant_max=255, |
| 1667 | + qscheme=None, |
| 1668 | + is_dynamic=False, |
| 1669 | + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( |
| 1670 | + # TODO: maybe align the arg name here |
| 1671 | + target_dtype=torch.uint8, |
| 1672 | + mapping_type=MappingType.SYMMETRIC, |
| 1673 | + granularity=PerToken(), |
| 1674 | + ), |
| 1675 | + ) |
| 1676 | + |
| 1677 | + weight_qspec = QuantizationSpec( |
| 1678 | + dtype=torch.uint8, |
| 1679 | + quant_min=0, |
| 1680 | + quant_max=255, |
| 1681 | + qscheme=None, |
| 1682 | + is_dynamic=False, |
| 1683 | + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( |
| 1684 | + target_dtype=torch.uint8, |
| 1685 | + mapping_type=MappingType.SYMMETRIC, |
| 1686 | + granularity=PerGroup(group_size=128), |
| 1687 | + ), |
| 1688 | + ) |
| 1689 | + node.meta["quantization_annotation"] = QuantizationAnnotation( |
| 1690 | + input_qspec_map={ |
| 1691 | + input_act: act_qspec, |
| 1692 | + weight: weight_qspec, |
| 1693 | + }, |
| 1694 | + _annotated=True, |
| 1695 | + ) |
| 1696 | + |
| 1697 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 1698 | + pass |
| 1699 | + |
| 1700 | + class M(torch.nn.Module): |
| 1701 | + def __init__(self): |
| 1702 | + super().__init__() |
| 1703 | + self.linear = torch.nn.Linear(128, 20) |
| 1704 | + |
| 1705 | + def forward(self, x): |
| 1706 | + return self.linear(x) |
| 1707 | + |
| 1708 | + example_inputs = (torch.randn(5, 128),) |
| 1709 | + model = M() |
| 1710 | + quantizer = BackendAQuantizer() |
| 1711 | + m = torch.export.export(model.eval(), example_inputs, strict=True).module() |
| 1712 | + m = prepare_pt2e(m, quantizer) |
| 1713 | + # Calibration |
| 1714 | + m(*example_inputs) |
| 1715 | + # Get the quantized model |
| 1716 | + m_fold = copy.deepcopy(m) |
| 1717 | + m_fold = convert_pt2e(m_fold, fold_quantize=True) |
| 1718 | + |
| 1719 | + # If fold, check the graph only contains frozed params and no linear_weight |
| 1720 | + FileCheck().check("_frozen_param0").check_not("linear_weight").run(m_fold.code) |
| 1721 | + |
| 1722 | + m_not_fold = copy.deepcopy(m) |
| 1723 | + m_not_fold = convert_pt2e(m_not_fold, fold_quantize=False) |
| 1724 | + |
| 1725 | + # If not fold, check the graph doesn't contain frozed params and contain linear_weight |
| 1726 | + FileCheck().check_not("_frozen_param0").check("linear_weight").run( |
| 1727 | + m_not_fold.code |
| 1728 | + ) |
| 1729 | + |
1633 | 1730 | def test_save_load(self): |
1634 | 1731 | """Test save/load a quantized model""" |
1635 | 1732 | m = self._get_pt2e_quantized_linear() |
|
0 commit comments