Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .decompose_roll import DecomposeRoll
from .decompose_silu import DecomposeSilu
from .decompose_threshold import DecomposeThreshold
from .decompose_triu import DecomposeTriu
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
from .fixed_linear_keep_dim import FixedLinearKeepDim
Expand Down Expand Up @@ -69,6 +70,7 @@
DecomposeRoll,
DecomposeSilu,
DecomposeThreshold,
DecomposeTriu,
DecomposeWrapWithAutocast,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
Expand Down
71 changes: 71 additions & 0 deletions backends/qualcomm/_passes/decompose_triu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.exir.pass_base import ExportPass, PassResult
from torch._decomp import get_decompositions
from torch.fx.experimental.proxy_tensor import make_fx

from .utils import merge_decomposed_graph


class DecomposeTriu(ExportPass):
"""
Decompose triu during quantization or export stage
This allows LiftConstantScalarOperands to lift the scalar into a scalar_tensor.
Otherwise, after to_edge, the triu operation will be decomposed into several operations that include aten.ge.Scalar.
"""

def __init__(self) -> None:
super().__init__()

def _replace_output(
self, node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict
):
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[output_node.args[0]],
)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
decom_mappings = get_decompositions([torch.ops.aten.triu.default])

for node in graph.nodes:
if node.target == torch.ops.aten.triu.default:
input_args = [node.args[0].meta["val"]]
# The args[1], diagonal, is optional
if len(node.args) > 1:
input_args.append(node.args[1])
decomposed_module = make_fx(
node.target,
decomposition_table=decom_mappings,
tracing_mode="fake",
)(*input_args)

with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {}
remap["arg0_1"] = node.args[0]

merge_decomposed_graph(
remap=remap,
target_node=node,
target_graph=graph,
decomposed_graph_module=decomposed_module,
predicate=lambda decomp_node: "arg1_1" not in decomp_node.name,
output_processor=self._replace_output,
)
graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class TensorOpInfo:
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),
aten.masked_fill_.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),
aten.bitwise_xor.Scalar: TensorOpInfo(aten.bitwise_xor.Tensor, False, False),
}

Expand Down
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DecomposeRoll,
DecomposeSilu,
DecomposeThreshold,
DecomposeTriu,
DecomposeWrapWithAutocast,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
Expand Down Expand Up @@ -204,6 +205,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeRoll())
self.add_pass(DecomposeSilu())
self.add_pass(DecomposeThreshold())
self.add_pass(DecomposeTriu())
self.add_pass(DecomposeWrapWithAutocast())
self.add_pass(DecomposeEinsum())
self.add_pass(DecomposeExpM1())
Expand All @@ -222,6 +224,7 @@ def transform_for_export_pipeline(
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
self.add_pass(DecomposeThreshold())
self.add_pass(DecomposeTriu())
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(DecomposeExpM1())
# DecomposeFloorDivide does not apply to the annotation pipeline,
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MyModel(torch.nn.Module):
```
At the time we try to lower it with Qualcomm backend:
```python
from excutorch.examples.qualcomm.utils import build_executorch_binary
from executorch.examples.qualcomm.utils import build_executorch_binary

build_executorch_binary(
model=MyModel(),
Expand Down
19 changes: 6 additions & 13 deletions backends/qualcomm/builders/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpFullyConnected, QNN_OP_PACKAGE_NAME_QTI_AISW
from .utils import get_parameter


@register_node_visitor
Expand Down Expand Up @@ -55,32 +54,26 @@ def define_node(
quant_attrs[QCOM_ZERO_POINTS] = quant_attrs[QCOM_ZERO_POINTS].reshape(
[-1, 1]
)

weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor = self.get_tensor(weight_node, node)
weight_tensor_wrapper = self.define_tensor(
weight_node,
node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
# It will determine correct QNN tensor type in define_tensor.
# This param seems unnecessary, which we could possibly remove this in the future.
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
linear_input_tensors.append(weight_tensor_wrapper)

if len(node.args) >= 3:
bias_node = self.get_node(node.args[2])

bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
bias_tensor = get_parameter(bias_node, self.edge_program)
# if bias_node is getitem
if bias_tensor is None:
bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
bias_tensor = bias_node.meta["val"]

bias_tensor = self.get_tensor(bias_node, node)
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
bias_tensor,
bias_tensor_type,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
linear_input_tensors.append(bias_tensor_wrapper)
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/op_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def define_node(
nodes_to_wrappers,
)

epsilon = node.args[3]
epsilon = torch.finfo(torch.float32).eps
if len(node.args) > 3:
epsilon = node.args[3]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
Expand Down
79 changes: 66 additions & 13 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,12 +1283,18 @@ class LargeTensorLinear(torch.nn.Module):
def __init__(self):
super().__init__()
hidden_dim = 4096
self.linear1 = torch.nn.Linear(512, hidden_dim)
self.linear1_1 = torch.nn.Linear(512, hidden_dim)
self.linear1_2 = torch.nn.Linear(512, hidden_dim)
self.linear1_3 = torch.nn.Linear(512, hidden_dim)
self.linear2 = torch.nn.Linear(hidden_dim, 512)
self.linear3 = torch.nn.Linear(hidden_dim, 512)
self.linear4 = torch.nn.Linear(hidden_dim, 512)

def forward(self, x):
x1 = self.linear1(x) + self.linear1(x)
return self.linear2(x1)
x1 = self.linear1_1(x) + self.linear1_1(x)
x2 = self.linear1_2(x) + self.linear1_2(x)
x3 = self.linear1_3(x) + self.linear1_3(x)
return self.linear2(x1) * self.linear3(x2) * self.linear4(x3)


class LayerNorm(torch.nn.Module):
Expand Down Expand Up @@ -1371,6 +1377,19 @@ def forward(self, x):
return x + N


class LinalgVectorNorm(torch.nn.Module):
def __init__(self, ord=2.0, dim=None, keepdim=False):
super().__init__()
self.ord = ord
self.dim = dim
self.keepdim = keepdim

def forward(self, x):
return torch.linalg.vector_norm(
x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
)


class Linear(torch.nn.Module):
def __init__(self, use_bias: bool = True):
super().__init__()
Expand All @@ -1380,17 +1399,24 @@ def forward(self, x):
return self.linear(x)


class LinalgVectorNorm(torch.nn.Module):
def __init__(self, ord=2.0, dim=None, keepdim=False):
class LinearNonConstantWeight(torch.nn.Module):
def __init__(self):
super().__init__()
self.ord = ord
self.dim = dim
self.keepdim = keepdim
self.input_dim = 512
self.output_dim = 128
self.linear = torch.nn.Linear(self.input_dim, 3 * self.output_dim, True).eval()

def forward(self, x):
return torch.linalg.vector_norm(
x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
w_q, w_k, w_v = self.linear.weight.split(
[self.output_dim, self.output_dim, self.output_dim]
)
b_q, b_k, b_v = self.linear.bias.split(
[self.output_dim, self.output_dim, self.output_dim]
)
q = torch.nn.functional.linear(x, w_q, b_q)
k = torch.nn.functional.linear(x, w_k, b_k)
v = torch.nn.functional.linear(x, w_v, b_v)
return q * k * v


class Log(torch.nn.Module):
Expand Down Expand Up @@ -1814,10 +1840,11 @@ def forward(self, x):


class RmsNorm(torch.nn.Module):
def __init__(self):
def __init__(self, eps=None):
super().__init__()
self.eps = 1e-5
self.rms = torch.nn.RMSNorm([4], 1e-5)
self.rms = torch.nn.RMSNorm([4])
if eps:
self.rms = torch.nn.RMSNorm([4], eps)

def forward(self, x):
return self.rms(x)
Expand Down Expand Up @@ -2141,6 +2168,32 @@ def forward(self, x):
return a + self.idx_source[b]


class Triu(torch.nn.Module):
def __init__(self, diagonal: Optional[int] = None):
super().__init__()
self.diagonal = diagonal

def forward(self, x):
if self.diagonal:
return torch.triu(x, diagonal=self.diagonal)
return torch.triu(x)


class TriuConstant(torch.nn.Module):
def __init__(self, diagonal, constant_dtype=torch.float32):
super().__init__()
self.diagonal = diagonal
self.constant_dtype = constant_dtype
self.register_buffer("mask", torch.ones((5, 5), dtype=constant_dtype))

def forward(self, x):
mask = torch.triu(self.mask, diagonal=self.diagonal)
if self.constant_dtype == torch.bool:
mask = torch.zeros(x.shape, dtype=x.dtype).masked_fill_(mask, -10000.0)
# Add x to avoid no input in graph
return mask + x


class Unbind(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading