Skip to content

Commit 16b9c46

Browse files
committed
Qualcomm AI Engine Direct - Support triu op and linear op with non-constant weights
Summary: - Add a pass to decompose triu op during quantization and export stage so that aten.ge.Scalar does not appear after to_edge. - Refactor linear op build to support non-constant weight and bias - Fixed doc - Fixed RMSNorm without eps
1 parent 1864fb0 commit 16b9c46

File tree

9 files changed

+243
-37
lines changed

9 files changed

+243
-37
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .decompose_roll import DecomposeRoll
2626
from .decompose_silu import DecomposeSilu
2727
from .decompose_threshold import DecomposeThreshold
28+
from .decompose_triu import DecomposeTriu
2829
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
2930
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
3031
from .fixed_linear_keep_dim import FixedLinearKeepDim
@@ -69,6 +70,7 @@
6970
DecomposeRoll,
7071
DecomposeSilu,
7172
DecomposeThreshold,
73+
DecomposeTriu,
7274
DecomposeWrapWithAutocast,
7375
ExpandBroadcastTensorShape,
7476
FixedLinearKeepDim,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from torch._decomp import get_decompositions
12+
from torch.fx.experimental.proxy_tensor import make_fx
13+
14+
from .utils import merge_decomposed_graph
15+
16+
17+
class DecomposeTriu(ExportPass):
18+
"""
19+
Decompose triu during quantization or export stage
20+
This allows LiftConstantScalarOperands to lift the scalar into a scalar_tensor.
21+
Otherwise, after to_edge, the triu operation will be decomposed into several operations that include aten.ge.Scalar.
22+
"""
23+
24+
def __init__(self) -> None:
25+
super().__init__()
26+
27+
def _replace_output(
28+
self, node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict
29+
):
30+
for user in node.users.copy():
31+
# remap
32+
user.replace_input_with(
33+
node,
34+
remap[output_node.args[0]],
35+
)
36+
37+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
38+
graph = graph_module.graph
39+
decom_mappings = get_decompositions([torch.ops.aten.triu.default])
40+
41+
for node in graph.nodes:
42+
if node.target == torch.ops.aten.triu.default:
43+
input_args = [node.args[0].meta["val"]]
44+
# The args[1], diagonal, is optional
45+
if len(node.args) > 1:
46+
input_args.append(node.args[1])
47+
decomposed_module = make_fx(
48+
node.target,
49+
decomposition_table=decom_mappings,
50+
tracing_mode="fake",
51+
)(*input_args)
52+
53+
with graph.inserting_before(node):
54+
# remap is used to map original node values to new node values,
55+
# which ensures that reference to nodes are correctly updated in the new graph
56+
remap = {}
57+
remap["arg0_1"] = node.args[0]
58+
59+
merge_decomposed_graph(
60+
remap=remap,
61+
target_node=node,
62+
target_graph=graph,
63+
decomposed_graph_module=decomposed_module,
64+
predicate=lambda decomp_node: "arg1_1" not in decomp_node.name,
65+
output_processor=self._replace_output,
66+
)
67+
graph.erase_node(node)
68+
69+
graph.eliminate_dead_code()
70+
graph_module.recompile()
71+
return PassResult(graph_module, True)

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class TensorOpInfo:
5555
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
5656
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
5757
aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),
58+
aten.masked_fill_.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),
5859
aten.bitwise_xor.Scalar: TensorOpInfo(aten.bitwise_xor.Tensor, False, False),
5960
}
6061

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DecomposeRoll,
3131
DecomposeSilu,
3232
DecomposeThreshold,
33+
DecomposeTriu,
3334
DecomposeWrapWithAutocast,
3435
ExpandBroadcastTensorShape,
3536
FixedLinearKeepDim,
@@ -204,6 +205,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
204205
self.add_pass(DecomposeRoll())
205206
self.add_pass(DecomposeSilu())
206207
self.add_pass(DecomposeThreshold())
208+
self.add_pass(DecomposeTriu())
207209
self.add_pass(DecomposeWrapWithAutocast())
208210
self.add_pass(DecomposeEinsum())
209211
self.add_pass(DecomposeExpM1())
@@ -222,6 +224,7 @@ def transform_for_export_pipeline(
222224
self.add_pass(DecomposeScaledDotProductAttention())
223225
self.add_pass(DecomposeRoll())
224226
self.add_pass(DecomposeThreshold())
227+
self.add_pass(DecomposeTriu())
225228
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
226229
self.add_pass(DecomposeExpM1())
227230
# DecomposeFloorDivide does not apply to the annotation pipeline,

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class MyModel(torch.nn.Module):
4141
```
4242
At the time we try to lower it with Qualcomm backend:
4343
```python
44-
from excutorch.examples.qualcomm.utils import build_executorch_binary
44+
from executorch.examples.qualcomm.utils import build_executorch_binary
4545

4646
build_executorch_binary(
4747
model=MyModel(),

backends/qualcomm/builders/op_linear.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from .node_visitor import NodeVisitor
1919
from .node_visitor_manager import register_node_visitor
2020
from .qnn_constants import OpFullyConnected, QNN_OP_PACKAGE_NAME_QTI_AISW
21-
from .utils import get_parameter
2221

2322

2423
@register_node_visitor
@@ -55,32 +54,26 @@ def define_node(
5554
quant_attrs[QCOM_ZERO_POINTS] = quant_attrs[QCOM_ZERO_POINTS].reshape(
5655
[-1, 1]
5756
)
58-
59-
weight_tensor = get_parameter(weight_node, self.edge_program)
57+
weight_tensor = self.get_tensor(weight_node, node)
6058
weight_tensor_wrapper = self.define_tensor(
6159
weight_node,
6260
node,
6361
weight_tensor,
64-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
62+
# It will determine correct QNN tensor type in define_tensor.
63+
# This param seems unnecessary, which we could possibly remove this in the future.
64+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
6565
nodes_to_wrappers,
6666
)
6767
linear_input_tensors.append(weight_tensor_wrapper)
6868

6969
if len(node.args) >= 3:
7070
bias_node = self.get_node(node.args[2])
71-
72-
bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
73-
bias_tensor = get_parameter(bias_node, self.edge_program)
74-
# if bias_node is getitem
75-
if bias_tensor is None:
76-
bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
77-
bias_tensor = bias_node.meta["val"]
78-
71+
bias_tensor = self.get_tensor(bias_node, node)
7972
bias_tensor_wrapper = self.define_tensor(
8073
bias_node,
8174
node,
8275
bias_tensor,
83-
bias_tensor_type,
76+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
8477
nodes_to_wrappers,
8578
)
8679
linear_input_tensors.append(bias_tensor_wrapper)

backends/qualcomm/builders/op_rms_norm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def define_node(
9393
nodes_to_wrappers,
9494
)
9595

96-
epsilon = node.args[3]
96+
epsilon = torch.finfo(torch.float32).eps
97+
if len(node.args) > 3:
98+
epsilon = node.args[3]
9799
output_tensor = self.get_tensor(node, node)
98100
output_tensor_wrapper = self.define_tensor(
99101
node,

backends/qualcomm/tests/models.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,12 +1283,18 @@ class LargeTensorLinear(torch.nn.Module):
12831283
def __init__(self):
12841284
super().__init__()
12851285
hidden_dim = 4096
1286-
self.linear1 = torch.nn.Linear(512, hidden_dim)
1286+
self.linear1_1 = torch.nn.Linear(512, hidden_dim)
1287+
self.linear1_2 = torch.nn.Linear(512, hidden_dim)
1288+
self.linear1_3 = torch.nn.Linear(512, hidden_dim)
12871289
self.linear2 = torch.nn.Linear(hidden_dim, 512)
1290+
self.linear3 = torch.nn.Linear(hidden_dim, 512)
1291+
self.linear4 = torch.nn.Linear(hidden_dim, 512)
12881292

12891293
def forward(self, x):
1290-
x1 = self.linear1(x) + self.linear1(x)
1291-
return self.linear2(x1)
1294+
x1 = self.linear1_1(x) + self.linear1_1(x)
1295+
x2 = self.linear1_2(x) + self.linear1_2(x)
1296+
x3 = self.linear1_3(x) + self.linear1_3(x)
1297+
return self.linear2(x1) * self.linear3(x2) * self.linear4(x3)
12921298

12931299

12941300
class LayerNorm(torch.nn.Module):
@@ -1380,6 +1386,26 @@ def forward(self, x):
13801386
return self.linear(x)
13811387

13821388

1389+
class LinearNonConstantWeight(torch.nn.Module):
1390+
def __init__(self):
1391+
super().__init__()
1392+
self.input_dim = 512
1393+
self.output_dim = 128
1394+
self.linear = torch.nn.Linear(self.input_dim, 3 * self.output_dim, True).eval()
1395+
1396+
def forward(self, x):
1397+
w_q, w_k, w_v = self.linear.weight.split(
1398+
[self.output_dim, self.output_dim, self.output_dim]
1399+
)
1400+
b_q, b_k, b_v = self.linear.bias.split(
1401+
[self.output_dim, self.output_dim, self.output_dim]
1402+
)
1403+
q = torch.nn.functional.linear(x, w_q, b_q)
1404+
k = torch.nn.functional.linear(x, w_k, b_k)
1405+
v = torch.nn.functional.linear(x, w_v, b_v)
1406+
return q * k * v
1407+
1408+
13831409
class LinalgVectorNorm(torch.nn.Module):
13841410
def __init__(self, ord=2.0, dim=None, keepdim=False):
13851411
super().__init__()
@@ -1814,10 +1840,11 @@ def forward(self, x):
18141840

18151841

18161842
class RmsNorm(torch.nn.Module):
1817-
def __init__(self):
1843+
def __init__(self, eps=None):
18181844
super().__init__()
1819-
self.eps = 1e-5
1820-
self.rms = torch.nn.RMSNorm([4], 1e-5)
1845+
self.rms = torch.nn.RMSNorm([4])
1846+
if eps:
1847+
self.rms = torch.nn.RMSNorm([4], eps)
18211848

18221849
def forward(self, x):
18231850
return self.rms(x)
@@ -2141,6 +2168,32 @@ def forward(self, x):
21412168
return a + self.idx_source[b]
21422169

21432170

2171+
class Triu(torch.nn.Module):
2172+
def __init__(self, diagonal: Optional[int] = None):
2173+
super().__init__()
2174+
self.diagonal = diagonal
2175+
2176+
def forward(self, x):
2177+
if self.diagonal:
2178+
return torch.triu(x, diagonal=self.diagonal)
2179+
return torch.triu(x)
2180+
2181+
2182+
class TriuConstant(torch.nn.Module):
2183+
def __init__(self, diagonal, constant_dtype=torch.float32):
2184+
super().__init__()
2185+
self.diagonal = diagonal
2186+
self.constant_dtype = constant_dtype
2187+
self.register_buffer("mask", torch.ones((5, 5), dtype=constant_dtype))
2188+
2189+
def forward(self, x):
2190+
mask = torch.triu(self.mask, diagonal=self.diagonal)
2191+
if self.constant_dtype == torch.bool:
2192+
mask = torch.zeros(x.shape, dtype=x.dtype).masked_fill_(mask, -10000.0)
2193+
# Add x to avoid no input in graph
2194+
return mask + x
2195+
2196+
21442197
class Unbind(torch.nn.Module):
21452198
def __init__(self):
21462199
super().__init__()

0 commit comments

Comments
 (0)