Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,6 +2042,14 @@ def forward(self, x, y):
return torch.sub(x, y)


class Sub_y_x_from_x_y(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.sub(y, x)


class SubAlpha(torch.nn.Module):
def __init__(self, alpha):
super().__init__()
Expand Down
83 changes: 82 additions & 1 deletion backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8266,6 +8266,9 @@ def test_export_example(self):


class TestUtilsScript(TestQNN):
TestQNN.atol = 1e-1
TestQNN.rtol = 1

def required_envs(self, conditions=None) -> bool:
conditions = [] if conditions is None else conditions
return all(
Expand Down Expand Up @@ -8407,13 +8410,91 @@ def test_cli(self):
self.target,
"--device",
self.device,
"--host",
self.host,
"--build_folder",
self.build_folder,
"--input_list",
f"{tmp_dir}/input_list",
]
subprocess.run(cmds, stdout=subprocess.DEVNULL)
self.assertTrue(os.path.isfile(f"{tmp_dir}/e_out/Result_0/output_0.pt"))

def test_cli_with_input_list_assignment(self):
with tempfile.TemporaryDirectory() as tmp_dir:
sample_input = torch.randn(1, 2, 3, 4)
sample_input2 = torch.randn(1, 2, 3, 4)
ep = torch.export.export(
Sub_y_x_from_x_y(), (sample_input, sample_input2) # noqa: F405
)
torch.export.save(ep, f"{tmp_dir}/sub.pt2")
torch.save(sample_input, f"{tmp_dir}/input_0_0.pt")
torch.save(sample_input2, f"{tmp_dir}/input_0_1.pt")
with open(f"{tmp_dir}/input_list", "w") as f:
f.write(f"x:={tmp_dir}/input_0_0.pt y:={tmp_dir}/input_0_1.pt\n")

# quantize
cmds = [
"python",
"-m",
"examples.qualcomm.util_scripts.cli",
"quantize",
"--artifact",
f"{tmp_dir}/sub.pt2",
"--output_folder",
f"{tmp_dir}/q_out",
"--input_list",
f"{tmp_dir}/input_list",
]
subprocess.run(cmds, stdout=subprocess.DEVNULL)
self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/sub_quantized.pt2"))
# compile
cmds = [
"python",
"-m",
"examples.qualcomm.util_scripts.cli",
"compile",
"--artifact",
f"{tmp_dir}/q_out/sub_quantized.pt2",
"--output_folder",
f"{tmp_dir}/c_out",
"--model",
self.model,
]
subprocess.run(cmds, stdout=subprocess.DEVNULL)
self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/sub_quantized.pte"))
self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/sub_quantized.svg"))
# execute
cmds = [
"python",
"-m",
"examples.qualcomm.util_scripts.cli",
"execute",
"--artifact",
f"{tmp_dir}/c_out/sub_quantized.pte",
"--output_folder",
f"{tmp_dir}/e_out",
"--model",
self.model,
"--target",
self.target,
"--device",
self.device,
"--host",
self.host,
"--build_folder",
self.build_folder,
"--input_list",
f"{tmp_dir}/input_list",
]
if self.host:
cmds.extend(["--host", self.host])
subprocess.run(cmds, stdout=subprocess.DEVNULL)
self.assertTrue(os.path.isfile(f"{tmp_dir}/e_out/output_0_0.pt"))
output_file = f"{tmp_dir}/e_out/Result_0/output_0.pt"
self.assertTrue(os.path.isfile(output_file))
device_output = torch.load(output_file, weights_only=True)
golden_output = ep.module()(sample_input, sample_input2)
self._assert_outputs_equal(golden_output, device_output)


def setup_environment():
Expand Down
5 changes: 5 additions & 0 deletions examples/qualcomm/executor_runner/qnn_executor_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,11 @@ int main(int argc, char** argv) {
int inference_index = 0;
double elapsed_time = 0;
while (std::getline(input_list, file_path)) {
// to avoid case where \r\n is used as EOL
if (!file_path.empty() && file_path.back() == '\r') {
file_path.pop_back();
}

auto input_files = split(file_path, " ");
if (input_files.size() == 0) {
break;
Expand Down
106 changes: 78 additions & 28 deletions examples/qualcomm/util_scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# This tool supports the QC internal QA pipeline by quantizing, compiling,
# and executing models under various configuration flags.

import argparse
import importlib
import logging
import os
import re
import shutil
from pathlib import Path

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
Expand All @@ -34,16 +38,14 @@
to_edge_transform_and_lower_to_qnn,
)
from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary
from executorch.examples.qualcomm.utils import (
make_output_dir,
make_quantizer,
SimpleADB,
)
from executorch.examples.qualcomm.utils import make_quantizer, SimpleADB
from executorch.exir import ExecutorchBackendConfig
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from torchao.quantization import pt2e
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

INPUT_ORDER = "input_order"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a side note, it seems this cli tool is for QA and not necessarily for community users, can we document it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks.



def get_logger():
logger = logging.getLogger("examples.qualcomm.util_scripts.cli")
Expand Down Expand Up @@ -74,6 +76,7 @@ def fill_tensor_info(info, qnn_tensors, category):
"offset": encoding.data["offset"].tolist(),
"axis": encoding.axis,
}

info[category].append(
{
"name": tensor.GetName(),
Expand Down Expand Up @@ -106,6 +109,26 @@ def fill_tensor_info(info, qnn_tensors, category):
return tensor_info


class InputListParser:
def __init__(self, input_list):
self.input_list = input_list

def __iter__(self):
with open(self.input_list, "r") as f:
for line in re.split(r"\r?\n", f.read()):
if not line:
continue
split_line = line.strip().split(" ")
inputs = {}
if ":=" in line:
for input_assignment in split_line:
name, path = input_assignment.split(":=")
inputs[name] = torch.load(path, weights_only=True)
else:
inputs = [torch.load(t, weights_only=True) for t in split_line]
yield inputs


def quantize(args):
logger = get_logger()

Expand All @@ -131,15 +154,21 @@ def quantize(args):
ep_prepared = prepare_pt2e(ep.module(), quantizer)
logger.info(f"perform calibration on {args.artifact}")
# step 2: perform calibration
with open(args.input_list, "r") as f:
for line in f.read().split("\n")[:-1]:
inputs = [torch.load(t, weights_only=True) for t in line.split(" ")]
ep_prepared(*inputs)
input_list_parser = InputListParser(args.input_list)
graph_input_names = [
spec.arg.name
for spec in ep.graph_signature.input_specs
if spec.kind.name == "USER_INPUT"
]
for inputs in input_list_parser:
if isinstance(inputs, dict):
inputs = [inputs[name] for name in graph_input_names]
ep_prepared(*inputs)
# step 3: use convert_pt2e to fix encodings of QDQ pairs
logger.info(f"saving calibrated model for {args.artifact}")
ep_converted = convert_pt2e(ep_prepared)
ep_quantized = torch.export.export(ep_converted, tuple(inputs))
make_output_dir(args.output_folder)
os.makedirs(args.output_folder, exist_ok=True)
torch.export.save(
ep_quantized, f"{args.output_folder}/{Path(args.artifact).stem}_quantized.pt2"
)
Expand All @@ -155,7 +184,7 @@ def compile(args):
)

file_name, extension = Path(args.artifact).stem, Path(args.artifact).suffix
make_output_dir(args.output_folder)
os.makedirs(args.output_folder, exist_ok=True)
# setup compiler spec dedicated to QNN HTP backend
backend_options = generate_htp_compiler_spec(use_fp16=True)
# setup general compiler spec for QNN
Expand Down Expand Up @@ -201,12 +230,13 @@ def compile(args):

for user_pass in user_passes:
passes[user_pass][QCOM_PASS_ACTIVATE_KEY] = True

input_order = {INPUT_ORDER: ep.graph_signature.user_inputs}
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
module=ep.module(),
inputs=sample_inputs,
compiler_specs=compiler_specs,
passes_job=passes,
constant_methods=input_order,
)
# step 2: write pte files and store final graph
logger.info(f"exporting {file_name}.pte")
Expand All @@ -227,15 +257,30 @@ def execute(args):

pte_name = Path(args.artifact).stem

# get input order
from executorch.runtime import Runtime, Verification

et_runtime = Runtime.get()
program = et_runtime.load_program(
args.artifact,
verification=Verification.Minimal,
)
input_order_func = program.load_method(INPUT_ORDER)
input_order = input_order_func.execute([])

# load input files
logger.info("loading user inputs")
input_list_parser = InputListParser(args.input_list)
user_inputs = []
with open(args.input_list, "r") as f:
for line in f.read().split("\n")[:-1]:
inputs, input_names = [], ""
for data in line.split(" "):
input_names += f"{Path(data).stem}.raw "
inputs.append(torch.load(data, weights_only=True))
for inputs in input_list_parser:
if isinstance(inputs, dict):
ordered_inputs = []
# since io_info is dict and it is ordered in python
# we use it to reorder input assignments here
for name in input_order:
ordered_inputs.append(inputs[name])
user_inputs.append(ordered_inputs)
else:
user_inputs.append(inputs)

logger.info("retrieving graph I/O")
Expand All @@ -247,7 +292,6 @@ def execute(args):
backend_options=backend_options,
)
io_info = get_io_info(args.artifact, compiler_specs)

logger.info("preparing ADB connection")
# leverage SimpleADB for e2e inference
adb = SimpleADB(
Expand All @@ -263,11 +307,16 @@ def execute(args):
)

logger.info("pushing QNN libraries & other artifacts")

adb.push(inputs=user_inputs)

logger.info("starting inference")
adb.execute()

tmp_dir = f"{args.output_folder}/tmp_outputs"
os.makedirs(tmp_dir, exist_ok=True)
os.makedirs(args.output_folder, exist_ok=True)

def post_process():
torch_to_numpy_dtype_dict = {
torch.bool: np.dtype("bool"),
Expand All @@ -283,11 +332,14 @@ def post_process():
torch.complex128: np.dtype("complex128"),
}
output_info = io_info["outputs"]
output_folder = f"{args.output_folder}/outputs"
for _, f in enumerate(os.listdir(output_folder)):
filename = os.path.join(output_folder, f)
match_res = re.match(r".*([0-9]+)_([0-9]+)\.raw$", filename)
tmp_output_folder = f"{tmp_dir}/outputs"
for _, f in enumerate(os.listdir(tmp_output_folder)):
filename = os.path.join(tmp_output_folder, f)
match_res = re.match(r".*output_([0-9]+)_([0-9]+)\.raw$", filename)
data_index, output_index = int(match_res.group(1)), int(match_res.group(2))

output_result_folder = f"{args.output_folder}/Result_{data_index}"
os.makedirs(output_result_folder, exist_ok=True)
output = np.fromfile(
filename,
dtype=eval(
Expand All @@ -297,13 +349,11 @@ def post_process():
output = torch.from_numpy(
output.reshape(output_info[output_index]["shape"])
)
torch.save(
output, f"{args.output_folder}/output_{data_index}_{output_index}.pt"
)
torch.save(output, f"{output_result_folder}/output_{output_index}.pt")

logger.info("collecting output data")
make_output_dir(args.output_folder)
adb.pull(args.output_folder, post_process)
adb.pull(tmp_dir, post_process)
shutil.rmtree(tmp_dir)
logger.info(f"execution finished, please check {args.output_folder} for results")


Expand Down
27 changes: 14 additions & 13 deletions examples/qualcomm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,23 @@ def push(self, inputs=None, input_list=None, files=None, init_env=True):
f"{self.build_path}/backends/qualcomm/libqnn_executorch_backend.so",
f"{self.qnn_sdk}/lib/{self.target}/libQnnModelDlc.so",
]
input_list_file, input_files = generate_inputs(
self.working_dir, self.input_list_filename, inputs
)
with tempfile.TemporaryDirectory() as tmp_dir:
input_list_file, input_files = generate_inputs(
tmp_dir, self.input_list_filename, inputs
)

if input_list_file is not None:
# prepare input list
artifacts.append(input_list_file)
if input_list_file is not None:
# prepare input list
artifacts.append(input_list_file)

for artifact in artifacts:
self._adb(["push", artifact, self.workspace])
for artifact in artifacts:
self._adb(["push", artifact, self.workspace])

# input data
for file_name in input_files:
self._adb(["push", file_name, self.workspace])
# input data
for file_name in input_files:
self._adb(["push", file_name, self.workspace])

# dynamic shape related
with tempfile.TemporaryDirectory() as tmp_dir:
# dynamic shape related
if self.expected_input_shape and self.expected_output_shape:
shape_info = {
"input_shape": self.expected_input_shape,
Expand Down Expand Up @@ -956,6 +956,7 @@ def prepare_input_file(tensor, fd, index, sub_index):
# Prepare input data
if inputs is not None:
input_list_file = f"{dest_path}/{file_name}"

with open(input_list_file, "w") as f:
for idx, data in enumerate(inputs):
sub_index = 0
Expand Down
Loading