From 16b9c46f1dce5657e27eadce7f7f1b8baae5ea86 Mon Sep 17 00:00:00 2001 From: shewu Date: Fri, 28 Nov 2025 10:37:50 +0800 Subject: [PATCH 1/2] Qualcomm AI Engine Direct - Support triu op and linear op with non-constant weights Summary: - Add a pass to decompose triu op during quantization and export stage so that aten.ge.Scalar does not appear after to_edge. - Refactor linear op build to support non-constant weight and bias - Fixed doc - Fixed RMSNorm without eps --- backends/qualcomm/_passes/__init__.py | 2 + backends/qualcomm/_passes/decompose_triu.py | 71 +++++++++++ .../_passes/lift_constant_scalar_operands.py | 1 + backends/qualcomm/_passes/qnn_pass_manager.py | 3 + backends/qualcomm/builders/README.md | 2 +- backends/qualcomm/builders/op_linear.py | 19 +-- backends/qualcomm/builders/op_rms_norm.py | 4 +- backends/qualcomm/tests/models.py | 65 +++++++++- backends/qualcomm/tests/test_qnn_delegate.py | 113 +++++++++++++++--- 9 files changed, 243 insertions(+), 37 deletions(-) create mode 100644 backends/qualcomm/_passes/decompose_triu.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 154a360689e..2800156a738 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -25,6 +25,7 @@ from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu from .decompose_threshold import DecomposeThreshold +from .decompose_triu import DecomposeTriu from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape from .fixed_linear_keep_dim import FixedLinearKeepDim @@ -69,6 +70,7 @@ DecomposeRoll, DecomposeSilu, DecomposeThreshold, + DecomposeTriu, DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, diff --git a/backends/qualcomm/_passes/decompose_triu.py b/backends/qualcomm/_passes/decompose_triu.py new file mode 100644 index 00000000000..cb0450a499d --- /dev/null +++ b/backends/qualcomm/_passes/decompose_triu.py @@ -0,0 +1,71 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.exir.pass_base import ExportPass, PassResult +from torch._decomp import get_decompositions +from torch.fx.experimental.proxy_tensor import make_fx + +from .utils import merge_decomposed_graph + + +class DecomposeTriu(ExportPass): + """ + Decompose triu during quantization or export stage + This allows LiftConstantScalarOperands to lift the scalar into a scalar_tensor. + Otherwise, after to_edge, the triu operation will be decomposed into several operations that include aten.ge.Scalar. + """ + + def __init__(self) -> None: + super().__init__() + + def _replace_output( + self, node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict + ): + for user in node.users.copy(): + # remap + user.replace_input_with( + node, + remap[output_node.args[0]], + ) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + decom_mappings = get_decompositions([torch.ops.aten.triu.default]) + + for node in graph.nodes: + if node.target == torch.ops.aten.triu.default: + input_args = [node.args[0].meta["val"]] + # The args[1], diagonal, is optional + if len(node.args) > 1: + input_args.append(node.args[1]) + decomposed_module = make_fx( + node.target, + decomposition_table=decom_mappings, + tracing_mode="fake", + )(*input_args) + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {} + remap["arg0_1"] = node.args[0] + + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + predicate=lambda decomp_node: "arg1_1" not in decomp_node.name, + output_processor=self._replace_output, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index 52bdf7fa090..e5d9371709d 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -55,6 +55,7 @@ class TensorOpInfo: aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True), aten.where.Scalar: TensorOpInfo(aten.where.self, False, True), aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False), + aten.masked_fill_.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False), aten.bitwise_xor.Scalar: TensorOpInfo(aten.bitwise_xor.Tensor, False, False), } diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 80b4675d2f1..95ec2f03b66 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -30,6 +30,7 @@ DecomposeRoll, DecomposeSilu, DecomposeThreshold, + DecomposeTriu, DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, @@ -204,6 +205,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeRoll()) self.add_pass(DecomposeSilu()) self.add_pass(DecomposeThreshold()) + self.add_pass(DecomposeTriu()) self.add_pass(DecomposeWrapWithAutocast()) self.add_pass(DecomposeEinsum()) self.add_pass(DecomposeExpM1()) @@ -222,6 +224,7 @@ def transform_for_export_pipeline( self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) self.add_pass(DecomposeThreshold()) + self.add_pass(DecomposeTriu()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeExpM1()) # DecomposeFloorDivide does not apply to the annotation pipeline, diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 2f1c2d54828..a0adace81df 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -41,7 +41,7 @@ class MyModel(torch.nn.Module): ``` At the time we try to lower it with Qualcomm backend: ```python -from excutorch.examples.qualcomm.utils import build_executorch_binary +from executorch.examples.qualcomm.utils import build_executorch_binary build_executorch_binary( model=MyModel(), diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index d5ac153b8d1..cdcd2f62e6e 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -18,7 +18,6 @@ from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor from .qnn_constants import OpFullyConnected, QNN_OP_PACKAGE_NAME_QTI_AISW -from .utils import get_parameter @register_node_visitor @@ -55,32 +54,26 @@ def define_node( quant_attrs[QCOM_ZERO_POINTS] = quant_attrs[QCOM_ZERO_POINTS].reshape( [-1, 1] ) - - weight_tensor = get_parameter(weight_node, self.edge_program) + weight_tensor = self.get_tensor(weight_node, node) weight_tensor_wrapper = self.define_tensor( weight_node, node, weight_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + # It will determine correct QNN tensor type in define_tensor. + # This param seems unnecessary, which we could possibly remove this in the future. + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) linear_input_tensors.append(weight_tensor_wrapper) if len(node.args) >= 3: bias_node = self.get_node(node.args[2]) - - bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC - bias_tensor = get_parameter(bias_node, self.edge_program) - # if bias_node is getitem - if bias_tensor is None: - bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE - bias_tensor = bias_node.meta["val"] - + bias_tensor = self.get_tensor(bias_node, node) bias_tensor_wrapper = self.define_tensor( bias_node, node, bias_tensor, - bias_tensor_type, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) linear_input_tensors.append(bias_tensor_wrapper) diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py index 6d5060f730b..058e536d003 100644 --- a/backends/qualcomm/builders/op_rms_norm.py +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -93,7 +93,9 @@ def define_node( nodes_to_wrappers, ) - epsilon = node.args[3] + epsilon = torch.finfo(torch.float32).eps + if len(node.args) > 3: + epsilon = node.args[3] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index cdd0c194fe3..3e64199f45f 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1283,12 +1283,18 @@ class LargeTensorLinear(torch.nn.Module): def __init__(self): super().__init__() hidden_dim = 4096 - self.linear1 = torch.nn.Linear(512, hidden_dim) + self.linear1_1 = torch.nn.Linear(512, hidden_dim) + self.linear1_2 = torch.nn.Linear(512, hidden_dim) + self.linear1_3 = torch.nn.Linear(512, hidden_dim) self.linear2 = torch.nn.Linear(hidden_dim, 512) + self.linear3 = torch.nn.Linear(hidden_dim, 512) + self.linear4 = torch.nn.Linear(hidden_dim, 512) def forward(self, x): - x1 = self.linear1(x) + self.linear1(x) - return self.linear2(x1) + x1 = self.linear1_1(x) + self.linear1_1(x) + x2 = self.linear1_2(x) + self.linear1_2(x) + x3 = self.linear1_3(x) + self.linear1_3(x) + return self.linear2(x1) * self.linear3(x2) * self.linear4(x3) class LayerNorm(torch.nn.Module): @@ -1380,6 +1386,26 @@ def forward(self, x): return self.linear(x) +class LinearNonConstantWeight(torch.nn.Module): + def __init__(self): + super().__init__() + self.input_dim = 512 + self.output_dim = 128 + self.linear = torch.nn.Linear(self.input_dim, 3 * self.output_dim, True).eval() + + def forward(self, x): + w_q, w_k, w_v = self.linear.weight.split( + [self.output_dim, self.output_dim, self.output_dim] + ) + b_q, b_k, b_v = self.linear.bias.split( + [self.output_dim, self.output_dim, self.output_dim] + ) + q = torch.nn.functional.linear(x, w_q, b_q) + k = torch.nn.functional.linear(x, w_k, b_k) + v = torch.nn.functional.linear(x, w_v, b_v) + return q * k * v + + class LinalgVectorNorm(torch.nn.Module): def __init__(self, ord=2.0, dim=None, keepdim=False): super().__init__() @@ -1814,10 +1840,11 @@ def forward(self, x): class RmsNorm(torch.nn.Module): - def __init__(self): + def __init__(self, eps=None): super().__init__() - self.eps = 1e-5 - self.rms = torch.nn.RMSNorm([4], 1e-5) + self.rms = torch.nn.RMSNorm([4]) + if eps: + self.rms = torch.nn.RMSNorm([4], eps) def forward(self, x): return self.rms(x) @@ -2141,6 +2168,32 @@ def forward(self, x): return a + self.idx_source[b] +class Triu(torch.nn.Module): + def __init__(self, diagonal: Optional[int] = None): + super().__init__() + self.diagonal = diagonal + + def forward(self, x): + if self.diagonal: + return torch.triu(x, diagonal=self.diagonal) + return torch.triu(x) + + +class TriuConstant(torch.nn.Module): + def __init__(self, diagonal, constant_dtype=torch.float32): + super().__init__() + self.diagonal = diagonal + self.constant_dtype = constant_dtype + self.register_buffer("mask", torch.ones((5, 5), dtype=constant_dtype)) + + def forward(self, x): + mask = torch.triu(self.mask, diagonal=self.diagonal) + if self.constant_dtype == torch.bool: + mask = torch.zeros(x.shape, dtype=x.dtype).masked_fill_(mask, -10000.0) + # Add x to avoid no input in graph + return mask + x + + class Unbind(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index af79256591d..469884e351b 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1351,9 +1351,14 @@ def test_qnn_backend_linalg_vector_norm(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_linear(self): - module = Linear() # noqa: F405 + modules = [ + Linear(), # noqa: F405 + LinearNonConstantWeight(), # noqa: F405 + ] sample_input = (torch.randn([3, 512]),) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_log(self): module = Log() # noqa: F405 @@ -1576,9 +1581,14 @@ def test_qnn_backend_reshape(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_rms_norm(self): - module = RmsNorm() # noqa: F405 - sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),) - self.lower_module_and_test_output(module, sample_input) + modules = [ + RmsNorm(), # noqa: F405 + RmsNorm(eps=1e-5), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 1, 4]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_roll(self): modules = [ @@ -1735,6 +1745,37 @@ def test_qnn_backend_threshold(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_triu(self): + test_comb = [ + { + QCOM_MODULE: [ + Triu(), # noqa: F405 + Triu(diagonal=1), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(3, 3),), + (torch.randn(1, 2, 3, 3),), + ], + }, + { + QCOM_MODULE: [ + TriuConstant(1), # noqa: F405 + TriuConstant(1, constant_dtype=torch.bool), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + (torch.zeros(5, 5),), + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_unflatten(self): module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405 sample_input = (torch.randn([1, 24]),) @@ -3501,10 +3542,15 @@ def test_qnn_backend_linalg_vector_norm(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_linear(self): - module = Linear() # noqa: F405 + modules = [ + Linear(), # noqa: F405 + LinearNonConstantWeight(), # noqa: F405 + ] sample_input = (torch.randn([3, 512]),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) @unittest.skipIf(is_qnn_sdk_version_less_than("2.30"), "UT pass after QNN 2.30") def test_qnn_backend_linear_block(self): @@ -3780,12 +3826,17 @@ def test_qnn_backend_reshape(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_rms_norm(self): - module = RmsNorm() # noqa: F405 - sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),) - module = self.get_qdq_module( - module, sample_input, quant_dtype=QuantDtype.use_16a4w - ) - self.lower_module_and_test_output(module, sample_input) + modules = [ + RmsNorm(), # noqa: F405 + RmsNorm(eps=1e-5), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 1, 4]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a4w + ) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_roll(self): modules = [ @@ -3967,6 +4018,38 @@ def test_qnn_backend_threshold(self): qdq_module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_triu(self): + test_comb = [ + { + QCOM_MODULE: [ + Triu(), # noqa: F405 + Triu(diagonal=1), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(3, 3),), + (torch.randn(1, 2, 3, 3),), + ], + }, + { + QCOM_MODULE: [ + TriuConstant(1), # noqa: F405 + TriuConstant(1, constant_dtype=torch.bool), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + (torch.zeros((5, 5)),), + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_unflatten(self): module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405 sample_input = (torch.randn([1, 24]),) @@ -4476,9 +4559,7 @@ def test_qnn_backend_skip_node_op(self): skip_node_op_set={"aten.add.Tensor"}, ) - @unittest.expectedFailure def test_qnn_backend_spill_fill_buffer_size(self): - # TODO: Fix self.assertNotEqual(0, max_sf_size) module = LargeTensorLinear() # noqa: F405 sample_input = (torch.randn(1, 256, 512),) backend_options = generate_htp_compiler_spec( From 52358a151d562e9dcc2c5c18eef8a1e6e753aa88 Mon Sep 17 00:00:00 2001 From: shewu Date: Tue, 2 Dec 2025 09:39:07 +0800 Subject: [PATCH 2/2] Fixed alphabetical order --- backends/qualcomm/tests/models.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 3e64199f45f..4ccf505e010 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1377,6 +1377,19 @@ def forward(self, x): return x + N +class LinalgVectorNorm(torch.nn.Module): + def __init__(self, ord=2.0, dim=None, keepdim=False): + super().__init__() + self.ord = ord + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.linalg.vector_norm( + x, ord=self.ord, dim=self.dim, keepdim=self.keepdim + ) + + class Linear(torch.nn.Module): def __init__(self, use_bias: bool = True): super().__init__() @@ -1406,19 +1419,6 @@ def forward(self, x): return q * k * v -class LinalgVectorNorm(torch.nn.Module): - def __init__(self, ord=2.0, dim=None, keepdim=False): - super().__init__() - self.ord = ord - self.dim = dim - self.keepdim = keepdim - - def forward(self, x): - return torch.linalg.vector_norm( - x, ord=self.ord, dim=self.dim, keepdim=self.keepdim - ) - - class Log(torch.nn.Module): def __init__(self): super().__init__()