Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
for ch in range(num_channels):
max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps
q_scales = torch.clamp(
input=scales[ch] / max_scale,
input=torch.round(input=scales[ch] / max_scale),
min=1,
max=2**bitwidth_of_scale,
).to(quant_scales_dtype)
Expand Down
14 changes: 12 additions & 2 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def annotate_matmul_input1(node: Node, is_qat: str):
torch.ops.aten.transpose.int,
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
torch.ops.aten.slice.Tensor,
]:
annotate_single_in_single_out(node, quantization_config_8a8w)
node = node.args[0]
Expand All @@ -340,7 +341,11 @@ def annotate_matmul_input1(node: Node, is_qat: str):
node, quantization_config=quantization_config_8a4w_per_channel
)
break
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
elif node.target in [
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.matmul.default,
]:
break
else:
print(f"The node ({node}) is not expected in the input1 of the matmul")
Expand All @@ -356,7 +361,12 @@ def annotate_matmul_input1(node: Node, is_qat: str):
)

for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
if (
node.op == "call_function"
and node.target == torch.ops.aten.matmul.default
and all(arg.op == "call_function" for arg in node.args)
):
# Only apply custom annotation on Q @ K^T @ V
annotate_matmul(node, quantization_config_16a8w)
annotate_matmul_input1(node.args[1], is_qat=is_qat)

Expand Down
8 changes: 0 additions & 8 deletions backends/qualcomm/runtime/backends/QnnImplementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,8 @@ Error QnnImplementation::StartBackend(
const std::string& lib_path,
const QnnSaver_Config_t** saver_config) {
Qnn_ErrorHandle_t error = QNN_SUCCESS;
// RTLD_GLOBAL is needed on x86 as HTP op package has a requirement for the
// symbols in backend to be visible. Using RTLD_LOCAL on Android to allow full
// unloading of HTP backend shared library on dlclose() as RTLD_GLOBAL isn't
// letting it happen.
void* lib_handle = nullptr;
#if defined(__ANDROID__)
lib_handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_LOCAL);
#else
lib_handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
#endif
if (lib_handle == nullptr) {
QNN_EXECUTORCH_LOG_ERROR(
"Cannot Open QNN library %s, with error: %s",
Expand Down
79 changes: 76 additions & 3 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,8 @@ def test_qnn_backend_where(self):
(torch.randn(30, 20),),
]
for i, module in enumerate(modules):
self.lower_module_and_test_output(module, sample_inputs[i])
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_inputs[i])

def test_qnn_backend_masked_fill(self):
module = MaskedFill() # noqa: F405
Expand Down Expand Up @@ -2556,8 +2557,9 @@ def test_qnn_backend_where(self):
(torch.randn(30, 20),),
]
for i, module in enumerate(modules):
module = self.get_qdq_module(module, sample_inputs[i])
self.lower_module_and_test_output(module, sample_inputs[i])
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_inputs[i])
self.lower_module_and_test_output(module, sample_inputs[i])

def test_qnn_backend_masked_fill(self):
module = MaskedFill() # noqa: F405
Expand Down Expand Up @@ -4527,6 +4529,77 @@ def test_llama_stories_110m(self):
if not self.compile_only and not self.enable_x86_64:
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai

def test_static_phi4(self):
if not self.required_envs():
self.skipTest("missing required envs")

prompt = "My favourite condiment is "
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--model",
self.model,
"--ip",
self.ip,
"--port",
str(self.port),
"--prompt",
f"{prompt}",
"--ptq",
"16a4w_block",
"--group_size",
"16",
"--decoder_model",
"phi_4_mini",
"--model_mode",
"kv",
"--max_seq_len",
"1024",
"--num_sharding",
"8",
"--eval_perplexity",
"--tasks",
"wikitext",
"--limit",
"1",
]
if self.compile_only:
cmds.extend(["--compile_only"])
elif self.device:
cmds.extend(["--device", self.device])
if self.host:
cmds.extend(["--host", self.host])
elif self.enable_x86_64:
cmds.extend(["--enable_x86_64"])
if self.pre_gen_pte:
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
cmds.extend(
[
"--quant_attrs_path",
f"{self.pre_gen_pte}/kv_llama_qnn_quant_attrs.json",
]
)

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
else:
inference_speed_ref = {"SM8650": 14, "SM8750": 19}
self.assertLessEqual(msg["wiki_ppl"], 12)
self.assertLessEqual(msg["pte_size"], 4000000000) # 4gb
if self.model in inference_speed_ref:
self.assertGreaterEqual(
msg["inference_speed"], inference_speed_ref[self.model]
)

def test_static_qwen2_5(self):
if not self.required_envs():
self.skipTest("missing required envs")
Expand Down
7 changes: 7 additions & 0 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ Default example using hybrid mode.
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
```

#### Phi4-mini-instruct
Default example using hybrid mode.
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w_block --group_size 16 --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --num_sharding 8 --prompt "I would like to learn python, could you teach me with a simple example?"
```

#### QWEN2.5 0.5B
Default example using hybrid mode
```bash
Expand All @@ -81,6 +87,7 @@ Default example using hybrid mode.
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a8w --tokenizer_bin tokenizer.bin --decoder_model smollm2 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?"
```


### KV Cache update mechanism
We have two distinct mechanisms for updating the key-value (KV) cache, which can be selected at runtime. Shift Pointer and Smart Mask.

Expand Down
30 changes: 25 additions & 5 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
logging.getLogger().setLevel(logging.INFO)
# Avoid the error message "Could not initialize NNPACK! Reason: Unsupported hardware."
torch.backends.nnpack.set_flags(False)


def next_power_of_two(n):
Expand Down Expand Up @@ -233,10 +235,16 @@ def quantize(
).module()

if quant_dtype == QuantDtype.use_16a4w_block:
if args.group_size is None:
raise ValueError(
"Group size is required when use quant_dtype 16a4w_block"
)
conv_nodes = [
n for n in fx_graph_module.graph.nodes if "conv" in n.name
]
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
block_size_map = {
n.name: (1, args.group_size, 1, 1) for n in conv_nodes
}
quantizer.set_block_size_map(block_size_map)

fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
Expand Down Expand Up @@ -584,7 +592,7 @@ def permute(w, heads):
if args.ptq != "16a8w":
# 16a8w use 16bit kv io, so skip this custom annotation
custom_annotations = custom_annotations + (annotate_matmul_16a8w,)
if args.decoder_model in {"stories110m", "stories260k"}:
if args.decoder_model in {"stories110m", "stories260k", "phi_4_mini"}:
custom_annotations = custom_annotations + (
annotate_linear_16a8w_in_affine_layer,
)
Expand Down Expand Up @@ -801,12 +809,20 @@ def post_process():

seq_len = args.max_seq_len
multi_prompts = " ".join([f'--prompt "{prompt}"' for prompt in args.prompt])
lookahead_args = " ".join(
[
f"--window {args.window}",
f"--gcap {args.gcap}",
f"--ngram {args.ngram}",
]
)
runner_args = " ".join(
[
multi_prompts,
f"--eval_mode {EVAL_MODE[args.model_mode]}",
f"--temperature {args.temperature}",
f"--system_prompt '{args.system_prompt}'",
lookahead_args if args.model_mode == "lookahead" else "",
]
)

Expand Down Expand Up @@ -856,9 +872,6 @@ def post_process():
"--output_path outputs/outputs.txt",
f"--performance_output_path {performance_output_path}",
f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
f"--window {args.window}",
f"--gcap {args.gcap}",
f"--ngram {args.ngram}",
runner_args,
]
)
Expand Down Expand Up @@ -1123,6 +1136,13 @@ def _build_parser():
action="store_true",
default=False,
)
parser.add_argument(
"-G",
"--group_size",
type=int,
default=None,
help="group_size used in block quantization for weight quantization.",
)

parser.add_argument("-v", "--verbose", action="store_true")

Expand Down
Loading