diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index ae3c99ff523..e81a80b3517 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -162,7 +162,7 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): for ch in range(num_channels): max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps q_scales = torch.clamp( - input=scales[ch] / max_scale, + input=torch.round(input=scales[ch] / max_scale), min=1, max=2**bitwidth_of_scale, ).to(quant_scales_dtype) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 5b69ae5ac3c..5dcb664bb9d 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -317,6 +317,7 @@ def annotate_matmul_input1(node: Node, is_qat: str): torch.ops.aten.transpose.int, torch.ops.aten.view.default, torch.ops.aten.reshape.default, + torch.ops.aten.slice.Tensor, ]: annotate_single_in_single_out(node, quantization_config_8a8w) node = node.args[0] @@ -340,7 +341,11 @@ def annotate_matmul_input1(node: Node, is_qat: str): node, quantization_config=quantization_config_8a4w_per_channel ) break - elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]: + elif node.target in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.matmul.default, + ]: break else: print(f"The node ({node}) is not expected in the input1 of the matmul") @@ -356,7 +361,12 @@ def annotate_matmul_input1(node: Node, is_qat: str): ) for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.matmul.default + and all(arg.op == "call_function" for arg in node.args) + ): + # Only apply custom annotation on Q @ K^T @ V annotate_matmul(node, quantization_config_16a8w) annotate_matmul_input1(node.args[1], is_qat=is_qat) diff --git a/backends/qualcomm/runtime/backends/QnnImplementation.cpp b/backends/qualcomm/runtime/backends/QnnImplementation.cpp index a9136a83c9c..42f866d22cc 100644 --- a/backends/qualcomm/runtime/backends/QnnImplementation.cpp +++ b/backends/qualcomm/runtime/backends/QnnImplementation.cpp @@ -51,16 +51,8 @@ Error QnnImplementation::StartBackend( const std::string& lib_path, const QnnSaver_Config_t** saver_config) { Qnn_ErrorHandle_t error = QNN_SUCCESS; - // RTLD_GLOBAL is needed on x86 as HTP op package has a requirement for the - // symbols in backend to be visible. Using RTLD_LOCAL on Android to allow full - // unloading of HTP backend shared library on dlclose() as RTLD_GLOBAL isn't - // letting it happen. void* lib_handle = nullptr; -#if defined(__ANDROID__) - lib_handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_LOCAL); -#else lib_handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL); -#endif if (lib_handle == nullptr) { QNN_EXECUTORCH_LOG_ERROR( "Cannot Open QNN library %s, with error: %s", diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 9c06b5e34f3..d64122ab12a 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1127,7 +1127,8 @@ def test_qnn_backend_where(self): (torch.randn(30, 20),), ] for i, module in enumerate(modules): - self.lower_module_and_test_output(module, sample_inputs[i]) + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_inputs[i]) def test_qnn_backend_masked_fill(self): module = MaskedFill() # noqa: F405 @@ -2556,8 +2557,9 @@ def test_qnn_backend_where(self): (torch.randn(30, 20),), ] for i, module in enumerate(modules): - module = self.get_qdq_module(module, sample_inputs[i]) - self.lower_module_and_test_output(module, sample_inputs[i]) + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_inputs[i]) + self.lower_module_and_test_output(module, sample_inputs[i]) def test_qnn_backend_masked_fill(self): module = MaskedFill() # noqa: F405 @@ -4527,6 +4529,77 @@ def test_llama_stories_110m(self): if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai + def test_static_phi4(self): + if not self.required_envs(): + self.skipTest("missing required envs") + + prompt = "My favourite condiment is " + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--prompt", + f"{prompt}", + "--ptq", + "16a4w_block", + "--group_size", + "16", + "--decoder_model", + "phi_4_mini", + "--model_mode", + "kv", + "--max_seq_len", + "1024", + "--num_sharding", + "8", + "--eval_perplexity", + "--tasks", + "wikitext", + "--limit", + "1", + ] + if self.compile_only: + cmds.extend(["--compile_only"]) + elif self.device: + cmds.extend(["--device", self.device]) + if self.host: + cmds.extend(["--host", self.host]) + elif self.enable_x86_64: + cmds.extend(["--enable_x86_64"]) + if self.pre_gen_pte: + cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + cmds.extend( + [ + "--quant_attrs_path", + f"{self.pre_gen_pte}/kv_llama_qnn_quant_attrs.json", + ] + ) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + inference_speed_ref = {"SM8650": 14, "SM8750": 19} + self.assertLessEqual(msg["wiki_ppl"], 12) + self.assertLessEqual(msg["pte_size"], 4000000000) # 4gb + if self.model in inference_speed_ref: + self.assertGreaterEqual( + msg["inference_speed"], inference_speed_ref[self.model] + ) + def test_static_qwen2_5(self): if not self.required_envs(): self.skipTest("missing required envs") diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index b76a3584479..87bd46ce407 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -69,6 +69,12 @@ Default example using hybrid mode. python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" ``` +#### Phi4-mini-instruct +Default example using hybrid mode. +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w_block --group_size 16 --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --num_sharding 8 --prompt "I would like to learn python, could you teach me with a simple example?" +``` + #### QWEN2.5 0.5B Default example using hybrid mode ```bash @@ -81,6 +87,7 @@ Default example using hybrid mode. python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a8w --tokenizer_bin tokenizer.bin --decoder_model smollm2 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" ``` + ### KV Cache update mechanism We have two distinct mechanisms for updating the key-value (KV) cache, which can be selected at runtime. Shift Pointer and Smart Mask. diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 2ce49c61cf6..df9d782ffba 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -116,6 +116,8 @@ FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logging.getLogger().setLevel(logging.INFO) +# Avoid the error message "Could not initialize NNPACK! Reason: Unsupported hardware." +torch.backends.nnpack.set_flags(False) def next_power_of_two(n): @@ -233,10 +235,16 @@ def quantize( ).module() if quant_dtype == QuantDtype.use_16a4w_block: + if args.group_size is None: + raise ValueError( + "Group size is required when use quant_dtype 16a4w_block" + ) conv_nodes = [ n for n in fx_graph_module.graph.nodes if "conv" in n.name ] - block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} + block_size_map = { + n.name: (1, args.group_size, 1, 1) for n in conv_nodes + } quantizer.set_block_size_map(block_size_map) fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) @@ -584,7 +592,7 @@ def permute(w, heads): if args.ptq != "16a8w": # 16a8w use 16bit kv io, so skip this custom annotation custom_annotations = custom_annotations + (annotate_matmul_16a8w,) - if args.decoder_model in {"stories110m", "stories260k"}: + if args.decoder_model in {"stories110m", "stories260k", "phi_4_mini"}: custom_annotations = custom_annotations + ( annotate_linear_16a8w_in_affine_layer, ) @@ -801,12 +809,20 @@ def post_process(): seq_len = args.max_seq_len multi_prompts = " ".join([f'--prompt "{prompt}"' for prompt in args.prompt]) + lookahead_args = " ".join( + [ + f"--window {args.window}", + f"--gcap {args.gcap}", + f"--ngram {args.ngram}", + ] + ) runner_args = " ".join( [ multi_prompts, f"--eval_mode {EVAL_MODE[args.model_mode]}", f"--temperature {args.temperature}", f"--system_prompt '{args.system_prompt}'", + lookahead_args if args.model_mode == "lookahead" else "", ] ) @@ -856,9 +872,6 @@ def post_process(): "--output_path outputs/outputs.txt", f"--performance_output_path {performance_output_path}", f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}", - f"--window {args.window}", - f"--gcap {args.gcap}", - f"--ngram {args.ngram}", runner_args, ] ) @@ -1123,6 +1136,13 @@ def _build_parser(): action="store_true", default=False, ) + parser.add_argument( + "-G", + "--group_size", + type=int, + default=None, + help="group_size used in block quantization for weight quantization.", + ) parser.add_argument("-v", "--verbose", action="store_true")