From d7519018674c8b62a1180f3226a5c12f5cae77d0 Mon Sep 17 00:00:00 2001 From: Cheng-Hsin Weng Date: Mon, 18 Aug 2025 16:58:27 +0800 Subject: [PATCH 1/3] Qualcomm AI Engine Direct - Improve CLI tools --- backends/qualcomm/tests/models.py | 8 ++ backends/qualcomm/tests/test_qnn_delegate.py | 83 +++++++++++++- .../executor_runner/qnn_executor_runner.cpp | 5 + examples/qualcomm/util_scripts/cli.py | 103 +++++++++++++----- examples/qualcomm/utils.py | 27 ++--- 5 files changed, 184 insertions(+), 42 deletions(-) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index cdd0c194fe3..407e73e79ef 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -2042,6 +2042,14 @@ def forward(self, x, y): return torch.sub(x, y) +class Sub_y_x_from_x_y(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.sub(y, x) + + class SubAlpha(torch.nn.Module): def __init__(self, alpha): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index af79256591d..7fcff7fadd4 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -8266,6 +8266,9 @@ def test_export_example(self): class TestUtilsScript(TestQNN): + TestQNN.atol = 1e-1 + TestQNN.rtol = 1 + def required_envs(self, conditions=None) -> bool: conditions = [] if conditions is None else conditions return all( @@ -8407,13 +8410,91 @@ def test_cli(self): self.target, "--device", self.device, + "--host", + self.host, + "--build_folder", + self.build_folder, + "--input_list", + f"{tmp_dir}/input_list", + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/e_out/Result_0/output_0.pt")) + + def test_cli_with_input_list_assignment(self): + with tempfile.TemporaryDirectory() as tmp_dir: + sample_input = torch.randn(1, 2, 3, 4) + sample_input2 = torch.randn(1, 2, 3, 4) + ep = torch.export.export( + Sub_y_x_from_x_y(), (sample_input, sample_input2) + ) # noqa: F405 + torch.export.save(ep, f"{tmp_dir}/sub.pt2") + torch.save(sample_input, f"{tmp_dir}/input_0_0.pt") + torch.save(sample_input2, f"{tmp_dir}/input_0_1.pt") + with open(f"{tmp_dir}/input_list", "w") as f: + f.write(f"x:={tmp_dir}/input_0_0.pt y:={tmp_dir}/input_0_1.pt\n") + + # quantize + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "quantize", + "--artifact", + f"{tmp_dir}/sub.pt2", + "--output_folder", + f"{tmp_dir}/q_out", + "--input_list", + f"{tmp_dir}/input_list", + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/sub_quantized.pt2")) + # compile + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "compile", + "--artifact", + f"{tmp_dir}/q_out/sub_quantized.pt2", + "--output_folder", + f"{tmp_dir}/c_out", + "--model", + self.model, + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/sub_quantized.pte")) + self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/sub_quantized.svg")) + # execute + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "execute", + "--artifact", + f"{tmp_dir}/c_out/sub_quantized.pte", + "--output_folder", + f"{tmp_dir}/e_out", + "--model", + self.model, + "--target", + self.target, + "--device", + self.device, + "--host", + self.host, "--build_folder", self.build_folder, "--input_list", f"{tmp_dir}/input_list", ] + if self.host: + cmds.extend(["--host", self.host]) subprocess.run(cmds, stdout=subprocess.DEVNULL) - self.assertTrue(os.path.isfile(f"{tmp_dir}/e_out/output_0_0.pt")) + output_file = f"{tmp_dir}/e_out/Result_0/output_0.pt" + self.assertTrue(os.path.isfile(output_file)) + device_output = torch.load(output_file, weights_only=True) + golden_output = ep.module()(sample_input, sample_input2) + self._assert_outputs_equal(golden_output, device_output) def setup_environment(): diff --git a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp index 47f9f0cfb38..a69537fda04 100644 --- a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp @@ -424,6 +424,11 @@ int main(int argc, char** argv) { int inference_index = 0; double elapsed_time = 0; while (std::getline(input_list, file_path)) { + // to avoid case where \r\n is used as EOL + if (!file_path.empty() && file_path.back() == '\r') { + file_path.pop_back(); + } + auto input_files = split(file_path, " "); if (input_files.size() == 0) { break; diff --git a/examples/qualcomm/util_scripts/cli.py b/examples/qualcomm/util_scripts/cli.py index 5cd411ec42f..c0db48a5280 100644 --- a/examples/qualcomm/util_scripts/cli.py +++ b/examples/qualcomm/util_scripts/cli.py @@ -9,6 +9,7 @@ import logging import os import re +import shutil from pathlib import Path import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor @@ -34,16 +35,14 @@ to_edge_transform_and_lower_to_qnn, ) from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary -from executorch.examples.qualcomm.utils import ( - make_output_dir, - make_quantizer, - SimpleADB, -) +from executorch.examples.qualcomm.utils import make_quantizer, SimpleADB from executorch.exir import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from torchao.quantization import pt2e from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +INPUT_ORDER = "input_order" + def get_logger(): logger = logging.getLogger("examples.qualcomm.util_scripts.cli") @@ -74,6 +73,7 @@ def fill_tensor_info(info, qnn_tensors, category): "offset": encoding.data["offset"].tolist(), "axis": encoding.axis, } + info[category].append( { "name": tensor.GetName(), @@ -106,6 +106,26 @@ def fill_tensor_info(info, qnn_tensors, category): return tensor_info +class InputListParser: + def __init__(self, input_list): + self.input_list = input_list + + def __iter__(self): + with open(self.input_list, "r") as f: + for line in re.split(r"\r?\n", f.read()): + if not line: + continue + split_line = line.strip().split(" ") + inputs = {} + if ":=" in line: + for input_assignment in split_line: + name, path = input_assignment.split(":=") + inputs[name] = torch.load(path, weights_only=True) + else: + inputs = [torch.load(t, weights_only=True) for t in split_line] + yield inputs + + def quantize(args): logger = get_logger() @@ -131,15 +151,21 @@ def quantize(args): ep_prepared = prepare_pt2e(ep.module(), quantizer) logger.info(f"perform calibration on {args.artifact}") # step 2: perform calibration - with open(args.input_list, "r") as f: - for line in f.read().split("\n")[:-1]: - inputs = [torch.load(t, weights_only=True) for t in line.split(" ")] - ep_prepared(*inputs) + input_list_parser = InputListParser(args.input_list) + graph_input_names = [ + spec.arg.name + for spec in ep.graph_signature.input_specs + if spec.kind.name == "USER_INPUT" + ] + for inputs in input_list_parser: + if isinstance(inputs, dict): + inputs = [inputs[name] for name in graph_input_names] + ep_prepared(*inputs) # step 3: use convert_pt2e to fix encodings of QDQ pairs logger.info(f"saving calibrated model for {args.artifact}") ep_converted = convert_pt2e(ep_prepared) ep_quantized = torch.export.export(ep_converted, tuple(inputs)) - make_output_dir(args.output_folder) + os.makedirs(args.output_folder, exist_ok=True) torch.export.save( ep_quantized, f"{args.output_folder}/{Path(args.artifact).stem}_quantized.pt2" ) @@ -155,7 +181,7 @@ def compile(args): ) file_name, extension = Path(args.artifact).stem, Path(args.artifact).suffix - make_output_dir(args.output_folder) + os.makedirs(args.output_folder, exist_ok=True) # setup compiler spec dedicated to QNN HTP backend backend_options = generate_htp_compiler_spec(use_fp16=True) # setup general compiler spec for QNN @@ -201,12 +227,13 @@ def compile(args): for user_pass in user_passes: passes[user_pass][QCOM_PASS_ACTIVATE_KEY] = True - + input_order = {INPUT_ORDER: ep.graph_signature.user_inputs} edge_prog_mgr = to_edge_transform_and_lower_to_qnn( module=ep.module(), inputs=sample_inputs, compiler_specs=compiler_specs, passes_job=passes, + constant_methods=input_order, ) # step 2: write pte files and store final graph logger.info(f"exporting {file_name}.pte") @@ -227,15 +254,30 @@ def execute(args): pte_name = Path(args.artifact).stem + # get input order + from executorch.runtime import Runtime, Verification + + et_runtime = Runtime.get() + program = et_runtime.load_program( + args.artifact, + verification=Verification.Minimal, + ) + input_order_func = program.load_method(INPUT_ORDER) + input_order = input_order_func.execute([]) + # load input files logger.info("loading user inputs") + input_list_parser = InputListParser(args.input_list) user_inputs = [] - with open(args.input_list, "r") as f: - for line in f.read().split("\n")[:-1]: - inputs, input_names = [], "" - for data in line.split(" "): - input_names += f"{Path(data).stem}.raw " - inputs.append(torch.load(data, weights_only=True)) + for inputs in input_list_parser: + if isinstance(inputs, dict): + ordered_inputs = [] + # since io_info is dict and it is ordered in python + # we use it to reorder input assignments here + for name in input_order: + ordered_inputs.append(inputs[name]) + user_inputs.append(ordered_inputs) + else: user_inputs.append(inputs) logger.info("retrieving graph I/O") @@ -247,7 +289,6 @@ def execute(args): backend_options=backend_options, ) io_info = get_io_info(args.artifact, compiler_specs) - logger.info("preparing ADB connection") # leverage SimpleADB for e2e inference adb = SimpleADB( @@ -263,11 +304,16 @@ def execute(args): ) logger.info("pushing QNN libraries & other artifacts") + adb.push(inputs=user_inputs) logger.info("starting inference") adb.execute() + tmp_dir = f"{args.output_folder}/tmp_outputs" + os.makedirs(tmp_dir, exist_ok=True) + os.makedirs(args.output_folder, exist_ok=True) + def post_process(): torch_to_numpy_dtype_dict = { torch.bool: np.dtype("bool"), @@ -283,11 +329,14 @@ def post_process(): torch.complex128: np.dtype("complex128"), } output_info = io_info["outputs"] - output_folder = f"{args.output_folder}/outputs" - for _, f in enumerate(os.listdir(output_folder)): - filename = os.path.join(output_folder, f) - match_res = re.match(r".*([0-9]+)_([0-9]+)\.raw$", filename) + tmp_output_folder = f"{tmp_dir}/outputs" + for _, f in enumerate(os.listdir(tmp_output_folder)): + filename = os.path.join(tmp_output_folder, f) + match_res = re.match(r".*output_([0-9]+)_([0-9]+)\.raw$", filename) data_index, output_index = int(match_res.group(1)), int(match_res.group(2)) + + output_result_folder = f"{args.output_folder}/Result_{data_index}" + os.makedirs(output_result_folder, exist_ok=True) output = np.fromfile( filename, dtype=eval( @@ -297,13 +346,11 @@ def post_process(): output = torch.from_numpy( output.reshape(output_info[output_index]["shape"]) ) - torch.save( - output, f"{args.output_folder}/output_{data_index}_{output_index}.pt" - ) + torch.save(output, f"{output_result_folder}/output_{output_index}.pt") logger.info("collecting output data") - make_output_dir(args.output_folder) - adb.pull(args.output_folder, post_process) + adb.pull(tmp_dir, post_process) + shutil.rmtree(tmp_dir) logger.info(f"execution finished, please check {args.output_folder} for results") diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index bff7a0cb14f..920bad37ac4 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -148,23 +148,23 @@ def push(self, inputs=None, input_list=None, files=None, init_env=True): f"{self.build_path}/backends/qualcomm/libqnn_executorch_backend.so", f"{self.qnn_sdk}/lib/{self.target}/libQnnModelDlc.so", ] - input_list_file, input_files = generate_inputs( - self.working_dir, self.input_list_filename, inputs - ) + with tempfile.TemporaryDirectory() as tmp_dir: + input_list_file, input_files = generate_inputs( + tmp_dir, self.input_list_filename, inputs + ) - if input_list_file is not None: - # prepare input list - artifacts.append(input_list_file) + if input_list_file is not None: + # prepare input list + artifacts.append(input_list_file) - for artifact in artifacts: - self._adb(["push", artifact, self.workspace]) + for artifact in artifacts: + self._adb(["push", artifact, self.workspace]) - # input data - for file_name in input_files: - self._adb(["push", file_name, self.workspace]) + # input data + for file_name in input_files: + self._adb(["push", file_name, self.workspace]) - # dynamic shape related - with tempfile.TemporaryDirectory() as tmp_dir: + # dynamic shape related if self.expected_input_shape and self.expected_output_shape: shape_info = { "input_shape": self.expected_input_shape, @@ -956,6 +956,7 @@ def prepare_input_file(tensor, fd, index, sub_index): # Prepare input data if inputs is not None: input_list_file = f"{dest_path}/{file_name}" + with open(input_list_file, "w") as f: for idx, data in enumerate(inputs): sub_index = 0 From f243590851d9322b20c4c2c06478dd4d241f4944 Mon Sep 17 00:00:00 2001 From: chenweng-quic <168707118+chenweng-quic@users.noreply.github.com> Date: Tue, 2 Dec 2025 10:00:11 +0800 Subject: [PATCH 2/3] Update cli.py --- examples/qualcomm/util_scripts/cli.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/qualcomm/util_scripts/cli.py b/examples/qualcomm/util_scripts/cli.py index c0db48a5280..e969b66af3f 100644 --- a/examples/qualcomm/util_scripts/cli.py +++ b/examples/qualcomm/util_scripts/cli.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# This tool supports the QC internal QA pipeline by quantizing, compiling, +# and executing models under various configuration flags. + import argparse import importlib import logging From 1d29be5be7610510e2ab55118d0ce92da3984b6e Mon Sep 17 00:00:00 2001 From: chenweng-quic <168707118+chenweng-quic@users.noreply.github.com> Date: Tue, 2 Dec 2025 11:14:34 +0800 Subject: [PATCH 3/3] Update test_qnn_delegate.py --- backends/qualcomm/tests/test_qnn_delegate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 7fcff7fadd4..dc14373aa4f 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -8425,8 +8425,8 @@ def test_cli_with_input_list_assignment(self): sample_input = torch.randn(1, 2, 3, 4) sample_input2 = torch.randn(1, 2, 3, 4) ep = torch.export.export( - Sub_y_x_from_x_y(), (sample_input, sample_input2) - ) # noqa: F405 + Sub_y_x_from_x_y(), (sample_input, sample_input2) # noqa: F405 + ) torch.export.save(ep, f"{tmp_dir}/sub.pt2") torch.save(sample_input, f"{tmp_dir}/input_0_0.pt") torch.save(sample_input2, f"{tmp_dir}/input_0_1.pt")