diff --git a/docs/qnn_backend/aot_execute.rst b/docs/qnn_backend/aot_execute.rst new file mode 100644 index 000000000..55addfef7 --- /dev/null +++ b/docs/qnn_backend/aot_execute.rst @@ -0,0 +1,125 @@ +QNN AOT Execution Flow +================================================================ + +.. note:: + Please refer to the `Environment Setup `_ documentation to configure the QNN and Hexagon SDK environments before proceeding. + +This document aims to explain the main execution flow of QNN AOT (Ahead-of-Time). This implementation is designed to fully leverage the offline compilation capabilities of the Qualcomm QNN framework to achieve efficient inference of fully integer-quantized Large Language Models (LLMs) on mobile devices, which is the de facto workflow for LLM execution on the Hexagon NPU. + +Specifically, our implementation employs a W4A16 quantization scheme. The Key-Value (KV) Cache is quantized to ``uint8``, and the linear weights are quantized using Low-Power Blockwise Quantization (LPBQ). + +The implementation of this module was inspired by the `PyTorch ExecuTorch`_ project, especially its `Hybrid Execution Mode`_ designed for the Qualcomm backend, for which we are grateful. + +.. _PyTorch ExecuTorch: https://pytorch.org/executorch/ +.. _Hybrid Execution Mode: https://github.com/pytorch/executorch/blob/main/examples/qualcomm/oss_scripts/llama/README.md + +Overall Flow +---------------------------------------------------------------- + +The QNN AOT execution flow is mainly divided into three stages: + +1. **Model Quantization and Export (Python)**: On the host machine, a Python script is used to quantize the pre-trained floating-point model and export it to the MLLM IR (``.mir``) format. +2. **Offline Compilation (C++)**: On the host machine, a C++ compiler program loads the ``.mir`` file, invokes the QNN toolchain for model compilation, graph optimization, and quantization parameter adjustment, and finally generates a QNN Context Binary. +3. **On-Device Execution (C++)**: On the target device (e.g., a mobile phone), the AOT runner program loads the pre-compiled context binary and executes inference. + + +Detailed Steps +---------------------------------------------------------------- + +Taking ``qwen3_qnn_aot`` as an example, the detailed steps are as follows. + +1. **Model Quantization and Export** + + First, we need to run a Python script on the host to quantize the model and export it as a ``.safetensors`` file. + + .. code-block:: shell + + cd ./pymllm/backends/qualcomm/transformers/qwen3 + python train.py --model_path "/your/qwen3/model/path/" --max_length 1024 --num_samples 128 --output_dir "/path/to/output" + + This step generates a key file: + + * ``model.safetensors``: The quantized model file, saved in the specified output directory. + + Next, convert the exported ``.safetensors`` model to the MLLM format (``.mllm``) using the ``mllm-convertor`` script. + + .. code-block:: shell + pip install pymllm + + mllm-convertor --input_path /path/to/output/model.safetensors --output_path /path/to/output/qwen3_1.7b.mllm + + This will generate the ``qwen3_1.7b.mllm`` file, which will be used in the subsequent compilation step. + +2. **Offline Compilation to Generate QNN Context** + + Next, we use a C++ compiler program (``compile.cpp``) on the host to generate the QNN context. This process invokes the QNN SDK to convert the MLLM IR into a QNN-supported format and performs optimizations. + + Compile and run the ``compile`` program: + + .. code-block:: shell + + # In the mllm-v2 project root directory + python task.py tasks/build_x86_qnn_aot.yaml + + # Run the compiler program + ./build-qnn-aot/bin/mllm-qwen3-aot-sha-c \ + -m /path/to/output/qwen3_1.7b.mllm \ + -c ./examples/qwen3_qnn_aot/config_1.7B.json \ + --aot_config ./examples/qwen3_qnn_aot/qnn_aot_cfg_1.7B.json + + + This program reads the ``.mllm`` model file and the quantization recipe, and finally generates a QNN context binary file named ``qwen3-1.7B-lpbq-sha.bin``. This file contains all the information needed to execute inference on the target device. + + .. note:: + The ``HtpSignedPd`` config in qnn_aot_cfg_1.7B.json will specify ``QNN_HTP_DEVICE_CONFIG_OPTION_SIGNEDPD`` during QNN initialization, which may cause an "Unsupported config option 2" error in older QNN versions. It is recommended to change the config in the json file to ``HtpUnsignedPd``. + +3. **On-Device AOT Inference** + + Finally, we push the generated ``qwen3-1.7B-lpbq-sha.bin`` file and other resources like the tokenizer to the target device. The on-device AOT runner program (``aot_run.cpp``) will load this binary file and execute inference. + + Compile and run the ``aot_run`` program: + + .. code-block:: shell + + # Cross-compile the aot_run program for the target device (e.g., Android) + python task.py tasks/build_android_qnn.yaml + + # Push compiled context file to the device + adb push qwen3-1.7B-lpbq-sha.bin /data/local/tmp/ + + # Push QNN libraries and Op Packages + ANDR_LIB=$QNN_SDK_ROOT/lib/aarch64-android + OP_PATH=mllm/backends/qnn/custom-op-package/LLaMAPackage/build + + adb push $ANDR_LIB/libQnnHtp.so /data/local/tmp + adb push $ANDR_LIB/libQnnHtpV75Stub.so /data/local/tmp + adb push $ANDR_LIB/libQnnHtpPrepare.so /data/local/tmp + adb push $ANDR_LIB/libQnnHtpProfilingReader.so /data/local/tmp + adb push $ANDR_LIB/libQnnHtpOptraceProfilingReader.so /data/local/tmp + adb push $ANDR_LIB/libQnnHtpV75CalculatorStub.so /data/local/tmp + adb push $QNN_SDK_ROOT/lib/hexagon-v75/unsigned/libQnnHtpV75Skel.so /data/local/tmp + adb push $QNN_SDK_ROOT/lib/aarch64-android/libQnnSystem.so /data/local/tmp + + adb push $OP_PATH/aarch64-android/libQnnLLaMAPackage.so /data/local/tmp/libQnnLLaMAPackage_CPU.so + adb push $OP_PATH/hexagon-v75/libQnnLLaMAPackage.so /data/local/tmp/libQnnLLaMAPackage_HTP.so + + # Push mllm runner and libs to device + adb push build-android-arm64-v8a-qnn/bin/*.so /data/local/tmp + adb push build-android-arm64-v8a-qnn/bin/mllm-qwen3-aot-runner /data/local/tmp + + # Execute on the device + adb shell "cd /data/local/tmp && export LD_LIBRARY_PATH=. && + ./mllm-qwen3-aot-runner -m qwen3-1.7B-lpbq-sha.bin + -t qwen3-tokenizer.json -c config_1.7B.json --ar_len 32" + + The AOT runner program loads the ``.bin`` file to initialize the QNN context, then receives input tokens, performs model inference, and outputs the next token, thus realizing the language model generation process. + +Hybrid Mode Explanation +---------------------------------------------------------------- + +Our QNN AOT implementation adopts a Hybrid mode similar to `executorch` to optimize the efficiency of Prompt processing and Token generation. + +* **Prefill Phase**: When processing the user's input (Prompt) for the first time, the model calculates and caches the Key-Value (KV) states for all input tokens at once. This phase is computationally intensive but is performed only once. +* **Decode Phase**: When generating subsequent tokens, the model takes only the previously generated token as input and uses the cached KV state for computation. This process is computationally light and fast, suitable for token-by-token generation. + +In this way, we combine the advantages of batch processing and stream processing to improve overall throughput while ensuring low latency. diff --git a/docs/qnn_backend/index.rst b/docs/qnn_backend/index.rst index b7092f938..336ef1845 100644 --- a/docs/qnn_backend/index.rst +++ b/docs/qnn_backend/index.rst @@ -6,4 +6,4 @@ QNN Backend setup_env core_design - qnn_model_convert + aot_execute diff --git a/docs/qnn_backend/setup_env.rst b/docs/qnn_backend/setup_env.rst index 6619a21b3..5d6b6712a 100644 --- a/docs/qnn_backend/setup_env.rst +++ b/docs/qnn_backend/setup_env.rst @@ -98,6 +98,10 @@ Compilation Commands This will build the necessary QNN op packages for both AArch64 and HVX v75 targets. +.. note:: + The Hexagon tools version in the Makefile may change. If compilation fails, please update the version number in the Makefile accordingly. + + Development Tips ---------------- diff --git a/examples/llama_qnn_aot/compile.cpp b/examples/llama_qnn_aot/compile.cpp index 3568a2f44..a064af95f 100644 --- a/examples/llama_qnn_aot/compile.cpp +++ b/examples/llama_qnn_aot/compile.cpp @@ -17,6 +17,9 @@ MLLM_MAIN({ auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + auto& qnn_env_path = Argparse::add("-qnn_env|--qnn_env_path") + .def("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/") + .help("QNN AOT Environment path."); Argparse::parse(argc, argv); @@ -47,7 +50,7 @@ MLLM_MAIN({ model.load(params); // Create Qnn AOT Model - auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv(qnn_env_path.get(), mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); // Model length 32. diff --git a/examples/llama_qnn_aot/compile_sha.cpp b/examples/llama_qnn_aot/compile_sha.cpp index bd938b7a9..bdc66a4a1 100644 --- a/examples/llama_qnn_aot/compile_sha.cpp +++ b/examples/llama_qnn_aot/compile_sha.cpp @@ -25,6 +25,9 @@ MLLM_MAIN({ auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + auto& qnn_env_path = Argparse::add("-qnn_env|--qnn_env_path") + .def("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/") + .help("QNN AOT Environment path."); Argparse::parse(argc, argv); @@ -73,7 +76,7 @@ MLLM_MAIN({ model.load(params); // Create Qnn AOT Model - auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv(qnn_env_path.get(), mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); // Model length 32. diff --git a/examples/qwen2_qnn_aot/compile.cpp b/examples/qwen2_qnn_aot/compile.cpp index 288501966..a5af957be 100644 --- a/examples/qwen2_qnn_aot/compile.cpp +++ b/examples/qwen2_qnn_aot/compile.cpp @@ -17,6 +17,9 @@ MLLM_MAIN({ auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + auto& qnn_env_path = Argparse::add("-qnn_env|--qnn_env_path") + .def("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/") + .help("QNN AOT Environment path."); Argparse::parse(argc, argv); @@ -47,7 +50,7 @@ MLLM_MAIN({ model.load(params); // Create Qnn AOT Model - auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv(qnn_env_path.get(), mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); // Model length 32. diff --git a/examples/qwen2_qnn_aot/compile_sha.cpp b/examples/qwen2_qnn_aot/compile_sha.cpp index 50aa9b5e5..cd0ffcb61 100644 --- a/examples/qwen2_qnn_aot/compile_sha.cpp +++ b/examples/qwen2_qnn_aot/compile_sha.cpp @@ -25,6 +25,9 @@ MLLM_MAIN({ auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + auto& qnn_env_path = Argparse::add("-qnn_env|--qnn_env_path") + .def("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/") + .help("QNN AOT Environment path."); Argparse::parse(argc, argv); @@ -73,7 +76,7 @@ MLLM_MAIN({ model.load(params); // Create Qnn AOT Model - auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv(qnn_env_path.get(), mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); // Model length 32. diff --git a/examples/qwen3_qnn_aot/aot_run.cpp b/examples/qwen3_qnn_aot/aot_run.cpp index b9bee3334..364ed4a06 100644 --- a/examples/qwen3_qnn_aot/aot_run.cpp +++ b/examples/qwen3_qnn_aot/aot_run.cpp @@ -43,10 +43,6 @@ MLLM_MAIN({ auto input_tensor = tokenizer.convertMessage({.prompt = prompt_text}); - // DBG: - mllm::print(input_tensor["sequence"].shape()); - mllm::print(input_tensor["sequence"]); - Runner runner(config, &tokenizer); if (!runner.load()) { std::cerr << "Failed to load model\n"; diff --git a/examples/qwen3_qnn_aot/compile.cpp b/examples/qwen3_qnn_aot/compile.cpp index cc813fe32..6404af3c1 100644 --- a/examples/qwen3_qnn_aot/compile.cpp +++ b/examples/qwen3_qnn_aot/compile.cpp @@ -17,6 +17,9 @@ MLLM_MAIN({ auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + auto& qnn_env_path = Argparse::add("-qnn_env|--qnn_env_path") + .def("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/") + .help("QNN AOT Environment path."); Argparse::parse(argc, argv); @@ -47,7 +50,7 @@ MLLM_MAIN({ model.load(params); // Create Qnn AOT Model - auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv(qnn_env_path.get(), mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); // Model length 32. diff --git a/examples/qwen3_qnn_aot/compile_sha.cpp b/examples/qwen3_qnn_aot/compile_sha.cpp index f6d25894b..9f2629f6f 100644 --- a/examples/qwen3_qnn_aot/compile_sha.cpp +++ b/examples/qwen3_qnn_aot/compile_sha.cpp @@ -25,6 +25,9 @@ MLLM_MAIN({ auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + auto& qnn_env_path = Argparse::add("-qnn_env|--qnn_env_path") + .def("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/") + .help("QNN AOT Environment path."); Argparse::parse(argc, argv); @@ -73,7 +76,7 @@ MLLM_MAIN({ model.load(params); // Create Qnn AOT Model - auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv(qnn_env_path.get(), mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); // Model length 32. diff --git a/mllm/CMakeLists.txt b/mllm/CMakeLists.txt index 06fa5aab2..8507df2b4 100644 --- a/mllm/CMakeLists.txt +++ b/mllm/CMakeLists.txt @@ -56,17 +56,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "App endif() endif() -# FIXME: @oreomaker Need to remove comma features in slice! -# Suppress comma-subscript warnings (deprecated C++ feature that will be removed in C++26) -# This flag is only available in Clang 13+ and GCC 10+ -if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") - target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) -elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "10.0") - target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) - endif() -endif() - # ONLY APPLE CAN DO ! # Processing OpenMP if(MLLM_KERNEL_USE_THREADS AND MLLM_KERNEL_THREADS_VENDOR_OPENMP) @@ -125,16 +114,17 @@ if(MLLM_BUILD_OPENCL_BACKEND) ) endif() -if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE OR MLLM_BUILD_QNN_BACKEND) - add_subdirectory(backends/qnn) -endif() - +# add definition before including qnn if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) add_compile_definitions( MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE ) endif() +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE OR MLLM_BUILD_QNN_BACKEND) + add_subdirectory(backends/qnn) +endif() + if(MLLM_BUILD_QNN_BACKEND) add_compile_definitions( MLLM_QNN_BACKEND diff --git a/mllm/backends/qnn/QNNModel.cpp b/mllm/backends/qnn/QNNModel.cpp index e99052d9a..6fc6110bf 100644 --- a/mllm/backends/qnn/QNNModel.cpp +++ b/mllm/backends/qnn/QNNModel.cpp @@ -134,8 +134,6 @@ ModelError_t QNNModel::loadGraphTensorInfo(const Qnn_Tensor_t* inputTensors, uin outputTensorWrappers_.push_back(wrapper); tensorWrapperMap_[tensorName] = wrapper; - // Record QNN output order (index in outputTensorWrappers_) - qnnOutputNameToIndex_[tensorName] = static_cast(outputTensorWrappers_.size() - 1); } MLLM_INFO("QNNModel::loadGraphTensorInfo() loaded {} input tensors and {} output tensors for graph: {}", numInputTensors, @@ -182,8 +180,6 @@ ModelError_t QNNModel::addTensorWrapper(const std::shared_ptr& inputTensorWrappers_.push_back(tensorWrapper); } else if (QNN_TENSOR_GET_TYPE(nativeTensor) == QNN_TENSOR_TYPE_APP_READ) { outputTensorWrappers_.push_back(tensorWrapper); - // Record QNN output order (index in outputTensorWrappers_) - qnnOutputNameToIndex_[tensorName] = static_cast(outputTensorWrappers_.size() - 1); } return MODEL_NO_ERROR; diff --git a/mllm/backends/qnn/QNNModel.hpp b/mllm/backends/qnn/QNNModel.hpp index 7c0b38870..49504474c 100644 --- a/mllm/backends/qnn/QNNModel.hpp +++ b/mllm/backends/qnn/QNNModel.hpp @@ -76,21 +76,6 @@ class QNNModel { std::map> getOutputTensorMap() { return modelOutputTensorMap_; } - // Set expected output order (MLLM order) - void setExpectedOutputOrder(const std::vector& expectedOrder) { expectedOutputOrder_ = expectedOrder; } - - // Get expected output order - [[nodiscard]] const std::vector& getExpectedOutputOrder() const { return expectedOutputOrder_; } - - // Get QNN output index by tensor name - [[nodiscard]] int getQnnOutputIndex(const std::string& tensorName) const { - auto it = qnnOutputNameToIndex_.find(tensorName); - if (it != qnnOutputNameToIndex_.end()) { - return it->second; - } - return -1; // Not found - } - // Load input/output tensor information from existing graph ModelError_t loadGraphTensorInfo(const Qnn_Tensor_t* inputTensors, uint32_t numInputTensors, const Qnn_Tensor_t* outputTensors, uint32_t numOutputTensors); @@ -118,10 +103,6 @@ class QNNModel { std::map> modelOutputTensorMap_; - // Output order mapping: MLLM expected order and QNN actual order - std::vector expectedOutputOrder_; // MLLM expected output order (tensor names) - std::map qnnOutputNameToIndex_; // QNN output tensor name -> index in outputTensorWrappers_ - // Storage for node string parameters to ensure lifetime struct NodeStringStorage { std::string name; diff --git a/mllm/backends/qnn/QNNUtils.cpp b/mllm/backends/qnn/QNNUtils.cpp index 73d240bbb..318300dbd 100644 --- a/mllm/backends/qnn/QNNUtils.cpp +++ b/mllm/backends/qnn/QNNUtils.cpp @@ -455,7 +455,9 @@ std::shared_ptr QNNTensorWrapper::create(const std::string& na // it will be allocated to QNN shared buffer via QNNTensorWrapper::alloc() later MLLM_RT_ASSERT(!name.empty()); // in AOT case, the tensor is all on CPU (TODO: handle this) - // if (type != QNN_TENSOR_TYPE_STATIC) { MLLM_RT_ASSERT(tensor.device() == kQNN); } +#ifndef MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE + if (type != QNN_TENSOR_TYPE_STATIC) { MLLM_RT_ASSERT(tensor.device() == kQNN); } +#endif Qnn_DataType_t dataType = mllmDataTypeToQnnDataType(tensor.dtype()); @@ -466,6 +468,9 @@ std::shared_ptr QNNTensorWrapper::create(const std::string& na tensorWrapper->dataContainer_ = tensor; + // when passed allocated tensor, mark isAlloc_ = true + if (!tensor.isNil()) tensorWrapper->isAlloc_ = true; + return tensorWrapper; } diff --git a/mllm/backends/qnn/QNNUtils.hpp b/mllm/backends/qnn/QNNUtils.hpp index 047a79355..36fb6a91c 100644 --- a/mllm/backends/qnn/QNNUtils.hpp +++ b/mllm/backends/qnn/QNNUtils.hpp @@ -302,4 +302,75 @@ QNNParamScalarWrapper::QNNParamScalarWrapper(const std::string& name, T value) : } } +// --------------- QNN Quantization Print Helper (DBG Use) --------------- +inline void __printDequantizedUInt16TensorData(const mllm::Tensor& tensor, int dim, std::vector& indices, float scale, + int32_t offset) { + auto shape = tensor.shape(); + if (dim >= (int)shape.size()) { + uint16_t val = tensor.constAt(indices); + float fval = (static_cast(val) + offset) * scale; + printf("%.4f", fval); + return; + } + + int32_t dim_size = shape[dim]; + printf("["); + + int max_elements_per_dim = 20; + bool is_last_dim = (dim == (int)shape.size() - 1); + + if (dim_size <= max_elements_per_dim) { + for (int32_t i = 0; i < dim_size; ++i) { + if (i > 0) { + printf(", "); + if (!is_last_dim) printf("\n"); + } + indices.push_back(i); + __printDequantizedUInt16TensorData(tensor, dim + 1, indices, scale, offset); + indices.pop_back(); + } + } else { + const int SHOW_ELEMENTS = max_elements_per_dim / 2; + for (int32_t i = 0; i < SHOW_ELEMENTS; ++i) { + if (i > 0) { + printf(", "); + if (!is_last_dim) printf("\n"); + } + indices.push_back(i); + __printDequantizedUInt16TensorData(tensor, dim + 1, indices, scale, offset); + indices.pop_back(); + } + if (!is_last_dim) { + printf(",\n...\n"); + } else { + printf(", ..., "); + } + + for (int32_t i = dim_size - SHOW_ELEMENTS; i < dim_size; ++i) { + if (i > dim_size - SHOW_ELEMENTS) { + printf(", "); + if (!is_last_dim) printf("\n"); + } + indices.push_back(i); + __printDequantizedUInt16TensorData(tensor, dim + 1, indices, scale, offset); + indices.pop_back(); + } + } + printf("]"); +} + +inline void printDequantizedTensor(const mllm::Tensor& tensor, float scale, int32_t offset) { + std::vector indices; + // reserve shape size + indices.reserve(tensor.shape().size()); + printf("Dequantized Tensor (scale=%f, offset=%d):\n", scale, offset); + + if (tensor.dtype() == mllm::kUInt16 && tensor.dtype() != mllm::kUInt16PerTensorAsy) { + __printDequantizedUInt16TensorData(tensor, 0, indices, scale, offset); + } else { + printf("Not supported type"); + } + printf("\n"); +} + } // namespace mllm::qnn diff --git a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp index f9eae7157..c276fbf00 100644 --- a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp +++ b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp @@ -148,7 +148,7 @@ int64_t PromptProcessor::prefill(const std::vector& prompt_tokens, i current_pos += chunk_size; } - auto logits = output_tensors_[0].to(kCPU).squeeze(0)[{kAll, (num_tokens + config_.ar_len - 1) % config_.ar_len, kAll}]; + auto logits = output_tensors_[0].to(kCPU).squeeze(0)[{kAll, ((int)num_tokens + config_.ar_len - 1) % config_.ar_len, kAll}]; auto cur_token = module_->sampleGreedy(logits); diff --git a/mllm/backends/qnn/passes/QNNGraphBuildPass.cpp b/mllm/backends/qnn/passes/QNNGraphBuildPass.cpp index 60b6b229d..760614942 100644 --- a/mllm/backends/qnn/passes/QNNGraphBuildPass.cpp +++ b/mllm/backends/qnn/passes/QNNGraphBuildPass.cpp @@ -130,7 +130,8 @@ void QNNGraphBuildPass::buildQnnGraph(const ir::graph::SubGraphOp::ptr_t& sub_gr QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, {.scaleOffsetEncoding = {.scale = scale, .offset = 0}}}; } - ModelError_t err = qnn_model->addTensor(input_tensor->name(), QNN_TENSOR_TYPE_APP_WRITE, input_tensor->tensor_, quantize_param); + ModelError_t err = + qnn_model->addTensor(input_tensor->name(), QNN_TENSOR_TYPE_APP_WRITE, input_tensor->tensor_, quantize_param); if (err != MODEL_NO_ERROR) { MLLM_ERROR("Failed to add input tensor {} to graph '{}'", input_tensor->name(), graph_name); return; @@ -139,7 +140,6 @@ void QNNGraphBuildPass::buildQnnGraph(const ir::graph::SubGraphOp::ptr_t& sub_gr // Record MLLM expected output order from ReturnOp std::vector expectedOutputOrder; - ir::cf::ReturnOp::ptr_t return_op = nullptr; // Process each operation in the subgraph for (auto& region_op : graph_region->ops()) { @@ -164,29 +164,12 @@ void QNNGraphBuildPass::buildQnnGraph(const ir::graph::SubGraphOp::ptr_t& sub_gr MLLM_WARN("No pattern registered for op type: {}", optype2Str(op_types)); } } else if (auto ret_op = std::dynamic_pointer_cast(region_op)) { - // Record ReturnOp to extract expected output order - return_op = ret_op; + continue; } else { MLLM_WARN("Unsupported op type in QNN subgraph: {}", (int)region_op->getKind()); } } - // Extract MLLM expected output order from ReturnOp inputs - if (return_op) { - for (auto& input : return_op->inputs()) { - auto output_tensor = input->cast_(); - if (output_tensor) { - expectedOutputOrder.push_back(output_tensor->name()); - } - } - // Set expected output order in QNN model - qnn_model->setExpectedOutputOrder(expectedOutputOrder); - // MLLM_INFO("QNNGraphBuildPass: Recorded MLLM expected output order for graph '{}' with {} outputs", graph_name, - // expectedOutputOrder.size()); - } else { - MLLM_WARN("QNNGraphBuildPass: No ReturnOp found in graph '{}', cannot determine expected output order", graph_name); - } - // Finalize the QNN graph if (!qnn_backend->graphFinalize(graph_name)) { MLLM_ERROR("Failed to finalize QNN graph '{}'", graph_name); diff --git a/mllm/core/SlicePrimitives.hpp b/mllm/core/SlicePrimitives.hpp index 59215737e..d38e8b431 100644 --- a/mllm/core/SlicePrimitives.hpp +++ b/mllm/core/SlicePrimitives.hpp @@ -26,42 +26,4 @@ struct SliceIndicesPair { using SliceIndices = std::vector; -// Helper class for comma operator to enable [1,2,3] syntax -class SliceIndicesBuilder { - public: - // NOLINT for intentional implicit conversion - SliceIndicesBuilder(int32_t first_index) { // NOLINT(google-explicit-constructor) - indices_.emplace_back(first_index); - } - - // NOLINT for intentional implicit conversion - SliceIndicesBuilder(const SliceIndicesPair& first_pair) { // NOLINT(google-explicit-constructor) - indices_.emplace_back(first_pair); - } - - // operator, to chain multiple indices - SliceIndicesBuilder operator,(int32_t index) && { - indices_.emplace_back(index); - return std::move(*this); - } - - SliceIndicesBuilder operator,(const SliceIndicesPair& pair) && { - indices_.emplace_back(pair); - return std::move(*this); - } - - // Implicit conversion to SliceIndices - intentional for syntax sugar - operator SliceIndices() const { // NOLINT(google-explicit-constructor) - return indices_; - } - - private: - SliceIndices indices_; -}; - -// Helper function to start the builder chain -inline SliceIndicesBuilder make_slice(int32_t index) { return {index}; } - -inline SliceIndicesBuilder make_slice(const SliceIndicesPair& pair) { return {pair}; } - } // namespace mllm diff --git a/mllm/models/minicpm_o2_6/modeling_resampler.hpp b/mllm/models/minicpm_o2_6/modeling_resampler.hpp index 69795e5a1..f447521bd 100644 --- a/mllm/models/minicpm_o2_6/modeling_resampler.hpp +++ b/mllm/models/minicpm_o2_6/modeling_resampler.hpp @@ -294,7 +294,7 @@ class Resampler : public nn::Module { std::vector outputs; for (int32_t b = 0; b < batch_size; ++b) { // x for this batch - Tensor x_b = x[make_slice(b), kAll, kAll].view({seq_len, embed_dim_}); + Tensor x_b = x[{b, kAll, kAll}].view({seq_len, embed_dim_}); // pos_embed for this batch // Tensor pos_embed_b = Tensor::empty({seq_len, embed_dim_}, kFloat32).alloc(); @@ -308,12 +308,12 @@ class Resampler : public nn::Module { // } // } // TODO: handle 'set 0' - Tensor pos_embed_b = pos_embed_padded[make_slice(b), kAll, kAll].view({seq_len, embed_dim_}); + Tensor pos_embed_b = pos_embed_padded[{b, kAll, kAll}].view({seq_len, embed_dim_}); auto kv_input = x_b + pos_embed_b; // key_padding_mask for this batch - Tensor key_padding_mask_b = key_padding_mask[make_slice(b), kAll].view({max_patch_len}); + Tensor key_padding_mask_b = key_padding_mask[{b, kAll}].view({max_patch_len}); bool has_padding = false; for (int i = 0; i < seq_len; i++) { diff --git a/mllm/models/minicpm_o2_6/streaming_generation.cpp b/mllm/models/minicpm_o2_6/streaming_generation.cpp index b1489470c..c4f9902e1 100644 --- a/mllm/models/minicpm_o2_6/streaming_generation.cpp +++ b/mllm/models/minicpm_o2_6/streaming_generation.cpp @@ -31,7 +31,7 @@ void StreamingGenerator::generate_next(OmniOutput& output) { if (spk_embeds_.isNil()) { streamer_ = ++streamer_; - spk_embeds_ = streamer_.getLastHiddenStates()[make_slice(0), spk_start_idx_ + 1, kAll]; + spk_embeds_ = streamer_.getLastHiddenStates().slice({0, spk_start_idx_ + 1, mllm::kAll}); std::string tts_eos_token = preprocessor::wideString2Utf8String(L"<|tts_eos|>"); std::string tts_text = streamer_->text;