Skip to content

Commit 0aa8c63

Browse files
authored
Merge branch 'pytorch:main' into enable_qnn_sm8850
2 parents 782338d + 1864fb0 commit 0aa8c63

File tree

14 files changed

+499
-143
lines changed

14 files changed

+499
-143
lines changed

.github/workflows/pull.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,15 +862,24 @@ jobs:
862862
# Install Node.js and Emscripten
863863
source .ci/scripts/setup-emscripten.sh
864864
865+
export PNPM_VERSION=10.24.0
866+
867+
curl -fsSL https://get.pnpm.io/install.sh | env PNPM_VERSION=$PNPM_VERSION SHELL="$(which bash)" sh -
868+
869+
export PNPM_HOME="$HOME/.local/share/pnpm"
870+
export PATH="$PNPM_HOME:$PATH"
871+
872+
pnpm --version
873+
865874
# Test selective build
866875
bash scripts/build_wasm_tests.sh ${{ matrix.enable-etdump }}
867876
868877
# Install Jest
869878
cd cmake-out-wasm/extension/wasm/test
870-
npm install --save-dev jest
879+
pnpm add -D jest@30.2.0 --ignore-scripts
871880
872881
# Run unit test
873-
npm test
882+
pnpm test
874883
875884
unittest-nxp-neutron:
876885
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/cadence/aot/replace_ops.py

Lines changed: 88 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch.fx
2222
from executorch.backends.cadence.aot.compiler_utils import (
2323
get_shape,
24-
get_tensor_from_attr,
2524
get_zero_point,
2625
is_node_with_op,
2726
quantize_tensor_multiplier,
@@ -321,90 +320,106 @@ def call_operator(self, op, args, kwargs, meta):
321320

322321

323322
@register_cadence_pass(CadencePassAttribute(opt_level=1))
324-
class ReplaceAddMMWithLinearPass(ExportPass):
323+
class ReplaceAddMMWithLinearPass(RemoveOrReplacePassInterface):
325324
"""
326325
This pass replaces addmm with linear op.
326+
327+
AddMM computes: beta*bias + alpha*mm(mat1, mat2)
328+
Linear computes: mat1 @ weight.T + bias
329+
327330
"""
328331

329-
def __init__(self):
330-
super().__init__()
331-
self.counter = 0
332+
@property
333+
def targets(self) -> list[EdgeOpOverload]:
334+
return [exir_ops.edge.aten.addmm.default]
332335

333-
def replace_addmm_with_linear(self, graph_module: torch.fx.GraphModule):
334-
graph = graph_module.graph
335-
for node in graph.nodes:
336-
# We are only interested in admm nodes
337-
if node.target != exir_ops.edge.aten.addmm.default:
338-
continue
336+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
337+
# The addmm op has three concrete args: bias, mat1, mat2
338+
assert len(node.args) >= 3
339+
(bias, mat1, mat2) = node.args[0:3]
339340

340-
# The addmm op has three concrete args: input, mat1, mat2
341-
assert len(node.args) >= 3
342-
(bias, mat1, mat2) = node.args[0:3]
343-
# The other two args are optional scale args
344-
beta = node.kwargs.get("beta", 1.0)
345-
alpha = node.kwargs.get("alpha", 1.0)
346-
347-
# AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert
348-
# it to linear op by multiplying beta to bias, and alpha to mat2.t().
349-
# However, the following two conditions must hold:
350-
# a. If bias is not a param, then beta must be 1.0
351-
# b. If mat2 is not a param, then mat2 must be a transpose op. Also,
352-
# the input to the transpose must be a param, or alpha must be 1.0.
353-
fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0
354-
fit_mat2 = is_node_with_op(mat2, "get_attr")
355-
transposed_mat2 = False
356-
if (
357-
not fit_mat2
358-
and is_node_with_op(mat2, "call_function")
359-
and mat2.target == exir_ops.edge.aten.transpose_copy.int
360-
):
361-
mat2, transposed_mat2 = mat2.args[0], True
362-
fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0
341+
# The other two args are optional scale args
342+
beta = float(node.kwargs.get("beta", 1.0))
343+
alpha = float(node.kwargs.get("alpha", 1.0))
363344

364-
if not fit_bias or not fit_mat2:
365-
continue
345+
bias, mat1, mat2 = cast(
346+
tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node],
347+
(bias, mat1, mat2),
348+
)
349+
350+
graph = node.graph
351+
352+
# Handle transpose: if mat2 is a transpose op, extract the original tensor
353+
transposed_mat2 = False
354+
if (
355+
mat2.op == "call_function"
356+
and mat2.target == exir_ops.edge.aten.transpose_copy.int
357+
):
358+
# mat2 is already transposed, so we use the input to the transpose
359+
mat2 = cast(torch.fx.Node, mat2.args[0])
360+
transposed_mat2 = True
361+
362+
# Multiply bias by beta if needed
363+
if beta != 1.0:
364+
# Create a scaled bias using element-wise multiplication in the graph
365+
with graph.inserting_before(node):
366+
beta_scalar = graph.call_function(
367+
exir_ops.edge.aten.full.default,
368+
args=([1], beta),
369+
kwargs={"dtype": torch.float32},
370+
)
371+
beta_scalar.meta = node.meta
372+
bias = graph.call_function(
373+
exir_ops.edge.aten.mul.Tensor,
374+
args=(bias, beta_scalar),
375+
)
366376

367-
# Multiply bias by beta
368-
if beta != 1.0:
369-
assert is_node_with_op(bias, "get_attr")
370-
bias_tensor = get_tensor_from_attr(graph_module, bias)
371-
assert isinstance(bias_tensor, torch.Tensor)
372-
bias_tensor = beta * bias_tensor
373-
with graph.inserting_before(node):
374-
bias_name = f"_bias_addmm_to_linear_{self.counter}"
375-
graph_module.register_buffer(bias_name, bias_tensor)
376-
bias = graph.get_attr(bias_name)
377-
378-
# Use associativity of scalar multiplication, and multiply alpha to mat2
379-
if is_node_with_op(mat2, "get_attr"):
380-
mat2_tensor = get_tensor_from_attr(graph_module, mat2)
381-
assert isinstance(mat2_tensor, torch.Tensor)
382-
mat2_tensor = alpha * mat2_tensor
383-
# transpose mat2
384-
mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t()
385-
with graph.inserting_before(node):
386-
mat2_name = f"_mat2_addmm_to_linear_{self.counter}"
387-
graph_module.register_buffer(mat2_name, mat2_tensor)
388-
mat2 = graph.get_attr(mat2_name)
389-
390-
# Construct the linear node
391-
linear_args = (mat1, mat2, bias)
377+
# Metadata copy important
378+
bias.meta = node.meta
379+
380+
# Multiply mat2 by alpha if needed
381+
if alpha != 1.0:
392382
with graph.inserting_before(node):
393-
linear_node = graph.call_function(
394-
exir_ops.edge.aten.linear.default, args=linear_args
383+
alpha_scalar = graph.call_function(
384+
exir_ops.edge.aten.full.default,
385+
args=([1], alpha),
386+
kwargs={"dtype": torch.float32},
387+
)
388+
alpha_scalar.meta = node.meta
389+
mat2 = graph.call_function(
390+
exir_ops.edge.aten.mul.Tensor,
391+
args=(mat2, alpha_scalar),
395392
)
396-
linear_node.meta = node.meta
397-
# Replace all the uses of the addmm op with linear op
398-
node.replace_all_uses_with(linear_node)
399-
self.counter += 1
400393

401-
graph_module.recompile()
402-
graph_module.graph.eliminate_dead_code()
394+
# Metadata copy important
395+
mat2.meta = node.meta
403396

404-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
405-
self.replace_addmm_with_linear(graph_module)
406-
result = super().call(graph_module)
407-
return result
397+
# Transpose mat2 if it wasn't already transposed
398+
if not transposed_mat2:
399+
with graph.inserting_before(node):
400+
mat2 = graph.call_function(
401+
exir_ops.edge.aten.transpose_copy.int,
402+
args=(mat2, -1, -2),
403+
)
404+
405+
# Metadata copy important
406+
mat2.meta = node.meta
407+
408+
# Construct the linear node: linear(input, weight, bias)
409+
# linear computes: input @ weight.T + bias
410+
linear_args = (mat1, mat2, bias)
411+
with graph.inserting_before(node):
412+
linear_node = graph.call_function(
413+
exir_ops.edge.aten.linear.default,
414+
args=linear_args,
415+
)
416+
417+
# Metadata copy important
418+
linear_node.meta = node.meta
419+
420+
# Replace all uses of the addmm op with linear op
421+
node.replace_all_uses_with(linear_node)
422+
return True
408423

409424

410425
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,19 @@ def validate(
6565
modified: torch.fx.GraphModule,
6666
inputs: tuple[torch.Tensor, ...] | list[torch.Tensor],
6767
pass_name: str,
68+
rtol: float = 1e-5,
69+
atol: float = 1e-6,
6870
) -> None:
71+
"""Validate that two graph modules produce numerically equivalent outputs.
72+
73+
Args:
74+
original: The original graph module before the pass
75+
modified: The modified graph module after the pass
76+
inputs: Input tensors to run through both graphs
77+
pass_name: Name of the pass being validated (for error messages)
78+
rtol: Relative tolerance for allclose comparison
79+
atol: Absolute tolerance for allclose comparison
80+
"""
6981
original.eval()
7082
modified.eval()
7183
with torch.no_grad():
@@ -74,10 +86,17 @@ def validate(
7486

7587
flat_orig_out, _ = pytree.tree_flatten(orig_out)
7688
flat_mod_out, _ = pytree.tree_flatten(mod_out)
77-
if not all(pytree.tree_map(torch.equal, flat_orig_out, flat_mod_out)):
78-
raise AssertionError(
79-
f"Pass validation failed with exact match for pass {pass_name}. Original graph {original} and modified graph {modified}"
80-
)
89+
90+
# Check that outputs match within tolerance
91+
for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)):
92+
if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol):
93+
max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item()
94+
raise AssertionError(
95+
f"Pass validation failed for pass {pass_name}. "
96+
f"Output tensor {i} differs by max {max_diff:.6e}. "
97+
f"Expected rtol={rtol}, atol={atol}. "
98+
f"Original output: {orig_tensor}, Modified output: {mod_tensor}"
99+
)
81100

82101

83102
class TestReplaceOpsPasses(unittest.TestCase):
@@ -840,10 +859,10 @@ def test_replace_scalar_tensor_with_full(
840859
def test_replace_linear_with_fully_connected(self) -> None:
841860
shape, in_channels, out_channels = (1, 14), 14, 128
842861
builder = GraphBuilder()
843-
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
844-
weights = builder.placeholder(
845-
"weights", torch.randn([out_channels, in_channels], dtype=torch.float32)
846-
)
862+
x_input = torch.randn(*shape, dtype=torch.float32)
863+
weights_input = torch.randn([out_channels, in_channels], dtype=torch.float32)
864+
x = builder.placeholder("x", x_input)
865+
weights = builder.placeholder("weights", weights_input)
847866
permute_copy = builder.call_operator(
848867
op=exir_ops.edge.aten.permute_copy.default,
849868
args=(weights, [1, 0]),
@@ -854,14 +873,31 @@ def test_replace_linear_with_fully_connected(self) -> None:
854873
)
855874
builder.output([mm])
856875
original_gm = builder.get_graph_module()
876+
857877
gm = cast(
858878
PassResult, ReplacePermuteWithTransposePass()(original_gm)
859879
).graph_module
860880
gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module
861-
gm = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)).graph_module
881+
882+
gm_before_linear = copy.deepcopy(gm)
883+
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
884+
self.assertTrue(pass_result.modified)
885+
gm = pass_result.graph_module
886+
887+
inputs = [x_input, weights_input]
888+
validate(gm_before_linear, gm, inputs, "ReplaceAddMMWithLinearPass")
889+
gm_before_fc = copy.deepcopy(gm)
862890
graph_after_passes = cast(
863891
PassResult, ReplaceLinearWithFullyConnectedOpPass()(gm)
864892
).graph_module
893+
894+
validate(
895+
gm_before_fc,
896+
graph_after_passes,
897+
inputs,
898+
"ReplaceLinearWithFullyConnectedOpPass",
899+
)
900+
865901
self.assertIsNotNone(graph_after_passes)
866902
self.assertEqual(
867903
count_node(graph_after_passes, exir_ops.edge.aten.full.default),
@@ -878,21 +914,17 @@ def test_replace_linear_with_fully_connected(self) -> None:
878914
0,
879915
)
880916

881-
@expand(
882-
[
883-
[(4, 16, 256), 256, 512, True],
884-
[(7, 17, 12), 12, 34, False],
885-
]
886-
)
917+
@expand([[1.0, 1.0], [2.0, 3.0]])
887918
@torch.no_grad()
888-
def test_replace_addmm_with_linear(
889-
self, shape: Tuple[int], in_features: int, out_features: int, bias: bool
890-
) -> None:
891-
M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0
919+
def test_replace_addmm_with_linear(self, alpha: float, beta: float) -> None:
920+
M, K, N = 14, 12, 10
892921
builder = GraphBuilder()
893-
x = builder.placeholder("x", torch.randn(N, dtype=torch.float32))
894-
y = builder.placeholder("y", torch.randn([M, K], dtype=torch.float32))
895-
z = builder.placeholder("z", torch.randn([N, K], dtype=torch.float32))
922+
x_input = torch.randn(N, dtype=torch.float32)
923+
y_input = torch.randn([M, K], dtype=torch.float32)
924+
z_input = torch.randn([N, K], dtype=torch.float32)
925+
x = builder.placeholder("x", x_input)
926+
y = builder.placeholder("y", y_input)
927+
z = builder.placeholder("z", z_input)
896928
permute_copy = builder.call_operator(
897929
op=exir_ops.edge.aten.permute_copy.default,
898930
args=(z, [1, 0]),
@@ -904,12 +936,21 @@ def test_replace_addmm_with_linear(
904936
)
905937
builder.output([addmm])
906938
original_gm = builder.get_graph_module()
939+
907940
gm = cast(
908941
PassResult, ReplacePermuteWithTransposePass()(original_gm)
909942
).graph_module
910-
graph_after_passes = cast(
911-
PassResult, ReplaceAddMMWithLinearPass()(gm)
912-
).graph_module
943+
944+
gm_before_linear = copy.deepcopy(gm)
945+
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
946+
self.assertTrue(pass_result.modified)
947+
graph_after_passes = pass_result.graph_module
948+
949+
inputs = [x_input, y_input, z_input]
950+
validate(
951+
gm_before_linear, graph_after_passes, inputs, "ReplaceAddMMWithLinearPass"
952+
)
953+
913954
self.assertIsNotNone(graph_after_passes)
914955
self.assertEqual(
915956
count_node(graph_after_passes, exir_ops.edge.aten.linear.default),

backends/cortex_m/test/targets.bzl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def define_operator_test_target(op):
2323
"//executorch/runtime/kernel:kernel_includes",
2424
"//executorch/kernels/test:test_util",
2525
"//executorch/backends/cortex_m/ops:op_{}".format(op),
26-
"//executorch/backends/cortex_m/ops:cortex_m_generated_lib",
26+
"//executorch/backends/cortex_m/ops:op_quantize_per_tensor",
27+
"//executorch/backends/cortex_m/ops:op_dequantize_per_tensor",
2728
"//executorch/backends/cortex_m/ops:cortex_m_generated_lib_headers",
2829
]
2930
)

0 commit comments

Comments
 (0)