From a04ef3e1c88cafb0409f9aebbd9cc4b95b675001 Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Thu, 22 Jan 2026 16:42:10 +0800 Subject: [PATCH 01/17] feat:add Qwen2.5omni text modal processing --- examples/CMakeLists.txt | 1 + examples/qwen2_5omni/CMakeLists.txt | 3 + examples/qwen2_5omni/config_qwen2_5omni.json | 495 ++++++++++++++++++ examples/qwen2_5omni/text_infer.cpp | 72 +++ .../qwen2_5omni/configuration_qwen2_5omni.hpp | 97 ++++ .../qwen2_5omni/modeling_qwen2_5omni.hpp | 357 +++++++++++++ .../qwen2_5omni/tokenization_qwen2_5omni.hpp | 252 +++++++++ 7 files changed, 1277 insertions(+) create mode 100644 examples/qwen2_5omni/CMakeLists.txt create mode 100644 examples/qwen2_5omni/config_qwen2_5omni.json create mode 100644 examples/qwen2_5omni/text_infer.cpp create mode 100644 mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp create mode 100644 mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp create mode 100644 mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 180c3cbe6..a2426f229 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(qwen2vl) add_subdirectory(qwen2vl_tracer) add_subdirectory(qwen2_5vl) add_subdirectory(qwen2_5vl_tracer) +add_subdirectory(qwen2_5omni) add_subdirectory(llama) add_subdirectory(minicpm_o) add_subdirectory(minicpm4) diff --git a/examples/qwen2_5omni/CMakeLists.txt b/examples/qwen2_5omni/CMakeLists.txt new file mode 100644 index 000000000..3141b56d7 --- /dev/null +++ b/examples/qwen2_5omni/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(mllm-qwen2_5-omni-text-runner text_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-text-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-text-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen2_5omni/config_qwen2_5omni.json b/examples/qwen2_5omni/config_qwen2_5omni.json new file mode 100644 index 000000000..633e1b2b1 --- /dev/null +++ b/examples/qwen2_5omni/config_qwen2_5omni.json @@ -0,0 +1,495 @@ +{ + "architectures": [ + "Qwen2_5OmniModel" + ], + "enable_audio_output": true, + "enable_talker": true, + "model_type": "qwen2_5_omni", + "talker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/talker", + "architectures": [ + "Qwen2OmniTalkerForConditionalGeneration" + ], + "attention_dropout": 0.0, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "embedding_size": 3584, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 896, + "image_token_index": 151655, + "init_std": 0.02, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2_5_omni_talker", + "num_attention_heads": 12, + "num_hidden_layers": 24, + "num_key_value_heads": 4, + "position_id_per_seconds": 25, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "seconds_per_chunk": 2, + "sliding_window": 32768, + "spatial_merge_size": 2, + "torch_dtype": "bfloat16", + "tts_codec_end_token_id": 8294, + "tts_codec_mask_token_id": 8296, + "tts_codec_pad_token_id": 8292, + "tts_codec_start_token_id": 8293, + "tts_text_end_token_id": 151861, + "tts_text_pad_token_id": 151859, + "tts_text_start_token_id": 151860, + "use_cache": true, + "use_sliding_window": false, + "video_token_index": 151656, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vocab_size": 8448 + }, + "thinker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/thinker", + "architectures": [ + "Qwen2OmniNaViTThinkerForConditionalGeneration" + ], + "audio_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "activation_dropout": 0.0, + "activation_function": "gelu", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "d_model": 1280, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_attention_heads": 20, + "encoder_ffn_dim": 5120, + "encoder_layerdrop": 0.0, + "encoder_layers": 32, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "init_std": 0.02, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_source_positions": 1500, + "min_length": 0, + "model_type": "qwen2_5_omni_audio_encoder", + "n_window": 100, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 32, + "num_mel_bins": 128, + "num_return_sequences": 1, + "output_attentions": false, + "output_dim": 3584, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "scale_embedding": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "text_config": { + "model_type": "qwen2_5_omni_text", + "hidden_act": "silu", + "hidden_size": 3584, + "init_std": 0.02, + "intermediate_size": 18944, + "vocab_size": 152064, + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "use_cache": true, + "rope_theta": 1000000.0, + "use_sliding_window": false, + "sliding_window": 32768, + "attention_dropout": 0.0, + "tie_word_embeddings": false + }, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "bos_token_id": 151644, + "eos_token_id": 151645, + "ignore_index": -100, + "image_token_index": 151655, + "init_std": 0.02, + "model_type": "qwen2_5_omni_thinker", + "pad_token_id": 151643, + "position_id_per_seconds": 25, + "seconds_per_chunk": 2, + "torch_dtype": "bfloat16", + "user_token_id": 872, + "video_token_index": 151656, + "vision_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 32, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "embed_dim": 1280, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "in_channels": 3, + "in_chans": 3, + "init_std": 0.02, + "intermediate_size": 3420, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "qwen2_5_omni_vision_encoder", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_heads": 16, + "num_return_sequences": 1, + "out_hidden_size": 3584, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "patch_size": 14, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "temporal_patch_size": 2, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "tokens_per_second": 25, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false, + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654 + }, + "token2wav_config": { + "_attn_implementation_autoset": true, + "bigvgan_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_bigvgan", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "upsample_initial_channel": 1536, + "upsample_kernel_sizes": [ + 11, + 7, + 4, + 4, + 4, + 4 + ], + "upsample_rates": [ + 5, + 3, + 2, + 2, + 2, + 2 + ], + "use_bfloat16": false, + "use_bias_at_final": false + }, + "dit_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 22, + "dim": 1024, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.1, + "early_stopping": false, + "emb_dim": 512, + "enc_attention_channels": 64, + "enc_channels": [ + 256, + 256, + 256, + 256, + 768 + ], + "enc_dilations": [ + 1, + 2, + 3, + 4, + 1 + ], + "enc_dim": 128, + "enc_emb_dim": 192, + "enc_global_context": true, + "enc_kernel_sizes": [ + 5, + 3, + 3, + 3, + 1 + ], + "enc_lin_neurons": 192, + "enc_res2net_scale": 2, + "enc_se_channels": 64, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "ff_mult": 2, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "head_dim": 64, + "heads": 16, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_dit", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_embeds": 8193, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repeats": 2, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "float32", + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "model_type": "qwen2_5_omni_token2wav" + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0.dev0" +} \ No newline at end of file diff --git a/examples/qwen2_5omni/text_infer.cpp b/examples/qwen2_5omni/text_infer.cpp new file mode 100644 index 000000000..299a0e07d --- /dev/null +++ b/examples/qwen2_5omni/text_infer.cpp @@ -0,0 +1,72 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Text CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string prompt_text; + + fmt::print("💬 Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + + if (prompt_text == "exit" || prompt_text == "quit") { return 0; } + + try { + fmt::print("🔄 Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertMessage({.prompt = prompt_text}); + + fmt::print("\n🤖 Response: "); + for (auto& step : qwen2_5omni.chat(inputs)) { + std::wcout << qwen2_5omni_tokenizer.detokenize(step.cur_token_id) << std::flush; + } + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp new file mode 100644 index 000000000..2b0cb1ee8 --- /dev/null +++ b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp @@ -0,0 +1,97 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::qwen2_5omni { + +struct Qwen2_5OmniConfig : protected ConfigFile { + Qwen2_5OmniConfig() = default; + + explicit Qwen2_5OmniConfig(const std::string& file_path) : ConfigFile(file_path) { + auto& root = data(); + + if (root.contains("thinker_config")) { + auto& thinker_cfg = root["thinker_config"]; + auto& text_cfg = thinker_cfg["text_config"]; + + hidden_size = text_cfg["hidden_size"]; + intermediate_size = text_cfg["intermediate_size"]; + num_attention_heads = text_cfg["num_attention_heads"]; + num_key_value_heads = text_cfg["num_key_value_heads"]; + num_hidden_layers = text_cfg["num_hidden_layers"]; + max_position_embeddings = text_cfg["max_position_embeddings"]; + rms_norm_eps = text_cfg["rms_norm_eps"]; + vocab_size = text_cfg["vocab_size"]; + rope_theta = text_cfg["rope_theta"]; + tie_word_embeddings = text_cfg.value("tie_word_embeddings", false); + + if (text_cfg.contains("rope_scaling") && text_cfg["rope_scaling"].contains("mrope_section")) { + mrope_section = text_cfg["rope_scaling"]["mrope_section"].get>(); + } + + bos_token_id = thinker_cfg.value("bos_token_id", bos_token_id); + eos_token_id = thinker_cfg.value("eos_token_id", eos_token_id); + pad_token_id = thinker_cfg.value("pad_token_id", pad_token_id); + image_token_id = thinker_cfg.value("image_token_index", image_token_id); + audio_token_id = thinker_cfg.value("audio_token_index", audio_token_id); + video_token_id = thinker_cfg.value("video_token_index", video_token_id); + } else { + hidden_size = root["hidden_size"]; + intermediate_size = root["intermediate_size"]; + num_attention_heads = root["num_attention_heads"]; + num_key_value_heads = root["num_key_value_heads"]; + num_hidden_layers = root["num_hidden_layers"]; + max_position_embeddings = root["max_position_embeddings"]; + rms_norm_eps = root["rms_norm_eps"]; + vocab_size = root["vocab_size"]; + rope_theta = root["rope_theta"]; + tie_word_embeddings = root.value("tie_word_embeddings", tie_word_embeddings); + if (root.contains("mrope_section")) { + mrope_section = root["mrope_section"].get>(); + } + bos_token_id = root.value("bos_token_id", bos_token_id); + eos_token_id = root.value("eos_token_id", eos_token_id); + pad_token_id = root.value("pad_token_id", pad_token_id); + image_token_id = root.value("image_token_id", image_token_id); + audio_token_id = root.value("audio_token_id", audio_token_id); + video_token_id = root.value("video_token_id", video_token_id); + } + + max_cache_length = root.value("max_cache_length", max_position_embeddings); + + if (root.contains("linear_impl_type")) { + linear_impl_type = aops::str2LinearImplTypes(root["linear_impl_type"]); + } + } + + int32_t hidden_size = 3584; + int32_t intermediate_size = 18944; + int32_t num_attention_heads = 28; + int32_t num_key_value_heads = 4; + int32_t num_hidden_layers = 28; + int32_t max_position_embeddings = 32768; + float rms_norm_eps = 1e-06f; + int32_t vocab_size = 152064; + std::vector mrope_section = {16, 24, 24}; + float rope_theta = 1000000.0f; + bool tie_word_embeddings = false; + + int32_t max_cache_length = 32768; + + int64_t bos_token_id = 151644; + int64_t eos_token_id = 151645; + int64_t pad_token_id = 151643; + int64_t image_token_id = 151655; + int64_t audio_token_id = 151646; + int64_t video_token_id = 151656; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp new file mode 100644 index 000000000..7bd00baa7 --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp @@ -0,0 +1,357 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include + +#include "mllm/mllm.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" + +namespace mllm::models::qwen2_5omni { + +inline auto makeMultimodalRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0f / std::pow(rope_theta, 2.0f * i / output_dim); } + return inv_freq; +} + +inline auto makeMultimodalPositionEmbedding(Tensor& position_ids, const Tensor& inv_freq, int seq_len, int output_dim, + const std::vector& mrope_section) -> std::pair { + MLLM_RT_ASSERT_EQ(position_ids.shape().size(), 3); + MLLM_RT_ASSERT_EQ(position_ids.shape()[1], 1); + + Tensor tmp_sin = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + Tensor tmp_cos = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + + for (int b = 0; b < 3; ++b) { + for (int d = 0; d < inv_freq.shape()[0]; ++d) { + for (int s = 0; s < position_ids.shape()[2]; ++s) { + auto value = inv_freq.ptr()[d] * (*position_ids.offsettedPtr({b, 0, s})); + *tmp_cos.offsettedPtr({b, s, d}) = cosf(value); + *tmp_cos.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = cosf(value); + *tmp_sin.offsettedPtr({b, s, d}) = sinf(value); + *tmp_sin.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = sinf(value); + } + } + } + + Tensor sin = Tensor::nil(); + Tensor cos = Tensor::nil(); + + if (!mrope_section.empty()) { + auto double_rope_section = mrope_section; + for (int i : mrope_section) { double_rope_section.push_back(i); } + + int num_rows = tmp_sin.shape()[1]; + int num_cols = tmp_sin.shape()[2]; + + sin = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + cos = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + + std::vector start_cols; + int current_start = 0; + start_cols.push_back(current_start); + for (int s : double_rope_section) { + current_start += s; + start_cols.push_back(current_start); + } + + for (int j = 0; j < static_cast(double_rope_section.size()); ++j) { + int layer = j % 3; + int s_j = double_rope_section[j]; + int start_col_in = start_cols[j]; + int start_col_out = start_cols[j]; + for (int row = 0; row < num_rows; ++row) { + auto in_cos_row_ptr = tmp_cos.offsettedPtr({layer, row, 0}); + auto out_cos_row_ptr = cos.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_cos_row_ptr[start_col_out + c] = in_cos_row_ptr[start_col_in + c]; } + + auto in_sin_row_ptr = tmp_sin.offsettedPtr({layer, row, 0}); + auto out_sin_row_ptr = sin.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_sin_row_ptr[start_col_out + c] = in_sin_row_ptr[start_col_in + c]; } + } + } + } else { + sin = tmp_sin; + cos = tmp_cos; + } + + return {sin, cos}; +} + +class Qwen2_5OmniMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniMLP() = default; + Qwen2_5OmniMLP(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::MultimodalRoPE q_rope_; + nn::MultimodalRoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen2_5OmniAttention() = default; + + Qwen2_5OmniAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = hidden_size_ / num_attention_heads_; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, true, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, false, cfg.linear_impl_type); + + q_rope_ = reg( + "q_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + k_rope_ = reg( + "k_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + auto [k, v] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = k; + value_states = v; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + auto output = nn::functional::matmul(attn, value_states); + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + return {output}; + } + + int layer_idx_; +}; + +class Qwen2_5OmniDecoder final : public nn::Module { + public: + Qwen2_5OmniAttention self_attn_; + Qwen2_5OmniMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2_5OmniDecoder() = default; + + Qwen2_5OmniDecoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen2_5OmniText final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + + public: + Qwen2_5OmniText() = default; + + Qwen2_5OmniText(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + + auto inv = makeMultimodalRoPEInvFreq(cfg.hidden_size / cfg.num_attention_heads, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + x = norm_(x); + + return {x}; + } + + nn::Embedding embedding_; +}; + +class Qwen2_5OmniThinker final : public nn::Module { + public: + Qwen2_5OmniThinker() = default; + Qwen2_5OmniThinker(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + model_ = reg("model", cfg); + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + + Qwen2_5OmniText model_; + nn::Linear lm_head_; +}; + +class Qwen2_5OmniForCausalLM : public ARGeneration { + public: + explicit Qwen2_5OmniForCausalLM(const Qwen2_5OmniConfig& cfg) : cfg_(cfg), thinker_("thinker", cfg) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, + cfg.num_attention_heads, + cfg.num_key_value_heads, + cfg.hidden_size / cfg.num_attention_heads, + kFloat32, + kFloat32, + kCPU, + false); + eos_token_id_ = cfg.eos_token_id; + max_length_ = cfg.max_cache_length; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto sequence = input.at("sequence"); + + auto input_embeddings = thinker_.model_.embedding_(sequence); + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + position_ids = input.at("position_ids"); + } + position_ids = getPositionIds(sequence, position_ids); + + auto [llm_embedding_sin, llm_embedding_cos] = + makeMultimodalPositionEmbedding(position_ids, thinker_.model_.getBuffer("inv_freq"), cfg_.max_position_embeddings, + cfg_.hidden_size / cfg_.num_attention_heads, cfg_.mrope_section); + + auto hidden_states = thinker_.model_(input_embeddings, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + auto seq_len = hidden_states.shape()[1]; + auto last_hidden = hidden_states[{kAll, {seq_len - 1}, kAll}]; + auto logits = thinker_.lm_head_(last_hidden); + + return { + {"sequence", logits}, + {"position_ids", position_ids}, + }; + } + + Qwen2_5OmniThinker thinker_; + + private: + Tensor getPositionIds(Tensor& input_ids, Tensor& position_ids) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + if (!position_ids.isNil()) { + auto last_pos = *position_ids.offsettedPtr({0, 0, position_ids.shape()[2] - 1}); + auto ret_position_ids = Tensor::empty({3, 1, 1}, kInt64, kCPU).alloc(); + *ret_position_ids.offsettedPtr({0, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({1, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({2, 0, 0}) = last_pos + 1; + return ret_position_ids; + } + + auto B = input_ids.shape()[0]; + auto S = input_ids.shape()[1]; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor out = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + for (int d = 0; d < 3; ++d) { + auto out_ptr = out.offsettedPtr({d, 0, 0}); + for (int64_t s = 0; s < S; ++s) { out_ptr[s] = s; } + } + return out; + } + + const Qwen2_5OmniConfig& cfg_; + nn::StaticCache kv_cache_; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp new file mode 100644 index 000000000..8674af9f5 --- /dev/null +++ b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp @@ -0,0 +1,252 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" +#include "mllm/models/ARGeneration.hpp" + +namespace mllm::models::qwen2_5omni { + +// same regex as Qwen2/Qwen2-VL tokenizers +inline bool qwen2_5OmniTokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + { + size_t original_pos = pos; + bool has_prefix = false; + matched.clear(); + + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + has_prefix = true; + } + + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } else if (has_prefix) { + pos = original_pos; + matched.clear(); + } + } + + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + { + size_t original_pos = pos; + matched.clear(); + size_t start = pos; + + if (str[pos] == L' ') { ++pos; } + + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + + matched = str.substr(start, pos - start); + + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } else { + pos = original_pos; + } + } + + { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos; + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos >= str.size() || std::iswspace(str[pos])) { + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline bool qwen2_5OmniRegex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (qwen2_5OmniTokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } + return true; +} + +struct Qwen2_5OmniMessage { + std::string prompt; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +class Qwen2_5OmniTokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit Qwen2_5OmniTokenizer(const std::string& file_path) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + bpe_.initFromSentencePieceJson(file_path); + special_tokens_trie_.add(L"<|endoftext|>"); + special_tokens_trie_.add(L"<|im_start|>"); + special_tokens_trie_.add(L"<|im_end|>"); + special_tokens_trie_.add(L"<|object_ref_start|>"); + special_tokens_trie_.add(L"<|object_ref_end|>"); + special_tokens_trie_.add(L"<|box_start|>"); + special_tokens_trie_.add(L"<|box_end|>"); + special_tokens_trie_.add(L"<|quad_start|>"); + special_tokens_trie_.add(L"<|quad_end|>"); + special_tokens_trie_.add(L"<|vision_bos|>"); + special_tokens_trie_.add(L"<|vision_eos|>"); + special_tokens_trie_.add(L"<|vision_pad|>"); + special_tokens_trie_.add(L"<|image_pad|>"); + special_tokens_trie_.add(L"<|video_pad|>"); + special_tokens_trie_.add(L"<|AUDIO|>"); + special_tokens_trie_.add(L"<|audio_bos|>"); + special_tokens_trie_.add(L"<|audio_eos|>"); + special_tokens_trie_.add(L"<|IMAGE|>"); + special_tokens_trie_.add(L"<|VIDEO|>"); + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + ::mllm::models::qwen2_5omni::qwen2_5OmniRegex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + + auto bpe_ts = bpe_._bpe(mapped_str); + + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + + return ret; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); } + return {mllm::preprocessor::utf8string2WideString(utf_8_str)}; + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return ret; + } + + ARGenerationOutputPast convertMessage(const Qwen2_5OmniMessage& message) { + auto applied_string = message.buildChatMessage(); + auto sequence_str = tokenize(applied_string); + + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return {{"sequence", sequence}}; + } + + private: + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; +}; + +} // namespace mllm::models::qwen2_5omni From c9333abef049e87f02faa8b9edf3943e3e082cc0 Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Fri, 23 Jan 2026 14:50:29 +0800 Subject: [PATCH 02/17] add qwen2.5omni vision, audio modal --- examples/qwen2_5omni/CMakeLists.txt | 8 + examples/qwen2_5omni/audio_infer.cpp | 84 ++ examples/qwen2_5omni/image_infer.cpp | 84 ++ .../audio_preprocessor_qwen2_5omni.hpp | 240 ++++ .../qwen2_5omni/configuration_qwen2_5omni.hpp | 82 ++ .../qwen2_5omni/modeling_qwen2_5omni.hpp | 1033 ++++++++++++++++- .../qwen2_5omni/tokenization_qwen2_5omni.hpp | 135 ++- 7 files changed, 1659 insertions(+), 7 deletions(-) create mode 100644 examples/qwen2_5omni/audio_infer.cpp create mode 100644 examples/qwen2_5omni/image_infer.cpp create mode 100644 mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp diff --git a/examples/qwen2_5omni/CMakeLists.txt b/examples/qwen2_5omni/CMakeLists.txt index 3141b56d7..479c3a635 100644 --- a/examples/qwen2_5omni/CMakeLists.txt +++ b/examples/qwen2_5omni/CMakeLists.txt @@ -1,3 +1,11 @@ add_executable(mllm-qwen2_5-omni-text-runner text_infer.cpp) target_link_libraries(mllm-qwen2_5-omni-text-runner PRIVATE MllmRT MllmCPUBackend) target_include_directories(mllm-qwen2_5-omni-text-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-image-runner image_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-image-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-image-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-audio-runner audio_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-audio-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-audio-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen2_5omni/audio_infer.cpp b/examples/qwen2_5omni/audio_infer.cpp new file mode 100644 index 000000000..014b4688f --- /dev/null +++ b/examples/qwen2_5omni/audio_infer.cpp @@ -0,0 +1,84 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Audio CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string audio_path; + std::string prompt_text; + + fmt::print("Audio path (or 'exit/quit'): "); + //std::getline(std::cin, audio_path); + //if (audio_path == "exit" || audio_path == "quit") { return 0; } + audio_path = "/Users/kkkai/Desktop/mllm2-former/mllm/rsc/recognize.wav"; + + fmt::print("Prompt text: "); + //std::getline(std::cin, prompt_text); + //if (prompt_text.empty()) { prompt_text = "Please describe the audio."; } + prompt_text = "复述这段音频"; + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertAudioMessage({.prompt = prompt_text, .audio_file_path = audio_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/image_infer.cpp b/examples/qwen2_5omni/image_infer.cpp new file mode 100644 index 000000000..41bf770b1 --- /dev/null +++ b/examples/qwen2_5omni/image_infer.cpp @@ -0,0 +1,84 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = + mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get(), qwen2_5omni_cfg.visual_spatial_merge_size); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Image CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string image_path; + std::string prompt_text; + + fmt::print("Image path (or 'exit/quit'): "); + image_path = "../../../mllm2-former/mllm/rsc/pics.jpg"; + //std::getline(std::cin, image_path); + if (image_path == "exit" || image_path == "quit") { return 0; } + + fmt::print("Prompt text: "); + prompt_text = "描述图片中物体"; + //std::getline(std::cin, prompt_text); + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp new file mode 100644 index 000000000..392bfc17b --- /dev/null +++ b/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp @@ -0,0 +1,240 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/Tensor.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/preprocessor/audio/Audio.hpp" + +namespace mllm::models::qwen2_5omni { + +inline float hertz_to_mel_slaney(float freq) { + constexpr float kMinLogHertz = 1000.0f; + constexpr float kMinLogMel = 15.0f; + const float logstep = 27.0f / std::log(6.4f); + + if (freq < kMinLogHertz) { + return 3.0f * freq / 200.0f; + } + return kMinLogMel + std::log(freq / kMinLogHertz) * logstep; +} + +inline float mel_to_hertz_slaney(float mel) { + constexpr float kMinLogHertz = 1000.0f; + constexpr float kMinLogMel = 15.0f; + const float logstep = std::log(6.4f) / 27.0f; + + if (mel < kMinLogMel) { + return 200.0f * mel / 3.0f; + } + return kMinLogHertz * std::exp(logstep * (mel - kMinLogMel)); +} + +inline Tensor create_hann_window(int32_t window_length, bool periodic = true) { + int32_t length = periodic ? window_length + 1 : window_length; + auto window = Tensor::empty({1, window_length}, kFloat32, kCPU).alloc(); + float* window_ptr = window.ptr(); + + for (int32_t i = 0; i < window_length; ++i) { + float n = static_cast(i); + float denominator = periodic ? static_cast(length) : static_cast(length - 1); + window_ptr[i] = 0.5f - 0.5f * std::cos(2.0f * M_PI * n / denominator); + } + + return window; +} + +inline Tensor create_mel_filterbank(int32_t num_frequency_bins, int32_t num_mel_filters, float min_frequency, + float max_frequency, int32_t sampling_rate) { + std::vector fft_freqs(num_frequency_bins); + for (int32_t i = 0; i < num_frequency_bins; ++i) { + fft_freqs[i] = static_cast(i) * (sampling_rate / 2.0f) / (num_frequency_bins - 1); + } + + float mel_min = hertz_to_mel_slaney(min_frequency); + float mel_max = hertz_to_mel_slaney(max_frequency); + + std::vector mel_freqs(num_mel_filters + 2); + for (int32_t i = 0; i < num_mel_filters + 2; ++i) { + mel_freqs[i] = mel_min + static_cast(i) * (mel_max - mel_min) / (num_mel_filters + 1); + } + + std::vector filter_freqs(num_mel_filters + 2); + for (int32_t i = 0; i < num_mel_filters + 2; ++i) { filter_freqs[i] = mel_to_hertz_slaney(mel_freqs[i]); } + + auto mel_filters = Tensor::empty({num_frequency_bins, num_mel_filters}, kFloat32, kCPU).alloc(); + float* filters_ptr = mel_filters.ptr(); + std::fill_n(filters_ptr, num_frequency_bins * num_mel_filters, 0.0f); + + for (int32_t mel_idx = 0; mel_idx < num_mel_filters; ++mel_idx) { + float left_freq = filter_freqs[mel_idx]; + float center_freq = filter_freqs[mel_idx + 1]; + float right_freq = filter_freqs[mel_idx + 2]; + + for (int32_t freq_idx = 0; freq_idx < num_frequency_bins; ++freq_idx) { + float freq = fft_freqs[freq_idx]; + float value = 0.0f; + + if (freq >= left_freq && freq <= center_freq && center_freq != left_freq) { + value = (freq - left_freq) / (center_freq - left_freq); + } else if (freq >= center_freq && freq <= right_freq && right_freq != center_freq) { + value = (right_freq - freq) / (right_freq - center_freq); + } + + filters_ptr[freq_idx * num_mel_filters + mel_idx] = value; + } + } + + for (int32_t mel_idx = 0; mel_idx < num_mel_filters; ++mel_idx) { + float enorm = 2.0f / (filter_freqs[mel_idx + 2] - filter_freqs[mel_idx]); + for (int32_t freq_idx = 0; freq_idx < num_frequency_bins; ++freq_idx) { + filters_ptr[freq_idx * num_mel_filters + mel_idx] *= enorm; + } + } + + return mel_filters; +} + +class MelSpectrogramFeatures final : public nn::Module { + int32_t n_fft_; + int32_t hop_length_; + int32_t win_length_; + int32_t n_mels_; + std::string padding_; + int power_; + nn::STFT stft_; + Tensor window_; + Tensor melscale_fbanks_; + + public: + MelSpectrogramFeatures() = default; + + explicit inline MelSpectrogramFeatures(const std::string& name, int32_t sample_rate = 16000, int32_t n_fft = 400, + int32_t hop_length = 160, int32_t n_mels = 128, + const std::string& padding = "center", int power = 2) + : nn::Module(name), n_fft_(n_fft), hop_length_(hop_length), n_mels_(n_mels), padding_(padding), power_(power) { + if (padding != "center" && padding != "same") { throw std::invalid_argument("Padding must be 'center' or 'same'."); } + + win_length_ = n_fft_; + stft_ = reg("stft", n_fft_, hop_length_, win_length_, true, true, "reflect", true); + window_ = create_hann_window(win_length_, true); + melscale_fbanks_ = create_mel_filterbank(n_fft_ / 2 + 1, n_mels_, 0.0f, 8000.0f, sample_rate); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto audio = inputs[0]; // [B, T] + + if (padding_ == "same") { + NYI("apply same padding in MelSpectrogramFeatures not implemented"); + } + + auto stft_result = stft_(audio, window_); + auto specgram = stft_result.abs(); + if (power_ == 2) { + specgram = specgram * specgram; + } else if (power_ != 1) { + NYI("power != 1 and power != 2 not implemented"); + } + + auto mel_specgram = nn::functional::matmul(specgram.T(), melscale_fbanks_).T(); + mel_specgram = nn::functional::clip(mel_specgram, 1e-10f, std::numeric_limits::max()); + mel_specgram = nn::functional::log(mel_specgram) / std::log(10.0f); + auto max_val = mel_specgram.max(); + float threshold = max_val.item() - 8.0f; + mel_specgram = nn::functional::clip(mel_specgram, threshold, std::numeric_limits::max()); + mel_specgram = (mel_specgram + 4.0f) / 4.0f; + + return {mel_specgram}; + } +}; + +struct Qwen2_5OmniAudioFeatures { + Tensor input_features = Tensor::nil(); + int32_t feature_length = 0; +}; + +class Qwen2_5OmniAudioPreprocessor { + MelSpectrogramFeatures mel_extractor_; + int32_t sample_rate_; + int32_t n_mels_; + int32_t hop_length_; + int32_t chunk_length_; + int32_t n_samples_; + + public: + explicit Qwen2_5OmniAudioPreprocessor(int32_t sample_rate = 16000, int32_t n_mels = 128, int32_t hop_length = 160, + int32_t chunk_length = 300) + : mel_extractor_("feature_extractor.mel_spec", sample_rate, 400, hop_length, n_mels, "center", 2), + sample_rate_(sample_rate), + n_mels_(n_mels), + hop_length_(hop_length), + chunk_length_(chunk_length), + n_samples_(chunk_length * sample_rate) {} + + [[nodiscard]] Qwen2_5OmniAudioFeatures processAudioFile(const std::string& audio_file_path) { + auto audio_data = mllm::audio::readWAV(audio_file_path, sample_rate_); + if (audio_data.empty()) { return {}; } + return processAudioData(audio_data.data(), static_cast(audio_data.size())); + } + + [[nodiscard]] Qwen2_5OmniAudioFeatures processAudioData(const float* audio_data, int32_t audio_length) { + Qwen2_5OmniAudioFeatures result; + if (audio_data == nullptr || audio_length <= 0) { return result; } + + int32_t padded_length = n_samples_; + int32_t effective_length = std::min(audio_length, padded_length); + + auto audio_tensor = Tensor::empty({1, padded_length}, kFloat32, kCPU).alloc(); + float* audio_ptr = audio_tensor.ptr(); + + if (audio_length <= padded_length) { + std::memcpy(audio_ptr, audio_data, audio_length * sizeof(float)); + std::fill(audio_ptr + audio_length, audio_ptr + padded_length, 0.0f); + } else { + std::memcpy(audio_ptr, audio_data, padded_length * sizeof(float)); + } + + auto mel_spec = mel_extractor_.forward({audio_tensor}, {})[0]; // [1, n_mels, n_frames] + + int32_t valid_frames = calcFeatureLength(effective_length); + int32_t max_frames = mel_spec.shape()[2]; + if (valid_frames > max_frames) { valid_frames = max_frames; } + if (valid_frames <= 0) { return result; } + + auto trimmed = Tensor::empty({1, n_mels_, valid_frames}, kFloat32, kCPU).alloc(); + for (int32_t m = 0; m < n_mels_; ++m) { + auto src_ptr = mel_spec.offsettedPtr({0, m, 0}); + auto dst_ptr = trimmed.offsettedPtr({0, m, 0}); + std::memcpy(dst_ptr, src_ptr, valid_frames * sizeof(float)); + } + + result.input_features = trimmed; + result.feature_length = valid_frames; + return result; + } + + [[nodiscard]] int32_t calcFeatureLength(int32_t audio_length) const { + if (audio_length <= 0) { return 0; } + return (audio_length + hop_length_ - 1) / hop_length_; + } + + [[nodiscard]] int32_t calcAudioTokenLength(int32_t feature_length) const { + if (feature_length <= 0) { return 0; } + int32_t after_conv = (feature_length - 1) / 2 + 1; + if (after_conv < 2) { return 0; } + int32_t after_pool = (after_conv - 2) / 2 + 1; + return std::max(0, after_pool); + } +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp index 2b0cb1ee8..d0e000642 100644 --- a/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp +++ b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp @@ -35,12 +35,48 @@ struct Qwen2_5OmniConfig : protected ConfigFile { mrope_section = text_cfg["rope_scaling"]["mrope_section"].get>(); } + if (thinker_cfg.contains("vision_config")) { + auto& vision_cfg = thinker_cfg["vision_config"]; + visual_in_chans = vision_cfg.value("in_channels", vision_cfg.value("in_chans", visual_in_chans)); + visual_hidden_size = vision_cfg.value("hidden_size", vision_cfg.value("embed_dim", visual_hidden_size)); + visual_patch_size = vision_cfg.value("patch_size", vision_cfg.value("spatial_patch_size", visual_patch_size)); + visual_temporal_patch_size = vision_cfg.value("temporal_patch_size", visual_temporal_patch_size); + visual_spatial_merge_size = vision_cfg.value("spatial_merge_size", visual_spatial_merge_size); + visual_out_hidden_size = vision_cfg.value("out_hidden_size", visual_out_hidden_size); + visual_num_heads = vision_cfg.value("num_heads", visual_num_heads); + visual_depth = vision_cfg.value("depth", visual_depth); + visual_intermediate_size = vision_cfg.value("intermediate_size", visual_intermediate_size); + if (vision_cfg.contains("fullatt_block_indexes")) { + visual_fullatt_block_indexes = vision_cfg["fullatt_block_indexes"].get>(); + } + visual_window_size = vision_cfg.value("window_size", visual_window_size); + } + + if (thinker_cfg.contains("audio_config")) { + auto& audio_cfg = thinker_cfg["audio_config"]; + audio_d_model = audio_cfg.value("d_model", audio_d_model); + audio_num_mel_bins = audio_cfg.value("num_mel_bins", audio_num_mel_bins); + audio_encoder_layers = audio_cfg.value("encoder_layers", audio_encoder_layers); + audio_encoder_attention_heads = audio_cfg.value("encoder_attention_heads", audio_encoder_attention_heads); + audio_encoder_ffn_dim = audio_cfg.value("encoder_ffn_dim", audio_encoder_ffn_dim); + audio_max_source_positions = audio_cfg.value("max_source_positions", audio_max_source_positions); + audio_n_window = audio_cfg.value("n_window", audio_n_window); + audio_output_dim = audio_cfg.value("output_dim", audio_output_dim); + } + bos_token_id = thinker_cfg.value("bos_token_id", bos_token_id); eos_token_id = thinker_cfg.value("eos_token_id", eos_token_id); pad_token_id = thinker_cfg.value("pad_token_id", pad_token_id); image_token_id = thinker_cfg.value("image_token_index", image_token_id); audio_token_id = thinker_cfg.value("audio_token_index", audio_token_id); video_token_id = thinker_cfg.value("video_token_index", video_token_id); + audio_start_token_id = thinker_cfg.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = thinker_cfg.value("audio_end_token_id", audio_end_token_id); + vision_start_token_id = thinker_cfg.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = thinker_cfg.value("vision_end_token_id", vision_end_token_id); + vision_token_id = thinker_cfg.value("vision_token_id", vision_token_id); + position_id_per_seconds = thinker_cfg.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = thinker_cfg.value("seconds_per_chunk", seconds_per_chunk); } else { hidden_size = root["hidden_size"]; intermediate_size = root["intermediate_size"]; @@ -55,12 +91,30 @@ struct Qwen2_5OmniConfig : protected ConfigFile { if (root.contains("mrope_section")) { mrope_section = root["mrope_section"].get>(); } + if (root.contains("audio_config")) { + auto& audio_cfg = root["audio_config"]; + audio_d_model = audio_cfg.value("d_model", audio_d_model); + audio_num_mel_bins = audio_cfg.value("num_mel_bins", audio_num_mel_bins); + audio_encoder_layers = audio_cfg.value("encoder_layers", audio_encoder_layers); + audio_encoder_attention_heads = audio_cfg.value("encoder_attention_heads", audio_encoder_attention_heads); + audio_encoder_ffn_dim = audio_cfg.value("encoder_ffn_dim", audio_encoder_ffn_dim); + audio_max_source_positions = audio_cfg.value("max_source_positions", audio_max_source_positions); + audio_n_window = audio_cfg.value("n_window", audio_n_window); + audio_output_dim = audio_cfg.value("output_dim", audio_output_dim); + } bos_token_id = root.value("bos_token_id", bos_token_id); eos_token_id = root.value("eos_token_id", eos_token_id); pad_token_id = root.value("pad_token_id", pad_token_id); image_token_id = root.value("image_token_id", image_token_id); audio_token_id = root.value("audio_token_id", audio_token_id); video_token_id = root.value("video_token_id", video_token_id); + audio_start_token_id = root.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = root.value("audio_end_token_id", audio_end_token_id); + vision_start_token_id = root.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = root.value("vision_end_token_id", vision_end_token_id); + vision_token_id = root.value("vision_token_id", vision_token_id); + position_id_per_seconds = root.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = root.value("seconds_per_chunk", seconds_per_chunk); } max_cache_length = root.value("max_cache_length", max_position_embeddings); @@ -82,6 +136,27 @@ struct Qwen2_5OmniConfig : protected ConfigFile { float rope_theta = 1000000.0f; bool tie_word_embeddings = false; + int32_t visual_in_chans = 3; + int32_t visual_hidden_size = 1280; + int32_t visual_patch_size = 14; + int32_t visual_temporal_patch_size = 2; + int32_t visual_spatial_merge_size = 2; + int32_t visual_out_hidden_size = 3584; + int32_t visual_num_heads = 16; + int32_t visual_depth = 32; + int32_t visual_intermediate_size = 3420; + std::vector visual_fullatt_block_indexes = {7, 15, 23, 31}; + int32_t visual_window_size = 112; + + int32_t audio_d_model = 1280; + int32_t audio_num_mel_bins = 128; + int32_t audio_encoder_layers = 32; + int32_t audio_encoder_attention_heads = 20; + int32_t audio_encoder_ffn_dim = 5120; + int32_t audio_max_source_positions = 1500; + int32_t audio_n_window = 100; + int32_t audio_output_dim = 3584; + int32_t max_cache_length = 32768; int64_t bos_token_id = 151644; @@ -90,6 +165,13 @@ struct Qwen2_5OmniConfig : protected ConfigFile { int64_t image_token_id = 151655; int64_t audio_token_id = 151646; int64_t video_token_id = 151656; + int64_t audio_start_token_id = 151647; + int64_t audio_end_token_id = 151648; + int64_t vision_start_token_id = 151652; + int64_t vision_end_token_id = 151653; + int64_t vision_token_id = 151654; + int32_t position_id_per_seconds = 25; + int32_t seconds_per_chunk = 2; aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; }; diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp index 7bd00baa7..fac087bae 100644 --- a/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp @@ -2,9 +2,14 @@ // Licensed under the MIT License. #pragma once +#include #include +#include +#include +#include #include "mllm/mllm.hpp" +#include "mllm/core/SlicePrimitives.hpp" #include "mllm/nn/Module.hpp" #include "mllm/nn/Nn.hpp" #include "mllm/nn/Functional.hpp" @@ -87,6 +92,756 @@ inline auto makeMultimodalPositionEmbedding(Tensor& position_ids, const Tensor& return {sin, cos}; } +inline auto makeWindowIndex(const Tensor& grid_thw, int window_size, int spatial_merge_size, + int patch_size) -> std::pair, std::vector> { + MLLM_RT_ASSERT_EQ(grid_thw.shape().size(), 2); + const int grid_num = grid_thw.shape()[0]; + + const int vit_merger_window_size = window_size / spatial_merge_size / patch_size; + const int spatial_merge_unit = spatial_merge_size * spatial_merge_size; + + std::vector window_index; + std::vector cu_window_seqlens = {0}; + int window_index_id = 0; + + for (int grid_idx = 0; grid_idx < grid_num; ++grid_idx) { + const int grid_t = grid_thw.constAt({grid_idx, 0}); + const int grid_h = grid_thw.constAt({grid_idx, 1}); + const int grid_w = grid_thw.constAt({grid_idx, 2}); + + const int llm_grid_h = grid_h / spatial_merge_size; + const int llm_grid_w = grid_w / spatial_merge_size; + const int pad_h = (vit_merger_window_size - llm_grid_h % vit_merger_window_size) % vit_merger_window_size; + const int pad_w = (vit_merger_window_size - llm_grid_w % vit_merger_window_size) % vit_merger_window_size; + + const int num_windows_h = (llm_grid_h + pad_h) / vit_merger_window_size; + const int num_windows_w = (llm_grid_w + pad_w) / vit_merger_window_size; + const int total_windows = grid_t * num_windows_h * num_windows_w; + + std::vector>> index( + grid_t, std::vector>(llm_grid_h, std::vector(llm_grid_w))); + + int counter = 0; + for (int t = 0; t < grid_t; t++) { + for (int h = 0; h < llm_grid_h; h++) { + for (int w = 0; w < llm_grid_w; w++) { index[t][h][w] = counter++; } + } + } + + std::vector>> index_padded( + grid_t, std::vector>(llm_grid_h + pad_h, std::vector(llm_grid_w + pad_w, -100))); + + for (int t = 0; t < grid_t; t++) { + for (int h = 0; h < llm_grid_h; h++) { + for (int w = 0; w < llm_grid_w; w++) { index_padded[t][h][w] = index[t][h][w]; } + } + } + + std::vector seqlens(total_windows, 0); + for (int t = 0; t < grid_t; t++) { + for (int wh = 0; wh < num_windows_h; wh++) { + for (int ww = 0; ww < num_windows_w; ww++) { + const int window_idx = t * num_windows_h * num_windows_w + wh * num_windows_w + ww; + for (int h = 0; h < vit_merger_window_size; h++) { + for (int w = 0; w < vit_merger_window_size; w++) { + const int orig_h = wh * vit_merger_window_size + h; + const int orig_w = ww * vit_merger_window_size + w; + if (index_padded[t][orig_h][orig_w] != -100) { + window_index.push_back(index_padded[t][orig_h][orig_w] + window_index_id); + seqlens[window_idx]++; + } + } + } + } + } + } + + int cumulative = cu_window_seqlens.back(); + for (int i = 0; i < total_windows; i++) { + cumulative += seqlens[i] * spatial_merge_unit; + cu_window_seqlens.push_back(cumulative); + } + + window_index_id += grid_t * llm_grid_h * llm_grid_w; + } + + return {window_index, cu_window_seqlens}; +} + +inline auto makeVisualRoPEInvFreq(int32_t dims, float theta) -> Tensor { + const int half_dim = dims / (2 * 2); + Tensor inv_freq = Tensor::empty({half_dim}, kFloat32).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + const float dims_inv = 1.0f / static_cast(dims / 2); + for (int i = 0; i < half_dim; ++i) { + const float exponent = (2.0f * i) * dims_inv; + inv_freq_ptr[i] = 1.0f / std::pow(theta, exponent); + } + return inv_freq; +} + +inline auto makeVisualRotaryPosEmbIds(Tensor& grid_thw, int32_t spatial_merge_size) -> Tensor { + MLLM_RT_ASSERT_EQ(grid_thw.shape().size(), 2); + + const auto img_nums = grid_thw.shape()[0]; + int total_positions = 0; + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + total_positions += dims[0] * dims[1] * dims[2]; + } + + Tensor out = Tensor::empty({total_positions, 2}, kInt32).alloc(); + int* out_ptr = out.ptr(); + int out_offset = 0; + + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + const int t = dims[0]; + const int h = dims[1]; + const int w = dims[2]; + + const int num_h_blocks = h / spatial_merge_size; + const int num_w_blocks = w / spatial_merge_size; + const int total_blocks = num_h_blocks * num_w_blocks; + const int block_area = spatial_merge_size * spatial_merge_size; + const int grid_size = h * w; + + std::vector flatten_hpos(grid_size); + std::vector flatten_wpos(grid_size); + + for (int block_idx = 0; block_idx < total_blocks; ++block_idx) { + const int i_h = block_idx / num_w_blocks; + const int i_w = block_idx % num_w_blocks; + const int start_idx = block_idx * block_area; + + const int base_h = i_h * spatial_merge_size; + const int base_w = i_w * spatial_merge_size; + + for (int j_h = 0; j_h < spatial_merge_size; ++j_h) { + const int global_h = base_h + j_h; + for (int j_w = 0; j_w < spatial_merge_size; ++j_w) { + const int global_w = base_w + j_w; + const int pos = start_idx + j_h * spatial_merge_size + j_w; + flatten_hpos[pos] = global_h; + flatten_wpos[pos] = global_w; + } + } + } + + for (int frame = 0; frame < t; ++frame) { + for (int pos = 0; pos < grid_size; ++pos) { + const int out_idx = out_offset + (frame * grid_size + pos) * 2; + out_ptr[out_idx] = flatten_hpos[pos]; + out_ptr[out_idx + 1] = flatten_wpos[pos]; + } + } + out_offset += t * grid_size * 2; + } + + return out; +} + +inline auto makeVisualRotaryPosEmbFull(Tensor& inv_freq, int seq_len) -> Tensor { + MLLM_RT_ASSERT(seq_len > 0); + const int32_t dim = inv_freq.shape()[0]; + Tensor freqs = Tensor::empty({seq_len, dim}, kFloat32, kCPU).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + float* freqs_ptr = freqs.ptr(); + for (int i = 0; i < seq_len; ++i) { + const float i_val = static_cast(i); + float* row_ptr = freqs_ptr + i * dim; + for (int j = 0; j < dim; ++j) { row_ptr[j] = i_val * inv_freq_ptr[j]; } + } + return freqs; +} + +inline auto makeVisualRotaryPosEmb(Tensor& rotary_pos_emb_full, Tensor& pos_ids, Tensor& grid_thw) -> Tensor { + const int32_t dim = rotary_pos_emb_full.shape()[1]; + const int32_t batch_size = pos_ids.shape()[0]; + const int32_t seq_len = pos_ids.shape()[1]; + + int total_positions = 0; + for (int row = 0; row < grid_thw.shape()[0]; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + total_positions += dims[0] * dims[1] * dims[2]; + } + + Tensor out = Tensor::empty({batch_size, seq_len * dim}, kFloat32, kCPU).alloc(); + + auto rotary_pos_emb_full_ptr = rotary_pos_emb_full.ptr(); + auto pos_ids_ptr = pos_ids.ptr(); + + if (rotary_pos_emb_full.shape()[0] <= 0 || dim <= 0 || batch_size <= 0) { + MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Invalid tensor dimensions"); + } + + if (total_positions != batch_size) { MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Grid dimensions mismatch with batch size"); } + + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < seq_len; ++j) { + const int idx = pos_ids_ptr[i * seq_len + j]; + if (idx < 0 || idx >= rotary_pos_emb_full.shape()[0]) { + MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Position index out of bounds"); + } + } + } + + for (int i = 0; i < batch_size; ++i) { + auto batch_ptr = out.offsettedPtr({i, 0}); + size_t offset = 0; + for (int j = 0; j < seq_len; ++j) { + const int idx = pos_ids_ptr[i * seq_len + j]; + auto emb_ptr = rotary_pos_emb_full_ptr + idx * dim; + std::copy(emb_ptr, emb_ptr + dim, batch_ptr + offset); + offset += dim; + } + } + + return out; +} + +inline auto makeVisualRotarySinCos(Tensor& rotary_pos_emb) -> std::pair { + const auto seq = rotary_pos_emb.shape()[0]; + const auto dim = rotary_pos_emb.shape()[1]; + + auto rotary_pos_emb_ptr = rotary_pos_emb.ptr(); + + Tensor sin_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + Tensor cos_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + + auto sin_pos_emb_ptr = sin_pos_emb.ptr(); + auto cos_pos_emb_ptr = cos_pos_emb.ptr(); + + for (int i = 0; i < seq; i++) { + for (int j = 0; j < dim; j++) { + sin_pos_emb_ptr[i * dim + j] = std::sin(rotary_pos_emb_ptr[i * dim + j]); + cos_pos_emb_ptr[i * dim + j] = std::cos(rotary_pos_emb_ptr[i * dim + j]); + } + } + + return {sin_pos_emb, cos_pos_emb}; +} + +inline auto makeAudioSinusoidalPosEmb(int32_t length, int32_t channels, float max_timescale = 10000.0f) -> Tensor { + MLLM_RT_ASSERT(channels % 2 == 0); + auto pos_emb = Tensor::empty({length, channels}, kFloat32, kCPU).alloc(); + auto pos_ptr = pos_emb.ptr(); + + const int half = channels / 2; + const float log_timescale_increment = std::log(max_timescale) / static_cast(half - 1); + + std::vector inv_timescales(half); + for (int i = 0; i < half; ++i) { + inv_timescales[i] = std::exp(-log_timescale_increment * static_cast(i)); + } + + for (int t = 0; t < length; ++t) { + for (int i = 0; i < half; ++i) { + const float scaled_time = static_cast(t) * inv_timescales[i]; + pos_ptr[t * channels + i] = std::sin(scaled_time); + pos_ptr[t * channels + half + i] = std::cos(scaled_time); + } + } + + return pos_emb; +} + +class Qwen2_5OmniPatchEmbed final : public nn::Module { + int32_t in_chans_; + int32_t embed_dim_; + int32_t patch_size_; + int32_t temporal_patch_size_; + + nn::Conv3D proj_; + + public: + Qwen2_5OmniPatchEmbed() = default; + + explicit Qwen2_5OmniPatchEmbed(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + in_chans_ = cfg.visual_in_chans; + embed_dim_ = cfg.visual_hidden_size; + patch_size_ = cfg.visual_patch_size; + temporal_patch_size_ = cfg.visual_temporal_patch_size; + + proj_ = reg("proj", cfg.visual_in_chans, cfg.visual_hidden_size, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + false); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + hidden_states = hidden_states.view({-1, in_chans_, temporal_patch_size_, patch_size_, patch_size_}); + hidden_states = proj_(hidden_states).view({-1, embed_dim_}); + return {hidden_states}; + } +}; + +class Qwen2_5OmniPatchMerger final : public nn::Module { + int32_t hidden_size_; + int32_t spatial_merge_size_; + int32_t context_dim_; + + nn::RMSNorm ln_q_; + nn::Linear mlp_0_; + nn::Linear mlp_2_; + nn::GELU mlp_gelu_; + + public: + Qwen2_5OmniPatchMerger() = default; + + explicit Qwen2_5OmniPatchMerger(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + context_dim_ = cfg.visual_hidden_size; + spatial_merge_size_ = cfg.visual_spatial_merge_size; + hidden_size_ = context_dim_ * spatial_merge_size_ * spatial_merge_size_; + + ln_q_ = reg("ln_q", 1e-6); + mlp_0_ = reg("mlp.0", hidden_size_, hidden_size_, true, cfg.linear_impl_type); + mlp_gelu_ = reg("mlp.gelu"); + mlp_2_ = reg("mlp.2", hidden_size_, cfg.visual_out_hidden_size, true, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto o = ln_q_(inputs[0]).view({-1, hidden_size_}); + o = mlp_0_(o); + o = mlp_gelu_(o); + o = mlp_2_(o); + return {o}; + } +}; + +class Qwen2_5OmniVisionMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniVisionMLP() = default; + explicit Qwen2_5OmniVisionMLP(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.visual_hidden_size, cfg.visual_intermediate_size, true); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.visual_hidden_size, cfg.visual_intermediate_size, true); + down_proj_ = reg("down_proj", cfg.visual_intermediate_size, cfg.visual_hidden_size, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniVisionAttention final : public nn::Module { + int32_t dim_; + int32_t num_heads_; + int32_t head_dim_; + + nn::Linear q_; + nn::Linear k_; + nn::Linear v_; + nn::Linear proj_; + nn::Softmax softmax_; + nn::VisionRoPE vision_rope_q_; + nn::VisionRoPE vision_rope_k_; + + public: + Qwen2_5OmniVisionAttention() = default; + + explicit Qwen2_5OmniVisionAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + dim_ = cfg.visual_hidden_size; + num_heads_ = cfg.visual_num_heads; + head_dim_ = dim_ / num_heads_; + + q_ = reg("q", dim_, dim_, true, cfg.linear_impl_type); + k_ = reg("k", dim_, dim_, true, cfg.linear_impl_type); + v_ = reg("v", dim_, dim_, true, cfg.linear_impl_type); + proj_ = reg("proj", dim_, dim_, true, cfg.linear_impl_type); + softmax_ = reg("softmax", -1); + + vision_rope_q_ = reg("vision_rope_q", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + vision_rope_k_ = reg("vision_rope_k", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + auto& mask = inputs[3]; + + auto seq_length = hidden_states.shape()[0]; + + auto query_states = q_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + auto key_states = k_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + auto value_states = v_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + + query_states = vision_rope_q_(query_states, visual_embedding_sin, visual_embedding_cos); + key_states = vision_rope_k_(key_states, visual_embedding_sin, visual_embedding_cos); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + auto attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + if (mask) { attn = attn + mask; } + attn = softmax_(attn); + + auto attn_output = nn::functional::matmul(attn, value_states); + attn_output = attn_output.transpose(1, 2).view({seq_length, -1}); + attn_output = proj_(attn_output); + return {attn_output}; + } +}; + +class Qwen2_5OmniVisionBlock final : public nn::Module { + nn::RMSNorm norm1_; + nn::RMSNorm norm2_; + + Qwen2_5OmniVisionAttention attn_; + Qwen2_5OmniVisionMLP mlp_; + + public: + Qwen2_5OmniVisionBlock() = default; + + explicit Qwen2_5OmniVisionBlock(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + norm1_ = reg("norm1", 1e-6); + norm2_ = reg("norm2", 1e-6); + attn_ = reg("attn", cfg); + mlp_ = reg("mlp", cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + auto mask = inputs[3]; + + hidden_states = hidden_states + attn_(norm1_(hidden_states), visual_embedding_sin, visual_embedding_cos, mask)[0]; + hidden_states = hidden_states + mlp_(norm2_(hidden_states))[0]; + return {hidden_states}; + } +}; + +class Qwen2_5OmniVisionEncoder final : public nn::Module { + Qwen2_5OmniPatchEmbed patch_embed_; + Qwen2_5OmniPatchMerger patch_merger_; + nn::ModuleList blocks_; + std::vector visual_fullatt_block_indexes_; + int32_t visual_window_size_ = 0; + int32_t visual_spatial_merge_size_ = 1; + int32_t visual_patch_size_ = 1; + int32_t spatial_merge_unit_ = 1; + + public: + Qwen2_5OmniVisionEncoder() = default; + + explicit Qwen2_5OmniVisionEncoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + visual_window_size_ = cfg.visual_window_size; + visual_spatial_merge_size_ = cfg.visual_spatial_merge_size; + visual_patch_size_ = cfg.visual_patch_size; + spatial_merge_unit_ = visual_spatial_merge_size_ * visual_spatial_merge_size_; + visual_fullatt_block_indexes_ = cfg.visual_fullatt_block_indexes; + patch_embed_ = reg("patch_embed", cfg); + patch_merger_ = reg("merger", cfg); + blocks_ = reg>("blocks", cfg.visual_depth, cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto embedding_sin = inputs[1]; + auto embedding_cos = inputs[2]; + auto& grid_thw = inputs[3]; + + hidden_states = patch_embed_(hidden_states)[0]; + auto [window_index, cu_window_seqlens] = + makeWindowIndex(grid_thw, visual_window_size_, visual_spatial_merge_size_, visual_patch_size_); + + auto seq_len = hidden_states.shape()[0]; + hidden_states = hidden_states.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + hidden_states = hidden_states[{window_index, {kAll}, {kAll}}]; + hidden_states = hidden_states.view({seq_len, -1}); + + embedding_sin = embedding_sin.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + embedding_sin = embedding_sin[{window_index, {kAll}, {kAll}}]; + embedding_sin = embedding_sin.view({seq_len, -1}); + embedding_cos = embedding_cos.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + embedding_cos = embedding_cos[{window_index, {kAll}, {kAll}}]; + embedding_cos = embedding_cos.view({seq_len, -1}); + + auto mask = Tensor::empty({1, 1, seq_len, seq_len}, DataTypes::kFloat32, DeviceTypes::kCPU).alloc(); + { + auto mask_ptr = mask.ptr(); + const mllm_fp32_t neg_inf = -1e12f; + for (int i = 0; i < seq_len * seq_len; ++i) { mask_ptr[i] = neg_inf; } + for (int i = 1; i < cu_window_seqlens.size(); ++i) { + const int start = cu_window_seqlens[i - 1]; + const int end = cu_window_seqlens[i]; + for (int r = start; r < end; ++r) { + for (int c = start; c < end; ++c) { mask_ptr[r * seq_len + c] = 0.0f; } + } + } + } + + for (auto [layer_idx, b] : enumerate(blocks_.list())) { + if (std::find(visual_fullatt_block_indexes_.begin(), visual_fullatt_block_indexes_.end(), layer_idx) + != visual_fullatt_block_indexes_.end()) { + hidden_states = b(hidden_states, embedding_sin, embedding_cos, Tensor::nil())[0]; + } else { + hidden_states = b(hidden_states, embedding_sin, embedding_cos, mask)[0]; + } + } + + hidden_states = patch_merger_(hidden_states)[0]; + + std::vector reverse_indices(window_index.size()); + std::iota(reverse_indices.begin(), reverse_indices.end(), 0); + std::sort(reverse_indices.begin(), reverse_indices.end(), + [&window_index](int i, int j) { return window_index[i] < window_index[j]; }); + hidden_states = hidden_states[{reverse_indices, {kAll}}]; + + return {hidden_states}; + } +}; + +class Qwen2_5OmniAudioAttention final : public nn::Module { + int32_t embed_dim_ = 0; + int32_t num_heads_ = 0; + int32_t head_dim_ = 0; + + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear q_proj_; + nn::Linear out_proj_; + + public: + Qwen2_5OmniAudioAttention() = default; + + explicit Qwen2_5OmniAudioAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + embed_dim_ = cfg.audio_d_model; + num_heads_ = cfg.audio_encoder_attention_heads; + head_dim_ = embed_dim_ / num_heads_; + + k_proj_ = reg("k_proj", embed_dim_, embed_dim_, false); + v_proj_ = reg("v_proj", embed_dim_, embed_dim_, true); + q_proj_ = reg("q_proj", embed_dim_, embed_dim_, true); + out_proj_ = reg("out_proj", embed_dim_, embed_dim_, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; // [seq_len, embed_dim] + auto seq_len = hidden_states.shape()[0]; + + auto hidden = hidden_states.unsqueeze(0); // [1, seq_len, embed_dim] + auto query_states = q_proj_(hidden); + auto key_states = k_proj_(hidden); + auto value_states = v_proj_(hidden); + + query_states = query_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + key_states = key_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + value_states = value_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + + float scale = 1.0f / std::sqrt(static_cast(head_dim_)); + auto attn_weights = nn::functional::matmul(query_states, key_states.transpose(-2, -1)) * scale; + attn_weights = nn::functional::softmax(attn_weights, -1); + auto attn_output = nn::functional::matmul(attn_weights, value_states); + + attn_output = attn_output.transpose(1, 2).contiguous().view({1, seq_len, embed_dim_}); + attn_output = out_proj_(attn_output); + + return {attn_output.squeeze(0)}; + } +}; + +class Qwen2_5OmniAudioEncoderLayer final : public nn::Module { + Qwen2_5OmniAudioAttention self_attn_; + nn::LayerNorm self_attn_layer_norm_; + nn::Linear fc1_; + nn::Linear fc2_; + nn::LayerNorm final_layer_norm_; + nn::GELU activation_fn_; + + public: + Qwen2_5OmniAudioEncoderLayer() = default; + + explicit Qwen2_5OmniAudioEncoderLayer(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + const int32_t embed_dim = cfg.audio_d_model; + self_attn_ = reg("self_attn", cfg); + self_attn_layer_norm_ = + reg("self_attn_layer_norm", std::vector{embed_dim}, true, true, 1e-5); + fc1_ = reg("fc1", embed_dim, cfg.audio_encoder_ffn_dim, true); + fc2_ = reg("fc2", cfg.audio_encoder_ffn_dim, embed_dim, true); + final_layer_norm_ = reg("final_layer_norm", std::vector{embed_dim}, true, true, 1e-5); + activation_fn_ = reg("activation_fn"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto residual = hidden_states; + + hidden_states = self_attn_layer_norm_(hidden_states); + hidden_states = self_attn_(hidden_states)[0]; + hidden_states = residual + hidden_states; + + residual = hidden_states; + hidden_states = final_layer_norm_(hidden_states); + hidden_states = fc1_(hidden_states); + hidden_states = activation_fn_(hidden_states); + hidden_states = fc2_(hidden_states); + hidden_states = residual + hidden_states; + + if (hidden_states.dtype() == kFloat16) { + const float clamp_value = 65504.0f - 1000.0f; + hidden_states = nn::functional::clip(hidden_states, -clamp_value, clamp_value); + } + + return {hidden_states}; + } +}; + +class Qwen2_5OmniAudioEncoder final : public nn::Module { + nn::Conv1D conv1_; + nn::Conv1D conv2_; + nn::GELU gelu_; + nn::ModuleList layers_; + nn::LayerNorm ln_post_; + nn::AvgPool1d avg_pooler_; + nn::Linear proj_; + nn::Embedding audio_bos_eos_token_; + + int32_t num_mel_bins_ = 0; + int32_t embed_dim_ = 0; + int32_t n_window_ = 0; + int32_t output_dim_ = 0; + + public: + Qwen2_5OmniAudioEncoder() = default; + + explicit Qwen2_5OmniAudioEncoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + num_mel_bins_ = cfg.audio_num_mel_bins; + embed_dim_ = cfg.audio_d_model; + n_window_ = cfg.audio_n_window; + output_dim_ = cfg.audio_output_dim; + + conv1_ = reg("conv1", num_mel_bins_, embed_dim_, 3, 1, 1); + conv2_ = reg("conv2", embed_dim_, embed_dim_, 3, 2, 1); + gelu_ = reg("gelu"); + audio_bos_eos_token_ = reg("audio_bos_eos_token", 2, cfg.audio_output_dim); + layers_ = reg>("layers", cfg.audio_encoder_layers, cfg); + ln_post_ = reg("ln_post", std::vector{embed_dim_}, true, true, 1e-5); + avg_pooler_ = reg("avg_pooler", 2, 2); + proj_ = reg("proj", embed_dim_, cfg.audio_output_dim, true); + + auto pos_emb = makeAudioSinusoidalPosEmb(cfg.audio_max_source_positions, embed_dim_); + registerBuffer("positional_embedding", pos_emb); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto input_features = inputs[0]; // [B, n_mels, T] + MLLM_RT_ASSERT_EQ(input_features.shape().size(), 3); + + const int32_t batch_size = input_features.shape()[0]; + MLLM_RT_ASSERT_EQ(input_features.shape()[1], num_mel_bins_); + const int32_t feature_len = input_features.shape()[2]; + MLLM_RT_ASSERT(feature_len > 0); + + auto pos_emb = getBuffer("positional_embedding"); + + std::vector audio_outputs; + audio_outputs.reserve(batch_size); + + for (int32_t b = 0; b < batch_size; ++b) { + Tensor audio_b = input_features[make_slice(b), kAll, kAll].view({1, num_mel_bins_, feature_len}).contiguous(); + + const int32_t chunk_size = n_window_ * 2; + const int32_t num_chunks = (feature_len + chunk_size - 1) / chunk_size; + + std::vector chunk_outputs; + chunk_outputs.reserve(num_chunks); + + for (int32_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + const int32_t start = chunk_idx * chunk_size; + const int32_t chunk_len = std::min(chunk_size, feature_len - start); + auto chunk = Tensor::empty({1, num_mel_bins_, chunk_len}, kFloat32, kCPU).alloc(); + for (int32_t m = 0; m < num_mel_bins_; ++m) { + auto src_ptr = audio_b.offsettedPtr({0, m, start}); + auto dst_ptr = chunk.offsettedPtr({0, m, 0}); + std::memcpy(dst_ptr, src_ptr, chunk_len * sizeof(float)); + } + + auto x = conv1_(chunk); + x = gelu_(x); + x = conv2_(x); + x = gelu_(x); + x = x.transpose(1, 2).contiguous(); // [1, T2, D] + + const int32_t t2 = x.shape()[1]; + MLLM_RT_ASSERT(t2 <= pos_emb.shape()[0]); + auto pos_ptr = pos_emb.ptr(); + auto x_ptr = x.ptr(); + for (int32_t t = 0; t < t2; ++t) { + const float* pos_row = pos_ptr + t * embed_dim_; + float* x_row = x_ptr + t * embed_dim_; + for (int32_t d = 0; d < embed_dim_; ++d) { x_row[d] += pos_row[d]; } + } + + auto hidden_states = x.squeeze(0); // [T2, D] + for (auto& layer : layers_.list()) { hidden_states = layer(hidden_states)[0]; } + if (hidden_states.shape()[0] < 2) { continue; } + + auto pooled = hidden_states.unsqueeze(0).transpose(1, 2); // [1, D, T] + pooled = avg_pooler_(pooled); + pooled = pooled.transpose(1, 2).squeeze(0); // [T', D] + pooled = ln_post_(pooled); + pooled = proj_(pooled); + chunk_outputs.push_back(pooled); + } + + int32_t total_len = 0; + for (const auto& chunk : chunk_outputs) { total_len += chunk.shape()[0]; } + + auto merged = Tensor::empty({total_len, output_dim_}, kFloat32, kCPU).alloc(); + int32_t offset = 0; + for (const auto& chunk : chunk_outputs) { + const int32_t len = chunk.shape()[0]; + const float* src_ptr = chunk.ptr(); + float* dst_ptr = merged.offsettedPtr({offset, 0}); + std::memcpy(dst_ptr, src_ptr, len * output_dim_ * sizeof(float)); + offset += len; + } + + audio_outputs.push_back(merged); + } + + int32_t total_audio_tokens = 0; + for (const auto& out : audio_outputs) { total_audio_tokens += out.shape()[0]; } + + auto output = Tensor::empty({total_audio_tokens, output_dim_}, kFloat32, kCPU).alloc(); + int32_t offset = 0; + for (const auto& out : audio_outputs) { + const int32_t len = out.shape()[0]; + const float* src_ptr = out.ptr(); + float* dst_ptr = output.offsettedPtr({offset, 0}); + std::memcpy(dst_ptr, src_ptr, len * output_dim_ * sizeof(float)); + offset += len; + } + + return {output}; + } +}; + class Qwen2_5OmniMLP final : public nn::Module { nn::Linear gate_proj_; nn::Linear up_proj_; @@ -275,10 +1030,14 @@ class Qwen2_5OmniThinker final : public nn::Module { Qwen2_5OmniThinker() = default; Qwen2_5OmniThinker(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { model_ = reg("model", cfg); + audio_tower_ = reg("audio_tower", cfg); + visual_ = reg("visual", cfg); lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); } Qwen2_5OmniText model_; + Qwen2_5OmniAudioEncoder audio_tower_; + Qwen2_5OmniVisionEncoder visual_; nn::Linear lm_head_; }; @@ -302,11 +1061,97 @@ class Qwen2_5OmniForCausalLM : public ARGeneration { auto input_embeddings = thinker_.model_.embedding_(sequence); - Tensor position_ids = Tensor::nil(); - if (input.count("position_ids")) { - position_ids = input.at("position_ids"); + if (input.count("input_features")) { + auto input_features = input.at("input_features"); + auto audio_embeddings = thinker_.audio_tower_(input_features)[0]; + MLLM_RT_ASSERT_EQ(audio_embeddings.shape()[1], input_embeddings.shape()[2]); + if (audio_embeddings.dtype() != input_embeddings.dtype()) { + audio_embeddings = audio_embeddings.to(input_embeddings.dtype()); + } + + MLLM_RT_ASSERT_EQ(sequence.shape()[0], 1); + auto S = sequence.shape()[1]; + std::vector audio_positions; + audio_positions.reserve(audio_embeddings.shape()[0]); + auto input_ids_ptr = sequence.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg_.audio_token_id) { audio_positions.push_back(s); } + } + MLLM_RT_ASSERT_EQ(static_cast(audio_positions.size()), audio_embeddings.shape()[0]); + + auto D = input_embeddings.shape()[2]; + if (input_embeddings.dtype() == kFloat32) { + for (size_t i = 0; i < audio_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, audio_positions[i], 0}); + auto in_ptr = audio_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else if (input_embeddings.dtype() == kFloat16) { + for (size_t i = 0; i < audio_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, audio_positions[i], 0}); + auto in_ptr = audio_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni audio input."); + } } - position_ids = getPositionIds(sequence, position_ids); + + if (input.count("img")) { + auto img = input.at("img"); + auto grid_thw = input.at("grid_thw"); + + auto inv_freq = makeVisualRoPEInvFreq(cfg_.visual_hidden_size / cfg_.visual_num_heads, 10000.0f); + auto pos_ids = makeVisualRotaryPosEmbIds(grid_thw, cfg_.visual_spatial_merge_size); + + int max_grid = 0; + for (int row = 0; row < grid_thw.shape()[0]; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + max_grid = std::max({max_grid, dims[1], dims[2]}); + } + MLLM_RT_ASSERT(max_grid > 0); + auto rotary_pos_emb_full = makeVisualRotaryPosEmbFull(inv_freq, max_grid); + auto pos_emb = makeVisualRotaryPosEmb(rotary_pos_emb_full, pos_ids, grid_thw); + auto [visual_embedding_sin, visual_embedding_cos] = makeVisualRotarySinCos(pos_emb); + + auto visual_embeddings = thinker_.visual_(img, visual_embedding_sin, visual_embedding_cos, grid_thw)[0]; + MLLM_RT_ASSERT_EQ(visual_embeddings.shape()[1], input_embeddings.shape()[2]); + if (visual_embeddings.dtype() != input_embeddings.dtype()) { + visual_embeddings = visual_embeddings.to(input_embeddings.dtype()); + } + + MLLM_RT_ASSERT_EQ(sequence.shape()[0], 1); + auto S = sequence.shape()[1]; + std::vector image_positions; + image_positions.reserve(visual_embeddings.shape()[0]); + auto input_ids_ptr = sequence.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg_.image_token_id) { image_positions.push_back(s); } + } + MLLM_RT_ASSERT_EQ(static_cast(image_positions.size()), visual_embeddings.shape()[0]); + + auto D = input_embeddings.shape()[2]; + if (input_embeddings.dtype() == kFloat32) { + for (size_t i = 0; i < image_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, image_positions[i], 0}); + auto in_ptr = visual_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else if (input_embeddings.dtype() == kFloat16) { + for (size_t i = 0; i < image_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, image_positions[i], 0}); + auto in_ptr = visual_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni image input."); + } + } + + Tensor position_ids = input.count("position_ids") ? input.at("position_ids") : Tensor::nil(); + Tensor img = input.count("img") ? input.at("img") : Tensor::nil(); + Tensor grid_thw = input.count("grid_thw") ? input.at("grid_thw") : Tensor::nil(); + position_ids = getPositionIds(img, grid_thw, sequence, position_ids); auto [llm_embedding_sin, llm_embedding_cos] = makeMultimodalPositionEmbedding(position_ids, thinker_.model_.getBuffer("inv_freq"), cfg_.max_position_embeddings, @@ -326,9 +1171,21 @@ class Qwen2_5OmniForCausalLM : public ARGeneration { Qwen2_5OmniThinker thinker_; private: - Tensor getPositionIds(Tensor& input_ids, Tensor& position_ids) const { + Tensor getPositionIds(Tensor& img, Tensor& grid_thw, Tensor& input_ids, Tensor& position_ids) const { MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + bool has_multimodal = false; + auto input_ids_ptr = input_ids.ptr(); + auto seq_len = input_ids.shape()[1]; + for (int s = 0; s < seq_len; ++s) { + if (input_ids_ptr[s] == cfg_.vision_start_token_id || input_ids_ptr[s] == cfg_.audio_start_token_id) { + has_multimodal = true; + break; + } + } + + if (has_multimodal) { return getPositionIdsPrefill(input_ids, grid_thw); } + if (!position_ids.isNil()) { auto last_pos = *position_ids.offsettedPtr({0, 0, position_ids.shape()[2] - 1}); auto ret_position_ids = Tensor::empty({3, 1, 1}, kInt64, kCPU).alloc(); @@ -339,7 +1196,7 @@ class Qwen2_5OmniForCausalLM : public ARGeneration { } auto B = input_ids.shape()[0]; - auto S = input_ids.shape()[1]; + auto S = seq_len; MLLM_RT_ASSERT_EQ(B, 1); Tensor out = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); @@ -350,6 +1207,170 @@ class Qwen2_5OmniForCausalLM : public ARGeneration { return out; } + Tensor getPositionIdsPrefill(Tensor& input_ids, Tensor& image_grid_thw) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + auto B = input_ids.shape()[0]; + auto S = input_ids.shape()[1]; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor position_ids = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + + auto input_ids_ptr = input_ids.ptr(); + + auto fill_text_positions = [&](int start_seq, int len, int64_t start_id) { + for (int d = 0; d < 3; ++d) { + auto out_ptr = position_ids.offsettedPtr({d, 0, 0}); + for (int i = 0; i < len; ++i) { out_ptr[start_seq + i] = start_id + i; } + } + }; + + int seq_idx = 0; + int image_idx = 0; + int64_t current_max_position_id = -1; + const int total_images = image_grid_thw.isNil() ? 0 : image_grid_thw.shape()[0]; + + while (seq_idx < S) { + int next_vision = -1; + int next_audio = -1; + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_start_token_id) { + next_vision = i; + break; + } + } + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_start_token_id) { + next_audio = i; + break; + } + } + + if (next_vision == -1 && next_audio == -1) { + const int text_len = S - seq_idx; + if (text_len > 0) { fill_text_positions(seq_idx, text_len, current_max_position_id + 1); } + break; + } + + const bool is_vision = (next_vision != -1) && (next_audio == -1 || next_vision < next_audio); + const int segment_start = is_vision ? next_vision : next_audio; + + const int text_len = segment_start - seq_idx; + if (text_len > 0) { + fill_text_positions(seq_idx, text_len, current_max_position_id + 1); + current_max_position_id += text_len; + } + + if (is_vision) { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int vision_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_end_token_id) { + vision_end = i; + break; + } + } + MLLM_RT_ASSERT(vision_end != -1); + MLLM_RT_ASSERT(image_idx < total_images); + if (image_grid_thw.isNil()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing grid_thw for Qwen2.5-Omni vision input."); + } + MLLM_RT_ASSERT_EQ(image_grid_thw.shape().size(), 2); + + std::vector image_positions; + for (int i = segment_start + 1; i < vision_end; ++i) { + if (input_ids_ptr[i] == cfg_.image_token_id) { + image_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside vision segment."); + } + } + + const int* grid_dims = image_grid_thw.offsettedPtr({image_idx, 0}); + const int grid_t = grid_dims[0]; + const int grid_h = grid_dims[1]; + const int grid_w = grid_dims[2]; + + const int image_token_len = (grid_t * grid_h * grid_w) + / (cfg_.visual_spatial_merge_size * cfg_.visual_spatial_merge_size); + MLLM_RT_ASSERT_EQ(static_cast(image_positions.size()), image_token_len); + + const int inputs_t = grid_t; + const int inputs_h = grid_h / cfg_.visual_spatial_merge_size; + const int inputs_w = grid_w / cfg_.visual_spatial_merge_size; + + const int64_t vision_start_id = current_max_position_id + 1; + int pos_counter = 0; + for (int ti = 0; ti < inputs_t; ++ti) { + const int64_t t_id = vision_start_id + static_cast(ti) * cfg_.position_id_per_seconds; + for (int hi = 0; hi < inputs_h; ++hi) { + for (int wi = 0; wi < inputs_w; ++wi) { + const auto seq_pos = image_positions[pos_counter++]; + *position_ids.offsettedPtr({0, 0, seq_pos}) = t_id; + *position_ids.offsettedPtr({1, 0, seq_pos}) = vision_start_id + hi; + *position_ids.offsettedPtr({2, 0, seq_pos}) = vision_start_id + wi; + } + } + } + + const int64_t dim_0_tail = vision_start_id + static_cast(inputs_t - 1) * cfg_.position_id_per_seconds; + const int64_t dim_1_tail = vision_start_id + inputs_h - 1; + const int64_t dim_2_tail = vision_start_id + inputs_w - 1; + current_max_position_id = std::max({dim_0_tail, dim_1_tail, dim_2_tail}); + + fill_text_positions(vision_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = vision_end + 1; + image_idx += 1; + } else { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int audio_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_end_token_id) { + audio_end = i; + break; + } + } + MLLM_RT_ASSERT(audio_end != -1); + + std::vector audio_positions; + for (int i = segment_start + 1; i < audio_end; ++i) { + if (input_ids_ptr[i] == cfg_.audio_token_id) { + audio_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside audio segment."); + } + } + + const int audio_len = static_cast(audio_positions.size()); + if (audio_len == 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Empty audio tokens inside audio segment."); + } + const int64_t audio_start_id = current_max_position_id + 1; + for (int i = 0; i < audio_len; ++i) { + const int64_t pos_id = audio_start_id + i; + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, audio_positions[i]}) = pos_id; + } + } + current_max_position_id += audio_len; + + fill_text_positions(audio_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = audio_end + 1; + } + } + + MLLM_RT_ASSERT_EQ(image_idx, total_images); + return position_ids; + } + const Qwen2_5OmniConfig& cfg_; nn::StaticCache kv_cache_; }; diff --git a/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp index 8674af9f5..961b5c8f2 100644 --- a/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp +++ b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include #include #include @@ -9,6 +10,9 @@ #include "mllm/preprocessor/tokenizers/Unicode.hpp" #include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" #include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen2vl/image_preprocessor_qwen2vl.hpp" +#include "mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp" +#include "mllm/utils/Common.hpp" namespace mllm::models::qwen2_5omni { @@ -141,9 +145,52 @@ struct Qwen2_5OmniMessage { } }; +struct Qwen2_5OmniVisionMessage { + std::string prompt; + std::string img_file_path; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +struct Qwen2_5OmniAudioMessage { + std::string prompt; + std::string audio_file_path; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + class Qwen2_5OmniTokenizer final : public mllm::preprocessor::AutoTokenizer { public: - explicit Qwen2_5OmniTokenizer(const std::string& file_path) { + explicit Qwen2_5OmniTokenizer(const std::string& file_path, + int32_t spatial_merge_size = 2, + int32_t min_pixels = 56 * 56, + int32_t max_pixels = 1280 * 1280, + int32_t audio_sample_rate = 16000, + int32_t audio_n_mels = 128, + int32_t audio_hop_length = 160, + int32_t audio_chunk_length = 300) + //interestingly, the answer went bad when setting max_pixels higher, eg. 3584*3584) + : image_preprocessor_(min_pixels, max_pixels), + audio_preprocessor_(audio_sample_rate, audio_n_mels, audio_hop_length, audio_chunk_length), + spatial_merge_size_(spatial_merge_size) { preprocessor::initLocal(); preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } @@ -243,10 +290,96 @@ class Qwen2_5OmniTokenizer final : public mllm::preprocessor::AutoTokenizer { return {{"sequence", sequence}}; } + ARGenerationOutputPast convertVisionMessage(const Qwen2_5OmniVisionMessage& message) { + auto applied_string = message.buildChatMessage(); + + auto [img, grid_thw] = image_preprocessor_(message.img_file_path); + + auto sequence_str = tokenize(applied_string); + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + auto grid_t = grid_thw.ptr()[0]; + auto grid_h = grid_thw.ptr()[1]; + auto grid_w = grid_thw.ptr()[2]; + int32_t img_token_nums = grid_t * grid_h * grid_w; + img_token_nums /= (spatial_merge_size_ * spatial_merge_size_); + + auto image_token_id = bpe_._lookup_vocab(L"<|IMAGE|>"); + { + auto it = std::find(ids.begin(), ids.end(), image_token_id); + if (it == ids.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing <|IMAGE|> token in Qwen2.5-Omni prompt template."); + } + ids.insert(it + 1, img_token_nums - 1, image_token_id); + } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + {"img", img}, + {"grid_thw", grid_thw}, + }; + } + + ARGenerationOutputPast convertAudioMessage(const Qwen2_5OmniAudioMessage& message) { + auto applied_string = message.buildChatMessage(); + auto sequence_str = tokenize(applied_string); + + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + auto audio_result = audio_preprocessor_.processAudioFile(message.audio_file_path); + if (audio_result.input_features.isNil() || audio_result.feature_length <= 0) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Failed to extract audio features for Qwen2.5-Omni."); + } + + int32_t audio_token_nums = audio_preprocessor_.calcAudioTokenLength(audio_result.feature_length); + if (audio_token_nums <= 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Invalid audio token length for Qwen2.5-Omni."); + } + + auto audio_token_id = bpe_._lookup_vocab(L"<|AUDIO|>"); + { + auto it = std::find(ids.begin(), ids.end(), audio_token_id); + if (it == ids.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing <|AUDIO|> token in Qwen2.5-Omni prompt template."); + } + ids.insert(it + 1, audio_token_nums - 1, audio_token_id); + } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + audio_result.input_features.setName("input_features"); + + return { + {"sequence", sequence}, + {"input_features", audio_result.input_features}, + }; + } + private: preprocessor::BPE bpe_; std::unordered_map bytes_2_unicode_dict_; std::unordered_map bytes_2_unicode_dict_inverse_; + mllm::models::qwen2vl::Qwen2VLImagePreprocessor image_preprocessor_; + Qwen2_5OmniAudioPreprocessor audio_preprocessor_; + int32_t spatial_merge_size_ = 2; }; } // namespace mllm::models::qwen2_5omni From e959822f3f1c09f915a2bc0b4f55c8c90eafcf5c Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sat, 17 Jan 2026 02:31:56 +0000 Subject: [PATCH 03/17] fix: Enhance quantization modules. Introduced FixedActivationQDQ for fixed quantization parameters, updated ActivationQDQ to use MovingAverageMinMaxObserver, and adjusted eps values for better precision. Modified Qwen3 model to utilize FixedActivationQDQ for sigmoid output and ensured dtype consistency in attention calculations. --- .../qualcomm/transformers/core/qdq.py | 117 +++++++++++++++++- .../qualcomm/transformers/core/rms_norm.py | 4 +- .../transformers/qwen3/modeling_qwen3.py | 34 ++++- .../qualcomm/transformers/qwen3/runner.py | 1 + .../qualcomm/transformers/qwen3/train.py | 1 + 5 files changed, 147 insertions(+), 10 deletions(-) diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/backends/qualcomm/transformers/core/qdq.py index ce67729f4..8a4f90687 100644 --- a/pymllm/backends/qualcomm/transformers/core/qdq.py +++ b/pymllm/backends/qualcomm/transformers/core/qdq.py @@ -1,6 +1,13 @@ import torch import torch.nn as nn -from torch.ao.quantization import FakeQuantize, MinMaxObserver +from torch.ao.quantization import ( + FakeQuantize, + MovingAverageMinMaxObserver, +) +from torch.ao.quantization.observer import FixedQParamsObserver + +DEFAULT_EPS_8BIT = 0.0001 / 255 +DEFAULT_EPS_16BIT = 0.0001 / 65535 class ActivationQDQ(nn.Module): @@ -30,16 +37,24 @@ def __init__(self, bits=8, qscheme=torch.per_tensor_affine): self.quant_min = 0 self.quant_max = (2**bits) - 1 + if bits == 8: + eps = DEFAULT_EPS_8BIT + elif bits == 16: + eps = DEFAULT_EPS_16BIT + else: + raise ValueError(f"Unsupported bit width: {bits}") + # 2. Initialize FakeQuantize - # MinMaxObserver calculates scale and zero_point based on observed tensors. + # MovingAverageMinMaxObserver calculates scale and zero_point based on observed tensors. # Passing quant_min/max to the observer ensures consistency. self.fake_quant = FakeQuantize( - observer=MinMaxObserver.with_args( - qscheme=self.qscheme, + observer=MovingAverageMinMaxObserver.with_args( dtype=self.dtype, + qscheme=self.qscheme, quant_min=self.quant_min, quant_max=self.quant_max, reduce_range=False, + eps=eps, ), quant_min=self.quant_min, quant_max=self.quant_max, @@ -72,3 +87,97 @@ def disable_fakequant(self): def extra_repr(self): mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" return f"bits={self.bits}, mode={mode}, q_range=({self.quant_min}, {self.quant_max}), dtype={self.dtype}" + + +class FixedActivationQDQ(nn.Module): + """ + Fixed activation Quantization-DeQuantization (QDQ) module. + Uses pre-determined scale and zero_point instead of dynamic observation. + Supports both Symmetric and Asymmetric (Affine) quantization. + Uses torch.qint32 as a unified type to support various bit-widths. + """ + + def __init__(self, scale, zero_point, bits=8, qscheme=torch.per_tensor_affine): + super().__init__() + self.bits = bits + self.qscheme = qscheme + + # Define the simulation dtype as qint32 to avoid overflow across different bit-widths + self.dtype = torch.qint32 + + # 1. Calculate quantization range based on bits and scheme + if qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]: + # Symmetric: range is [-(2^(bits-1)), 2^(bits-1) - 1] + # e.g., 8-bit: -128 to 127 + self.quant_min = -(2 ** (bits - 1)) + self.quant_max = 2 ** (bits - 1) - 1 + else: + # Asymmetric (Affine): range is [0, 2^bits - 1] + # e.g., 8-bit: 0 to 255 + self.quant_min = 0 + self.quant_max = (2**bits) - 1 + + if bits not in [8, 16]: + raise ValueError(f"Unsupported bit width: {bits}") + + # 2. Convert scale and zero_point to tensors if needed + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, dtype=torch.float32) + if not isinstance(zero_point, torch.Tensor): + zero_point = torch.tensor(zero_point, dtype=torch.int32) + + # 3. Initialize FakeQuantize with fixed parameters + # Use FakeQuantize with FixedQParamsObserver for fixed scale and zero_point + self.fake_quant = FakeQuantize.with_args( + observer=FixedQParamsObserver.with_args( + scale=scale, + zero_point=zero_point, + ), + dtype=self.dtype, + qscheme=self.qscheme, + quant_min=self.quant_min, + quant_max=self.quant_max, + )() + + def forward(self, x): + # Applies fake quantization with fixed scale and zero_point: + # rounds to nearest integer and clamps to [min, max], + # then dequantizes back to float to simulate quantization noise. + return self.fake_quant(x) + + # Control methods for quantization-aware training (QAT) + # Note: FixedActivationQDQ doesn't have observer, so these methods + # only control fake quantization behavior + def enable_observer(self): + """No-op: FixedActivationQDQ doesn't use observer.""" + pass + + def disable_observer(self): + """No-op: FixedActivationQDQ doesn't use observer.""" + pass + + def enable_fakequant(self): + """Enable simulation of quantization error.""" + self.fake_quant.enable_fakequant() + + def disable_fakequant(self): + """Disable quantization simulation (act as identity).""" + self.fake_quant.disable_fakequant() + + @property + def scale(self): + """Get the fixed scale value.""" + return self.fake_quant.scale + + @property + def zero_point(self): + """Get the fixed zero_point value.""" + return self.fake_quant.zero_point + + def extra_repr(self): + mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" + scale_val = self.scale.item() if self.scale.numel() == 1 else self.scale + zp_val = ( + self.zero_point.item() if self.zero_point.numel() == 1 else self.zero_point + ) + return f"bits={self.bits}, mode={mode}, scale={scale_val}, zero_point={zp_val}, q_range=({self.quant_min}, {self.quant_max}), dtype={self.dtype}" diff --git a/pymllm/backends/qualcomm/transformers/core/rms_norm.py b/pymllm/backends/qualcomm/transformers/core/rms_norm.py index 0101d6aee..b3964469f 100644 --- a/pymllm/backends/qualcomm/transformers/core/rms_norm.py +++ b/pymllm/backends/qualcomm/transformers/core/rms_norm.py @@ -21,7 +21,9 @@ def __init__( # Quantization configuration for Weight self.weight_fake_quant = FakeQuantize( observer=MinMaxObserver.with_args( - qscheme=torch.per_tensor_affine, dtype=torch.qint32 + qscheme=torch.per_tensor_affine, + dtype=torch.qint32, + eps=0.0001 / 65535, ), quant_min=0, quant_max=2 ** (quant_bits) - 1, diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 9c0696328..0bbcbffd8 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -49,9 +49,11 @@ from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from pymllm.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, - QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ActivationQDQ +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) class Qwen3MLP(nn.Module): @@ -76,7 +78,12 @@ def __init__(self, config): self.gate_proj_output_qdq = ActivationQDQ(bits=16) self.act_output_qdq = ActivationQDQ(bits=16) self.down_proj_input_qdq = ActivationQDQ(bits=16) - self.sigmoid_output_qdq = ActivationQDQ(bits=16) + # For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0 + # For 16-bit: q_min = 0, q_max = 65535 + sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536 + self.sigmoid_output_qdq = FixedActivationQDQ( + scale=sigmoid_scale, zero_point=0, bits=16 + ) def forward(self, x): x = self.up_proj_input_qdq(x) @@ -281,7 +288,7 @@ def forward( torch.matmul(query_states, key_states.transpose(2, 3)) ) * self.scaling_qdq( - torch.ones(1, dtype=torch.bfloat16, device=value_states.device) + torch.ones(1, dtype=value_states.dtype, device=value_states.device) * self.scaling ) ) @@ -292,7 +299,8 @@ def forward( attn_vv = self.minus_0_output_qdq( attn_min + self.neg_20_qdq( - torch.ones(1, dtype=torch.bfloat16, device=value_states.device) * (-20) + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * (-20) ) ) attn_weights = torch.where(attention_mask == 0, attn_weights, attn_vv) @@ -315,6 +323,7 @@ def forward( class Qwen3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() + self.layer_dix = layer_idx self.hidden_size = config.hidden_size self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) @@ -362,6 +371,15 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) + + if self.layer_dix == 2: + print("1", hidden_states.min(), hidden_states.max()) + print( + "2", + self.add_0_lhs_input_qdq(hidden_states).min(), + self.add_0_lhs_input_qdq(hidden_states).max(), + ) + hidden_states = self.add_0_output_qdq( residual + self.add_0_lhs_input_qdq(hidden_states) ) @@ -567,6 +585,12 @@ def forward( self.mllm_max_cos_embedding, self.mllm_max_sin_embedding = self.rotary_emb( hidden_states, max_position_ids ) + self.mllm_max_cos_embedding = self.mllm_max_cos_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_sin_embedding = self.mllm_max_sin_embedding.to( + inputs_embeds.dtype + ) self.mllm_max_cos_embedding = self.cos_embedding_input_qdq( self.mllm_max_cos_embedding ) diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index 53ab40a9e..88f5ce84e 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -44,6 +44,7 @@ def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): self.model = Qwen3ForCausalLM.from_pretrained( model_path, attn_implementation="eager", + dtype=torch.bfloat16, ) self.model.cuda() self.mllm_qualcomm_max_length = mllm_qualcomm_max_length diff --git a/pymllm/backends/qualcomm/transformers/qwen3/train.py b/pymllm/backends/qualcomm/transformers/qwen3/train.py index 13ad2785a..33351918f 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/train.py @@ -44,6 +44,7 @@ def main(): # !!! # Things below is for deploy. We will turn all fp32 weights and some buffers(rope) to quantized dtype. # !!! + # This line maybe error. we need use quantized weight!!! not embed_tokens.weight!!! m.model.lm_head.weight = torch.nn.Parameter( m.model.model.embed_tokens.weight.clone() ) From 0672432d6d7f94a2567dfb2e1dbbb2b2e76985e9 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sat, 17 Jan 2026 03:29:56 +0000 Subject: [PATCH 04/17] fix: Suppress deprecated comma-subscript warnings in CMake and remove debug print statements from Qwen3DecoderLayer --- mllm/CMakeLists.txt | 4 ++++ .../qualcomm/transformers/qwen3/modeling_qwen3.py | 10 ++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mllm/CMakeLists.txt b/mllm/CMakeLists.txt index 9df6b7741..fd796f95a 100644 --- a/mllm/CMakeLists.txt +++ b/mllm/CMakeLists.txt @@ -56,6 +56,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "App endif() endif() +# FIXME: @oreomaker Need to remove comma features in slice! +# Suppress comma-subscript warnings (deprecated C++ feature that will be removed in C++26) +target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) + # ONLY APPLE CAN DO ! # Processing OpenMP if(MLLM_KERNEL_USE_THREADS AND MLLM_KERNEL_THREADS_VENDOR_OPENMP) diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 0bbcbffd8..dc6486043 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -372,14 +372,6 @@ def forward( **kwargs, ) - if self.layer_dix == 2: - print("1", hidden_states.min(), hidden_states.max()) - print( - "2", - self.add_0_lhs_input_qdq(hidden_states).min(), - self.add_0_lhs_input_qdq(hidden_states).max(), - ) - hidden_states = self.add_0_output_qdq( residual + self.add_0_lhs_input_qdq(hidden_states) ) @@ -388,6 +380,8 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + # if self.layer_dix == 2: + # print(hidden_states.min(), hidden_states.max()) hidden_states = residual + self.add_1_lhs_input_qdq(hidden_states) return hidden_states From 927f7eb8c76afa1664bb482e0b425f04f4f022db Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Mon, 19 Jan 2026 07:45:23 +0000 Subject: [PATCH 05/17] feat(qualcomm): Add installation targets for flatbuffers and MllmQNNBackend in CMake, enhance PTQPass with unsolved tensor value checks, and update quantization specifications in RMSNorm and model file conversion. --- CMakeLists.txt | 7 +++ .../qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp | 3 +- mllm/backends/qnn/CMakeLists.txt | 7 +++ mllm/backends/qnn/aot/passes/PTQPass.cpp | 44 +++++++++++++++++++ mllm/backends/qnn/aot/visitor/RMSNorm.cpp | 5 ++- .../qualcomm/transformers/core/qdq.py | 4 +- .../qualcomm/transformers/core/qlinear.py | 4 +- .../transformers/qwen3/modeling_qwen3.py | 2 - .../qualcomm/transformers/qwen3/runner.py | 2 +- pymllm/convertor/model_file_v2.py | 12 ++++- 10 files changed, 80 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 298b412c0..fca470ee5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -332,6 +332,13 @@ install( ARCHIVE DESTINATION lib RUNTIME DESTINATION bin) +install( + TARGETS flatbuffers + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) + if(MLLM_BUILD_SDK_C_BINDING) install( TARGETS MllmSdkC diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp index 9eed37267..f1b20a1a2 100644 --- a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp @@ -272,7 +272,8 @@ class Qwen3Attention final : public nn::Module { auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq"); auto minus_value = Tensor::constant(-20, kFloat32); minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); - attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_min.addConstant(minus_value)); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); + attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_vv); attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); diff --git a/mllm/backends/qnn/CMakeLists.txt b/mllm/backends/qnn/CMakeLists.txt index 0ad833792..83b4a43f9 100644 --- a/mllm/backends/qnn/CMakeLists.txt +++ b/mllm/backends/qnn/CMakeLists.txt @@ -44,3 +44,10 @@ get_property(current_includes DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INC message(STATUS "MLLM_QNN INCLUDES: ${current_includes}") #print include directories target_link_libraries(MllmQNNBackend PUBLIC MllmRT) + +install( + TARGETS MllmQNNBackend + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index 1d42d58d3..7172db475 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -300,6 +300,45 @@ void recursiveSolveNormal(const std::shared_ptr& ir_ctx, const ir }); } +void recursiveCheckUnsolved(const std::shared_ptr& ir_ctx, const ir::graph::SubGraphOp::ptr_t& call_op) { + auto wow = ir::IRWriter(ir_ctx, call_op->getTopRegion()); + wow.walk([&](ir::IRWriter& w, const ir::Op::ptr_t& op) -> ir::IRWriter::WalkResult { + if (op->isa_()) { + auto linalg_op = op->cast_(); + std::string op_name = linalg_op->getAOp()->getName(); + + auto inputs = op->inputs(); + auto outputs = op->outputs(); + + for (auto iii : inputs) { + if (!iii->isa_()) continue; + auto tv = iii->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + if (!f_spec->spec_->solved) { + MLLM_WARN("PTQPass: TensorValue '{}' is not solved, used by Op: '{}'", tv->name(), op_name); + } + } + + for (auto ooo : outputs) { + if (!ooo->isa_()) continue; + auto tv = ooo->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + if (!f_spec->spec_->solved) { + MLLM_WARN("PTQPass: TensorValue '{}' is not solved, produced by Op: '{}'", tv->name(), op_name); + } + } + } + + if (op->isa_()) { + auto ns = op->cast_()->getSymbolAttr()->str(); + recursiveCheckUnsolved(w.getContext(), w.getContext()->lookupSymbolTable(ns)->cast_()); + } + return ir::IRWriter::WALK_CONTINUE; + }); +} + } // namespace uint8_t PTQPass::run(const ir::node_ptr_t& op) { @@ -330,6 +369,11 @@ uint8_t PTQPass::run(const ir::node_ptr_t& op) { getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_(), pf); + // Check for unsolved tensorValues and warn + recursiveCheckUnsolved( + writer.getContext(), + getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + return ir::PASS_RET_SUCCESS; } diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp index 27f72e2e2..351e2562a 100644 --- a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp +++ b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp @@ -47,9 +47,12 @@ bool QnnAOTRMSNormPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) auto bias_tensor = mllm::Tensor::zeros(weight->tensor_.shape(), weight->tensor_.dtype()); auto bias_node = ir::tensor::TensorValue::build(writer.getContext().get(), bias_tensor); bias_node->tensor_.setName(a->getName() + "_runtime_bias"); + bias_node->name() = a->getName() + "_runtime_bias"; // fake bias quant recipe - auto quant_spec = mllm::ir::linalg::QuantizationSpecSymPerTensor::create(0, 0, kInt32, kFloat32, Tensor::ones({1})); + auto bias_scale = Tensor::ones({1}); + bias_scale.at({0}) = 1.0 / 32767; + auto quant_spec = mllm::ir::linalg::QuantizationSpecSymPerTensor::create(-32768, 32767, kInt16, kFloat32, bias_scale); auto quant_attr = mllm::ir::linalg::LinalgIRQuantizatonSpecAttr::build(writer.getContext().get(), quant_spec); bias_node->setAttr("quant_recipe", quant_attr); diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/backends/qualcomm/transformers/core/qdq.py index 8a4f90687..f1c4d20dc 100644 --- a/pymllm/backends/qualcomm/transformers/core/qdq.py +++ b/pymllm/backends/qualcomm/transformers/core/qdq.py @@ -2,7 +2,7 @@ import torch.nn as nn from torch.ao.quantization import ( FakeQuantize, - MovingAverageMinMaxObserver, + MinMaxObserver, ) from torch.ao.quantization.observer import FixedQParamsObserver @@ -48,7 +48,7 @@ def __init__(self, bits=8, qscheme=torch.per_tensor_affine): # MovingAverageMinMaxObserver calculates scale and zero_point based on observed tensors. # Passing quant_min/max to the observer ensures consistency. self.fake_quant = FakeQuantize( - observer=MovingAverageMinMaxObserver.with_args( + observer=MinMaxObserver.with_args( dtype=self.dtype, qscheme=self.qscheme, quant_min=self.quant_min, diff --git a/pymllm/backends/qualcomm/transformers/core/qlinear.py b/pymllm/backends/qualcomm/transformers/core/qlinear.py index d9c55e759..255f52ffb 100644 --- a/pymllm/backends/qualcomm/transformers/core/qlinear.py +++ b/pymllm/backends/qualcomm/transformers/core/qlinear.py @@ -296,7 +296,9 @@ def convert_to_conv2d_deploy_hwio(self): s1_permuted = ( s1.view(self.out_features, -1).t().contiguous() ) # [Out, Blocks] -> [Blocks, Out] - s1_hwio = s1_permuted.view(1, 1, -1, self.out_features) # Shape: [1, 1, Blocks, Out] + s1_hwio = s1_permuted.view( + 1, 1, -1, self.out_features + ) # Shape: [1, 1, Blocks, Out] del self.weight self.register_buffer("weight", w_hwio) diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index dc6486043..2f099088e 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -380,8 +380,6 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - # if self.layer_dix == 2: - # print(hidden_states.min(), hidden_states.max()) hidden_states = residual + self.add_1_lhs_input_qdq(hidden_states) return hidden_states diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index 88f5ce84e..ed302f215 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -44,7 +44,7 @@ def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): self.model = Qwen3ForCausalLM.from_pretrained( model_path, attn_implementation="eager", - dtype=torch.bfloat16, + dtype=torch.float32, ) self.model.cuda() self.mllm_qualcomm_max_length = mllm_qualcomm_max_length diff --git a/pymllm/convertor/model_file_v2.py b/pymllm/convertor/model_file_v2.py index 302e3e21b..976c04411 100644 --- a/pymllm/convertor/model_file_v2.py +++ b/pymllm/convertor/model_file_v2.py @@ -24,6 +24,14 @@ MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH = 16 +def _torch_tensor_bytes(tensor: "torch.Tensor") -> bytes: + # Use uint8 view to preserve raw bytes for dtypes not supported by numpy. + t = tensor.detach().cpu().contiguous() + if t.dim() == 0: + t = t.reshape(1) + return t.view(torch.uint8).numpy().tobytes() + + class ModelFileV2Descriptor: SIZE = 532 @@ -132,7 +140,7 @@ def streaming_write(self, tensor_name, tensor_obj): if MLLM_FIND_TORCH_AVAILABLE and isinstance(tensor_obj, torch.Tensor): # PyTorch tensor shape = list(tensor_obj.shape) - tensor_data = tensor_obj.detach().cpu().numpy().tobytes() + tensor_data = _torch_tensor_bytes(tensor_obj) true_dtype = MLLM_TYPE_MAPPING[tensor_obj.dtype] elif MLLM_FIND_NUMPY_AVAILABLE and isinstance(tensor_obj, np.ndarray): # Numpy array @@ -203,7 +211,7 @@ def static_write(self, tensor_obj): if MLLM_FIND_TORCH_AVAILABLE and isinstance(tensor, torch.Tensor): # PyTorch tensor shape = list(tensor.shape) - tensor_data = tensor.detach().cpu().numpy().tobytes() + tensor_data = _torch_tensor_bytes(tensor) true_dtype = MLLM_TYPE_MAPPING[tensor.dtype] elif MLLM_FIND_NUMPY_AVAILABLE and isinstance(tensor, np.ndarray): # Numpy array From d2e6b36edf6b799c126fa71c77d090f0a2bcb7bb Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Mon, 19 Jan 2026 13:08:59 +0000 Subject: [PATCH 06/17] feat(qualcomm): Refactor Qwen3 model to integrate ConcatObserver for improved quantization, enhance rotate_half function to utilize observers, and ensure consistent scale and zero_point across concatenated inputs. --- .../qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp | 35 +++---- .../qnn/aot/passes/LLMQuantRecipePass.cpp | 17 +++- mllm/backends/qnn/aot/passes/PTQPass.cpp | 93 +++++++++++++++++++ .../qualcomm/transformers/core/observer.py | 56 +++++++++++ .../qualcomm/transformers/core/qdq.py | 8 +- .../transformers/qwen3/modeling_qwen3.py | 65 ++++++++++++- .../qualcomm/transformers/qwen3/runner.py | 21 ++++- .../qualcomm/transformers/qwen3/train.py | 5 +- 8 files changed, 268 insertions(+), 32 deletions(-) create mode 100644 pymllm/backends/qualcomm/transformers/core/observer.py diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp index f1b20a1a2..a2d054bad 100644 --- a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp @@ -15,14 +15,6 @@ namespace mllm::models::qwen3 { -Tensor rotateHalf(Tensor x) { // NOLINT - // X is [x, x, x, D] - auto D = x.size(-1); - auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); - auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); - return nn::functional::concat({-x2, x1}, -1); -} - namespace ptq { Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { @@ -112,6 +104,14 @@ Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch } // namespace ptq +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + using vi32 = std::vector; #define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 @@ -232,14 +232,16 @@ class Qwen3Attention final : public nn::Module { // [B, H, S, D] auto cos = llm_embedding_cos.unsqueeze(1); auto sin = llm_embedding_sin.unsqueeze(1); - query_states = ptq::QDQ(this, - ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") - + ptq::QDQ(this, rotateHalf(query_states) * sin, "q_rope_mul_1_output_qdq"), - "q_rope_add_0_output_qdq"); - key_states = ptq::QDQ(this, - ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") - + ptq::QDQ(this, rotateHalf(key_states) * sin, "k_rope_mul_1_output_qdq"), - "k_rope_add_0_output_qdq"); + query_states = + ptq::QDQ(this, + ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(query_states, this, "q_rope_neg_half_qdq") * sin, "q_rope_mul_1_output_qdq"), + "q_rope_add_0_output_qdq"); + key_states = + ptq::QDQ(this, + ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(key_states, this, "k_rope_neg_half_qdq") * sin, "k_rope_mul_1_output_qdq"), + "k_rope_add_0_output_qdq"); // De-quantization and quantization again key_states = key_states.to(kFloat32); @@ -274,6 +276,7 @@ class Qwen3Attention final : public nn::Module { minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq"); attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp index 90ee4ad72..957fdf321 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp @@ -369,8 +369,7 @@ bool LLMQuantRecipeNegPattern::isMatch(const mllm::ir::op_ptr_t& op) { } bool LLMQuantRecipeNegPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& node) { - return shareQuantSpecSingleInputToSingleOutputAndSetOpQuantAnnoAttr(writer.getContext(), - node->cast_()); + return noSharingSingleInAndSingleOutQuantAnnoAttr(writer.getContext(), node->cast_()); } //===----------------------------------------------------------------------===// @@ -651,8 +650,15 @@ bool LLMQuantRecipeConcatPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr return false; } - MLLM_RETURN_FALSE_IF_NOT(i_0->getAttr("quant_recipe")); - MLLM_RETURN_FALSE_IF_NOT(i_1->getAttr("quant_recipe")); + // Create quant_recipe if not present + if (!i_0->getAttr("quant_recipe")) { + auto i_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_0->cast_()); + i_0->setAttr("quant_recipe", i_0_spec); + } + if (!i_1->getAttr("quant_recipe")) { + auto i_1_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_1->cast_()); + i_1->setAttr("quant_recipe", i_1_spec); + } o_0->setAttr("quant_recipe", i_0->getAttr("quant_recipe")); @@ -795,7 +801,8 @@ bool LLMQuantRecipeWherePattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_ MLLM_RETURN_FALSE_IF_NOT(i_1->getAttr("quant_recipe")); MLLM_RETURN_FALSE_IF_NOT(i_2->getAttr("quant_recipe")); - o_0->setAttr("quant_recipe", i_2->getAttr("quant_recipe")); + auto o_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), o_0->cast_()); + o_0->setAttr("quant_recipe", o_0_spec); auto annotation_attr = writer.create(); annotation_attr->annotation_.inputs.emplace_back( diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index 7172db475..82869ab16 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -339,6 +339,94 @@ void recursiveCheckUnsolved(const std::shared_ptr& ir_ctx, const }); } +void recursiveCheckConcatInputs(const std::shared_ptr& ir_ctx, const ir::graph::SubGraphOp::ptr_t& call_op) { + auto wow = ir::IRWriter(ir_ctx, call_op->getTopRegion()); + wow.walk([&](ir::IRWriter& w, const ir::Op::ptr_t& op) -> ir::IRWriter::WalkResult { + if (op->isa_()) { + auto concat_op = op->cast_(); + std::string op_name = concat_op->getAOp()->getName(); + + auto inputs = op->inputs(); + if (inputs.empty()) { return ir::IRWriter::WALK_CONTINUE; } + + // Get first input's scale and zero_point as reference + Tensor ref_scale; + Tensor ref_zero_point; + bool has_ref = false; + std::string ref_input_name; + + for (auto iii : inputs) { + if (!iii->isa_()) continue; + auto tv = iii->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + + if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kAsymPerTensor) { + auto this_spec = std::static_pointer_cast(f_spec->spec_); + if (!this_spec->solved) continue; + + if (!has_ref) { + ref_scale = this_spec->scale; + ref_zero_point = this_spec->zero_point; + ref_input_name = tv->name(); + has_ref = true; + } else { + // Check if scale and zero_point match + auto cur_scale = this_spec->scale; + auto cur_zero_point = this_spec->zero_point; + + MLLM_RT_ASSERT_EQ(ref_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(ref_zero_point.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_zero_point.numel(), 1); + + auto ref_scale_v = ref_scale.item(); + auto cur_scale_v = cur_scale.item(); + auto ref_zp_v = ref_zero_point.item(); + auto cur_zp_v = cur_zero_point.item(); + + if (std::abs(ref_scale_v - cur_scale_v) > 1e-6 || ref_zp_v != cur_zp_v) { + MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale/zp between inputs. " + "Input '{}': scale={}, zp={}; Input '{}': scale={}, zp={}", + op_name, ref_input_name, ref_scale_v, ref_zp_v, tv->name(), cur_scale_v, cur_zp_v); + } + } + } else if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kSymPerTensor) { + auto this_spec = std::static_pointer_cast(f_spec->spec_); + if (!this_spec->solved) continue; + + if (!has_ref) { + ref_scale = this_spec->scale; + ref_input_name = tv->name(); + has_ref = true; + } else { + // Check if scale matches + auto cur_scale = this_spec->scale; + + MLLM_RT_ASSERT_EQ(ref_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_scale.numel(), 1); + + auto ref_scale_v = ref_scale.item(); + auto cur_scale_v = cur_scale.item(); + + if (std::abs(ref_scale_v - cur_scale_v) > 1e-6) { + MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale between inputs. " + "Input '{}': scale={}; Input '{}': scale={}", + op_name, ref_input_name, ref_scale_v, tv->name(), cur_scale_v); + } + } + } + } + } + + if (op->isa_()) { + auto ns = op->cast_()->getSymbolAttr()->str(); + recursiveCheckConcatInputs(w.getContext(), w.getContext()->lookupSymbolTable(ns)->cast_()); + } + return ir::IRWriter::WALK_CONTINUE; + }); +} + } // namespace uint8_t PTQPass::run(const ir::node_ptr_t& op) { @@ -374,6 +462,11 @@ uint8_t PTQPass::run(const ir::node_ptr_t& op) { writer.getContext(), getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + // Check Concat inputs have consistent scale and zero_point + recursiveCheckConcatInputs( + writer.getContext(), + getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + return ir::PASS_RET_SUCCESS; } diff --git a/pymllm/backends/qualcomm/transformers/core/observer.py b/pymllm/backends/qualcomm/transformers/core/observer.py new file mode 100644 index 000000000..67a946b10 --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/core/observer.py @@ -0,0 +1,56 @@ +import torch +from torchao.quantization.pt2e import UniformQuantizationObserverBase + + +class ConcatObserver(UniformQuantizationObserverBase): + """ + Fetch maximum data range of all tensors to be concatenated + """ + + def __init__( + self, + dtype=torch.uint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + # get concat node and its inputs + self.input_observers = [] + + def add_observer(self, observer): + self.input_observers.append(observer) + + def forward(self, x_orig): + # calculate the min / max first + self.min_val = min(self.min_val, x_orig.min()) + self.max_val = max(self.max_val, x_orig.max()) + + # update min / max for all observers of input nodes + for observers in self.input_observers: + observers.min_val = self.min_val + observers.max_val = self.max_val + + return x_orig + + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/backends/qualcomm/transformers/core/qdq.py index f1c4d20dc..c13011a51 100644 --- a/pymllm/backends/qualcomm/transformers/core/qdq.py +++ b/pymllm/backends/qualcomm/transformers/core/qdq.py @@ -78,11 +78,11 @@ def disable_observer(self): def enable_fakequant(self): """Enable simulation of quantization error.""" - self.fake_quant.enable_fakequant() + self.fake_quant.enable_fake_quant() def disable_fakequant(self): """Disable quantization simulation (act as identity).""" - self.fake_quant.disable_fakequant() + self.fake_quant.disable_fake_quant() def extra_repr(self): mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" @@ -158,11 +158,11 @@ def disable_observer(self): def enable_fakequant(self): """Enable simulation of quantization error.""" - self.fake_quant.enable_fakequant() + self.fake_quant.enable_fake_quant() def disable_fakequant(self): """Disable quantization simulation (act as identity).""" - self.fake_quant.disable_fakequant() + self.fake_quant.disable_fake_quant() @property def scale(self): diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 2f099088e..92efaa06d 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -54,6 +54,7 @@ ActivationQDQ, FixedActivationQDQ, ) +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver class Qwen3MLP(nn.Module): @@ -100,11 +101,13 @@ def forward(self, x): return o -def rotate_half(x): +def rotate_half( + x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver +): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1)) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -214,6 +217,39 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_rope_mul_1_output_qdq = ActivationQDQ(bits=16) self.k_rope_add_0_output_qdq = ActivationQDQ(bits=16) + self.q_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.q_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer.add_observer( + self.k_norm_output_qdq.fake_quant.activation_post_process + ) + self.k_rope_concat_observer.add_observer( + self.k_rope_neg_half_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_norm_output_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_rope_neg_half_qdq.fake_quant.activation_post_process + ) + # In qnn, is uint8 sym. self.k_cast_to_int8_qdq = ActivationQDQ( bits=8, qscheme=torch.per_tensor_symmetric @@ -231,6 +267,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.minus_0_output_qdq = ActivationQDQ(bits=16) self.softmax_output_qdq = ActivationQDQ(bits=16) self.attn_value_matmul_output_qdq = ActivationQDQ(bits=16) + self.where_attn_qdq = ActivationQDQ(bits=16) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -263,11 +300,27 @@ def forward( sin = sin.unsqueeze(1) query_states = self.q_rope_add_0_output_qdq( self.q_rope_mul_0_output_qdq(query_states * cos) - + self.q_rope_mul_1_output_qdq(rotate_half(query_states) * sin) + + self.q_rope_mul_1_output_qdq( + rotate_half( + query_states, + self.q_norm_output_qdq.fake_quant.activation_post_process, + self.q_rope_neg_half_qdq, + self.q_rope_concat_observer, + ) + * sin + ) ) key_states = self.k_rope_add_0_output_qdq( self.k_rope_mul_0_output_qdq(key_states * cos) - + self.k_rope_mul_1_output_qdq(rotate_half(key_states) * sin) + + self.k_rope_mul_1_output_qdq( + rotate_half( + key_states, + self.k_norm_output_qdq.fake_quant.activation_post_process, + self.k_rope_neg_half_qdq, + self.k_rope_concat_observer, + ) + * sin + ) ) key_states = self.k_cast_to_int8_qdq(key_states) @@ -303,7 +356,9 @@ def forward( * (-20) ) ) - attn_weights = torch.where(attention_mask == 0, attn_weights, attn_vv) + attn_weights = self.where_attn_qdq( + torch.where(attention_mask == 0, attn_weights, attn_vv) + ) attn_weights = self.softmax_output_qdq( nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index ed302f215..6565ca7e6 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -2,7 +2,10 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ActivationQDQ +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from pymllm.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, @@ -31,6 +34,16 @@ def enable_qdq_observer(m): m.enable_observer() +def enable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.enable_fakequant() + + +def disable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.disable_fakequant() + + def convert_weight(m): if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): m.convert_to_conv2d_deploy_hwio() @@ -61,6 +74,12 @@ def freeze_activation(self): def enable_activation_update(self): self.model.apply(enable_qdq_observer) + def enable_fake_quant(self): + self.model.apply(enable_fake_quant) + + def disable_fake_quant(self): + self.model.apply(disable_fake_quant) + def compile(self): print("Compile Start.") self.model = torch.compile( diff --git a/pymllm/backends/qualcomm/transformers/qwen3/train.py b/pymllm/backends/qualcomm/transformers/qwen3/train.py index 33351918f..25361f372 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/train.py @@ -37,8 +37,11 @@ def main(): args = parser.parse_args() m = Qwen3Quantizer(args.model_path, mllm_qualcomm_max_length=args.max_length) + + # FIXME: Should disable or not. + m.disable_fake_quant() m.calibrate(num_samples=args.num_samples, max_seq_length=args.max_length) - # m.compile() + m.enable_fake_quant() m.infer(args.infer_text) # !!! From 48c259a8e87b4b0fabb6eaeca7074a1656500e55 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Tue, 20 Jan 2026 09:16:50 +0000 Subject: [PATCH 07/17] feat(cpu): Implement fill operations for various data types including zeros, ones, specific values, arange, and random fills. Introduce a new fill-inl.hpp file for optimized implementations and update kernel dispatch to include these operations. Enhance CPUFillOp to utilize the new fill functions for better performance and maintainability. --- mllm/backends/cpu/kernels/common/fill-inl.hpp | 363 ++++++++++++++++++ .../cpu/kernels/common/kernel_dispatch.cpp | 180 ++++++++- .../cpu/kernels/common/kernel_dispatch.hpp | 217 +++++++++++ mllm/backends/cpu/ops/FillOp.cpp | 118 +++--- mllm/backends/qnn/aot/passes/PTQPass.cpp | 6 +- mllm/ffi/Extension.cc | 16 + pymllm/__init__.py | 16 +- pymllm/ffi/__init__.py | 67 +++- 8 files changed, 928 insertions(+), 55 deletions(-) create mode 100644 mllm/backends/cpu/kernels/common/fill-inl.hpp diff --git a/mllm/backends/cpu/kernels/common/fill-inl.hpp b/mllm/backends/cpu/kernels/common/fill-inl.hpp new file mode 100644 index 000000000..4c799daf6 --- /dev/null +++ b/mllm/backends/cpu/kernels/common/fill-inl.hpp @@ -0,0 +1,363 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +// NOTE: Do NOT use #pragma once here! +// Highway's foreach_target.h mechanism requires -inl.hpp files to be included +// multiple times, once for each target architecture (AVX3_DL, AVX10_2, etc.). + +#include +#include +#include "mllm/core/DataTypes.hpp" + +HWY_BEFORE_NAMESPACE(); +namespace mllm::cpu::common { // NOLINT +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_zeros_impl(T* HWY_RESTRICT dst, size_t count) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec zero = hn::Zero(d); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(zero, d, dst + idx); } + + if (idx < count) { hn::StoreN(zero, d, dst + idx, count - idx); } +} + +// Specialization for types not supported by Highway SIMD, use memset +template +HWY_INLINE void fill_zeros_scalar(T* HWY_RESTRICT dst, size_t count) { + if constexpr (std::is_trivial_v) { + std::memset(dst, 0, count * sizeof(T)); + } else { + T zero_val{}; + for (size_t i = 0; i < count; ++i) { dst[i] = zero_val; } + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_ones_impl(T* HWY_RESTRICT dst, size_t count) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec one = hn::Set(d, static_cast(1)); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(one, d, dst + idx); } + + if (idx < count) { hn::StoreN(one, d, dst + idx, count - idx); } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_value_impl(T* HWY_RESTRICT dst, size_t count, T value) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec v = hn::Set(d, value); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(v, d, dst + idx); } + + if (idx < count) { hn::StoreN(v, d, dst + idx, count - idx); } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size, mllm_fp64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_int32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_uint32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_int64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_uint64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_int16_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_uint16_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_int8_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_uint8_t value) { + fill_value_impl(dst, size, value); +} + +//===----------------------------------------------------------------------===// +// Fill Arange (start, end, step) +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_arange_impl(T* HWY_RESTRICT dst, size_t count, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + if (step == 0) { + fill_value_impl(dst, count, static_cast(start)); + return; + } + + // Calculate the actual number of elements to fill + size_t n = 0; + if ((step > 0 && start < end) || (step < 0 && start > end)) { + mllm_fp32_t n_float = (end - start) / step; + if (n_float > 0) { + n = static_cast(std::ceil(n_float)); + if (step > 0) { + if (start + (n - 1) * step >= end) --n; + } else { + if (start + (n - 1) * step <= end) --n; + } + n = std::min(n, count); + } + } + + // Use SIMD for float types where we can vectorize the computation + if constexpr (std::is_same_v) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + + // Create increment vector: [0, 1, 2, 3, ...] * step + const hn::Vec step_vec = hn::Set(d, step); + const hn::Vec n_step_vec = hn::Set(d, step * static_cast(N)); + + // Create base offsets [0, 1, 2, 3, ...] + hn::Vec base = hn::Iota(d, 0); + base = hn::Mul(base, step_vec); + hn::Vec current_start = hn::Add(hn::Set(d, start), base); + + size_t idx = 0; + for (; idx + N <= n; idx += N) { + hn::StoreU(current_start, d, dst + idx); + current_start = hn::Add(current_start, n_step_vec); + } + + // Handle remaining elements + for (; idx < n; ++idx) { dst[idx] = static_cast(start + idx * step); } + } else { + // Scalar fallback for other types + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(start + i * step); } + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +//===----------------------------------------------------------------------===// +// Fill Random (using LCG random number generator) +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_random_impl(T* HWY_RESTRICT dst, size_t count, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; // 2^31 + const mllm_fp32_t range = end - start; + + if (range == 0) { + fill_value_impl(dst, count, static_cast(start)); + return; + } + + uint64_t state = seed; + state = (multiplier * state + increment) % modulus; + + for (size_t i = 0; i < count; ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + dst[i] = static_cast(start + random_value * range); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +} // namespace HWY_NAMESPACE +} // namespace mllm::cpu::common +HWY_AFTER_NAMESPACE(); diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp index 1ad3cee93..7e81adfdf 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp @@ -17,6 +17,7 @@ // Include all inline implementations here #include "mllm/backends/cpu/kernels/common/elewise-inl.hpp" +#include "mllm/backends/cpu/kernels/common/fill-inl.hpp" #if HWY_ONCE namespace mllm::cpu::common { @@ -69,11 +70,188 @@ HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp3 // GELU //===----------------------------------------------------------------------===// // HWY_EXPORT(gelu_fp32); -// +// // HWY_DLLEXPORT void call_gelu_fp32(mllm_fp32_t* out, const mllm_fp32_t* in, size_t n) { // HWY_DYNAMIC_DISPATCH(gelu_fp32)(out, in, n); // } +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_zeros_fp32); +HWY_EXPORT(fill_zeros_fp64); +HWY_EXPORT(fill_zeros_i32); +HWY_EXPORT(fill_zeros_u32); +HWY_EXPORT(fill_zeros_i64); +HWY_EXPORT(fill_zeros_u64); +HWY_EXPORT(fill_zeros_i16); +HWY_EXPORT(fill_zeros_u16); +HWY_EXPORT(fill_zeros_i8); +HWY_EXPORT(fill_zeros_u8); + +HWY_DLLEXPORT void call_fill_zeros_fp32(mllm_fp32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_fp32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_fp64(mllm_fp64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_fp64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i32(mllm_int32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u32(mllm_uint32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i64(mllm_int64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u64(mllm_uint64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i16(mllm_int16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i16)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u16(mllm_uint16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u16)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i8(mllm_int8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i8)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u8(mllm_uint8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u8)(dst, n); } + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_ones_fp32); +HWY_EXPORT(fill_ones_fp64); +HWY_EXPORT(fill_ones_i32); +HWY_EXPORT(fill_ones_u32); +HWY_EXPORT(fill_ones_i64); +HWY_EXPORT(fill_ones_u64); +HWY_EXPORT(fill_ones_i16); +HWY_EXPORT(fill_ones_u16); +HWY_EXPORT(fill_ones_i8); +HWY_EXPORT(fill_ones_u8); + +HWY_DLLEXPORT void call_fill_ones_fp32(mllm_fp32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_fp32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_fp64(mllm_fp64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_fp64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i32(mllm_int32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u32(mllm_uint32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i64(mllm_int64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u64(mllm_uint64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i16(mllm_int16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i16)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u16(mllm_uint16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u16)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i8(mllm_int8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i8)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u8(mllm_uint8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u8)(dst, n); } + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_value_fp32); +HWY_EXPORT(fill_value_fp64); +HWY_EXPORT(fill_value_i32); +HWY_EXPORT(fill_value_u32); +HWY_EXPORT(fill_value_i64); +HWY_EXPORT(fill_value_u64); +HWY_EXPORT(fill_value_i16); +HWY_EXPORT(fill_value_u16); +HWY_EXPORT(fill_value_i8); +HWY_EXPORT(fill_value_u8); + +HWY_DLLEXPORT void call_fill_value_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_fp32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_fp64(mllm_fp64_t* dst, size_t n, mllm_fp64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_fp64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i32(mllm_int32_t* dst, size_t n, mllm_int32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u32(mllm_uint32_t* dst, size_t n, mllm_uint32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i64(mllm_int64_t* dst, size_t n, mllm_int64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u64(mllm_uint64_t* dst, size_t n, mllm_uint64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i16(mllm_int16_t* dst, size_t n, mllm_int16_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i16)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u16(mllm_uint16_t* dst, size_t n, mllm_uint16_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u16)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i8(mllm_int8_t* dst, size_t n, mllm_int8_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i8)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u8(mllm_uint8_t* dst, size_t n, mllm_uint8_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u8)(dst, n, value); +} + +//===----------------------------------------------------------------------===// +// Fill Arange +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_arange_fp32); +HWY_EXPORT(fill_arange_i32); +HWY_EXPORT(fill_arange_u32); +HWY_EXPORT(fill_arange_i64); +HWY_EXPORT(fill_arange_u64); +HWY_EXPORT(fill_arange_i16); +HWY_EXPORT(fill_arange_u16); +HWY_EXPORT(fill_arange_i8); +HWY_EXPORT(fill_arange_u8); + +HWY_DLLEXPORT void call_fill_arange_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_fp32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i64)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u64)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i16)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u16)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i8)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u8)(dst, n, start, end, step); +} + +//===----------------------------------------------------------------------===// +// Fill Random +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_random_fp32); +HWY_EXPORT(fill_random_i32); +HWY_EXPORT(fill_random_u32); +HWY_EXPORT(fill_random_i64); +HWY_EXPORT(fill_random_u64); +HWY_EXPORT(fill_random_i16); +HWY_EXPORT(fill_random_u16); +HWY_EXPORT(fill_random_i8); +HWY_EXPORT(fill_random_u8); + +HWY_DLLEXPORT void call_fill_random_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_fp32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i64)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u64)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i16)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u16)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i8)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u8)(dst, n, start, end, seed); +} + } // namespace mllm::cpu::common #endif // HWY_ONCE diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp index eb100ac43..4df34db0e 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp @@ -7,6 +7,7 @@ #include "mllm/utils/CPUArchHelper.hpp" #if !(defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)) +#include #include "mllm/core/DataTypes.hpp" // Platform-specific definitions used for declaring an interface, independent of @@ -30,6 +31,222 @@ HWY_DLLEXPORT void call_elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp3 HWY_DLLEXPORT void call_elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_zeros_fp32(mllm_fp32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_fp64(mllm_fp64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i32(mllm_int32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u32(mllm_uint32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i64(mllm_int64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u64(mllm_uint64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i16(mllm_int16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u16(mllm_uint16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i8(mllm_int8_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u8(mllm_uint8_t* dst, size_t n); + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_ones_fp32(mllm_fp32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_fp64(mllm_fp64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i32(mllm_int32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u32(mllm_uint32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i64(mllm_int64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u64(mllm_uint64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i16(mllm_int16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u16(mllm_uint16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i8(mllm_int8_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u8(mllm_uint8_t* dst, size_t n); + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_value_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t value); +HWY_DLLEXPORT void call_fill_value_fp64(mllm_fp64_t* dst, size_t n, mllm_fp64_t value); +HWY_DLLEXPORT void call_fill_value_i32(mllm_int32_t* dst, size_t n, mllm_int32_t value); +HWY_DLLEXPORT void call_fill_value_u32(mllm_uint32_t* dst, size_t n, mllm_uint32_t value); +HWY_DLLEXPORT void call_fill_value_i64(mllm_int64_t* dst, size_t n, mllm_int64_t value); +HWY_DLLEXPORT void call_fill_value_u64(mllm_uint64_t* dst, size_t n, mllm_uint64_t value); +HWY_DLLEXPORT void call_fill_value_i16(mllm_int16_t* dst, size_t n, mllm_int16_t value); +HWY_DLLEXPORT void call_fill_value_u16(mllm_uint16_t* dst, size_t n, mllm_uint16_t value); +HWY_DLLEXPORT void call_fill_value_i8(mllm_int8_t* dst, size_t n, mllm_int8_t value); +HWY_DLLEXPORT void call_fill_value_u8(mllm_uint8_t* dst, size_t n, mllm_uint8_t value); + +//===----------------------------------------------------------------------===// +// Fill Arange +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_arange_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); + +//===----------------------------------------------------------------------===// +// Fill Random +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_random_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); + +//===----------------------------------------------------------------------===// +// Template wrapper for generic fill operations +//===----------------------------------------------------------------------===// +template +inline void fill_zeros_anytype(T* dst, size_t n) { + if constexpr (std::is_same_v) { + call_fill_zeros_fp32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_fp64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i8(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u8(dst, n); + } else { + // Fallback for unsupported types + std::memset(dst, 0, n * sizeof(T)); + } +} + +template +inline void fill_ones_anytype(T* dst, size_t n) { + if constexpr (std::is_same_v) { + call_fill_ones_fp32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_fp64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i8(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u8(dst, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(1); } + } +} + +template +inline void fill_value_anytype(T* dst, size_t n, mllm_fp32_t value) { + if constexpr (std::is_same_v) { + call_fill_value_fp32(dst, n, value); + } else if constexpr (std::is_same_v) { + call_fill_value_fp64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i32(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u32(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i16(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u16(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i8(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u8(dst, n, static_cast(value)); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(value); } + } +} + +template +inline void fill_arange_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + if constexpr (std::is_same_v) { + call_fill_arange_fp32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i64(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u64(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i16(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u16(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i8(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u8(dst, n, start, end, step); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(start + i * step); } + } +} + +template +inline void fill_random_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + if constexpr (std::is_same_v) { + call_fill_random_fp32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i64(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u64(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i16(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u16(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i8(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u8(dst, n, start, end, seed); + } else { + // Fallback using LCG + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; + const mllm_fp32_t range = end - start; + uint64_t state = seed; + for (size_t i = 0; i < n; ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + dst[i] = static_cast(start + random_value * range); + } + } +} + } // namespace mllm::cpu::common #endif diff --git a/mllm/backends/cpu/ops/FillOp.cpp b/mllm/backends/cpu/ops/FillOp.cpp index e4d935f51..cf5cee47e 100644 --- a/mllm/backends/cpu/ops/FillOp.cpp +++ b/mllm/backends/cpu/ops/FillOp.cpp @@ -21,7 +21,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_zeros(dst.ptr(), dst.numel(), threads); + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros(dst.ptr(), dst.numel(), threads); #endif @@ -29,7 +29,8 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + std::memset(dst.ptr(), 0, dst.numel() * sizeof(mllm_fp16_t)); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_fp16(dst.ptr(), dst.numel(), threads); #endif @@ -37,7 +38,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -45,7 +46,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -53,7 +54,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -61,7 +62,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -69,7 +70,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -77,7 +78,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -85,7 +86,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -93,7 +94,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -110,7 +111,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_ones(dst.ptr(), dst.numel(), threads); + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones(dst.ptr(), dst.numel(), threads); #endif @@ -118,7 +119,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(1.0f); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_fp16(dst.ptr(), dst.numel(), threads); #endif @@ -126,7 +129,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -134,7 +137,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -142,7 +145,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -150,7 +153,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -158,7 +161,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -166,7 +169,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -174,7 +177,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -182,7 +185,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -199,7 +202,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_arange(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); #endif @@ -207,7 +210,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(options_.start + i * options_.step); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_fp16(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); #endif @@ -215,7 +220,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -224,7 +229,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -233,7 +238,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -242,7 +247,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -251,7 +256,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -260,7 +265,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -269,7 +274,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -278,7 +283,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -295,7 +300,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_random(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -303,7 +308,18 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; + const mllm_fp32_t range = options_.end - options_.start; + uint64_t state = options_.seed; + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + ptr[i] = static_cast(options_.start + random_value * range); + } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_fp16(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -311,7 +327,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -319,7 +335,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -327,7 +343,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -335,7 +351,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -343,7 +359,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -351,7 +367,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -359,7 +375,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -367,7 +383,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -383,7 +399,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_specific_value(dst.ptr(), dst.numel(), options_.value, threads); + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -391,7 +407,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(options_.value); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_fp16(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -399,7 +417,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -407,7 +425,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -415,7 +433,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -423,7 +441,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -431,7 +449,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -439,7 +457,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -447,7 +465,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -455,7 +473,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index 82869ab16..0d34a51b2 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -387,8 +387,10 @@ void recursiveCheckConcatInputs(const std::shared_ptr& ir_ctx, co if (std::abs(ref_scale_v - cur_scale_v) > 1e-6 || ref_zp_v != cur_zp_v) { MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale/zp between inputs. " - "Input '{}': scale={}, zp={}; Input '{}': scale={}, zp={}", - op_name, ref_input_name, ref_scale_v, ref_zp_v, tv->name(), cur_scale_v, cur_zp_v); + "Input '{}': scale={}, zp={}, scale_name={}, zp_name={}; Input '{}': scale={}, zp={}, scale_name={}, " + "zp_name={}", + op_name, ref_input_name, ref_scale_v, ref_zp_v, ref_scale.name(), ref_zero_point.name(), tv->name(), + cur_scale_v, cur_zp_v, cur_scale.name(), cur_zero_point.name()); } } } else if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kSymPerTensor) { diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index 22449f883..cb999191d 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -53,9 +53,25 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("mllm.cpu_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCPU); }); refl::GlobalDef().def("mllm.cuda_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCUDA); }); refl::GlobalDef().def("mllm.qnn_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kQNN); }); + // Floating point types refl::GlobalDef().def("mllm.float32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kFloat32); }); refl::GlobalDef().def("mllm.float16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kFloat16); }); refl::GlobalDef().def("mllm.bfloat16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kBFloat16); }); + + // Signed integer types + refl::GlobalDef().def("mllm.int8_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt8); }); + refl::GlobalDef().def("mllm.int16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt16); }); + refl::GlobalDef().def("mllm.int32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt32); }); + refl::GlobalDef().def("mllm.int64_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt64); }); + + // Unsigned integer types + refl::GlobalDef().def("mllm.uint8_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt8); }); + refl::GlobalDef().def("mllm.uint16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt16); }); + refl::GlobalDef().def("mllm.uint32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt32); }); + refl::GlobalDef().def("mllm.uint64_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt64); }); + + // Bool type + refl::GlobalDef().def("mllm.bool_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt8); }); } //===----------------------------------------------------------------------===// diff --git a/pymllm/__init__.py b/pymllm/__init__.py index 66240b714..1bd31cd6c 100644 --- a/pymllm/__init__.py +++ b/pymllm/__init__.py @@ -12,12 +12,27 @@ from . import service from . import backends from .ffi import ( + # Floating point types float32, float16, bfloat16, + # Signed integer types + int8, + int16, + int32, + int64, + # Unsigned integer types + uint8, + uint16, + uint32, + uint64, + # Bool type + boolean, + # Devices cpu, cuda, qnn, + # Tensor and utilities Tensor, empty, echo, @@ -26,7 +41,6 @@ is_numpy_available, from_torch, from_numpy, - empty, zeros, ones, arange, diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index 17bd04c19..9780eabb0 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -48,6 +48,10 @@ def to_pod(self) -> int: return tvm_ffi.get_global_func("mllm.DType.to_pod")(self) +# ============================================================================= +# DType factory functions +# ============================================================================= +# Floating point types def float32_() -> DType: return _ffi_api.float32_() @@ -60,6 +64,45 @@ def bfloat16_() -> DType: return _ffi_api.bfloat16_() +# Signed integer types +def int8_() -> DType: + return _ffi_api.int8_() + + +def int16_() -> DType: + return _ffi_api.int16_() + + +def int32_() -> DType: + return _ffi_api.int32_() + + +def int64_() -> DType: + return _ffi_api.int64_() + + +# Unsigned integer types +def uint8_() -> DType: + return _ffi_api.uint8_() + + +def uint16_() -> DType: + return _ffi_api.uint16_() + + +def uint32_() -> DType: + return _ffi_api.uint32_() + + +def uint64_() -> DType: + return _ffi_api.uint64_() + + +# Bool type (backed by uint8) +def bool_() -> DType: + return _ffi_api.bool_() + + def cpu_() -> Device: return _ffi_api.cpu_() @@ -219,10 +262,32 @@ def is_contiguous(self): return tvm_ffi.get_global_func("mllm.Tensor.is_contiguous")(self) -# Global dtypes +# ============================================================================= +# Global dtype instances +# ============================================================================= +# Floating point types float32: DType = float32_() float16: DType = float16_() bfloat16: DType = bfloat16_() + +# Signed integer types +int8: DType = int8_() +int16: DType = int16_() +int32: DType = int32_() +int64: DType = int64_() + +# Unsigned integer types +uint8: DType = uint8_() +uint16: DType = uint16_() +uint32: DType = uint32_() +uint64: DType = uint64_() + +# Bool type (use 'boolean' to avoid shadowing Python's built-in 'bool') +boolean: DType = bool_() + +# ============================================================================= +# Global device instances +# ============================================================================= cpu: Device = cpu_() cuda: Device = cuda_() qnn: Device = qnn_() From e976d11e4dbbc6baf7d1717e69aeab2dda7ffbf6 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Wed, 21 Jan 2026 13:25:41 +0000 Subject: [PATCH 08/17] feat(qnn): Enhance QNNBackend initialization with improved logging and error handling; update default log level to verbose. Add QEmbedding class for quantized embedding operations in PyTorch. Introduce build tasks for Android and x86 QNN AOT SDKs. --- .gitignore | 1 + mllm/CMakeLists.txt | 9 +- mllm/backends/qnn/QNNBackend.cpp | 89 +++++++++--- mllm/backends/qnn/QNNBackend.hpp | 2 +- mllm/backends/qnn/Register.cpp | 15 +- mllm/backends/qnn/aot/QnnWrappersAPI.cpp | 7 + .../qualcomm/transformers/core/embedding.py | 133 ++++++++++++++++++ tasks/build_sdk_android_qnn_aot.yaml | 22 +++ tasks/build_sdk_x86_qnn_aot.yaml | 2 +- 9 files changed, 255 insertions(+), 25 deletions(-) create mode 100644 pymllm/backends/qualcomm/transformers/core/embedding.py create mode 100644 tasks/build_sdk_android_qnn_aot.yaml diff --git a/.gitignore b/.gitignore index 22e2a9a6f..7397d6ecc 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ tasks/mllmteam* build*/ install*/ mllm-sdk-*/ +mllm-install-*/ # Pymllm related stubs/ diff --git a/mllm/CMakeLists.txt b/mllm/CMakeLists.txt index fd796f95a..06fa5aab2 100644 --- a/mllm/CMakeLists.txt +++ b/mllm/CMakeLists.txt @@ -58,7 +58,14 @@ endif() # FIXME: @oreomaker Need to remove comma features in slice! # Suppress comma-subscript warnings (deprecated C++ feature that will be removed in C++26) -target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) +# This flag is only available in Clang 13+ and GCC 10+ +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) +elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "10.0") + target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) + endif() +endif() # ONLY APPLE CAN DO ! # Processing OpenMP diff --git a/mllm/backends/qnn/QNNBackend.cpp b/mllm/backends/qnn/QNNBackend.cpp index 54da97c9d..05ebedfcb 100644 --- a/mllm/backends/qnn/QNNBackend.cpp +++ b/mllm/backends/qnn/QNNBackend.cpp @@ -29,15 +29,28 @@ QNNBackend::QNNBackend() : Backend(kQNN, createQNNAllocator()) { QNNViewOpFactory, QNNRMSNormOpFactory, QNNTransposeOpFactory, QNNX2XOpFactory, QNNCastTypeOpFactory, QNNParamOpFactory, QNNSiLUOpFactory, QNNEmbeddingOpFactory>(); - QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_ERROR; // default QNN log level + QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_VERBOSE; // default QNN log level profilingLevel_ = ProfilingLevel::OFF; debug_ = false; // when set true, NATIVE tensor will be regared as APP_READ tensor - loadQNNSymbol(); - loadQNNSystemSymbol(); + if (!loadQNNSymbol()) { + MLLM_ERROR_EXIT(ExitCode::kQnnError, "Failed to load QNN symbols"); + } else { + MLLM_INFO("QNN symbols loaded successfully"); + } + + if (!loadQNNSystemSymbol()) { + MLLM_ERROR_EXIT(ExitCode::kQnnError, "Failed to load QNN System symbols"); + } else { + MLLM_INFO("QNN System symbols loaded successfully"); + } runtime_ = QNNRuntime::create(profilingLevel_, qnnLogLevel); - if (!runtime_) { MLLM_ERROR_EXIT(1, "Failed to create QNN Runtime"); } + if (!runtime_) { + MLLM_ERROR_EXIT(ExitCode::kQnnError, "Failed to create QNN Runtime"); + } else { + MLLM_INFO("QNN Runtime created successfully"); + } // check QNN capability, detect QNN features for future use char* backendBuildId{nullptr}; @@ -59,6 +72,7 @@ QNNBackend::QNNBackend() : Backend(kQNN, createQNNAllocator()) { perf_ = QNNPerf::create(&runtime_->qnnInterface); perf_->setPowerConfigBurst(); perf_->setRpcLatencyAndPolling(); + MLLM_INFO("QNN Perf created successfully"); } QNNPerf::QNNPerf(const QNN_INTERFACE_VER_TYPE* qnnInterface) { @@ -204,11 +218,13 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ // Create Log Qnn_LogHandle_t logHandle = nullptr; { - QnnLog_Callback_t logCallback = &__mllmQnnLoggerCallback; + QnnLog_Callback_t logCallback = __mllmQnnLoggerCallback; if ((QNN_GET_ERROR_CODE(qnnInterface.logCreate(logCallback, qnnLogLevel, &logHandle)) != QNN_SUCCESS) || (logHandle == nullptr)) { MLLM_ERROR("Failed to initialize logging in the backend."); return nullptr; + } else { + MLLM_INFO("Logging initialized successfully"); } } @@ -220,6 +236,8 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ || (backendHandle == nullptr)) { MLLM_ERROR("Failed to create the backend."); return nullptr; + } else { + MLLM_INFO("Backend created successfully"); } } @@ -227,16 +245,13 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ Qnn_DeviceHandle_t deviceHandle = nullptr; { // Check whether the device API is supported. - if (nullptr != qnnInterface.propertyHasCapability) { - auto qnnStatus = qnnInterface.propertyHasCapability(QNN_PROPERTY_GROUP_DEVICE); - if (QNN_PROPERTY_NOT_SUPPORTED == qnnStatus) { - MLLM_WARN("Device property is not supported"); - return nullptr; - } - if (QNN_PROPERTY_ERROR_UNKNOWN_KEY == qnnStatus) { - MLLM_ERROR("Device property is not known to backend"); + if (nullptr != qnnInterface.deviceCreate) { + auto status = qnnInterface.deviceCreate(logHandle, nullptr, &deviceHandle); + if (QNN_SUCCESS != status) { + MLLM_ERROR("Failed to create device, error: {}", (int)status); return nullptr; } + MLLM_INFO("Device created successfully"); } } @@ -269,9 +284,7 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ std::string target; }; - std::vector opPackages = { - {.path = "libQnnLLaMAPackage_CPU.so", .interfaceProvider = "LLaMAPackageInterfaceProvider", .target = "CPU"}, - {.path = "libQnnLLaMAPackage_HTP.so", .interfaceProvider = "LLaMAPackageInterfaceProvider", .target = "HTP"}}; + std::vector opPackages = {}; for (const auto& pkg : opPackages) { if (!qnnInterface.backendRegisterOpPackage) { @@ -298,6 +311,8 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ != QnnSystemInterface_getProviders((const QnnSystemInterface_t***)&systemInterfaceProviders, &numProviders)) { MLLM_ERROR("Failed to get system interface providers."); return nullptr; + } else { + MLLM_INFO("System interface providers found: {}", numProviders); } if (0 == numProviders) { MLLM_ERROR("Failed to get interface providers: 0 interface providers."); @@ -305,11 +320,17 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ } bool foundValidSystemInterface = false; for (size_t pIdx = 0; pIdx < numProviders; pIdx++) { - foundValidSystemInterface = true; if (QNN_SYSTEM_API_VERSION_MAJOR == systemInterfaceProviders[pIdx]->systemApiVersion.major && QNN_SYSTEM_API_VERSION_MINOR <= systemInterfaceProviders[pIdx]->systemApiVersion.minor) { qnnSystemInterface = systemInterfaceProviders[pIdx]->QNN_SYSTEM_INTERFACE_VER_NAME; + foundValidSystemInterface = true; break; + } else { + // Print system interface provider and self version + MLLM_WARN("System interface provider: {} version: {}", systemInterfaceProviders[pIdx]->systemApiVersion.major, + systemInterfaceProviders[pIdx]->systemApiVersion.minor); + MLLM_WARN("Self version: {} {}", QNN_SYSTEM_API_VERSION_MAJOR, QNN_SYSTEM_API_VERSION_MINOR); + MLLM_WARN("Unable to find a valid system interface."); } } if (!foundValidSystemInterface) { @@ -334,7 +355,14 @@ bool QNNRuntime::retrieveContext(const std::string& contextBinaryPath, Qnn_Conte std::vector>& qnnModels, QnnContext_Config_t** contextConfig) { // Read the binary from qnn_context.bin and get the size in byte std::ifstream file(contextBinaryPath, std::ios::binary | std::ios::ate); + if (!file.is_open() || !file.good()) { + MLLM_ERROR("Could not open context binary file: {}", contextBinaryPath); + return false; + } else { + MLLM_INFO("Context binary file opened successfully: {}", contextBinaryPath); + } std::streamsize size = file.tellg(); + MLLM_INFO("Context binary file size: {} MB", size / 1024 / 1024); file.seekg(0, std::ios::beg); auto binaryBuffer = std::make_unique(size); @@ -344,17 +372,27 @@ bool QNNRuntime::retrieveContext(const std::string& contextBinaryPath, Qnn_Conte // inspect binary info QnnSystemContext_Handle_t sysCtxHandle{nullptr}; + if (!qnnSystemInterface.systemContextCreate) { + MLLM_ERROR("systemContextCreate is nullptr."); + return false; + } if (QNN_SUCCESS != qnnSystemInterface.systemContextCreate(&sysCtxHandle)) { MLLM_ERROR("Could not create system handle."); return false; + } else { + MLLM_INFO("System context created successfully"); } + const QnnSystemContext_BinaryInfo_t* binaryInfo{nullptr}; Qnn_ContextBinarySize_t binaryInfoSize{0}; + if (QNN_SUCCESS != qnnSystemInterface.systemContextGetBinaryInfo(sysCtxHandle, static_cast(binaryBuffer.get()), size, &binaryInfo, &binaryInfoSize)) { MLLM_ERROR("Failed to get context binary info"); return false; + } else { + MLLM_INFO("Context binary info retrieved successfully"); } // Extract graph metadata to create QNNModels instead of GraphInfo_t @@ -365,13 +403,24 @@ bool QNNRuntime::retrieveContext(const std::string& contextBinaryPath, Qnn_Conte MLLM_ERROR("Failed to copy metadata."); return false; } - qnnSystemInterface.systemContextFree(sysCtxHandle); + if (QNN_SUCCESS != qnnSystemInterface.systemContextFree(sysCtxHandle)) { + MLLM_ERROR("Could not free system context."); + return false; + } else { + MLLM_INFO("System context freed successfully"); + } sysCtxHandle = nullptr; // Create context from binary Qnn_ContextBinarySize_t writtenSize = 0; - qnnInterface.contextCreateFromBinary(backendHandle, deviceHandle, (const QnnContext_Config_t**)contextConfig, - binaryBuffer.get(), size, &context, profileHandle); + if (QNN_CONTEXT_NO_ERROR + != qnnInterface.contextCreateFromBinary(backendHandle, deviceHandle, (const QnnContext_Config_t**)contextConfig, + binaryBuffer.get(), size, &context, profileHandle)) { + MLLM_ERROR("Could not create context from binary. Mostly due to binary's qnn version mismatch with backend's qnn version."); + return false; + } else { + MLLM_INFO("Context created from binary successfully"); + } // Create QNNModels for each graph and initialize from context qnnModels.clear(); diff --git a/mllm/backends/qnn/QNNBackend.hpp b/mllm/backends/qnn/QNNBackend.hpp index 49669c7c1..78953f32d 100644 --- a/mllm/backends/qnn/QNNBackend.hpp +++ b/mllm/backends/qnn/QNNBackend.hpp @@ -45,7 +45,7 @@ class QNNRuntime { ~QNNRuntime(); static std::unique_ptr create(ProfilingLevel profilingLevel = ProfilingLevel::OFF, - QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_WARN) { + QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_VERBOSE) { return std::unique_ptr(initRuntime(profilingLevel, qnnLogLevel)); } diff --git a/mllm/backends/qnn/Register.cpp b/mllm/backends/qnn/Register.cpp index 158294f35..88185921e 100644 --- a/mllm/backends/qnn/Register.cpp +++ b/mllm/backends/qnn/Register.cpp @@ -21,9 +21,18 @@ void initQnnBackend(const std::string& context_path) { // 1. Register backend auto backend = std::make_shared(); if (std::filesystem::exists(context_path)) { - if (!backend->loadContext(context_path)) { MLLM_ERROR_EXIT(1, "Failed to load QNN context from {}", context_path); } + MLLM_INFO("QNN context path exists: {}", context_path); + if (!backend->loadContext(context_path)) { + MLLM_ERROR_EXIT(1, "Failed to load QNN context from {}", context_path); + } else { + MLLM_INFO("QNN context loaded successfully from {}", context_path); + } } else { - if (!backend->createContext()) { MLLM_ERROR_EXIT(1, "Failed to create QNN context"); } + if (!backend->createContext()) { + MLLM_ERROR_EXIT(1, "Failed to create QNN context"); + } else { + MLLM_INFO("QNN context created successfully"); + } } ctx.registerBackend(backend); @@ -33,6 +42,8 @@ void initQnnBackend(const std::string& context_path) { .really_large_tensor_threshold = 0, .using_buddy_mem_pool = false, }); + MLLM_INFO("QNN memory manager registered"); + // 3. Initialize dispatcher manager ctx.dispatcherManager()->registerDispatcher( createQNNDispatcher(ctx.dispatcherManager()->getExecutor(), qnn::QNNDispatcherOptions())); diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index b2b04fd78..23496591f 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -107,6 +107,7 @@ std::string QnnAOTNodeTensor::parseQnnTensorNameFromIR(const ir::tensor::TensorV Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::tensor::TensorValue::ptr_t& v) { Qnn_QuantizeParams_t ret = QNN_QUANTIZE_PARAMS_INIT; + MLLM_RT_ASSERT(v); MLLM_RT_ASSERT(v->getAttr("quant_recipe")); auto quant_spec = v->getAttr("quant_recipe")->cast_()->spec_; @@ -120,6 +121,9 @@ Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::ten auto cfg = std::static_pointer_cast(quant_spec); ret.encodingDefinition = QNN_DEFINITION_DEFINED; ret.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + if (!cfg->scale || !cfg->zero_point) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "AsymPerTensor quant recipe has no scale or zero point. tensor: {}", v->name()); + } ret.scaleOffsetEncoding = Qnn_ScaleOffset_t{.scale = cfg->scale.item(), .offset = cfg->zero_point.item()}; break; } @@ -127,6 +131,9 @@ Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::ten auto cfg = std::static_pointer_cast(quant_spec); ret.encodingDefinition = QNN_DEFINITION_DEFINED; ret.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + if (!cfg->scale) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "SymPerTensor quant recipe has no scale. tensor: {}", v->name()); + } ret.scaleOffsetEncoding = Qnn_ScaleOffset_t{.scale = cfg->scale.item(), .offset = 0}; break; } diff --git a/pymllm/backends/qualcomm/transformers/core/embedding.py b/pymllm/backends/qualcomm/transformers/core/embedding.py new file mode 100644 index 000000000..84c4d61fe --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/core/embedding.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +from torch.ao.quantization import FakeQuantize, MinMaxObserver + + +class QEmbedding(nn.Module): + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + quant_bits=16, + ): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.quant_bits = quant_bits + + self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim)) + nn.init.normal_(self.weight) + + if padding_idx is not None: + with torch.no_grad(): + self.weight[padding_idx].fill_(0) + + # Quantization configuration for Weight + self.weight_fake_quant = FakeQuantize( + observer=MinMaxObserver.with_args( + qscheme=torch.per_tensor_affine, + dtype=torch.qint32, + eps=0.0001 / 65535, + ), + quant_min=0, + quant_max=2 ** (quant_bits) - 1, + dtype=torch.qint32, + qscheme=torch.per_tensor_affine, + ) + + def forward(self, x): + # 1. Weight fake quantization + # If observer is not closed, this step will continuously update scale/zp + # If freeze_weight() is called, this will just use fixed scale/zp for quantization + w_q = self.weight_fake_quant(self.weight) + + # 2. Embedding lookup (Gather operation) + return nn.functional.embedding( + x, + w_q, + padding_idx=self.padding_idx, + ) + + @torch.no_grad() + def convert_to_deploy(self): + """ + In-place replacement of self.weight: + Float Parameter -> Int Buffer + """ + # 1. Ensure quantization parameters are ready + if self.weight_fake_quant.scale is None: + self.freeze_weight() + + scale = self.weight_fake_quant.scale + zero_point = self.weight_fake_quant.zero_point + quant_min = self.weight_fake_quant.quant_min + quant_max = self.weight_fake_quant.quant_max + + # 2. Calculate integer values + # w_int = round(w / s + zp) + w_int = torch.round(self.weight / scale + zero_point).clamp( + quant_min, quant_max + ) + + # 3. Set target integer type + if self.quant_bits <= 8: + target_dtype = torch.uint8 + elif self.quant_bits <= 16: + target_dtype = torch.uint16 + else: + target_dtype = torch.uint32 + + w_int = w_int.to(target_dtype) + + # === Key steps: Replacement operations === + + # A. Delete original Parameter 'weight' + # Must delete first, otherwise cannot register buffer with same name + del self.weight + + # B. Register Buffer with same name 'weight' + # This makes state_dict['weight'] become Int Tensor + self.register_buffer("weight", w_int) + + # C. Register Scale (usually needed by engine) + self.register_buffer("scale", scale) + self.register_buffer("zero_point", zero_point) + + # D. Clean up unnecessary modules + if hasattr(self, "weight_fake_quant"): + del self.weight_fake_quant + + class_name = self.__class__.__name__ + instance_class_name = type(self).__name__ + print( + f"Class: {class_name}, Instance: {instance_class_name}, Deploy Mode Activated. 'weight' is now {self.weight.dtype} buffer. zp is {zero_point}" + ) + + @torch.no_grad() + def freeze_weight(self): + """ + Manually trigger Observer to observe and calculate scale, then lock it. + Solve the problem of output being 0 on first run. + """ + self.weight_fake_quant.activation_post_process(self.weight) + s, zp = self.weight_fake_quant.activation_post_process.calculate_qparams() + self.weight_fake_quant.scale.copy_(s) + self.weight_fake_quant.zero_point.copy_(zp) + self.weight_fake_quant.disable_observer() + class_name = self.__class__.__name__ + instance_class_name = type(self).__name__ + print( + f"Class: {class_name}, Instance: {instance_class_name}, Weight Quantized: scale={self.weight_fake_quant.scale}, zp={self.weight_fake_quant.zero_point}" + ) + + def disable_quant(self): + """Completely turn off quantization noise and return to floating point mode""" + self.weight_fake_quant.disable_fakequant() + + def extra_repr(self): + s = f"{self.num_embeddings}, {self.embedding_dim}" + if self.padding_idx is not None: + s += f", padding_idx={self.padding_idx}" + return s diff --git a/tasks/build_sdk_android_qnn_aot.yaml b/tasks/build_sdk_android_qnn_aot.yaml new file mode 100644 index 000000000..f0e983b75 --- /dev/null +++ b/tasks/build_sdk_android_qnn_aot.yaml @@ -0,0 +1,22 @@ +Tasks: + - CMakeConfigTask: + cmake_cfg_path: "build-android-arm64-v8a-qnn" + cmake_build_type: "ReleaseWithDebInfo" + cmake_toolchain_file: "$ANDROID_NDK_PATH/build/cmake/android.toolchain.cmake" + cmake_extra_args: + - "-DMLLM_CROSS_COMPILE=ON" + - "-DMLLM_BUILD_ARM_BACKEND=ON" + - "-DMLLM_BUILD_QNN_BACKEND=ON" + - "-DANDROID_PLATFORM=android-28" + - "-DANDROID_ABI=arm64-v8a" + - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=armv8.2-a+fp16+fp16fml+dotprod+i8mm;-ffast-math;-Wno-nan-infinity-disabled"' + - "-DCMAKE_INSTALL_PREFIX=mllm-install-android-arm64-v8a-qnn" + - "-DMLLM_KERNEL_USE_THREADS=ON" + - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=ON" + - "-DMLLM_KERNEL_USE_THREADS_VENDOR_MLLM=OFF" + + - CMakeBuildTask: + cmake_cfg_path: "build-android-arm64-v8a-qnn" + + - CMakeInstallTask: + cmake_cfg_path: "build-android-arm64-v8a-qnn" diff --git a/tasks/build_sdk_x86_qnn_aot.yaml b/tasks/build_sdk_x86_qnn_aot.yaml index f33281616..fd9131d2e 100644 --- a/tasks/build_sdk_x86_qnn_aot.yaml +++ b/tasks/build_sdk_x86_qnn_aot.yaml @@ -1,7 +1,7 @@ Tasks: - CMakeConfigTask: cmake_cfg_path: "build-qnn-aot" - cmake_build_type: "Release" + cmake_build_type: "ReleaseWithDebInfo" cmake_extra_args: # Optional, If use Highway - "-DHWY_ENABLE_TESTS=OFF" From 224d68e0dcacb337239b4cc769594d4b66178a1f Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 23 Jan 2026 02:24:07 +0000 Subject: [PATCH 09/17] feat(qnn): Update quantization handling and embedding output data types; ensure position-independent code for flatbuffers. Enhance context creation with existing context checks and improve weight quantization specifications. --- CMakeLists.txt | 1 + mllm/backends/qnn/aot/QnnWrappersAPI.cpp | 6 +++++ .../qnn/aot/passes/LLMQuantRecipePass.cpp | 26 ++++++++++--------- mllm/backends/qnn/aot/passes/PTQPass.cpp | 23 +++++++++++++++- mllm/backends/qnn/aot_rt/PromptProcessor.cpp | 1 - mllm/core/aops/EmbeddingOp.cpp | 6 +++-- 6 files changed, 47 insertions(+), 16 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fca470ee5..a19e80df3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -262,6 +262,7 @@ add_subdirectory(third_party/fmt) add_subdirectory(third_party/xxHash) set(FLATBUFFERS_BUILD_TESTS OFF) add_subdirectory(third_party/flatbuffers EXCLUDE_FROM_ALL) +set_target_properties(flatbuffers PROPERTIES POSITION_INDEPENDENT_CODE ON) add_subdirectory(mllm) if(MLLM_ENABLE_TEST) diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index 23496591f..a79047e78 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -436,6 +436,12 @@ void QnnAOTEnv::_setup(const std::string& path) { } std::shared_ptr QnnAOTEnv::createContext(const std::string& name, bool weights_sharing) { + // Check if context with this name already exists + if (contexts_.count(name) > 0) { + MLLM_WARN("Context '{}' already exists, reusing the existing context", name); + return contexts_[name]; + } + std::shared_ptr context = std::make_shared(); context->name_ = name; diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp index 957fdf321..18bbb505c 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp @@ -986,6 +986,7 @@ bool LLMQuantRecipeEmbeddingPattern::rewrite(ir::IRWriter& writer, const ir::op_ auto annotation_attr = writer.create(); + // i_0 logic stays the same if (!i_0->getAttr("quant_recipe")) { auto i_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_0->cast_()); i_0->setAttr("quant_recipe", i_0_spec); @@ -996,16 +997,7 @@ bool LLMQuantRecipeEmbeddingPattern::rewrite(ir::IRWriter& writer, const ir::op_ i_0->getAttr("quant_recipe")->cast_()->spec_); } - if (!o_0->getAttr("quant_recipe")) { - auto o_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), o_0->cast_()); - o_0->setAttr("quant_recipe", o_0_spec); - annotation_attr->annotation_.outputs.emplace_back(o_0_spec->spec_); - } else { - annotation_attr->annotation_.outputs.emplace_back( - o_0->getAttr("quant_recipe")->cast_()->spec_); - } - - // Weights + // Weights - must be uint16, force set to kUInt16PerTensorAsy auto weight_name = embedding_op->getAOp()->getName() + ".weight"; auto weight_reg_tensor_ir = writer.getContext()->lookupSymbolTable(weight_name); MLLM_RETURN_FALSE_IF_NOT(weight_reg_tensor_ir); @@ -1013,11 +1005,21 @@ bool LLMQuantRecipeEmbeddingPattern::rewrite(ir::IRWriter& writer, const ir::op_ MLLM_RETURN_FALSE_IF_NOT(weight_reg_tensor_ir->outputs().front()->isa_()); auto weight_tensor = weight_reg_tensor_ir->outputs().front()->cast_(); - // Embedding weight quantization method same as outputs, but not share, just same type - auto weight_spec_attr = genSimpleQuantizationSpecAttr(writer.getContext(), weight_tensor); + // Embedding weight dtype must be uint16, force set to kUInt16PerTensorAsy + MLLM_RETURN_FALSE_IF_NOT(weight_tensor->tensor_.dtype() == kUInt16 || weight_tensor->tensor_.dtype() == kUInt16PerTensorAsy); + weight_tensor->tensor_ = weight_tensor->tensor_.__unsafeSetDType(kUInt16PerTensorAsy); + + // Create weight spec with kUInt16PerTensorAsy (AsymPerTensor) + auto weight_spec = + ir::linalg::QuantizationSpecAsymPerTensor::create(0, 65535, kUInt16, kFloat32, kInt32, Tensor::nil(), Tensor::nil()); + auto weight_spec_attr = writer.getContext()->create(weight_spec); weight_reg_tensor_ir->outputs().front()->setAttr("quant_recipe", weight_spec_attr); annotation_attr->annotation_.weights.insert({"weight", weight_spec_attr->spec_}); + // o_0's quant recipe shares with weight + o_0->setAttr("quant_recipe", weight_spec_attr); + annotation_attr->annotation_.outputs.emplace_back(weight_spec_attr->spec_); + // Attach to quantize node node->setAttr("quant_recipe", annotation_attr); diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index 0d34a51b2..d9f1d97cb 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -111,6 +111,22 @@ void solveEmbeddingWeight(const ir::IRContext::ptr_t& ctx, const ParameterFile:: weight_spec->solved = true; break; } + case ir::linalg::QuantizationSpecType::kAsymPerTensor: { + auto this_spec = std::static_pointer_cast(weight_spec); + auto scale = pf->pull(mllm_op->getName() + ".scale"); + auto zero_point = pf->pull(mllm_op->getName() + ".zero_point"); + this_spec->scale = scale; + this_spec->zero_point = zero_point; + checkTypeLimits(pf->pull(mllm_op->getName() + ".weight"), this_spec->quant_min, this_spec->quant_max); + MLLM_RT_ASSERT(scale.dtype() == kFloat32); + MLLM_RT_ASSERT(scale.rank() == 1); + MLLM_RT_ASSERT(scale.item() > 0); + MLLM_RT_ASSERT(zero_point.dtype() == kInt32); + MLLM_RT_ASSERT(zero_point.rank() == 1); + MLLM_RT_ASSERT(zero_point.item() >= 0); + weight_spec->solved = true; + break; + } default: { NYI("quant recipe type not support"); } @@ -203,6 +219,9 @@ void _recursiveSolveNormalImpl(const ir::IRContext::ptr_t& ctx, const ir::Val::p auto _attr = ctx->create(std::vector{(uint16_t)ptq_constant_v}); tv->removeAttr("constant"); tv->setAttr("constant", _attr); + + MLLM_INFO("Constant tensor '{}' quantized (AsymPerTensor): before={}, after={}", tv->name(), constant_v, + ptq_constant_v); } this_spec->solved = true; @@ -262,6 +281,8 @@ void _recursiveSolveNormalImpl(const ir::IRContext::ptr_t& ctx, const ir::Val::p auto _attr = ctx->create(std::vector{(uint16_t)ptq_constant_v}); tv->removeAttr("constant"); tv->setAttr("constant", _attr); + + MLLM_INFO("Constant tensor '{}' quantized (SymPerTensor): before={}, after={}", tv->name(), constant_v, ptq_constant_v); } this_spec->solved = true; @@ -273,7 +294,7 @@ void _recursiveSolveNormalImpl(const ir::IRContext::ptr_t& ctx, const ir::Val::p break; } default: { - NYI("quant recipe type not support on tensor: {}", v->name()); + NYI("Quant recipe type not support on tensor: {}", v->name()); } } } diff --git a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp index 50396955d..99cd22db9 100644 --- a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp +++ b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp @@ -1,4 +1,3 @@ - // Copyright (c) MLLM Team. // Licensed under the MIT License. diff --git a/mllm/core/aops/EmbeddingOp.cpp b/mllm/core/aops/EmbeddingOp.cpp index b67eeff7a..a5a6400dd 100644 --- a/mllm/core/aops/EmbeddingOp.cpp +++ b/mllm/core/aops/EmbeddingOp.cpp @@ -70,8 +70,10 @@ void EmbeddingOp::reshape(const std::vector& inputs, std::vector std::vector o_shape{/*batch*/ shape[0], /*seq*/ shape[1], /*feat dim*/ options_.hidden_size}; - // FIXME: We should tell embedding output to use what kinds of data types. Currently it's hardcoded to float32. - outputs.emplace_back(Tensor::empty(o_shape, kFloat32, i.device())); + // Output dtype should match weight dtype (e.g., uint16 for AsymPerTensor quantization) + auto out_dtype = weight_.dtype(); + if (weight_.dtype() == kUInt16) { out_dtype = kUInt16PerTensorAsy; } + outputs.emplace_back(Tensor::empty(o_shape, out_dtype, i.device())); } void EmbeddingOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } From d2d5c09ce56c74f38ffa3a38e4273e29e98af1eb Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 23 Jan 2026 02:42:09 +0000 Subject: [PATCH 10/17] feat(qwen3): Integrate QEmbedding for quantized embeddings and refine input layer normalization handling in Qwen3DecoderLayer. Update weight conversion logic in training script to address model compatibility issues. --- .../qualcomm/transformers/qwen3/modeling_qwen3.py | 12 +++++++----- .../backends/qualcomm/transformers/qwen3/runner.py | 3 +++ pymllm/backends/qualcomm/transformers/qwen3/train.py | 10 +++++++--- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 92efaa06d..cf71a48ba 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -54,6 +54,7 @@ ActivationQDQ, FixedActivationQDQ, ) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver @@ -393,7 +394,8 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.attention_type = config.layer_types[layer_idx] # QDQ - self.input_layernorm_input_qdq = ActivationQDQ(bits=16) + if self.layer_dix != 0: + self.input_layernorm_input_qdq = ActivationQDQ(bits=16) self.add_0_lhs_input_qdq = ActivationQDQ(bits=16) self.add_0_output_qdq = ActivationQDQ(bits=16) self.add_1_lhs_input_qdq = ActivationQDQ(bits=16) @@ -412,7 +414,8 @@ def forward( ] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: - hidden_states = self.input_layernorm_input_qdq(hidden_states) + if self.layer_dix != 0: + hidden_states = self.input_layernorm_input_qdq(hidden_states) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention @@ -513,9 +516,8 @@ def __init__(self, config: Qwen3Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx + self.embed_tokens = QEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, quant_bits=16 ) self.layers = nn.ModuleList( [ diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index 6565ca7e6..416816875 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -11,6 +11,7 @@ QLinearLPBQ, QLinearW8A16_PerChannelSym, ) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding from pymllm.backends.qualcomm.transformers.qwen3.modeling_qwen3 import Qwen3ForCausalLM @@ -49,6 +50,8 @@ def convert_weight(m): m.convert_to_conv2d_deploy_hwio() if isinstance(m, QRMSNorm): m.convert_to_deploy() + if isinstance(m, QEmbedding): + m.convert_to_deploy() class Qwen3Quantizer: diff --git a/pymllm/backends/qualcomm/transformers/qwen3/train.py b/pymllm/backends/qualcomm/transformers/qwen3/train.py index 25361f372..9c4604d8f 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/train.py @@ -48,9 +48,13 @@ def main(): # Things below is for deploy. We will turn all fp32 weights and some buffers(rope) to quantized dtype. # !!! # This line maybe error. we need use quantized weight!!! not embed_tokens.weight!!! - m.model.lm_head.weight = torch.nn.Parameter( - m.model.model.embed_tokens.weight.clone() - ) + # m.model.lm_head.weight = torch.nn.Parameter( + # m.model.model.embed_tokens.weight.clone() + # ) + if "1.7B" in args.model_path: + raise ValueError( + "1.7B model is not supported for now due to tied embedding weights is not supported." + ) m.convert() os.makedirs(args.output_dir, exist_ok=True) From c4f230648ccce567ced2d7c8990d81757a64aa24 Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Fri, 23 Jan 2026 15:35:15 +0800 Subject: [PATCH 11/17] fix --- examples/qwen2_5omni/config_qwen2_5omni.json | 495 ------------------- examples/qwen2_5omni/image_infer.cpp | 4 +- 2 files changed, 2 insertions(+), 497 deletions(-) delete mode 100644 examples/qwen2_5omni/config_qwen2_5omni.json diff --git a/examples/qwen2_5omni/config_qwen2_5omni.json b/examples/qwen2_5omni/config_qwen2_5omni.json deleted file mode 100644 index 633e1b2b1..000000000 --- a/examples/qwen2_5omni/config_qwen2_5omni.json +++ /dev/null @@ -1,495 +0,0 @@ -{ - "architectures": [ - "Qwen2_5OmniModel" - ], - "enable_audio_output": true, - "enable_talker": true, - "model_type": "qwen2_5_omni", - "talker_config": { - "_attn_implementation_autoset": true, - "_name_or_path": "Qwen2.5-Omni-7B/talker", - "architectures": [ - "Qwen2OmniTalkerForConditionalGeneration" - ], - "attention_dropout": 0.0, - "audio_end_token_id": 151648, - "audio_start_token_id": 151647, - "audio_token_index": 151646, - "embedding_size": 3584, - "head_dim": 128, - "hidden_act": "silu", - "hidden_size": 896, - "image_token_index": 151655, - "init_std": 0.02, - "initializer_range": 0.02, - "intermediate_size": 18944, - "max_position_embeddings": 32768, - "max_window_layers": 28, - "model_type": "qwen2_5_omni_talker", - "num_attention_heads": 12, - "num_hidden_layers": 24, - "num_key_value_heads": 4, - "position_id_per_seconds": 25, - "rms_norm_eps": 1e-06, - "rope_scaling": { - "mrope_section": [ - 16, - 24, - 24 - ], - "rope_type": "default", - "type": "default" - }, - "rope_theta": 1000000.0, - "seconds_per_chunk": 2, - "sliding_window": 32768, - "spatial_merge_size": 2, - "torch_dtype": "bfloat16", - "tts_codec_end_token_id": 8294, - "tts_codec_mask_token_id": 8296, - "tts_codec_pad_token_id": 8292, - "tts_codec_start_token_id": 8293, - "tts_text_end_token_id": 151861, - "tts_text_pad_token_id": 151859, - "tts_text_start_token_id": 151860, - "use_cache": true, - "use_sliding_window": false, - "video_token_index": 151656, - "vision_end_token_id": 151653, - "vision_start_token_id": 151652, - "vocab_size": 8448 - }, - "thinker_config": { - "_attn_implementation_autoset": true, - "_name_or_path": "Qwen2.5-Omni-7B/thinker", - "architectures": [ - "Qwen2OmniNaViTThinkerForConditionalGeneration" - ], - "audio_config": { - "_attn_implementation_autoset": true, - "_name_or_path": "", - "activation_dropout": 0.0, - "activation_function": "gelu", - "add_cross_attention": false, - "architectures": null, - "attention_dropout": 0.0, - "bad_words_ids": null, - "begin_suppress_tokens": null, - "bos_token_id": null, - "chunk_size_feed_forward": 0, - "cross_attention_hidden_size": null, - "d_model": 1280, - "decoder_start_token_id": null, - "diversity_penalty": 0.0, - "do_sample": false, - "dropout": 0.0, - "early_stopping": false, - "encoder_attention_heads": 20, - "encoder_ffn_dim": 5120, - "encoder_layerdrop": 0.0, - "encoder_layers": 32, - "encoder_no_repeat_ngram_size": 0, - "eos_token_id": null, - "exponential_decay_length_penalty": null, - "finetuning_task": null, - "forced_bos_token_id": null, - "forced_eos_token_id": null, - "id2label": { - "0": "LABEL_0", - "1": "LABEL_1" - }, - "init_std": 0.02, - "is_decoder": false, - "is_encoder_decoder": false, - "label2id": { - "LABEL_0": 0, - "LABEL_1": 1 - }, - "length_penalty": 1.0, - "max_length": 20, - "max_source_positions": 1500, - "min_length": 0, - "model_type": "qwen2_5_omni_audio_encoder", - "n_window": 100, - "no_repeat_ngram_size": 0, - "num_beam_groups": 1, - "num_beams": 1, - "num_hidden_layers": 32, - "num_mel_bins": 128, - "num_return_sequences": 1, - "output_attentions": false, - "output_dim": 3584, - "output_hidden_states": false, - "output_scores": false, - "pad_token_id": null, - "prefix": null, - "problem_type": null, - "pruned_heads": {}, - "remove_invalid_values": false, - "repetition_penalty": 1.0, - "return_dict": true, - "return_dict_in_generate": false, - "scale_embedding": false, - "sep_token_id": null, - "suppress_tokens": null, - "task_specific_params": null, - "temperature": 1.0, - "tf_legacy_loss": false, - "tie_encoder_decoder": false, - "tie_word_embeddings": true, - "tokenizer_class": null, - "top_k": 50, - "top_p": 1.0, - "torch_dtype": null, - "torchscript": false, - "typical_p": 1.0, - "use_bfloat16": false - }, - "text_config": { - "model_type": "qwen2_5_omni_text", - "hidden_act": "silu", - "hidden_size": 3584, - "init_std": 0.02, - "intermediate_size": 18944, - "vocab_size": 152064, - "num_attention_heads": 28, - "num_hidden_layers": 28, - "num_key_value_heads": 4, - "max_position_embeddings": 32768, - "max_window_layers": 28, - "rms_norm_eps": 1e-06, - "rope_scaling": { - "mrope_section": [ - 16, - 24, - 24 - ], - "rope_type": "default", - "type": "default" - }, - "use_cache": true, - "rope_theta": 1000000.0, - "use_sliding_window": false, - "sliding_window": 32768, - "attention_dropout": 0.0, - "tie_word_embeddings": false - }, - "audio_end_token_id": 151648, - "audio_start_token_id": 151647, - "audio_token_index": 151646, - "bos_token_id": 151644, - "eos_token_id": 151645, - "ignore_index": -100, - "image_token_index": 151655, - "init_std": 0.02, - "model_type": "qwen2_5_omni_thinker", - "pad_token_id": 151643, - "position_id_per_seconds": 25, - "seconds_per_chunk": 2, - "torch_dtype": "bfloat16", - "user_token_id": 872, - "video_token_index": 151656, - "vision_config": { - "_attn_implementation_autoset": true, - "_name_or_path": "", - "add_cross_attention": false, - "architectures": null, - "bad_words_ids": null, - "begin_suppress_tokens": null, - "bos_token_id": null, - "chunk_size_feed_forward": 0, - "cross_attention_hidden_size": null, - "decoder_start_token_id": null, - "depth": 32, - "diversity_penalty": 0.0, - "do_sample": false, - "early_stopping": false, - "embed_dim": 1280, - "encoder_no_repeat_ngram_size": 0, - "eos_token_id": null, - "exponential_decay_length_penalty": null, - "finetuning_task": null, - "forced_bos_token_id": null, - "forced_eos_token_id": null, - "fullatt_block_indexes": [ - 7, - 15, - 23, - 31 - ], - "hidden_act": "silu", - "hidden_size": 1280, - "id2label": { - "0": "LABEL_0", - "1": "LABEL_1" - }, - "in_channels": 3, - "in_chans": 3, - "init_std": 0.02, - "intermediate_size": 3420, - "is_decoder": false, - "is_encoder_decoder": false, - "label2id": { - "LABEL_0": 0, - "LABEL_1": 1 - }, - "length_penalty": 1.0, - "max_length": 20, - "min_length": 0, - "model_type": "qwen2_5_omni_vision_encoder", - "no_repeat_ngram_size": 0, - "num_beam_groups": 1, - "num_beams": 1, - "num_heads": 16, - "num_return_sequences": 1, - "out_hidden_size": 3584, - "output_attentions": false, - "output_hidden_states": false, - "output_scores": false, - "pad_token_id": null, - "patch_size": 14, - "prefix": null, - "problem_type": null, - "pruned_heads": {}, - "remove_invalid_values": false, - "repetition_penalty": 1.0, - "return_dict": true, - "return_dict_in_generate": false, - "sep_token_id": null, - "spatial_merge_size": 2, - "spatial_patch_size": 14, - "suppress_tokens": null, - "task_specific_params": null, - "temperature": 1.0, - "temporal_patch_size": 2, - "tf_legacy_loss": false, - "tie_encoder_decoder": false, - "tie_word_embeddings": true, - "tokenizer_class": null, - "tokens_per_second": 25, - "top_k": 50, - "top_p": 1.0, - "torch_dtype": null, - "torchscript": false, - "typical_p": 1.0, - "use_bfloat16": false, - "window_size": 112 - }, - "vision_end_token_id": 151653, - "vision_start_token_id": 151652, - "vision_token_id": 151654 - }, - "token2wav_config": { - "_attn_implementation_autoset": true, - "bigvgan_config": { - "_attn_implementation_autoset": true, - "_name_or_path": "", - "add_cross_attention": false, - "architectures": null, - "bad_words_ids": null, - "begin_suppress_tokens": null, - "bos_token_id": null, - "chunk_size_feed_forward": 0, - "cross_attention_hidden_size": null, - "decoder_start_token_id": null, - "diversity_penalty": 0.0, - "do_sample": false, - "early_stopping": false, - "encoder_no_repeat_ngram_size": 0, - "eos_token_id": null, - "exponential_decay_length_penalty": null, - "finetuning_task": null, - "forced_bos_token_id": null, - "forced_eos_token_id": null, - "id2label": { - "0": "LABEL_0", - "1": "LABEL_1" - }, - "is_decoder": false, - "is_encoder_decoder": false, - "label2id": { - "LABEL_0": 0, - "LABEL_1": 1 - }, - "length_penalty": 1.0, - "max_length": 20, - "mel_dim": 80, - "min_length": 0, - "model_type": "qwen2_5_omni_bigvgan", - "no_repeat_ngram_size": 0, - "num_beam_groups": 1, - "num_beams": 1, - "num_return_sequences": 1, - "output_attentions": false, - "output_hidden_states": false, - "output_scores": false, - "pad_token_id": null, - "prefix": null, - "problem_type": null, - "pruned_heads": {}, - "remove_invalid_values": false, - "repetition_penalty": 1.0, - "resblock_dilation_sizes": [ - [ - 1, - 3, - 5 - ], - [ - 1, - 3, - 5 - ], - [ - 1, - 3, - 5 - ] - ], - "resblock_kernel_sizes": [ - 3, - 7, - 11 - ], - "return_dict": true, - "return_dict_in_generate": false, - "sep_token_id": null, - "suppress_tokens": null, - "task_specific_params": null, - "temperature": 1.0, - "tf_legacy_loss": false, - "tie_encoder_decoder": false, - "tie_word_embeddings": true, - "tokenizer_class": null, - "top_k": 50, - "top_p": 1.0, - "torch_dtype": null, - "torchscript": false, - "typical_p": 1.0, - "upsample_initial_channel": 1536, - "upsample_kernel_sizes": [ - 11, - 7, - 4, - 4, - 4, - 4 - ], - "upsample_rates": [ - 5, - 3, - 2, - 2, - 2, - 2 - ], - "use_bfloat16": false, - "use_bias_at_final": false - }, - "dit_config": { - "_attn_implementation_autoset": true, - "_name_or_path": "", - "add_cross_attention": false, - "architectures": null, - "bad_words_ids": null, - "begin_suppress_tokens": null, - "bos_token_id": null, - "chunk_size_feed_forward": 0, - "cross_attention_hidden_size": null, - "decoder_start_token_id": null, - "depth": 22, - "dim": 1024, - "diversity_penalty": 0.0, - "do_sample": false, - "dropout": 0.1, - "early_stopping": false, - "emb_dim": 512, - "enc_attention_channels": 64, - "enc_channels": [ - 256, - 256, - 256, - 256, - 768 - ], - "enc_dilations": [ - 1, - 2, - 3, - 4, - 1 - ], - "enc_dim": 128, - "enc_emb_dim": 192, - "enc_global_context": true, - "enc_kernel_sizes": [ - 5, - 3, - 3, - 3, - 1 - ], - "enc_lin_neurons": 192, - "enc_res2net_scale": 2, - "enc_se_channels": 64, - "encoder_no_repeat_ngram_size": 0, - "eos_token_id": null, - "exponential_decay_length_penalty": null, - "ff_mult": 2, - "finetuning_task": null, - "forced_bos_token_id": null, - "forced_eos_token_id": null, - "head_dim": 64, - "heads": 16, - "id2label": { - "0": "LABEL_0", - "1": "LABEL_1" - }, - "is_decoder": false, - "is_encoder_decoder": false, - "label2id": { - "LABEL_0": 0, - "LABEL_1": 1 - }, - "length_penalty": 1.0, - "max_length": 20, - "mel_dim": 80, - "min_length": 0, - "model_type": "qwen2_5_omni_dit", - "no_repeat_ngram_size": 0, - "num_beam_groups": 1, - "num_beams": 1, - "num_embeds": 8193, - "num_return_sequences": 1, - "output_attentions": false, - "output_hidden_states": false, - "output_scores": false, - "pad_token_id": null, - "prefix": null, - "problem_type": null, - "pruned_heads": {}, - "remove_invalid_values": false, - "repeats": 2, - "repetition_penalty": 1.0, - "return_dict": true, - "return_dict_in_generate": false, - "sep_token_id": null, - "suppress_tokens": null, - "task_specific_params": null, - "temperature": 1.0, - "tf_legacy_loss": false, - "tie_encoder_decoder": false, - "tie_word_embeddings": true, - "tokenizer_class": null, - "top_k": 50, - "top_p": 1.0, - "torch_dtype": "float32", - "torchscript": false, - "typical_p": 1.0, - "use_bfloat16": false - }, - "model_type": "qwen2_5_omni_token2wav" - }, - "torch_dtype": "bfloat16", - "transformers_version": "4.50.0.dev0" -} \ No newline at end of file diff --git a/examples/qwen2_5omni/image_infer.cpp b/examples/qwen2_5omni/image_infer.cpp index 41bf770b1..3c0bf214b 100644 --- a/examples/qwen2_5omni/image_infer.cpp +++ b/examples/qwen2_5omni/image_infer.cpp @@ -50,12 +50,12 @@ MLLM_MAIN({ std::string prompt_text; fmt::print("Image path (or 'exit/quit'): "); - image_path = "../../../mllm2-former/mllm/rsc/pics.jpg"; + image_path = ""; //std::getline(std::cin, image_path); if (image_path == "exit" || image_path == "quit") { return 0; } fmt::print("Prompt text: "); - prompt_text = "描述图片中物体"; + prompt_text = ""; //std::getline(std::cin, prompt_text); try { From a235a134ec63263c1bf58c4beb7c5731ac5eec56 Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Fri, 23 Jan 2026 15:35:49 +0800 Subject: [PATCH 12/17] fix --- examples/qwen2_5omni/audio_infer.cpp | 4 +- .../qwen2_5omni/config_qwen2_5omni_7B.json | 495 ++++++++++++++++++ 2 files changed, 497 insertions(+), 2 deletions(-) create mode 100644 examples/qwen2_5omni/config_qwen2_5omni_7B.json diff --git a/examples/qwen2_5omni/audio_infer.cpp b/examples/qwen2_5omni/audio_infer.cpp index 014b4688f..d159c2b3e 100644 --- a/examples/qwen2_5omni/audio_infer.cpp +++ b/examples/qwen2_5omni/audio_infer.cpp @@ -51,12 +51,12 @@ MLLM_MAIN({ fmt::print("Audio path (or 'exit/quit'): "); //std::getline(std::cin, audio_path); //if (audio_path == "exit" || audio_path == "quit") { return 0; } - audio_path = "/Users/kkkai/Desktop/mllm2-former/mllm/rsc/recognize.wav"; + audio_path = ""; fmt::print("Prompt text: "); //std::getline(std::cin, prompt_text); //if (prompt_text.empty()) { prompt_text = "Please describe the audio."; } - prompt_text = "复述这段音频"; + prompt_text = ""; try { fmt::print("Processing...\n"); diff --git a/examples/qwen2_5omni/config_qwen2_5omni_7B.json b/examples/qwen2_5omni/config_qwen2_5omni_7B.json new file mode 100644 index 000000000..8f27b94b9 --- /dev/null +++ b/examples/qwen2_5omni/config_qwen2_5omni_7B.json @@ -0,0 +1,495 @@ +{ + "architectures": [ + "Qwen2_5OmniModel" + ], + "enable_audio_output": true, + "enable_talker": true, + "model_type": "qwen2_5_omni", + "talker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/talker", + "architectures": [ + "Qwen2OmniTalkerForConditionalGeneration" + ], + "attention_dropout": 0.0, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "embedding_size": 3584, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 896, + "image_token_index": 151655, + "init_std": 0.02, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2_5_omni_talker", + "num_attention_heads": 12, + "num_hidden_layers": 24, + "num_key_value_heads": 4, + "position_id_per_seconds": 25, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "seconds_per_chunk": 2, + "sliding_window": 32768, + "spatial_merge_size": 2, + "torch_dtype": "bfloat16", + "tts_codec_end_token_id": 8294, + "tts_codec_mask_token_id": 8296, + "tts_codec_pad_token_id": 8292, + "tts_codec_start_token_id": 8293, + "tts_text_end_token_id": 151861, + "tts_text_pad_token_id": 151859, + "tts_text_start_token_id": 151860, + "use_cache": true, + "use_sliding_window": false, + "video_token_index": 151656, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vocab_size": 8448 + }, + "thinker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/thinker", + "architectures": [ + "Qwen2OmniNaViTThinkerForConditionalGeneration" + ], + "audio_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "activation_dropout": 0.0, + "activation_function": "gelu", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "d_model": 1280, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_attention_heads": 20, + "encoder_ffn_dim": 5120, + "encoder_layerdrop": 0.0, + "encoder_layers": 32, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "init_std": 0.02, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_source_positions": 1500, + "min_length": 0, + "model_type": "qwen2_5_omni_audio_encoder", + "n_window": 100, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 32, + "num_mel_bins": 128, + "num_return_sequences": 1, + "output_attentions": false, + "output_dim": 3584, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "scale_embedding": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "text_config": { + "model_type": "qwen2_5_omni_text", + "hidden_act": "silu", + "hidden_size": 3584, + "init_std": 0.02, + "intermediate_size": 18944, + "vocab_size": 152064, + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "use_cache": true, + "rope_theta": 1000000.0, + "use_sliding_window": false, + "sliding_window": 32768, + "attention_dropout": 0.0, + "tie_word_embeddings": false + }, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "bos_token_id": 151644, + "eos_token_id": 151645, + "ignore_index": -100, + "image_token_index": 151655, + "init_std": 0.02, + "model_type": "qwen2_5_omni_thinker", + "pad_token_id": 151643, + "position_id_per_seconds": 25, + "seconds_per_chunk": 2, + "torch_dtype": "bfloat16", + "user_token_id": 872, + "video_token_index": 151656, + "vision_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 32, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "embed_dim": 1280, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "in_channels": 3, + "in_chans": 3, + "init_std": 0.02, + "intermediate_size": 3420, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "qwen2_5_omni_vision_encoder", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_heads": 16, + "num_return_sequences": 1, + "out_hidden_size": 3584, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "patch_size": 14, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "temporal_patch_size": 2, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "tokens_per_second": 25, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false, + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654 + }, + "token2wav_config": { + "_attn_implementation_autoset": true, + "bigvgan_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_bigvgan", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "upsample_initial_channel": 1536, + "upsample_kernel_sizes": [ + 11, + 7, + 4, + 4, + 4, + 4 + ], + "upsample_rates": [ + 5, + 3, + 2, + 2, + 2, + 2 + ], + "use_bfloat16": false, + "use_bias_at_final": false + }, + "dit_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 22, + "dim": 1024, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.1, + "early_stopping": false, + "emb_dim": 512, + "enc_attention_channels": 64, + "enc_channels": [ + 256, + 256, + 256, + 256, + 768 + ], + "enc_dilations": [ + 1, + 2, + 3, + 4, + 1 + ], + "enc_dim": 128, + "enc_emb_dim": 192, + "enc_global_context": true, + "enc_kernel_sizes": [ + 5, + 3, + 3, + 3, + 1 + ], + "enc_lin_neurons": 192, + "enc_res2net_scale": 2, + "enc_se_channels": 64, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "ff_mult": 2, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "head_dim": 64, + "heads": 16, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_dit", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_embeds": 8193, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repeats": 2, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "float32", + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "model_type": "qwen2_5_omni_token2wav" + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0.dev0" +} From adc3b644af619085240f98ecb058f0cbf66da7fc Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Sun, 25 Jan 2026 01:51:00 +0800 Subject: [PATCH 13/17] add ConvTranspose1dOp & TanhOp --- mllm/backends/cpu/CPUBackend.cpp | 14 +-- mllm/backends/cpu/ops/ConvTranspose1DOp.cpp | 91 ++++++++++++++++++++ mllm/backends/cpu/ops/ConvTranspose1DOp.hpp | 25 ++++++ mllm/backends/cpu/ops/TanhOp.cpp | 42 +++++++++ mllm/backends/cpu/ops/TanhOp.hpp | 25 ++++++ mllm/compile/ir/GeneratedRTTIKind.hpp | 2 + mllm/compile/ir/NodeRTTIClassOfImpl.hpp | 6 ++ mllm/compile/ir/linalg/Op.cpp | 2 + mllm/compile/ir/linalg/Op.hpp | 4 + mllm/core/OpTypes.hpp | 4 + mllm/core/aops/ConvTranspose1DOp.cpp | 95 +++++++++++++++++++++ mllm/core/aops/ConvTranspose1DOp.hpp | 52 +++++++++++ mllm/core/aops/TanhOp.cpp | 37 ++++++++ mllm/core/aops/TanhOp.hpp | 33 +++++++ mllm/nn/Nn.hpp | 2 + mllm/nn/layers/ConvTranspose1D.cpp | 32 +++++++ mllm/nn/layers/ConvTranspose1D.hpp | 29 +++++++ mllm/nn/layers/Tanh.cpp | 12 +++ mllm/nn/layers/Tanh.hpp | 21 +++++ 19 files changed, 522 insertions(+), 6 deletions(-) create mode 100644 mllm/backends/cpu/ops/ConvTranspose1DOp.cpp create mode 100644 mllm/backends/cpu/ops/ConvTranspose1DOp.hpp create mode 100644 mllm/backends/cpu/ops/TanhOp.cpp create mode 100644 mllm/backends/cpu/ops/TanhOp.hpp create mode 100644 mllm/core/aops/ConvTranspose1DOp.cpp create mode 100644 mllm/core/aops/ConvTranspose1DOp.hpp create mode 100644 mllm/core/aops/TanhOp.cpp create mode 100644 mllm/core/aops/TanhOp.hpp create mode 100644 mllm/nn/layers/ConvTranspose1D.cpp create mode 100644 mllm/nn/layers/ConvTranspose1D.hpp create mode 100644 mllm/nn/layers/Tanh.cpp create mode 100644 mllm/nn/layers/Tanh.hpp diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index 0964cba0d..f4b909913 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -14,6 +14,7 @@ #include "mllm/backends/cpu/ops/ConcatOp.hpp" #include "mllm/backends/cpu/ops/ContiguousOp.hpp" #include "mllm/backends/cpu/ops/Conv1DOp.hpp" +#include "mllm/backends/cpu/ops/ConvTranspose1DOp.hpp" #include "mllm/backends/cpu/ops/Conv2DOp.hpp" #include "mllm/backends/cpu/ops/Conv3DOp.hpp" #include "mllm/backends/cpu/ops/CopyOp.hpp" @@ -52,6 +53,7 @@ #include "mllm/backends/cpu/ops/Scatter2ShardsOp.hpp" #include "mllm/backends/cpu/ops/SiLUOp.hpp" #include "mllm/backends/cpu/ops/SigmoidOp.hpp" +#include "mllm/backends/cpu/ops/TanhOp.hpp" #include "mllm/backends/cpu/ops/SliceOp.hpp" #include "mllm/backends/cpu/ops/SoftmaxOp.hpp" #include "mllm/backends/cpu/ops/SplitOp.hpp" @@ -78,12 +80,12 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { CPUSiLUOpFactory, CPUSigmoidOpFactory, CPURMSNormOpFactory, CPUGELUOpFactory, CPUQuickGELUOpFactory, CPUReLUOpFactory, CPUMatMulOpFactory, CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory, CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, - CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, - CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, - CPUConv2DOpFactory, CPULayerNorm2DOpFactory, CPUInterpolateOpFactory, CPUPadOpFactory, CPUMaskedScatterOpFactory, - CPUArgsortOpFactory, CPUCloneOpFactory, CPUAvgPool1dOpFactory, CPUFlashAttention2SwaSinkOpFactory, - CPURadixAttnRelaxOpFactory, CPURadixAttnSwaSinkOpFactory, CPUEqualOpFactory, CPUWhereOpFactory, - CPUGatherOpFactory>(); + CPUConvTranspose1DOpFactory, CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, + CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, + CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, CPUConv2DOpFactory, CPULayerNorm2DOpFactory, + CPUInterpolateOpFactory, CPUPadOpFactory, CPUMaskedScatterOpFactory, CPUArgsortOpFactory, CPUCloneOpFactory, + CPUAvgPool1dOpFactory, CPUFlashAttention2SwaSinkOpFactory, CPURadixAttnRelaxOpFactory, + CPURadixAttnSwaSinkOpFactory, CPUEqualOpFactory, CPUWhereOpFactory, CPUGatherOpFactory, CPUTanhOpFactory>(); } CPUBackend::~CPUBackend() { diff --git a/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp b/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp new file mode 100644 index 000000000..cfa38bf34 --- /dev/null +++ b/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp @@ -0,0 +1,91 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/backends/cpu/ops/ConvTranspose1DOp.hpp" +#include "mllm/core/Parallel.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::cpu { + +CPUConvTranspose1DOp::CPUConvTranspose1DOp(const aops::ConvTranspose1DOpOptions& options) + : aops::ConvTranspose1DOp(options) {} + +void CPUConvTranspose1DOp::forward(const std::vector& inputs, std::vector& outputs) { + auto& input = inputs[0]; + auto& output = outputs[0]; + + auto i_shape = input.shape(); + auto o_shape = output.shape(); + + // input shape: [batch, in_channels, sequence] + // output shape: [batch, out_channels, out_sequence] + const int batch = i_shape[0]; + const int in_channels = i_shape[1]; + const int sequence = i_shape[2]; + + const int out_channels = o_shape[1]; + const int out_sequence = o_shape[2]; + + const int kernel_size = options_.kernel_size; + const int stride = options_.stride; + const int padding = options_.padding; + const int dilation = options_.dilation; + const int groups = options_.groups; + + const int in_channels_per_group = in_channels / groups; + const int out_channels_per_group = out_channels / groups; + + MLLM_RT_ASSERT(weight_.dtype() == kFloat32); + const auto* weight_ptr = weight_.ptr(); + const auto* input_ptr = input.ptr(); + auto* output_ptr = output.ptr(); + + float* bias_ptr = nullptr; + if (options_.bias && !bias_.isNil()) { bias_ptr = bias_.ptr(); } + + std::fill_n(output_ptr, output.numel(), 0.0f); + + const int total_iterations = batch * out_channels * out_sequence; + + switch (output.dtype()) { + case kFloat32: + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, 4, idx, 0, total_iterations, 1, { + int b = idx / (out_channels * out_sequence); + int oc = (idx % (out_channels * out_sequence)) / out_sequence; + int out_pos = idx % out_sequence; + + const int group_idx = oc / out_channels_per_group; + const int oc_in_group = oc % out_channels_per_group; + + float sum = 0.0f; + + for (int ic_in_group = 0; ic_in_group < in_channels_per_group; ++ic_in_group) { + const int ic = group_idx * in_channels_per_group + ic_in_group; + const int base_input_idx = b * (in_channels * sequence) + ic * sequence; + + const int base_weight_idx = (ic * out_channels_per_group + oc_in_group) * kernel_size; + + for (int k = 0; k < kernel_size; ++k) { + int input_pos = out_pos + padding - k * dilation; + if (input_pos % stride != 0) { continue; } + input_pos /= stride; + if (input_pos < 0 || input_pos >= sequence) { continue; } + + const int input_idx = base_input_idx + input_pos; + const int weight_idx = base_weight_idx + k; + + sum += input_ptr[input_idx] * weight_ptr[weight_idx]; + } + } + + if (bias_ptr) { sum += bias_ptr[oc]; } + + const int output_idx = b * (out_channels * out_sequence) + oc * out_sequence + out_pos; + output_ptr[output_idx] = sum; + }); + break; + default: NYI("ConvTranspose1D: unsupported data type"); + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/ConvTranspose1DOp.hpp b/mllm/backends/cpu/ops/ConvTranspose1DOp.hpp new file mode 100644 index 000000000..fd1163ed3 --- /dev/null +++ b/mllm/backends/cpu/ops/ConvTranspose1DOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/ConvTranspose1DOp.hpp" + +namespace mllm::cpu { + +class CPUConvTranspose1DOp final : public aops::ConvTranspose1DOp { + public: + explicit CPUConvTranspose1DOp(const aops::ConvTranspose1DOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUConvTranspose1DOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::ConvTranspose1DOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/TanhOp.cpp b/mllm/backends/cpu/ops/TanhOp.cpp new file mode 100644 index 000000000..3d8dc6af1 --- /dev/null +++ b/mllm/backends/cpu/ops/TanhOp.cpp @@ -0,0 +1,42 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include + +#include "mllm/backends/cpu/ops/TanhOp.hpp" +#include "mllm/core/Parallel.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::cpu { + +CPUTanhOp::CPUTanhOp(const aops::TanhOpOptions& options) : aops::TanhOp(options) {} + +void CPUTanhOp::forward(const std::vector& inputs, std::vector& outputs) { + const auto& X = inputs[0]; + auto& Y = outputs[0]; + + const auto numel = X.numel(); + + switch (X.dtype()) { + case kFloat32: { + const auto* x_ptr = X.ptr(); + auto* y_ptr = Y.ptr(); + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, 4, idx, 0, numel, 1, { + y_ptr[idx] = std::tanh(x_ptr[idx]); + }); + break; + } + case kFloat16: { + const auto* x_ptr = X.ptr(); + auto* y_ptr = Y.ptr(); + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, 4, idx, 0, numel, 1, { + float v = static_cast(x_ptr[idx]); + y_ptr[idx] = static_cast(std::tanh(v)); + }); + break; + } + default: NYI("CPUTanhOp::forward not support dtype {}", nameOfType(X.dtype())); break; + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/TanhOp.hpp b/mllm/backends/cpu/ops/TanhOp.hpp new file mode 100644 index 000000000..c88fae9ce --- /dev/null +++ b/mllm/backends/cpu/ops/TanhOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/TanhOp.hpp" + +namespace mllm::cpu { + +class CPUTanhOp final : public aops::TanhOp { + public: + explicit CPUTanhOp(const aops::TanhOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUTanhOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::TanhOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index d100dc621..2d83493d1 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -44,6 +44,7 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_RepeatOp, RK_Op_LinalgIROp_PermuteOp, RK_Op_LinalgIROp_Conv1DOp, + RK_Op_LinalgIROp_ConvTranspose1DOp, RK_Op_LinalgIROp_Conv2DOp, RK_Op_LinalgIROp_Conv3DOp, RK_Op_LinalgIROp_GELUOp, @@ -86,6 +87,7 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_EqualOp, RK_Op_LinalgIROp_WhereOp, RK_Op_LinalgIROp_SigmoidOp, + RK_Op_LinalgIROp_TanhOp, RK_Op_LinalgIROp_CustomizedOp, RK_Op_LinalgIROp_Last, RK_Op_GraphIROp, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index 6a98797a9..4c3313cf9 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -102,6 +102,9 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_CONV1DOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_Conv1DOp && (v)->getKind() <= RK_Op_LinalgIROp_Conv1DOp +#define RTTI_RK_OP_LINALGIROP_CONVTRANSPOSE1DOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_ConvTranspose1DOp && (v)->getKind() <= RK_Op_LinalgIROp_ConvTranspose1DOp + #define RTTI_RK_OP_LINALGIROP_CONV2DOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_Conv2DOp && (v)->getKind() <= RK_Op_LinalgIROp_Conv2DOp @@ -229,6 +232,9 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_SIGMOIDOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_SigmoidOp && (v)->getKind() <= RK_Op_LinalgIROp_SigmoidOp +#define RTTI_RK_OP_LINALGIROP_TANHOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_TanhOp && (v)->getKind() <= RK_Op_LinalgIROp_TanhOp + #define RTTI_RK_OP_LINALGIROP_CUSTOMIZEDOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_CustomizedOp && (v)->getKind() <= RK_Op_LinalgIROp_CustomizedOp diff --git a/mllm/compile/ir/linalg/Op.cpp b/mllm/compile/ir/linalg/Op.cpp index bb4e2fb9d..ad05e9437 100644 --- a/mllm/compile/ir/linalg/Op.cpp +++ b/mllm/compile/ir/linalg/Op.cpp @@ -55,6 +55,7 @@ LINALG_AOPS_DECL(OpTypes::kTranspose, TransposeOp); LINALG_AOPS_DECL(OpTypes::kRMSNorm, RMSNormOp); LINALG_AOPS_DECL(OpTypes::kSiLU, SiLUOp); LINALG_AOPS_DECL(OpTypes::kSigmoid, SigmoidOp); +LINALG_AOPS_DECL(OpTypes::kTanh, TanhOp); LINALG_AOPS_DECL(OpTypes::kCastType, CastTypeOp); @@ -70,6 +71,7 @@ LINALG_AOPS_DECL(OpTypes::kRepeat, RepeatOp); LINALG_AOPS_DECL(OpTypes::kPermute, PermuteOp); LINALG_AOPS_DECL(OpTypes::kConv1D, Conv1DOp); +LINALG_AOPS_DECL(OpTypes::kConvTranspose1D, ConvTranspose1DOp); LINALG_AOPS_DECL(OpTypes::kConv2D, Conv2DOp); LINALG_AOPS_DECL(OpTypes::kConv3D, Conv3DOp); diff --git a/mllm/compile/ir/linalg/Op.hpp b/mllm/compile/ir/linalg/Op.hpp index 02d04400b..6e6de4785 100644 --- a/mllm/compile/ir/linalg/Op.hpp +++ b/mllm/compile/ir/linalg/Op.hpp @@ -29,6 +29,7 @@ class TransposeOp; class RMSNormOp; class SiLUOp; class SigmoidOp; +class TanhOp; class CausalMaskOp; class CastTypeOp; class X2XOp; @@ -38,6 +39,7 @@ class FlashAttention2Op; class RepeatOp; class PermuteOp; class Conv1DOp; +class ConvTranspose1DOp; class Conv2DOp; class Conv3DOp; class GELUOp; @@ -188,6 +190,7 @@ LINALG_AOPS_DEFINE(TransposeOp, TRANSPOSEOP); LINALG_AOPS_DEFINE(RMSNormOp, RMSNORMOP); LINALG_AOPS_DEFINE(SiLUOp, SILUOP); LINALG_AOPS_DEFINE(SigmoidOp, SIGMOIDOP); +LINALG_AOPS_DEFINE(TanhOp, TANHOP); LINALG_AOPS_DEFINE(CastTypeOp, CASTTYPEOP); @@ -201,6 +204,7 @@ LINALG_AOPS_DEFINE(RepeatOp, REPEATOP); LINALG_AOPS_DEFINE(PermuteOp, PERMUTEOP); LINALG_AOPS_DEFINE(Conv1DOp, CONV1DOP); +LINALG_AOPS_DEFINE(ConvTranspose1DOp, CONVTRANSPOSE1DOP); LINALG_AOPS_DEFINE(Conv2DOp, CONV2DOP); LINALG_AOPS_DEFINE(Conv3DOp, CONV3DOP); diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index 310b39cd0..d64d484fe 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -96,6 +96,8 @@ enum class OpTypes : int32_t { kWhere = 74, kSigmoid = 75, + kTanh = 76, + kConvTranspose1D = 77, // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, @@ -181,6 +183,8 @@ inline std::string optype2Str(OpTypes type) { case OpTypes::kEqual: return "Equal"; case OpTypes::kWhere: return "Where"; case OpTypes::kSigmoid: return "Sigmoid"; + case OpTypes::kTanh: return "Tanh"; + case OpTypes::kConvTranspose1D: return "ConvTranspose1D"; case OpTypes::kDynamicOp_Start: return "DynamicOp_Start"; case OpTypes::kOpType_End: return "OpType_End"; default: return "Unknown"; diff --git a/mllm/core/aops/ConvTranspose1DOp.cpp b/mllm/core/aops/ConvTranspose1DOp.cpp new file mode 100644 index 000000000..25d1b5935 --- /dev/null +++ b/mllm/core/aops/ConvTranspose1DOp.cpp @@ -0,0 +1,95 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/ConvTranspose1DOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/graph/Op.hpp" +#include "mllm/compile/ir/tensor/Op.hpp" + +namespace mllm::aops { + +ConvTranspose1DOp::ConvTranspose1DOp(const ConvTranspose1DOpOptions& options) + : BaseOp(OpTypes::kConvTranspose1D), options_(options) {} + +void ConvTranspose1DOp::load(const ParameterFile::ptr_t& ploader) { + switch (ploader->version()) { + case ModelFileVersion::kV1: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + weight_ = weight_.view({options_.in_channels, options_.out_channels / options_.groups, options_.kernel_size}); + if (options_.bias) { bias_ = bias_.view({options_.out_channels}); } + break; + } + case ModelFileVersion::kUserTemporary: + case ModelFileVersion::kV2: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + break; + } + default: NYI("Unsupported model file version") + } +} + +void ConvTranspose1DOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + + if (weight_ && !ir_ctx->lookupSymbolTable(getName() + ".weight")) { + ir::IRWriterGuard guard(ir_ctx, ir_ctx->lookupSymbolTable("init")->cast_()->getTopRegion()); + ir_ctx->create(ir_ctx->create(weight_)); + if (options_.bias) { ir_ctx->create(ir_ctx->create(bias_)); } + } + + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void ConvTranspose1DOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("ConvTranspose1DOp::forward not implemented in aops base."); +} + +void ConvTranspose1DOp::reshape(const std::vector& inputs, std::vector& outputs) { + const auto& i = inputs[0]; + const auto& ishape = i.shape(); + + if (ishape.size() != 3) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "ConvTranspose1DOp expects 3D input, got {} D", ishape.size()); + outputs.emplace_back(Tensor::empty(i.shape(), i.dtype(), i.device())); + return; + } + + const int batch = ishape[0]; + const int in_channels = ishape[1]; + const int sequence = ishape[2]; + + MLLM_RT_ASSERT_EQ(in_channels, options_.in_channels); + MLLM_RT_ASSERT_EQ(in_channels % options_.groups, 0); + MLLM_RT_ASSERT_EQ(options_.out_channels % options_.groups, 0); + + const int kernel_size = options_.kernel_size; + const int stride = options_.stride; + const int dilation = options_.dilation; + const int padding = options_.padding; + const int output_padding = options_.output_padding; + + const int seq_out = (sequence - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1; + + auto new_shape = std::vector{batch, options_.out_channels, seq_out}; + outputs.emplace_back(Tensor::empty(new_shape, i.dtype(), i.device())); +} + +void ConvTranspose1DOp::setup(const std::vector& inputs, std::vector& outputs) { + BaseOp::setup(inputs, outputs); +} + +ParameterFile::ptr_t ConvTranspose1DOp::getParams() { + auto p = ParameterFile::create(); + p->push(getName() + ".weight", weight_); + if (options_.bias) { p->push(getName() + ".bias", bias_); } + return p; +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/ConvTranspose1DOp.hpp b/mllm/core/aops/ConvTranspose1DOp.hpp new file mode 100644 index 000000000..daeda0b8e --- /dev/null +++ b/mllm/core/aops/ConvTranspose1DOp.hpp @@ -0,0 +1,52 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct ConvTranspose1DOpOptions : public BaseOpOptions { + int32_t in_channels; + int32_t out_channels; + int32_t kernel_size; + int32_t stride = 1; + int32_t padding = 0; + int32_t output_padding = 0; + int32_t dilation = 1; + int32_t groups = 1; + bool bias = true; +}; + +class ConvTranspose1DOp : public BaseOp { + public: + explicit ConvTranspose1DOp(const ConvTranspose1DOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + ParameterFile::ptr_t getParams() override; + + inline Tensor& weight() { return weight_; } + + inline Tensor& bias() { return bias_; } + + inline ConvTranspose1DOpOptions& options() { return options_; } + + protected: + Tensor weight_; + Tensor bias_; + ConvTranspose1DOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/core/aops/TanhOp.cpp b/mllm/core/aops/TanhOp.cpp new file mode 100644 index 000000000..c0938d82f --- /dev/null +++ b/mllm/core/aops/TanhOp.cpp @@ -0,0 +1,37 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/TanhOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +TanhOp::TanhOp(const TanhOpOptions& options) : BaseOp(OpTypes::kTanh), options_(options) {} + +void TanhOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void TanhOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void TanhOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("TanhOp::forward not implemented in aops base."); +} + +void TanhOp::reshape(const std::vector& inputs, std::vector& outputs) { + if (options_.isInplace()) { + outputs.emplace_back(inputs[0]); + } else { + outputs.emplace_back(Tensor::empty(inputs[0].shape(), inputs[0].dtype(), inputs[0].device())); + } +} + +void TanhOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +} // namespace mllm::aops diff --git a/mllm/core/aops/TanhOp.hpp b/mllm/core/aops/TanhOp.hpp new file mode 100644 index 000000000..8b2ce4f43 --- /dev/null +++ b/mllm/core/aops/TanhOp.hpp @@ -0,0 +1,33 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct TanhOpOptions : public BaseOpOptions {}; + +class TanhOp : public BaseOp { + public: + explicit TanhOp(const TanhOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + inline TanhOpOptions& options() { return options_; } + + protected: + TanhOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/nn/Nn.hpp b/mllm/nn/Nn.hpp index fdb0edc82..160e1cb43 100644 --- a/mllm/nn/Nn.hpp +++ b/mllm/nn/Nn.hpp @@ -11,6 +11,7 @@ #include "mllm/nn/layers/RMSNorm.hpp" // IWYU pragma: export #include "mllm/nn/layers/SiLU.hpp" // IWYU pragma: export #include "mllm/nn/layers/Sigmoid.hpp" // IWYU pragma: export +#include "mllm/nn/layers/Tanh.hpp" // IWYU pragma: export #include "mllm/nn/layers/Embedding.hpp" // IWYU pragma: export #include "mllm/nn/layers/GELU.hpp" // IWYU pragma: export #include "mllm/nn/layers/QuickGELU.hpp" // IWYU pragma: export @@ -26,6 +27,7 @@ #include "mllm/nn/layers/Param.hpp" // IWYU pragma: export #include "mllm/nn/layers/KVCache.hpp" // IWYU pragma: export #include "mllm/nn/layers/Conv1D.hpp" // IWYU pragma: export +#include "mllm/nn/layers/ConvTranspose1D.hpp" // IWYU pragma: export #include "mllm/nn/layers/AvgPool1d.hpp" // IWYU pragma: export #include "mllm/nn/layers/STFT.hpp" // IWYU pragma: export #include "mllm/nn/layers/PagedAttn.hpp" // IWYU pragma: export diff --git a/mllm/nn/layers/ConvTranspose1D.cpp b/mllm/nn/layers/ConvTranspose1D.cpp new file mode 100644 index 000000000..de2a7a5c7 --- /dev/null +++ b/mllm/nn/layers/ConvTranspose1D.cpp @@ -0,0 +1,32 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/nn/layers/ConvTranspose1D.hpp" + +namespace mllm::nn { + +ConvTranspose1D::ConvTranspose1D() : Layer(OpTypes::kConvTranspose1D, aops::ConvTranspose1DOpOptions{}) {} + +ConvTranspose1D::ConvTranspose1D(int32_t in_channels, int32_t out_channels, int32_t kernel_size, int32_t stride_size, + int32_t padding, int32_t output_padding, int32_t dilation, int32_t groups, bool bias) + : Layer(OpTypes::kConvTranspose1D, aops::ConvTranspose1DOpOptions{.in_channels = in_channels, + .out_channels = out_channels, + .kernel_size = kernel_size, + .stride = stride_size, + .padding = padding, + .output_padding = output_padding, + .dilation = dilation, + .groups = groups, + .bias = bias}) {} + +ConvTranspose1D::ConvTranspose1D(const aops::ConvTranspose1DOpOptions& options) : Layer(OpTypes::kConvTranspose1D, options) {} + +Tensor ConvTranspose1D::weight() const { + return std::static_pointer_cast(impl()->getInstancedOp())->weight(); +} + +Tensor ConvTranspose1D::bias() const { + return std::static_pointer_cast(impl()->getInstancedOp())->bias(); +} + +} // namespace mllm::nn diff --git a/mllm/nn/layers/ConvTranspose1D.hpp b/mllm/nn/layers/ConvTranspose1D.hpp new file mode 100644 index 000000000..6ddc2fac3 --- /dev/null +++ b/mllm/nn/layers/ConvTranspose1D.hpp @@ -0,0 +1,29 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/ConvTranspose1DOp.hpp" + +namespace mllm::nn { + +class ConvTranspose1D : public Layer { + public: + ConvTranspose1D(); + + ConvTranspose1D(int32_t in_channels, int32_t out_channels, int32_t kernel_size, int32_t stride_size = 1, + int32_t padding = 0, int32_t output_padding = 0, int32_t dilation = 1, int32_t groups = 1, + bool bias = true); + + explicit ConvTranspose1D(const aops::ConvTranspose1DOpOptions& options); + + [[nodiscard]] Tensor weight() const; + + [[nodiscard]] Tensor bias() const; + + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD +}; + +} // namespace mllm::nn diff --git a/mllm/nn/layers/Tanh.cpp b/mllm/nn/layers/Tanh.cpp new file mode 100644 index 000000000..dda95f7ae --- /dev/null +++ b/mllm/nn/layers/Tanh.cpp @@ -0,0 +1,12 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/nn/layers/Tanh.hpp" + +namespace mllm::nn { + +Tanh::Tanh() : Layer(OpTypes::kTanh, aops::TanhOpOptions{}) {} + +Tanh::Tanh(const aops::TanhOpOptions& options) : Layer(OpTypes::kTanh, options) {} + +} // namespace mllm::nn diff --git a/mllm/nn/layers/Tanh.hpp b/mllm/nn/layers/Tanh.hpp new file mode 100644 index 000000000..ab84e7eeb --- /dev/null +++ b/mllm/nn/layers/Tanh.hpp @@ -0,0 +1,21 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/TanhOp.hpp" + +namespace mllm::nn { + +class Tanh : public Layer { + public: + Tanh(); + + explicit Tanh(const aops::TanhOpOptions& options); + + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD + MLLM_LAYER_ENABLE_INPLACE_ATTRIBUTE(Tanh) +}; + +} // namespace mllm::nn From 674f97c4cb02b07a53cb182d27799745c12fa5a5 Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Sun, 25 Jan 2026 19:25:41 +0800 Subject: [PATCH 14/17] fix: fix Tanh op and add test for Tanh Op and ConvTranspose1d Op --- mllm/backends/cpu/ops/ConvTranspose1DOp.cpp | 2 + tests/cpu/ConvTranspose1DKernelTest.hpp | 134 ++++++++++++++++++++ tests/cpu/KernelTest.cpp | 42 ++++++ tests/cpu/TanhKernelTest.hpp | 49 +++++++ 4 files changed, 227 insertions(+) create mode 100644 tests/cpu/ConvTranspose1DKernelTest.hpp create mode 100644 tests/cpu/TanhKernelTest.hpp diff --git a/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp b/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp index cfa38bf34..15a8097d1 100644 --- a/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp +++ b/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp @@ -35,6 +35,8 @@ void CPUConvTranspose1DOp::forward(const std::vector& inputs, std::vecto const int in_channels_per_group = in_channels / groups; const int out_channels_per_group = out_channels / groups; + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(output.dtype(), kFloat32); MLLM_RT_ASSERT(weight_.dtype() == kFloat32); const auto* weight_ptr = weight_.ptr(); const auto* input_ptr = input.ptr(); diff --git a/tests/cpu/ConvTranspose1DKernelTest.hpp b/tests/cpu/ConvTranspose1DKernelTest.hpp new file mode 100644 index 000000000..d7657baf1 --- /dev/null +++ b/tests/cpu/ConvTranspose1DKernelTest.hpp @@ -0,0 +1,134 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "KernelTestHelper.hpp" +#include "mllm/core/ParameterFile.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" + +using namespace mllm; // NOLINT + +void naive_conv_transpose1d(const float* input_data, const float* weight_data, const float* bias_data, float* output_data, + int batch, int in_channels, int sequence, int out_channels, int kernel_size, int stride, + int padding, int dilation, int output_padding, int groups) { + const int out_sequence = (sequence - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1; + std::fill_n(output_data, batch * out_channels * out_sequence, 0.0f); + + const int in_channels_per_group = in_channels / groups; + const int out_channels_per_group = out_channels / groups; + + for (int b = 0; b < batch; ++b) { + for (int oc = 0; oc < out_channels; ++oc) { + const int group_idx = oc / out_channels_per_group; + const int oc_in_group = oc % out_channels_per_group; + for (int out_pos = 0; out_pos < out_sequence; ++out_pos) { + float sum = 0.0f; + for (int ic_in_group = 0; ic_in_group < in_channels_per_group; ++ic_in_group) { + const int ic = group_idx * in_channels_per_group + ic_in_group; + const int base_input_idx = b * (in_channels * sequence) + ic * sequence; + const int base_weight_idx = (ic * out_channels_per_group + oc_in_group) * kernel_size; + + for (int k = 0; k < kernel_size; ++k) { + int input_pos = out_pos + padding - k * dilation; + if (input_pos % stride != 0) { continue; } + input_pos /= stride; + if (input_pos < 0 || input_pos >= sequence) { continue; } + + const int input_idx = base_input_idx + input_pos; + const int weight_idx = base_weight_idx + k; + sum += input_data[input_idx] * weight_data[weight_idx]; + } + } + if (bias_data != nullptr) { sum += bias_data[oc]; } + const int output_idx = b * (out_channels * out_sequence) + oc * out_sequence + out_pos; + output_data[output_idx] = sum; + } + } + } +} + +class ConvTranspose1DModule : public nn::Module { + nn::ConvTranspose1D conv_; + + public: + ConvTranspose1DModule(int in_channel, int out_channel, int kernel_size, int stride, int padding, int output_padding, + int dilation, int groups, bool bias) { + conv_ = reg("conv", in_channel, out_channel, kernel_size, stride, padding, output_padding, dilation, + groups, bias); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {conv_(inputs[0])}; + } +}; + +class ConvTranspose1DKernelTest : public KernelTest { + public: + bool testConvTranspose1DOnce(const std::unordered_map& cfg) { + auto batch = cfg.at("batch"); + auto in_channel = cfg.at("in_channel"); + auto out_channel = cfg.at("out_channel"); + auto sequence = cfg.at("sequence"); + auto kernel_size = cfg.at("kernel_size"); + auto stride = cfg.at("stride"); + auto padding = cfg.at("padding"); + auto output_padding = cfg.at("output_padding"); + auto dilation = cfg.at("dilation"); + auto groups = cfg.at("groups"); + auto bias = cfg.at("bias"); + + auto module = ConvTranspose1DModule(in_channel, out_channel, kernel_size, stride, padding, output_padding, dilation, + groups, bias); + + auto weight_param = + Tensor::random({in_channel, out_channel / groups, kernel_size}, -1, 1, kFloat32, kCPU); + auto bias_param = Tensor::random({out_channel}, -1, 1, kFloat32, kCPU); + weight_param.setName("conv.weight"); + bias_param.setName("conv.bias"); + + auto param = ParameterFile::create(); + param->push("conv.weight", weight_param); + if (bias) { param->push("conv.bias", bias_param); } + module.load(param); + + auto input = Tensor::random({batch, in_channel, sequence}, -1, 1, kFloat32, kCPU); + auto predict = module(input)[0]; + + auto expected = Tensor::zeros(predict.shape(), kFloat32, kCPU); + naive_conv_transpose1d(input.ptr(), weight_param.ptr(), bias ? bias_param.ptr() : nullptr, + expected.ptr(), batch, in_channel, sequence, out_channel, kernel_size, stride, padding, + dilation, output_padding, groups); + + auto result = test::allClose(expected, predict, 1e-4f, 1e-4f); + if (!result) { + print(result); + return false; + } + return true; + } + + bool testConvTranspose1D(const std::vector>& cfgs) { + for (auto& cfg : cfgs) { + if (!testConvTranspose1DOnce(cfg)) { + auto batch = cfg.at("batch"); + auto in_channel = cfg.at("in_channel"); + auto out_channel = cfg.at("out_channel"); + auto sequence = cfg.at("sequence"); + auto kernel_size = cfg.at("kernel_size"); + auto stride = cfg.at("stride"); + auto padding = cfg.at("padding"); + auto output_padding = cfg.at("output_padding"); + auto dilation = cfg.at("dilation"); + auto groups = cfg.at("groups"); + auto bias = cfg.at("bias"); + print(batch, in_channel, out_channel, sequence, kernel_size, stride, padding, output_padding, dilation, groups, bias); + return false; + } + } + return true; + } +}; diff --git a/tests/cpu/KernelTest.cpp b/tests/cpu/KernelTest.cpp index 9f8d613ee..575360703 100644 --- a/tests/cpu/KernelTest.cpp +++ b/tests/cpu/KernelTest.cpp @@ -857,6 +857,48 @@ TEST_F(FlashAttn2KernelTest, fwd_bshd) { } #endif +//===----------------------------------------------------------------------===// +// Tanh +//===----------------------------------------------------------------------===// +#include "TanhKernelTest.hpp" +TEST_F(TanhKernelTest, TanhFloat32) { EXPECT_EQ(testTanh({{8}, {2, 3, 4}}), true); } + +//===----------------------------------------------------------------------===// +// ConvTranspose1D +//===----------------------------------------------------------------------===// +#include "ConvTranspose1DKernelTest.hpp" +TEST_F(ConvTranspose1DKernelTest, Basic) { + EXPECT_EQ(testConvTranspose1D({ + { + {"batch", 1}, + {"in_channel", 2}, + {"out_channel", 3}, + {"sequence", 4}, + {"kernel_size", 3}, + {"stride", 2}, + {"padding", 1}, + {"output_padding", 0}, + {"dilation", 1}, + {"groups", 1}, + {"bias", 1}, + }, + { + {"batch", 2}, + {"in_channel", 1}, + {"out_channel", 2}, + {"sequence", 5}, + {"kernel_size", 2}, + {"stride", 1}, + {"padding", 0}, + {"output_padding", 0}, + {"dilation", 1}, + {"groups", 1}, + {"bias", 0}, + }, + }), + true); +} + //===----------------------------------------------------------------------===// // Conv2D Test // diff --git a/tests/cpu/TanhKernelTest.hpp b/tests/cpu/TanhKernelTest.hpp new file mode 100644 index 000000000..ff6762170 --- /dev/null +++ b/tests/cpu/TanhKernelTest.hpp @@ -0,0 +1,49 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include + +#include "KernelTestHelper.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" + +class TanhModule : public mllm::nn::Module { + mllm::nn::Tanh tanh_; + + public: + TanhModule() { tanh_ = reg("tanh"); } + + std::vector forward(const std::vector& inputs, + const std::vector& args) override { + return {tanh_(inputs[0])}; + } +}; + +class TanhKernelTest : public KernelTest { + public: + bool testTanh(const std::vector& shapes) { + using mllm::Tensor; + using mllm::kCPU; + using mllm::kFloat32; + TanhModule module; + + for (auto& s : shapes) { + auto input = Tensor::random(s, -3, 3, kFloat32, kCPU); + auto output = module(input)[0]; + auto expected = Tensor::empty(s, kFloat32, kCPU).alloc(); + + const auto* in_ptr = input.ptr(); + auto* out_ptr = expected.ptr(); + const auto numel = input.numel(); + for (size_t i = 0; i < numel; ++i) { out_ptr[i] = std::tanh(in_ptr[i]); } + + auto result = mllm::test::allClose(expected, output, 1e-5f, 1e-5f); + if (!result) { + mllm::print(result); + return false; + } + } + return true; + } +}; From af574aec2c8adbbe3e8db0cbe7e13c915bcbd416 Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Tue, 24 Feb 2026 15:42:37 +0800 Subject: [PATCH 15/17] add --- examples/minicpm_o45/CMakeLists.txt | 6 +- examples/qwen2_5omni/audio_out_infer.cpp | 93 ++++++++++++++++++++++++ examples/qwen2_5omni/image_infer_dbg.cpp | 91 +++++++++++++++++++++++ 3 files changed, 187 insertions(+), 3 deletions(-) create mode 100644 examples/qwen2_5omni/audio_out_infer.cpp create mode 100644 examples/qwen2_5omni/image_infer_dbg.cpp diff --git a/examples/minicpm_o45/CMakeLists.txt b/examples/minicpm_o45/CMakeLists.txt index a866fb4ec..a755efda1 100644 --- a/examples/minicpm_o45/CMakeLists.txt +++ b/examples/minicpm_o45/CMakeLists.txt @@ -2,6 +2,6 @@ add_executable(mllm-minicpm-o45-runner main.cpp) target_link_libraries(mllm-minicpm-o45-runner PRIVATE MllmRT MllmCPUBackend) target_include_directories(mllm-minicpm-o45-runner PRIVATE ${MLLM_INCLUDE_DIR}) -add_executable(mllm-minicpm-o45-runner-python main_python.cpp) -target_link_libraries(mllm-minicpm-o45-runner-python PRIVATE MllmRT MllmCPUBackend) -target_include_directories(mllm-minicpm-o45-runner-python PRIVATE ${MLLM_INCLUDE_DIR}) +# add_executable(mllm-minicpm-o45-runner-python main_python.cpp) +# target_link_libraries(mllm-minicpm-o45-runner-python PRIVATE MllmRT MllmCPUBackend) +# target_include_directories(mllm-minicpm-o45-runner-python PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen2_5omni/audio_out_infer.cpp b/examples/qwen2_5omni/audio_out_infer.cpp new file mode 100644 index 000000000..9e46fcd0e --- /dev/null +++ b/examples/qwen2_5omni/audio_out_infer.cpp @@ -0,0 +1,93 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include "wenet_audio/wav.h" + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + auto& spk_dict_path = Argparse::add("-s|--spk_dict_path").help("Speaker json path").required(true); + auto& prompt = Argparse::add("-p|--prompt").help("Prompt text").def(""); + auto& image_path = Argparse::add("-i|--image_path").help("Image path").def(""); + auto& audio_path = Argparse::add("-a|--audio_path").help("Audio path").def(""); + auto& speaker = Argparse::add("-sp|--speaker").help("Speaker name (default: first entry)").def(""); + auto& output_path = Argparse::add("-o|--output_path").help("Output wav path").def("./qwen2_5omni.wav"); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + if (!image_path.get().empty() && !audio_path.get().empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "Only one of --image_path or --audio_path can be set."); + } + + auto qwen_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen_omni = mllm::models::qwen2_5omni::Qwen2_5OmniForConditionalGeneration(qwen_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen_omni.load(param); + qwen_omni.loadSpeakers(spk_dict_path.get()); + + std::string prompt_text = prompt.get(); + if (prompt_text.empty()) { + fmt::print("Prompt text: "); + std::getline(std::cin, prompt_text); + if (prompt_text.empty()) { prompt_text = "Please respond."; } + } + + mllm::models::ARGenerationOutputPast inputs; + if (!image_path.get().empty()) { + inputs = qwen_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path.get()}); + } else if (!audio_path.get().empty()) { + inputs = qwen_tokenizer.convertAudioMessage({.prompt = prompt_text, .audio_file_path = audio_path.get()}); + } else { + inputs = qwen_tokenizer.convertMessage({.prompt = prompt_text}); + } + + mllm::models::qwen2_5omni::Qwen2_5OmniAudioGenerationConfig gen_cfg; + auto output = qwen_omni.generateAudio(inputs, gen_cfg, speaker.get()); + + auto input_len = inputs["sequence"].shape()[1]; + auto total_len = output.sequences.shape()[1]; + fmt::print("\nResponse: "); + for (int i = input_len; i < total_len; ++i) { + std::wcout << qwen_tokenizer.detokenize(output.sequences.at({0, i})) << std::flush; + } + fmt::print("\n"); + + auto wav = output.wav * 32767.0f; + wenet::WavWriter wav_writer(wav.ptr(), wav.shape().back(), 1, 24000, 16); + wav_writer.Write(output_path.get()); + + fmt::print("Saved audio to {}\n", output_path.get()); + + qwen_omni.thinker_.perfSummary(); + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/image_infer_dbg.cpp b/examples/qwen2_5omni/image_infer_dbg.cpp new file mode 100644 index 000000000..de21c8ec7 --- /dev/null +++ b/examples/qwen2_5omni/image_infer_dbg.cpp @@ -0,0 +1,91 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +//MLLM_MAIN({ +int main(int argc, char** argv) { + ::mllm::__setup_signal_handler(); + ::mllm::initializeContext(); + + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = + mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get(), qwen2_5omni_cfg.visual_spatial_merge_size); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Image CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string image_path; + std::string prompt_text; + + fmt::print("Image path (or 'exit/quit'): "); + image_path = "../../rsc/pics.jpg"; + //std::getline(std::cin, image_path); + if (image_path == "exit" || image_path == "quit") { return 0; } + + fmt::print("Prompt text: "); + prompt_text = "描述图片中物体"; + //std::getline(std::cin, prompt_text); + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); + + ::mllm::shutdownContext(); + return 0; +} From 06b754c4892cc84e0a4ee2878678ac7a0afeb142 Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Thu, 5 Mar 2026 14:37:06 +0800 Subject: [PATCH 16/17] add qwen2.5o talker --- .../modeling_qwen2_5omni_talker.hpp | 626 +++++++ .../modeling_qwen2_5omni_token2wav.hpp | 1508 +++++++++++++++++ 2 files changed, 2134 insertions(+) create mode 100644 mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp create mode 100644 mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp new file mode 100644 index 000000000..df8019a84 --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp @@ -0,0 +1,626 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "mllm/core/Parallel.hpp" +#include "mllm/core/SlicePrimitives.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" + +namespace mllm::models::qwen2_5omni { + +constexpr float kPi = 3.14159265358979323846f; + +inline auto makeTalkerRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0f / std::pow(rope_theta, 2.0f * i / output_dim); } + return inv_freq; +} + +inline auto makeTalkerPositionEmbedding(Tensor& position_ids, const Tensor& inv_freq, const std::vector& mrope_section) + -> std::pair { + MLLM_RT_ASSERT_EQ(position_ids.shape().size(), 3); + MLLM_RT_ASSERT_EQ(position_ids.shape()[1], 1); + + Tensor tmp_sin = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + Tensor tmp_cos = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + + for (int b = 0; b < 3; ++b) { + for (int d = 0; d < inv_freq.shape()[0]; ++d) { + for (int s = 0; s < position_ids.shape()[2]; ++s) { + auto value = inv_freq.ptr()[d] * (*position_ids.offsettedPtr({b, 0, s})); + *tmp_cos.offsettedPtr({b, s, d}) = cosf(value); + *tmp_cos.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = cosf(value); + *tmp_sin.offsettedPtr({b, s, d}) = sinf(value); + *tmp_sin.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = sinf(value); + } + } + } + + Tensor sin = Tensor::nil(); + Tensor cos = Tensor::nil(); + + if (!mrope_section.empty()) { + auto double_rope_section = mrope_section; + for (int i : mrope_section) { double_rope_section.push_back(i); } + + int num_rows = tmp_sin.shape()[1]; + int num_cols = tmp_sin.shape()[2]; + + sin = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + cos = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + + std::vector start_cols; + int current_start = 0; + start_cols.push_back(current_start); + for (int s : double_rope_section) { + current_start += s; + start_cols.push_back(current_start); + } + + for (int j = 0; j < static_cast(double_rope_section.size()); ++j) { + int layer = j % 3; + int s_j = double_rope_section[j]; + int start_col_in = start_cols[j]; + int start_col_out = start_cols[j]; + for (int row = 0; row < num_rows; ++row) { + auto in_cos_row_ptr = tmp_cos.offsettedPtr({layer, row, 0}); + auto out_cos_row_ptr = cos.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_cos_row_ptr[start_col_out + c] = in_cos_row_ptr[start_col_in + c]; } + + auto in_sin_row_ptr = tmp_sin.offsettedPtr({layer, row, 0}); + auto out_sin_row_ptr = sin.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_sin_row_ptr[start_col_out + c] = in_sin_row_ptr[start_col_in + c]; } + } + } + } else { + sin = tmp_sin; + cos = tmp_cos; + } + + return {sin, cos}; +} + +struct Qwen2_5OmniSpeakerParams { + int64_t bos_token = 0; + Tensor cond = Tensor::nil(); + Tensor ref_mel = Tensor::nil(); +}; + +struct Qwen2_5OmniSpeakerMap { + std::unordered_map speakers; + std::string default_speaker; +}; + +inline Tensor tensorFromJson(const nlohmann::ordered_json& obj) { + if (!obj.contains("shape") || !obj.contains("data")) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Invalid speaker json entry: missing shape/data."); + } + auto shape = obj["shape"].get>(); + auto data = obj["data"].get>(); + + int64_t expected = 1; + for (auto dim : shape) { expected *= dim; } + MLLM_RT_ASSERT_EQ(expected, static_cast(data.size())); + + Tensor out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + std::copy(data.begin(), data.end(), out.ptr()); + return out; +} + +inline Qwen2_5OmniSpeakerMap loadSpeakerMap(const std::string& path) { + std::ifstream in(path); + if (!in.is_open()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Failed to open spk_dict.json at {}", path); } + + nlohmann::ordered_json root; + in >> root; + + Qwen2_5OmniSpeakerMap map; + bool first = true; + for (auto it = root.begin(); it != root.end(); ++it) { + const auto& name = it.key(); + const auto& entry = it.value(); + Qwen2_5OmniSpeakerParams params; + params.bos_token = entry.value("bos_token", 0); + params.cond = tensorFromJson(entry["cond"]); + params.ref_mel = tensorFromJson(entry["ref_mel"]); + map.speakers.emplace(name, std::move(params)); + if (first) { + map.default_speaker = name; + first = false; + } + } + + if (map.speakers.empty()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Empty speaker map in {}", path); } + return map; +} + +class Qwen2_5OmniTalkerMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniTalkerMLP() = default; + Qwen2_5OmniTalkerMLP(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniTalkerAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::MultimodalRoPE q_rope_; + nn::MultimodalRoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen2_5OmniTalkerAttention() = default; + + Qwen2_5OmniTalkerAttention(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + head_dim_ = cfg.head_dim; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, true); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, true); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, true); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, false); + + q_rope_ = reg( + "q_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + k_rope_ = reg( + "k_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + auto [k, v] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = k; + value_states = v; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + auto output = nn::functional::matmul(attn, value_states); + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + return {output}; + } + + int layer_idx_ = 0; +}; + +class Qwen2_5OmniTalkerDecoder final : public nn::Module { + public: + Qwen2_5OmniTalkerAttention self_attn_; + Qwen2_5OmniTalkerMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2_5OmniTalkerDecoder() = default; + + Qwen2_5OmniTalkerDecoder(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen2_5OmniTalkerModel final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + + public: + Qwen2_5OmniTalkerModel() = default; + + Qwen2_5OmniTalkerModel(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.embedding_size); + + auto inv = makeTalkerRoPEInvFreq(cfg.head_dim, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + x = norm_(x); + + return {x}; + } + + nn::Embedding embedding_; +}; + +struct Qwen2_5OmniTalkerOutput { + Tensor logits = Tensor::nil(); + Tensor thinker_reply_part = Tensor::nil(); + Tensor position_ids = Tensor::nil(); +}; + +class Qwen2_5OmniTalker final : public nn::Module { + public: + Qwen2_5OmniTalker() = delete; + Qwen2_5OmniTalker(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name), cfg_(cfg) { + thinker_to_talker_proj_ = reg("thinker_to_talker_proj", cfg.embedding_size, cfg.hidden_size, true); + model_ = reg("model", cfg); + codec_head_ = reg("codec_head", cfg.hidden_size, cfg.vocab_size, false); + + kv_cache_ = nn::StaticCache(cfg.max_position_embeddings, cfg.num_hidden_layers, cfg.num_attention_heads, cfg.num_key_value_heads, + cfg.head_dim, kFloat32, kFloat32, kCPU, false); + + codec_bos_token_ = cfg.tts_codec_start_token_id; + codec_eos_token_ = cfg.tts_codec_end_token_id; + codec_pad_token_ = cfg.tts_codec_pad_token_id; + codec_mask_token_ = cfg.tts_codec_mask_token_id; + text_bos_token_ = cfg.tts_text_start_token_id; + text_eos_token_ = cfg.tts_text_end_token_id; + text_pad_token_ = cfg.tts_text_pad_token_id; + } + + void clearCache() { + kv_cache_.clearCache(); + rope_deltas_ = Tensor::nil(); + } + + Qwen2_5OmniTalkerOutput forward(const Tensor& input_ids, const Tensor& input_text_ids, Tensor thinker_reply_part, + Tensor inputs_embeds, const Tensor& attention_mask, const Tensor& image_grid_thw, + Tensor position_ids) { + Tensor ids_for_pos = input_text_ids.isNil() ? input_ids : input_text_ids; + position_ids = getPositionIds(ids_for_pos, image_grid_thw, position_ids); + + const bool prefill = kv_cache_.getCurrentSeqCnt(0) == 0; + if (!inputs_embeds.isNil() && prefill) { + const auto S = inputs_embeds.shape()[1]; + MLLM_RT_ASSERT(S >= 2); + + auto bos_token = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + bos_token.at({0, 0}) = codec_bos_token_; + auto bos_embed = model_.embedding_(bos_token); + + auto pad_token = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + pad_token.at({0, 0}) = codec_pad_token_; + auto pad_embed = model_.embedding_(pad_token); + + auto embed_dim = inputs_embeds.shape()[2]; + if (inputs_embeds.dtype() == kFloat32) { + auto* out_ptr = inputs_embeds.offsettedPtr({0, S - 1, 0}); + auto* pad_ptr = inputs_embeds.offsettedPtr({0, S - 2, 0}); + auto* bos_ptr = bos_embed.ptr(); + auto* pad_src_ptr = pad_embed.ptr(); + for (int d = 0; d < embed_dim; ++d) { + out_ptr[d] += bos_ptr[d]; + pad_ptr[d] += pad_src_ptr[d]; + } + } else if (inputs_embeds.dtype() == kFloat16) { + auto* out_ptr = inputs_embeds.offsettedPtr({0, S - 1, 0}); + auto* pad_ptr = inputs_embeds.offsettedPtr({0, S - 2, 0}); + auto* bos_ptr = bos_embed.ptr(); + auto* pad_src_ptr = pad_embed.ptr(); + for (int d = 0; d < embed_dim; ++d) { + out_ptr[d] = static_cast(static_cast(out_ptr[d]) + static_cast(bos_ptr[d])); + pad_ptr[d] = static_cast(static_cast(pad_ptr[d]) + static_cast(pad_src_ptr[d])); + } + } + } + + if (inputs_embeds.isNil()) { + auto codec_embeds = model_.embedding_(input_ids); + inputs_embeds = codec_embeds + thinker_reply_part[{kAll, {0, 1}, kAll}]; + if (thinker_reply_part.shape()[1] > 1) { + thinker_reply_part = thinker_reply_part[{kAll, {1, thinker_reply_part.shape()[1]}, kAll}]; + } + } + + auto [llm_embedding_sin, llm_embedding_cos] = + makeTalkerPositionEmbedding(position_ids, model_.getBuffer("inv_freq"), cfg_.mrope_section); + + auto talker_lm_input = thinker_to_talker_proj_(inputs_embeds); + auto hidden_states = model_(talker_lm_input, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + auto logits = codec_head_(hidden_states).to(kFloat32); + + return { + .logits = logits, + .thinker_reply_part = thinker_reply_part, + .position_ids = position_ids, + }; + } + + int64_t codec_bos_token() const { return codec_bos_token_; } + int64_t codec_eos_token() const { return codec_eos_token_; } + int64_t codec_pad_token() const { return codec_pad_token_; } + int64_t codec_mask_token() const { return codec_mask_token_; } + int64_t text_eos_token() const { return text_eos_token_; } + int64_t text_pad_token() const { return text_pad_token_; } + int64_t text_bos_token() const { return text_bos_token_; } + + Qwen2_5OmniTalkerModel model_; + + private: + Tensor getPositionIds(const Tensor& input_ids, const Tensor& image_grid_thw, const Tensor& position_ids) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + bool has_multimodal = false; + auto input_ids_ptr = input_ids.ptr(); + auto seq_len = input_ids.shape()[1]; + for (int s = 0; s < seq_len; ++s) { + if (input_ids_ptr[s] == cfg_.vision_start_token_id || input_ids_ptr[s] == cfg_.audio_start_token_id) { + has_multimodal = true; + break; + } + } + + if (has_multimodal) { return getPositionIdsPrefill(input_ids, image_grid_thw); } + + if (!position_ids.isNil()) { + auto last_pos = position_ids.constAt({0, 0, position_ids.shape()[2] - 1}); + auto ret_position_ids = Tensor::empty({3, 1, 1}, kInt64, kCPU).alloc(); + *ret_position_ids.offsettedPtr({0, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({1, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({2, 0, 0}) = last_pos + 1; + return ret_position_ids; + } + + auto B = input_ids.shape()[0]; + auto S = seq_len; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor out = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + for (int d = 0; d < 3; ++d) { + auto out_ptr = out.offsettedPtr({d, 0, 0}); + for (int64_t s = 0; s < S; ++s) { out_ptr[s] = s; } + } + return out; + } + + Tensor getPositionIdsPrefill(const Tensor& input_ids, const Tensor& image_grid_thw) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + auto B = input_ids.shape()[0]; + auto S = input_ids.shape()[1]; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor position_ids = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + auto input_ids_ptr = input_ids.ptr(); + + auto fill_text_positions = [&](int start_seq, int len, int64_t start_id) { + for (int d = 0; d < 3; ++d) { + auto out_ptr = position_ids.offsettedPtr({d, 0, 0}); + for (int i = 0; i < len; ++i) { out_ptr[start_seq + i] = start_id + i; } + } + }; + + int seq_idx = 0; + int image_idx = 0; + int64_t current_max_position_id = -1; + const int total_images = image_grid_thw.isNil() ? 0 : image_grid_thw.shape()[0]; + + while (seq_idx < S) { + int next_vision = -1; + int next_audio = -1; + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_start_token_id) { + next_vision = i; + break; + } + } + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_start_token_id) { + next_audio = i; + break; + } + } + + if (next_vision == -1 && next_audio == -1) { + const int text_len = S - seq_idx; + if (text_len > 0) { fill_text_positions(seq_idx, text_len, current_max_position_id + 1); } + break; + } + + const bool is_vision = (next_vision != -1) && (next_audio == -1 || next_vision < next_audio); + const int segment_start = is_vision ? next_vision : next_audio; + + const int text_len = segment_start - seq_idx; + if (text_len > 0) { + fill_text_positions(seq_idx, text_len, current_max_position_id + 1); + current_max_position_id += text_len; + } + + if (is_vision) { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int vision_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_end_token_id) { + vision_end = i; + break; + } + } + MLLM_RT_ASSERT(vision_end != -1); + + if (image_idx >= total_images) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "Image index out of range."); } + + auto grid_t = image_grid_thw.ptr()[image_idx * 3]; + auto grid_h = image_grid_thw.ptr()[image_idx * 3 + 1]; + auto grid_w = image_grid_thw.ptr()[image_idx * 3 + 2]; + int vision_len = grid_t * grid_h * grid_w; + vision_len /= (cfg_.spatial_merge_size * cfg_.spatial_merge_size); + + for (int i = 0; i < vision_len; ++i) { + const int pos = segment_start + 1 + i; + if (pos >= S) { break; } + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, pos}) = current_max_position_id + 1 + i; + } + } + current_max_position_id += vision_len; + + fill_text_positions(vision_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = vision_end + 1; + image_idx += 1; + } else { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int audio_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_end_token_id) { + audio_end = i; + break; + } + } + MLLM_RT_ASSERT(audio_end != -1); + + std::vector audio_positions; + for (int i = segment_start + 1; i < audio_end; ++i) { + if (input_ids_ptr[i] == cfg_.audio_token_id) { + audio_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside audio segment."); + } + } + const int audio_len = static_cast(audio_positions.size()); + if (audio_len == 0) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "Empty audio tokens inside audio segment."); } + const int64_t audio_start_id = current_max_position_id + 1; + for (int i = 0; i < audio_len; ++i) { + const int64_t pos_id = audio_start_id + i; + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, audio_positions[i]}) = pos_id; + } + } + current_max_position_id += audio_len; + fill_text_positions(audio_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + seq_idx = audio_end + 1; + } + } + + return position_ids; + } + + const Qwen2_5OmniTalkerConfig& cfg_; + nn::Linear thinker_to_talker_proj_; + nn::Linear codec_head_; + nn::StaticCache kv_cache_; + Tensor rope_deltas_ = Tensor::nil(); + + int64_t codec_bos_token_ = 0; + int64_t codec_eos_token_ = 0; + int64_t codec_pad_token_ = 0; + int64_t codec_mask_token_ = 0; + int64_t text_bos_token_ = 0; + int64_t text_eos_token_ = 0; + int64_t text_pad_token_ = 0; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp new file mode 100644 index 000000000..6e5939a44 --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp @@ -0,0 +1,1508 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/Parallel.hpp" +#include "mllm/core/SlicePrimitives.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" + +namespace mllm::models::qwen2_5omni { + +namespace token2wav { + +constexpr float kPi = 3.14159265358979323846f; + +inline Tensor pad1dReflect(const Tensor& x, int32_t pad_left, int32_t pad_right) { + if (pad_left == 0 && pad_right == 0) { return x; } + return nn::functional::pad(x, {pad_left, pad_right}, aops::PadMode::kReflect); +} + +inline Tensor pad1dReplicate(const Tensor& x, int32_t pad_left, int32_t pad_right) { + if (pad_left == 0 && pad_right == 0) { return x; } + return nn::functional::pad(x, {pad_left, pad_right}, aops::PadMode::kReplicate); +} + +inline Tensor clampTensor(const Tensor& x, float min_val, float max_val) { + MLLM_RT_ASSERT_EQ(x.device(), kCPU); + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + + auto out = Tensor::empty(x.shape(), x.dtype(), x.device()).alloc(); + const auto* src = x.ptr(); + auto* dst = out.ptr(); + const auto numel = x.numel(); + + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + float v = src[idx]; + v = std::min(std::max(v, min_val), max_val); + dst[idx] = v; + }); + return out; +} + +inline Tensor amplitudeToDb(const Tensor& amplitude, float min_db_level) { + MLLM_RT_ASSERT_EQ(amplitude.device(), kCPU); + MLLM_RT_ASSERT_EQ(amplitude.dtype(), kFloat32); + + const float min_level = std::exp(min_db_level / 20.0f * std::log(10.0f)); + const float log10_scale = 1.0f / std::log(10.0f); + + auto out = Tensor::empty(amplitude.shape(), amplitude.dtype(), amplitude.device()).alloc(); + const auto* src = amplitude.ptr(); + auto* dst = out.ptr(); + const auto numel = amplitude.numel(); + + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + float v = std::max(src[idx], min_level); + dst[idx] = 20.0f * std::log(v) * log10_scale; + }); + + return out; +} + +inline Tensor normalizeSpectrogram(const Tensor& spectrogram, float max_value, float min_db) { + MLLM_RT_ASSERT_EQ(spectrogram.device(), kCPU); + MLLM_RT_ASSERT_EQ(spectrogram.dtype(), kFloat32); + + auto out = Tensor::empty(spectrogram.shape(), spectrogram.dtype(), spectrogram.device()).alloc(); + const auto* src = spectrogram.ptr(); + auto* dst = out.ptr(); + const auto numel = spectrogram.numel(); + + const float scale = (2.0f * max_value) / (-min_db); + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + float v = scale * (src[idx] - min_db) - max_value; + v = std::min(std::max(v, -max_value), max_value); + dst[idx] = v; + }); + return out; +} + +inline float besselI0(float x) { + const float ax = std::abs(x); + if (ax < 3.75f) { + const float y = (ax / 3.75f); + const float y2 = y * y; + return 1.0f + y2 * (3.5156229f + + y2 * (3.0899424f + + y2 * (1.2067492f + + y2 * (0.2659732f + + y2 * (0.0360768f + + y2 * 0.0045813f))))); + } + + const float y = 3.75f / ax; + const float exp_ax = std::exp(ax); + return (exp_ax / std::sqrt(ax)) * + (0.39894228f + + y * (0.01328592f + + y * (0.00225319f + + y * (-0.00157565f + + y * (0.00916281f + + y * (-0.02057706f + + y * (0.02635537f + + y * (-0.01647633f + + y * 0.00392377f)))))))); +} + +inline Tensor kaiserSincFilter1d(float cutoff, float half_width, int32_t kernel_size) { + const bool is_even = (kernel_size % 2) == 0; + const int32_t half_size = kernel_size / 2; + + if (cutoff == 0.0f) { return Tensor::zeros({1, 1, kernel_size}, kFloat32, kCPU); } + + const float delta_f = 4.0f * half_width; + const float attenuation = 2.285f * static_cast(half_size - 1) * kPi * delta_f + 7.95f; + + float beta = 0.0f; + if (attenuation > 50.0f) { + beta = 0.1102f * (attenuation - 8.7f); + } else if (attenuation >= 21.0f) { + beta = 0.5842f * std::pow(attenuation - 21.0f, 0.4f) + 0.07886f * (attenuation - 21.0f); + } + + const float denom = besselI0(beta); + std::vector window(kernel_size, 1.0f); + for (int32_t n = 0; n < kernel_size; ++n) { + const float ratio = (2.0f * static_cast(n) / static_cast(kernel_size - 1)) - 1.0f; + const float val = std::sqrt(std::max(0.0f, 1.0f - ratio * ratio)); + window[n] = besselI0(beta * val) / denom; + } + + std::vector filter(kernel_size, 0.0f); + float sum = 0.0f; + for (int32_t n = 0; n < kernel_size; ++n) { + float t = static_cast(n) - static_cast(half_size); + if (is_even) { t += 0.5f; } + const float arg = 2.0f * cutoff * t; + const float sinc = (arg == 0.0f) ? 1.0f : std::sin(kPi * arg) / (kPi * arg); + const float v = 2.0f * cutoff * window[n] * sinc; + filter[n] = v; + sum += v; + } + + if (sum != 0.0f) { + for (auto& v : filter) { v /= sum; } + } + + auto out = Tensor::empty({1, 1, kernel_size}, kFloat32, kCPU).alloc(); + std::copy(filter.begin(), filter.end(), out.ptr()); + return out; +} + +inline Tensor convTranspose1dDepthwise(const Tensor& input, const Tensor& filter, int32_t stride) { + MLLM_RT_ASSERT_EQ(input.device(), kCPU); + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(filter.device(), kCPU); + MLLM_RT_ASSERT_EQ(filter.dtype(), kFloat32); + + const auto& in_shape = input.shape(); + const int32_t batch = in_shape[0]; + const int32_t channels = in_shape[1]; + const int32_t in_len = in_shape[2]; + const int32_t kernel = filter.shape()[2]; + + const int32_t out_len = (in_len - 1) * stride + kernel; + auto out = Tensor::zeros({batch, channels, out_len}, kFloat32, kCPU); + + const auto* in_ptr = input.ptr(); + const auto* filt_ptr = filter.ptr(); + auto* out_ptr = out.ptr(); + + const int32_t in_step = channels * in_len; + const int32_t out_step = channels * out_len; + + for (int32_t b = 0; b < batch; ++b) { + const float* in_b = in_ptr + b * in_step; + float* out_b = out_ptr + b * out_step; + for (int32_t c = 0; c < channels; ++c) { + const float* in_c = in_b + c * in_len; + float* out_c = out_b + c * out_len; + const float* f = filt_ptr; + for (int32_t i = 0; i < in_len; ++i) { + const float v = in_c[i]; + const int32_t base = i * stride; + for (int32_t k = 0; k < kernel; ++k) { out_c[base + k] += v * f[k]; } + } + } + } + + return out; +} + +inline Tensor conv1dDepthwise(const Tensor& input, const Tensor& filter, int32_t stride) { + MLLM_RT_ASSERT_EQ(input.device(), kCPU); + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(filter.device(), kCPU); + MLLM_RT_ASSERT_EQ(filter.dtype(), kFloat32); + + const auto& in_shape = input.shape(); + const int32_t batch = in_shape[0]; + const int32_t channels = in_shape[1]; + const int32_t in_len = in_shape[2]; + const int32_t kernel = filter.shape()[2]; + + const int32_t out_len = (in_len - kernel) / stride + 1; + auto out = Tensor::zeros({batch, channels, out_len}, kFloat32, kCPU); + + const auto* in_ptr = input.ptr(); + const auto* filt_ptr = filter.ptr(); + auto* out_ptr = out.ptr(); + + const int32_t in_step = channels * in_len; + const int32_t out_step = channels * out_len; + + for (int32_t b = 0; b < batch; ++b) { + const float* in_b = in_ptr + b * in_step; + float* out_b = out_ptr + b * out_step; + for (int32_t c = 0; c < channels; ++c) { + const float* in_c = in_b + c * in_len; + float* out_c = out_b + c * out_len; + const float* f = filt_ptr; + for (int32_t o = 0; o < out_len; ++o) { + float sum = 0.0f; + const int32_t base = o * stride; + for (int32_t k = 0; k < kernel; ++k) { sum += in_c[base + k] * f[k]; } + out_c[o] = sum; + } + } + } + + return out; +} + +inline Tensor randomNormal(const std::vector& shape, float mean = 0.0f, float std = 1.0f) { + auto out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + auto* ptr = out.ptr(); + const int64_t numel = out.numel(); + std::mt19937 gen(static_cast(mllm::Context::instance().getRandomState())); + std::normal_distribution dist(mean, std); + for (int64_t i = 0; i < numel; ++i) { ptr[i] = dist(gen); } + return out; +} + +inline Tensor linspace(float start, float end, int32_t steps) { + auto out = Tensor::empty({steps}, kFloat32, kCPU).alloc(); + auto* ptr = out.ptr(); + if (steps <= 1) { + if (steps == 1) { ptr[0] = start; } + return out; + } + const float step = (end - start) / static_cast(steps - 1); + for (int32_t i = 0; i < steps; ++i) { ptr[i] = start + step * static_cast(i); } + return out; +} + +inline Tensor repeatInterleave(const Tensor& input, int32_t repeats, int32_t dim) { + MLLM_RT_ASSERT_EQ(input.device(), kCPU); + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(dim, 1); + + if (repeats == 1) { return input; } + + const auto& shape = input.shape(); + const int32_t batch = shape[0]; + const int32_t seq_len = shape[1]; + const int32_t channels = shape[2]; + + auto out = Tensor::empty({batch, seq_len * repeats, channels}, kFloat32, kCPU).alloc(); + const auto* src = input.ptr(); + auto* dst = out.ptr(); + + const int64_t in_stride_b = static_cast(seq_len) * channels; + const int64_t out_stride_b = static_cast(seq_len) * repeats * channels; + + for (int32_t b = 0; b < batch; ++b) { + const float* src_b = src + b * in_stride_b; + float* dst_b = dst + b * out_stride_b; + for (int32_t s = 0; s < seq_len; ++s) { + const float* src_s = src_b + static_cast(s) * channels; + for (int32_t r = 0; r < repeats; ++r) { + float* dst_s = dst_b + (static_cast(s) * repeats + r) * channels; + std::memcpy(dst_s, src_s, sizeof(float) * channels); + } + } + } + + return out; +} + +class SnakeBeta final : public nn::Module { + nn::Param alpha_; + nn::Param beta_; + float no_div_by_zero_ = 1e-9f; + + public: + SnakeBeta() = default; + SnakeBeta(const std::string& name, int32_t in_features) : nn::Module(name) { + alpha_ = reg("alpha", getModuleName() + ".alpha", std::vector{in_features}); + beta_ = reg("beta", getModuleName() + ".beta", std::vector{in_features}); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + MLLM_RT_ASSERT_EQ(x.device(), kCPU); + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + if (!x.isContiguous()) { x = x.contiguous(); } + + const auto& shape = x.shape(); + const int32_t batch = shape[0]; + const int32_t channels = shape[1]; + const int32_t seq_len = shape[2]; + + auto y = Tensor::empty(shape, kFloat32, kCPU).alloc(); + const auto* x_ptr = x.ptr(); + auto* y_ptr = y.ptr(); + + auto alpha = alpha_.weight(); + auto beta = beta_.weight(); + const auto* alpha_ptr = alpha.ptr(); + const auto* beta_ptr = beta.ptr(); + + const int32_t stride_c = seq_len; + const int32_t stride_b = channels * seq_len; + + for (int32_t b = 0; b < batch; ++b) { + for (int32_t c = 0; c < channels; ++c) { + const float a = std::exp(alpha_ptr[c]); + const float bb = std::exp(beta_ptr[c]); + const float inv_b = 1.0f / (bb + no_div_by_zero_); + const int32_t base = b * stride_b + c * stride_c; + for (int32_t t = 0; t < seq_len; ++t) { + float v = x_ptr[base + t]; + const float s = std::sin(v * a); + v = v + inv_b * (s * s); + y_ptr[base + t] = v; + } + } + } + + return {y}; + } + +}; + +class TorchActivation1d final : public nn::Module { + public: + TorchActivation1d() = default; + TorchActivation1d(const std::string& name, int32_t channels, int32_t up_ratio = 2, int32_t down_ratio = 2, + int32_t up_kernel_size = 12, int32_t down_kernel_size = 12) + : nn::Module(name), + up_ratio_(up_ratio), + down_ratio_(down_ratio), + up_kernel_size_(up_kernel_size), + down_kernel_size_(down_kernel_size) { + act_ = reg("act", channels); + + up_kernel_size_ = (up_kernel_size_ <= 0) ? static_cast(int(6 * up_ratio_ / 2) * 2) : up_kernel_size_; + up_stride_ = up_ratio_; + up_pad_ = up_kernel_size_ / up_ratio_ - 1; + up_pad_left_ = up_pad_ * up_stride_ + (up_kernel_size_ - up_stride_) / 2; + up_pad_right_ = up_pad_ * up_stride_ + (up_kernel_size_ - up_stride_ + 1) / 2; + + down_kernel_size_ = (down_kernel_size_ <= 0) ? static_cast(int(6 * down_ratio_ / 2) * 2) : down_kernel_size_; + down_stride_ = down_ratio_; + down_even_ = (down_kernel_size_ % 2) == 0; + down_pad_left_ = down_kernel_size_ / 2 - (down_even_ ? 1 : 0); + down_pad_right_ = down_kernel_size_ / 2; + + up_filter_ = kaiserSincFilter1d(0.5f / static_cast(up_ratio_), 0.6f / static_cast(up_ratio_), up_kernel_size_); + down_filter_ = + kaiserSincFilter1d(0.5f / static_cast(down_ratio_), 0.6f / static_cast(down_ratio_), down_kernel_size_); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + x = upsample(x); + x = act_(x)[0]; + x = downsample(x); + return {x}; + } + + private: + Tensor upsample(const Tensor& input) const { + auto padded = pad1dReplicate(input, up_pad_, up_pad_); + auto out = convTranspose1dDepthwise(padded, up_filter_, up_stride_); + out = out * static_cast(up_ratio_); + if (up_pad_left_ > 0 || up_pad_right_ > 0) { + auto length = out.shape()[2]; + auto start = up_pad_left_; + auto end = length - up_pad_right_; + out = out[{kAll, kAll, {start, end}}]; + } + return out; + } + + Tensor downsample(const Tensor& input) const { + auto padded = pad1dReplicate(input, down_pad_left_, down_pad_right_); + auto out = conv1dDepthwise(padded, down_filter_, down_stride_); + return out; + } + + SnakeBeta act_; + int32_t up_ratio_ = 2; + int32_t down_ratio_ = 2; + int32_t up_kernel_size_ = 12; + int32_t down_kernel_size_ = 12; + int32_t up_stride_ = 2; + int32_t down_stride_ = 2; + int32_t up_pad_ = 0; + int32_t up_pad_left_ = 0; + int32_t up_pad_right_ = 0; + int32_t down_pad_left_ = 0; + int32_t down_pad_right_ = 0; + bool down_even_ = false; + Tensor up_filter_ = Tensor::nil(); + Tensor down_filter_ = Tensor::nil(); +}; + +class TimeDelayNetBlock final : public nn::Module { + public: + TimeDelayNetBlock() = default; + TimeDelayNetBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t kernel_size, int32_t dilation) + : nn::Module(name), kernel_size_(kernel_size), dilation_(dilation) { + conv_ = reg("conv", in_channels, out_channels, kernel_size_, 1, 0, dilation_, 1, true); + relu_ = reg("relu"); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + const int32_t pad_total = dilation_ * (kernel_size_ - 1); + const int32_t pad_left = pad_total / 2; + const int32_t pad_right = pad_total - pad_left; + if (pad_total > 0) { x = pad1dReflect(x, pad_left, pad_right); } + x = conv_(x); + x = relu_(x); + return {x}; + } + + private: + nn::Conv1D conv_; + nn::ReLU relu_; + int32_t kernel_size_ = 1; + int32_t dilation_ = 1; +}; + +class Res2NetBlock final : public nn::Module { + public: + Res2NetBlock() = default; + Res2NetBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t scale, int32_t kernel_size, int32_t dilation) + : nn::Module(name), scale_(scale) { + const int32_t in_channel = in_channels / scale; + const int32_t hidden_channel = out_channels / scale; + blocks_ = reg>("blocks", scale_ - 1, in_channel, hidden_channel, kernel_size, dilation); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + const int32_t channels = x.shape()[1]; + const int32_t split = channels / scale_; + + std::vector outputs; + outputs.reserve(scale_); + Tensor output_part = Tensor::nil(); + + for (int32_t i = 0; i < scale_; ++i) { + auto hidden_part = x[{kAll, {i * split, (i + 1) * split}, kAll}]; + if (i == 0) { + output_part = hidden_part; + } else if (i == 1) { + output_part = blocks_.list()[i - 1](hidden_part)[0]; + } else { + output_part = blocks_.list()[i - 1](hidden_part + output_part)[0]; + } + outputs.push_back(output_part); + } + + auto out = nn::functional::concat(outputs, 1); + return {out}; + } + + private: + int32_t scale_ = 1; + nn::ModuleList blocks_; +}; + +class SqueezeExcitationBlock final : public nn::Module { + public: + SqueezeExcitationBlock() = default; + SqueezeExcitationBlock(const std::string& name, int32_t in_channels, int32_t se_channels, int32_t out_channels) + : nn::Module(name) { + conv1_ = reg("conv1", in_channels, se_channels, 1, 1, 0, 1, 1, true); + conv2_ = reg("conv2", se_channels, out_channels, 1, 1, 0, 1, 1, true); + relu_ = reg("relu"); + sigmoid_ = reg("sigmoid"); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + auto hidden_mean = nn::functional::mean(hidden_states, 2, true); + hidden_mean = relu_(conv1_(hidden_mean)); + hidden_mean = sigmoid_(conv2_(hidden_mean)); + hidden_states = hidden_states * hidden_mean; + return {hidden_states}; + } + + private: + nn::Conv1D conv1_; + nn::Conv1D conv2_; + nn::ReLU relu_; + nn::Sigmoid sigmoid_; +}; + +class AttentiveStatisticsPooling final : public nn::Module { + public: + AttentiveStatisticsPooling() = default; + AttentiveStatisticsPooling(const std::string& name, int32_t channels, int32_t attention_channels) + : nn::Module(name), channels_(channels) { + tdnn_ = reg("tdnn", channels * 3, attention_channels, 1, 1); + tanh_ = reg("tanh"); + conv_ = reg("conv", attention_channels, channels, 1, 1, 0, 1, 1, true); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + MLLM_RT_ASSERT_EQ(hidden_states.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(hidden_states.device(), kCPU); + + const int32_t batch = hidden_states.shape()[0]; + const int32_t channels = hidden_states.shape()[1]; + const int32_t seq_len = hidden_states.shape()[2]; + + auto mean = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + auto std = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + + const auto* x_ptr = hidden_states.ptr(); + auto* mean_ptr = mean.ptr(); + auto* std_ptr = std.ptr(); + + const int32_t stride_c = seq_len; + const int32_t stride_b = channels * seq_len; + + for (int32_t b = 0; b < batch; ++b) { + for (int32_t c = 0; c < channels; ++c) { + const int32_t base = b * stride_b + c * stride_c; + float sum = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { sum += x_ptr[base + t]; } + float m = sum / static_cast(seq_len); + mean_ptr[b * channels + c] = m; + + float var = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { + float diff = x_ptr[base + t] - m; + var += diff * diff; + } + var /= static_cast(seq_len); + std_ptr[b * channels + c] = std::sqrt(std::max(var, 1e-12f)); + } + } + + auto mean_rep = mean.view({batch, channels, 1}).repeat(seq_len, 2); + auto std_rep = std.view({batch, channels, 1}).repeat(seq_len, 2); + + auto attention = nn::functional::concat({hidden_states, mean_rep, std_rep}, 1); + attention = tdnn_(attention)[0]; + attention = tanh_(attention); + attention = conv_(attention); + attention = nn::functional::softmax(attention, 2); + + auto out_mean = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + auto out_std = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + auto* out_mean_ptr = out_mean.ptr(); + auto* out_std_ptr = out_std.ptr(); + const auto* attn_ptr = attention.ptr(); + + for (int32_t b = 0; b < batch; ++b) { + for (int32_t c = 0; c < channels; ++c) { + const int32_t base = b * stride_b + c * stride_c; + float m = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { m += attn_ptr[base + t] * x_ptr[base + t]; } + out_mean_ptr[b * channels + c] = m; + + float var = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { + float diff = x_ptr[base + t] - m; + var += attn_ptr[base + t] * diff * diff; + } + out_std_ptr[b * channels + c] = std::sqrt(std::max(var, 1e-12f)); + } + } + + auto pooled = nn::functional::concat({out_mean, out_std}, 1).view({batch, channels * 2, 1}); + return {pooled}; + } + + private: + int32_t channels_ = 0; + TimeDelayNetBlock tdnn_; + nn::Tanh tanh_; + nn::Conv1D conv_; +}; + +class SqueezeExcitationRes2NetBlock final : public nn::Module { + public: + SqueezeExcitationRes2NetBlock() = default; + SqueezeExcitationRes2NetBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t res2net_scale, + int32_t se_channels, int32_t kernel_size, int32_t dilation) + : nn::Module(name), out_channels_(out_channels) { + tdnn1_ = reg("tdnn1", in_channels, out_channels, 1, 1); + res2net_block_ = reg("res2net_block", out_channels, out_channels, res2net_scale, kernel_size, dilation); + tdnn2_ = reg("tdnn2", out_channels, out_channels, 1, 1); + se_block_ = reg("se_block", out_channels, se_channels, out_channels); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_state = inputs[0]; + auto residual = hidden_state; + + hidden_state = tdnn1_(hidden_state)[0]; + hidden_state = res2net_block_(hidden_state)[0]; + hidden_state = tdnn2_(hidden_state)[0]; + hidden_state = se_block_(hidden_state)[0]; + hidden_state = hidden_state + residual; + return {hidden_state}; + } + + private: + int32_t out_channels_ = 0; + TimeDelayNetBlock tdnn1_; + Res2NetBlock res2net_block_; + TimeDelayNetBlock tdnn2_; + SqueezeExcitationBlock se_block_; +}; + +class ECAPA_TimeDelayNet final : public nn::Module { + public: + ECAPA_TimeDelayNet() = default; + explicit ECAPA_TimeDelayNet(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name) { + if (cfg.enc_channels.size() != cfg.enc_kernel_sizes.size() || cfg.enc_channels.size() != cfg.enc_dilations.size()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "enc_channels, enc_kernel_sizes and enc_dilations should have same length"); + } + + if (cfg.enc_channels.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "enc_channels should not be empty"); + } + + const int32_t num_blocks = static_cast(cfg.enc_channels.size()); + tdnn0_ = reg("blocks.0", cfg.mel_dim, cfg.enc_channels[0], cfg.enc_kernel_sizes[0], cfg.enc_dilations[0]); + + for (int32_t i = 1; i < num_blocks - 1; ++i) { + se_blocks_.emplace_back(reg( + "blocks." + std::to_string(i), + cfg.enc_channels[i - 1], + cfg.enc_channels[i], + cfg.enc_res2net_scale, + cfg.enc_se_channels, + cfg.enc_kernel_sizes[i], + cfg.enc_dilations[i])); + } + + mfa_ = reg("mfa", cfg.enc_channels.back(), cfg.enc_channels.back(), cfg.enc_kernel_sizes.back(), + cfg.enc_dilations.back()); + asp_ = reg("asp", cfg.enc_channels.back(), cfg.enc_attention_channels); + fc_ = reg("fc", cfg.enc_channels.back() * 2, cfg.enc_dim, 1, 1, 0, 1, 1, true); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + MLLM_RT_ASSERT_EQ(hidden_states.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(hidden_states.device(), kCPU); + + hidden_states = hidden_states.transpose(1, 2); + + std::vector hidden_states_list; + hidden_states = tdnn0_(hidden_states)[0]; + hidden_states_list.push_back(hidden_states); + + for (auto& block : se_blocks_) { + hidden_states = block(hidden_states)[0]; + hidden_states_list.push_back(hidden_states); + } + + if (hidden_states_list.size() <= 1) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "ECAPA_TimeDelayNet expects at least 2 blocks."); + } + + std::vector mfa_inputs; + for (size_t i = 1; i < hidden_states_list.size(); ++i) { mfa_inputs.push_back(hidden_states_list[i]); } + hidden_states = nn::functional::concat(mfa_inputs, 1); + hidden_states = mfa_(hidden_states)[0]; + hidden_states = asp_(hidden_states)[0]; + hidden_states = fc_(hidden_states); + hidden_states = hidden_states.squeeze(-1); + + return {hidden_states}; + } + + private: + TimeDelayNetBlock tdnn0_; + std::vector se_blocks_; + TimeDelayNetBlock mfa_; + AttentiveStatisticsPooling asp_; + nn::Conv1D fc_; +}; + +class DiTInputEmbedding final : public nn::Module { + public: + DiTInputEmbedding() = default; + explicit DiTInputEmbedding(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name) { + const int32_t in_dim = cfg.mel_dim + cfg.enc_dim + cfg.enc_emb_dim + cfg.emb_dim; + proj_ = reg("proj", in_dim, cfg.hidden_size, true); + spk_encoder_ = reg("spk_encoder", cfg); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& speaker_embedding, const Tensor& condition_vector, const Tensor& code_embed, + bool drop_audio_cond, const Tensor& code_embed_uncond, bool apply_cfg) { + auto x = hidden_states; + auto spk = speaker_embedding; + auto cond = condition_vector; + auto code = code_embed; + + if (apply_cfg) { + x = nn::functional::concat({x, x}, 0); + spk = nn::functional::concat({spk, Tensor::zeros(spk.shape(), spk.dtype(), spk.device())}, 0); + cond = nn::functional::concat({cond, Tensor::zeros(cond.shape(), cond.dtype(), cond.device())}, 0); + code = nn::functional::concat({code, code_embed_uncond}, 0); + } else if (drop_audio_cond) { + cond = Tensor::zeros(cond.shape(), cond.dtype(), cond.device()); + spk = Tensor::zeros(spk.shape(), spk.dtype(), spk.device()); + } + + auto cond_embed = spk_encoder_(cond)[0]; + const int32_t seq_len = x.shape()[1]; + cond_embed = cond_embed.view({cond_embed.shape()[0], 1, cond_embed.shape()[1]}).repeat(seq_len, 1); + + auto merged = nn::functional::concat({x, cond_embed, code, spk}, -1); + auto out = proj_(merged); + return out; + } + + private: + nn::Linear proj_; + ECAPA_TimeDelayNet spk_encoder_; +}; + +class DiTCodecEmbedding final : public nn::Module { + public: + DiTCodecEmbedding() = default; + DiTCodecEmbedding(const std::string& name, int32_t codec_num_embeds, int32_t codec_dim, int32_t repeats) + : nn::Module(name), repeats_(repeats) { + codec_embed_ = reg("codec_embed", codec_num_embeds + 1, codec_dim); + } + + Tensor forward(const Tensor& code, bool drop_code) { + Tensor code_ids = code; + if (drop_code) { code_ids = Tensor::zeros(code.shape(), code.dtype(), code.device()); } + auto code_embed = codec_embed_(code_ids); + return repeatInterleave(code_embed, repeats_, 1); + } + + private: + int32_t repeats_ = 1; + nn::Embedding codec_embed_; +}; + +class Qwen2_5_OmniAdaLayerNormZero final : public nn::Module { + public: + Qwen2_5_OmniAdaLayerNormZero() = default; + Qwen2_5_OmniAdaLayerNormZero(const std::string& name, int32_t dim) : nn::Module(name) { + silu_ = reg("silu"); + linear_ = reg("linear", dim, dim * 6, true); + norm_ = reg("norm", std::vector{dim}, false, false, 1e-6f); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + auto emb = inputs[1]; + emb = linear_(silu_(emb)); + + auto chunks = nn::functional::chunk<6>(emb, 1); + auto shift_msa = chunks[0]; + auto scale_msa = chunks[1]; + auto gate_msa = chunks[2]; + auto shift_mlp = chunks[3]; + auto scale_mlp = chunks[4]; + auto gate_mlp = chunks[5]; + + auto normed = norm_(hidden_states); + const int32_t seq_len = hidden_states.shape()[1]; + auto scale = scale_msa.view({scale_msa.shape()[0], 1, scale_msa.shape()[1]}).repeat(seq_len, 1); + auto shift = shift_msa.view({shift_msa.shape()[0], 1, shift_msa.shape()[1]}).repeat(seq_len, 1); + normed = normed * (scale + 1.0f) + shift; + + return {normed, gate_msa, shift_mlp, scale_mlp, gate_mlp}; + } + + private: + nn::SiLU silu_; + nn::Linear linear_; + nn::LayerNorm norm_; +}; + +class Qwen2_5_OmniAdaLayerNormZero_Final final : public nn::Module { + public: + Qwen2_5_OmniAdaLayerNormZero_Final() = default; + Qwen2_5_OmniAdaLayerNormZero_Final(const std::string& name, int32_t dim) : nn::Module(name) { + silu_ = reg("silu"); + linear_ = reg("linear", dim, dim * 2, true); + norm_ = reg("norm", std::vector{dim}, false, false, 1e-6f); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& emb) { + auto emb_out = linear_(silu_(emb)); + auto chunks = nn::functional::chunk<2>(emb_out, 1); + auto scale = chunks[0]; + auto shift = chunks[1]; + + auto normed = norm_(hidden_states); + const int32_t seq_len = hidden_states.shape()[1]; + scale = scale.view({scale.shape()[0], 1, scale.shape()[1]}).repeat(seq_len, 1); + shift = shift.view({shift.shape()[0], 1, shift.shape()[1]}).repeat(seq_len, 1); + normed = normed * (scale + 1.0f) + shift; + return normed; + } + + private: + nn::SiLU silu_; + nn::Linear linear_; + nn::LayerNorm norm_; +}; + +class DiTMLP final : public nn::Module { + public: + DiTMLP() = default; + DiTMLP(const std::string& name, int32_t dim, int32_t mult) : nn::Module(name) { + const int32_t inner_dim = dim * mult; + fc1_ = reg("ff.0", dim, inner_dim, true); + act_ = reg("ff.1"); + fc2_ = reg("ff.3", inner_dim, dim, true); + } + + Tensor forward(const Tensor& hidden_states) { + auto x = fc1_(hidden_states); + x = act_(x); + x = fc2_(x); + return x; + } + + private: + nn::Linear fc1_; + nn::GELU act_; + nn::Linear fc2_; +}; + +inline void applyRotaryPosEmbFirstHead(Tensor& q, Tensor& k, const Tensor& cos, const Tensor& sin) { + MLLM_RT_ASSERT_EQ(q.device(), kCPU); + MLLM_RT_ASSERT_EQ(k.device(), kCPU); + MLLM_RT_ASSERT_EQ(cos.device(), kCPU); + MLLM_RT_ASSERT_EQ(sin.device(), kCPU); + MLLM_RT_ASSERT_EQ(q.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(k.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(cos.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(sin.dtype(), kFloat32); + + const int32_t batch = q.shape()[0]; + const int32_t heads = q.shape()[1]; + const int32_t seq_len = q.shape()[2]; + const int32_t head_dim = q.shape()[3]; + MLLM_RT_ASSERT_EQ(head_dim % 2, 0); + MLLM_RT_ASSERT_EQ(cos.shape()[0], batch); + MLLM_RT_ASSERT_EQ(cos.shape()[1], seq_len); + MLLM_RT_ASSERT_EQ(cos.shape()[2], head_dim); + + const auto* cos_ptr = cos.ptr(); + const auto* sin_ptr = sin.ptr(); + auto* q_ptr = q.ptr(); + auto* k_ptr = k.ptr(); + + const int64_t stride_q_b = static_cast(heads) * seq_len * head_dim; + const int64_t stride_q_h = static_cast(seq_len) * head_dim; + const int64_t stride_q_s = head_dim; + + const int64_t stride_cos_b = static_cast(seq_len) * head_dim; + const int64_t stride_cos_s = head_dim; + + for (int32_t b = 0; b < batch; ++b) { + const int64_t q_base_b = static_cast(b) * stride_q_b; + const int64_t cos_base_b = static_cast(b) * stride_cos_b; + for (int32_t s = 0; s < seq_len; ++s) { + float* q_row = q_ptr + q_base_b + 0 * stride_q_h + static_cast(s) * stride_q_s; + float* k_row = k_ptr + q_base_b + 0 * stride_q_h + static_cast(s) * stride_q_s; + const float* cos_row = cos_ptr + cos_base_b + static_cast(s) * stride_cos_s; + const float* sin_row = sin_ptr + cos_base_b + static_cast(s) * stride_cos_s; + for (int32_t d = 0; d < head_dim; d += 2) { + const float c = cos_row[d]; + const float ss = sin_row[d]; + const float q1 = q_row[d]; + const float q2 = q_row[d + 1]; + const float k1 = k_row[d]; + const float k2 = k_row[d + 1]; + q_row[d] = q1 * c - q2 * ss; + q_row[d + 1] = q1 * ss + q2 * c; + k_row[d] = k1 * c - k2 * ss; + k_row[d + 1] = k1 * ss + k2 * c; + } + } + } +} + +inline Tensor makeBlockDiff(int32_t batch, int32_t heads, int32_t seq_len, int32_t block_size) { + (void)heads; + MLLM_RT_ASSERT(block_size > 0); + std::vector block_indices(seq_len, 0); + for (int32_t i = 0; i < seq_len; ++i) { block_indices[i] = i / block_size; } + + std::vector base(static_cast(seq_len) * seq_len, 0.0f); + for (int32_t i = 0; i < seq_len; ++i) { + for (int32_t j = 0; j < seq_len; ++j) { + base[static_cast(i) * seq_len + j] = static_cast(block_indices[j] - block_indices[i]); + } + } + + // Use a broadcast-friendly shape to avoid materializing head copies while keeping naive broadcast support. + auto out = Tensor::empty({batch, 1, seq_len, seq_len}, kFloat32, kCPU).alloc(); + const int64_t block_stride = static_cast(seq_len) * seq_len; + auto* out_ptr = out.ptr(); + for (int32_t b = 0; b < batch; ++b) { + float* dst = out_ptr + static_cast(b) * block_stride; + std::memcpy(dst, base.data(), sizeof(float) * base.size()); + } + return out; +} + +inline Tensor makeBlockMask(const Tensor& block_diff, int32_t look_backward_block, int32_t look_ahead_block) { + MLLM_RT_ASSERT_EQ(block_diff.device(), kCPU); + MLLM_RT_ASSERT_EQ(block_diff.dtype(), kFloat32); + + auto mask = Tensor::empty(block_diff.shape(), kFloat32, kCPU).alloc(); + const auto* src = block_diff.ptr(); + auto* dst = mask.ptr(); + const int64_t numel = block_diff.numel(); + const float lower = -static_cast(look_backward_block); + const float upper = static_cast(look_ahead_block); + + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + const float v = src[idx]; + dst[idx] = (v >= lower && v <= upper) ? 0.0f : -1e4f; + }); + return mask; +} + +class DiTAttention final : public nn::Module { + public: + DiTAttention() = default; + explicit DiTAttention(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name), cfg_(cfg) { + dim_ = cfg.hidden_size; + heads_ = cfg.num_attention_heads; + head_dim_ = cfg.head_dim; + inner_dim_ = head_dim_ * heads_; + + to_q_ = reg("to_q", dim_, inner_dim_, true); + to_k_ = reg("to_k", dim_, inner_dim_, true); + to_v_ = reg("to_v", dim_, inner_dim_, true); + to_out_ = reg("to_out.0", inner_dim_, dim_, true); + } + + Tensor forward(const Tensor& hidden_states, const std::pair& position_embeddings, const Tensor& attention_mask) { + auto query = to_q_(hidden_states); + auto key = to_k_(hidden_states); + auto value = to_v_(hidden_states); + + const int32_t batch = hidden_states.shape()[0]; + const int32_t seq_len = hidden_states.shape()[1]; + + query = query.view({batch, seq_len, heads_, head_dim_}).transpose(1, 2); + key = key.view({batch, seq_len, heads_, head_dim_}).transpose(1, 2); + value = value.view({batch, seq_len, heads_, head_dim_}).transpose(1, 2); + + if (!position_embeddings.first.isNil()) { + applyRotaryPosEmbFirstHead(query, key, position_embeddings.first, position_embeddings.second); + } + + auto attn_output = nn::functional::scaledDotProductAttention(query, key, value, attention_mask); + attn_output = attn_output.transpose(1, 2).view({batch, seq_len, inner_dim_}); + attn_output = to_out_(attn_output); + return attn_output; + } + + private: + Qwen2_5OmniDiTConfig cfg_; + int32_t dim_ = 0; + int32_t heads_ = 0; + int32_t head_dim_ = 0; + int32_t inner_dim_ = 0; + nn::Linear to_q_; + nn::Linear to_k_; + nn::Linear to_v_; + nn::Linear to_out_; +}; + +class SinusPositionEmbedding final : public nn::Module { + public: + SinusPositionEmbedding() = default; + explicit SinusPositionEmbedding(const std::string& name, int32_t dim) : nn::Module(name), dim_(dim) {} + + Tensor forward(const Tensor& hidden_states, float scale = 1000.0f) { + MLLM_RT_ASSERT_EQ(hidden_states.device(), kCPU); + MLLM_RT_ASSERT_EQ(hidden_states.dtype(), kFloat32); + + const int32_t batch = hidden_states.shape()[0]; + const int32_t half_dim = dim_ / 2; + auto out = Tensor::empty({batch, dim_}, kFloat32, kCPU).alloc(); + auto* out_ptr = out.ptr(); + const auto* hs_ptr = hidden_states.ptr(); + + const float emb = std::log(10000.0f) / static_cast(half_dim - 1); + std::vector freqs(half_dim); + for (int32_t i = 0; i < half_dim; ++i) { freqs[i] = std::exp(-emb * static_cast(i)); } + + for (int32_t b = 0; b < batch; ++b) { + const float t = hs_ptr[b] * scale; + float* row = out_ptr + static_cast(b) * dim_; + for (int32_t i = 0; i < half_dim; ++i) { + const float val = t * freqs[i]; + row[i] = std::sin(val); + row[i + half_dim] = std::cos(val); + } + } + + return out; + } + + private: + int32_t dim_ = 0; +}; + +class DiTTimestepEmbedding final : public nn::Module { + public: + DiTTimestepEmbedding() = default; + explicit DiTTimestepEmbedding(const std::string& name, int32_t dim, int32_t freq_embed_dim = 256) + : nn::Module(name), freq_embed_dim_(freq_embed_dim) { + time_embed_ = reg("time_embed", freq_embed_dim_); + fc1_ = reg("time_mlp.0", freq_embed_dim_, dim, true); + act_ = reg("time_mlp.1"); + fc2_ = reg("time_mlp.2", dim, dim, true); + } + + Tensor forward(const Tensor& timestep) { + auto time_hidden = time_embed_.forward(timestep); + time_hidden = fc1_(time_hidden); + time_hidden = act_(time_hidden); + time_hidden = fc2_(time_hidden); + return time_hidden; + } + + private: + int32_t freq_embed_dim_ = 256; + SinusPositionEmbedding time_embed_; + nn::Linear fc1_; + nn::SiLU act_; + nn::Linear fc2_; +}; + +class DiTDecoderLayer final : public nn::Module { + public: + DiTDecoderLayer() = default; + DiTDecoderLayer(const std::string& name, const Qwen2_5OmniDiTConfig& cfg, int32_t look_ahead_block, int32_t look_backward_block) + : nn::Module(name), look_ahead_block_(look_ahead_block), look_backward_block_(look_backward_block) { + attn_norm_ = reg("attn_norm", cfg.hidden_size); + attn_ = reg("attn", cfg); + ff_norm_ = reg("ff_norm", std::vector{cfg.hidden_size}, false, false, 1e-6f); + ff_ = reg("ff", cfg.hidden_size, cfg.ff_mult); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& timestep, const std::pair& position_embeddings, + const Tensor& block_diff) { + auto attn_norm_out = attn_norm_(hidden_states, timestep); + auto norm = attn_norm_out[0]; + auto gate_msa = attn_norm_out[1]; + auto shift_mlp = attn_norm_out[2]; + auto scale_mlp = attn_norm_out[3]; + auto gate_mlp = attn_norm_out[4]; + + Tensor attn_mask = Tensor::nil(); + if (!block_diff.isNil()) { attn_mask = makeBlockMask(block_diff, look_backward_block_, look_ahead_block_); } + auto attn_output = attn_.forward(norm, position_embeddings, attn_mask); + + auto gate_msa_rep = gate_msa.view({gate_msa.shape()[0], 1, gate_msa.shape()[1]}).repeat(hidden_states.shape()[1], 1); + auto x = Tensor(hidden_states); + x = x + gate_msa_rep * attn_output; + + auto norm_ff = ff_norm_(x); + auto scale_rep = scale_mlp.view({scale_mlp.shape()[0], 1, scale_mlp.shape()[1]}).repeat(x.shape()[1], 1); + auto shift_rep = shift_mlp.view({shift_mlp.shape()[0], 1, shift_mlp.shape()[1]}).repeat(x.shape()[1], 1); + norm_ff = norm_ff * (scale_rep + 1.0f) + shift_rep; + auto ff_output = ff_.forward(norm_ff); + auto gate_mlp_rep = gate_mlp.view({gate_mlp.shape()[0], 1, gate_mlp.shape()[1]}).repeat(x.shape()[1], 1); + x = x + gate_mlp_rep * ff_output; + return x; + } + + private: + Qwen2_5_OmniAdaLayerNormZero attn_norm_; + DiTAttention attn_; + nn::LayerNorm ff_norm_; + DiTMLP ff_; + int32_t look_ahead_block_ = 0; + int32_t look_backward_block_ = 0; +}; + +class Qwen2_5OmniDiTRotaryEmbedding final : public nn::Module { + public: + Qwen2_5OmniDiTRotaryEmbedding() = default; + explicit Qwen2_5OmniDiTRotaryEmbedding(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name), cfg_(cfg) { + const int32_t dim = cfg.head_dim; + inv_freq_ = reg("inv_freq", getModuleName() + ".inv_freq", std::vector{dim / 2}); + attention_scaling_ = 1.0f; + + auto inv = inv_freq_.weight(); + if (!inv.isNil() && inv.numel() == 0) { + inv = Tensor::empty({dim / 2}, kFloat32, kCPU).alloc(); + inv_freq_.weight().copy2(inv); + } + } + + std::pair forward(const Tensor& x, const Tensor& position_ids) { + MLLM_RT_ASSERT_EQ(x.device(), kCPU); + MLLM_RT_ASSERT_EQ(position_ids.device(), kCPU); + MLLM_RT_ASSERT_EQ(position_ids.dtype(), kInt64); + + const int32_t batch = position_ids.shape()[0]; + const int32_t seq_len = position_ids.shape()[1]; + auto inv_freq = inv_freq_.weight(); + if (inv_freq.isNil() || inv_freq.numel() == 0) { + const int32_t dim = cfg_.head_dim; + inv_freq = Tensor::empty({dim / 2}, kFloat32, kCPU).alloc(); + auto* ptr = inv_freq.ptr(); + for (int32_t i = 0; i < dim / 2; ++i) { + ptr[i] = 1.0f / std::pow(cfg_.rope_theta, 2.0f * i / static_cast(dim)); + } + } + + const int32_t half_dim = inv_freq.shape()[0]; + auto cos = Tensor::empty({batch, seq_len, half_dim * 2}, kFloat32, kCPU).alloc(); + auto sin = Tensor::empty({batch, seq_len, half_dim * 2}, kFloat32, kCPU).alloc(); + + const auto* inv_ptr = inv_freq.ptr(); + const auto* pos_ptr = position_ids.ptr(); + auto* cos_ptr = cos.ptr(); + auto* sin_ptr = sin.ptr(); + + const int64_t stride_pos_b = seq_len; + const int64_t stride_cos_b = static_cast(seq_len) * half_dim * 2; + const int64_t stride_cos_s = half_dim * 2; + + for (int32_t b = 0; b < batch; ++b) { + const int64_t pos_base = static_cast(b) * stride_pos_b; + const int64_t out_base = static_cast(b) * stride_cos_b; + for (int32_t s = 0; s < seq_len; ++s) { + const float position = static_cast(pos_ptr[pos_base + s]); + float* cos_row = cos_ptr + out_base + static_cast(s) * stride_cos_s; + float* sin_row = sin_ptr + out_base + static_cast(s) * stride_cos_s; + for (int32_t d = 0; d < half_dim; ++d) { + const float freq = inv_ptr[d] * position; + const float c = std::cos(freq) * attention_scaling_; + const float ss = std::sin(freq) * attention_scaling_; + cos_row[d] = c; + cos_row[d + half_dim] = c; + sin_row[d] = ss; + sin_row[d + half_dim] = ss; + } + } + } + + return {cos, sin}; + } + + private: + Qwen2_5OmniDiTConfig cfg_; + nn::Param inv_freq_; + float attention_scaling_ = 1.0f; +}; + +class RungeKutta4ODESolver { + public: + using Function = std::function; + + RungeKutta4ODESolver(Function function, Tensor initial_value) + : function_(std::move(function)), initial_value_(std::move(initial_value)) {} + + Tensor integrate(const std::vector& time_points) { + auto current_value = initial_value_; + if (time_points.size() < 2) { return current_value; } + + for (size_t i = 0; i + 1 < time_points.size(); ++i) { + const float time_start = time_points[i]; + const float time_end = time_points[i + 1]; + const float time_step = time_end - time_start; + + auto k1 = function_(time_start, current_value); + auto k2 = function_(time_start + time_step * one_third_, current_value + k1 * (time_step * one_third_)); + auto k3 = function_(time_start + time_step * two_thirds_, + current_value + (k2 - k1 * one_third_) * time_step); + auto k4 = function_(time_end, current_value + (k1 - k2 + k3) * time_step); + + auto delta = (k1 + (k2 + k3) * 3.0f + k4) * (time_step / 8.0f); + current_value = current_value + delta; + } + + return current_value; + } + + private: + Function function_; + Tensor initial_value_; + float one_third_ = 1.0f / 3.0f; + float two_thirds_ = 2.0f / 3.0f; +}; + +class Qwen2_5OmniToken2WavDiTModel final : public nn::Module { + public: + Qwen2_5OmniToken2WavDiTModel() = default; + explicit Qwen2_5OmniToken2WavDiTModel(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name), cfg_(cfg) { + mel_dim_ = cfg.mel_dim; + repeats_ = cfg.repeats; + block_size_ = cfg.block_size; + num_attention_heads_ = cfg.num_attention_heads; + + time_embed_ = reg("time_embed", cfg.hidden_size); + text_embed_ = reg("text_embed", cfg.num_embeds, cfg.emb_dim, cfg.repeats); + input_embed_ = reg("input_embed", cfg); + rotary_embed_ = reg("rotary_embed", cfg); + + for (int32_t i = 0; i < cfg.num_hidden_layers; ++i) { + const bool look_ahead = std::find(cfg.look_ahead_layers.begin(), cfg.look_ahead_layers.end(), i) != cfg.look_ahead_layers.end(); + const bool look_backward = + std::find(cfg.look_backward_layers.begin(), cfg.look_backward_layers.end(), i) != cfg.look_backward_layers.end(); + transformer_blocks_.emplace_back(reg("transformer_blocks." + std::to_string(i), cfg, look_ahead ? 1 : 0, + look_backward ? 1 : 0)); + } + + norm_out_ = reg("norm_out", cfg.hidden_size); + proj_out_ = reg("proj_out", cfg.hidden_size, cfg.mel_dim, true); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& condition_vector, const Tensor& speaker_embedding, const Tensor& quantized_code, + const Tensor& time_step, bool drop_audio_conditioning, bool drop_code, bool apply_cfg) { + Tensor timestep = time_step; + if (timestep.shape().empty()) { timestep = timestep.view({1}); } + if (timestep.shape().size() == 1 && timestep.shape()[0] == 1 && hidden_states.shape()[0] > 1) { + timestep = timestep.repeat(hidden_states.shape()[0], 0); + } + + auto time_embedding = time_embed_.forward(timestep); + auto text_embedding = text_embed_.forward(quantized_code, apply_cfg ? false : drop_code); + Tensor text_embedding_uncond = Tensor::nil(); + if (apply_cfg) { text_embedding_uncond = text_embed_.forward(quantized_code, true); } + + auto x = input_embed_.forward(hidden_states, speaker_embedding, condition_vector, text_embedding, drop_audio_conditioning, + text_embedding_uncond, apply_cfg); + + const int32_t seq_len = x.shape()[1]; + auto position_ids = Tensor::empty({x.shape()[0], seq_len}, kInt64, kCPU).alloc(); + auto* pos_ptr = position_ids.ptr(); + for (int32_t b = 0; b < position_ids.shape()[0]; ++b) { + for (int32_t s = 0; s < seq_len; ++s) { pos_ptr[b * seq_len + s] = s; } + } + + auto position_embeddings = rotary_embed_.forward(x, position_ids); + auto block_diff = makeBlockDiff(x.shape()[0], num_attention_heads_, seq_len, block_size_); + + for (auto& block : transformer_blocks_) { x = block.forward(x, time_embedding, position_embeddings, block_diff); } + + x = norm_out_.forward(x, time_embedding); + x = proj_out_(x); + return x; + } + + Tensor sample(const Tensor& conditioning_vector, const Tensor& reference_mel, const Tensor& quantized_code, int32_t num_steps, + float guidance_scale, float sway_coefficient) { + const int32_t max_duration = quantized_code.shape()[1] * repeats_; + auto initial_state = randomNormal({1, max_duration, mel_dim_}); + + const int32_t batch = reference_mel.shape()[0]; + if (batch != 1) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "Only batch size = 1 is supported for Qwen2.5-Omni token2wav."); } + + auto cond = Tensor(conditioning_vector); + cond = cond.view({batch, 1, conditioning_vector.shape()[1]}).repeat(max_duration, 1); + + auto ode_function = [&](float time_step, const Tensor& hidden) -> Tensor { + auto t = Tensor::empty({1}, kFloat32, kCPU).alloc(); + t.ptr()[0] = time_step; + + if (guidance_scale < 1e-5f) { + return forward(hidden, reference_mel, cond, quantized_code, t, false, false, false); + } + + auto model_output = forward(hidden, reference_mel, cond, quantized_code, t, false, false, true); + auto outputs = nn::functional::chunk<2>(model_output, 0); + return outputs[0] + (outputs[0] - outputs[1]) * guidance_scale; + }; + + auto time_points_tensor = linspace(0.0f, 1.0f, num_steps); + std::vector time_points(static_cast(num_steps)); + const auto* tp_ptr = time_points_tensor.ptr(); + for (int32_t i = 0; i < num_steps; ++i) { time_points[i] = tp_ptr[i]; } + + if (sway_coefficient != 0.0f) { + for (auto& t : time_points) { + t = t + sway_coefficient * (std::cos(kPi / 2.0f * t) - 1.0f + t); + } + } + + RungeKutta4ODESolver solver(ode_function, initial_state); + auto generated = solver.integrate(time_points); + auto mel = generated.permute({0, 2, 1}); + if (!mel.isContiguous()) { mel = mel.contiguous(); } + return mel; + } + + private: + Qwen2_5OmniDiTConfig cfg_; + int32_t mel_dim_ = 0; + int32_t repeats_ = 1; + int32_t block_size_ = 1; + int32_t num_attention_heads_ = 1; + + DiTTimestepEmbedding time_embed_; + DiTCodecEmbedding text_embed_; + DiTInputEmbedding input_embed_; + Qwen2_5OmniDiTRotaryEmbedding rotary_embed_; + std::vector transformer_blocks_; + Qwen2_5_OmniAdaLayerNormZero_Final norm_out_; + nn::Linear proj_out_; +}; + +class AMPBlock final : public nn::Module { + public: + AMPBlock() = default; + AMPBlock(const std::string& name, int32_t channels, int32_t kernel_size, const std::vector& dilations) + : nn::Module(name) { + if (dilations.size() != 3) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "AMPBlock expects 3 dilation values."); } + + convs1_.emplace_back(reg("convs1.0", channels, channels, kernel_size, 1, getPadding(kernel_size, dilations[0]), + dilations[0], 1, true)); + convs1_.emplace_back(reg("convs1.1", channels, channels, kernel_size, 1, getPadding(kernel_size, dilations[1]), + dilations[1], 1, true)); + convs1_.emplace_back(reg("convs1.2", channels, channels, kernel_size, 1, getPadding(kernel_size, dilations[2]), + dilations[2], 1, true)); + + convs2_.emplace_back(reg("convs2.0", channels, channels, kernel_size, 1, getPadding(kernel_size, 1), 1, 1, true)); + convs2_.emplace_back(reg("convs2.1", channels, channels, kernel_size, 1, getPadding(kernel_size, 1), 1, 1, true)); + convs2_.emplace_back(reg("convs2.2", channels, channels, kernel_size, 1, getPadding(kernel_size, 1), 1, 1, true)); + + const int32_t num_layers = static_cast(convs1_.size() + convs2_.size()); + for (int32_t i = 0; i < num_layers; ++i) { + activations_.emplace_back(reg("activations." + std::to_string(i), channels)); + } + } + + Tensor forward(const Tensor& hidden_states) { + auto out = hidden_states; + const int32_t num_blocks = static_cast(convs1_.size()); + for (int32_t i = 0; i < num_blocks; ++i) { + auto residual = out; + auto x = activations_[i * 2].forward({out}, {})[0]; + x = convs1_[i](x); + x = activations_[i * 2 + 1].forward({x}, {})[0]; + x = convs2_[i](x); + out = residual + x; + } + return out; + } + + private: + static int32_t getPadding(int32_t kernel_size, int32_t dilation) { + return static_cast((kernel_size * dilation - dilation) / 2); + } + + std::vector convs1_; + std::vector convs2_; + std::vector activations_; +}; + +class Qwen2_5OmniToken2WavBigVGANModel final : public nn::Module { + public: + Qwen2_5OmniToken2WavBigVGANModel() = default; + explicit Qwen2_5OmniToken2WavBigVGANModel(const std::string& name, const Qwen2_5OmniBigVGANConfig& cfg) : nn::Module(name), cfg_(cfg) { + num_residual_blocks_ = static_cast(cfg.resblock_kernel_sizes.size()); + num_upsample_layers_ = static_cast(cfg.upsample_rates.size()); + + conv_pre_ = reg("conv_pre", cfg.mel_dim, cfg.upsample_initial_channel, 7, 1, 3, 1, 1, true); + + for (int32_t layer_idx = 0; layer_idx < num_upsample_layers_; ++layer_idx) { + const int32_t stride = cfg.upsample_rates[layer_idx]; + const int32_t kernel = cfg.upsample_kernel_sizes[layer_idx]; + const int32_t in_ch = cfg.upsample_initial_channel / static_cast(std::pow(2, layer_idx)); + const int32_t out_ch = cfg.upsample_initial_channel / static_cast(std::pow(2, layer_idx + 1)); + const int32_t padding = (kernel - stride) / 2; + ups_.emplace_back(reg("ups." + std::to_string(layer_idx) + ".0", in_ch, out_ch, kernel, stride, + padding, 0, 1, 1, true)); + } + + for (int32_t layer_idx = 0; layer_idx < num_upsample_layers_; ++layer_idx) { + const int32_t channels = cfg.upsample_initial_channel / static_cast(std::pow(2, layer_idx + 1)); + for (size_t i = 0; i < cfg.resblock_kernel_sizes.size(); ++i) { + resblocks_.emplace_back(reg("resblocks." + std::to_string(resblocks_.size()), channels, + cfg.resblock_kernel_sizes[i], cfg.resblock_dilation_sizes[i])); + } + } + + activation_post_ = + reg("activation_post", cfg.upsample_initial_channel / static_cast(std::pow(2, num_upsample_layers_))); + conv_post_ = reg("conv_post", + cfg.upsample_initial_channel / static_cast(std::pow(2, num_upsample_layers_)), 1, 7, 1, 3, 1, 1, + false); + } + + Tensor forward(const Tensor& mel_spectrogram) { + auto mel = mel_spectrogram; + if (!mel.isContiguous()) { mel = mel.contiguous(); } + auto processed = processMelSpectrogram(mel); + return forwardProcessed(processed); + } + + private: + Tensor forwardProcessed(const Tensor& processed) { + auto hidden = conv_pre_(processed); + + for (int32_t layer_idx = 0; layer_idx < num_upsample_layers_; ++layer_idx) { + hidden = ups_[layer_idx](hidden); + Tensor residual_sum = Tensor::zeros(hidden.shape(), hidden.dtype(), hidden.device()); + for (int32_t block_idx = 0; block_idx < num_residual_blocks_; ++block_idx) { + residual_sum = residual_sum + resblocks_[layer_idx * num_residual_blocks_ + block_idx].forward(hidden); + } + hidden = residual_sum * (1.0f / static_cast(num_residual_blocks_)); + } + + hidden = activation_post_.forward({hidden}, {})[0]; + auto output = conv_post_(hidden); + output = clampTensor(output, -1.0f, 1.0f); + return output.squeeze(); + } + Tensor processMelSpectrogram(const Tensor& mel_spectrogram) const { + auto amplitude = nn::functional::exp(mel_spectrogram); + auto decibel = amplitudeToDb(amplitude, -115.0f) + (-20.0f); + return normalizeSpectrogram(decibel, 1.0f, -115.0f); + } + + Qwen2_5OmniBigVGANConfig cfg_; + int32_t num_residual_blocks_ = 0; + int32_t num_upsample_layers_ = 0; + nn::Conv1D conv_pre_; + std::vector ups_; + std::vector resblocks_; + TorchActivation1d activation_post_; + nn::Conv1D conv_post_; +}; + +class Qwen2_5OmniToken2WavModel final : public nn::Module { + public: + Qwen2_5OmniToken2WavModel() = default; + explicit Qwen2_5OmniToken2WavModel(const std::string& name, const Qwen2_5OmniToken2WavConfig& cfg) : nn::Module(name), cfg_(cfg) { + code2wav_dit_model_ = reg("code2wav_dit_model", cfg.dit_config); + code2wav_bigvgan_model_ = reg("code2wav_bigvgan_model", cfg.bigvgan_config); + } + + Tensor forward(const Tensor& code, const Tensor& conditioning, const Tensor& reference_mel, int32_t num_steps = 10, + float guidance_scale = 0.5f, float sway_coefficient = -1.0f) { + auto mel = code2wav_dit_model_.sample(conditioning, reference_mel, code, num_steps, guidance_scale, sway_coefficient); + if (!mel.isContiguous()) { mel = mel.contiguous(); } + return code2wav_bigvgan_model_.forward(mel); + } + + Tensor vocodeMel(const Tensor& mel) { + return code2wav_bigvgan_model_.forward(mel); + } + + private: + Qwen2_5OmniToken2WavConfig cfg_; + Qwen2_5OmniToken2WavDiTModel code2wav_dit_model_; + Qwen2_5OmniToken2WavBigVGANModel code2wav_bigvgan_model_; +}; + +} // namespace token2wav + +using token2wav::Qwen2_5OmniToken2WavBigVGANModel; +using token2wav::Qwen2_5OmniToken2WavDiTModel; +using token2wav::Qwen2_5OmniToken2WavModel; + +} // namespace mllm::models::qwen2_5omni From 5676edc00f354a96fad3c1cbc7c3bebd690dd90c Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Thu, 5 Mar 2026 15:32:14 +0800 Subject: [PATCH 17/17] add --- examples/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e5501e8cd..0f025fcf6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -2,7 +2,6 @@ add_subdirectory(qwen2vl) add_subdirectory(qwen2vl_tracer) add_subdirectory(qwen2_5vl) add_subdirectory(qwen2_5vl_tracer) -add_subdirectory(qwen2_5omni) add_subdirectory(minicpm_o45) add_subdirectory(llama) add_subdirectory(minicpm_o)