diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3df37bddc..0f025fcf6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(qwen2vl) add_subdirectory(qwen2vl_tracer) add_subdirectory(qwen2_5vl) add_subdirectory(qwen2_5vl_tracer) +add_subdirectory(minicpm_o45) add_subdirectory(llama) add_subdirectory(minicpm_o) add_subdirectory(minicpm4) diff --git a/examples/minicpm_o45/CMakeLists.txt b/examples/minicpm_o45/CMakeLists.txt new file mode 100644 index 000000000..a755efda1 --- /dev/null +++ b/examples/minicpm_o45/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(mllm-minicpm-o45-runner main.cpp) +target_link_libraries(mllm-minicpm-o45-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-minicpm-o45-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +# add_executable(mllm-minicpm-o45-runner-python main_python.cpp) +# target_link_libraries(mllm-minicpm-o45-runner-python PRIVATE MllmRT MllmCPUBackend) +# target_include_directories(mllm-minicpm-o45-runner-python PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/minicpm_o45/config_minicpm_o45.json b/examples/minicpm_o45/config_minicpm_o45.json new file mode 100644 index 000000000..e432e2355 --- /dev/null +++ b/examples/minicpm_o45/config_minicpm_o45.json @@ -0,0 +1,285 @@ +{ + "architectures": [ + "MiniCPMO" + ], + "version": "4.5", + "attention_bias": false, + "attention_dropout": 0.0, + "audio_chunk_length": 1.0, + "audio_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "openai/whisper-medium", + "activation_dropout": 0.0, + "activation_function": "gelu", + "apply_spec_augment": false, + "architectures": [ + "MiniCPMWhisperEncoder" + ], + "attention_dropout": 0.0, + "begin_suppress_tokens": [ + 220, + 50257 + ], + "bos_token_id": 50257, + "classifier_proj_size": 256, + "d_model": 1024, + "decoder_attention_heads": 16, + "decoder_ffn_dim": 4096, + "decoder_layerdrop": 0.0, + "decoder_layers": 24, + "decoder_start_token_id": 50258, + "dropout": 0.0, + "encoder_attention_heads": 16, + "encoder_ffn_dim": 4096, + "encoder_layerdrop": 0.0, + "encoder_layers": 24, + "eos_token_id": 50257, + "forced_decoder_ids": [ + [ + 1, + 50259 + ], + [ + 2, + 50359 + ], + [ + 3, + 50363 + ] + ], + "init_std": 0.02, + "mask_feature_length": 10, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_prob": 0.05, + "max_length": 448, + "max_source_positions": 1500, + "max_target_positions": 448, + "median_filter_width": 7, + "model_type": "whisper", + "num_hidden_layers": 24, + "num_mel_bins": 80, + "pad_token_id": 50257, + "scale_embedding": false, + "suppress_tokens": [ + 1, + 2, + 7, + 8, + 9, + 10, + 14, + 25, + 26, + 27, + 28, + 29, + 31, + 58, + 59, + 60, + 61, + 62, + 63, + 90, + 91, + 92, + 93, + 359, + 503, + 522, + 542, + 873, + 893, + 902, + 918, + 922, + 931, + 1350, + 1853, + 1982, + 2460, + 2627, + 3246, + 3253, + 3268, + 3536, + 3846, + 3961, + 4183, + 4667, + 6585, + 6647, + 7273, + 9061, + 9383, + 10428, + 10929, + 11938, + 12033, + 12331, + 12562, + 13793, + 14157, + 14635, + 15265, + 15618, + 16553, + 16604, + 18362, + 18956, + 20075, + 21675, + 22520, + 26130, + 26161, + 26435, + 28279, + 29464, + 31650, + 32302, + 32470, + 36865, + 42863, + 47425, + 49870, + 50254, + 50258, + 50358, + 50359, + 50360, + 50361, + 50362 + ], + "torch_dtype": "float32", + "use_cache": true, + "use_weighted_layer_sum": false, + "vocab_size": 51865 + }, + "audio_pool_step": 5, + "auto_map": { + "AutoConfig": "configuration_minicpmo.MiniCPMOConfig", + "AutoModel": "modeling_minicpmo.MiniCPMO", + "AutoModelForCausalLM": "modeling_minicpmo.MiniCPMO" + }, + "batch_vision_input": true, + "bos_token_id": 151643, + "drop_vision_last_layer": false, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "image_size": 448, + "init_audio": true, + "init_tts": true, + "init_vision": true, + "initializer_range": 0.02, + "intermediate_size": 12288, + "listen_speak_type": "asr", + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "minicpmo", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "patch_size": 14, + "query_num": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "slice_config": { + "max_slice_nums": 1, + "model_type": "minicpmv", + "patch_size": 14, + "scale_resolution": 448 + }, + "slice_mode": true, + "sliding_window": null, + "stream_input": true, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "tts_config": { + "_attn_implementation_autoset": true, + "attention_type": "full_attention", + "attn_implementation": "sdpa", + "audio_bos_token_id": 151687, + "audio_tokenizer_sample_rate": 16000, + "audio_tokenizer_type": "s3tokenizer", + "aug_layer_loss_weight": false, + "aug_loss_weight": false, + "backbone_model": "llama", + "condition_type": "hidden_text_merge", + "cosyvoice_config_path": null, + "cosyvoice_model_dir": null, + "filter_tts_loss": false, + "hidden_act": "silu", + "hidden_size": 768, + "interleaved": false, + "intermediate_size": 3072, + "llm_dim": 4096, + "llm_dim_model_base": 256, + "llm_down_scale": false, + "llm_hidden_size": 4096, + "llm_intermediate_size": 768, + "long_weight": 0.1, + "max_position_embeddings": 4096, + "model_type": "minicpmtts", + "normalize_projected_hidden": true, + "num_attention_heads": 12, + "num_audio_tokens": 6562, + "num_hidden_layers": 20, + "num_key_value_heads": 12, + "num_mel_bins": 100, + "num_text_tokens": 152064, + "num_vq": 1, + "projector_type": "mlp", + "recomputed_chunks": 1, + "s3_stream_chunk_size": 25, + "s3_stream_generate": false, + "s3_stream_n_timesteps": 10, + "s3_stream_prelook_size": 3, + "short_weight": 0.1, + "streaming": false, + "streaming_audio_chunk_size": 50, + "streaming_sliding_window": false, + "streaming_sliding_window_audio_frame_rate": 50, + "streaming_sliding_window_audio_init_text_length": 10, + "streaming_sliding_window_audio_window_size": 300, + "streaming_sliding_window_average_speed": 5, + "streaming_sliding_window_fast_speed": 7, + "streaming_sliding_window_max_text_len": 500, + "streaming_sliding_window_slow_speed": 3, + "streaming_sliding_window_text_window_size": 50, + "streaming_text_chunk_max": 7, + "streaming_text_chunk_min": 3, + "streaming_text_reserved_len": 300, + "text_eos_token_id": 151692, + "tts_filter_loss_fix": false, + "use_llm_hidden_state": false, + "use_text": true, + "window_size": 2 + }, + "use_cache": true, + "use_image_id": true, + "use_sliding_window": false, + "vision_batch_size": 16, + "vision_config": { + "_attn_implementation_autoset": true, + "attention_dropout": 0.0, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 980, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "patch_size": 14 + }, + "vocab_size": 151748 +} diff --git a/examples/minicpm_o45/main.cpp b/examples/minicpm_o45/main.cpp new file mode 100644 index 000000000..482428038 --- /dev/null +++ b/examples/minicpm_o45/main.cpp @@ -0,0 +1,300 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include + +#include "mllm/mllm.hpp" +#include "mllm/models/minicpm_o45/configuration_minicpm_o45.hpp" +#include "mllm/models/minicpm_o45/modeling_minicpm_o45.hpp" +#include "mllm/models/minicpm_o45/modeling_minicpm_o45_token2wav.hpp" +#include "mllm/models/minicpm_o45/tokenization_minicpm_o45.hpp" +#include "mllm/models/minicpm_o45/token2wav_prompt_cache.hpp" + +#include "wenet_audio/wav.h" + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").def(""); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version: v1/v2").def("v1"); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer path (tokenizer.json)").def(""); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").def(""); + auto& prompt = Argparse::add("-p|--prompt").help("Prompt text").def("Describe the input."); + auto& image_path = Argparse::add("-i|--image").help("Optional image path").def(""); + auto& audio_path = Argparse::add("-a|--audio").help("Optional audio path (wav)").def(""); + auto& generate_tts_tokens = Argparse::add("-gt|--generate_tts_tokens") + .help("Generate TTS tokens (text->tts-token stage, no waveform)") + .def(false); + auto& text_max_new_tokens = Argparse::add("--text_max_new_tokens").help("Max new text tokens").def(512); + auto& tts_max_new_tokens = Argparse::add("--tts_max_new_tokens").help("Max new TTS tokens").def(1024); + auto& tts_min_new_tokens = Argparse::add("--tts_min_new_tokens").help("Min new TTS tokens").def(50); + auto& tts_force_no_stop = Argparse::add("--tts_force_no_stop").help("Disable TTS EOS stopping").def(false); + auto& tts_temperature = Argparse::add("--tts_temperature").help("TTS sampling temperature").def(0.8f); + auto& tts_top_k = Argparse::add("--tts_top_k").help("TTS top-k sampling (<=0 disables)").def(25); + auto& tts_top_p = Argparse::add("--tts_top_p").help("TTS top-p sampling (<=0 or >=1 disables)").def(0.85f); + auto& tts_repetition_penalty = + Argparse::add("--tts_repetition_penalty").help("TTS repetition penalty (1.0 disables)").def(1.05f); + auto& tts_repetition_window = + Argparse::add("--tts_repetition_window").help("TTS repetition window size in generated tokens").def(16); + auto& tts_greedy = Argparse::add("--tts_greedy").help("Use greedy decoding for TTS tokens").def(false); + auto& tts_tokens_out = Argparse::add("--tts_tokens_out").help("Output path for generated TTS token ids").def(""); + auto& tts_tokens_in = + Argparse::add("--tts_tokens_in").help("Input path for pre-generated TTS token ids (one per line or whitespace).").def(""); + auto& tts_wav_out = Argparse::add("--tts_wav_out") + .help("Output wav path. If set, run native C++ token2wav.") + .def(""); + auto& tts_token2wav_model_path = Argparse::add("--tts_token2wav_model_path") + .help("Path to token2wav .mllm (if empty, fallback to --model_path).") + .def(""); + auto& tts_token2wav_model_version = Argparse::add("--tts_token2wav_model_version") + .help("token2wav model version: v1/v2") + .def("v1"); + auto& tts_prompt_cache = Argparse::add("--tts_prompt_cache") + .help("Path to fixed prompt cache generated by export_prompt_cache.py") + .def(""); + auto& tts_token2wav_n_timesteps = Argparse::add("--tts_token2wav_n_timesteps") + .help("Flow diffusion steps for native token2wav") + .def(10); + auto& debug_progress = Argparse::add("--debug_progress").help("Print step-level debug progress.").def(false); + auto& debug_interval = + Argparse::add("--debug_interval").help("Token step interval for debug progress logs.").def(16); + + Argparse::parse(argc, argv); + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v2") { file_version = mllm::ModelFileVersion::kV2; } + + auto token2wav_model_path = tts_token2wav_model_path.get().empty() ? model_path.get() : tts_token2wav_model_path.get(); + mllm::ModelFileVersion token2wav_file_version = mllm::ModelFileVersion::kV1; + if (tts_token2wav_model_version.get() == "v2") { token2wav_file_version = mllm::ModelFileVersion::kV2; } + + auto run_native_token2wav = !tts_wav_out.get().empty(); + if (run_native_token2wav && tts_prompt_cache.get().empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "--tts_prompt_cache is required when --tts_wav_out is set."); + } + + auto debug_t0 = std::chrono::steady_clock::now(); + auto debug_log = [&](const std::string& msg) { + if (!debug_progress.get()) { return; } + auto now = std::chrono::steady_clock::now(); + auto sec = std::chrono::duration_cast(now - debug_t0).count() / 1000.0; + fmt::print("[debug +{:.3f}s] {}\n", sec, msg); + }; + + if (!tts_tokens_in.get().empty()) { + if (!run_native_token2wav) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "--tts_wav_out is required when --tts_tokens_in is set."); + } + if (token2wav_model_path.empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "Missing token2wav model path (--tts_token2wav_model_path or --model_path)."); + } + + std::ifstream ifs(tts_tokens_in.get()); + if (!ifs.is_open()) { MLLM_ERROR_EXIT(mllm::ExitCode::kIOError, "Failed to open token file: {}", tts_tokens_in.get()); } + std::vector token_ids; + for (std::string line; std::getline(ifs, line);) { + if (line.empty()) { continue; } + std::stringstream ss(line); + while (!ss.eof()) { + int64_t token = 0; + ss >> token; + if (!ss.fail()) { token_ids.push_back(token); } + } + } + if (token_ids.empty()) { MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No token id found in {}", tts_tokens_in.get()); } + + fmt::print("Loaded {} TTS token IDs from {}\n", token_ids.size(), tts_tokens_in.get()); + debug_log("Loading token2wav model and prompt cache..."); + auto token2wav_param = mllm::load(token2wav_model_path, token2wav_file_version); + auto prompt_cache = mllm::models::minicpm_o45::loadMiniCPMO45Token2WavPromptCache(tts_prompt_cache.get()); + + mllm::models::minicpm_o45::MiniCPMO45Token2WavModel token2wav("token2wav", {}); + token2wav.loadFromParameter(token2wav_param); + debug_log("Native token2wav model loaded."); + + debug_log("Running native flow + HiFT..."); + auto wav = token2wav.infer(token_ids, prompt_cache, std::max(1, tts_token2wav_n_timesteps.get())); + auto wav_i16 = wav * 32767.0f; + wenet::WavWriter wav_writer(wav_i16.ptr(), wav_i16.shape().back(), 1, 24000, 16); + wav_writer.Write(tts_wav_out.get()); + fmt::print("Saved TTS waveform to {}\n", tts_wav_out.get()); + debug_log("Native token2wav finished."); + mllm::shutdownContext(); + return 0; + } + + if (model_path.get().empty() || tokenizer_path.get().empty() || config_path.get().empty()) { + Argparse::printHelp(); + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, + "Missing required arguments: --model_path, --tokenizer_path, --config_path"); + } + + auto cfg = mllm::models::minicpm_o45::MiniCPMO45Config(config_path.get()); + + debug_log("Loading tokenizer and model modules..."); + auto tokenizer = mllm::models::minicpm_o45::MiniCPMO45Tokenizer(tokenizer_path.get(), cfg.vision_patch_size, cfg.audio_pool_step); + auto model = mllm::models::minicpm_o45::MiniCPMO45ForCausalLM(cfg); + + debug_log("Loading model parameters..."); + auto param = mllm::load(model_path.get(), file_version); + model.llm_.load(param); + model.vpm_.load(param); + model.resampler_.load(param); + model.apm_.load(param); + model.audio_projection_layer_.load(param); + if (generate_tts_tokens.get()) { model.tts_.loadFromParameter(param); } + debug_log("Model parameters loaded."); + + mllm::models::minicpm_o45::MiniCPMO45Message message; + message.prompt = prompt.get(); + message.img_file_path = image_path.get(); + message.audio_file_path = audio_path.get(); + + auto inputs = tokenizer.convertMessage(message, generate_tts_tokens.get()); + debug_log("Tokenizer convertMessage finished."); + + fmt::print("\n{:*^60}\n", " MiniCPM-o-4_5 CLI "); + fmt::print("Prompt: {}\n", message.prompt); + if (!message.img_file_path.empty()) { fmt::print("Image : {}\n", message.img_file_path); } + if (!message.audio_file_path.empty()) { fmt::print("Audio : {}\n", message.audio_file_path); } + + if (!generate_tts_tokens.get()) { + fmt::print("\nResponse: "); + for (auto& step : model.chat(inputs)) { + std::wcout << tokenizer.detokenize(step.cur_token_id) << std::flush; + } + fmt::print("\n"); + } else { + auto tts_eos_id = tokenizer.lookupTokenId(L"<|tts_eos|>"); + auto im_end_id = tokenizer.lookupTokenId(L"<|im_end|>"); + auto eot_id = tokenizer.lookupTokenId(L"<|endoftext|>"); + + std::vector stop_token_ids = { + tts_eos_id, + im_end_id, + eot_id, + cfg.eos_token_id, + }; + + debug_log("Start text generation for TTS conditioning..."); + auto text_out = model.generateTextWithHidden( + inputs, text_max_new_tokens.get(), stop_token_ids, false, 1.0f, 0, 0.0f, + [&](int32_t step, int64_t token_id) { + auto interval = std::max(debug_interval.get(), 1); + if (debug_progress.get() && (step == 1 || (step % interval) == 0)) { + debug_log(fmt::format("Text generation step {} (token_id={})", step, token_id)); + } + }); + debug_log(fmt::format("Text generation done, generated_tokens={}", text_out.generated_tokens.size())); + + fmt::print("\nGenerated text tokens: {}\n", text_out.generated_tokens.size()); + fmt::print("Text (for TTS conditioning): "); + + std::vector tts_text_tokens; + std::vector tts_hidden_states; + for (size_t i = 0; i < text_out.aligned_tokens.size() && i < text_out.aligned_hidden_states.size(); ++i) { + auto token_id = text_out.aligned_tokens[i]; + if (token_id == tts_eos_id || token_id == im_end_id || token_id == eot_id || token_id == cfg.eos_token_id) { break; } + tts_text_tokens.push_back(token_id); + tts_hidden_states.push_back(text_out.aligned_hidden_states[i]); + std::wcout << tokenizer.detokenize(token_id) << std::flush; + } + fmt::print("\n"); + + if (tts_text_tokens.empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, + "No text token available before <|tts_eos|>/<|im_end|>; cannot build TTS condition."); + } + + auto condition_embeds = model.tts_.makeConditionEmbeddings(tts_text_tokens, tts_hidden_states); + if (condition_embeds.isNil()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "Failed to build TTS conditioning embeddings."); + } + debug_log(fmt::format("Built TTS condition embeddings from {} text tokens.", tts_text_tokens.size())); + + mllm::models::minicpm_o45::MiniCPMO45TTSGenerationConfig tts_cfg; + tts_cfg.max_new_tokens = tts_max_new_tokens.get(); + tts_cfg.min_new_tokens = tts_min_new_tokens.get(); + tts_cfg.force_no_stop = tts_force_no_stop.get(); + tts_cfg.do_sample = !tts_greedy.get(); + tts_cfg.temperature = {tts_temperature.get()}; + tts_cfg.top_k = tts_top_k.get(); + tts_cfg.top_p = tts_top_p.get(); + tts_cfg.repetition_penalty = tts_repetition_penalty.get(); + tts_cfg.repetition_penalty_window = tts_repetition_window.get(); + tts_cfg.debug_interval = std::max(debug_interval.get(), 1); + if (debug_progress.get()) { + tts_cfg.step_callback = [&](int32_t step, const std::vector& tokens, bool has_eos) { + auto first_token = tokens.empty() ? -1 : tokens[0]; + debug_log(fmt::format("TTS generation step {} (first_vq_token={}, has_eos={})", step, first_token, + has_eos ? "true" : "false")); + }; + } + + debug_log("Start TTS token generation..."); + auto tts_out = model.tts_.generate(condition_embeds, tts_cfg); + debug_log("TTS token generation finished."); + if (tts_out.new_ids.isNil()) { + fmt::print("Generated TTS tokens: 0\n"); + } else { + auto token_count = tts_out.new_ids.shape()[1]; + fmt::print("Generated TTS tokens: {} (finished={})\n", token_count, tts_out.finished ? "true" : "false"); + + std::vector token_ids; + token_ids.reserve(token_count); + for (int32_t i = 0; i < token_count; ++i) { token_ids.push_back(tts_out.new_ids.at({0, i, 0})); } + + fmt::print("TTS token IDs:\n"); + for (size_t i = 0; i < token_ids.size(); ++i) { + fmt::print("{}{}", token_ids[i], (i + 1 == token_ids.size() ? "\n" : " ")); + } + + if (!tts_tokens_out.get().empty()) { + std::ofstream ofs(tts_tokens_out.get()); + if (!ofs.is_open()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kIOError, "Failed to open output file: {}", tts_tokens_out.get()); + } + for (auto id : token_ids) { ofs << std::to_string(id) << '\n'; } + fmt::print("Saved TTS token ids to {}\n", tts_tokens_out.get()); + debug_log(fmt::format("Saved token ids to {}", tts_tokens_out.get())); + } + + if (!tts_wav_out.get().empty()) { + debug_log("Loading token2wav model and prompt cache..."); + auto token2wav_param = mllm::load(token2wav_model_path, token2wav_file_version); + auto prompt_cache = mllm::models::minicpm_o45::loadMiniCPMO45Token2WavPromptCache(tts_prompt_cache.get()); + + mllm::models::minicpm_o45::MiniCPMO45Token2WavModel token2wav("token2wav", {}); + token2wav.loadFromParameter(token2wav_param); + debug_log("Native token2wav model loaded."); + + debug_log("Running native flow + HiFT..."); + auto wav = token2wav.infer(token_ids, prompt_cache, std::max(1, tts_token2wav_n_timesteps.get())); + auto wav_i16 = wav * 32767.0f; + wenet::WavWriter wav_writer(wav_i16.ptr(), wav_i16.shape().back(), 1, 24000, 16); + wav_writer.Write(tts_wav_out.get()); + fmt::print("Saved TTS waveform to {}\n", tts_wav_out.get()); + debug_log("Native token2wav finished."); + } + } + } + + model.perfSummary(); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/CMakeLists.txt b/examples/qwen2_5omni/CMakeLists.txt new file mode 100644 index 000000000..2fdd3690f --- /dev/null +++ b/examples/qwen2_5omni/CMakeLists.txt @@ -0,0 +1,19 @@ +add_executable(mllm-qwen2_5-omni-text-runner text_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-text-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-text-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-image-runner image_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-image-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-image-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-audio-runner audio_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-audio-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-audio-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-audio-out-runner audio_out_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-audio-out-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-audio-out-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-image-runner-dbg image_infer_dbg.cpp) +target_link_libraries(mllm-qwen2_5-omni-image-runner-dbg PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-image-runner-dbg PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen2_5omni/audio_infer.cpp b/examples/qwen2_5omni/audio_infer.cpp new file mode 100644 index 000000000..d159c2b3e --- /dev/null +++ b/examples/qwen2_5omni/audio_infer.cpp @@ -0,0 +1,84 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Audio CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string audio_path; + std::string prompt_text; + + fmt::print("Audio path (or 'exit/quit'): "); + //std::getline(std::cin, audio_path); + //if (audio_path == "exit" || audio_path == "quit") { return 0; } + audio_path = ""; + + fmt::print("Prompt text: "); + //std::getline(std::cin, prompt_text); + //if (prompt_text.empty()) { prompt_text = "Please describe the audio."; } + prompt_text = ""; + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertAudioMessage({.prompt = prompt_text, .audio_file_path = audio_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/audio_out_infer.cpp b/examples/qwen2_5omni/audio_out_infer.cpp new file mode 100644 index 000000000..9e46fcd0e --- /dev/null +++ b/examples/qwen2_5omni/audio_out_infer.cpp @@ -0,0 +1,93 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include "wenet_audio/wav.h" + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + auto& spk_dict_path = Argparse::add("-s|--spk_dict_path").help("Speaker json path").required(true); + auto& prompt = Argparse::add("-p|--prompt").help("Prompt text").def(""); + auto& image_path = Argparse::add("-i|--image_path").help("Image path").def(""); + auto& audio_path = Argparse::add("-a|--audio_path").help("Audio path").def(""); + auto& speaker = Argparse::add("-sp|--speaker").help("Speaker name (default: first entry)").def(""); + auto& output_path = Argparse::add("-o|--output_path").help("Output wav path").def("./qwen2_5omni.wav"); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + if (!image_path.get().empty() && !audio_path.get().empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "Only one of --image_path or --audio_path can be set."); + } + + auto qwen_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen_omni = mllm::models::qwen2_5omni::Qwen2_5OmniForConditionalGeneration(qwen_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen_omni.load(param); + qwen_omni.loadSpeakers(spk_dict_path.get()); + + std::string prompt_text = prompt.get(); + if (prompt_text.empty()) { + fmt::print("Prompt text: "); + std::getline(std::cin, prompt_text); + if (prompt_text.empty()) { prompt_text = "Please respond."; } + } + + mllm::models::ARGenerationOutputPast inputs; + if (!image_path.get().empty()) { + inputs = qwen_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path.get()}); + } else if (!audio_path.get().empty()) { + inputs = qwen_tokenizer.convertAudioMessage({.prompt = prompt_text, .audio_file_path = audio_path.get()}); + } else { + inputs = qwen_tokenizer.convertMessage({.prompt = prompt_text}); + } + + mllm::models::qwen2_5omni::Qwen2_5OmniAudioGenerationConfig gen_cfg; + auto output = qwen_omni.generateAudio(inputs, gen_cfg, speaker.get()); + + auto input_len = inputs["sequence"].shape()[1]; + auto total_len = output.sequences.shape()[1]; + fmt::print("\nResponse: "); + for (int i = input_len; i < total_len; ++i) { + std::wcout << qwen_tokenizer.detokenize(output.sequences.at({0, i})) << std::flush; + } + fmt::print("\n"); + + auto wav = output.wav * 32767.0f; + wenet::WavWriter wav_writer(wav.ptr(), wav.shape().back(), 1, 24000, 16); + wav_writer.Write(output_path.get()); + + fmt::print("Saved audio to {}\n", output_path.get()); + + qwen_omni.thinker_.perfSummary(); + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/config_qwen2_5omni_7B.json b/examples/qwen2_5omni/config_qwen2_5omni_7B.json new file mode 100644 index 000000000..8f27b94b9 --- /dev/null +++ b/examples/qwen2_5omni/config_qwen2_5omni_7B.json @@ -0,0 +1,495 @@ +{ + "architectures": [ + "Qwen2_5OmniModel" + ], + "enable_audio_output": true, + "enable_talker": true, + "model_type": "qwen2_5_omni", + "talker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/talker", + "architectures": [ + "Qwen2OmniTalkerForConditionalGeneration" + ], + "attention_dropout": 0.0, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "embedding_size": 3584, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 896, + "image_token_index": 151655, + "init_std": 0.02, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2_5_omni_talker", + "num_attention_heads": 12, + "num_hidden_layers": 24, + "num_key_value_heads": 4, + "position_id_per_seconds": 25, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "seconds_per_chunk": 2, + "sliding_window": 32768, + "spatial_merge_size": 2, + "torch_dtype": "bfloat16", + "tts_codec_end_token_id": 8294, + "tts_codec_mask_token_id": 8296, + "tts_codec_pad_token_id": 8292, + "tts_codec_start_token_id": 8293, + "tts_text_end_token_id": 151861, + "tts_text_pad_token_id": 151859, + "tts_text_start_token_id": 151860, + "use_cache": true, + "use_sliding_window": false, + "video_token_index": 151656, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vocab_size": 8448 + }, + "thinker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/thinker", + "architectures": [ + "Qwen2OmniNaViTThinkerForConditionalGeneration" + ], + "audio_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "activation_dropout": 0.0, + "activation_function": "gelu", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "d_model": 1280, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_attention_heads": 20, + "encoder_ffn_dim": 5120, + "encoder_layerdrop": 0.0, + "encoder_layers": 32, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "init_std": 0.02, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_source_positions": 1500, + "min_length": 0, + "model_type": "qwen2_5_omni_audio_encoder", + "n_window": 100, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 32, + "num_mel_bins": 128, + "num_return_sequences": 1, + "output_attentions": false, + "output_dim": 3584, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "scale_embedding": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "text_config": { + "model_type": "qwen2_5_omni_text", + "hidden_act": "silu", + "hidden_size": 3584, + "init_std": 0.02, + "intermediate_size": 18944, + "vocab_size": 152064, + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "use_cache": true, + "rope_theta": 1000000.0, + "use_sliding_window": false, + "sliding_window": 32768, + "attention_dropout": 0.0, + "tie_word_embeddings": false + }, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "bos_token_id": 151644, + "eos_token_id": 151645, + "ignore_index": -100, + "image_token_index": 151655, + "init_std": 0.02, + "model_type": "qwen2_5_omni_thinker", + "pad_token_id": 151643, + "position_id_per_seconds": 25, + "seconds_per_chunk": 2, + "torch_dtype": "bfloat16", + "user_token_id": 872, + "video_token_index": 151656, + "vision_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 32, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "embed_dim": 1280, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "in_channels": 3, + "in_chans": 3, + "init_std": 0.02, + "intermediate_size": 3420, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "qwen2_5_omni_vision_encoder", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_heads": 16, + "num_return_sequences": 1, + "out_hidden_size": 3584, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "patch_size": 14, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "temporal_patch_size": 2, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "tokens_per_second": 25, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false, + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654 + }, + "token2wav_config": { + "_attn_implementation_autoset": true, + "bigvgan_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_bigvgan", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "upsample_initial_channel": 1536, + "upsample_kernel_sizes": [ + 11, + 7, + 4, + 4, + 4, + 4 + ], + "upsample_rates": [ + 5, + 3, + 2, + 2, + 2, + 2 + ], + "use_bfloat16": false, + "use_bias_at_final": false + }, + "dit_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 22, + "dim": 1024, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.1, + "early_stopping": false, + "emb_dim": 512, + "enc_attention_channels": 64, + "enc_channels": [ + 256, + 256, + 256, + 256, + 768 + ], + "enc_dilations": [ + 1, + 2, + 3, + 4, + 1 + ], + "enc_dim": 128, + "enc_emb_dim": 192, + "enc_global_context": true, + "enc_kernel_sizes": [ + 5, + 3, + 3, + 3, + 1 + ], + "enc_lin_neurons": 192, + "enc_res2net_scale": 2, + "enc_se_channels": 64, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "ff_mult": 2, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "head_dim": 64, + "heads": 16, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_dit", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_embeds": 8193, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repeats": 2, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "float32", + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "model_type": "qwen2_5_omni_token2wav" + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0.dev0" +} diff --git a/examples/qwen2_5omni/image_infer.cpp b/examples/qwen2_5omni/image_infer.cpp new file mode 100644 index 000000000..473f7de60 --- /dev/null +++ b/examples/qwen2_5omni/image_infer.cpp @@ -0,0 +1,84 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = + mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get(), qwen2_5omni_cfg.visual_spatial_merge_size); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Image CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string image_path; + std::string prompt_text; + + fmt::print("Image path (or 'exit/quit'): "); + image_path = "../../rsc/pics.jpg"; + //std::getline(std::cin, image_path); + if (image_path == "exit" || image_path == "quit") { return 0; } + + fmt::print("Prompt text: "); + prompt_text = "描述图片中物体"; + //std::getline(std::cin, prompt_text); + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/image_infer_dbg.cpp b/examples/qwen2_5omni/image_infer_dbg.cpp new file mode 100644 index 000000000..de21c8ec7 --- /dev/null +++ b/examples/qwen2_5omni/image_infer_dbg.cpp @@ -0,0 +1,91 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +//MLLM_MAIN({ +int main(int argc, char** argv) { + ::mllm::__setup_signal_handler(); + ::mllm::initializeContext(); + + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = + mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get(), qwen2_5omni_cfg.visual_spatial_merge_size); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Image CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string image_path; + std::string prompt_text; + + fmt::print("Image path (or 'exit/quit'): "); + image_path = "../../rsc/pics.jpg"; + //std::getline(std::cin, image_path); + if (image_path == "exit" || image_path == "quit") { return 0; } + + fmt::print("Prompt text: "); + prompt_text = "描述图片中物体"; + //std::getline(std::cin, prompt_text); + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); + + ::mllm::shutdownContext(); + return 0; +} diff --git a/examples/qwen2_5omni/text_infer.cpp b/examples/qwen2_5omni/text_infer.cpp new file mode 100644 index 000000000..299a0e07d --- /dev/null +++ b/examples/qwen2_5omni/text_infer.cpp @@ -0,0 +1,72 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Text CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string prompt_text; + + fmt::print("💬 Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + + if (prompt_text == "exit" || prompt_text == "quit") { return 0; } + + try { + fmt::print("🔄 Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertMessage({.prompt = prompt_text}); + + fmt::print("\n🤖 Response: "); + for (auto& step : qwen2_5omni.chat(inputs)) { + std::wcout << qwen2_5omni_tokenizer.detokenize(step.cur_token_id) << std::flush; + } + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index 0964cba0d..f4b909913 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -14,6 +14,7 @@ #include "mllm/backends/cpu/ops/ConcatOp.hpp" #include "mllm/backends/cpu/ops/ContiguousOp.hpp" #include "mllm/backends/cpu/ops/Conv1DOp.hpp" +#include "mllm/backends/cpu/ops/ConvTranspose1DOp.hpp" #include "mllm/backends/cpu/ops/Conv2DOp.hpp" #include "mllm/backends/cpu/ops/Conv3DOp.hpp" #include "mllm/backends/cpu/ops/CopyOp.hpp" @@ -52,6 +53,7 @@ #include "mllm/backends/cpu/ops/Scatter2ShardsOp.hpp" #include "mllm/backends/cpu/ops/SiLUOp.hpp" #include "mllm/backends/cpu/ops/SigmoidOp.hpp" +#include "mllm/backends/cpu/ops/TanhOp.hpp" #include "mllm/backends/cpu/ops/SliceOp.hpp" #include "mllm/backends/cpu/ops/SoftmaxOp.hpp" #include "mllm/backends/cpu/ops/SplitOp.hpp" @@ -78,12 +80,12 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { CPUSiLUOpFactory, CPUSigmoidOpFactory, CPURMSNormOpFactory, CPUGELUOpFactory, CPUQuickGELUOpFactory, CPUReLUOpFactory, CPUMatMulOpFactory, CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory, CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, - CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, - CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, - CPUConv2DOpFactory, CPULayerNorm2DOpFactory, CPUInterpolateOpFactory, CPUPadOpFactory, CPUMaskedScatterOpFactory, - CPUArgsortOpFactory, CPUCloneOpFactory, CPUAvgPool1dOpFactory, CPUFlashAttention2SwaSinkOpFactory, - CPURadixAttnRelaxOpFactory, CPURadixAttnSwaSinkOpFactory, CPUEqualOpFactory, CPUWhereOpFactory, - CPUGatherOpFactory>(); + CPUConvTranspose1DOpFactory, CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, + CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, + CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, CPUConv2DOpFactory, CPULayerNorm2DOpFactory, + CPUInterpolateOpFactory, CPUPadOpFactory, CPUMaskedScatterOpFactory, CPUArgsortOpFactory, CPUCloneOpFactory, + CPUAvgPool1dOpFactory, CPUFlashAttention2SwaSinkOpFactory, CPURadixAttnRelaxOpFactory, + CPURadixAttnSwaSinkOpFactory, CPUEqualOpFactory, CPUWhereOpFactory, CPUGatherOpFactory, CPUTanhOpFactory>(); } CPUBackend::~CPUBackend() { diff --git a/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp b/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp new file mode 100644 index 000000000..15a8097d1 --- /dev/null +++ b/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp @@ -0,0 +1,93 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/backends/cpu/ops/ConvTranspose1DOp.hpp" +#include "mllm/core/Parallel.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::cpu { + +CPUConvTranspose1DOp::CPUConvTranspose1DOp(const aops::ConvTranspose1DOpOptions& options) + : aops::ConvTranspose1DOp(options) {} + +void CPUConvTranspose1DOp::forward(const std::vector& inputs, std::vector& outputs) { + auto& input = inputs[0]; + auto& output = outputs[0]; + + auto i_shape = input.shape(); + auto o_shape = output.shape(); + + // input shape: [batch, in_channels, sequence] + // output shape: [batch, out_channels, out_sequence] + const int batch = i_shape[0]; + const int in_channels = i_shape[1]; + const int sequence = i_shape[2]; + + const int out_channels = o_shape[1]; + const int out_sequence = o_shape[2]; + + const int kernel_size = options_.kernel_size; + const int stride = options_.stride; + const int padding = options_.padding; + const int dilation = options_.dilation; + const int groups = options_.groups; + + const int in_channels_per_group = in_channels / groups; + const int out_channels_per_group = out_channels / groups; + + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(output.dtype(), kFloat32); + MLLM_RT_ASSERT(weight_.dtype() == kFloat32); + const auto* weight_ptr = weight_.ptr(); + const auto* input_ptr = input.ptr(); + auto* output_ptr = output.ptr(); + + float* bias_ptr = nullptr; + if (options_.bias && !bias_.isNil()) { bias_ptr = bias_.ptr(); } + + std::fill_n(output_ptr, output.numel(), 0.0f); + + const int total_iterations = batch * out_channels * out_sequence; + + switch (output.dtype()) { + case kFloat32: + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, 4, idx, 0, total_iterations, 1, { + int b = idx / (out_channels * out_sequence); + int oc = (idx % (out_channels * out_sequence)) / out_sequence; + int out_pos = idx % out_sequence; + + const int group_idx = oc / out_channels_per_group; + const int oc_in_group = oc % out_channels_per_group; + + float sum = 0.0f; + + for (int ic_in_group = 0; ic_in_group < in_channels_per_group; ++ic_in_group) { + const int ic = group_idx * in_channels_per_group + ic_in_group; + const int base_input_idx = b * (in_channels * sequence) + ic * sequence; + + const int base_weight_idx = (ic * out_channels_per_group + oc_in_group) * kernel_size; + + for (int k = 0; k < kernel_size; ++k) { + int input_pos = out_pos + padding - k * dilation; + if (input_pos % stride != 0) { continue; } + input_pos /= stride; + if (input_pos < 0 || input_pos >= sequence) { continue; } + + const int input_idx = base_input_idx + input_pos; + const int weight_idx = base_weight_idx + k; + + sum += input_ptr[input_idx] * weight_ptr[weight_idx]; + } + } + + if (bias_ptr) { sum += bias_ptr[oc]; } + + const int output_idx = b * (out_channels * out_sequence) + oc * out_sequence + out_pos; + output_ptr[output_idx] = sum; + }); + break; + default: NYI("ConvTranspose1D: unsupported data type"); + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/ConvTranspose1DOp.hpp b/mllm/backends/cpu/ops/ConvTranspose1DOp.hpp new file mode 100644 index 000000000..fd1163ed3 --- /dev/null +++ b/mllm/backends/cpu/ops/ConvTranspose1DOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/ConvTranspose1DOp.hpp" + +namespace mllm::cpu { + +class CPUConvTranspose1DOp final : public aops::ConvTranspose1DOp { + public: + explicit CPUConvTranspose1DOp(const aops::ConvTranspose1DOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUConvTranspose1DOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::ConvTranspose1DOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/EmbeddingOp.cpp b/mllm/backends/cpu/ops/EmbeddingOp.cpp index 71af75f68..f25849f7a 100644 --- a/mllm/backends/cpu/ops/EmbeddingOp.cpp +++ b/mllm/backends/cpu/ops/EmbeddingOp.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "mllm/backends/cpu/ops/EmbeddingOp.hpp" #include "mllm/core/DataTypes.hpp" @@ -22,30 +23,40 @@ void CPUEmbeddingOp::forward(const std::vector& inputs, std::vector 0); + MLLM_RT_ASSERT(options_.hidden_size > 0); + + static std::atomic warned_token_oob{false}; const bool use_parallel = options_.getThreads() > 1; const int thread_count = options_.getThreads(); for (int b = 0; b < B; ++b) { MLLM_CONDITIONAL_PARALLEL_FOR(use_parallel, thread_count, s, 0, S, 1, { - switch (weight_dtype) { - case kFloat32: - std::memcpy(ous.coffsettedPtr({b, (int)s, 0}), - weight_.ptr() + options_.hidden_size * (*ins.coffsettedPtr({b, (int)s})), - options_.hidden_size * sizeof(float)); - break; - case kFloat16: - std::memcpy(ous.coffsettedPtr({b, (int)s, 0}), - weight_.ptr() + options_.hidden_size * (*ins.coffsettedPtr({b, (int)s})), - options_.hidden_size * sizeof(mllm_fp16_t)); - break; - case kGGUF_Q4_K: { - auto token_idx = *ins.coffsettedPtr({b, (int)s}); - if (token_idx >= 0) { + const auto token_idx = *ins.coffsettedPtr({b, (int)s}); + auto* out_ptr = ous.coffsettedPtr({b, (int)s, 0}); + if (token_idx < 0 || token_idx >= options_.vocab_size) { + std::memset(out_ptr, 0, options_.hidden_size * bytesOfType(ous.dtype())); + bool expected = false; + if (warned_token_oob.compare_exchange_strong(expected, true)) { + MLLM_WARN("Embedding token index out of range (idx={}, vocab={}), output row is zero-filled.", + token_idx, options_.vocab_size); + } + } else { + switch (weight_dtype) { + case kFloat32: + std::memcpy(out_ptr, weight_.ptr() + options_.hidden_size * token_idx, + options_.hidden_size * sizeof(float)); + break; + case kFloat16: + std::memcpy(out_ptr, weight_.ptr() + options_.hidden_size * token_idx, + options_.hidden_size * sizeof(mllm_fp16_t)); + break; + case kGGUF_Q4_K: { dequantize_row_q4_K(weight_.ptr() + token_idx * options_.hidden_size / QK_K, ous.coffsettedPtr({b, (int)s, 0}), options_.hidden_size); + break; } - break; - } + case kGGUF_Q4_0: { auto token_idx = *ins.coffsettedPtr({b, (int)s}); if (token_idx >= 0) { @@ -56,6 +67,7 @@ void CPUEmbeddingOp::forward(const std::vector& inputs, std::vector + +#include "mllm/backends/cpu/ops/TanhOp.hpp" +#include "mllm/core/Parallel.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::cpu { + +CPUTanhOp::CPUTanhOp(const aops::TanhOpOptions& options) : aops::TanhOp(options) {} + +void CPUTanhOp::forward(const std::vector& inputs, std::vector& outputs) { + const auto& X = inputs[0]; + auto& Y = outputs[0]; + + const auto numel = X.numel(); + + switch (X.dtype()) { + case kFloat32: { + const auto* x_ptr = X.ptr(); + auto* y_ptr = Y.ptr(); + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, 4, idx, 0, numel, 1, { + y_ptr[idx] = std::tanh(x_ptr[idx]); + }); + break; + } + case kFloat16: { + const auto* x_ptr = X.ptr(); + auto* y_ptr = Y.ptr(); + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, 4, idx, 0, numel, 1, { + float v = static_cast(x_ptr[idx]); + y_ptr[idx] = static_cast(std::tanh(v)); + }); + break; + } + default: NYI("CPUTanhOp::forward not support dtype {}", nameOfType(X.dtype())); break; + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/TanhOp.hpp b/mllm/backends/cpu/ops/TanhOp.hpp new file mode 100644 index 000000000..c88fae9ce --- /dev/null +++ b/mllm/backends/cpu/ops/TanhOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/TanhOp.hpp" + +namespace mllm::cpu { + +class CPUTanhOp final : public aops::TanhOp { + public: + explicit CPUTanhOp(const aops::TanhOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUTanhOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::TanhOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index d100dc621..2d83493d1 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -44,6 +44,7 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_RepeatOp, RK_Op_LinalgIROp_PermuteOp, RK_Op_LinalgIROp_Conv1DOp, + RK_Op_LinalgIROp_ConvTranspose1DOp, RK_Op_LinalgIROp_Conv2DOp, RK_Op_LinalgIROp_Conv3DOp, RK_Op_LinalgIROp_GELUOp, @@ -86,6 +87,7 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_EqualOp, RK_Op_LinalgIROp_WhereOp, RK_Op_LinalgIROp_SigmoidOp, + RK_Op_LinalgIROp_TanhOp, RK_Op_LinalgIROp_CustomizedOp, RK_Op_LinalgIROp_Last, RK_Op_GraphIROp, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index 6a98797a9..4c3313cf9 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -102,6 +102,9 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_CONV1DOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_Conv1DOp && (v)->getKind() <= RK_Op_LinalgIROp_Conv1DOp +#define RTTI_RK_OP_LINALGIROP_CONVTRANSPOSE1DOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_ConvTranspose1DOp && (v)->getKind() <= RK_Op_LinalgIROp_ConvTranspose1DOp + #define RTTI_RK_OP_LINALGIROP_CONV2DOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_Conv2DOp && (v)->getKind() <= RK_Op_LinalgIROp_Conv2DOp @@ -229,6 +232,9 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_SIGMOIDOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_SigmoidOp && (v)->getKind() <= RK_Op_LinalgIROp_SigmoidOp +#define RTTI_RK_OP_LINALGIROP_TANHOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_TanhOp && (v)->getKind() <= RK_Op_LinalgIROp_TanhOp + #define RTTI_RK_OP_LINALGIROP_CUSTOMIZEDOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_CustomizedOp && (v)->getKind() <= RK_Op_LinalgIROp_CustomizedOp diff --git a/mllm/compile/ir/linalg/Op.cpp b/mllm/compile/ir/linalg/Op.cpp index bb4e2fb9d..ad05e9437 100644 --- a/mllm/compile/ir/linalg/Op.cpp +++ b/mllm/compile/ir/linalg/Op.cpp @@ -55,6 +55,7 @@ LINALG_AOPS_DECL(OpTypes::kTranspose, TransposeOp); LINALG_AOPS_DECL(OpTypes::kRMSNorm, RMSNormOp); LINALG_AOPS_DECL(OpTypes::kSiLU, SiLUOp); LINALG_AOPS_DECL(OpTypes::kSigmoid, SigmoidOp); +LINALG_AOPS_DECL(OpTypes::kTanh, TanhOp); LINALG_AOPS_DECL(OpTypes::kCastType, CastTypeOp); @@ -70,6 +71,7 @@ LINALG_AOPS_DECL(OpTypes::kRepeat, RepeatOp); LINALG_AOPS_DECL(OpTypes::kPermute, PermuteOp); LINALG_AOPS_DECL(OpTypes::kConv1D, Conv1DOp); +LINALG_AOPS_DECL(OpTypes::kConvTranspose1D, ConvTranspose1DOp); LINALG_AOPS_DECL(OpTypes::kConv2D, Conv2DOp); LINALG_AOPS_DECL(OpTypes::kConv3D, Conv3DOp); diff --git a/mllm/compile/ir/linalg/Op.hpp b/mllm/compile/ir/linalg/Op.hpp index 02d04400b..6e6de4785 100644 --- a/mllm/compile/ir/linalg/Op.hpp +++ b/mllm/compile/ir/linalg/Op.hpp @@ -29,6 +29,7 @@ class TransposeOp; class RMSNormOp; class SiLUOp; class SigmoidOp; +class TanhOp; class CausalMaskOp; class CastTypeOp; class X2XOp; @@ -38,6 +39,7 @@ class FlashAttention2Op; class RepeatOp; class PermuteOp; class Conv1DOp; +class ConvTranspose1DOp; class Conv2DOp; class Conv3DOp; class GELUOp; @@ -188,6 +190,7 @@ LINALG_AOPS_DEFINE(TransposeOp, TRANSPOSEOP); LINALG_AOPS_DEFINE(RMSNormOp, RMSNORMOP); LINALG_AOPS_DEFINE(SiLUOp, SILUOP); LINALG_AOPS_DEFINE(SigmoidOp, SIGMOIDOP); +LINALG_AOPS_DEFINE(TanhOp, TANHOP); LINALG_AOPS_DEFINE(CastTypeOp, CASTTYPEOP); @@ -201,6 +204,7 @@ LINALG_AOPS_DEFINE(RepeatOp, REPEATOP); LINALG_AOPS_DEFINE(PermuteOp, PERMUTEOP); LINALG_AOPS_DEFINE(Conv1DOp, CONV1DOP); +LINALG_AOPS_DEFINE(ConvTranspose1DOp, CONVTRANSPOSE1DOP); LINALG_AOPS_DEFINE(Conv2DOp, CONV2DOP); LINALG_AOPS_DEFINE(Conv3DOp, CONV3DOP); diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index 310b39cd0..d64d484fe 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -96,6 +96,8 @@ enum class OpTypes : int32_t { kWhere = 74, kSigmoid = 75, + kTanh = 76, + kConvTranspose1D = 77, // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, @@ -181,6 +183,8 @@ inline std::string optype2Str(OpTypes type) { case OpTypes::kEqual: return "Equal"; case OpTypes::kWhere: return "Where"; case OpTypes::kSigmoid: return "Sigmoid"; + case OpTypes::kTanh: return "Tanh"; + case OpTypes::kConvTranspose1D: return "ConvTranspose1D"; case OpTypes::kDynamicOp_Start: return "DynamicOp_Start"; case OpTypes::kOpType_End: return "OpType_End"; default: return "Unknown"; diff --git a/mllm/core/aops/ConvTranspose1DOp.cpp b/mllm/core/aops/ConvTranspose1DOp.cpp new file mode 100644 index 000000000..25d1b5935 --- /dev/null +++ b/mllm/core/aops/ConvTranspose1DOp.cpp @@ -0,0 +1,95 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/ConvTranspose1DOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/graph/Op.hpp" +#include "mllm/compile/ir/tensor/Op.hpp" + +namespace mllm::aops { + +ConvTranspose1DOp::ConvTranspose1DOp(const ConvTranspose1DOpOptions& options) + : BaseOp(OpTypes::kConvTranspose1D), options_(options) {} + +void ConvTranspose1DOp::load(const ParameterFile::ptr_t& ploader) { + switch (ploader->version()) { + case ModelFileVersion::kV1: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + weight_ = weight_.view({options_.in_channels, options_.out_channels / options_.groups, options_.kernel_size}); + if (options_.bias) { bias_ = bias_.view({options_.out_channels}); } + break; + } + case ModelFileVersion::kUserTemporary: + case ModelFileVersion::kV2: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + break; + } + default: NYI("Unsupported model file version") + } +} + +void ConvTranspose1DOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + + if (weight_ && !ir_ctx->lookupSymbolTable(getName() + ".weight")) { + ir::IRWriterGuard guard(ir_ctx, ir_ctx->lookupSymbolTable("init")->cast_()->getTopRegion()); + ir_ctx->create(ir_ctx->create(weight_)); + if (options_.bias) { ir_ctx->create(ir_ctx->create(bias_)); } + } + + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void ConvTranspose1DOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("ConvTranspose1DOp::forward not implemented in aops base."); +} + +void ConvTranspose1DOp::reshape(const std::vector& inputs, std::vector& outputs) { + const auto& i = inputs[0]; + const auto& ishape = i.shape(); + + if (ishape.size() != 3) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "ConvTranspose1DOp expects 3D input, got {} D", ishape.size()); + outputs.emplace_back(Tensor::empty(i.shape(), i.dtype(), i.device())); + return; + } + + const int batch = ishape[0]; + const int in_channels = ishape[1]; + const int sequence = ishape[2]; + + MLLM_RT_ASSERT_EQ(in_channels, options_.in_channels); + MLLM_RT_ASSERT_EQ(in_channels % options_.groups, 0); + MLLM_RT_ASSERT_EQ(options_.out_channels % options_.groups, 0); + + const int kernel_size = options_.kernel_size; + const int stride = options_.stride; + const int dilation = options_.dilation; + const int padding = options_.padding; + const int output_padding = options_.output_padding; + + const int seq_out = (sequence - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1; + + auto new_shape = std::vector{batch, options_.out_channels, seq_out}; + outputs.emplace_back(Tensor::empty(new_shape, i.dtype(), i.device())); +} + +void ConvTranspose1DOp::setup(const std::vector& inputs, std::vector& outputs) { + BaseOp::setup(inputs, outputs); +} + +ParameterFile::ptr_t ConvTranspose1DOp::getParams() { + auto p = ParameterFile::create(); + p->push(getName() + ".weight", weight_); + if (options_.bias) { p->push(getName() + ".bias", bias_); } + return p; +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/ConvTranspose1DOp.hpp b/mllm/core/aops/ConvTranspose1DOp.hpp new file mode 100644 index 000000000..daeda0b8e --- /dev/null +++ b/mllm/core/aops/ConvTranspose1DOp.hpp @@ -0,0 +1,52 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct ConvTranspose1DOpOptions : public BaseOpOptions { + int32_t in_channels; + int32_t out_channels; + int32_t kernel_size; + int32_t stride = 1; + int32_t padding = 0; + int32_t output_padding = 0; + int32_t dilation = 1; + int32_t groups = 1; + bool bias = true; +}; + +class ConvTranspose1DOp : public BaseOp { + public: + explicit ConvTranspose1DOp(const ConvTranspose1DOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + ParameterFile::ptr_t getParams() override; + + inline Tensor& weight() { return weight_; } + + inline Tensor& bias() { return bias_; } + + inline ConvTranspose1DOpOptions& options() { return options_; } + + protected: + Tensor weight_; + Tensor bias_; + ConvTranspose1DOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/core/aops/TanhOp.cpp b/mllm/core/aops/TanhOp.cpp new file mode 100644 index 000000000..c0938d82f --- /dev/null +++ b/mllm/core/aops/TanhOp.cpp @@ -0,0 +1,37 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/TanhOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +TanhOp::TanhOp(const TanhOpOptions& options) : BaseOp(OpTypes::kTanh), options_(options) {} + +void TanhOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void TanhOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void TanhOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("TanhOp::forward not implemented in aops base."); +} + +void TanhOp::reshape(const std::vector& inputs, std::vector& outputs) { + if (options_.isInplace()) { + outputs.emplace_back(inputs[0]); + } else { + outputs.emplace_back(Tensor::empty(inputs[0].shape(), inputs[0].dtype(), inputs[0].device())); + } +} + +void TanhOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +} // namespace mllm::aops diff --git a/mllm/core/aops/TanhOp.hpp b/mllm/core/aops/TanhOp.hpp new file mode 100644 index 000000000..8b2ce4f43 --- /dev/null +++ b/mllm/core/aops/TanhOp.hpp @@ -0,0 +1,33 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct TanhOpOptions : public BaseOpOptions {}; + +class TanhOp : public BaseOp { + public: + explicit TanhOp(const TanhOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + inline TanhOpOptions& options() { return options_; } + + protected: + TanhOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/models/minicpm_o2_6/modeling_chattts.hpp b/mllm/models/minicpm_o2_6/modeling_chattts.hpp index 3190210a7..4814d0eb7 100644 --- a/mllm/models/minicpm_o2_6/modeling_chattts.hpp +++ b/mllm/models/minicpm_o2_6/modeling_chattts.hpp @@ -401,7 +401,6 @@ class ConditionalChatTTS : public nn::Module { // Apply softmax to get probabilities: [num_vq, codebook_size] auto scores = nn::functional::softmax(logits.view({1, 1, logits.shape()[0], logits.shape()[1]}), -1).squeeze(); - logits.delete_(); // Free memory // Sample from each VQ codebook independently using multinomial sampling // This matches PyTorch's torch.multinomial(scores, num_samples=1) behavior @@ -418,7 +417,6 @@ class ConditionalChatTTS : public nn::Module { if (sampled_token == eos_token) { finished = true; } } - scores.delete_(); // Free memory progress++; audio_bos = false; diff --git a/mllm/models/minicpm_o45/configuration_minicpm_o45.hpp b/mllm/models/minicpm_o45/configuration_minicpm_o45.hpp new file mode 100644 index 000000000..92d9106b8 --- /dev/null +++ b/mllm/models/minicpm_o45/configuration_minicpm_o45.hpp @@ -0,0 +1,177 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::minicpm_o45 { + +struct MiniCPMO45Config : protected ConfigFile { + MiniCPMO45Config() = default; + + explicit MiniCPMO45Config(const std::string& file_path) : ConfigFile(file_path) { + auto& cfg = data(); + auto get_or = [&](const std::string& key, auto fallback) { + using T = decltype(fallback); + return cfg.contains(key) ? cfg[key].get() : fallback; + }; + + auto vision_cfg = cfg.contains("vision_config") ? cfg["vision_config"] : nlohmann::json::object(); + auto audio_cfg = cfg.contains("audio_config") ? cfg["audio_config"] : nlohmann::json::object(); + auto tts_cfg = cfg.contains("tts_config") ? cfg["tts_config"] : nlohmann::json::object(); + + auto get_vision = [&](const std::string& key, auto fallback) { + using T = decltype(fallback); + if (vision_cfg.contains(key)) { return vision_cfg[key].get(); } + if (cfg.contains(key)) { return cfg[key].get(); } + return fallback; + }; + + auto get_audio = [&](const std::string& key, auto fallback) { + using T = decltype(fallback); + if (audio_cfg.contains(key)) { return audio_cfg[key].get(); } + if (cfg.contains(key)) { return cfg[key].get(); } + return fallback; + }; + + auto get_tts = [&](const std::string& key, auto fallback) { + using T = decltype(fallback); + if (tts_cfg.contains(key)) { return tts_cfg[key].get(); } + if (cfg.contains(key)) { return cfg[key].get(); } + return fallback; + }; + + // Vision config. + vision_hidden_size = get_vision("vision_hidden_size", get_vision("hidden_size", vision_hidden_size)); + vision_intermediate_size = get_vision("vision_intermediate_size", get_vision("intermediate_size", vision_intermediate_size)); + vision_num_hidden_layers = get_vision("vision_num_hidden_layers", get_vision("num_hidden_layers", vision_num_hidden_layers)); + vision_num_attention_heads = get_vision("vision_num_attention_heads", get_vision("num_attention_heads", vision_num_attention_heads)); + vision_num_channels = get_vision("vision_num_channels", get_vision("num_channels", vision_num_channels)); + vision_image_size = get_vision("vision_image_size", get_vision("image_size", vision_image_size)); + vision_patch_size = get_vision("vision_patch_size", get_vision("patch_size", vision_patch_size)); + + // LLM config (Qwen3). + attention_bias = get_or("attention_bias", attention_bias); + hidden_size = get_or("hidden_size", hidden_size); + num_attention_heads = get_or("num_attention_heads", num_attention_heads); + num_key_value_heads = get_or("num_key_value_heads", num_key_value_heads); + head_dim = get_or("head_dim", hidden_size / std::max(num_attention_heads, 1)); + intermediate_size = get_or("intermediate_size", intermediate_size); + num_hidden_layers = get_or("num_hidden_layers", num_hidden_layers); + max_position_embeddings = get_or("max_position_embeddings", max_position_embeddings); + rms_norm_eps = get_or("rms_norm_eps", rms_norm_eps); + vocab_size = get_or("vocab_size", vocab_size); + + // Resampler config. + query_num = get_or("query_num", query_num); + + // Audio config (Whisper encoder). + audio_hidden_size = get_audio("audio_hidden_size", get_audio("d_model", audio_hidden_size)); + audio_num_hidden_layers = get_audio("audio_num_hidden_layers", get_audio("num_hidden_layers", audio_num_hidden_layers)); + audio_num_attention_heads = get_audio("audio_num_attention_heads", get_audio("encoder_attention_heads", audio_num_attention_heads)); + audio_max_position_embeddings = + get_audio("audio_max_position_embeddings", get_audio("max_source_positions", audio_max_position_embeddings)); + audio_chunk_length = get_audio("audio_chunk_length", audio_chunk_length); + audio_pool_step = get_or("audio_pool_step", audio_pool_step); + + // TTS config (token generation stage). + tts_llm_dim = get_tts("tts_llm_dim", get_tts("llm_dim", tts_llm_dim)); + tts_llm_intermediate_size = get_tts("tts_llm_intermediate_size", get_tts("llm_intermediate_size", tts_llm_intermediate_size)); + tts_hidden_size = get_tts("tts_hidden_size", get_tts("hidden_size", tts_hidden_size)); + tts_intermediate_size = get_tts("tts_intermediate_size", get_tts("intermediate_size", tts_intermediate_size)); + tts_num_attention_heads = get_tts("tts_num_attention_heads", get_tts("num_attention_heads", tts_num_attention_heads)); + tts_num_key_value_heads = get_tts("tts_num_key_value_heads", get_tts("num_key_value_heads", tts_num_key_value_heads)); + tts_num_hidden_layers = get_tts("tts_num_hidden_layers", get_tts("num_hidden_layers", tts_num_hidden_layers)); + tts_max_position_embeddings = get_tts("tts_max_position_embeddings", get_tts("max_position_embeddings", tts_max_position_embeddings)); + tts_num_audio_tokens = get_tts("tts_num_audio_tokens", get_tts("num_audio_tokens", tts_num_audio_tokens)); + tts_num_text_tokens = get_tts("tts_num_text_tokens", get_tts("num_text_tokens", tts_num_text_tokens)); + tts_num_vq = get_tts("tts_num_vq", get_tts("num_vq", tts_num_vq)); + tts_audio_bos_token_id = get_tts("tts_audio_bos_token_id", get_tts("audio_bos_token_id", tts_audio_bos_token_id)); + tts_text_eos_token_id = get_tts("tts_text_eos_token_id", get_tts("text_eos_token_id", tts_text_eos_token_id)); + tts_backbone_vocab_size = tts_cfg.contains("vocab_size") ? tts_cfg["vocab_size"].get() : tts_backbone_vocab_size; + tts_rms_norm_eps = get_tts("tts_rms_norm_eps", get_tts("rms_norm_eps", tts_rms_norm_eps)); + tts_rope_theta = get_tts("tts_rope_theta", get_tts("rope_theta", tts_rope_theta)); + tts_hidden_act = get_tts("tts_hidden_act", get_tts("hidden_act", tts_hidden_act)); + tts_projector_type = get_tts("tts_projector_type", get_tts("projector_type", tts_projector_type)); + tts_condition_type = get_tts("tts_condition_type", get_tts("condition_type", tts_condition_type)); + tts_normalize_projected_hidden = get_tts("tts_normalize_projected_hidden", get_tts("normalize_projected_hidden", tts_normalize_projected_hidden)); + + // Common config. + max_cache_length = get_or("max_cache_length", max_cache_length); + eos_token_id = get_or("eos_token_id", eos_token_id); + bos_token_id = get_or("bos_token_id", bos_token_id); + rope_theta = get_or("rope_theta", rope_theta); + tie_word_embeddings = get_or("tie_word_embeddings", tie_word_embeddings); + + linear_impl_type = cfg.contains("linear_impl_type") ? aops::str2LinearImplTypes(cfg["linear_impl_type"]) : linear_impl_type; + } + + // Vision config (SigLIP). + int32_t vision_hidden_size = 1152; + int32_t vision_intermediate_size = 4304; + int32_t vision_num_hidden_layers = 27; + int32_t vision_num_attention_heads = 16; + int32_t vision_num_channels = 3; + int32_t vision_image_size = 980; + int32_t vision_patch_size = 14; + + // LLM config (Qwen3-8B). + bool attention_bias = false; + int32_t hidden_size = 4096; + int32_t head_dim = 128; + int32_t intermediate_size = 12288; + int32_t num_attention_heads = 32; + int32_t num_key_value_heads = 8; + int32_t num_hidden_layers = 36; + int32_t max_position_embeddings = 40960; + float rms_norm_eps = 1e-06f; + int32_t vocab_size = 151748; + + // Resampler config. + int32_t query_num = 64; + + // Audio config (Whisper-medium). + int32_t audio_hidden_size = 1024; + int32_t audio_num_hidden_layers = 24; + int32_t audio_num_attention_heads = 16; + int32_t audio_max_position_embeddings = 1500; + float audio_chunk_length = 1.0f; + int32_t audio_pool_step = 5; + + // TTS config (MiniCPMTTS in MiniCPM-o-4_5). + int32_t tts_llm_dim = 4096; + int32_t tts_llm_intermediate_size = 768; + int32_t tts_hidden_size = 768; + int32_t tts_intermediate_size = 3072; + int32_t tts_num_attention_heads = 12; + int32_t tts_num_key_value_heads = 12; + int32_t tts_num_hidden_layers = 20; + int32_t tts_max_position_embeddings = 4096; + int32_t tts_num_audio_tokens = 6562; + int32_t tts_num_text_tokens = 152064; + int32_t tts_num_vq = 1; + int32_t tts_audio_bos_token_id = 151687; + int32_t tts_text_eos_token_id = 151692; + int32_t tts_backbone_vocab_size = 32000; + float tts_rms_norm_eps = 1e-06f; + float tts_rope_theta = 10000.0f; + std::string tts_hidden_act = "silu"; + std::string tts_projector_type = "mlp"; + std::string tts_condition_type = "hidden_text_merge"; + bool tts_normalize_projected_hidden = true; + + // Common config. + int32_t max_cache_length = 4096; + int64_t eos_token_id = 151645; + int64_t bos_token_id = 151643; + float rope_theta = 1000000.0f; + bool tie_word_embeddings = false; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/minicpm_o45/convert_token2wav_pt_to_mllm.py b/mllm/models/minicpm_o45/convert_token2wav_pt_to_mllm.py new file mode 100644 index 000000000..760869611 --- /dev/null +++ b/mllm/models/minicpm_o45/convert_token2wav_pt_to_mllm.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 +# Copyright (c) MLLM Team. +# Licensed under the MIT License. + +"""Lightweight MiniCPM-o-4_5 token2wav converter. + +This script merges `flow.pt` + `hift.pt` into one `.mllm` file without +depending on `pymllm`/`tvm_ffi`. +""" + +from __future__ import annotations + +import argparse +import gc +import os +import struct +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + +try: + import torch +except ImportError as exc: + raise ImportError("PyTorch is required. Please install torch in your Python env.") from exc + + +# ----------------------------- MLLM constants ----------------------------- # + +MLLM_MODEL_FILE_V1_MAGIC_NUMBER = 20012 +MLLM_MODEL_FILE_V2_MAGIC_NUMBER = 0x519A +MLLM_MODEL_FILE_V2_VERSION = 2 +MLLM_MODEL_FILE_V2_MODEL_NAME_LENGTH = 512 +MLLM_MODEL_FILE_V2_PARAMS_NAME_LENGTH = 256 +MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH = 16 + +MODEL_FILE_V2_DESC_SIZE = 532 +MODEL_FILE_V2_PARAM_DESC_SIZE = 352 + + +def _build_torch_type_mapping() -> Dict[torch.dtype, int]: + mapping = { + torch.float32: 0, # kFloat32 + torch.float16: 1, # kFloat16 + torch.bfloat16: 128, # kBFloat16 + torch.int8: 16, # kInt8 + torch.int16: 17, # kInt16 + torch.int32: 18, # kInt32 + torch.int64: 132, # kInt64 + torch.uint8: 129, # kUInt8 + torch.bool: 129, # kUInt8 + } + if hasattr(torch, "uint16"): + mapping[torch.uint16] = 130 # kUInt16 + return mapping + + +TORCH_TYPE_MAPPING = _build_torch_type_mapping() + + +# ----------------------------- Helpers ----------------------------- # + + +@dataclass +class TensorMeta: + raw_name: str + full_name: str + dtype_id: int + data_len: int + + +def _load_pt(path: Path) -> Dict[str, torch.Tensor]: + if not path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {path}") + + try: + obj = torch.load(path, map_location="cpu", weights_only=True) + except TypeError: + obj = torch.load(path, map_location="cpu") + + if isinstance(obj, dict): + if obj and all(torch.is_tensor(v) for v in obj.values()): + return obj + for candidate in ("state_dict", "model", "module"): + cand = obj.get(candidate) + if isinstance(cand, dict) and cand and any(torch.is_tensor(v) for v in cand.values()): + return {k: v for k, v in cand.items() if torch.is_tensor(v)} + tensor_only = {k: v for k, v in obj.items() if torch.is_tensor(v)} + if tensor_only: + return tensor_only + + raise ValueError(f"Unsupported checkpoint layout: {path}") + + +def _normalized_tensor(t: torch.Tensor) -> torch.Tensor: + x = t.detach().cpu().contiguous() + if x.dim() == 0: + x = x.reshape(1) + return x + + +def _tensor_to_bytes(t: torch.Tensor) -> bytes: + x = _normalized_tensor(t) + return x.view(torch.uint8).numpy().tobytes() + + +def _tensor_dtype_id(dtype: torch.dtype) -> int: + if dtype not in TORCH_TYPE_MAPPING: + raise ValueError(f"Unsupported tensor dtype for .mllm export: {dtype}") + return TORCH_TYPE_MAPPING[dtype] + + +def _collect_source_meta( + ckpt_path: Path, + out_prefix: str, + strip_prefix: str, + preview_limit: int, +) -> List[TensorMeta]: + state = _load_pt(ckpt_path) + keys = list(state.keys()) + print(f"[inspect] {ckpt_path}: {len(keys)} tensors") + + metas: List[TensorMeta] = [] + for i, raw_name in enumerate(keys): + t = state[raw_name] + out_name = raw_name[len(strip_prefix) :] if (strip_prefix and raw_name.startswith(strip_prefix)) else raw_name + full_name = f"{out_prefix}{out_name}" + x = _normalized_tensor(t) + dtype_id = _tensor_dtype_id(x.dtype) + data_len = int(x.numel()) * int(x.element_size()) + metas.append(TensorMeta(raw_name=raw_name, full_name=full_name, dtype_id=dtype_id, data_len=data_len)) + if i < max(preview_limit, 0): + print(f" - {raw_name} shape={tuple(x.shape)} dtype={x.dtype}") + + del state + gc.collect() + return metas + + +def _check_duplicate_names(metas: Iterable[TensorMeta]) -> None: + seen = set() + for m in metas: + if m.full_name in seen: + raise ValueError(f"Duplicated tensor name after rename: {m.full_name}") + seen.add(m.full_name) + + +def _stream_source_tensors( + ckpt_path: Path, + metas: List[TensorMeta], +) -> Iterable[Tuple[TensorMeta, torch.Tensor]]: + state = _load_pt(ckpt_path) + try: + for m in metas: + if m.raw_name not in state: + raise KeyError(f"Tensor missing in checkpoint: {ckpt_path} -> {m.raw_name}") + t = _normalized_tensor(state[m.raw_name]) + yield m, t + finally: + del state + gc.collect() + + +# ----------------------------- V1 writer ----------------------------- # + + +def _write_v1( + output: Path, + model_name: str, + flow_pt: Path, + flow_metas: List[TensorMeta], + hift_pt: Path, + hift_metas: List[TensorMeta], +) -> None: + del model_name # v1 header has no model name + + all_metas = flow_metas + hift_metas + _check_duplicate_names(all_metas) + + desc_size = 0 + for m in all_metas: + name_bytes = m.full_name.encode("utf-8") + desc_size += 4 + len(name_bytes) + 8 + 8 + 4 + + output.parent.mkdir(parents=True, exist_ok=True) + with open(output, "wb") as f: + f.write(struct.pack(" {output}") + + +# ----------------------------- V2 writer ----------------------------- # + + +def _pack_v2_file_desc(model_name: str, num_params: int) -> bytes: + name_bytes = model_name.encode("utf-8") + name_bytes = name_bytes.ljust(MLLM_MODEL_FILE_V2_MODEL_NAME_LENGTH, b"\0")[:MLLM_MODEL_FILE_V2_MODEL_NAME_LENGTH] + return struct.pack( + f" bytes: + if len(shape) > MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH: + raise ValueError(f"Tensor rank > {MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH} is not supported: {name}") + + shape_padded = list(shape) + [0] * (MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH - len(shape)) + name_bytes = name.encode("utf-8") + name_bytes = name_bytes.ljust(MLLM_MODEL_FILE_V2_PARAMS_NAME_LENGTH, b"\0")[:MLLM_MODEL_FILE_V2_PARAMS_NAME_LENGTH] + return struct.pack( + f" None: + if self.num_params >= self.max_params: + raise ValueError(f"Descriptor buffer exceeded: {self.num_params} >= {self.max_params}") + + dtype_id = _tensor_dtype_id(tensor.dtype) + shape = tuple(int(v) for v in tensor.shape) + data = _tensor_to_bytes(tensor) + data_offset = self.f.tell() + data_len = len(data) + + self.f.write(data) + + desc_off = MODEL_FILE_V2_DESC_SIZE + self.num_params * MODEL_FILE_V2_PARAM_DESC_SIZE + self.f.seek(desc_off, os.SEEK_SET) + self.f.write( + _pack_v2_param_desc( + param_id=self.num_params, + param_type=dtype_id, + param_size=data_len, + param_offset=data_offset, + shape=shape, + name=name, + ) + ) + self.f.seek(0, os.SEEK_END) + self.num_params += 1 + + def finalize(self) -> None: + self.f.seek(0, os.SEEK_SET) + self.f.write(_pack_v2_file_desc(self.model_name, self.num_params)) + self.f.flush() + + def close(self) -> None: + if not self.f.closed: + self.f.close() + + +def _write_v2( + output: Path, + model_name: str, + flow_pt: Path, + flow_metas: List[TensorMeta], + hift_pt: Path, + hift_metas: List[TensorMeta], + max_param_desc: int, +) -> None: + all_metas = flow_metas + hift_metas + _check_duplicate_names(all_metas) + + if max_param_desc <= 0: + max_param_desc = len(all_metas) + if max_param_desc < len(all_metas): + raise ValueError(f"--max-param-desc ({max_param_desc}) < total tensors ({len(all_metas)})") + + output.parent.mkdir(parents=True, exist_ok=True) + writer = _V2StreamingWriter(output=output, model_name=model_name, max_params=max_param_desc) + written = 0 + try: + for m, t in _stream_source_tensors(flow_pt, flow_metas): + writer.write_tensor(m.full_name, t) + written += 1 + for m, t in _stream_source_tensors(hift_pt, hift_metas): + writer.write_tensor(m.full_name, t) + written += 1 + writer.finalize() + finally: + writer.close() + + print(f"[done:v2] wrote {written} tensors -> {output}") + + +# ----------------------------- Main ----------------------------- # + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert MiniCPM-o-4_5 token2wav flow.pt + hift.pt into one .mllm file." + ) + parser.add_argument( + "--flow-pt", + default="mllm/models/minicpm_o45/python_src_code/assets/token2wav/flow.pt", + help="Path to flow.pt", + ) + parser.add_argument( + "--hift-pt", + default="mllm/models/minicpm_o45/python_src_code/assets/token2wav/hift.pt", + help="Path to hift.pt", + ) + parser.add_argument("--output", required=True, help="Output .mllm path") + parser.add_argument("--model-name", default="minicpm_o45_token2wav", help="Model name (used in v2 header)") + parser.add_argument("--format", choices=["v1", "v2"], default="v1", help="Output model format") + parser.add_argument("--flow-prefix", default="token2wav.flow_model.", help="Prefix for flow tensor names") + parser.add_argument("--hift-prefix", default="token2wav.hift_model.", help="Prefix for hift tensor names") + parser.add_argument( + "--strip-hift-prefix", + default="generator.", + help="Strip this prefix from hift tensor names before adding --hift-prefix", + ) + parser.add_argument( + "--max-param-desc", + type=int, + default=0, + help="Only for v2: max descriptor buffer size, 0 means auto", + ) + parser.add_argument("--inspect-only", action="store_true", help="Only inspect checkpoints and quit") + parser.add_argument("--preview-limit", type=int, default=8, help="How many tensors to print per checkpoint") + args = parser.parse_args() + + flow_pt = Path(args.flow_pt).expanduser().resolve() + hift_pt = Path(args.hift_pt).expanduser().resolve() + output = Path(args.output).expanduser().resolve() + + flow_metas = _collect_source_meta(flow_pt, args.flow_prefix, "", args.preview_limit) + hift_metas = _collect_source_meta(hift_pt, args.hift_prefix, args.strip_hift_prefix, args.preview_limit) + + total = len(flow_metas) + len(hift_metas) + print(f"[count] flow={len(flow_metas)}, hift={len(hift_metas)}, total={total}") + + if args.inspect_only: + print("[inspect-only] done") + return + + if args.format == "v1": + _write_v1( + output=output, + model_name=args.model_name, + flow_pt=flow_pt, + flow_metas=flow_metas, + hift_pt=hift_pt, + hift_metas=hift_metas, + ) + else: + _write_v2( + output=output, + model_name=args.model_name, + flow_pt=flow_pt, + flow_metas=flow_metas, + hift_pt=hift_pt, + hift_metas=hift_metas, + max_param_desc=args.max_param_desc, + ) + + +if __name__ == "__main__": + main() diff --git a/mllm/models/minicpm_o45/export_prompt_cache.py b/mllm/models/minicpm_o45/export_prompt_cache.py new file mode 100644 index 000000000..041cc0c3f --- /dev/null +++ b/mllm/models/minicpm_o45/export_prompt_cache.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# Copyright (c) MLLM Team. +# Licensed under the MIT License. + +"""Export fixed MiniCPM-o-4_5 token2wav prompt cache for native C++ runtime. + +This script extracts prompt_speech_tokens / prompt_mels / speaker_embedding from +one reference wav and writes a compact binary cache. +""" + +from __future__ import annotations + +import argparse +import struct +import sys +import time +import types +from pathlib import Path + +import numpy as np + + +def _setup_cosyvoice2_alias() -> None: + if "cosyvoice2.flow.flow" in sys.modules: + return + + import stepaudio2.cosyvoice2.flow.decoder_dit as _step_decoder_dit + import stepaudio2.cosyvoice2.flow.flow as _step_flow + import stepaudio2.cosyvoice2.flow.flow_matching as _step_flow_matching + import stepaudio2.cosyvoice2.transformer.upsample_encoder_v2 as _step_upsample + + cosyvoice2_pkg = types.ModuleType("cosyvoice2") + cosyvoice2_flow_pkg = types.ModuleType("cosyvoice2.flow") + cosyvoice2_transformer_pkg = types.ModuleType("cosyvoice2.transformer") + + cosyvoice2_flow_pkg.flow = _step_flow + cosyvoice2_flow_pkg.flow_matching = _step_flow_matching + cosyvoice2_flow_pkg.decoder_dit = _step_decoder_dit + cosyvoice2_transformer_pkg.upsample_encoder_v2 = _step_upsample + + cosyvoice2_pkg.flow = cosyvoice2_flow_pkg + cosyvoice2_pkg.transformer = cosyvoice2_transformer_pkg + + sys.modules["cosyvoice2"] = cosyvoice2_pkg + sys.modules["cosyvoice2.flow"] = cosyvoice2_flow_pkg + sys.modules["cosyvoice2.flow.flow"] = _step_flow + sys.modules["cosyvoice2.flow.flow_matching"] = _step_flow_matching + sys.modules["cosyvoice2.flow.decoder_dit"] = _step_decoder_dit + sys.modules["cosyvoice2.transformer"] = cosyvoice2_transformer_pkg + sys.modules["cosyvoice2.transformer.upsample_encoder_v2"] = _step_upsample + + +def _resolve_device(torch_mod, req: str): + req = req.lower() + if req == "cuda": + if not torch_mod.cuda.is_available(): + raise RuntimeError("Requested --device=cuda but CUDA is unavailable") + return torch_mod.device("cuda") + if req == "mps": + if not getattr(torch_mod.backends, "mps", None) or not torch_mod.backends.mps.is_available(): + raise RuntimeError("Requested --device=mps but MPS is unavailable") + return torch_mod.device("mps") + if req == "cpu": + return torch_mod.device("cpu") + if req != "auto": + raise ValueError(f"Unsupported --device: {req}") + + if torch_mod.cuda.is_available(): + return torch_mod.device("cuda") + if getattr(torch_mod.backends, "mps", None) and torch_mod.backends.mps.is_available(): + return torch_mod.device("mps") + return torch_mod.device("cpu") + + +def _move_model(model, device): + if device.type == "cuda" and hasattr(model, "cuda"): + return model.cuda() + if device.type == "cpu" and hasattr(model, "cpu"): + return model.cpu() + if hasattr(model, "to"): + return model.to(device) + return model + + +class _StageLogger: + def __init__(self, verbose: bool): + self.verbose = verbose + self.t0 = time.time() + + def log(self, msg: str) -> None: + if not self.verbose: + return + dt = time.time() - self.t0 + print(f"[prompt-cache +{dt:.3f}s] {msg}", flush=True) + + +def _write_cache(path: Path, prompt_tokens: np.ndarray, prompt_mels: np.ndarray, spk_emb: np.ndarray) -> None: + # File layout (little-endian): + # magic[8] = "M45PC1\\0\\0" + # u32 version = 1 + # i32 prompt_token_len + # i32 prompt_mel_frames + # i32 mel_dim + # i32 spk_dim + # i32[prompt_token_len] + # f32[prompt_mel_frames * mel_dim] + # f32[spk_dim] + magic = b"M45PC1\0\0" + version = 1 + token_len = int(prompt_tokens.shape[0]) + mel_frames = int(prompt_mels.shape[0]) + mel_dim = int(prompt_mels.shape[1]) + spk_dim = int(spk_emb.shape[0]) + + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as f: + f.write(magic) + f.write(struct.pack(" None: + parser = argparse.ArgumentParser(description="Export MiniCPM-o-4_5 fixed prompt cache for native C++ token2wav") + parser.add_argument("--ref_wav", required=True, help="Reference wav path used for voice style") + parser.add_argument("--token2wav_dir", required=True, help="Path to assets/token2wav directory") + parser.add_argument("--python_src_root", required=True, help="Path to MiniCPM-o-4_5 python_src_code directory") + parser.add_argument("--out_cache", required=True, help="Output cache path (.bin)") + parser.add_argument("--device", default="auto", choices=["auto", "cpu", "mps", "cuda"], help="Runtime device") + parser.add_argument("--verbose", action="store_true", help="Print detailed stage logs") + args = parser.parse_args() + + python_src_root = Path(args.python_src_root).expanduser().resolve() + ref_wav = Path(args.ref_wav).expanduser().resolve() + token2wav_dir = Path(args.token2wav_dir).expanduser().resolve() + out_cache = Path(args.out_cache).expanduser().resolve() + + if str(python_src_root) not in sys.path: + sys.path.insert(0, str(python_src_root)) + + import onnxruntime + import s3tokenizer + import torch + import torchaudio + import torchaudio.compliance.kaldi as kaldi + from hyperpyyaml import load_hyperpyyaml + from stepaudio2.flashcosyvoice.utils.audio import mel_spectrogram + + logger = _StageLogger(args.verbose) + + logger.log("Resolving runtime device...") + device = _resolve_device(torch, args.device) + print(f"[prompt-cache] device={device}", flush=True) + print(f"[prompt-cache] ref_wav={ref_wav}", flush=True) + print(f"[prompt-cache] token2wav_dir={token2wav_dir}", flush=True) + + _setup_cosyvoice2_alias() + + logger.log("Loading speech tokenizer ONNX...") + audio_tokenizer = s3tokenizer.load_model(str(token2wav_dir / "speech_tokenizer_v2_25hz.onnx")) + audio_tokenizer = _move_model(audio_tokenizer, device) + if hasattr(audio_tokenizer, "eval"): + audio_tokenizer = audio_tokenizer.eval() + + logger.log("Loading campplus.onnx...") + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + spk_model = onnxruntime.InferenceSession( + str(token2wav_dir / "campplus.onnx"), + sess_options=option, + providers=["CPUExecutionProvider"], + ) + + logger.log("Reading flow.yaml for up_rate...") + with open(token2wav_dir / "flow.yaml", "r", encoding="utf-8") as f: + cfg = load_hyperpyyaml(f) + up_rate = int(cfg["flow"].up_rate) + print(f"[prompt-cache] flow.up_rate={up_rate}", flush=True) + + logger.log("Preparing prompt speech tokens (16k)...") + audio = s3tokenizer.load_audio(str(ref_wav), sr=16000) + mels = s3tokenizer.log_mel_spectrogram(audio) + mels, mels_lens = s3tokenizer.padding([mels]) + + quantize_device = device + try: + prompt_tokens, prompt_tokens_lens = audio_tokenizer.quantize(mels.to(quantize_device), mels_lens.to(quantize_device)) + except Exception: + quantize_device = torch.device("cpu") + audio_tokenizer = _move_model(audio_tokenizer, quantize_device) + if hasattr(audio_tokenizer, "eval"): + audio_tokenizer = audio_tokenizer.eval() + prompt_tokens, prompt_tokens_lens = audio_tokenizer.quantize(mels.to(quantize_device), mels_lens.to(quantize_device)) + + prompt_tokens = prompt_tokens.to(device) + prompt_tokens_lens = prompt_tokens_lens.to(device) + logger.log(f"prompt_tokens shape={tuple(prompt_tokens.shape)}, lens={prompt_tokens_lens.tolist()}") + + logger.log("Preparing speaker embedding...") + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb_np = spk_model.run( + None, + {spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}, + )[0] + spk_emb = torch.tensor(spk_emb_np, device=device, dtype=torch.float32) + logger.log(f"spk_emb shape={tuple(spk_emb.shape)}") + + logger.log("Preparing prompt mel (24k)...") + audio_24k, sample_rate = torchaudio.load(str(ref_wav), backend="soundfile") + audio_24k = audio_24k.mean(dim=0, keepdim=True) + if sample_rate != 24000: + audio_24k = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio_24k) + prompt_mel = mel_spectrogram(audio_24k).transpose(1, 2).squeeze(0) # [T, 80] + prompt_mels = prompt_mel.unsqueeze(0).to(device) + target_len = int(prompt_tokens.shape[1]) * up_rate + if target_len > prompt_mels.shape[1]: + prompt_mels = torch.nn.functional.pad( + prompt_mels, + (0, 0, 0, target_len - prompt_mels.shape[1]), + mode="replicate", + ) + logger.log(f"prompt_mels shape={tuple(prompt_mels.shape)}") + + logger.log("Writing cache...") + token_np = prompt_tokens[0].detach().cpu().numpy().astype(np.int32) + mel_np = prompt_mels[0].detach().cpu().numpy().astype(np.float32) + spk_np = spk_emb[0].detach().cpu().numpy().astype(np.float32) + _write_cache(out_cache, token_np, mel_np, spk_np) + + print(f"[prompt-cache] wrote: {out_cache}", flush=True) + print(f"[prompt-cache] token_len={token_np.shape[0]}, mel_frames={mel_np.shape[0]}, mel_dim={mel_np.shape[1]}, spk_dim={spk_np.shape[0]}", + flush=True) + + +if __name__ == "__main__": + main() + diff --git a/mllm/models/minicpm_o45/modeling_minicpm_o45.hpp b/mllm/models/minicpm_o45/modeling_minicpm_o45.hpp new file mode 100644 index 000000000..1cb92d4d8 --- /dev/null +++ b/mllm/models/minicpm_o45/modeling_minicpm_o45.hpp @@ -0,0 +1,910 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/mllm.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/llama/configuration_llama.hpp" +#include "mllm/models/llama/modeling_llama.hpp" +#include "mllm/models/minicpm_o2_6/configuration_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/modeling_resampler.hpp" +#include "mllm/models/minicpm_o2_6/modeling_siglip.hpp" +#include "mllm/models/minicpm_o2_6/modeling_whisper_encoder.hpp" +#include "mllm/models/minicpm_o45/configuration_minicpm_o45.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" +#include "mllm/models/qwen3/modeling_qwen3.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/utils/Log.hpp" + +namespace mllm::models::minicpm_o45 { + +class AudioProjectionLayer final : public nn::Module { + public: + AudioProjectionLayer() = default; + + AudioProjectionLayer(const std::string& name, int32_t input_dim, int32_t hidden_dim, int32_t output_dim) : Module(name) { + linear1_ = reg("linear1", input_dim, hidden_dim, true); + relu_ = reg("relu"); + linear2_ = reg("linear2", hidden_dim, output_dim, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = linear1_(x); + x = relu_(x); + x = linear2_(x); + return {x}; + } + + private: + nn::Linear linear1_; + nn::ReLU relu_; + nn::Linear linear2_; +}; + +class AudioAvgPooler final : public nn::Module { + public: + AudioAvgPooler() = default; + + AudioAvgPooler(const std::string& name, int32_t kernel_size, int32_t stride) : Module(name) { + avg_pool_ = reg("pool", kernel_size, stride); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {avg_pool_(inputs[0])}; + } + + private: + nn::AvgPool1d avg_pool_; +}; + +class TTSProjector final : public nn::Module { + public: + TTSProjector() = default; + + TTSProjector(const std::string& name, int32_t input_dim, int32_t output_dim) : nn::Module(name) { + linear1_ = reg("linear1", input_dim, output_dim, true); + relu_ = reg("relu"); + linear2_ = reg("linear2", output_dim, output_dim, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = linear1_(inputs[0]); + x = relu_(x); + x = linear2_(x); + return {x}; + } + + private: + nn::Linear linear1_; + nn::ReLU relu_; + nn::Linear linear2_; +}; + +struct MiniCPMO45TTSGenerationConfig { + int32_t max_new_tokens = 1024; + int32_t min_new_tokens = 50; + bool force_no_stop = false; + bool do_sample = true; + int32_t top_k = 25; + float top_p = 0.85f; + float repetition_penalty = 1.05f; + int32_t repetition_penalty_window = 16; + std::vector temperature = {0.8f}; + int32_t debug_interval = 16; + std::function& tokens, bool has_eos)> step_callback = nullptr; +}; + +struct MiniCPMO45TTSGenerationOutput { + Tensor new_ids = Tensor::nil(); + bool finished = false; +}; + +class MiniCPMO45TTS final : public nn::Module { + public: + MiniCPMO45TTS() = default; + + MiniCPMO45TTS(const std::string& name, const MiniCPMO45Config& cfg) : nn::Module(name), cfg_(cfg) { + projector_spk_ = reg("projector_spk", cfg.tts_llm_dim, cfg.tts_hidden_size); + projector_semantic_ = reg("projector_semantic", cfg.tts_llm_dim, cfg.tts_hidden_size); + + emb_text_ = reg("emb_text", cfg.tts_num_text_tokens, cfg.tts_hidden_size); + + emb_code_.reserve(cfg.tts_num_vq); + for (int32_t i = 0; i < cfg.tts_num_vq; ++i) { + emb_code_.emplace_back(reg("emb_code." + std::to_string(i), cfg.tts_num_audio_tokens, cfg.tts_hidden_size)); + } + + auto llama_cfg = llama::LLaMAConfig(); + llama_cfg.vocab_size = cfg.tts_backbone_vocab_size; + llama_cfg.hidden_size = cfg.tts_hidden_size; + llama_cfg.intermediate_size = cfg.tts_intermediate_size; + llama_cfg.num_attention_heads = cfg.tts_num_attention_heads; + llama_cfg.num_key_value_heads = cfg.tts_num_key_value_heads; + llama_cfg.num_hidden_layers = cfg.tts_num_hidden_layers; + llama_cfg.max_position_embeddings = cfg.tts_max_position_embeddings; + llama_cfg.rms_norm_eps = cfg.tts_rms_norm_eps; + llama_cfg.rope_theta = cfg.tts_rope_theta; + llama_cfg.hidden_act = cfg.tts_hidden_act; + llama_cfg.tie_word_embeddings = false; + llama_cfg.attention_bias = false; + llama_cfg.linear_impl_type = cfg.linear_impl_type; + model_ = reg("model", llama_cfg); + } + + void loadFromParameter(const ParameterFile::ptr_t& param_file) { + nn::Module::load(param_file); + + head_code_weight_.clear(); + head_code_weight_.reserve(cfg_.tts_num_vq); + + auto prefix = getModuleName() + ".head_code."; + for (int32_t i = 0; i < cfg_.tts_num_vq; ++i) { + auto g = param_file->pull(prefix + std::to_string(i) + ".parametrizations.weight.original0"); + auto v = param_file->pull(prefix + std::to_string(i) + ".parametrizations.weight.original1"); + if (g.dtype() != kFloat32) { g = g.to(kFloat32); } + if (v.dtype() != kFloat32) { v = v.to(kFloat32); } + g = g.contiguous(); + v = v.contiguous().view({cfg_.tts_num_audio_tokens, cfg_.tts_hidden_size}); + + auto weight = Tensor::empty({cfg_.tts_num_audio_tokens, cfg_.tts_hidden_size}, kFloat32, kCPU).alloc(); + + auto* g_ptr = g.ptr(); + auto* v_ptr = v.ptr(); + auto* w_ptr = weight.ptr(); + + constexpr float kEps = 1e-12f; + for (int32_t out_idx = 0; out_idx < cfg_.tts_num_audio_tokens; ++out_idx) { + float norm = 0.0f; + auto row_offset = out_idx * cfg_.tts_hidden_size; + for (int32_t d = 0; d < cfg_.tts_hidden_size; ++d) { + auto val = v_ptr[row_offset + d]; + norm += val * val; + } + norm = std::sqrt(norm); + if (norm < kEps) { norm = kEps; } + auto scale = g_ptr[out_idx] / norm; + for (int32_t d = 0; d < cfg_.tts_hidden_size; ++d) { w_ptr[row_offset + d] = v_ptr[row_offset + d] * scale; } + } + + head_code_weight_.push_back(weight); + } + } + + Tensor makeConditionEmbeddings(const std::vector& text_token_ids, const std::vector& text_hidden_states) { + if (text_token_ids.empty() || text_hidden_states.empty()) { return Tensor::nil(); } + if (text_token_ids.size() != text_hidden_states.size()) { + MLLM_ERROR("MiniCPM-o-4_5 TTS input mismatch: token count {} != hidden count {}.", + text_token_ids.size(), text_hidden_states.size()); + return Tensor::nil(); + } + + Tensor token_ids = Tensor::empty({1, static_cast(text_token_ids.size())}, kInt64, kCPU).alloc(); + for (size_t i = 0; i < text_token_ids.size(); ++i) { + auto token_id = text_token_ids[i]; + if (token_id < 0 || token_id >= cfg_.tts_num_text_tokens) { + MLLM_ERROR("MiniCPM-o-4_5 TTS text token id out of range: token_id={} valid=[0, {}).", + token_id, cfg_.tts_num_text_tokens); + return Tensor::nil(); + } + token_ids.at({0, static_cast(i)}) = token_id; + } + + auto llm_embeds = emb_text_(token_ids); + + Tensor hidden_states = text_hidden_states.size() == 1 ? text_hidden_states[0] : nn::functional::concat(text_hidden_states, 1); + auto projected_hidden = projector_semantic_(hidden_states)[0]; + if (cfg_.tts_normalize_projected_hidden) { projected_hidden = normalizeProjectedHidden(projected_hidden); } + + auto tts_embeds = llm_embeds + projected_hidden; + + Tensor text_eos = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + text_eos.at({0, 0}) = cfg_.tts_text_eos_token_id; + Tensor audio_bos = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + audio_bos.at({0, 0}) = cfg_.tts_audio_bos_token_id; + if (cfg_.tts_text_eos_token_id < 0 || cfg_.tts_text_eos_token_id >= cfg_.tts_num_text_tokens) { + MLLM_ERROR("MiniCPM-o-4_5 TTS text_eos_token_id out of range: {} (vocab={}).", + cfg_.tts_text_eos_token_id, cfg_.tts_num_text_tokens); + return Tensor::nil(); + } + if (cfg_.tts_audio_bos_token_id < 0 || cfg_.tts_audio_bos_token_id >= cfg_.tts_num_text_tokens) { + MLLM_ERROR("MiniCPM-o-4_5 TTS audio_bos_token_id out of range: {} (vocab={}).", + cfg_.tts_audio_bos_token_id, cfg_.tts_num_text_tokens); + return Tensor::nil(); + } + + auto text_eos_embed = emb_text_(text_eos); + auto audio_bos_embed = emb_text_(audio_bos); + + return nn::functional::concat({tts_embeds, text_eos_embed, audio_bos_embed}, 1); + } + + MiniCPMO45TTSGenerationOutput generate(const Tensor& condition_embeds, + const MiniCPMO45TTSGenerationConfig& generation_cfg = {}) { + if (condition_embeds.isNil()) { return {}; } + + auto eos_token = cfg_.tts_num_audio_tokens - 1; + + std::vector temperature = generation_cfg.temperature; + if (temperature.empty()) { temperature.assign(cfg_.tts_num_vq, 1.0f); } + if (temperature.size() < static_cast(cfg_.tts_num_vq)) { + temperature.resize(cfg_.tts_num_vq, temperature.back()); + } + + nn::StaticCache kv_cache(cfg_.tts_max_position_embeddings, cfg_.tts_num_hidden_layers, + cfg_.tts_num_attention_heads, // q heads + cfg_.tts_num_key_value_heads, // kv heads + cfg_.tts_hidden_size / cfg_.tts_num_attention_heads, + kFloat32, // k dtype + kFloat32, // v dtype + kCPU, // device + false // use fa2 + ); + + Tensor generated = Tensor::zeros({1, generation_cfg.max_new_tokens, cfg_.tts_num_vq}, kInt64, kCPU); + int32_t generated_len = 0; + bool finished = false; + auto condition_length = condition_embeds.shape()[1]; + std::vector> generated_history(cfg_.tts_num_vq); + + for (int32_t t = 0; t < generation_cfg.max_new_tokens; ++t) { + Tensor inputs_embeds = Tensor::nil(); + Tensor position_ids = Tensor::nil(); + + if (t == 0) { + inputs_embeds = condition_embeds; + position_ids = Tensor::empty({1, condition_length}, kInt64, kCPU).alloc(); + for (int32_t i = 0; i < condition_length; ++i) { position_ids.at({0, i}) = i; } + } else { + for (int32_t q = 0; q < cfg_.tts_num_vq; ++q) { + auto code_ids = generated[{kAll, {t - 1, t}, {q, q + 1}}].contiguous().view({1, 1}); + auto code_embeds = emb_code_[q](code_ids); + if (q == 0) { + inputs_embeds = code_embeds; + } else { + inputs_embeds = inputs_embeds + code_embeds; + } + } + position_ids = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + position_ids.at({0, 0}) = condition_length + t - 1; + } + + auto [llm_embedding_sin, llm_embedding_cos] = llama::makeRotaryPosEmbedding(position_ids, model_.getBuffer("inv_freq"), 1.0f); + Tensor causal_mask = Tensor::nil(); + auto* cache_ptr = static_cast(&kv_cache); + auto hidden_states = model_(inputs_embeds, llm_embedding_sin, llm_embedding_cos, causal_mask, AnyValue(cache_ptr))[0]; + + auto seq_len = hidden_states.shape()[1]; + auto last_hidden = hidden_states[{kAll, {seq_len - 1, seq_len}, kAll}].contiguous(); + + bool has_eos = false; + std::vector step_tokens; + step_tokens.reserve(cfg_.tts_num_vq); + for (int32_t q = 0; q < cfg_.tts_num_vq; ++q) { + MLLM_RT_ASSERT(q < static_cast(head_code_weight_.size())); + auto logits = nn::functional::matmul(last_hidden, head_code_weight_[q], false, true)[{0, 0, kAll}].contiguous(); + auto temp = std::max(temperature[q], 1e-5f); + logits = logits / temp; + + if (t > 0) { + applyRepetitionPenalty(logits, generated_history[q], generation_cfg.repetition_penalty, + generation_cfg.repetition_penalty_window); + applyTopPLogits(logits, generation_cfg.top_p, 3); + applyTopKLogits(logits, generation_cfg.top_k, 3); + } + + if (t < generation_cfg.min_new_tokens || generation_cfg.force_no_stop) { + if (logits.dtype() == kFloat32) { + logits.ptr()[eos_token] = -std::numeric_limits::infinity(); + } else if (logits.dtype() == kFloat16) { + logits.ptr()[eos_token] = -65504.0f; + } + } + + bool use_sampling = generation_cfg.do_sample || generation_cfg.top_k > 0 || generation_cfg.top_p > 0.0f + || std::abs(temp - 1.0f) > 1e-6f; + auto token_id = sampleFromLogits(logits, use_sampling); + generated.at({0, t, q}) = token_id; + generated_history[q].push_back(token_id); + step_tokens.push_back(token_id); + has_eos = has_eos || token_id == eos_token; + } + + if (generation_cfg.step_callback) { + auto interval = std::max(generation_cfg.debug_interval, 1); + if (t == 0 || ((t + 1) % interval) == 0 || has_eos) { + generation_cfg.step_callback(t + 1, step_tokens, has_eos); + } + } + + generated_len = t + 1; + if (has_eos) { + finished = true; + break; + } + } + + auto out_len = generated_len; + if (finished && out_len > 0) { out_len -= 1; } // do not return terminal token + + Tensor out_ids = Tensor::nil(); + if (out_len > 0) { out_ids = generated[{kAll, {0, out_len}, kAll}].contiguous(); } + return {.new_ids = out_ids, .finished = finished}; + } + + private: + static int64_t argmax1d(const Tensor& logits) { + auto probs = logits; + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + auto* data = probs.ptr(); + auto n = probs.shape().back(); + + auto max_idx = 0; + auto max_value = data[0]; + for (int32_t i = 1; i < n; ++i) { + if (data[i] > max_value) { + max_value = data[i]; + max_idx = i; + } + } + return max_idx; + } + + static int64_t sampleFromLogits(Tensor logits, bool do_sample) { + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + if (!do_sample) { return argmax1d(logits); } + + auto probs = nn::functional::softmax(logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + return categoricalSample1d(probs); + } + + static int64_t categoricalSample1d(const Tensor& probs) { + MLLM_RT_ASSERT_EQ(probs.dtype(), kFloat32); + auto* prob_data = probs.ptr(); + auto vocab_size = probs.shape().back(); + + std::vector cumulative_probs(vocab_size); + std::partial_sum(prob_data, prob_data + vocab_size, cumulative_probs.begin()); + + auto total = cumulative_probs.back(); + if (total <= 0.0f) { return argmax1d(probs); } + + static thread_local std::mt19937 rng(std::random_device{}()); + std::uniform_real_distribution dist(0.0f, total); + auto target = dist(rng); + + auto it = std::lower_bound(cumulative_probs.begin(), cumulative_probs.end(), target); + if (it == cumulative_probs.end()) { return vocab_size - 1; } + return static_cast(std::distance(cumulative_probs.begin(), it)); + } + + static void applyRepetitionPenalty(Tensor& logits, const std::vector& token_ids, float penalty, + int32_t past_window) { + if (penalty <= 1.0f || token_ids.empty()) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + + auto vocab_size = logits.shape().back(); + std::unordered_map frequencies; + + int32_t start = 0; + if (past_window > 0 && static_cast(token_ids.size()) > past_window) { + start = static_cast(token_ids.size()) - past_window; + } + for (int32_t i = start; i < static_cast(token_ids.size()); ++i) { + auto token_id = token_ids[i]; + if (token_id < 0 || token_id >= vocab_size) { continue; } + frequencies[token_id] += 1; + } + + auto* logits_ptr = logits.ptr(); + for (const auto& [token_id, freq] : frequencies) { + auto alpha = std::pow(penalty, static_cast(freq)); + float& value = logits_ptr[token_id]; + value = value < 0.0f ? value * alpha : value / alpha; + } + } + + static void applyTopKLogits(Tensor& logits, int32_t top_k, int32_t min_tokens_to_keep) { + if (top_k <= 0) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + + auto vocab_size = logits.shape().back(); + int32_t k = std::min(std::max(top_k, min_tokens_to_keep), vocab_size); + if (k >= vocab_size) { return; } + + auto* logits_ptr = logits.ptr(); + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), + [&logits_ptr](int32_t lhs, int32_t rhs) { return logits_ptr[lhs] > logits_ptr[rhs]; }); + + auto threshold = logits_ptr[indices[k - 1]]; + auto neg_inf = -std::numeric_limits::infinity(); + for (int32_t i = 0; i < vocab_size; ++i) { + if (logits_ptr[i] < threshold) { logits_ptr[i] = neg_inf; } + } + } + + static void applyTopPLogits(Tensor& logits, float top_p, int32_t min_tokens_to_keep) { + if (top_p <= 0.0f || top_p >= 1.0f) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + + auto vocab_size = logits.shape().back(); + if (vocab_size <= min_tokens_to_keep) { return; } + + auto* logits_ptr = logits.ptr(); + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&logits_ptr](int32_t lhs, int32_t rhs) { return logits_ptr[lhs] > logits_ptr[rhs]; }); + + auto max_logit = logits_ptr[indices[0]]; + std::vector probs(vocab_size); + float sum_exp = 0.0f; + for (int32_t i = 0; i < vocab_size; ++i) { + auto prob = std::exp(logits_ptr[indices[i]] - max_logit); + probs[i] = prob; + sum_exp += prob; + } + if (sum_exp <= 0.0f) { return; } + for (auto& p : probs) { p /= sum_exp; } + + int32_t keep = 0; + float cumulative = 0.0f; + for (int32_t i = 0; i < vocab_size; ++i) { + cumulative += probs[i]; + keep += 1; + if (cumulative >= top_p && keep >= min_tokens_to_keep) { break; } + } + keep = std::max(keep, min_tokens_to_keep); + keep = std::min(keep, vocab_size); + + auto neg_inf = -std::numeric_limits::infinity(); + for (int32_t i = keep; i < vocab_size; ++i) { logits_ptr[indices[i]] = neg_inf; } + } + + static Tensor normalizeProjectedHidden(Tensor hidden_states) { + auto original_dtype = hidden_states.dtype(); + auto normalized = original_dtype == kFloat32 ? hidden_states.contiguous() : hidden_states.to(kFloat32).contiguous(); + + auto B = normalized.shape()[0]; + auto S = normalized.shape()[1]; + auto D = normalized.shape()[2]; + auto* ptr = normalized.ptr(); + + constexpr float kEps = 1e-12f; + for (int32_t b = 0; b < B; ++b) { + for (int32_t s = 0; s < S; ++s) { + auto base = b * S * D + s * D; + float norm = 0.0f; + for (int32_t d = 0; d < D; ++d) { norm += ptr[base + d] * ptr[base + d]; } + norm = std::sqrt(norm); + if (norm < kEps) { norm = kEps; } + for (int32_t d = 0; d < D; ++d) { ptr[base + d] /= norm; } + } + } + + if (original_dtype != kFloat32) { return normalized.to(original_dtype); } + return normalized; + } + + private: + MiniCPMO45Config cfg_; + TTSProjector projector_spk_; + TTSProjector projector_semantic_; + std::vector emb_code_; + nn::Embedding emb_text_; + std::vector head_code_weight_; + llama::LlamaText model_; +}; + +class MiniCPMO45TextModel final : public nn::Module { + public: + MiniCPMO45TextModel() = default; + + MiniCPMO45TextModel(const std::string& name, const MiniCPMO45Config& cfg) : Module(name) { + auto llm_cfg = toQwen3Config(cfg); + decode_blocks_ = reg>("layers", llm_cfg.num_hidden_layers, llm_cfg); + for (auto [idx, block] : enumerate(decode_blocks_.list())) { block.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", llm_cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", llm_cfg.vocab_size, llm_cfg.hidden_size); + registerBuffer("last_hidden_states", Tensor::nil()); + } + + Tensor embed(const Tensor& input_ids) { return embedding_(input_ids); } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto kv_cache = args[0].get(); + + for (auto& block : decode_blocks_.list()) { + hidden_states = block(hidden_states, llm_embedding_sin, llm_embedding_cos, AnyValue(kv_cache))[0]; + } + updateBuffer("last_hidden_states", hidden_states); + hidden_states = norm_(hidden_states); + return {hidden_states}; + } + + private: + static qwen3::Qwen3Config toQwen3Config(const MiniCPMO45Config& cfg) { + qwen3::Qwen3Config llm_cfg; + llm_cfg.attention_bias = cfg.attention_bias; + llm_cfg.hidden_size = cfg.hidden_size; + llm_cfg.head_dim = cfg.head_dim; + llm_cfg.intermediate_size = cfg.intermediate_size; + llm_cfg.num_attention_heads = cfg.num_attention_heads; + llm_cfg.num_key_value_heads = cfg.num_key_value_heads; + llm_cfg.num_hidden_layers = cfg.num_hidden_layers; + llm_cfg.max_position_embeddings = cfg.max_position_embeddings; + llm_cfg.rms_norm_eps = cfg.rms_norm_eps; + llm_cfg.vocab_size = cfg.vocab_size; + llm_cfg.bos_token_id = cfg.bos_token_id; + llm_cfg.eos_token_id = cfg.eos_token_id; + llm_cfg.end_of_text_token_id = static_cast(cfg.eos_token_id); + llm_cfg.rope_theta = cfg.rope_theta; + llm_cfg.tie_word_embeddings = cfg.tie_word_embeddings; + llm_cfg.max_cache_length = cfg.max_cache_length; + llm_cfg.linear_impl_type = cfg.linear_impl_type; + return llm_cfg; + } + + private: + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + + public: + nn::Embedding embedding_; +}; + +class MiniCPMO45LLM final : public nn::Module { + public: + MiniCPMO45LLM() = default; + + MiniCPMO45LLM(const std::string& name, const MiniCPMO45Config& cfg) : nn::Module(name) { + model_ = reg("model", cfg); + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + registerBuffer("inv_freq", qwen3::makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta)); + } + + Tensor embed(const Tensor& input_ids) { return model_.embedding_(input_ids); } + + Tensor logits(const Tensor& hidden_states) { return lm_head_(hidden_states); } + + Tensor hiddenStates(Tensor& input_embeddings, Tensor& llm_embedding_sin, Tensor& llm_embedding_cos, nn::StaticCache* kv_cache) { + return model_(input_embeddings, llm_embedding_sin, llm_embedding_cos, AnyValue(kv_cache))[0]; + } + + public: + MiniCPMO45TextModel model_; + + private: + nn::Linear lm_head_; +}; + +class MiniCPMO45ForCausalLM : public models::ARGeneration { + public: + struct TextGenerationWithHiddenOutput { + std::vector generated_tokens; + std::vector aligned_tokens; + std::vector aligned_hidden_states; + bool finished = false; + }; + + explicit MiniCPMO45ForCausalLM(const MiniCPMO45Config& config) + : config_(config), + legacy_config_(createLegacyConfig(config)), + llm_("llm", config), + vpm_("vpm", legacy_config_), + resampler_("resampler", config.query_num, config.hidden_size, config.num_attention_heads, config.vision_hidden_size), + apm_("apm", legacy_config_), + audio_projection_layer_("audio_projection_layer", config.audio_hidden_size, config.hidden_size, config.hidden_size), + audio_avg_pooler_("audio_avg_pooler", config.audio_pool_step, config.audio_pool_step), + tts_("tts", config), + kv_cache_(config.max_cache_length, config.num_hidden_layers, + config.num_attention_heads, // q heads + config.num_key_value_heads, // kv heads + config.head_dim, // kv dim + kFloat32, // k dtype + kFloat32, // v dtype + kCPU, // device + false // use fa2 + ) { + eos_token_id_ = static_cast(config.eos_token_id); + max_length_ = config.max_cache_length; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& inputs, const ARGenerationArgs& args) override { + Tensor input_ids = Tensor::nil(); + if (inputs.count("input_ids")) { + input_ids = inputs.at("input_ids"); + } else if (inputs.count("sequence")) { + input_ids = inputs.at("sequence"); + } else { + MLLM_ERROR("No input_ids or sequence found in MiniCPM-o-4_5 forward input."); + return {}; + } + + auto input_embeddings = llm_.embed(input_ids); + + Tensor prev_position_ids = inputs.count("position_ids") ? inputs.at("position_ids") : Tensor::nil(); + + // Prefill-only multimodal embedding injection. + if (prev_position_ids.isNil()) { + auto pixel_values = inputs.count("pixel_values") ? inputs.at("pixel_values") : Tensor::nil(); + auto tgt_sizes = inputs.count("tgt_sizes") ? inputs.at("tgt_sizes") : Tensor::nil(); + auto image_bounds = inputs.count("image_bounds") ? inputs.at("image_bounds") : Tensor::nil(); + + if (!pixel_values.isNil() && !tgt_sizes.isNil() && !image_bounds.isNil()) { + auto vision_outputs = vpm_(pixel_values, tgt_sizes)[0]; + auto vision_embeddings = resampler_(vision_outputs, tgt_sizes)[0]; + input_embeddings = mergeVisionTextEmbeddings(input_embeddings, vision_embeddings, image_bounds); + } + + auto audio_features = inputs.count("audio_features") ? inputs.at("audio_features") : Tensor::nil(); + auto audio_bounds = inputs.count("audio_bounds") ? inputs.at("audio_bounds") : Tensor::nil(); + + if (!audio_features.isNil() && !audio_bounds.isNil()) { + auto audio_embeddings = encodeAudio(audio_features); + input_embeddings = mergeAudioTextEmbeddings(input_embeddings, audio_embeddings, audio_bounds); + } + } + + Tensor position_ids = makePositionIds(input_embeddings.shape()[1], prev_position_ids); + + auto [llm_embedding_sin, llm_embedding_cos] = qwen3::makeRotaryPosEmbedding(position_ids, llm_.getBuffer("inv_freq"), 1.0f); + + auto hidden_states = llm_.hiddenStates(input_embeddings, llm_embedding_sin, llm_embedding_cos, &kv_cache_); + auto seq_len = hidden_states.shape()[1]; + auto last_hidden = hidden_states[{kAll, {seq_len - 1, seq_len}, kAll}].contiguous(); + auto logits = llm_.logits(last_hidden); + + return { + {"sequence", logits}, + {"position_ids", position_ids}, + {"last_hidden", last_hidden}, + }; + } + + Tensor encodeAudio(const Tensor& audio_features) { + // 1) Whisper encoder. + auto audio_states = apm_(audio_features)[0]; + + // 2) Project to the LLM hidden space. + auto audio_embeds = audio_projection_layer_(audio_states)[0]; + + // 3) Temporal pooling. + audio_embeds = audio_embeds.transpose(1, 2); + audio_embeds = audio_avg_pooler_(audio_embeds)[0]; + audio_embeds = audio_embeds.transpose(1, 2); + return audio_embeds; + } + + TextGenerationWithHiddenOutput generateTextWithHidden(const ARGenerationOutputPast& initial_inputs, int32_t max_new_tokens, + const std::vector& stop_token_ids, bool do_sample = false, + float temperature = 1.0f, int32_t top_k = 0, float top_p = 0.0f, + const std::function& step_callback = + nullptr) { + TextGenerationWithHiddenOutput result; + + auto current_input = initial_inputs; + bool has_previous_generated = false; + int64_t previous_generated_token = 0; + + for (int32_t i = 0; i < max_new_tokens; ++i) { + auto output = forward(current_input, {}); + + if (has_previous_generated && output.count("last_hidden")) { + result.aligned_tokens.push_back(previous_generated_token); + result.aligned_hidden_states.push_back(output.at("last_hidden").contiguous().clone()); + } + + Tensor logits = output.at("sequence"); + int64_t next_token_id = 0; + if (do_sample || temperature != 1.0f || top_k > 0 || top_p > 0.0f) { + if (top_k > 0) { + next_token_id = sampleTopK(logits, top_k, temperature); + } else if (top_p > 0.0f) { + next_token_id = sampleTopP(logits, top_p, temperature); + } else { + next_token_id = sampleTemperature(logits, temperature); + } + } else { + next_token_id = sampleGreedy(logits); + } + result.generated_tokens.push_back(next_token_id); + if (step_callback) { step_callback(i + 1, next_token_id); } + + if (isStopToken(next_token_id, stop_token_ids)) { + result.finished = true; + break; + } + + current_input = std::move(output); + current_input["sequence"] = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + current_input["sequence"].at({0, 0}) = next_token_id; + + previous_generated_token = next_token_id; + has_previous_generated = true; + } + + if (!result.finished && has_previous_generated + && result.aligned_tokens.size() + 1 == result.generated_tokens.size()) { + auto probe_output = forward(current_input, {}); + if (probe_output.count("last_hidden")) { + result.aligned_tokens.push_back(previous_generated_token); + result.aligned_hidden_states.push_back(probe_output.at("last_hidden").contiguous().clone()); + } + } + + return result; + } + + public: + MiniCPMO45Config config_; + minicpmo::MiniCPMOConfig legacy_config_; + + MiniCPMO45LLM llm_; + minicpmo::SiglipVisionModel vpm_; + minicpmo::Resampler resampler_; + + minicpmo::WhisperEncoder apm_; + AudioProjectionLayer audio_projection_layer_; + AudioAvgPooler audio_avg_pooler_; + MiniCPMO45TTS tts_; + + private: + template + static void copyEmbeddingVector(Tensor& dst, const Tensor& src, int32_t dst_batch, int32_t dst_pos, int32_t src_batch, + int32_t src_pos, int32_t hidden_size) { + auto* dst_ptr = dst.offsettedPtr({dst_batch, dst_pos, 0}); + auto* src_ptr = src.coffsettedPtr({src_batch, src_pos, 0}); + std::memcpy(dst_ptr, src_ptr, hidden_size * sizeof(DType)); + } + + static Tensor mergeVisionTextEmbeddings(Tensor& text_embeddings, Tensor& vision_embeddings, const Tensor& image_bounds) { + auto batch_size = text_embeddings.shape()[0]; + auto hidden_size = text_embeddings.shape()[2]; + auto vision_seq_len = vision_embeddings.shape()[1]; + auto num_bounds = std::min(image_bounds.shape()[0], vision_embeddings.shape()[0]); + + if (vision_embeddings.shape()[0] != image_bounds.shape()[0]) { + MLLM_WARN("MiniCPM-o-4_5 vision bound count ({}) != embedding group count ({}). Using min={}.", + image_bounds.shape()[0], vision_embeddings.shape()[0], num_bounds); + } + + if (vision_embeddings.dtype() != text_embeddings.dtype()) { vision_embeddings = vision_embeddings.to(text_embeddings.dtype()); } + + for (int32_t b = 0; b < batch_size; ++b) { + for (int32_t bound_idx = 0; bound_idx < num_bounds; ++bound_idx) { + int32_t vision_idx = 0; + auto start_pos = image_bounds.constAt({bound_idx, 0}) + 1; + auto end_pos = image_bounds.constAt({bound_idx, 1}) - 1; + + for (int32_t pos = start_pos; pos <= end_pos && vision_idx < vision_seq_len; ++pos, ++vision_idx) { + if (text_embeddings.dtype() == kFloat32) { + copyEmbeddingVector(text_embeddings, vision_embeddings, b, pos, bound_idx, vision_idx, hidden_size); + } else if (text_embeddings.dtype() == kFloat16) { + copyEmbeddingVector(text_embeddings, vision_embeddings, b, pos, bound_idx, vision_idx, hidden_size); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported text embedding dtype in MiniCPM-o-4_5 vision merge."); + } + } + } + } + return text_embeddings; + } + + static Tensor mergeAudioTextEmbeddings(Tensor& text_embeddings, Tensor& audio_embeddings, const Tensor& audio_bounds) { + auto batch_size = text_embeddings.shape()[0]; + auto hidden_size = text_embeddings.shape()[2]; + auto audio_seq_len = audio_embeddings.shape()[1]; + auto num_bounds = std::min(audio_bounds.shape()[0], audio_embeddings.shape()[0]); + + if (audio_embeddings.shape()[0] != audio_bounds.shape()[0]) { + MLLM_WARN("MiniCPM-o-4_5 audio bound count ({}) != embedding group count ({}). Using min={}.", + audio_bounds.shape()[0], audio_embeddings.shape()[0], num_bounds); + } + + if (audio_embeddings.dtype() != text_embeddings.dtype()) { audio_embeddings = audio_embeddings.to(text_embeddings.dtype()); } + + for (int32_t b = 0; b < batch_size; ++b) { + for (int32_t bound_idx = 0; bound_idx < num_bounds; ++bound_idx) { + int32_t audio_idx = 0; + auto start_pos = audio_bounds.constAt({bound_idx, 0}); + auto end_pos = audio_bounds.constAt({bound_idx, 1}) - 1; + + for (int32_t pos = start_pos; pos <= end_pos && audio_idx < audio_seq_len; ++pos, ++audio_idx) { + if (text_embeddings.dtype() == kFloat32) { + copyEmbeddingVector(text_embeddings, audio_embeddings, b, pos, bound_idx, audio_idx, hidden_size); + } else if (text_embeddings.dtype() == kFloat16) { + copyEmbeddingVector(text_embeddings, audio_embeddings, b, pos, bound_idx, audio_idx, hidden_size); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported text embedding dtype in MiniCPM-o-4_5 audio merge."); + } + } + } + } + return text_embeddings; + } + + Tensor makePositionIds(int32_t seq_len, const Tensor& prev_position_ids) { + Tensor position_ids = Tensor::empty({1, seq_len}, kInt64).alloc(); + if (!prev_position_ids.isNil()) { + auto last_pos = *prev_position_ids.coffsettedPtr({0, prev_position_ids.shape()[1] - 1}); + for (int32_t i = 0; i < seq_len; ++i) { position_ids.at({0, i}) = last_pos + i + 1; } + return position_ids; + } + + auto last_seen_tokens = kv_cache_.getCurrentSeqCnt(0); + for (int32_t i = 0; i < seq_len; ++i) { position_ids.at({0, i}) = last_seen_tokens + i; } + return position_ids; + } + + static minicpmo::MiniCPMOConfig createLegacyConfig(const MiniCPMO45Config& config) { + minicpmo::MiniCPMOConfig legacy; + legacy.vision_hidden_size = config.vision_hidden_size; + legacy.vision_intermediate_size = config.vision_intermediate_size; + legacy.vision_num_hidden_layers = config.vision_num_hidden_layers; + legacy.vision_num_attention_heads = config.vision_num_attention_heads; + legacy.vision_num_channels = config.vision_num_channels; + legacy.vision_image_size = config.vision_image_size; + legacy.vision_patch_size = config.vision_patch_size; + + legacy.hidden_size = config.hidden_size; + legacy.intermediate_size = config.intermediate_size; + legacy.num_attention_heads = config.num_attention_heads; + legacy.num_key_value_heads = config.num_key_value_heads; + legacy.num_hidden_layers = config.num_hidden_layers; + legacy.max_position_embeddings = config.max_position_embeddings; + legacy.rms_norm_eps = config.rms_norm_eps; + legacy.vocab_size = config.vocab_size; + + legacy.query_num = config.query_num; + + legacy.audio_hidden_size = config.audio_hidden_size; + legacy.audio_num_hidden_layers = config.audio_num_hidden_layers; + legacy.audio_num_attention_heads = config.audio_num_attention_heads; + legacy.audio_max_position_embeddings = config.audio_max_position_embeddings; + legacy.audio_chunk_length = config.audio_chunk_length; + legacy.audio_pool_step = config.audio_pool_step; + + legacy.max_cache_length = config.max_cache_length; + legacy.eos_token_id = config.eos_token_id; + legacy.bos_token_id = config.bos_token_id; + legacy.rope_theta = config.rope_theta; + legacy.tie_word_embeddings = config.tie_word_embeddings; + + legacy.linear_impl_type = config.linear_impl_type; + return legacy; + } + + static bool isStopToken(int64_t token_id, const std::vector& stop_token_ids) { + for (auto id : stop_token_ids) { + if (token_id == id) { return true; } + } + return false; + } + + private: + nn::StaticCache kv_cache_; +}; + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/minicpm_o45/modeling_minicpm_o45_token2wav.hpp b/mllm/models/minicpm_o45/modeling_minicpm_o45_token2wav.hpp new file mode 100644 index 000000000..0e145d22c --- /dev/null +++ b/mllm/models/minicpm_o45/modeling_minicpm_o45_token2wav.hpp @@ -0,0 +1,1522 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/Parallel.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/layers/STFT.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/utils/Log.hpp" + +#include "mllm/models/minicpm_o45/token2wav_prompt_cache.hpp" +#include "mllm/models/minicpm_o45/token2wav_weight_norm.hpp" + +namespace mllm::models::minicpm_o45 { + +struct MiniCPMO45FlowConfig { + int32_t input_size = 512; + int32_t output_size = 80; + int32_t spk_embed_dim = 192; + int32_t vocab_size = 6561; + int32_t up_rate = 2; + + int32_t encoder_attention_heads = 8; + int32_t encoder_linear_units = 2048; + int32_t encoder_num_blocks = 6; + int32_t encoder_num_up_blocks = 4; + int32_t pre_lookahead_len = 3; + + int32_t dit_in_channels = 320; + int32_t dit_out_channels = 80; + float dit_mlp_ratio = 4.0f; + int32_t dit_depth = 16; + int32_t dit_num_heads = 8; + int32_t dit_head_dim = 64; + int32_t dit_hidden_size = 512; + float cfm_inference_cfg_rate = 0.7f; +}; + +struct MiniCPMO45HiFTConfig { + int32_t in_channels = 80; + int32_t base_channels = 512; + int32_t nb_harmonics = 8; + int32_t sampling_rate = 24000; + float nsf_alpha = 0.1f; + float nsf_sigma = 0.003f; + float nsf_voiced_threshold = 10.0f; + std::vector upsample_rates = {8, 5, 3}; + std::vector upsample_kernel_sizes = {16, 11, 7}; + int32_t istft_n_fft = 16; + int32_t istft_hop_len = 4; + std::vector resblock_kernel_sizes = {3, 7, 11}; + std::vector> resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + std::vector source_resblock_kernel_sizes = {7, 7, 11}; + std::vector> source_resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + float lrelu_slope = 0.1f; + float audio_limit = 0.99f; +}; + +struct MiniCPMO45Token2WavConfig { + MiniCPMO45FlowConfig flow{}; + MiniCPMO45HiFTConfig hift{}; +}; + +namespace token2wav { + +inline bool isDebugEnabled() { + static bool enabled = []() { + const char* v = std::getenv("MLLM_TOKEN2WAV_DEBUG"); + if (v == nullptr) { return false; } + return std::string(v) != "0"; + }(); + return enabled; +} + +inline void debugLog(const std::string& msg) { + if (!isDebugEnabled()) { return; } + std::cerr << "[token2wav-cpp] " << msg << std::endl; +} + +inline std::string shapeOf(const Tensor& x) { + std::string s = "["; + const auto& sh = x.shape(); + for (int32_t i = 0; i < static_cast(sh.size()); ++i) { + s += std::to_string(sh[i]); + if (i + 1 != static_cast(sh.size())) { s += ","; } + } + s += "]"; + return s; +} + +inline std::string descOf(const Tensor& x) { + return "shape=" + shapeOf(x) + ",dtype=" + std::to_string(static_cast(x.dtype())); +} + +inline Tensor repeatInterleaveSeq(Tensor x, int32_t repeats) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + const auto& shape = in.shape(); + MLLM_RT_ASSERT_EQ(static_cast(shape.size()), 3); + const int32_t batch = shape[0]; + const int32_t seq_len = shape[1]; + const int32_t channels = shape[2]; + + auto out = Tensor::empty({batch, seq_len * repeats, channels}, kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t in_stride_b = static_cast(seq_len) * channels; + const int64_t out_stride_b = static_cast(seq_len) * repeats * channels; + + for (int32_t b = 0; b < batch; ++b) { + const float* src_b = src + static_cast(b) * in_stride_b; + float* dst_b = dst + static_cast(b) * out_stride_b; + for (int32_t s = 0; s < seq_len; ++s) { + const float* src_s = src_b + static_cast(s) * channels; + for (int32_t r = 0; r < repeats; ++r) { + float* dst_s = dst_b + (static_cast(s) * repeats + r) * channels; + std::memcpy(dst_s, src_s, sizeof(float) * channels); + } + } + } + return out; +} + +inline Tensor concatInt64Seq(Tensor a, Tensor b) { + MLLM_RT_ASSERT_EQ(a.dtype(), kInt64); + MLLM_RT_ASSERT_EQ(b.dtype(), kInt64); + MLLM_RT_ASSERT_EQ(static_cast(a.shape().size()), 2); + MLLM_RT_ASSERT_EQ(static_cast(b.shape().size()), 2); + MLLM_RT_ASSERT_EQ(a.shape()[0], b.shape()[0]); + + auto av = a.contiguous(); + auto bv = b.contiguous(); + const int32_t B = av.shape()[0]; + const int32_t Ta = av.shape()[1]; + const int32_t Tb = bv.shape()[1]; + auto out = Tensor::empty({B, Ta + Tb}, kInt64, kCPU).alloc(); + + const auto* ap = av.ptr(); + const auto* bp = bv.ptr(); + auto* op = out.ptr(); + for (int32_t bidx = 0; bidx < B; ++bidx) { + std::memcpy(op + static_cast(bidx) * (Ta + Tb), ap + static_cast(bidx) * Ta, sizeof(int64_t) * Ta); + std::memcpy(op + static_cast(bidx) * (Ta + Tb) + Ta, bp + static_cast(bidx) * Tb, sizeof(int64_t) * Tb); + } + return out; +} + +inline Tensor repeatInterleave1d(Tensor x, int32_t repeats) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + const auto& shape = in.shape(); + MLLM_RT_ASSERT_EQ(static_cast(shape.size()), 3); + const int32_t batch = shape[0]; + const int32_t channels = shape[1]; + const int32_t seq_len = shape[2]; + + auto out = Tensor::empty({batch, channels, seq_len * repeats}, kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + + const int64_t in_stride_b = static_cast(channels) * seq_len; + const int64_t out_stride_b = static_cast(channels) * seq_len * repeats; + + for (int32_t b = 0; b < batch; ++b) { + const float* src_b = src + static_cast(b) * in_stride_b; + float* dst_b = dst + static_cast(b) * out_stride_b; + for (int32_t c = 0; c < channels; ++c) { + const float* src_c = src_b + static_cast(c) * seq_len; + float* dst_c = dst_b + static_cast(c) * seq_len * repeats; + for (int32_t t = 0; t < seq_len; ++t) { + const float v = src_c[t]; + for (int32_t r = 0; r < repeats; ++r) { dst_c[t * repeats + r] = v; } + } + } + } + return out; +} + +inline Tensor l2NormalizeRow(Tensor x, float eps = 1e-12f) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + const auto& shape = in.shape(); + MLLM_RT_ASSERT_EQ(static_cast(shape.size()), 2); + const int32_t batch = shape[0]; + const int32_t dim = shape[1]; + auto out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + + const auto* src = in.ptr(); + auto* dst = out.ptr(); + for (int32_t b = 0; b < batch; ++b) { + const int64_t base = static_cast(b) * dim; + float norm = 0.0f; + for (int32_t i = 0; i < dim; ++i) { + const float v = src[base + i]; + norm += v * v; + } + norm = std::sqrt(std::max(norm, eps)); + for (int32_t i = 0; i < dim; ++i) { dst[base + i] = src[base + i] / norm; } + } + return out; +} + +inline Tensor tensorMish(Tensor x) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + MLLM_CONDITIONAL_PARALLEL_FOR(n > 4096, 4, i, 0, n, 1, { + const float v = src[i]; + const float sp = std::log1p(std::exp(v)); + dst[i] = v * std::tanh(sp); + }); + return out; +} + +inline Tensor tensorLeakyRelu(Tensor x, float slope) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + MLLM_CONDITIONAL_PARALLEL_FOR(n > 4096, 4, i, 0, n, 1, { + const float v = src[i]; + dst[i] = (v >= 0.0f) ? v : (v * slope); + }); + return out; +} + +inline Tensor makeHannWindow(int32_t win_length) { + auto w = Tensor::empty({1, win_length}, kFloat32, kCPU).alloc(); + auto* ptr = w.ptr(); + constexpr float kPi = 3.14159265358979323846f; + for (int32_t i = 0; i < win_length; ++i) { + ptr[i] = 0.5f - 0.5f * std::cos(2.0f * kPi * static_cast(i) / static_cast(win_length)); + } + return w; +} + +inline Tensor relShift(Tensor x) { + // x: [B, H, T, 2T-1], output [B, H, T, T] + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + const int32_t B = in.shape()[0]; + const int32_t H = in.shape()[1]; + const int32_t T = in.shape()[2]; + const int32_t R = in.shape()[3]; + MLLM_RT_ASSERT_EQ(R, 2 * T - 1); + + auto out = Tensor::empty({B, H, T, T}, kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t in_stride_b = static_cast(H) * T * R; + const int64_t in_stride_h = static_cast(T) * R; + const int64_t in_stride_t = R; + const int64_t out_stride_b = static_cast(H) * T * T; + const int64_t out_stride_h = static_cast(T) * T; + const int64_t out_stride_t = T; + + for (int32_t b = 0; b < B; ++b) { + for (int32_t h = 0; h < H; ++h) { + const float* src_h = src + static_cast(b) * in_stride_b + static_cast(h) * in_stride_h; + float* dst_h = dst + static_cast(b) * out_stride_b + static_cast(h) * out_stride_h; + for (int32_t i = 0; i < T; ++i) { + const float* src_i = src_h + static_cast(i) * in_stride_t; + float* dst_i = dst_h + static_cast(i) * out_stride_t; + for (int32_t j = 0; j < T; ++j) { + const int32_t src_idx = j - i + T - 1; + dst_i[j] = src_i[src_idx]; + } + } + } + } + return out; +} + +inline void addHeadBiasInplace(Tensor& q, Tensor bias) { + // q: [B, H, T, D], bias: [H, D] + MLLM_RT_ASSERT_EQ(q.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(bias.dtype(), kFloat32); + auto qv = q.contiguous(); + auto bv = bias.contiguous(); + const int32_t B = qv.shape()[0]; + const int32_t H = qv.shape()[1]; + const int32_t T = qv.shape()[2]; + const int32_t D = qv.shape()[3]; + MLLM_RT_ASSERT_EQ(bv.shape()[0], H); + MLLM_RT_ASSERT_EQ(bv.shape()[1], D); + auto* q_ptr = qv.ptr(); + const auto* b_ptr = bv.ptr(); + + const int64_t q_stride_b = static_cast(H) * T * D; + const int64_t q_stride_h = static_cast(T) * D; + const int64_t q_stride_t = D; + + for (int32_t b = 0; b < B; ++b) { + for (int32_t h = 0; h < H; ++h) { + const float* bh = b_ptr + static_cast(h) * D; + for (int32_t t = 0; t < T; ++t) { + float* row = q_ptr + static_cast(b) * q_stride_b + static_cast(h) * q_stride_h + + static_cast(t) * q_stride_t; + for (int32_t d = 0; d < D; ++d) { row[d] += bh[d]; } + } + } + } + q = qv; +} + +inline Tensor concatChannel(const std::vector& xs) { + return nn::functional::concat(xs, 1); +} + +inline Tensor makeTimeStepsTensor(const std::vector& values) { + auto t = Tensor::empty({static_cast(values.size())}, kFloat32, kCPU).alloc(); + auto* ptr = t.ptr(); + for (size_t i = 0; i < values.size(); ++i) { ptr[i] = values[i]; } + return t; +} + +inline Tensor randomNormalLike(const std::vector& shape, float scale = 1.0f) { + auto out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + auto* ptr = out.ptr(); + const int64_t n = static_cast(out.numel()); + static thread_local std::mt19937 rng(std::random_device{}()); + std::normal_distribution dist(0.0f, 1.0f); + for (int64_t i = 0; i < n; ++i) { ptr[i] = dist(rng) * scale; } + return out; +} + +class EspnetRelPositionalEncoding final : public nn::Module { + public: + EspnetRelPositionalEncoding() = default; + EspnetRelPositionalEncoding(const std::string& name, int32_t dim) : nn::Module(name), dim_(dim) { xscale_ = std::sqrt(static_cast(dim)); } + + std::pair forwardWithPos(Tensor x) { + const int32_t T = x.shape()[1]; + auto pos = positionEncoding(T); + return {x * xscale_, pos}; + } + + private: + Tensor positionEncoding(int32_t size) const { + const int32_t dim = dim_; + auto pe_pos = Tensor::empty({size, dim}, kFloat32, kCPU).alloc(); + auto pe_neg = Tensor::empty({size, dim}, kFloat32, kCPU).alloc(); + auto* pos_ptr = pe_pos.ptr(); + auto* neg_ptr = pe_neg.ptr(); + + for (int32_t p = 0; p < size; ++p) { + for (int32_t i = 0; i < dim; i += 2) { + const float div = std::exp(-std::log(10000.0f) * static_cast(i) / static_cast(dim)); + const float v1 = std::sin(static_cast(p) * div); + const float v2 = std::cos(static_cast(p) * div); + pos_ptr[p * dim + i] = v1; + pos_ptr[p * dim + i + 1] = v2; + neg_ptr[p * dim + i] = -v1; + neg_ptr[p * dim + i + 1] = v2; + } + } + + auto pe_positive = Tensor::empty({1, size, dim}, kFloat32, kCPU).alloc(); + auto pe_negative = Tensor::empty({1, std::max(size - 1, 0), dim}, kFloat32, kCPU).alloc(); + auto* pp = pe_positive.ptr(); + auto* pn = pe_negative.ptr(); + for (int32_t i = 0; i < size; ++i) { + std::memcpy(pp + static_cast(i) * dim, pos_ptr + static_cast(size - 1 - i) * dim, + sizeof(float) * dim); + } + for (int32_t i = 1; i < size; ++i) { + std::memcpy(pn + static_cast(i - 1) * dim, neg_ptr + static_cast(i) * dim, sizeof(float) * dim); + } + return nn::functional::concat({pe_positive, pe_negative}, 1); + } + + private: + int32_t dim_ = 0; + float xscale_ = 1.0f; +}; + +class LinearNoSubsampling final : public nn::Module { + public: + LinearNoSubsampling() = default; + LinearNoSubsampling(const std::string& name, int32_t idim, int32_t odim) : nn::Module(name) { + out_linear_ = reg("out.0", idim, odim, true); + out_norm_ = reg("out.1", std::vector{odim}, true, true, 1e-5f); + pos_enc_ = reg("pos_enc", odim); + } + + std::pair forwardWithPos(Tensor x) { + auto y = out_linear_(x); + y = out_norm_(y); + return pos_enc_.forwardWithPos(y); + } + + private: + nn::Linear out_linear_; + nn::LayerNorm out_norm_; + EspnetRelPositionalEncoding pos_enc_; +}; + +class PositionwiseFeedForward final : public nn::Module { + public: + PositionwiseFeedForward() = default; + PositionwiseFeedForward(const std::string& name, int32_t idim, int32_t hidden_units) : nn::Module(name) { + w1_ = reg("w_1", idim, hidden_units, true); + w2_ = reg("w_2", hidden_units, idim, true); + } + + Tensor forwardOne(Tensor x) { + auto y = w1_(x); + y = nn::functional::silu(y); + y = w2_(y); + return y; + } + + private: + nn::Linear w1_; + nn::Linear w2_; +}; + +class RelPositionMultiHeadedAttention final : public nn::Module { + public: + RelPositionMultiHeadedAttention() = default; + RelPositionMultiHeadedAttention(const std::string& name, int32_t n_head, int32_t n_feat, bool key_bias) + : nn::Module(name), n_head_(n_head), n_feat_(n_feat) { + d_k_ = n_feat_ / n_head_; + linear_q_ = reg("linear_q", n_feat_, n_feat_, true); + linear_k_ = reg("linear_k", n_feat_, n_feat_, key_bias); + linear_v_ = reg("linear_v", n_feat_, n_feat_, true); + linear_out_ = reg("linear_out", n_feat_, n_feat_, true); + linear_pos_ = reg("linear_pos", n_feat_, n_feat_, false); + pos_bias_u_ = reg("pos_bias_u", getModuleName() + ".pos_bias_u", Tensor::shape_t{n_head_, d_k_}); + pos_bias_v_ = reg("pos_bias_v", getModuleName() + ".pos_bias_v", Tensor::shape_t{n_head_, d_k_}); + } + + Tensor forwardOne(Tensor x, Tensor pos_emb) { + auto q = linear_q_(x).view({x.shape()[0], x.shape()[1], n_head_, d_k_}).transpose(1, 2); // [B,H,T,D] + auto k = linear_k_(x).view({x.shape()[0], x.shape()[1], n_head_, d_k_}).transpose(1, 2); + auto v = linear_v_(x).view({x.shape()[0], x.shape()[1], n_head_, d_k_}).transpose(1, 2); + + auto p = linear_pos_(pos_emb).view({pos_emb.shape()[0], pos_emb.shape()[1], n_head_, d_k_}).transpose(1, 2); // [1,H,2T-1,D] + + auto q_with_bias_u = q.contiguous(); + auto q_with_bias_v = q.contiguous(); + addHeadBiasInplace(q_with_bias_u, pos_bias_u_.weight()); + addHeadBiasInplace(q_with_bias_v, pos_bias_v_.weight()); + + auto matrix_ac = nn::functional::matmul(q_with_bias_u, k.transpose(2, 3), false, false); // [B,H,T,T] + auto matrix_bd = nn::functional::matmul(q_with_bias_v, p.transpose(2, 3), false, false); // [B,H,T,2T-1] + if (matrix_ac.shape()[3] != matrix_bd.shape()[3]) { matrix_bd = relShift(matrix_bd); } + auto scores = (matrix_ac + matrix_bd) / std::sqrt(static_cast(d_k_)); + auto attn = nn::functional::softmax(scores, -1); + auto y = nn::functional::matmul(attn, v, false, false); // [B,H,T,D] + y = y.transpose(1, 2).view({x.shape()[0], x.shape()[1], n_feat_}); + return linear_out_(y); + } + + private: + int32_t n_head_ = 0; + int32_t n_feat_ = 0; + int32_t d_k_ = 0; + nn::Linear linear_q_; + nn::Linear linear_k_; + nn::Linear linear_v_; + nn::Linear linear_out_; + nn::Linear linear_pos_; + nn::Param pos_bias_u_; + nn::Param pos_bias_v_; +}; + +class ConformerEncoderLayer final : public nn::Module { + public: + ConformerEncoderLayer() = default; + ConformerEncoderLayer(const std::string& name, int32_t size, int32_t n_head, int32_t linear_units, bool key_bias) + : nn::Module(name) { + self_attn_ = reg("self_attn", n_head, size, key_bias); + feed_forward_ = reg("feed_forward", size, linear_units); + norm_ff_ = reg("norm_ff", std::vector{size}, true, true, 1e-12f); + norm_mha_ = reg("norm_mha", std::vector{size}, true, true, 1e-12f); + } + + Tensor forwardOne(Tensor x, Tensor pos_emb) { + auto h = norm_mha_(x); + auto y = self_attn_.forwardOne(h, pos_emb); + y = x + y; + auto z = norm_ff_(y); + z = feed_forward_.forwardOne(z); + return y + z; + } + + private: + RelPositionMultiHeadedAttention self_attn_; + PositionwiseFeedForward feed_forward_; + nn::LayerNorm norm_ff_; + nn::LayerNorm norm_mha_; +}; + +class PreLookaheadLayer final : public nn::Module { + public: + PreLookaheadLayer() = default; + PreLookaheadLayer(const std::string& name, int32_t channels, int32_t pre_lookahead_len) : nn::Module(name), pre_(pre_lookahead_len) { + conv1_ = reg("conv1", channels, channels, pre_ + 1, 1, 0, 1, 1, true); + conv2_ = reg("conv2", channels, channels, 3, 1, 0, 1, 1, true); + } + + Tensor forwardOne(Tensor inputs) { + auto x = inputs.transpose(1, 2).contiguous(); // [B,C,T] + x = nn::functional::pad(x, {0, pre_}, aops::PadMode::kConstant, 0.0f); // right pad + x = conv1_(x); + x = tensorLeakyRelu(x, 0.01f); + x = nn::functional::pad(x, {2, 0}, aops::PadMode::kConstant, 0.0f); // left pad + x = conv2_(x); + x = x.transpose(1, 2).contiguous(); // [B,T,C] + return x + inputs; + } + + private: + int32_t pre_ = 3; + nn::Conv1D conv1_; + nn::Conv1D conv2_; +}; + +class Upsample1D final : public nn::Module { + public: + Upsample1D() = default; + Upsample1D(const std::string& name, int32_t channels, int32_t out_channels, int32_t stride) : nn::Module(name), stride_(stride) { + conv_ = reg("conv", channels, out_channels, stride_ * 2 + 1, 1, 0, 1, 1, true); + } + + Tensor forwardOne(Tensor inputs) { + auto x = repeatInterleave1d(inputs, stride_); + x = nn::functional::pad(x, {stride_ * 2, 0}, aops::PadMode::kConstant, 0.0f); + return conv_(x); + } + + int32_t stride() const { return stride_; } + + private: + int32_t stride_ = 2; + nn::Conv1D conv_; +}; + +class UpsampleConformerEncoderV2 final : public nn::Module { + public: + UpsampleConformerEncoderV2() = default; + UpsampleConformerEncoderV2(const std::string& name, const MiniCPMO45FlowConfig& cfg) : nn::Module(name), cfg_(cfg) { + embed_ = reg("embed", cfg.input_size, cfg.input_size); + pre_lookahead_ = reg("pre_lookahead_layer", cfg.input_size, cfg.pre_lookahead_len); + encoders_ = reg>("encoders", cfg.encoder_num_blocks, cfg.input_size, + cfg.encoder_attention_heads, cfg.encoder_linear_units, true); + up_layer_ = reg("up_layer", cfg.input_size, cfg.input_size, cfg.up_rate); + up_embed_ = reg("up_embed", cfg.input_size, cfg.input_size); + up_encoders_ = reg>("up_encoders", cfg.encoder_num_up_blocks, cfg.input_size, + cfg.encoder_attention_heads, cfg.encoder_linear_units, true); + after_norm_ = reg("after_norm", std::vector{cfg.input_size}, true, true, 1e-5f); + } + + Tensor forwardOne(Tensor xs) { + auto [x0, pos0] = embed_.forwardWithPos(xs); + x0 = pre_lookahead_.forwardOne(x0); + for (auto& layer : encoders_.list()) { x0 = layer.forwardOne(x0, pos0); } + + x0 = x0.transpose(1, 2).contiguous(); + x0 = up_layer_.forwardOne(x0); + x0 = x0.transpose(1, 2).contiguous(); + + auto [x1, pos1] = up_embed_.forwardWithPos(x0); + for (auto& layer : up_encoders_.list()) { x1 = layer.forwardOne(x1, pos1); } + x1 = after_norm_(x1); + return x1; + } + + private: + MiniCPMO45FlowConfig cfg_; + LinearNoSubsampling embed_; + PreLookaheadLayer pre_lookahead_; + nn::ModuleList encoders_; + Upsample1D up_layer_; + LinearNoSubsampling up_embed_; + nn::ModuleList up_encoders_; + nn::LayerNorm after_norm_; +}; + +class DiTAttention final : public nn::Module { + public: + DiTAttention() = default; + DiTAttention(const std::string& name, int32_t dim, int32_t num_heads, int32_t head_dim) : nn::Module(name), + dim_(dim), heads_(num_heads), head_dim_(head_dim), inner_dim_(num_heads * head_dim) { + to_q_ = reg("to_q", dim_, inner_dim_, true); + to_k_ = reg("to_k", dim_, inner_dim_, true); + to_v_ = reg("to_v", dim_, inner_dim_, true); + q_norm_ = reg("q_norm", std::vector{head_dim_}, true, true, 1e-5f); + k_norm_ = reg("k_norm", std::vector{head_dim_}, true, true, 1e-5f); + proj_ = reg("proj", inner_dim_, dim_, true); + } + + Tensor forwardOne(Tensor x) { + debugLog("dit.attn: enter x(" + descOf(x) + ")"); + auto q = to_q_(x).view({x.shape()[0], x.shape()[1], heads_, head_dim_}).transpose(1, 2); // [B,H,T,D] + debugLog("dit.attn: to_q done"); + auto k = to_k_(x).view({x.shape()[0], x.shape()[1], heads_, head_dim_}).transpose(1, 2); + debugLog("dit.attn: to_k done"); + auto v = to_v_(x).view({x.shape()[0], x.shape()[1], heads_, head_dim_}).transpose(1, 2); + debugLog("dit.attn: to_v done"); + + q = q_norm_(q); + k = k_norm_(k); + + auto out = nn::functional::scaledDotProductAttention(q, k, v); // [B,H,T,D] + out = out.transpose(1, 2).contiguous().view({x.shape()[0], x.shape()[1], inner_dim_}); + out = proj_(out); + debugLog("dit.attn: exit"); + return out; + } + + private: + int32_t dim_ = 0; + int32_t heads_ = 0; + int32_t head_dim_ = 0; + int32_t inner_dim_ = 0; + nn::Linear to_q_; + nn::Linear to_k_; + nn::Linear to_v_; + nn::LayerNorm q_norm_; + nn::LayerNorm k_norm_; + nn::Linear proj_; +}; + +class CausalConv1dBlock final : public nn::Module { + public: + CausalConv1dBlock() = default; + CausalConv1dBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t kernel_size) : nn::Module(name), + in_channels_(in_channels), out_channels_(out_channels), kernel_size_(kernel_size) { + conv1_ = reg("block.1", in_channels_, out_channels_, kernel_size_, 1, 0, 1, 1, true); + norm_ = reg("block.3", std::vector{out_channels_}, true, true, 1e-5f); + conv2_ = reg("block.6", out_channels_, out_channels_, kernel_size_, 1, 0, 1, 1, true); + } + + Tensor forwardOne(Tensor x) { + auto y = x.transpose(1, 2).contiguous(); + y = nn::functional::pad(y, {kernel_size_ - 1, 0}, aops::PadMode::kConstant, 0.0f); + y = conv1_(y); + y = y.transpose(1, 2).contiguous(); + y = norm_(y); + y = tensorMish(y); + y = y.transpose(1, 2).contiguous(); + y = nn::functional::pad(y, {kernel_size_ - 1, 0}, aops::PadMode::kConstant, 0.0f); + y = conv2_(y); + y = y.transpose(1, 2).contiguous(); + return y; + } + + private: + int32_t in_channels_ = 0; + int32_t out_channels_ = 0; + int32_t kernel_size_ = 3; + nn::Conv1D conv1_; + nn::LayerNorm norm_; + nn::Conv1D conv2_; +}; + +class DiTMLP final : public nn::Module { + public: + DiTMLP() = default; + DiTMLP(const std::string& name, int32_t in_features, int32_t hidden_features) : nn::Module(name) { + fc1_ = reg("fc1", in_features, hidden_features, true); + gelu_ = reg("act"); + fc2_ = reg("fc2", hidden_features, in_features, true); + } + + Tensor forwardOne(Tensor x) { + auto y = fc1_(x); + y = gelu_(y); + y = fc2_(y); + return y; + } + + private: + nn::Linear fc1_; + nn::GELU gelu_; + nn::Linear fc2_; +}; + +class TimestepEmbedder final : public nn::Module { + public: + TimestepEmbedder() = default; + TimestepEmbedder(const std::string& name, int32_t hidden_size, int32_t frequency_embedding_size = 256) + : nn::Module(name), hidden_size_(hidden_size), freq_size_(frequency_embedding_size) { + fc1_ = reg("mlp.0", freq_size_, hidden_size_, true); + act_ = reg("mlp.1"); + fc2_ = reg("mlp.2", hidden_size_, hidden_size_, true); + } + + Tensor forwardOne(Tensor t) { + auto emb = timestepEmbedding(t, freq_size_); + emb = fc1_(emb); + emb = act_(emb); + emb = fc2_(emb); + return emb; + } + + private: + Tensor timestepEmbedding(Tensor t, int32_t dim) const { + MLLM_RT_ASSERT_EQ(t.dtype(), kFloat32); + auto tt = t.contiguous(); + const int32_t N = tt.shape()[0]; + const int32_t half = dim / 2; + auto out = Tensor::empty({N, dim}, kFloat32, kCPU).alloc(); + const auto* tp = tt.ptr(); + auto* op = out.ptr(); + for (int32_t i = 0; i < N; ++i) { + const float tv = tp[i] * 1000.0f; + for (int32_t j = 0; j < half; ++j) { + const float freq = std::exp(-std::log(10000.0f) * static_cast(j) / static_cast(half)); + const float a = tv * freq; + op[i * dim + j] = std::cos(a); + op[i * dim + half + j] = std::sin(a); + } + if (dim % 2 == 1) { op[i * dim + dim - 1] = 0.0f; } + } + return out; + } + + private: + int32_t hidden_size_ = 0; + int32_t freq_size_ = 0; + nn::Linear fc1_; + nn::SiLU act_; + nn::Linear fc2_; +}; + +class FinalLayer final : public nn::Module { + public: + FinalLayer() = default; + FinalLayer(const std::string& name, int32_t hidden_size, int32_t out_channels) : nn::Module(name) { + adaln_act_ = reg("adaLN_modulation.0"); + adaln_linear_ = reg("adaLN_modulation.1", hidden_size, 2 * hidden_size, true); + norm_ = reg("norm_final", std::vector{hidden_size}, false, false, 1e-6f); + linear_ = reg("linear", hidden_size, out_channels, true); + } + + Tensor forwardOne(Tensor x, Tensor c) { + auto m = adaln_linear_(adaln_act_(c)); + auto chunks = nn::functional::chunk<2>(m, 2); + auto shift = chunks[0]; + auto scale = chunks[1]; + auto y = norm_(x); + if (scale.rank() == 2) { scale = scale.view({scale.shape()[0], 1, scale.shape()[1]}); } + if (shift.rank() == 2) { shift = shift.view({shift.shape()[0], 1, shift.shape()[1]}); } + y = y * (scale + 1.0f) + shift; + y = linear_(y); + return y; + } + + private: + nn::SiLU adaln_act_; + nn::Linear adaln_linear_; + nn::LayerNorm norm_; + nn::Linear linear_; +}; + +class DiTBlock final : public nn::Module { + public: + DiTBlock() = default; + DiTBlock(const std::string& name, int32_t hidden_size, int32_t num_heads, int32_t head_dim, float mlp_ratio) + : nn::Module(name), hidden_size_(hidden_size) { + norm1_ = reg("norm1", std::vector{hidden_size_}, false, false, 1e-6f); + attn_ = reg("attn", hidden_size_, num_heads, head_dim); + norm2_ = reg("norm2", std::vector{hidden_size_}, false, false, 1e-6f); + norm3_ = reg("norm3", std::vector{hidden_size_}, false, false, 1e-6f); + const int32_t mlp_hidden = static_cast(hidden_size_ * mlp_ratio); + mlp_ = reg("mlp", hidden_size_, mlp_hidden); + conv_ = reg("conv", hidden_size_, hidden_size_, 3); + adaln_act_ = reg("adaLN_modulation.0"); + adaln_linear_ = reg("adaLN_modulation.1", hidden_size_, hidden_size_ * 9, true); + } + + Tensor forwardOne(Tensor x, Tensor c) { + debugLog("dit.block: enter x(" + descOf(x) + ") c(" + descOf(c) + ")"); + auto mods = adaln_linear_(adaln_act_(c)); // [B,1,9C] + debugLog("dit.block: adaln_linear done mods(" + descOf(mods) + ")"); + const int32_t C = hidden_size_; + auto shift_msa = mods[{kAll, kAll, {0 * C, 1 * C}}].contiguous(); + auto scale_msa = mods[{kAll, kAll, {1 * C, 2 * C}}].contiguous(); + auto gate_msa = mods[{kAll, kAll, {2 * C, 3 * C}}].contiguous(); + auto shift_mlp = mods[{kAll, kAll, {3 * C, 4 * C}}].contiguous(); + auto scale_mlp = mods[{kAll, kAll, {4 * C, 5 * C}}].contiguous(); + auto gate_mlp = mods[{kAll, kAll, {5 * C, 6 * C}}].contiguous(); + auto shift_conv = mods[{kAll, kAll, {6 * C, 7 * C}}].contiguous(); + auto scale_conv = mods[{kAll, kAll, {7 * C, 8 * C}}].contiguous(); + auto gate_conv = mods[{kAll, kAll, {8 * C, 9 * C}}].contiguous(); + debugLog("dit.block: chunk9 done"); + + auto y = norm1_(x); + y = y * (scale_msa + 1.0f) + shift_msa; + debugLog("dit.block: before attn y(" + descOf(y) + ")"); + auto attn_out = attn_.forwardOne(y); + debugLog("dit.block: attn done"); + auto h = x + attn_out * gate_msa; + + auto c_in = norm3_(h); + c_in = c_in * (scale_conv + 1.0f) + shift_conv; + auto conv_out = conv_.forwardOne(c_in); + debugLog("dit.block: conv done"); + h = h + conv_out * gate_conv; + + auto m_in = norm2_(h); + m_in = m_in * (scale_mlp + 1.0f) + shift_mlp; + auto mlp_out = mlp_.forwardOne(m_in); + debugLog("dit.block: mlp done"); + h = h + mlp_out * gate_mlp; + debugLog("dit.block: exit"); + return h; + } + + private: + int32_t hidden_size_ = 0; + nn::LayerNorm norm1_; + DiTAttention attn_; + nn::LayerNorm norm2_; + nn::LayerNorm norm3_; + DiTMLP mlp_; + CausalConv1dBlock conv_; + nn::SiLU adaln_act_; + nn::Linear adaln_linear_; +}; + +class DiTEstimator final : public nn::Module { + public: + DiTEstimator() = default; + DiTEstimator(const std::string& name, const MiniCPMO45FlowConfig& cfg) : nn::Module(name), cfg_(cfg) { + t_embedder_ = reg("t_embedder", cfg.dit_hidden_size, 256); + in_proj_ = reg("in_proj", cfg.dit_in_channels, cfg.dit_hidden_size, true); + blocks_ = reg>("blocks", cfg.dit_depth, cfg.dit_hidden_size, cfg.dit_num_heads, cfg.dit_head_dim, + cfg.dit_mlp_ratio); + final_layer_ = reg("final_layer", cfg.dit_hidden_size, cfg.dit_out_channels); + } + + Tensor forwardOne(Tensor x, Tensor mu, Tensor t, Tensor spks, Tensor cond) { + // x,mu,cond: [B,C,T], spks: [B,C], t:[B] + debugLog("dit.forward: begin"); + auto time_emb = t_embedder_.forwardOne(t).view({t.shape()[0], 1, cfg_.dit_hidden_size}); + debugLog("dit.forward: t_embedder done"); + auto spk_seq = spks.view({spks.shape()[0], spks.shape()[1], 1}).repeat(x.shape()[2], 2); + auto packed = concatChannel({x, mu, spk_seq, cond}); // [B,320,T] + debugLog("dit.forward: concat packed done"); + auto h = packed.transpose(1, 2).contiguous(); // [B,T,320] + h = in_proj_(h); // [B,T,512] + debugLog("dit.forward: in_proj done"); + int32_t block_idx = 0; + for (auto& block : blocks_.list()) { + h = block.forwardOne(h, time_emb); + if (block_idx == 0) { debugLog("dit.forward: block0 done"); } + ++block_idx; + } + h = final_layer_.forwardOne(h, time_emb); // [B,T,80] + debugLog("dit.forward: final_layer done"); + h = h.transpose(1, 2).contiguous(); // [B,80,T] + debugLog("dit.forward: end"); + return h; + } + + private: + MiniCPMO45FlowConfig cfg_; + TimestepEmbedder t_embedder_; + nn::Linear in_proj_; + nn::ModuleList blocks_; + FinalLayer final_layer_; +}; + +class CausalConditionalCFM final : public nn::Module { + public: + CausalConditionalCFM() = default; + CausalConditionalCFM(const std::string& name, const MiniCPMO45FlowConfig& cfg) : nn::Module(name), cfg_(cfg) { + estimator_ = reg("estimator", cfg_); + } + + Tensor forwardOne(Tensor mu, Tensor spks, Tensor cond, int32_t n_timesteps, float temperature = 1.0f) { + // all in float32 cpu. + debugLog("cfm.forward: start"); + const int32_t B = mu.shape()[0]; + const int32_t C = mu.shape()[1]; + const int32_t T = mu.shape()[2]; + MLLM_RT_ASSERT_EQ(B, 1); + + auto z = randomNormalLike({B, C, T}, temperature); + + std::vector t_span(static_cast(n_timesteps + 1), 0.0f); + constexpr float kPi = 3.14159265358979323846f; + for (int32_t i = 0; i <= n_timesteps; ++i) { + float t = static_cast(i) / static_cast(n_timesteps); + t_span[static_cast(i)] = 1.0f - std::cos(t * 0.5f * kPi); + } + + auto x = z; + auto mu_in = nn::functional::concat({mu, Tensor::zeros(mu.shape(), kFloat32, kCPU)}, 0); + auto spk_in = nn::functional::concat({spks, Tensor::zeros(spks.shape(), kFloat32, kCPU)}, 0); + auto cond_in = nn::functional::concat({cond, Tensor::zeros(cond.shape(), kFloat32, kCPU)}, 0); + + float t = t_span[0]; + float dt = t_span[1] - t_span[0]; + for (int32_t step = 1; step <= n_timesteps; ++step) { + if (step == 1) { debugLog("cfm.forward: first estimator step"); } + auto x_in = nn::functional::concat({x, x}, 0); // [2,C,T] + auto t_in = makeTimeStepsTensor({t, t}); + auto dphi = estimator_.forwardOne(x_in, mu_in, t_in, spk_in, cond_in); // [2,C,T] + auto dphi_split = nn::functional::chunk<2>(dphi, 0); + auto dphi_main = dphi_split[0]; + auto dphi_cfg = dphi_split[1]; + auto dphi_out = dphi_main * (1.0f + cfg_.cfm_inference_cfg_rate) - dphi_cfg * cfg_.cfm_inference_cfg_rate; + x = x + dphi_out * dt; + t += dt; + if (step < n_timesteps) { dt = t_span[static_cast(step + 1)] - t; } + } + debugLog("cfm.forward: finish"); + return x; + } + + private: + MiniCPMO45FlowConfig cfg_; + DiTEstimator estimator_; +}; + +class CausalMaskedDiffWithXvec final : public nn::Module { + public: + CausalMaskedDiffWithXvec() = default; + CausalMaskedDiffWithXvec(const std::string& name, const MiniCPMO45FlowConfig& cfg) : nn::Module(name), cfg_(cfg) { + input_embedding_ = reg("input_embedding", cfg.vocab_size, cfg.input_size); + spk_embed_affine_layer_ = reg("spk_embed_affine_layer", cfg.spk_embed_dim, cfg.output_size, true); + encoder_ = reg("encoder", cfg); + encoder_proj_ = reg("encoder_proj", cfg.input_size, cfg.output_size, true); + decoder_ = reg("decoder", cfg); + } + + Tensor inference(Tensor token, Tensor prompt_token, Tensor prompt_feat, Tensor embedding, + int32_t n_timesteps) { + // token/prompt_token: [1,T], int64 + debugLog("flow.inference: start"); + auto spk = l2NormalizeRow(embedding); + spk = spk_embed_affine_layer_(spk); // [1,80] + debugLog("flow.inference: spk_embed_affine_layer done"); + + auto all_token = concatInt64Seq(prompt_token, token); + auto token_embed = input_embedding_(all_token); + debugLog("flow.inference: input_embedding done"); + + auto h = encoder_.forwardOne(token_embed); + debugLog("flow.inference: encoder done"); + h = encoder_proj_(h); // [1, Tm, 80] + debugLog("flow.inference: encoder_proj done"); + + const int32_t mel_len1 = prompt_feat.shape()[1]; + const int32_t mel_len_total = h.shape()[1]; + const int32_t mel_len2 = mel_len_total - mel_len1; + MLLM_RT_ASSERT(mel_len2 > 0); + + auto conds = Tensor::zeros(h.shape(), kFloat32, kCPU); + // copy prompt mel to prefix. + auto* cond_ptr = conds.ptr(); + const auto* prm_ptr = prompt_feat.ptr(); + const int32_t C = h.shape()[2]; + for (int32_t t = 0; t < mel_len1; ++t) { + std::memcpy(cond_ptr + static_cast(t) * C, prm_ptr + static_cast(t) * C, sizeof(float) * C); + } + + auto feat = decoder_.forwardOne(h.transpose(1, 2).contiguous(), spk, conds.transpose(1, 2).contiguous(), n_timesteps); + debugLog("flow.inference: decoder done"); + // remove prompt part. + auto out = feat[{kAll, kAll, {mel_len1, mel_len1 + mel_len2}}].contiguous(); + debugLog("flow.inference: finish"); + return out; + } + + private: + MiniCPMO45FlowConfig cfg_; + nn::Embedding input_embedding_; + nn::Linear spk_embed_affine_layer_; + UpsampleConformerEncoderV2 encoder_; + nn::Linear encoder_proj_; + CausalConditionalCFM decoder_; +}; + +class SnakeActivation final : public nn::Module { + public: + SnakeActivation() = default; + SnakeActivation(const std::string& name, int32_t channels) : nn::Module(name) { + alpha_ = reg("alpha", getModuleName() + ".alpha", Tensor::shape_t{channels}); + } + + Tensor forwardOne(Tensor x) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto out = Tensor::empty(x.shape(), kFloat32, kCPU).alloc(); + auto in = x.contiguous(); + auto* dst = out.ptr(); + const auto* src = in.ptr(); + const auto* alpha = alpha_.weight().contiguous().ptr(); + const int32_t B = in.shape()[0]; + const int32_t C = in.shape()[1]; + const int32_t T = in.shape()[2]; + const int64_t stride_b = static_cast(C) * T; + const int64_t stride_c = T; + constexpr float eps = 1e-9f; + for (int32_t b = 0; b < B; ++b) { + for (int32_t c = 0; c < C; ++c) { + const float a = alpha[c]; + for (int32_t t = 0; t < T; ++t) { + const int64_t idx = static_cast(b) * stride_b + static_cast(c) * stride_c + t; + const float v = src[idx]; + const float s = std::sin(v * a); + dst[idx] = v + (s * s) / (a + eps); + } + } + } + return out; + } + + private: + nn::Param alpha_; +}; + +class ResBlock final : public nn::Module { + public: + ResBlock() = default; + ResBlock(const std::string& name, int32_t channels, int32_t kernel_size, const std::vector& dilations) + : nn::Module(name) { + MLLM_RT_ASSERT_EQ(static_cast(dilations.size()), 3); + for (int32_t i = 0; i < 3; ++i) { + convs1_.emplace_back(reg("convs1." + std::to_string(i), channels, channels, kernel_size, 1, + getPadding(kernel_size, dilations[i]), dilations[i], 1, true)); + convs2_.emplace_back(reg("convs2." + std::to_string(i), channels, channels, kernel_size, 1, + getPadding(kernel_size, 1), 1, 1, true)); + activations1_.emplace_back(reg("activations1." + std::to_string(i), channels)); + activations2_.emplace_back(reg("activations2." + std::to_string(i), channels)); + } + } + + Tensor forwardOne(Tensor x) { + auto out = x; + for (int32_t i = 0; i < 3; ++i) { + auto y = activations1_[i].forwardOne(out); + y = convs1_[i](y); + y = activations2_[i].forwardOne(y); + y = convs2_[i](y); + out = out + y; + } + return out; + } + + private: + static int32_t getPadding(int32_t kernel_size, int32_t dilation) { return (kernel_size * dilation - dilation) / 2; } + + private: + std::vector convs1_; + std::vector convs2_; + std::vector activations1_; + std::vector activations2_; +}; + +class ConvRNNF0Predictor final : public nn::Module { + public: + ConvRNNF0Predictor() = default; + ConvRNNF0Predictor(const std::string& name, int32_t in_channels = 80, int32_t cond_channels = 512) : nn::Module(name) { + condnet_0_ = reg("condnet.0", in_channels, cond_channels, 3, 1, 1, 1, 1, true); + condnet_2_ = reg("condnet.2", cond_channels, cond_channels, 3, 1, 1, 1, 1, true); + condnet_4_ = reg("condnet.4", cond_channels, cond_channels, 3, 1, 1, 1, 1, true); + condnet_6_ = reg("condnet.6", cond_channels, cond_channels, 3, 1, 1, 1, 1, true); + condnet_8_ = reg("condnet.8", cond_channels, cond_channels, 3, 1, 1, 1, 1, true); + classifier_ = reg("classifier", cond_channels, 1, true); + } + + Tensor forwardOne(Tensor x) { + auto y = condnet_0_(x); + y = tensorElu(y); + y = condnet_2_(y); + y = tensorElu(y); + y = condnet_4_(y); + y = tensorElu(y); + y = condnet_6_(y); + y = tensorElu(y); + y = condnet_8_(y); + y = tensorElu(y); + y = y.transpose(1, 2).contiguous(); + y = classifier_(y).squeeze(-1); + y = tensorAbs(y); + return y; + } + + private: + static Tensor tensorAbs(Tensor x) { + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + MLLM_CONDITIONAL_PARALLEL_FOR(n > 4096, 4, i, 0, n, 1, { dst[i] = std::abs(src[i]); }); + return out; + } + + static Tensor tensorElu(Tensor x) { + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + MLLM_CONDITIONAL_PARALLEL_FOR(n > 4096, 4, i, 0, n, 1, { + const float v = src[i]; + dst[i] = (v >= 0.0f) ? v : std::expm1(v); + }); + return out; + } + + private: + nn::Conv1D condnet_0_; + nn::Conv1D condnet_2_; + nn::Conv1D condnet_4_; + nn::Conv1D condnet_6_; + nn::Conv1D condnet_8_; + nn::Linear classifier_; +}; + +class SineGen2 { + public: + SineGen2() = default; + SineGen2(int32_t sampling_rate, int32_t upsample_scale, int32_t harmonic_num, float sine_amp, float noise_std, float voiced_threshold) + : sampling_rate_(sampling_rate), + upsample_scale_(upsample_scale), + harmonic_num_(harmonic_num), + sine_amp_(sine_amp), + noise_std_(noise_std), + voiced_threshold_(voiced_threshold) {} + + std::tuple forward(Tensor f0) { + // f0: [B, T, 1] + auto fn = makeHarmonics(f0); + auto sine = f02sine(fn) * sine_amp_; + auto uv = f02uv(f0); + auto inv_uv = uv * -1.0f + 1.0f; + auto noise_amp = uv * noise_std_ + inv_uv * (sine_amp_ / 3.0f); + auto noise = randomLike(noise_amp); + auto out = sine * uv + noise_amp * noise; + return {out, uv, noise_amp * noise}; + } + + private: + Tensor makeHarmonics(Tensor f0) const { + const int32_t B = f0.shape()[0]; + const int32_t T = f0.shape()[1]; + const int32_t H = harmonic_num_ + 1; + auto out = Tensor::empty({B, T, H}, kFloat32, kCPU).alloc(); + const auto* fp = f0.contiguous().ptr(); + auto* op = out.ptr(); + for (int32_t b = 0; b < B; ++b) { + for (int32_t t = 0; t < T; ++t) { + const float v = fp[(static_cast(b) * T + t)]; + for (int32_t h = 0; h < H; ++h) { op[(static_cast(b) * T + t) * H + h] = v * static_cast(h + 1); } + } + } + return out; + } + + Tensor f02uv(Tensor f0) const { + auto out = Tensor::empty(f0.shape(), kFloat32, kCPU).alloc(); + const auto* src = f0.contiguous().ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(out.numel()); + for (int64_t i = 0; i < n; ++i) { dst[i] = src[i] > voiced_threshold_ ? 1.0f : 0.0f; } + return out; + } + + Tensor f02sine(Tensor f0_values) const { + // f0_values: [B, T, H] + auto fv = f0_values.contiguous(); + const int32_t B = fv.shape()[0]; + const int32_t T = fv.shape()[1]; + const int32_t H = fv.shape()[2]; + auto rad = Tensor::empty(fv.shape(), kFloat32, kCPU).alloc(); + const auto* fp = fv.ptr(); + auto* rp = rad.ptr(); + for (int32_t b = 0; b < B; ++b) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t h = 0; h < H; ++h) { + const int64_t idx = (static_cast(b) * T + t) * H + h; + float v = fp[idx] / static_cast(sampling_rate_); + v = v - std::floor(v); + rp[idx] = v; + } + } + } + + std::mt19937 rng(std::random_device{}()); + std::uniform_real_distribution uni(0.0f, 1.0f); + for (int32_t b = 0; b < B; ++b) { + for (int32_t h = 1; h < H; ++h) { + const float phase0 = uni(rng); + rp[(static_cast(b) * T + 0) * H + h] += phase0; + } + } + + // linear interpolate in time by 1 / upsample_scale, then cumulative phase, then upsample back. + auto rad_t = rad.transpose(1, 2).contiguous(); // [B,H,T] + auto down_t = nn::functional::interpolateByScale(rad_t, {1.0f / static_cast(upsample_scale_)}, + aops::InterpolateOpMode::kLinear, false, false); + down_t = down_t.transpose(1, 2).contiguous(); // [B,T',H] + + auto phase = Tensor::empty(down_t.shape(), kFloat32, kCPU).alloc(); + auto* pp = phase.ptr(); + const auto* dp = down_t.ptr(); + const int32_t Td = down_t.shape()[1]; + for (int32_t b = 0; b < B; ++b) { + for (int32_t h = 0; h < H; ++h) { + float acc = 0.0f; + for (int32_t t = 0; t < Td; ++t) { + const int64_t idx = (static_cast(b) * Td + t) * H + h; + acc += dp[idx]; + constexpr float kPi = 3.14159265358979323846f; + pp[idx] = acc * 2.0f * kPi; + } + } + } + + auto phase_t = phase.transpose(1, 2).contiguous(); // [B,H,T'] + phase_t = phase_t * static_cast(upsample_scale_); + auto up_t = nn::functional::interpolateByScale(phase_t, {static_cast(upsample_scale_)}, + aops::InterpolateOpMode::kLinear, false, false); + up_t = up_t.transpose(1, 2).contiguous(); // [B,T,H] + auto out = nn::functional::sin(up_t); + return out; + } + + static Tensor randomLike(Tensor x) { + auto out = Tensor::empty(x.shape(), kFloat32, kCPU).alloc(); + auto* dst = out.ptr(); + const int64_t n = static_cast(out.numel()); + static thread_local std::mt19937 rng(std::random_device{}()); + std::normal_distribution dist(0.0f, 1.0f); + for (int64_t i = 0; i < n; ++i) { dst[i] = dist(rng); } + return out; + } + + private: + int32_t sampling_rate_ = 24000; + int32_t upsample_scale_ = 480; + int32_t harmonic_num_ = 8; + float sine_amp_ = 0.1f; + float noise_std_ = 0.003f; + float voiced_threshold_ = 10.0f; +}; + +class SourceModuleHnNSF2 { + public: + SourceModuleHnNSF2() = default; + SourceModuleHnNSF2(int32_t sampling_rate, int32_t upsample_scale, int32_t harmonic_num, float sine_amp, float noise_std, + float voiced_threshold) + : l_sin_gen_(sampling_rate, upsample_scale, harmonic_num, sine_amp, noise_std, voiced_threshold), + sine_amp_(sine_amp) {} + + // This wrapper only supports loading external weight via setLinearWeights(). + void setLinearWeights(Tensor w, Tensor b) { + linear_w_ = w.contiguous(); + linear_b_ = b.contiguous(); + } + + std::tuple forward(Tensor x) { + auto [sine_wavs, uv, _] = l_sin_gen_.forward(x); + auto sine_merge = linearForward(sine_wavs); + sine_merge = tensorTanh(sine_merge); + auto noise = randomLike(uv) * (sine_amp_ / 3.0f); + return {sine_merge, noise, uv}; + } + + private: + Tensor linearForward(Tensor x) { + // x: [B,T,H], weight [1,H] + MLLM_RT_ASSERT(!linear_w_.isNil()); + MLLM_RT_ASSERT(!linear_b_.isNil()); + const int32_t B = x.shape()[0]; + const int32_t T = x.shape()[1]; + const int32_t H = x.shape()[2]; + auto out = Tensor::empty({B, T, 1}, kFloat32, kCPU).alloc(); + const auto* xp = x.contiguous().ptr(); + const auto* wp = linear_w_.contiguous().ptr(); + const float bias = linear_b_.constAt({0}); + auto* op = out.ptr(); + for (int32_t b = 0; b < B; ++b) { + for (int32_t t = 0; t < T; ++t) { + float acc = bias; + for (int32_t h = 0; h < H; ++h) { acc += xp[(static_cast(b) * T + t) * H + h] * wp[h]; } + op[static_cast(b) * T + t] = acc; + } + } + return out; + } + + static Tensor randomLike(Tensor x) { + auto out = Tensor::empty(x.shape(), kFloat32, kCPU).alloc(); + auto* dst = out.ptr(); + const int64_t n = static_cast(out.numel()); + static thread_local std::mt19937 rng(std::random_device{}()); + std::normal_distribution dist(0.0f, 1.0f); + for (int64_t i = 0; i < n; ++i) { dst[i] = dist(rng); } + return out; + } + + static Tensor tensorTanh(Tensor x) { + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + for (int64_t i = 0; i < n; ++i) { dst[i] = std::tanh(src[i]); } + return out; + } + + private: + SineGen2 l_sin_gen_; + float sine_amp_ = 0.1f; + Tensor linear_w_ = Tensor::nil(); + Tensor linear_b_ = Tensor::nil(); +}; + +class HiFTGenerator final : public nn::Module { + public: + HiFTGenerator() = default; + HiFTGenerator(const std::string& name, const MiniCPMO45HiFTConfig& cfg) + : nn::Module(name), cfg_(cfg), + upsample_total_scale_(cfg.upsample_rates[0] * cfg.upsample_rates[1] * cfg.upsample_rates[2] * cfg.istft_hop_len), + m_source_(cfg.sampling_rate, upsample_total_scale_, cfg.nb_harmonics, cfg.nsf_alpha, cfg.nsf_sigma, cfg.nsf_voiced_threshold) { + conv_pre_ = reg("conv_pre", cfg.in_channels, cfg.base_channels, 7, 1, 3, 1, 1, true); + for (int32_t i = 0; i < static_cast(cfg.upsample_rates.size()); ++i) { + const int32_t in_ch = cfg.base_channels / static_cast(std::pow(2, i)); + const int32_t out_ch = cfg.base_channels / static_cast(std::pow(2, i + 1)); + ups_.emplace_back( + reg("ups." + std::to_string(i), in_ch, out_ch, cfg.upsample_kernel_sizes[i], cfg.upsample_rates[i], + (cfg.upsample_kernel_sizes[i] - cfg.upsample_rates[i]) / 2, 0, 1, 1, true)); + } + + // source downs + std::vector downsample_rates = {1, cfg.upsample_rates[2], cfg.upsample_rates[2] * cfg.upsample_rates[1]}; + std::reverse(downsample_rates.begin(), downsample_rates.end()); // [15,3,1] + for (int32_t i = 0; i < static_cast(downsample_rates.size()); ++i) { + const int32_t u = downsample_rates[i]; + const int32_t out_ch = cfg.base_channels / static_cast(std::pow(2, i + 1)); + if (u == 1) { + source_downs_.emplace_back(reg("source_downs." + std::to_string(i), cfg.istft_n_fft + 2, out_ch, 1, 1, 0, 1, 1, true)); + } else { + source_downs_.emplace_back(reg("source_downs." + std::to_string(i), cfg.istft_n_fft + 2, out_ch, u * 2, u, (u / 2), 1, 1, true)); + } + source_resblocks_.emplace_back( + reg("source_resblocks." + std::to_string(i), out_ch, cfg.source_resblock_kernel_sizes[i], cfg.source_resblock_dilation_sizes[i])); + } + + const int32_t num_ups = static_cast(cfg.upsample_rates.size()); + const int32_t num_kernels = static_cast(cfg.resblock_kernel_sizes.size()); + for (int32_t i = 0; i < num_ups; ++i) { + const int32_t ch = cfg.base_channels / static_cast(std::pow(2, i + 1)); + for (int32_t j = 0; j < num_kernels; ++j) { + resblocks_.emplace_back(reg("resblocks." + std::to_string(static_cast(resblocks_.size())), ch, + cfg.resblock_kernel_sizes[j], cfg.resblock_dilation_sizes[j])); + } + } + + conv_post_ = reg("conv_post", cfg.base_channels / static_cast(std::pow(2, cfg.upsample_rates.size())), + cfg.istft_n_fft + 2, 7, 1, 3, 1, 1, true); + f0_predictor_ = reg("f0_predictor"); + stft_ = reg("internal_stft", cfg.istft_n_fft, cfg.istft_hop_len, cfg.istft_n_fft, true, true, "reflect", false); + istft_ = reg("internal_istft", cfg.istft_n_fft, cfg.istft_hop_len, cfg.istft_n_fft, true, true, "reflect"); + hann_window_ = makeHannWindow(cfg.istft_n_fft); + } + + void loadFromParameter(const ParameterFile::ptr_t& param) { + nn::Module::load(param); + // SourceModuleHnNSF2 linear is not a nn::Module member, load manually. + auto w = param->pull(getModuleName() + ".m_source.l_linear.weight"); + auto b = param->pull(getModuleName() + ".m_source.l_linear.bias"); + if (w.dtype() != kFloat32) { w = w.to(kFloat32); } + if (b.dtype() != kFloat32) { b = b.to(kFloat32); } + w = w.contiguous().view({1, cfg_.nb_harmonics + 1}); + b = b.contiguous().view({1}); + m_source_.setLinearWeights(w, b); + } + + Tensor forwardOne(Tensor speech_feat) { + auto f0 = f0_predictor_.forwardOne(speech_feat); // [B,T] + auto f0_ex = f0.view({f0.shape()[0], 1, f0.shape()[1]}); // [B,1,T] + auto s = repeatInterleave1d(f0_ex, upsample_total_scale_).transpose(1, 2); // [B,S,1] + auto [s_merge, _, _uv] = m_source_.forward(s); + auto src = s_merge.transpose(1, 2).contiguous(); // [B,1,S] + auto wav = decode(speech_feat, src); + return wav; + } + + private: + Tensor decode(Tensor x_in, Tensor s) { + auto stft = stft_(s.squeeze(1), hann_window_); // [B,F,T,2] + auto stft_chunks = nn::functional::chunk<2>(stft, 3); + auto s_real = stft_chunks[0].squeeze(-1); + auto s_imag = stft_chunks[1].squeeze(-1); + auto s_stft = nn::functional::concat({s_real, s_imag}, 1); // [B,F*2,T] + + auto x = conv_pre_(x_in); + const int32_t num_ups = static_cast(ups_.size()); + const int32_t num_kernels = static_cast(cfg_.resblock_kernel_sizes.size()); + for (int32_t i = 0; i < num_ups; ++i) { + x = tensorLeakyRelu(x, cfg_.lrelu_slope); + x = ups_[i](x); + if (i == num_ups - 1) { x = nn::functional::pad(x, {1, 0}, aops::PadMode::kReflect); } + + auto si = source_downs_[i](s_stft); + si = source_resblocks_[i].forwardOne(si); + x = x + si; + + Tensor xs = Tensor::nil(); + for (int32_t j = 0; j < num_kernels; ++j) { + auto y = resblocks_[i * num_kernels + j].forwardOne(x); + if (j == 0) { + xs = y; + } else { + xs = xs + y; + } + } + x = xs / static_cast(num_kernels); + } + + x = tensorLeakyRelu(x, 0.01f); + x = conv_post_(x); // [B,18,T] + auto mag = x[{kAll, {0, cfg_.istft_n_fft / 2 + 1}, kAll}].contiguous(); + auto phase = x[{kAll, {cfg_.istft_n_fft / 2 + 1, cfg_.istft_n_fft + 2}, kAll}].contiguous(); + mag = nn::functional::exp(mag); + mag = nn::functional::clip(mag, 0.0f, 1e2f); + // Keep parity with python HiFT: phase is first squashed by sin() before ISTFT synthesis. + phase = nn::functional::sin(phase); + auto real = mag * nn::functional::cos(phase); + auto imag = mag * nn::functional::sin(phase); + auto S = real + std::complex{0, 1} * imag; + auto wav = istft_(S, hann_window_); + wav = nn::functional::clip(wav, -cfg_.audio_limit, cfg_.audio_limit); + return wav; + } + + private: + MiniCPMO45HiFTConfig cfg_; + int32_t upsample_total_scale_ = 480; + nn::Conv1D conv_pre_; + std::vector ups_; + std::vector source_downs_; + std::vector source_resblocks_; + std::vector resblocks_; + nn::Conv1D conv_post_; + ConvRNNF0Predictor f0_predictor_; + nn::STFT stft_; + nn::ISTFT istft_; + Tensor hann_window_ = Tensor::nil(); + SourceModuleHnNSF2 m_source_; +}; + +class MiniCPMO45Token2WavModel final : public nn::Module { + public: + MiniCPMO45Token2WavModel() = default; + MiniCPMO45Token2WavModel(const std::string& name, const MiniCPMO45Token2WavConfig& cfg) : nn::Module(name), cfg_(cfg) { + flow_model_ = reg("flow_model", cfg_.flow); + hift_model_ = reg("hift_model", cfg_.hift); + } + + void loadFromParameter(const ParameterFile::ptr_t& param_file) { + // Materialize weight_norm reparameterized conv weights in-place. + (void)materializeWeightNormParameters(param_file, getModuleName() + ".hift_model."); + flow_model_.load(param_file); + hift_model_.loadFromParameter(param_file); + } + + Tensor infer(const std::vector& token_ids, const MiniCPMO45Token2WavPromptCache& prompt_cache, int32_t n_timesteps) { + if (token_ids.empty()) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "MiniCPM-o-4_5 token2wav got empty token ids."); } + if (prompt_cache.prompt_tokens.empty() || prompt_cache.prompt_mels.isNil() || prompt_cache.spk_emb.isNil()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "MiniCPM-o-4_5 token2wav prompt cache is incomplete."); + } + + auto token = Tensor::empty({1, static_cast(token_ids.size())}, kInt64, kCPU).alloc(); + for (int32_t i = 0; i < static_cast(token_ids.size()); ++i) { token.at({0, i}) = token_ids[static_cast(i)]; } + + auto prompt_token = Tensor::empty({1, static_cast(prompt_cache.prompt_tokens.size())}, kInt64, kCPU).alloc(); + for (int32_t i = 0; i < static_cast(prompt_cache.prompt_tokens.size()); ++i) { + prompt_token.at({0, i}) = static_cast(prompt_cache.prompt_tokens[static_cast(i)]); + } + + auto prompt_mels = Tensor(prompt_cache.prompt_mels); + auto spk = Tensor(prompt_cache.spk_emb); + if (prompt_mels.dtype() != kFloat32) { prompt_mels = prompt_mels.to(kFloat32); } + if (spk.dtype() != kFloat32) { spk = spk.to(kFloat32); } + + auto mel = flow_model_.inference(token, prompt_token, prompt_mels, spk, n_timesteps); + auto wav = hift_model_.forwardOne(mel); + return wav; + } + + private: + MiniCPMO45Token2WavConfig cfg_; + CausalMaskedDiffWithXvec flow_model_; + HiFTGenerator hift_model_; +}; + +} // namespace token2wav + +using token2wav::MiniCPMO45Token2WavModel; + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/minicpm_o45/token2wav_prompt_cache.hpp b/mllm/models/minicpm_o45/token2wav_prompt_cache.hpp new file mode 100644 index 000000000..fecbe1f93 --- /dev/null +++ b/mllm/models/minicpm_o45/token2wav_prompt_cache.hpp @@ -0,0 +1,70 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include + +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Log.hpp" + +namespace mllm::models::minicpm_o45 { + +struct MiniCPMO45Token2WavPromptCache { + std::vector prompt_tokens; + Tensor prompt_mels = Tensor::nil(); // [1, Tm, 80], float32 + Tensor spk_emb = Tensor::nil(); // [1, 192], float32 +}; + +inline MiniCPMO45Token2WavPromptCache loadMiniCPMO45Token2WavPromptCache(const std::string& file_path) { + MiniCPMO45Token2WavPromptCache out; + + std::ifstream in(file_path, std::ios::binary); + if (!in.is_open()) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Failed to open MiniCPM-o-4_5 prompt cache: {}", file_path); + } + + std::array magic{}; + in.read(magic.data(), static_cast(magic.size())); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt cache header read failed: {}", file_path); } + const std::array expected = {'M', '4', '5', 'P', 'C', '1', '\0', '\0'}; + if (magic != expected) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Invalid prompt cache magic: {}", file_path); } + + uint32_t version = 0; + in.read(reinterpret_cast(&version), sizeof(version)); + if (version != 1) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Unsupported prompt cache version {}: {}", version, file_path); } + + int32_t token_len = 0; + int32_t mel_frames = 0; + int32_t mel_dim = 0; + int32_t spk_dim = 0; + in.read(reinterpret_cast(&token_len), sizeof(token_len)); + in.read(reinterpret_cast(&mel_frames), sizeof(mel_frames)); + in.read(reinterpret_cast(&mel_dim), sizeof(mel_dim)); + in.read(reinterpret_cast(&spk_dim), sizeof(spk_dim)); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt cache meta read failed: {}", file_path); } + if (token_len <= 0 || mel_frames <= 0 || mel_dim <= 0 || spk_dim <= 0) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt cache has invalid shape metadata: {}", file_path); + } + + out.prompt_tokens.resize(static_cast(token_len)); + in.read(reinterpret_cast(out.prompt_tokens.data()), sizeof(int32_t) * static_cast(token_len)); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt token section read failed: {}", file_path); } + + out.prompt_mels = Tensor::empty({1, mel_frames, mel_dim}, kFloat32, kCPU).alloc(); + in.read(reinterpret_cast(out.prompt_mels.ptr()), + sizeof(float) * static_cast(mel_frames) * static_cast(mel_dim)); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt mel section read failed: {}", file_path); } + + out.spk_emb = Tensor::empty({1, spk_dim}, kFloat32, kCPU).alloc(); + in.read(reinterpret_cast(out.spk_emb.ptr()), sizeof(float) * static_cast(spk_dim)); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt speaker embedding section read failed: {}", file_path); } + + return out; +} + +} // namespace mllm::models::minicpm_o45 + diff --git a/mllm/models/minicpm_o45/token2wav_weight_norm.hpp b/mllm/models/minicpm_o45/token2wav_weight_norm.hpp new file mode 100644 index 000000000..cfc7fdd9a --- /dev/null +++ b/mllm/models/minicpm_o45/token2wav_weight_norm.hpp @@ -0,0 +1,80 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include + +#include "mllm/core/ParameterFile.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Log.hpp" + +namespace mllm::models::minicpm_o45 { + +inline bool _endsWith(const std::string& s, const std::string& suffix) { + return s.size() >= suffix.size() && s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +inline Tensor _materializeWeightNorm(Tensor g_in, Tensor v_in) { + auto g = g_in.dtype() == kFloat32 ? g_in.contiguous() : g_in.to(kFloat32).contiguous(); + auto v = v_in.dtype() == kFloat32 ? v_in.contiguous() : v_in.to(kFloat32).contiguous(); + + const int64_t out_dim = static_cast(g.numel()); + const int64_t total = static_cast(v.numel()); + if (out_dim <= 0 || total <= 0 || (total % out_dim) != 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, + "Invalid weight-norm tensor shape: g.numel()={}, v.numel()={}", + out_dim, total); + } + + const int64_t row = total / out_dim; + auto w = Tensor::empty({static_cast(total)}, kFloat32, kCPU).alloc(); + auto* g_ptr = g.ptr(); + auto* v_ptr = v.ptr(); + auto* w_ptr = w.ptr(); + + constexpr float kEps = 1e-12f; + for (int64_t i = 0; i < out_dim; ++i) { + const int64_t base = i * row; + float norm = 0.0f; + for (int64_t j = 0; j < row; ++j) { + const float val = v_ptr[base + j]; + norm += val * val; + } + norm = std::sqrt(std::max(norm, kEps)); + const float scale = g_ptr[i] / norm; + for (int64_t j = 0; j < row; ++j) { w_ptr[base + j] = v_ptr[base + j] * scale; } + } + return w; +} + +inline int32_t materializeWeightNormParameters(const ParameterFile::ptr_t& param_file, const std::string& scope_prefix) { + std::vector keys; + keys.reserve(param_file->dict().size()); + for (const auto& kv : param_file->dict()) { keys.push_back(kv.first); } + + const std::string marker = ".parametrizations.weight.original0"; + int32_t count = 0; + for (const auto& key : keys) { + if (!_endsWith(key, marker)) { continue; } + if (!scope_prefix.empty() && key.rfind(scope_prefix, 0) != 0) { continue; } + + const auto prefix = key.substr(0, key.size() - marker.size()); + const auto key_g = prefix + ".parametrizations.weight.original0"; + const auto key_v = prefix + ".parametrizations.weight.original1"; + const auto key_w = prefix + ".weight"; + + if (param_file->has(key_w)) { continue; } + if (!param_file->has(key_g) || !param_file->has(key_v)) { continue; } + + auto weight = _materializeWeightNorm(param_file->pull(key_g), param_file->pull(key_v)); + param_file->push(key_w, weight); + ++count; + } + return count; +} + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/minicpm_o45/tokenization_minicpm_o45.hpp b/mllm/models/minicpm_o45/tokenization_minicpm_o45.hpp new file mode 100644 index 000000000..0a0e00ca0 --- /dev/null +++ b/mllm/models/minicpm_o45/tokenization_minicpm_o45.hpp @@ -0,0 +1,547 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "mllm/core/DataTypes.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/minicpm_o2_6/audio_preprocessor_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp" +#include "mllm/preprocessor/audio/Audio.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" + +namespace mllm::models::minicpm_o45 { + +// Same tokenizer splitting rules as Qwen2/Qwen3 family. +inline bool miniCPMO45TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + { + size_t original_pos = pos; + bool has_prefix = false; + matched.clear(); + + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + has_prefix = true; + } + + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } + + if (has_prefix) { + pos = original_pos; + matched.clear(); + } + } + + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + { + size_t original_pos = pos; + matched.clear(); + size_t start = pos; + + if (str[pos] == L' ') { ++pos; } + + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + + matched = str.substr(start, pos - start); + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } + + pos = original_pos; + } + + { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos; + matched = str.substr(start, pos - start); + return true; + } + pos = start; + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos >= str.size() || std::iswspace(str[pos])) { + matched = str.substr(start, pos - start); + return true; + } + pos = start; + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline bool miniCPMO45Regex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (miniCPMO45TokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } + return true; +} + +struct MiniCPMO45Message { + std::string prompt; + std::string img_file_path; + std::string audio_file_path; + std::string system_prompt = + "You are a helpful assistant. You can accept video, audio and text input and output voice and text."; + + [[nodiscard]] std::string buildChatMessage(bool generate_audio = false) const { + std::string result; + if (!system_prompt.empty()) { result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; } + + result += "<|im_start|>user\n"; + if (!img_file_path.empty()) { result += "(./)"; } + if (!audio_file_path.empty()) { result += "()"; } + + if (!prompt.empty()) { + if (!img_file_path.empty() || !audio_file_path.empty()) { result += "\n"; } + result += prompt; + } + + result += "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + + if (generate_audio) { result += "<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"; } + return result; + } +}; + +class MiniCPMO45Tokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit MiniCPMO45Tokenizer(const std::string& tokenizer_path, int32_t patch_size = 14, int32_t audio_pool_step = 5) + : image_preprocessor_(patch_size), + audio_preprocessor_(16000, 80, 160), + audio_pool_step_(audio_pool_step) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + + bpe_.initFromSentencePieceJson(tokenizer_path); + + const std::vector special_tokens = { + L"", + L"<|endoftext|>", + L"<|im_start|>", + L"<|im_end|>", + L"<|object_ref_start|>", + L"<|object_ref_end|>", + L"<|box_start|>", + L"<|box_end|>", + L"<|quad_start|>", + L"<|quad_end|>", + L"<|vision_start|>", + L"<|vision_end|>", + L"<|vision_pad|>", + L"<|image_pad|>", + L"<|video_pad|>", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"<|audio_start|>", + L"<|audio|>", + L"<|audio_end|>", + L"<|spk_bos|>", + L"<|spk|>", + L"<|spk_eos|>", + L"<|tts_bos|>", + L"<|tts_eos|>", + L"<|listen|>", + L"<|speak|>", + L"<|interrupt|>", + L"<|vad_start|>", + L"<|vad_end|>", + L"<|chunk_eos|>", + L"<|chunk_bos|>", + L"<|chunk_tts_bos|>", + L"<|chunk_tts_eos|>", + }; + + for (const auto& token : special_tokens) { addSpecialToken(token); } + loadSpecialTokensFromTokenizerJson(tokenizer_path); + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + ::mllm::models::minicpm_o45::miniCPMO45Regex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + + auto bpe_tokens = bpe_._bpe(mapped_str); + for (const auto& bpe_token : bpe_tokens) { ret.push_back(bpe_token); } + } + return ret; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { utf_8_str.push_back(static_cast(bytes_2_unicode_dict_inverse_[c])); } + return {mllm::preprocessor::utf8string2WideString(utf_8_str)}; + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor ret = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("minicpmo45-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + return ret; + } + + int64_t lookupTokenId(const std::wstring& token) { return bpe_._lookup_vocab(token); } + + ARGenerationOutputPast convertMessage(const MiniCPMO45Message& message, bool generate_audio_prompt = false) { + bool has_image = !message.img_file_path.empty(); + bool has_audio = !message.audio_file_path.empty(); + + auto applied_string = message.buildChatMessage(generate_audio_prompt); + + std::vector img_tensors; + std::vector> original_sizes; + std::vector> tgt_sizes; + std::vector grid; + + Tensor audio_features = Tensor::nil(); + int32_t audio_length = 0; + + if (has_image) { + auto [tensors, orig_size, target_sizes, img_grid] = image_preprocessor_.process(message.img_file_path); + img_tensors = std::move(tensors); + original_sizes = std::move(orig_size); + tgt_sizes = std::move(target_sizes); + grid = std::move(img_grid); + } + + if (has_audio) { + auto audio_data = mllm::audio::readWAV(message.audio_file_path, 16000); + audio_length = static_cast(audio_data.size()); + audio_features = audio_preprocessor_.processAudioData(audio_data.data(), audio_length); + } + + if (has_image) { + std::regex img_pattern(R"(\(\./\))"); + std::vector image_tags; + std::sregex_iterator iter(applied_string.begin(), applied_string.end(), img_pattern); + std::sregex_iterator end; + + for (; iter != end; ++iter) { image_tags.push_back(iter->str()); } + + std::vector text_chunks; + int32_t pos = 0; + for (const auto& tag : image_tags) { + auto found = applied_string.find(tag, pos); + if (found != std::string::npos) { + text_chunks.push_back(applied_string.substr(pos, found - pos)); + pos = static_cast(found + tag.size()); + } + } + text_chunks.push_back(applied_string.substr(pos)); + + std::string final_text; + for (size_t i = 0; i < image_tags.size(); ++i) { + final_text += text_chunks[i]; + final_text += image_preprocessor_.get_slice_image_placeholder(original_sizes[i], grid, static_cast(i)); + } + final_text += text_chunks.back(); + applied_string = final_text; + } + + if (has_audio) { + auto audio_placeholder = getAudioPlaceholder(audio_length, false); + size_t audio_placeholder_pos = applied_string.find("()"); + if (audio_placeholder_pos != std::string::npos) { + applied_string.replace(audio_placeholder_pos, std::string("()").size(), audio_placeholder); + } + } + + auto sequence_str = tokenize(applied_string); + std::vector input_ids_vec; + input_ids_vec.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { input_ids_vec.emplace_back(bpe_._lookup_vocab(str)); } + + std::vector> image_bounds; + std::vector> audio_bounds; + + if (has_image) { + auto [_, bounds] = image_preprocessor_.calc_bounds(input_ids_vec, bpe_); + image_bounds = std::move(bounds); + } + + if (has_audio) { + int64_t audio_start_id = bpe_._lookup_vocab(L"<|audio_start|>"); + int64_t audio_end_id = bpe_._lookup_vocab(L"<|audio_end|>"); + audio_bounds = audio_preprocessor_.calcAudioBounds(input_ids_vec, audio_start_id, audio_end_id); + } + + return convertToTensors(input_ids_vec, img_tensors, tgt_sizes, image_bounds, audio_features, audio_bounds); + } + + private: + void addSpecialToken(const std::wstring& token) { + if (!token.empty()) { special_tokens_trie_.add(token); } + } + + void loadSpecialTokensFromTokenizerJson(const std::string& tokenizer_path) { + std::ifstream in(tokenizer_path); + if (!in.is_open()) { return; } + + nlohmann::json json_data; + try { + json_data = nlohmann::json::parse(in); + } catch (...) { + return; + } + + if (!json_data.contains("added_tokens") || !json_data["added_tokens"].is_array()) { return; } + for (const auto& token_info : json_data["added_tokens"]) { + if (!token_info.contains("content")) { continue; } + addSpecialToken(preprocessor::utf8string2WideString(token_info["content"].get())); + } + } + + [[nodiscard]] std::string getAudioPlaceholder(int32_t audio_length, bool chunk_input, float chunk_length = 1.0f) const { + int32_t capped_audio_length = std::min(audio_length, max_audio_samples_); + int32_t feature_lens = static_cast(std::ceil(static_cast(capped_audio_length) / hop_length_)); + feature_lens = (feature_lens - 1) / 2 + 1; + + auto output_lens = (feature_lens - audio_pool_step_) / audio_pool_step_ + 1; + output_lens = std::max(output_lens, 0); + + if (!chunk_input) { + std::string audio_placeholder = "<|audio_start|>"; + for (int32_t i = 0; i < output_lens; ++i) { audio_placeholder += ""; } + audio_placeholder += "<|audio_end|>"; + return audio_placeholder; + } + + auto fbank_feat_in_chunk = static_cast(chunk_length * 100); + auto cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) / 2 + 1; + auto audio_embeds_in_chunk = (cnn_feat_in_chunk - audio_pool_step_) / audio_pool_step_ + 1; + audio_embeds_in_chunk = std::max(audio_embeds_in_chunk, 1); + + auto num_audio_chunks = (output_lens + audio_embeds_in_chunk - 1) / audio_embeds_in_chunk; + + std::string placeholders; + int32_t total_unk_len = 0; + for (int32_t i = 0; i < num_audio_chunks; ++i) { + auto unk_len = std::min(audio_embeds_in_chunk, output_lens - total_unk_len); + placeholders += "<|audio_start|>"; + for (int32_t j = 0; j < unk_len; ++j) { placeholders += ""; } + placeholders += "<|audio_end|>"; + total_unk_len += unk_len; + } + return placeholders; + } + + ARGenerationOutputPast convertToTensors(const std::vector& input_ids_vec, std::vector& img_tensors, + const std::vector>& tgt_sizes, + const std::vector>& image_bounds, const Tensor& audio_features, + const std::vector>& audio_bounds) { + ARGenerationOutputPast result; + + if (!input_ids_vec.empty()) { + auto input_ids_tensor = Tensor::empty({1, static_cast(input_ids_vec.size())}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("input_ids") + .alloc(); + auto* input_ids_ptr = input_ids_tensor.ptr(); + for (size_t i = 0; i < input_ids_vec.size(); ++i) { input_ids_ptr[i] = input_ids_vec[i]; } + result["input_ids"] = input_ids_tensor; + } + + if (!img_tensors.empty()) { + int32_t channels = img_tensors[0].shape()[0]; + int32_t patch_size = img_tensors[0].shape()[1]; + int32_t hw_patch_size = img_tensors[0].shape()[2]; + for (const auto& img_tensor : img_tensors) { + if (img_tensor.shape()[2] > hw_patch_size) { hw_patch_size = img_tensor.shape()[2]; } + } + + auto pixel_values = Tensor::empty({static_cast(img_tensors.size()), channels, patch_size, hw_patch_size}, kFloat32, + kCPU) + .setMemType(kExtraInput) + .setName("pixel_values") + .alloc(); + auto* pixel_values_ptr = pixel_values.ptr(); + std::memset(pixel_values_ptr, 0, static_cast(img_tensors.size()) * channels * patch_size * hw_patch_size * sizeof(float_t)); + + for (int32_t b = 0; b < static_cast(img_tensors.size()); ++b) { + int32_t src_hw = img_tensors[b].shape()[2]; + const auto* src_ptr = img_tensors[b].ptr(); + + for (int32_t c = 0; c < channels; ++c) { + for (int32_t p = 0; p < patch_size; ++p) { + int32_t src_offset = c * patch_size * src_hw + p * src_hw; + int32_t dst_offset = b * channels * patch_size * hw_patch_size + c * patch_size * hw_patch_size + p * hw_patch_size; + std::memcpy(pixel_values_ptr + dst_offset, src_ptr + src_offset, src_hw * sizeof(float_t)); + } + } + } + + result["pixel_values"] = pixel_values; + } + + if (!tgt_sizes.empty()) { + auto tgt_sizes_tensor = Tensor::empty({static_cast(tgt_sizes.size()), 2}, kInt32, kCPU) + .setMemType(kExtraInput) + .setName("tgt_sizes") + .alloc(); + auto* tgt_sizes_ptr = tgt_sizes_tensor.ptr(); + for (size_t i = 0; i < tgt_sizes.size(); ++i) { + tgt_sizes_ptr[i * 2] = tgt_sizes[i].first; + tgt_sizes_ptr[i * 2 + 1] = tgt_sizes[i].second; + } + result["tgt_sizes"] = tgt_sizes_tensor; + } + + if (!image_bounds.empty()) { + auto image_bounds_tensor = Tensor::empty({static_cast(image_bounds.size()), 2}, kInt32, kCPU) + .setMemType(kExtraInput) + .setName("image_bounds") + .alloc(); + auto* image_bounds_ptr = image_bounds_tensor.ptr(); + for (size_t i = 0; i < image_bounds.size(); ++i) { + image_bounds_ptr[i * 2] = image_bounds[i].first; + image_bounds_ptr[i * 2 + 1] = image_bounds[i].second; + } + result["image_bounds"] = image_bounds_tensor; + } + + if (!audio_features.isNil()) { result["audio_features"] = audio_features; } + + if (!audio_bounds.empty()) { + auto audio_bounds_tensor = Tensor::empty({static_cast(audio_bounds.size()), 2}, kInt32, kCPU) + .setMemType(kExtraInput) + .setName("audio_bounds") + .alloc(); + auto* audio_bounds_ptr = audio_bounds_tensor.ptr(); + for (size_t i = 0; i < audio_bounds.size(); ++i) { + audio_bounds_ptr[i * 2] = audio_bounds[i].first; + audio_bounds_ptr[i * 2 + 1] = audio_bounds[i].second; + } + result["audio_bounds"] = audio_bounds_tensor; + } + + return result; + } + + private: + minicpmo::MiniCPMOImageProcessor image_preprocessor_; + minicpmo::MiniCPMOAudioProcessor audio_preprocessor_; + int32_t audio_pool_step_ = 5; + int32_t hop_length_ = 160; + int32_t max_audio_samples_ = 30 * 16000; + + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; +}; + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp new file mode 100644 index 000000000..392bfc17b --- /dev/null +++ b/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp @@ -0,0 +1,240 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/Tensor.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/preprocessor/audio/Audio.hpp" + +namespace mllm::models::qwen2_5omni { + +inline float hertz_to_mel_slaney(float freq) { + constexpr float kMinLogHertz = 1000.0f; + constexpr float kMinLogMel = 15.0f; + const float logstep = 27.0f / std::log(6.4f); + + if (freq < kMinLogHertz) { + return 3.0f * freq / 200.0f; + } + return kMinLogMel + std::log(freq / kMinLogHertz) * logstep; +} + +inline float mel_to_hertz_slaney(float mel) { + constexpr float kMinLogHertz = 1000.0f; + constexpr float kMinLogMel = 15.0f; + const float logstep = std::log(6.4f) / 27.0f; + + if (mel < kMinLogMel) { + return 200.0f * mel / 3.0f; + } + return kMinLogHertz * std::exp(logstep * (mel - kMinLogMel)); +} + +inline Tensor create_hann_window(int32_t window_length, bool periodic = true) { + int32_t length = periodic ? window_length + 1 : window_length; + auto window = Tensor::empty({1, window_length}, kFloat32, kCPU).alloc(); + float* window_ptr = window.ptr(); + + for (int32_t i = 0; i < window_length; ++i) { + float n = static_cast(i); + float denominator = periodic ? static_cast(length) : static_cast(length - 1); + window_ptr[i] = 0.5f - 0.5f * std::cos(2.0f * M_PI * n / denominator); + } + + return window; +} + +inline Tensor create_mel_filterbank(int32_t num_frequency_bins, int32_t num_mel_filters, float min_frequency, + float max_frequency, int32_t sampling_rate) { + std::vector fft_freqs(num_frequency_bins); + for (int32_t i = 0; i < num_frequency_bins; ++i) { + fft_freqs[i] = static_cast(i) * (sampling_rate / 2.0f) / (num_frequency_bins - 1); + } + + float mel_min = hertz_to_mel_slaney(min_frequency); + float mel_max = hertz_to_mel_slaney(max_frequency); + + std::vector mel_freqs(num_mel_filters + 2); + for (int32_t i = 0; i < num_mel_filters + 2; ++i) { + mel_freqs[i] = mel_min + static_cast(i) * (mel_max - mel_min) / (num_mel_filters + 1); + } + + std::vector filter_freqs(num_mel_filters + 2); + for (int32_t i = 0; i < num_mel_filters + 2; ++i) { filter_freqs[i] = mel_to_hertz_slaney(mel_freqs[i]); } + + auto mel_filters = Tensor::empty({num_frequency_bins, num_mel_filters}, kFloat32, kCPU).alloc(); + float* filters_ptr = mel_filters.ptr(); + std::fill_n(filters_ptr, num_frequency_bins * num_mel_filters, 0.0f); + + for (int32_t mel_idx = 0; mel_idx < num_mel_filters; ++mel_idx) { + float left_freq = filter_freqs[mel_idx]; + float center_freq = filter_freqs[mel_idx + 1]; + float right_freq = filter_freqs[mel_idx + 2]; + + for (int32_t freq_idx = 0; freq_idx < num_frequency_bins; ++freq_idx) { + float freq = fft_freqs[freq_idx]; + float value = 0.0f; + + if (freq >= left_freq && freq <= center_freq && center_freq != left_freq) { + value = (freq - left_freq) / (center_freq - left_freq); + } else if (freq >= center_freq && freq <= right_freq && right_freq != center_freq) { + value = (right_freq - freq) / (right_freq - center_freq); + } + + filters_ptr[freq_idx * num_mel_filters + mel_idx] = value; + } + } + + for (int32_t mel_idx = 0; mel_idx < num_mel_filters; ++mel_idx) { + float enorm = 2.0f / (filter_freqs[mel_idx + 2] - filter_freqs[mel_idx]); + for (int32_t freq_idx = 0; freq_idx < num_frequency_bins; ++freq_idx) { + filters_ptr[freq_idx * num_mel_filters + mel_idx] *= enorm; + } + } + + return mel_filters; +} + +class MelSpectrogramFeatures final : public nn::Module { + int32_t n_fft_; + int32_t hop_length_; + int32_t win_length_; + int32_t n_mels_; + std::string padding_; + int power_; + nn::STFT stft_; + Tensor window_; + Tensor melscale_fbanks_; + + public: + MelSpectrogramFeatures() = default; + + explicit inline MelSpectrogramFeatures(const std::string& name, int32_t sample_rate = 16000, int32_t n_fft = 400, + int32_t hop_length = 160, int32_t n_mels = 128, + const std::string& padding = "center", int power = 2) + : nn::Module(name), n_fft_(n_fft), hop_length_(hop_length), n_mels_(n_mels), padding_(padding), power_(power) { + if (padding != "center" && padding != "same") { throw std::invalid_argument("Padding must be 'center' or 'same'."); } + + win_length_ = n_fft_; + stft_ = reg("stft", n_fft_, hop_length_, win_length_, true, true, "reflect", true); + window_ = create_hann_window(win_length_, true); + melscale_fbanks_ = create_mel_filterbank(n_fft_ / 2 + 1, n_mels_, 0.0f, 8000.0f, sample_rate); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto audio = inputs[0]; // [B, T] + + if (padding_ == "same") { + NYI("apply same padding in MelSpectrogramFeatures not implemented"); + } + + auto stft_result = stft_(audio, window_); + auto specgram = stft_result.abs(); + if (power_ == 2) { + specgram = specgram * specgram; + } else if (power_ != 1) { + NYI("power != 1 and power != 2 not implemented"); + } + + auto mel_specgram = nn::functional::matmul(specgram.T(), melscale_fbanks_).T(); + mel_specgram = nn::functional::clip(mel_specgram, 1e-10f, std::numeric_limits::max()); + mel_specgram = nn::functional::log(mel_specgram) / std::log(10.0f); + auto max_val = mel_specgram.max(); + float threshold = max_val.item() - 8.0f; + mel_specgram = nn::functional::clip(mel_specgram, threshold, std::numeric_limits::max()); + mel_specgram = (mel_specgram + 4.0f) / 4.0f; + + return {mel_specgram}; + } +}; + +struct Qwen2_5OmniAudioFeatures { + Tensor input_features = Tensor::nil(); + int32_t feature_length = 0; +}; + +class Qwen2_5OmniAudioPreprocessor { + MelSpectrogramFeatures mel_extractor_; + int32_t sample_rate_; + int32_t n_mels_; + int32_t hop_length_; + int32_t chunk_length_; + int32_t n_samples_; + + public: + explicit Qwen2_5OmniAudioPreprocessor(int32_t sample_rate = 16000, int32_t n_mels = 128, int32_t hop_length = 160, + int32_t chunk_length = 300) + : mel_extractor_("feature_extractor.mel_spec", sample_rate, 400, hop_length, n_mels, "center", 2), + sample_rate_(sample_rate), + n_mels_(n_mels), + hop_length_(hop_length), + chunk_length_(chunk_length), + n_samples_(chunk_length * sample_rate) {} + + [[nodiscard]] Qwen2_5OmniAudioFeatures processAudioFile(const std::string& audio_file_path) { + auto audio_data = mllm::audio::readWAV(audio_file_path, sample_rate_); + if (audio_data.empty()) { return {}; } + return processAudioData(audio_data.data(), static_cast(audio_data.size())); + } + + [[nodiscard]] Qwen2_5OmniAudioFeatures processAudioData(const float* audio_data, int32_t audio_length) { + Qwen2_5OmniAudioFeatures result; + if (audio_data == nullptr || audio_length <= 0) { return result; } + + int32_t padded_length = n_samples_; + int32_t effective_length = std::min(audio_length, padded_length); + + auto audio_tensor = Tensor::empty({1, padded_length}, kFloat32, kCPU).alloc(); + float* audio_ptr = audio_tensor.ptr(); + + if (audio_length <= padded_length) { + std::memcpy(audio_ptr, audio_data, audio_length * sizeof(float)); + std::fill(audio_ptr + audio_length, audio_ptr + padded_length, 0.0f); + } else { + std::memcpy(audio_ptr, audio_data, padded_length * sizeof(float)); + } + + auto mel_spec = mel_extractor_.forward({audio_tensor}, {})[0]; // [1, n_mels, n_frames] + + int32_t valid_frames = calcFeatureLength(effective_length); + int32_t max_frames = mel_spec.shape()[2]; + if (valid_frames > max_frames) { valid_frames = max_frames; } + if (valid_frames <= 0) { return result; } + + auto trimmed = Tensor::empty({1, n_mels_, valid_frames}, kFloat32, kCPU).alloc(); + for (int32_t m = 0; m < n_mels_; ++m) { + auto src_ptr = mel_spec.offsettedPtr({0, m, 0}); + auto dst_ptr = trimmed.offsettedPtr({0, m, 0}); + std::memcpy(dst_ptr, src_ptr, valid_frames * sizeof(float)); + } + + result.input_features = trimmed; + result.feature_length = valid_frames; + return result; + } + + [[nodiscard]] int32_t calcFeatureLength(int32_t audio_length) const { + if (audio_length <= 0) { return 0; } + return (audio_length + hop_length_ - 1) / hop_length_; + } + + [[nodiscard]] int32_t calcAudioTokenLength(int32_t feature_length) const { + if (feature_length <= 0) { return 0; } + int32_t after_conv = (feature_length - 1) / 2 + 1; + if (after_conv < 2) { return 0; } + int32_t after_pool = (after_conv - 2) / 2 + 1; + return std::max(0, after_pool); + } +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp new file mode 100644 index 000000000..496ff6996 --- /dev/null +++ b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp @@ -0,0 +1,382 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::qwen2_5omni { + +struct Qwen2_5OmniTalkerConfig { + Qwen2_5OmniTalkerConfig() = default; + + explicit Qwen2_5OmniTalkerConfig(const nlohmann::json& root) { parse(root); } + + void parse(const nlohmann::json& root) { + audio_token_id = root.value("audio_token_index", audio_token_id); + image_token_id = root.value("image_token_index", image_token_id); + video_token_id = root.value("video_token_index", video_token_id); + + vocab_size = root.value("vocab_size", vocab_size); + tts_text_start_token_id = root.value("tts_text_start_token_id", tts_text_start_token_id); + tts_text_end_token_id = root.value("tts_text_end_token_id", tts_text_end_token_id); + tts_text_pad_token_id = root.value("tts_text_pad_token_id", tts_text_pad_token_id); + tts_codec_start_token_id = root.value("tts_codec_start_token_id", tts_codec_start_token_id); + tts_codec_end_token_id = root.value("tts_codec_end_token_id", tts_codec_end_token_id); + tts_codec_pad_token_id = root.value("tts_codec_pad_token_id", tts_codec_pad_token_id); + tts_codec_mask_token_id = root.value("tts_codec_mask_token_id", tts_codec_mask_token_id); + + vision_start_token_id = root.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = root.value("vision_end_token_id", vision_end_token_id); + audio_start_token_id = root.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = root.value("audio_end_token_id", audio_end_token_id); + + embedding_size = root.value("embedding_size", embedding_size); + hidden_size = root.value("hidden_size", hidden_size); + intermediate_size = root.value("intermediate_size", intermediate_size); + num_hidden_layers = root.value("num_hidden_layers", num_hidden_layers); + num_attention_heads = root.value("num_attention_heads", num_attention_heads); + num_key_value_heads = root.value("num_key_value_heads", num_key_value_heads); + head_dim = root.value("head_dim", head_dim); + max_position_embeddings = root.value("max_position_embeddings", max_position_embeddings); + rms_norm_eps = root.value("rms_norm_eps", rms_norm_eps); + rope_theta = root.value("rope_theta", rope_theta); + use_sliding_window = root.value("use_sliding_window", use_sliding_window); + sliding_window = root.value("sliding_window", sliding_window); + max_window_layers = root.value("max_window_layers", max_window_layers); + attention_dropout = root.value("attention_dropout", attention_dropout); + position_id_per_seconds = root.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = root.value("seconds_per_chunk", seconds_per_chunk); + spatial_merge_size = root.value("spatial_merge_size", spatial_merge_size); + + if (root.contains("rope_scaling") && root["rope_scaling"].contains("mrope_section")) { + mrope_section = root["rope_scaling"]["mrope_section"].get>(); + } + } + + int64_t audio_token_id = 151646; + int64_t image_token_id = 151655; + int64_t video_token_id = 151656; + + int32_t vocab_size = 8448; + int64_t tts_text_start_token_id = 151860; + int64_t tts_text_end_token_id = 151861; + int64_t tts_text_pad_token_id = 151859; + int64_t tts_codec_start_token_id = 8293; + int64_t tts_codec_end_token_id = 8294; + int64_t tts_codec_pad_token_id = 8292; + int64_t tts_codec_mask_token_id = 8296; + + int64_t vision_start_token_id = 151652; + int64_t vision_end_token_id = 151653; + int64_t audio_start_token_id = 151647; + int64_t audio_end_token_id = 151648; + + int32_t embedding_size = 3584; + int32_t hidden_size = 896; + int32_t intermediate_size = 18944; + int32_t num_hidden_layers = 24; + int32_t num_attention_heads = 12; + int32_t num_key_value_heads = 4; + int32_t head_dim = 128; + int32_t max_position_embeddings = 32768; + float rms_norm_eps = 1e-06f; + float rope_theta = 1000000.0f; + bool use_sliding_window = false; + int32_t sliding_window = 32768; + int32_t max_window_layers = 28; + float attention_dropout = 0.0f; + int32_t position_id_per_seconds = 25; + int32_t seconds_per_chunk = 2; + int32_t spatial_merge_size = 2; + std::vector mrope_section = {16, 24, 24}; +}; + +struct Qwen2_5OmniDiTConfig { + Qwen2_5OmniDiTConfig() = default; + + explicit Qwen2_5OmniDiTConfig(const nlohmann::json& root) { parse(root); } + + void parse(const nlohmann::json& root) { + hidden_size = root.value("dim", hidden_size); + num_hidden_layers = root.value("depth", num_hidden_layers); + num_attention_heads = root.value("heads", num_attention_heads); + ff_mult = root.value("ff_mult", ff_mult); + emb_dim = root.value("emb_dim", emb_dim); + head_dim = root.value("head_dim", head_dim); + repeats = root.value("repeats", repeats); + num_embeds = root.value("num_embeds", num_embeds); + mel_dim = root.value("mel_dim", mel_dim); + dropout = root.value("dropout", dropout); + + max_position_embeddings = root.value("max_position_embeddings", max_position_embeddings); + block_size = root.value("block_size", block_size); + if (root.contains("look_ahead_layers")) { look_ahead_layers = root["look_ahead_layers"].get>(); } + if (root.contains("look_backward_layers")) { look_backward_layers = root["look_backward_layers"].get>(); } + rope_theta = root.value("rope_theta", rope_theta); + rope_type = root.value("rope_type", rope_type); + if (root.contains("rope_parameters")) { + const auto& rope_params = root["rope_parameters"]; + rope_theta = rope_params.value("rope_theta", rope_theta); + rope_type = rope_params.value("rope_type", rope_type); + } + + enc_emb_dim = root.value("enc_emb_dim", enc_emb_dim); + enc_dim = root.value("enc_dim", enc_dim); + if (root.contains("enc_channels")) { enc_channels = root["enc_channels"].get>(); } + if (root.contains("enc_kernel_sizes")) { enc_kernel_sizes = root["enc_kernel_sizes"].get>(); } + if (root.contains("enc_dilations")) { enc_dilations = root["enc_dilations"].get>(); } + enc_attention_channels = root.value("enc_attention_channels", enc_attention_channels); + enc_res2net_scale = root.value("enc_res2net_scale", enc_res2net_scale); + enc_se_channels = root.value("enc_se_channels", enc_se_channels); + } + + int32_t hidden_size = 1024; + int32_t num_hidden_layers = 22; + int32_t num_attention_heads = 16; + int32_t ff_mult = 2; + int32_t emb_dim = 512; + int32_t head_dim = 64; + int32_t max_position_embeddings = 32768; + int32_t block_size = 24; + std::vector look_ahead_layers = {10}; + std::vector look_backward_layers = {0, 20}; + int32_t repeats = 2; + int32_t num_embeds = 8193; + int32_t mel_dim = 80; + float dropout = 0.1f; + + int32_t enc_emb_dim = 192; + int32_t enc_dim = 128; + std::vector enc_channels = {256, 256, 256, 256, 768}; + std::vector enc_kernel_sizes = {5, 3, 3, 3, 1}; + std::vector enc_dilations = {1, 2, 3, 4, 1}; + int32_t enc_attention_channels = 64; + int32_t enc_res2net_scale = 2; + int32_t enc_se_channels = 64; + + float rope_theta = 10000.0f; + std::string rope_type = "default"; +}; + +struct Qwen2_5OmniBigVGANConfig { + Qwen2_5OmniBigVGANConfig() = default; + + explicit Qwen2_5OmniBigVGANConfig(const nlohmann::json& root) { parse(root); } + + void parse(const nlohmann::json& root) { + mel_dim = root.value("mel_dim", mel_dim); + upsample_initial_channel = root.value("upsample_initial_channel", upsample_initial_channel); + if (root.contains("resblock_kernel_sizes")) { + resblock_kernel_sizes = root["resblock_kernel_sizes"].get>(); + } + if (root.contains("resblock_dilation_sizes")) { + resblock_dilation_sizes = root["resblock_dilation_sizes"].get>>(); + } + if (root.contains("upsample_rates")) { upsample_rates = root["upsample_rates"].get>(); } + if (root.contains("upsample_kernel_sizes")) { + upsample_kernel_sizes = root["upsample_kernel_sizes"].get>(); + } + } + + int32_t mel_dim = 80; + int32_t upsample_initial_channel = 1536; + std::vector resblock_kernel_sizes = {3, 7, 11}; + std::vector> resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + std::vector upsample_rates = {5, 3, 2, 2, 2, 2}; + std::vector upsample_kernel_sizes = {11, 7, 4, 4, 4, 4}; +}; + +struct Qwen2_5OmniToken2WavConfig { + Qwen2_5OmniToken2WavConfig() = default; + + explicit Qwen2_5OmniToken2WavConfig(const nlohmann::json& root) { parse(root); } + + void parse(const nlohmann::json& root) { + if (root.contains("dit_config")) { dit_config.parse(root["dit_config"]); } + if (root.contains("bigvgan_config")) { bigvgan_config.parse(root["bigvgan_config"]); } + } + + Qwen2_5OmniDiTConfig dit_config{}; + Qwen2_5OmniBigVGANConfig bigvgan_config{}; +}; + +struct Qwen2_5OmniConfig : protected ConfigFile { + Qwen2_5OmniConfig() = default; + + explicit Qwen2_5OmniConfig(const std::string& file_path) : ConfigFile(file_path) { + auto& root = data(); + enable_audio_output = root.value("enable_audio_output", root.value("enable_talker", enable_audio_output)); + + if (root.contains("talker_config")) { talker_cfg.parse(root["talker_config"]); } + if (root.contains("token2wav_config")) { token2wav_cfg.parse(root["token2wav_config"]); } + + if (root.contains("thinker_config")) { + auto& thinker_cfg = root["thinker_config"]; + auto& text_cfg = thinker_cfg["text_config"]; + + hidden_size = text_cfg["hidden_size"]; + intermediate_size = text_cfg["intermediate_size"]; + num_attention_heads = text_cfg["num_attention_heads"]; + num_key_value_heads = text_cfg["num_key_value_heads"]; + num_hidden_layers = text_cfg["num_hidden_layers"]; + max_position_embeddings = text_cfg["max_position_embeddings"]; + rms_norm_eps = text_cfg["rms_norm_eps"]; + vocab_size = text_cfg["vocab_size"]; + rope_theta = text_cfg["rope_theta"]; + tie_word_embeddings = text_cfg.value("tie_word_embeddings", false); + + if (text_cfg.contains("rope_scaling") && text_cfg["rope_scaling"].contains("mrope_section")) { + mrope_section = text_cfg["rope_scaling"]["mrope_section"].get>(); + } + + if (thinker_cfg.contains("vision_config")) { + auto& vision_cfg = thinker_cfg["vision_config"]; + visual_in_chans = vision_cfg.value("in_channels", vision_cfg.value("in_chans", visual_in_chans)); + visual_hidden_size = vision_cfg.value("hidden_size", vision_cfg.value("embed_dim", visual_hidden_size)); + visual_patch_size = vision_cfg.value("patch_size", vision_cfg.value("spatial_patch_size", visual_patch_size)); + visual_temporal_patch_size = vision_cfg.value("temporal_patch_size", visual_temporal_patch_size); + visual_spatial_merge_size = vision_cfg.value("spatial_merge_size", visual_spatial_merge_size); + visual_out_hidden_size = vision_cfg.value("out_hidden_size", visual_out_hidden_size); + visual_num_heads = vision_cfg.value("num_heads", visual_num_heads); + visual_depth = vision_cfg.value("depth", visual_depth); + visual_intermediate_size = vision_cfg.value("intermediate_size", visual_intermediate_size); + if (vision_cfg.contains("fullatt_block_indexes")) { + visual_fullatt_block_indexes = vision_cfg["fullatt_block_indexes"].get>(); + } + visual_window_size = vision_cfg.value("window_size", visual_window_size); + } + + if (thinker_cfg.contains("audio_config")) { + auto& audio_cfg = thinker_cfg["audio_config"]; + audio_d_model = audio_cfg.value("d_model", audio_d_model); + audio_num_mel_bins = audio_cfg.value("num_mel_bins", audio_num_mel_bins); + audio_encoder_layers = audio_cfg.value("encoder_layers", audio_encoder_layers); + audio_encoder_attention_heads = audio_cfg.value("encoder_attention_heads", audio_encoder_attention_heads); + audio_encoder_ffn_dim = audio_cfg.value("encoder_ffn_dim", audio_encoder_ffn_dim); + audio_max_source_positions = audio_cfg.value("max_source_positions", audio_max_source_positions); + audio_n_window = audio_cfg.value("n_window", audio_n_window); + audio_output_dim = audio_cfg.value("output_dim", audio_output_dim); + } + + bos_token_id = thinker_cfg.value("bos_token_id", bos_token_id); + eos_token_id = thinker_cfg.value("eos_token_id", eos_token_id); + pad_token_id = thinker_cfg.value("pad_token_id", pad_token_id); + image_token_id = thinker_cfg.value("image_token_index", image_token_id); + audio_token_id = thinker_cfg.value("audio_token_index", audio_token_id); + video_token_id = thinker_cfg.value("video_token_index", video_token_id); + audio_start_token_id = thinker_cfg.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = thinker_cfg.value("audio_end_token_id", audio_end_token_id); + vision_start_token_id = thinker_cfg.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = thinker_cfg.value("vision_end_token_id", vision_end_token_id); + vision_token_id = thinker_cfg.value("vision_token_id", vision_token_id); + position_id_per_seconds = thinker_cfg.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = thinker_cfg.value("seconds_per_chunk", seconds_per_chunk); + } else { + hidden_size = root["hidden_size"]; + intermediate_size = root["intermediate_size"]; + num_attention_heads = root["num_attention_heads"]; + num_key_value_heads = root["num_key_value_heads"]; + num_hidden_layers = root["num_hidden_layers"]; + max_position_embeddings = root["max_position_embeddings"]; + rms_norm_eps = root["rms_norm_eps"]; + vocab_size = root["vocab_size"]; + rope_theta = root["rope_theta"]; + tie_word_embeddings = root.value("tie_word_embeddings", tie_word_embeddings); + if (root.contains("mrope_section")) { + mrope_section = root["mrope_section"].get>(); + } + if (root.contains("audio_config")) { + auto& audio_cfg = root["audio_config"]; + audio_d_model = audio_cfg.value("d_model", audio_d_model); + audio_num_mel_bins = audio_cfg.value("num_mel_bins", audio_num_mel_bins); + audio_encoder_layers = audio_cfg.value("encoder_layers", audio_encoder_layers); + audio_encoder_attention_heads = audio_cfg.value("encoder_attention_heads", audio_encoder_attention_heads); + audio_encoder_ffn_dim = audio_cfg.value("encoder_ffn_dim", audio_encoder_ffn_dim); + audio_max_source_positions = audio_cfg.value("max_source_positions", audio_max_source_positions); + audio_n_window = audio_cfg.value("n_window", audio_n_window); + audio_output_dim = audio_cfg.value("output_dim", audio_output_dim); + } + bos_token_id = root.value("bos_token_id", bos_token_id); + eos_token_id = root.value("eos_token_id", eos_token_id); + pad_token_id = root.value("pad_token_id", pad_token_id); + image_token_id = root.value("image_token_id", image_token_id); + audio_token_id = root.value("audio_token_id", audio_token_id); + video_token_id = root.value("video_token_id", video_token_id); + audio_start_token_id = root.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = root.value("audio_end_token_id", audio_end_token_id); + vision_start_token_id = root.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = root.value("vision_end_token_id", vision_end_token_id); + vision_token_id = root.value("vision_token_id", vision_token_id); + position_id_per_seconds = root.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = root.value("seconds_per_chunk", seconds_per_chunk); + } + + max_cache_length = root.value("max_cache_length", max_position_embeddings); + + if (root.contains("linear_impl_type")) { + linear_impl_type = aops::str2LinearImplTypes(root["linear_impl_type"]); + } + } + + int32_t hidden_size = 3584; + int32_t intermediate_size = 18944; + int32_t num_attention_heads = 28; + int32_t num_key_value_heads = 4; + int32_t num_hidden_layers = 28; + int32_t max_position_embeddings = 32768; + float rms_norm_eps = 1e-06f; + int32_t vocab_size = 152064; + std::vector mrope_section = {16, 24, 24}; + float rope_theta = 1000000.0f; + bool tie_word_embeddings = false; + + int32_t visual_in_chans = 3; + int32_t visual_hidden_size = 1280; + int32_t visual_patch_size = 14; + int32_t visual_temporal_patch_size = 2; + int32_t visual_spatial_merge_size = 2; + int32_t visual_out_hidden_size = 3584; + int32_t visual_num_heads = 16; + int32_t visual_depth = 32; + int32_t visual_intermediate_size = 3420; + std::vector visual_fullatt_block_indexes = {7, 15, 23, 31}; + int32_t visual_window_size = 112; + + int32_t audio_d_model = 1280; + int32_t audio_num_mel_bins = 128; + int32_t audio_encoder_layers = 32; + int32_t audio_encoder_attention_heads = 20; + int32_t audio_encoder_ffn_dim = 5120; + int32_t audio_max_source_positions = 1500; + int32_t audio_n_window = 100; + int32_t audio_output_dim = 3584; + + int32_t max_cache_length = 32768; + + int64_t bos_token_id = 151644; + int64_t eos_token_id = 151645; + int64_t pad_token_id = 151643; + int64_t image_token_id = 151655; + int64_t audio_token_id = 151646; + int64_t video_token_id = 151656; + int64_t audio_start_token_id = 151647; + int64_t audio_end_token_id = 151648; + int64_t vision_start_token_id = 151652; + int64_t vision_end_token_id = 151653; + int64_t vision_token_id = 151654; + int32_t position_id_per_seconds = 25; + int32_t seconds_per_chunk = 2; + + bool enable_audio_output = true; + Qwen2_5OmniTalkerConfig talker_cfg{}; + Qwen2_5OmniToken2WavConfig token2wav_cfg{}; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp new file mode 100644 index 000000000..42bae162f --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp @@ -0,0 +1,2036 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/mllm.hpp" +#include "mllm/core/SlicePrimitives.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" +#include "mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp" +#include "mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp" + +namespace mllm::models::qwen2_5omni { + +inline auto makeMultimodalRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0f / std::pow(rope_theta, 2.0f * i / output_dim); } + return inv_freq; +} + +inline auto makeMultimodalPositionEmbedding(Tensor& position_ids, const Tensor& inv_freq, int seq_len, int output_dim, + const std::vector& mrope_section) -> std::pair { + MLLM_RT_ASSERT_EQ(position_ids.shape().size(), 3); + MLLM_RT_ASSERT_EQ(position_ids.shape()[1], 1); + + Tensor tmp_sin = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + Tensor tmp_cos = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + + for (int b = 0; b < 3; ++b) { + for (int d = 0; d < inv_freq.shape()[0]; ++d) { + for (int s = 0; s < position_ids.shape()[2]; ++s) { + auto value = inv_freq.ptr()[d] * (*position_ids.offsettedPtr({b, 0, s})); + *tmp_cos.offsettedPtr({b, s, d}) = cosf(value); + *tmp_cos.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = cosf(value); + *tmp_sin.offsettedPtr({b, s, d}) = sinf(value); + *tmp_sin.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = sinf(value); + } + } + } + + Tensor sin = Tensor::nil(); + Tensor cos = Tensor::nil(); + + if (!mrope_section.empty()) { + auto double_rope_section = mrope_section; + for (int i : mrope_section) { double_rope_section.push_back(i); } + + int num_rows = tmp_sin.shape()[1]; + int num_cols = tmp_sin.shape()[2]; + + sin = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + cos = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + + std::vector start_cols; + int current_start = 0; + start_cols.push_back(current_start); + for (int s : double_rope_section) { + current_start += s; + start_cols.push_back(current_start); + } + + for (int j = 0; j < static_cast(double_rope_section.size()); ++j) { + int layer = j % 3; + int s_j = double_rope_section[j]; + int start_col_in = start_cols[j]; + int start_col_out = start_cols[j]; + for (int row = 0; row < num_rows; ++row) { + auto in_cos_row_ptr = tmp_cos.offsettedPtr({layer, row, 0}); + auto out_cos_row_ptr = cos.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_cos_row_ptr[start_col_out + c] = in_cos_row_ptr[start_col_in + c]; } + + auto in_sin_row_ptr = tmp_sin.offsettedPtr({layer, row, 0}); + auto out_sin_row_ptr = sin.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_sin_row_ptr[start_col_out + c] = in_sin_row_ptr[start_col_in + c]; } + } + } + } else { + sin = tmp_sin; + cos = tmp_cos; + } + + return {sin, cos}; +} + + +inline auto makeWindowIndex(const Tensor& grid_thw, int window_size, int spatial_merge_size, + int patch_size) -> std::pair, std::vector> { + MLLM_RT_ASSERT_EQ(grid_thw.shape().size(), 2); + const int grid_num = grid_thw.shape()[0]; + + const int vit_merger_window_size = window_size / spatial_merge_size / patch_size; + const int spatial_merge_unit = spatial_merge_size * spatial_merge_size; + + std::vector window_index; + std::vector cu_window_seqlens = {0}; + int window_index_id = 0; + + for (int grid_idx = 0; grid_idx < grid_num; ++grid_idx) { + const int grid_t = grid_thw.constAt({grid_idx, 0}); + const int grid_h = grid_thw.constAt({grid_idx, 1}); + const int grid_w = grid_thw.constAt({grid_idx, 2}); + + const int llm_grid_h = grid_h / spatial_merge_size; + const int llm_grid_w = grid_w / spatial_merge_size; + const int pad_h = (vit_merger_window_size - llm_grid_h % vit_merger_window_size) % vit_merger_window_size; + const int pad_w = (vit_merger_window_size - llm_grid_w % vit_merger_window_size) % vit_merger_window_size; + + const int num_windows_h = (llm_grid_h + pad_h) / vit_merger_window_size; + const int num_windows_w = (llm_grid_w + pad_w) / vit_merger_window_size; + const int total_windows = grid_t * num_windows_h * num_windows_w; + + std::vector>> index( + grid_t, std::vector>(llm_grid_h, std::vector(llm_grid_w))); + + int counter = 0; + for (int t = 0; t < grid_t; t++) { + for (int h = 0; h < llm_grid_h; h++) { + for (int w = 0; w < llm_grid_w; w++) { index[t][h][w] = counter++; } + } + } + + std::vector>> index_padded( + grid_t, std::vector>(llm_grid_h + pad_h, std::vector(llm_grid_w + pad_w, -100))); + + for (int t = 0; t < grid_t; t++) { + for (int h = 0; h < llm_grid_h; h++) { + for (int w = 0; w < llm_grid_w; w++) { index_padded[t][h][w] = index[t][h][w]; } + } + } + + std::vector seqlens(total_windows, 0); + for (int t = 0; t < grid_t; t++) { + for (int wh = 0; wh < num_windows_h; wh++) { + for (int ww = 0; ww < num_windows_w; ww++) { + const int window_idx = t * num_windows_h * num_windows_w + wh * num_windows_w + ww; + for (int h = 0; h < vit_merger_window_size; h++) { + for (int w = 0; w < vit_merger_window_size; w++) { + const int orig_h = wh * vit_merger_window_size + h; + const int orig_w = ww * vit_merger_window_size + w; + if (index_padded[t][orig_h][orig_w] != -100) { + window_index.push_back(index_padded[t][orig_h][orig_w] + window_index_id); + seqlens[window_idx]++; + } + } + } + } + } + } + + int cumulative = cu_window_seqlens.back(); + for (int i = 0; i < total_windows; i++) { + cumulative += seqlens[i] * spatial_merge_unit; + cu_window_seqlens.push_back(cumulative); + } + + window_index_id += grid_t * llm_grid_h * llm_grid_w; + } + + return {window_index, cu_window_seqlens}; +} + +inline auto makeVisualRoPEInvFreq(int32_t dims, float theta) -> Tensor { + const int half_dim = dims / (2 * 2); + Tensor inv_freq = Tensor::empty({half_dim}, kFloat32).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + const float dims_inv = 1.0f / static_cast(dims / 2); + for (int i = 0; i < half_dim; ++i) { + const float exponent = (2.0f * i) * dims_inv; + inv_freq_ptr[i] = 1.0f / std::pow(theta, exponent); + } + return inv_freq; +} + +inline auto makeVisualRotaryPosEmbIds(Tensor& grid_thw, int32_t spatial_merge_size) -> Tensor { + MLLM_RT_ASSERT_EQ(grid_thw.shape().size(), 2); + + const auto img_nums = grid_thw.shape()[0]; + int total_positions = 0; + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + total_positions += dims[0] * dims[1] * dims[2]; + } + + Tensor out = Tensor::empty({total_positions, 2}, kInt32).alloc(); + int* out_ptr = out.ptr(); + int out_offset = 0; + + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + const int t = dims[0]; + const int h = dims[1]; + const int w = dims[2]; + + const int num_h_blocks = h / spatial_merge_size; + const int num_w_blocks = w / spatial_merge_size; + const int total_blocks = num_h_blocks * num_w_blocks; + const int block_area = spatial_merge_size * spatial_merge_size; + const int grid_size = h * w; + + std::vector flatten_hpos(grid_size); + std::vector flatten_wpos(grid_size); + + for (int block_idx = 0; block_idx < total_blocks; ++block_idx) { + const int i_h = block_idx / num_w_blocks; + const int i_w = block_idx % num_w_blocks; + const int start_idx = block_idx * block_area; + + const int base_h = i_h * spatial_merge_size; + const int base_w = i_w * spatial_merge_size; + + for (int j_h = 0; j_h < spatial_merge_size; ++j_h) { + const int global_h = base_h + j_h; + for (int j_w = 0; j_w < spatial_merge_size; ++j_w) { + const int global_w = base_w + j_w; + const int pos = start_idx + j_h * spatial_merge_size + j_w; + flatten_hpos[pos] = global_h; + flatten_wpos[pos] = global_w; + } + } + } + + for (int frame = 0; frame < t; ++frame) { + for (int pos = 0; pos < grid_size; ++pos) { + const int out_idx = out_offset + (frame * grid_size + pos) * 2; + out_ptr[out_idx] = flatten_hpos[pos]; + out_ptr[out_idx + 1] = flatten_wpos[pos]; + } + } + out_offset += t * grid_size * 2; + } + + return out; +} + +inline float kaiserBesselI0(float x) { + const float ax = std::fabs(x); + if (ax < 3.75f) { + const float y = (x / 3.75f) * (x / 3.75f); + return 1.0f + y * (3.5156229f + y * (3.0899424f + y * (1.2067492f + y * (0.2659732f + y * (0.0360768f + y * 0.0045813f))))); + } + const float y = 3.75f / ax; + return (std::exp(ax) / std::sqrt(ax)) * + (0.39894228f + y * (0.01328592f + y * (0.00225319f + y * (-0.00157565f + y * (0.00916281f + + y * (-0.02057706f + y * (0.02635537f + y * (-0.01647633f + y * 0.00392377f)))))))); +} + +inline Tensor kaiserSincFilter1d(float cutoff, float half_width, int32_t kernel_size) { + const bool is_even = (kernel_size % 2 == 0); + const int32_t half_size = kernel_size / 2; + + const float delta_f = 4.0f * half_width; + const float attenuation = 2.285f * (half_size - 1) * static_cast(M_PI) * delta_f + 7.95f; + + float beta = 0.0f; + if (attenuation > 50.0f) { + beta = 0.1102f * (attenuation - 8.7f); + } else if (attenuation >= 21.0f) { + beta = 0.5842f * std::pow(attenuation - 21.0f, 0.4f) + 0.07886f * (attenuation - 21.0f); + } + + std::vector window(kernel_size); + const float denom = kaiserBesselI0(beta); + for (int32_t n = 0; n < kernel_size; ++n) { + const float ratio = (kernel_size == 1) ? 0.0f : (2.0f * n) / (kernel_size - 1) - 1.0f; + const float val = beta * std::sqrt(std::max(0.0f, 1.0f - ratio * ratio)); + window[n] = (denom == 0.0f) ? 0.0f : kaiserBesselI0(val) / denom; + } + + std::vector time_indices(kernel_size); + if (is_even) { + for (int32_t i = 0; i < kernel_size; ++i) { time_indices[i] = static_cast(i - half_size) + 0.5f; } + } else { + for (int32_t i = 0; i < kernel_size; ++i) { time_indices[i] = static_cast(i - half_size); } + } + + Tensor filter = Tensor::empty({1, 1, kernel_size}, kFloat32, kCPU).alloc(); + auto* filter_ptr = filter.ptr(); + + if (cutoff == 0.0f) { + std::fill(filter_ptr, filter_ptr + kernel_size, 0.0f); + return filter; + } + + float sum = 0.0f; + for (int32_t i = 0; i < kernel_size; ++i) { + const float x = 2.0f * cutoff * time_indices[i]; + const float sinc = (x == 0.0f) ? 1.0f : std::sin(static_cast(M_PI) * x) / (static_cast(M_PI) * x); + const float value = 2.0f * cutoff * window[i] * sinc; + filter_ptr[i] = value; + sum += value; + } + if (sum != 0.0f) { + for (int32_t i = 0; i < kernel_size; ++i) { filter_ptr[i] /= sum; } + } + + return filter; +} + +inline auto makeVisualRotaryPosEmbFull(Tensor& inv_freq, int seq_len) -> Tensor { + MLLM_RT_ASSERT(seq_len > 0); + const int32_t dim = inv_freq.shape()[0]; + Tensor freqs = Tensor::empty({seq_len, dim}, kFloat32, kCPU).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + float* freqs_ptr = freqs.ptr(); + for (int i = 0; i < seq_len; ++i) { + const float i_val = static_cast(i); + float* row_ptr = freqs_ptr + i * dim; + for (int j = 0; j < dim; ++j) { row_ptr[j] = i_val * inv_freq_ptr[j]; } + } + return freqs; +} + +inline auto makeVisualRotaryPosEmb(Tensor& rotary_pos_emb_full, Tensor& pos_ids, Tensor& grid_thw) -> Tensor { + const int32_t dim = rotary_pos_emb_full.shape()[1]; + const int32_t batch_size = pos_ids.shape()[0]; + const int32_t seq_len = pos_ids.shape()[1]; + + int total_positions = 0; + for (int row = 0; row < grid_thw.shape()[0]; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + total_positions += dims[0] * dims[1] * dims[2]; + } + + Tensor out = Tensor::empty({batch_size, seq_len * dim}, kFloat32, kCPU).alloc(); + + auto rotary_pos_emb_full_ptr = rotary_pos_emb_full.ptr(); + auto pos_ids_ptr = pos_ids.ptr(); + + if (rotary_pos_emb_full.shape()[0] <= 0 || dim <= 0 || batch_size <= 0) { + MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Invalid tensor dimensions"); + } + + if (total_positions != batch_size) { MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Grid dimensions mismatch with batch size"); } + + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < seq_len; ++j) { + const int idx = pos_ids_ptr[i * seq_len + j]; + if (idx < 0 || idx >= rotary_pos_emb_full.shape()[0]) { + MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Position index out of bounds"); + } + } + } + + for (int i = 0; i < batch_size; ++i) { + auto batch_ptr = out.offsettedPtr({i, 0}); + size_t offset = 0; + for (int j = 0; j < seq_len; ++j) { + const int idx = pos_ids_ptr[i * seq_len + j]; + auto emb_ptr = rotary_pos_emb_full_ptr + idx * dim; + std::copy(emb_ptr, emb_ptr + dim, batch_ptr + offset); + offset += dim; + } + } + + return out; +} + +inline auto makeVisualRotarySinCos(Tensor& rotary_pos_emb) -> std::pair { + const auto seq = rotary_pos_emb.shape()[0]; + const auto dim = rotary_pos_emb.shape()[1]; + + auto rotary_pos_emb_ptr = rotary_pos_emb.ptr(); + + Tensor sin_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + Tensor cos_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + + auto sin_pos_emb_ptr = sin_pos_emb.ptr(); + auto cos_pos_emb_ptr = cos_pos_emb.ptr(); + + for (int i = 0; i < seq; i++) { + for (int j = 0; j < dim; j++) { + sin_pos_emb_ptr[i * dim + j] = std::sin(rotary_pos_emb_ptr[i * dim + j]); + cos_pos_emb_ptr[i * dim + j] = std::cos(rotary_pos_emb_ptr[i * dim + j]); + } + } + + return {sin_pos_emb, cos_pos_emb}; +} + +inline auto makeAudioSinusoidalPosEmb(int32_t length, int32_t channels, float max_timescale = 10000.0f) -> Tensor { + MLLM_RT_ASSERT(channels % 2 == 0); + auto pos_emb = Tensor::empty({length, channels}, kFloat32, kCPU).alloc(); + auto pos_ptr = pos_emb.ptr(); + + const int half = channels / 2; + const float log_timescale_increment = std::log(max_timescale) / static_cast(half - 1); + + std::vector inv_timescales(half); + for (int i = 0; i < half; ++i) { + inv_timescales[i] = std::exp(-log_timescale_increment * static_cast(i)); + } + + for (int t = 0; t < length; ++t) { + for (int i = 0; i < half; ++i) { + const float scaled_time = static_cast(t) * inv_timescales[i]; + pos_ptr[t * channels + i] = std::sin(scaled_time); + pos_ptr[t * channels + half + i] = std::cos(scaled_time); + } + } + + return pos_emb; +} + +class Qwen2_5OmniPatchEmbed final : public nn::Module { + int32_t in_chans_; + int32_t embed_dim_; + int32_t patch_size_; + int32_t temporal_patch_size_; + + nn::Conv3D proj_; + + public: + Qwen2_5OmniPatchEmbed() = default; + + explicit Qwen2_5OmniPatchEmbed(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + in_chans_ = cfg.visual_in_chans; + embed_dim_ = cfg.visual_hidden_size; + patch_size_ = cfg.visual_patch_size; + temporal_patch_size_ = cfg.visual_temporal_patch_size; + + proj_ = reg("proj", cfg.visual_in_chans, cfg.visual_hidden_size, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + false); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + hidden_states = hidden_states.view({-1, in_chans_, temporal_patch_size_, patch_size_, patch_size_}); + hidden_states = proj_(hidden_states).view({-1, embed_dim_}); + return {hidden_states}; + } +}; + +class Qwen2_5OmniPatchMerger final : public nn::Module { + int32_t hidden_size_; + int32_t spatial_merge_size_; + int32_t context_dim_; + + nn::RMSNorm ln_q_; + nn::Linear mlp_0_; + nn::Linear mlp_2_; + nn::GELU mlp_gelu_; + + public: + Qwen2_5OmniPatchMerger() = default; + + explicit Qwen2_5OmniPatchMerger(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + context_dim_ = cfg.visual_hidden_size; + spatial_merge_size_ = cfg.visual_spatial_merge_size; + hidden_size_ = context_dim_ * spatial_merge_size_ * spatial_merge_size_; + + ln_q_ = reg("ln_q", 1e-6); + mlp_0_ = reg("mlp.0", hidden_size_, hidden_size_, true, cfg.linear_impl_type); + mlp_gelu_ = reg("mlp.gelu"); + mlp_2_ = reg("mlp.2", hidden_size_, cfg.visual_out_hidden_size, true, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto o = ln_q_(inputs[0]).view({-1, hidden_size_}); + o = mlp_0_(o); + o = mlp_gelu_(o); + o = mlp_2_(o); + return {o}; + } +}; + +class Qwen2_5OmniVisionMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniVisionMLP() = default; + explicit Qwen2_5OmniVisionMLP(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.visual_hidden_size, cfg.visual_intermediate_size, true); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.visual_hidden_size, cfg.visual_intermediate_size, true); + down_proj_ = reg("down_proj", cfg.visual_intermediate_size, cfg.visual_hidden_size, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniVisionAttention final : public nn::Module { + int32_t dim_; + int32_t num_heads_; + int32_t head_dim_; + + nn::Linear q_; + nn::Linear k_; + nn::Linear v_; + nn::Linear proj_; + nn::Softmax softmax_; + nn::VisionRoPE vision_rope_q_; + nn::VisionRoPE vision_rope_k_; + + public: + Qwen2_5OmniVisionAttention() = default; + + explicit Qwen2_5OmniVisionAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + dim_ = cfg.visual_hidden_size; + num_heads_ = cfg.visual_num_heads; + head_dim_ = dim_ / num_heads_; + + q_ = reg("q", dim_, dim_, true, cfg.linear_impl_type); + k_ = reg("k", dim_, dim_, true, cfg.linear_impl_type); + v_ = reg("v", dim_, dim_, true, cfg.linear_impl_type); + proj_ = reg("proj", dim_, dim_, true, cfg.linear_impl_type); + softmax_ = reg("softmax", -1); + + vision_rope_q_ = reg("vision_rope_q", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + vision_rope_k_ = reg("vision_rope_k", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + auto& mask = inputs[3]; + + auto seq_length = hidden_states.shape()[0]; + + auto query_states = q_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + auto key_states = k_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + auto value_states = v_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + + query_states = vision_rope_q_(query_states, visual_embedding_sin, visual_embedding_cos); + key_states = vision_rope_k_(key_states, visual_embedding_sin, visual_embedding_cos); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + auto attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + if (mask) { attn = attn + mask; } + attn = softmax_(attn); + + auto attn_output = nn::functional::matmul(attn, value_states); + attn_output = attn_output.transpose(1, 2).view({seq_length, -1}); + attn_output = proj_(attn_output); + return {attn_output}; + } +}; + +class Qwen2_5OmniVisionBlock final : public nn::Module { + nn::RMSNorm norm1_; + nn::RMSNorm norm2_; + + Qwen2_5OmniVisionAttention attn_; + Qwen2_5OmniVisionMLP mlp_; + + public: + Qwen2_5OmniVisionBlock() = default; + + explicit Qwen2_5OmniVisionBlock(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + norm1_ = reg("norm1", 1e-6); + norm2_ = reg("norm2", 1e-6); + attn_ = reg("attn", cfg); + mlp_ = reg("mlp", cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + auto mask = inputs[3]; + + hidden_states = hidden_states + attn_(norm1_(hidden_states), visual_embedding_sin, visual_embedding_cos, mask)[0]; + hidden_states = hidden_states + mlp_(norm2_(hidden_states))[0]; + return {hidden_states}; + } +}; + +class Qwen2_5OmniVisionEncoder final : public nn::Module { + Qwen2_5OmniPatchEmbed patch_embed_; + Qwen2_5OmniPatchMerger patch_merger_; + nn::ModuleList blocks_; + std::vector visual_fullatt_block_indexes_; + int32_t visual_window_size_ = 0; + int32_t visual_spatial_merge_size_ = 1; + int32_t visual_patch_size_ = 1; + int32_t spatial_merge_unit_ = 1; + + public: + Qwen2_5OmniVisionEncoder() = default; + + explicit Qwen2_5OmniVisionEncoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + visual_window_size_ = cfg.visual_window_size; + visual_spatial_merge_size_ = cfg.visual_spatial_merge_size; + visual_patch_size_ = cfg.visual_patch_size; + spatial_merge_unit_ = visual_spatial_merge_size_ * visual_spatial_merge_size_; + visual_fullatt_block_indexes_ = cfg.visual_fullatt_block_indexes; + patch_embed_ = reg("patch_embed", cfg); + patch_merger_ = reg("merger", cfg); + blocks_ = reg>("blocks", cfg.visual_depth, cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto embedding_sin = inputs[1]; + auto embedding_cos = inputs[2]; + auto& grid_thw = inputs[3]; + + hidden_states = patch_embed_(hidden_states)[0]; + auto [window_index, cu_window_seqlens] = + makeWindowIndex(grid_thw, visual_window_size_, visual_spatial_merge_size_, visual_patch_size_); + + auto seq_len = hidden_states.shape()[0]; + hidden_states = hidden_states.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + hidden_states = hidden_states[{window_index, {kAll}, {kAll}}]; + hidden_states = hidden_states.view({seq_len, -1}); + + embedding_sin = embedding_sin.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + embedding_sin = embedding_sin[{window_index, {kAll}, {kAll}}]; + embedding_sin = embedding_sin.view({seq_len, -1}); + embedding_cos = embedding_cos.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + embedding_cos = embedding_cos[{window_index, {kAll}, {kAll}}]; + embedding_cos = embedding_cos.view({seq_len, -1}); + + auto mask = Tensor::empty({1, 1, seq_len, seq_len}, DataTypes::kFloat32, DeviceTypes::kCPU).alloc(); + { + auto mask_ptr = mask.ptr(); + const mllm_fp32_t neg_inf = -1e12f; + for (int i = 0; i < seq_len * seq_len; ++i) { mask_ptr[i] = neg_inf; } + for (int i = 1; i < cu_window_seqlens.size(); ++i) { + const int start = cu_window_seqlens[i - 1]; + const int end = cu_window_seqlens[i]; + for (int r = start; r < end; ++r) { + for (int c = start; c < end; ++c) { mask_ptr[r * seq_len + c] = 0.0f; } + } + } + } + + for (auto [layer_idx, b] : enumerate(blocks_.list())) { + if (std::find(visual_fullatt_block_indexes_.begin(), visual_fullatt_block_indexes_.end(), layer_idx) + != visual_fullatt_block_indexes_.end()) { + hidden_states = b(hidden_states, embedding_sin, embedding_cos, Tensor::nil())[0]; + } else { + hidden_states = b(hidden_states, embedding_sin, embedding_cos, mask)[0]; + } + } + + hidden_states = patch_merger_(hidden_states)[0]; + + std::vector reverse_indices(window_index.size()); + std::iota(reverse_indices.begin(), reverse_indices.end(), 0); + std::sort(reverse_indices.begin(), reverse_indices.end(), + [&window_index](int i, int j) { return window_index[i] < window_index[j]; }); + hidden_states = hidden_states[{reverse_indices, {kAll}}]; + + return {hidden_states}; + } +}; + +class Qwen2_5OmniAudioAttention final : public nn::Module { + int32_t embed_dim_ = 0; + int32_t num_heads_ = 0; + int32_t head_dim_ = 0; + + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear q_proj_; + nn::Linear out_proj_; + + public: + Qwen2_5OmniAudioAttention() = default; + + explicit Qwen2_5OmniAudioAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + embed_dim_ = cfg.audio_d_model; + num_heads_ = cfg.audio_encoder_attention_heads; + head_dim_ = embed_dim_ / num_heads_; + + k_proj_ = reg("k_proj", embed_dim_, embed_dim_, false); + v_proj_ = reg("v_proj", embed_dim_, embed_dim_, true); + q_proj_ = reg("q_proj", embed_dim_, embed_dim_, true); + out_proj_ = reg("out_proj", embed_dim_, embed_dim_, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; // [seq_len, embed_dim] + auto seq_len = hidden_states.shape()[0]; + + auto hidden = hidden_states.unsqueeze(0); // [1, seq_len, embed_dim] + auto query_states = q_proj_(hidden); + auto key_states = k_proj_(hidden); + auto value_states = v_proj_(hidden); + + query_states = query_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + key_states = key_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + value_states = value_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + + float scale = 1.0f / std::sqrt(static_cast(head_dim_)); + auto attn_weights = nn::functional::matmul(query_states, key_states.transpose(-2, -1)) * scale; + attn_weights = nn::functional::softmax(attn_weights, -1); + auto attn_output = nn::functional::matmul(attn_weights, value_states); + + attn_output = attn_output.transpose(1, 2).contiguous().view({1, seq_len, embed_dim_}); + attn_output = out_proj_(attn_output); + + return {attn_output.squeeze(0)}; + } +}; + +class Qwen2_5OmniAudioEncoderLayer final : public nn::Module { + Qwen2_5OmniAudioAttention self_attn_; + nn::LayerNorm self_attn_layer_norm_; + nn::Linear fc1_; + nn::Linear fc2_; + nn::LayerNorm final_layer_norm_; + nn::GELU activation_fn_; + + public: + Qwen2_5OmniAudioEncoderLayer() = default; + + explicit Qwen2_5OmniAudioEncoderLayer(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + const int32_t embed_dim = cfg.audio_d_model; + self_attn_ = reg("self_attn", cfg); + self_attn_layer_norm_ = + reg("self_attn_layer_norm", std::vector{embed_dim}, true, true, 1e-5); + fc1_ = reg("fc1", embed_dim, cfg.audio_encoder_ffn_dim, true); + fc2_ = reg("fc2", cfg.audio_encoder_ffn_dim, embed_dim, true); + final_layer_norm_ = reg("final_layer_norm", std::vector{embed_dim}, true, true, 1e-5); + activation_fn_ = reg("activation_fn"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto residual = hidden_states; + + hidden_states = self_attn_layer_norm_(hidden_states); + hidden_states = self_attn_(hidden_states)[0]; + hidden_states = residual + hidden_states; + + residual = hidden_states; + hidden_states = final_layer_norm_(hidden_states); + hidden_states = fc1_(hidden_states); + hidden_states = activation_fn_(hidden_states); + hidden_states = fc2_(hidden_states); + hidden_states = residual + hidden_states; + + if (hidden_states.dtype() == kFloat16) { + const float clamp_value = 65504.0f - 1000.0f; + hidden_states = nn::functional::clip(hidden_states, -clamp_value, clamp_value); + } + + return {hidden_states}; + } +}; + +class Qwen2_5OmniAudioEncoder final : public nn::Module { + nn::Conv1D conv1_; + nn::Conv1D conv2_; + nn::GELU gelu_; + nn::ModuleList layers_; + nn::LayerNorm ln_post_; + nn::AvgPool1d avg_pooler_; + nn::Linear proj_; + nn::Embedding audio_bos_eos_token_; + + int32_t num_mel_bins_ = 0; + int32_t embed_dim_ = 0; + int32_t n_window_ = 0; + int32_t output_dim_ = 0; + + public: + Qwen2_5OmniAudioEncoder() = default; + + explicit Qwen2_5OmniAudioEncoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + num_mel_bins_ = cfg.audio_num_mel_bins; + embed_dim_ = cfg.audio_d_model; + n_window_ = cfg.audio_n_window; + output_dim_ = cfg.audio_output_dim; + + conv1_ = reg("conv1", num_mel_bins_, embed_dim_, 3, 1, 1); + conv2_ = reg("conv2", embed_dim_, embed_dim_, 3, 2, 1); + gelu_ = reg("gelu"); + audio_bos_eos_token_ = reg("audio_bos_eos_token", 2, cfg.audio_output_dim); + layers_ = reg>("layers", cfg.audio_encoder_layers, cfg); + ln_post_ = reg("ln_post", std::vector{embed_dim_}, true, true, 1e-5); + avg_pooler_ = reg("avg_pooler", 2, 2); + proj_ = reg("proj", embed_dim_, cfg.audio_output_dim, true); + + auto pos_emb = makeAudioSinusoidalPosEmb(cfg.audio_max_source_positions, embed_dim_); + registerBuffer("positional_embedding", pos_emb); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto input_features = inputs[0]; // [B, n_mels, T] + MLLM_RT_ASSERT_EQ(input_features.shape().size(), 3); + + const int32_t batch_size = input_features.shape()[0]; + MLLM_RT_ASSERT_EQ(input_features.shape()[1], num_mel_bins_); + const int32_t feature_len = input_features.shape()[2]; + MLLM_RT_ASSERT(feature_len > 0); + + auto pos_emb = getBuffer("positional_embedding"); + + std::vector audio_outputs; + audio_outputs.reserve(batch_size); + + for (int32_t b = 0; b < batch_size; ++b) { + Tensor audio_b = input_features[make_slice(b), kAll, kAll].view({1, num_mel_bins_, feature_len}).contiguous(); + + const int32_t chunk_size = n_window_ * 2; + const int32_t num_chunks = (feature_len + chunk_size - 1) / chunk_size; + + std::vector chunk_outputs; + chunk_outputs.reserve(num_chunks); + + for (int32_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + const int32_t start = chunk_idx * chunk_size; + const int32_t chunk_len = std::min(chunk_size, feature_len - start); + auto chunk = Tensor::empty({1, num_mel_bins_, chunk_len}, kFloat32, kCPU).alloc(); + for (int32_t m = 0; m < num_mel_bins_; ++m) { + auto src_ptr = audio_b.offsettedPtr({0, m, start}); + auto dst_ptr = chunk.offsettedPtr({0, m, 0}); + std::memcpy(dst_ptr, src_ptr, chunk_len * sizeof(float)); + } + + auto x = conv1_(chunk); + x = gelu_(x); + x = conv2_(x); + x = gelu_(x); + x = x.transpose(1, 2).contiguous(); // [1, T2, D] + + const int32_t t2 = x.shape()[1]; + MLLM_RT_ASSERT(t2 <= pos_emb.shape()[0]); + auto pos_ptr = pos_emb.ptr(); + auto x_ptr = x.ptr(); + for (int32_t t = 0; t < t2; ++t) { + const float* pos_row = pos_ptr + t * embed_dim_; + float* x_row = x_ptr + t * embed_dim_; + for (int32_t d = 0; d < embed_dim_; ++d) { x_row[d] += pos_row[d]; } + } + + auto hidden_states = x.squeeze(0); // [T2, D] + for (auto& layer : layers_.list()) { hidden_states = layer(hidden_states)[0]; } + if (hidden_states.shape()[0] < 2) { continue; } + + auto pooled = hidden_states.unsqueeze(0).transpose(1, 2); // [1, D, T] + pooled = avg_pooler_(pooled); + pooled = pooled.transpose(1, 2).squeeze(0); // [T', D] + pooled = ln_post_(pooled); + pooled = proj_(pooled); + chunk_outputs.push_back(pooled); + } + + int32_t total_len = 0; + for (const auto& chunk : chunk_outputs) { total_len += chunk.shape()[0]; } + + auto merged = Tensor::empty({total_len, output_dim_}, kFloat32, kCPU).alloc(); + int32_t offset = 0; + for (const auto& chunk : chunk_outputs) { + const int32_t len = chunk.shape()[0]; + const float* src_ptr = chunk.ptr(); + float* dst_ptr = merged.offsettedPtr({offset, 0}); + std::memcpy(dst_ptr, src_ptr, len * output_dim_ * sizeof(float)); + offset += len; + } + + audio_outputs.push_back(merged); + } + + int32_t total_audio_tokens = 0; + for (const auto& out : audio_outputs) { total_audio_tokens += out.shape()[0]; } + + auto output = Tensor::empty({total_audio_tokens, output_dim_}, kFloat32, kCPU).alloc(); + int32_t offset = 0; + for (const auto& out : audio_outputs) { + const int32_t len = out.shape()[0]; + const float* src_ptr = out.ptr(); + float* dst_ptr = output.offsettedPtr({offset, 0}); + std::memcpy(dst_ptr, src_ptr, len * output_dim_ * sizeof(float)); + offset += len; + } + + return {output}; + } +}; + +class Qwen2_5OmniMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniMLP() = default; + Qwen2_5OmniMLP(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::MultimodalRoPE q_rope_; + nn::MultimodalRoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen2_5OmniAttention() = default; + + Qwen2_5OmniAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = hidden_size_ / num_attention_heads_; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, true, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, false, cfg.linear_impl_type); + + q_rope_ = reg( + "q_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + k_rope_ = reg( + "k_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + auto [k, v] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = k; + value_states = v; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + auto output = nn::functional::matmul(attn, value_states); + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + return {output}; + } + + int layer_idx_; +}; + +class Qwen2_5OmniDecoder final : public nn::Module { + public: + Qwen2_5OmniAttention self_attn_; + Qwen2_5OmniMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2_5OmniDecoder() = default; + + Qwen2_5OmniDecoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen2_5OmniText final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + + public: + Qwen2_5OmniText() = default; + + Qwen2_5OmniText(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + + auto inv = makeMultimodalRoPEInvFreq(cfg.hidden_size / cfg.num_attention_heads, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + x = norm_(x); + + return {x}; + } + + nn::Embedding embedding_; +}; + +class Qwen2_5OmniThinker final : public nn::Module { + public: + Qwen2_5OmniThinker() = default; + Qwen2_5OmniThinker(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + model_ = reg("model", cfg); + audio_tower_ = reg("audio_tower", cfg); + visual_ = reg("visual", cfg); + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + + Qwen2_5OmniText model_; + Qwen2_5OmniAudioEncoder audio_tower_; + Qwen2_5OmniVisionEncoder visual_; + nn::Linear lm_head_; +}; + +class Qwen2_5OmniForCausalLM : public ARGeneration { + public: + explicit Qwen2_5OmniForCausalLM(const Qwen2_5OmniConfig& cfg) : cfg_(cfg), thinker_("thinker", cfg) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, + cfg.num_attention_heads, + cfg.num_key_value_heads, + cfg.hidden_size / cfg.num_attention_heads, + kFloat32, + kFloat32, + kCPU, + false); + eos_token_id_ = cfg.eos_token_id; + max_length_ = cfg.max_cache_length; + } + + void clearCache() { kv_cache_.clearCache(); } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto sequence = input.at("sequence"); + + auto input_embeddings = thinker_.model_.embedding_(sequence); + + if (input.count("input_features")) { + auto input_features = input.at("input_features"); + auto audio_embeddings = thinker_.audio_tower_(input_features)[0]; + MLLM_RT_ASSERT_EQ(audio_embeddings.shape()[1], input_embeddings.shape()[2]); + if (audio_embeddings.dtype() != input_embeddings.dtype()) { + audio_embeddings = audio_embeddings.to(input_embeddings.dtype()); + } + + MLLM_RT_ASSERT_EQ(sequence.shape()[0], 1); + auto S = sequence.shape()[1]; + std::vector audio_positions; + audio_positions.reserve(audio_embeddings.shape()[0]); + auto input_ids_ptr = sequence.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg_.audio_token_id) { audio_positions.push_back(s); } + } + MLLM_RT_ASSERT_EQ(static_cast(audio_positions.size()), audio_embeddings.shape()[0]); + + auto D = input_embeddings.shape()[2]; + if (input_embeddings.dtype() == kFloat32) { + for (size_t i = 0; i < audio_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, audio_positions[i], 0}); + auto in_ptr = audio_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else if (input_embeddings.dtype() == kFloat16) { + for (size_t i = 0; i < audio_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, audio_positions[i], 0}); + auto in_ptr = audio_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni audio input."); + } + } + + if (input.count("img")) { + auto img = input.at("img"); + auto grid_thw = input.at("grid_thw"); + + auto inv_freq = makeVisualRoPEInvFreq(cfg_.visual_hidden_size / cfg_.visual_num_heads, 10000.0f); + auto pos_ids = makeVisualRotaryPosEmbIds(grid_thw, cfg_.visual_spatial_merge_size); + + int max_grid = 0; + for (int row = 0; row < grid_thw.shape()[0]; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + max_grid = std::max({max_grid, dims[1], dims[2]}); + } + MLLM_RT_ASSERT(max_grid > 0); + auto rotary_pos_emb_full = makeVisualRotaryPosEmbFull(inv_freq, max_grid); + auto pos_emb = makeVisualRotaryPosEmb(rotary_pos_emb_full, pos_ids, grid_thw); + auto [visual_embedding_sin, visual_embedding_cos] = makeVisualRotarySinCos(pos_emb); + + auto visual_embeddings = thinker_.visual_(img, visual_embedding_sin, visual_embedding_cos, grid_thw)[0]; + MLLM_RT_ASSERT_EQ(visual_embeddings.shape()[1], input_embeddings.shape()[2]); + if (visual_embeddings.dtype() != input_embeddings.dtype()) { + visual_embeddings = visual_embeddings.to(input_embeddings.dtype()); + } + + MLLM_RT_ASSERT_EQ(sequence.shape()[0], 1); + auto S = sequence.shape()[1]; + std::vector image_positions; + image_positions.reserve(visual_embeddings.shape()[0]); + auto input_ids_ptr = sequence.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg_.image_token_id) { image_positions.push_back(s); } + } + MLLM_RT_ASSERT_EQ(static_cast(image_positions.size()), visual_embeddings.shape()[0]); + + auto D = input_embeddings.shape()[2]; + if (input_embeddings.dtype() == kFloat32) { + for (size_t i = 0; i < image_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, image_positions[i], 0}); + auto in_ptr = visual_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else if (input_embeddings.dtype() == kFloat16) { + for (size_t i = 0; i < image_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, image_positions[i], 0}); + auto in_ptr = visual_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni image input."); + } + } + + Tensor position_ids = input.count("position_ids") ? input.at("position_ids") : Tensor::nil(); + Tensor img = input.count("img") ? input.at("img") : Tensor::nil(); + Tensor grid_thw = input.count("grid_thw") ? input.at("grid_thw") : Tensor::nil(); + position_ids = getPositionIds(img, grid_thw, sequence, position_ids); + + auto [llm_embedding_sin, llm_embedding_cos] = + makeMultimodalPositionEmbedding(position_ids, thinker_.model_.getBuffer("inv_freq"), cfg_.max_position_embeddings, + cfg_.hidden_size / cfg_.num_attention_heads, cfg_.mrope_section); + + auto hidden_states = thinker_.model_(input_embeddings, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + auto seq_len = hidden_states.shape()[1]; + auto last_hidden = hidden_states[{kAll, {seq_len - 1}, kAll}]; + auto logits = thinker_.lm_head_(last_hidden); + + const bool output_hidden_states = + args.count("output_hidden_states") ? args.at("output_hidden_states").get() : false; + + if (output_hidden_states) { + return { + {"sequence", logits}, + {"position_ids", position_ids}, + {"hidden_states", hidden_states}, + {"input_embeddings", input_embeddings}, + }; + } + + return { + {"sequence", logits}, + {"position_ids", position_ids}, + }; + } + + Qwen2_5OmniThinker thinker_; + + private: + Tensor getPositionIds(Tensor& img, Tensor& grid_thw, Tensor& input_ids, Tensor& position_ids) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + bool has_multimodal = false; + auto input_ids_ptr = input_ids.ptr(); + auto seq_len = input_ids.shape()[1]; + for (int s = 0; s < seq_len; ++s) { + if (input_ids_ptr[s] == cfg_.vision_start_token_id || input_ids_ptr[s] == cfg_.audio_start_token_id) { + has_multimodal = true; + break; + } + } + + if (has_multimodal) { return getPositionIdsPrefill(input_ids, grid_thw); } + + if (!position_ids.isNil()) { + auto last_pos = *position_ids.offsettedPtr({0, 0, position_ids.shape()[2] - 1}); + auto ret_position_ids = Tensor::empty({3, 1, 1}, kInt64, kCPU).alloc(); + *ret_position_ids.offsettedPtr({0, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({1, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({2, 0, 0}) = last_pos + 1; + return ret_position_ids; + } + + auto B = input_ids.shape()[0]; + auto S = seq_len; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor out = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + for (int d = 0; d < 3; ++d) { + auto out_ptr = out.offsettedPtr({d, 0, 0}); + for (int64_t s = 0; s < S; ++s) { out_ptr[s] = s; } + } + return out; + } + + Tensor getPositionIdsPrefill(Tensor& input_ids, Tensor& image_grid_thw) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + auto B = input_ids.shape()[0]; + auto S = input_ids.shape()[1]; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor position_ids = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + + auto input_ids_ptr = input_ids.ptr(); + + auto fill_text_positions = [&](int start_seq, int len, int64_t start_id) { + for (int d = 0; d < 3; ++d) { + auto out_ptr = position_ids.offsettedPtr({d, 0, 0}); + for (int i = 0; i < len; ++i) { out_ptr[start_seq + i] = start_id + i; } + } + }; + + int seq_idx = 0; + int image_idx = 0; + int64_t current_max_position_id = -1; + const int total_images = image_grid_thw.isNil() ? 0 : image_grid_thw.shape()[0]; + + while (seq_idx < S) { + int next_vision = -1; + int next_audio = -1; + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_start_token_id) { + next_vision = i; + break; + } + } + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_start_token_id) { + next_audio = i; + break; + } + } + + if (next_vision == -1 && next_audio == -1) { + const int text_len = S - seq_idx; + if (text_len > 0) { fill_text_positions(seq_idx, text_len, current_max_position_id + 1); } + break; + } + + const bool is_vision = (next_vision != -1) && (next_audio == -1 || next_vision < next_audio); + const int segment_start = is_vision ? next_vision : next_audio; + + const int text_len = segment_start - seq_idx; + if (text_len > 0) { + fill_text_positions(seq_idx, text_len, current_max_position_id + 1); + current_max_position_id += text_len; + } + + if (is_vision) { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int vision_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_end_token_id) { + vision_end = i; + break; + } + } + MLLM_RT_ASSERT(vision_end != -1); + MLLM_RT_ASSERT(image_idx < total_images); + if (image_grid_thw.isNil()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing grid_thw for Qwen2.5-Omni vision input."); + } + MLLM_RT_ASSERT_EQ(image_grid_thw.shape().size(), 2); + + std::vector image_positions; + for (int i = segment_start + 1; i < vision_end; ++i) { + if (input_ids_ptr[i] == cfg_.image_token_id) { + image_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside vision segment."); + } + } + + const int* grid_dims = image_grid_thw.offsettedPtr({image_idx, 0}); + const int grid_t = grid_dims[0]; + const int grid_h = grid_dims[1]; + const int grid_w = grid_dims[2]; + + const int image_token_len = (grid_t * grid_h * grid_w) + / (cfg_.visual_spatial_merge_size * cfg_.visual_spatial_merge_size); + MLLM_RT_ASSERT_EQ(static_cast(image_positions.size()), image_token_len); + + const int inputs_t = grid_t; + const int inputs_h = grid_h / cfg_.visual_spatial_merge_size; + const int inputs_w = grid_w / cfg_.visual_spatial_merge_size; + + const int64_t vision_start_id = current_max_position_id + 1; + int pos_counter = 0; + for (int ti = 0; ti < inputs_t; ++ti) { + const int64_t t_id = vision_start_id + static_cast(ti) * cfg_.position_id_per_seconds; + for (int hi = 0; hi < inputs_h; ++hi) { + for (int wi = 0; wi < inputs_w; ++wi) { + const auto seq_pos = image_positions[pos_counter++]; + *position_ids.offsettedPtr({0, 0, seq_pos}) = t_id; + *position_ids.offsettedPtr({1, 0, seq_pos}) = vision_start_id + hi; + *position_ids.offsettedPtr({2, 0, seq_pos}) = vision_start_id + wi; + } + } + } + + const int64_t dim_0_tail = vision_start_id + static_cast(inputs_t - 1) * cfg_.position_id_per_seconds; + const int64_t dim_1_tail = vision_start_id + inputs_h - 1; + const int64_t dim_2_tail = vision_start_id + inputs_w - 1; + current_max_position_id = std::max({dim_0_tail, dim_1_tail, dim_2_tail}); + + fill_text_positions(vision_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = vision_end + 1; + image_idx += 1; + } else { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int audio_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_end_token_id) { + audio_end = i; + break; + } + } + MLLM_RT_ASSERT(audio_end != -1); + + std::vector audio_positions; + for (int i = segment_start + 1; i < audio_end; ++i) { + if (input_ids_ptr[i] == cfg_.audio_token_id) { + audio_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside audio segment."); + } + } + + const int audio_len = static_cast(audio_positions.size()); + if (audio_len == 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Empty audio tokens inside audio segment."); + } + const int64_t audio_start_id = current_max_position_id + 1; + for (int i = 0; i < audio_len; ++i) { + const int64_t pos_id = audio_start_id + i; + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, audio_positions[i]}) = pos_id; + } + } + current_max_position_id += audio_len; + + fill_text_positions(audio_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = audio_end + 1; + } + } + + MLLM_RT_ASSERT_EQ(image_idx, total_images); + return position_ids; + } + + const Qwen2_5OmniConfig& cfg_; + nn::StaticCache kv_cache_; +}; + +struct Qwen2_5OmniAudioGenerationConfig { + int32_t thinker_max_new_tokens = 1024; + bool thinker_do_sample = false; + int32_t thinker_top_k = 0; + float thinker_top_p = 0.0f; + float thinker_temperature = 1.0f; + + int32_t talker_max_new_tokens = 1024; + int32_t talker_min_new_tokens = 128; + bool talker_do_sample = true; + int32_t talker_top_k = 40; + float talker_top_p = 0.8f; + float talker_temperature = 0.9f; + float talker_repetition_penalty = 1.05f; + std::vector talker_eos_token_ids = {}; + bool suppress_codec_bos = true; + + int32_t token2wav_num_steps = 10; + float token2wav_guidance_scale = 0.5f; + float token2wav_sway_coefficient = -1.0f; +}; + +struct Qwen2_5OmniAudioGenerationResult { + Tensor sequences = Tensor::nil(); + Tensor wav = Tensor::nil(); +}; + +class Qwen2_5OmniForConditionalGeneration { + public: + explicit Qwen2_5OmniForConditionalGeneration(const Qwen2_5OmniConfig& cfg) + : cfg_(cfg), + thinker_(cfg_), + talker_("talker", cfg_.talker_cfg), + token2wav_("token2wav", cfg_.token2wav_cfg) {} + + void load(const ParameterFile::ptr_t& param) { + thinker_.thinker_.load(param); + if (cfg_.enable_audio_output) { + talker_.load(param); + token2wav_.load(param); + } + } + + void loadSpeakers(const std::string& path) { speaker_map_ = loadSpeakerMap(path); } + + void clearCache() { + thinker_.clearCache(); + talker_.clearCache(); + } + + Qwen2_5OmniAudioGenerationResult generateAudio(const ARGenerationOutputPast& input, const Qwen2_5OmniAudioGenerationConfig& gen_cfg, + const std::string& speaker = "") { + if (!cfg_.enable_audio_output) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Audio output is disabled in Qwen2.5-Omni config."); + } + if (speaker_map_.speakers.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Speaker map is empty. Call loadSpeakers() first."); + } + + const std::string speaker_name = speaker.empty() ? speaker_map_.default_speaker : speaker; + auto spk_it = speaker_map_.speakers.find(speaker_name); + if (spk_it == speaker_map_.speakers.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unknown speaker '{}'.", speaker_name); + } + + auto thinker_output = runThinkerGeneration(input, gen_cfg); + if (thinker_output.generated_ids.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Thinker produced no tokens; cannot run talker."); + } + + auto talker_output = runTalkerGeneration(input, thinker_output, spk_it->second, gen_cfg); + auto wav = token2wav_.forward(talker_output, spk_it->second.cond.to(kFloat32), spk_it->second.ref_mel.to(kFloat32), + gen_cfg.token2wav_num_steps, gen_cfg.token2wav_guidance_scale, gen_cfg.token2wav_sway_coefficient); + + return { + .sequences = thinker_output.sequences, + .wav = wav, + }; + } + + Tensor generateReferenceWav(const std::string& speaker = "") { + if (speaker_map_.speakers.empty()) { return Tensor::nil(); } + const std::string speaker_name = speaker.empty() ? speaker_map_.default_speaker : speaker; + auto spk_it = speaker_map_.speakers.find(speaker_name); + if (spk_it == speaker_map_.speakers.end()) { return Tensor::nil(); } + auto ref_mel = spk_it->second.ref_mel.to(kFloat32); + ref_mel = ref_mel.permute({0, 2, 1}); + if (!ref_mel.isContiguous()) { ref_mel = ref_mel.contiguous(); } + return token2wav_.vocodeMel(ref_mel); + } + + private: + Qwen2_5OmniConfig cfg_; + Qwen2_5OmniSpeakerMap speaker_map_{}; + + public: + Qwen2_5OmniForCausalLM thinker_; + Qwen2_5OmniTalker talker_; + Qwen2_5OmniToken2WavModel token2wav_; + + private: + struct ThinkerGenerationOutput { + Tensor sequences = Tensor::nil(); + std::vector generated_ids; + std::vector token_embeddings; + std::vector token_hidden_states; + int32_t prompt_len = 0; + }; + + static Tensor makeTokenTensor(int64_t token_id) { + Tensor out = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + out.at({0, 0}) = token_id; + return out; + } + + static Tensor makeTokenTensor(const std::vector& ids) { + Tensor out = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU).alloc(); + auto* ptr = out.ptr(); + std::copy(ids.begin(), ids.end(), ptr); + return out; + } + + static Tensor concatTokenTensors(const std::vector& parts) { + MLLM_RT_ASSERT(!parts.empty()); + int32_t total_len = 0; + for (const auto& part : parts) { + MLLM_RT_ASSERT_EQ(part.shape().size(), 2); + MLLM_RT_ASSERT_EQ(part.shape()[0], 1); + MLLM_RT_ASSERT_EQ(part.dtype(), kInt64); + MLLM_RT_ASSERT_EQ(part.device(), kCPU); + total_len += part.shape()[1]; + } + + Tensor out = Tensor::empty({1, total_len}, kInt64, kCPU).alloc(); + auto* out_ptr = out.ptr(); + int32_t offset = 0; + for (const auto& part : parts) { + auto* in_ptr = part.ptr(); + int32_t len = part.shape()[1]; + std::copy(in_ptr, in_ptr + len, out_ptr + offset); + offset += len; + } + return out; + } + + static void zeroEmbeddingsByTokenId(Tensor& embeds, const Tensor& input_ids, int64_t token_id) { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + MLLM_RT_ASSERT_EQ(embeds.shape().size(), 3); + MLLM_RT_ASSERT_EQ(input_ids.shape()[1], embeds.shape()[1]); + + auto seq_len = input_ids.shape()[1]; + auto dim = embeds.shape()[2]; + auto* ids = input_ids.ptr(); + + if (embeds.dtype() == kFloat32) { + for (int s = 0; s < seq_len; ++s) { + if (ids[s] != token_id) continue; + auto* out_ptr = embeds.offsettedPtr({0, s, 0}); + std::fill(out_ptr, out_ptr + dim, 0.0f); + } + } else if (embeds.dtype() == kFloat16) { + for (int s = 0; s < seq_len; ++s) { + if (ids[s] != token_id) continue; + auto* out_ptr = embeds.offsettedPtr({0, s, 0}); + std::fill(out_ptr, out_ptr + dim, static_cast(0)); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni talker preparation."); + } + } + + static Tensor getLastLogits(const Tensor& logits) { + MLLM_RT_ASSERT_EQ(logits.shape().size(), 3); + if (logits.shape()[1] == 1) { return logits; } + return logits[{kAll, logits.shape()[1] - 1, kAll}]; + } + + static int64_t sampleFromDistribution(const std::vector& probs) { + std::random_device rd; + std::mt19937 gen(rd()); + std::discrete_distribution<> dist(probs.begin(), probs.end()); + return dist(gen); + } + + static int64_t categoricalSample(const Tensor& probs) { + MLLM_RT_ASSERT_EQ(probs.dtype(), kFloat32); + auto* prob_data = probs.ptr(); + int vocab_size = probs.shape().back(); + + std::vector cumulative_probs(vocab_size); + std::partial_sum(prob_data, prob_data + vocab_size, cumulative_probs.begin()); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + float r = dis(gen); + + auto it = std::lower_bound(cumulative_probs.begin(), cumulative_probs.end(), r); + if (it == cumulative_probs.end()) { return static_cast(vocab_size - 1); } + return static_cast(std::distance(cumulative_probs.begin(), it)); + } + + static void applyRepetitionPenalty(Tensor& logits, const std::vector& token_ids, float penalty) { + if (penalty <= 1.0f || token_ids.empty()) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + + int vocab_size = logits.shape().back(); + if (logits.shape().size() == 2) { MLLM_RT_ASSERT_EQ(logits.shape()[0], 1); } + + std::unordered_set unique_ids; + unique_ids.reserve(token_ids.size()); + for (auto id : token_ids) { unique_ids.insert(id); } + + auto* logits_ptr = logits.ptr(); + for (auto id : unique_ids) { + if (id < 0 || id >= vocab_size) { continue; } + float& v = logits_ptr[id]; + v = (v < 0.0f) ? v * penalty : v / penalty; + } + } + + static void applyTopKLogits(Tensor& logits, int32_t top_k) { + if (top_k <= 0) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + if (logits.shape().size() == 2) { MLLM_RT_ASSERT_EQ(logits.shape()[0], 1); } + + int vocab_size = logits.shape().back(); + int k = std::min(std::max(top_k, 1), vocab_size); + + auto* logits_ptr = logits.ptr(); + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), + [&logits_ptr](int i1, int i2) { return logits_ptr[i1] > logits_ptr[i2]; }); + + float threshold = logits_ptr[indices[k - 1]]; + float neg_inf = -std::numeric_limits::infinity(); + for (int i = 0; i < vocab_size; ++i) { + if (logits_ptr[i] < threshold) { logits_ptr[i] = neg_inf; } + } + } + + static void applyTopPLogits(Tensor& logits, float top_p) { + if (top_p <= 0.0f || top_p >= 1.0f) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + if (logits.shape().size() == 2) { MLLM_RT_ASSERT_EQ(logits.shape()[0], 1); } + + int vocab_size = logits.shape().back(); + auto* logits_ptr = logits.ptr(); + + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&logits_ptr](int i1, int i2) { return logits_ptr[i1] > logits_ptr[i2]; }); + + float max_logit = logits_ptr[indices[0]]; + std::vector probs(vocab_size); + float sum_exp = 0.0f; + for (int i = 0; i < vocab_size; ++i) { + float exp_val = std::exp(logits_ptr[indices[i]] - max_logit); + probs[i] = exp_val; + sum_exp += exp_val; + } + if (sum_exp <= 0.0f) { return; } + for (auto& p : probs) { p /= sum_exp; } + + float cumulative = 0.0f; + int keep = 0; + for (int i = 0; i < vocab_size; ++i) { + cumulative += probs[i]; + keep++; + if (cumulative > top_p) { break; } + } + keep = std::max(keep, 1); + + float neg_inf = -std::numeric_limits::infinity(); + for (int i = keep; i < vocab_size; ++i) { + logits_ptr[indices[i]] = neg_inf; + } + } + + static int64_t sampleFromLogits(Tensor logits, bool do_sample) { + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + if (!do_sample) { + auto* logits_ptr = logits.ptr(); + int vocab_size = logits.shape().back(); + auto max_it = std::max_element(logits_ptr, logits_ptr + vocab_size); + return static_cast(std::distance(logits_ptr, max_it)); + } + Tensor probs = nn::functional::softmax(logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + return categoricalSample(probs); + } + + static int64_t sampleGreedyLocal(const Tensor& logits) { + Tensor last_logits = getLastLogits(logits); + if (last_logits.dtype() != kFloat32) { last_logits = last_logits.to(kFloat32); } + auto* logits_data = last_logits.ptr(); + int vocab_size = last_logits.shape().back(); + auto max_it = std::max_element(logits_data, logits_data + vocab_size); + return static_cast(std::distance(logits_data, max_it)); + } + + static int64_t sampleTemperatureLocal(const Tensor& logits, float temperature) { + Tensor last_logits = getLastLogits(logits); + if (last_logits.dtype() != kFloat32) { last_logits = last_logits.to(kFloat32); } + if (temperature != 1.0f && temperature > 0.0f) { last_logits = last_logits * (1.f / temperature); } + Tensor probs = nn::functional::softmax(last_logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + return categoricalSample(probs); + } + + static int64_t sampleTopKLocal(const Tensor& logits, int k, float temperature) { + Tensor last_logits = getLastLogits(logits); + if (last_logits.dtype() != kFloat32) { last_logits = last_logits.to(kFloat32); } + if (temperature != 1.0f && temperature > 0.0f) { last_logits = last_logits * (1.f / temperature); } + Tensor probs = nn::functional::softmax(last_logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + + auto* prob_data = probs.ptr(); + int vocab_size = probs.shape().back(); + if (k <= 0 || k > vocab_size) { k = vocab_size; } + + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), + [&prob_data](int i1, int i2) { return prob_data[i1] > prob_data[i2]; }); + + std::vector top_k_probs(k); + float sum = 0.0f; + for (int i = 0; i < k; ++i) { + top_k_probs[i] = prob_data[indices[i]]; + sum += top_k_probs[i]; + } + if (sum <= 0.0f) { return static_cast(indices[0]); } + for (int i = 0; i < k; ++i) { top_k_probs[i] *= (1.f / sum); } + + return static_cast(indices[sampleFromDistribution(top_k_probs)]); + } + + static int64_t sampleTopPLocal(const Tensor& logits, float p, float temperature) { + Tensor last_logits = getLastLogits(logits); + if (last_logits.dtype() != kFloat32) { last_logits = last_logits.to(kFloat32); } + if (temperature != 1.0f && temperature > 0.0f) { last_logits = last_logits * (1.f / temperature); } + Tensor probs = nn::functional::softmax(last_logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + + auto* prob_data = probs.ptr(); + int vocab_size = probs.shape().back(); + + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&prob_data](int i1, int i2) { return prob_data[i1] > prob_data[i2]; }); + + std::vector top_probs; + float cumulative_prob = 0.0f; + int i = 0; + for (; i < vocab_size && cumulative_prob < p; ++i) { + top_probs.push_back(prob_data[indices[i]]); + cumulative_prob += prob_data[indices[i]]; + } + + float sum = std::accumulate(top_probs.begin(), top_probs.end(), 0.0f); + if (sum <= 0.0f) { return static_cast(indices[0]); } + for (float& prob : top_probs) { prob *= (1.f / sum); } + + return static_cast(indices[sampleFromDistribution(top_probs)]); + } + + int64_t sampleToken(const Tensor& logits, bool do_sample, int32_t top_k, float top_p, float temperature) { + bool use_sampling = do_sample || (temperature != 1.0f) || (top_k > 0) || (top_p > 0.0f); + if (use_sampling) { + if (top_k > 0) { return sampleTopKLocal(logits, top_k, temperature); } + if (top_p > 0.0f) { return sampleTopPLocal(logits, top_p, temperature); } + return sampleTemperatureLocal(logits, temperature); + } + return sampleGreedyLocal(logits); + } + + ThinkerGenerationOutput runThinkerGeneration(const ARGenerationOutputPast& input, const Qwen2_5OmniAudioGenerationConfig& gen_cfg) { + thinker_.clearCache(); + + ARGenerationOutputPast past = input; + ARGenerationArgs args; + args.emplace("output_hidden_states", AnyValue(true)); + + const auto& input_ids = input.at("sequence"); + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + std::vector generated_ids; + std::vector token_embeddings; + std::vector token_hidden_states; + + for (int32_t step = 0; step < gen_cfg.thinker_max_new_tokens; ++step) { + auto output = thinker_.forward(past, args); + auto logits = output.at("sequence"); + + auto input_embeddings = output.at("input_embeddings"); + auto hidden_states = output.at("hidden_states"); + + if (step == 0) { + auto embeds_to_talker = input_embeddings.clone(); + if (input.count("input_features")) { zeroEmbeddingsByTokenId(embeds_to_talker, input_ids, cfg_.audio_token_id); } + if (input.count("img")) { zeroEmbeddingsByTokenId(embeds_to_talker, input_ids, cfg_.image_token_id); } + if (input.count("video")) { zeroEmbeddingsByTokenId(embeds_to_talker, input_ids, cfg_.video_token_id); } + token_embeddings.emplace_back(std::move(embeds_to_talker)); + } else { + token_embeddings.emplace_back(std::move(input_embeddings)); + } + token_hidden_states.emplace_back(std::move(hidden_states)); + + int64_t next_token_id = sampleToken(logits, gen_cfg.thinker_do_sample, gen_cfg.thinker_top_k, gen_cfg.thinker_top_p, + gen_cfg.thinker_temperature); + generated_ids.push_back(next_token_id); + + if (next_token_id == cfg_.eos_token_id) { break; } + + past = std::move(output); + past["sequence"] = makeTokenTensor(next_token_id); + } + + std::vector sequence_ids; + sequence_ids.reserve(input_ids.shape()[1] + generated_ids.size()); + auto* input_ptr = input_ids.ptr(); + for (int i = 0; i < input_ids.shape()[1]; ++i) { sequence_ids.push_back(input_ptr[i]); } + sequence_ids.insert(sequence_ids.end(), generated_ids.begin(), generated_ids.end()); + + return { + .sequences = makeTokenTensor(sequence_ids), + .generated_ids = std::move(generated_ids), + .token_embeddings = std::move(token_embeddings), + .token_hidden_states = std::move(token_hidden_states), + .prompt_len = input_ids.shape()[1], + }; + } + + Tensor runTalkerGeneration(const ARGenerationOutputPast& input, const ThinkerGenerationOutput& thinker_output, + const Qwen2_5OmniSpeakerParams& speaker_params, const Qwen2_5OmniAudioGenerationConfig& gen_cfg) { + if (thinker_output.generated_ids.empty()) { return Tensor::nil(); } + + talker_.clearCache(); + + const auto& input_ids = input.at("sequence"); + const auto& token_embeddings = thinker_output.token_embeddings; + const auto& token_hidden_states = thinker_output.token_hidden_states; + + std::vector reply_hidden_states(token_hidden_states.begin() + 1, token_hidden_states.end()); + std::vector reply_token_embeds(token_embeddings.begin() + 1, token_embeddings.end()); + + auto hidden_dtype = token_hidden_states[0].dtype(); + auto hidden_device = token_hidden_states[0].device(); + auto embed_dtype = token_embeddings[0].dtype(); + auto embed_device = token_embeddings[0].device(); + Tensor reply_hidden = reply_hidden_states.empty() + ? Tensor::empty({1, 0, token_hidden_states[0].shape()[2]}, hidden_dtype, hidden_device).alloc() + : nn::functional::concat(reply_hidden_states, 1); + Tensor reply_embeds = reply_token_embeds.empty() + ? Tensor::empty({1, 0, token_embeddings[0].shape()[2]}, embed_dtype, embed_device).alloc() + : nn::functional::concat(reply_token_embeds, 1); + auto thinker_reply_part = reply_hidden + reply_embeds; + if (thinker_reply_part.shape()[1] == 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Thinker response is too short for talker conditioning."); + } + + std::vector talker_text_ids; + talker_text_ids.reserve(input_ids.shape()[1] + 2); + auto* input_ptr = input_ids.ptr(); + for (int i = 0; i < input_ids.shape()[1]; ++i) { talker_text_ids.push_back(input_ptr[i]); } + talker_text_ids.push_back(speaker_params.bos_token); + talker_text_ids.push_back(thinker_output.generated_ids.front()); + auto talker_input_text_ids = makeTokenTensor(talker_text_ids); + + std::vector talker_codec_ids(input_ids.shape()[1] + 2, talker_.codec_mask_token()); + talker_codec_ids[input_ids.shape()[1]] = talker_.codec_pad_token(); + talker_codec_ids[input_ids.shape()[1] + 1] = talker_.codec_bos_token(); + auto talker_input_ids = makeTokenTensor(talker_codec_ids); + + auto talker_inputs_embeds = Tensor(token_hidden_states[0]); + talker_inputs_embeds = talker_inputs_embeds + token_embeddings[0]; + auto talker_text_bos_embed = thinker_.thinker_.model_.embedding_(makeTokenTensor(speaker_params.bos_token)); + auto first_reply = thinker_reply_part.shape()[1] > 0 + ? thinker_reply_part[{kAll, {0, 1}, kAll}] + : Tensor::empty({1, 0, talker_inputs_embeds.shape()[2]}, talker_inputs_embeds.dtype(), talker_inputs_embeds.device()) + .alloc(); + talker_inputs_embeds = nn::functional::concat({talker_inputs_embeds, talker_text_bos_embed, first_reply}, 1); + + auto eos_embedding = thinker_.thinker_.model_.embedding_(makeTokenTensor(talker_.text_eos_token())); + auto pad_embedding = thinker_.thinker_.model_.embedding_(makeTokenTensor(talker_.text_pad_token())); + Tensor reply_tail = + thinker_reply_part.shape()[1] > 1 + ? thinker_reply_part[{kAll, {1, thinker_reply_part.shape()[1]}, kAll}] + : Tensor::empty({1, 0, talker_inputs_embeds.shape()[2]}, talker_inputs_embeds.dtype(), talker_inputs_embeds.device()).alloc(); + thinker_reply_part = nn::functional::concat({reply_tail, eos_embedding, pad_embedding}, 1); + + Tensor talker_attention_mask = Tensor::nil(); + if (input.count("attention_mask")) { + auto mask = input.at("attention_mask"); + if (mask.dtype() != kFloat16 && mask.dtype() != kFloat32) { mask = mask.to(kFloat32); } + auto ones = Tensor::ones({1, 2}, mask.dtype(), mask.device()); + talker_attention_mask = nn::functional::concat({mask, ones}, 1); + } + + Tensor image_grid_thw = input.count("grid_thw") ? input.at("grid_thw") : Tensor::nil(); + + std::vector generated_codes; + Tensor position_ids = Tensor::nil(); + Tensor cur_input_ids = talker_input_ids; + Tensor cur_input_text_ids = talker_input_text_ids; + Tensor cur_inputs_embeds = talker_inputs_embeds; + Tensor cur_reply_part = thinker_reply_part; + + std::vector repetition_tokens = talker_codec_ids; + repetition_tokens.reserve(talker_codec_ids.size() + gen_cfg.talker_max_new_tokens); + + std::vector eos_ids = gen_cfg.talker_eos_token_ids; + if (eos_ids.empty()) { + eos_ids.push_back(talker_.codec_pad_token()); + eos_ids.push_back(talker_.codec_eos_token()); + } + + for (int32_t step = 0; step < gen_cfg.talker_max_new_tokens; ++step) { + auto output = talker_.forward(cur_input_ids, cur_input_text_ids, cur_reply_part, cur_inputs_embeds, talker_attention_mask, + image_grid_thw, position_ids); + + auto logits = output.logits; + auto last_logits = getLastLogits(logits); + + const int32_t vocab_size = last_logits.shape().back(); + + if (gen_cfg.suppress_codec_bos) { + auto* logits_ptr = last_logits.ptr(); + logits_ptr[talker_.codec_bos_token()] = -1e9f; + } + if (gen_cfg.talker_min_new_tokens > 0 && step < gen_cfg.talker_min_new_tokens) { + auto* logits_ptr = last_logits.ptr(); + for (int64_t eos_id : eos_ids) { + if (eos_id >= 0 && eos_id < vocab_size) { logits_ptr[eos_id] = -1e9f; } + } + } + applyRepetitionPenalty(last_logits, repetition_tokens, gen_cfg.talker_repetition_penalty); + + Tensor sample_logits = last_logits; + if (gen_cfg.talker_temperature != 1.0f && gen_cfg.talker_temperature > 0.0f) { + sample_logits = sample_logits * (1.f / gen_cfg.talker_temperature); + } + if (gen_cfg.talker_do_sample) { + if (gen_cfg.talker_top_k > 0) { applyTopKLogits(sample_logits, gen_cfg.talker_top_k); } + if (gen_cfg.talker_top_p > 0.0f) { applyTopPLogits(sample_logits, gen_cfg.talker_top_p); } + } + + int64_t next_token_id = sampleFromLogits(sample_logits, gen_cfg.talker_do_sample); + generated_codes.push_back(next_token_id); + repetition_tokens.push_back(next_token_id); + + if (std::find(eos_ids.begin(), eos_ids.end(), next_token_id) != eos_ids.end()) { break; } + + position_ids = output.position_ids; + cur_reply_part = output.thinker_reply_part; + cur_input_ids = makeTokenTensor(next_token_id); + cur_input_text_ids = Tensor::nil(); + cur_inputs_embeds = Tensor::nil(); + } + + if (!generated_codes.empty()) { generated_codes.pop_back(); } + if (generated_codes.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Talker produced no codec tokens."); + } + return makeTokenTensor(generated_codes); + } + + +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp new file mode 100644 index 000000000..df8019a84 --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp @@ -0,0 +1,626 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "mllm/core/Parallel.hpp" +#include "mllm/core/SlicePrimitives.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" + +namespace mllm::models::qwen2_5omni { + +constexpr float kPi = 3.14159265358979323846f; + +inline auto makeTalkerRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0f / std::pow(rope_theta, 2.0f * i / output_dim); } + return inv_freq; +} + +inline auto makeTalkerPositionEmbedding(Tensor& position_ids, const Tensor& inv_freq, const std::vector& mrope_section) + -> std::pair { + MLLM_RT_ASSERT_EQ(position_ids.shape().size(), 3); + MLLM_RT_ASSERT_EQ(position_ids.shape()[1], 1); + + Tensor tmp_sin = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + Tensor tmp_cos = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + + for (int b = 0; b < 3; ++b) { + for (int d = 0; d < inv_freq.shape()[0]; ++d) { + for (int s = 0; s < position_ids.shape()[2]; ++s) { + auto value = inv_freq.ptr()[d] * (*position_ids.offsettedPtr({b, 0, s})); + *tmp_cos.offsettedPtr({b, s, d}) = cosf(value); + *tmp_cos.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = cosf(value); + *tmp_sin.offsettedPtr({b, s, d}) = sinf(value); + *tmp_sin.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = sinf(value); + } + } + } + + Tensor sin = Tensor::nil(); + Tensor cos = Tensor::nil(); + + if (!mrope_section.empty()) { + auto double_rope_section = mrope_section; + for (int i : mrope_section) { double_rope_section.push_back(i); } + + int num_rows = tmp_sin.shape()[1]; + int num_cols = tmp_sin.shape()[2]; + + sin = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + cos = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + + std::vector start_cols; + int current_start = 0; + start_cols.push_back(current_start); + for (int s : double_rope_section) { + current_start += s; + start_cols.push_back(current_start); + } + + for (int j = 0; j < static_cast(double_rope_section.size()); ++j) { + int layer = j % 3; + int s_j = double_rope_section[j]; + int start_col_in = start_cols[j]; + int start_col_out = start_cols[j]; + for (int row = 0; row < num_rows; ++row) { + auto in_cos_row_ptr = tmp_cos.offsettedPtr({layer, row, 0}); + auto out_cos_row_ptr = cos.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_cos_row_ptr[start_col_out + c] = in_cos_row_ptr[start_col_in + c]; } + + auto in_sin_row_ptr = tmp_sin.offsettedPtr({layer, row, 0}); + auto out_sin_row_ptr = sin.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_sin_row_ptr[start_col_out + c] = in_sin_row_ptr[start_col_in + c]; } + } + } + } else { + sin = tmp_sin; + cos = tmp_cos; + } + + return {sin, cos}; +} + +struct Qwen2_5OmniSpeakerParams { + int64_t bos_token = 0; + Tensor cond = Tensor::nil(); + Tensor ref_mel = Tensor::nil(); +}; + +struct Qwen2_5OmniSpeakerMap { + std::unordered_map speakers; + std::string default_speaker; +}; + +inline Tensor tensorFromJson(const nlohmann::ordered_json& obj) { + if (!obj.contains("shape") || !obj.contains("data")) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Invalid speaker json entry: missing shape/data."); + } + auto shape = obj["shape"].get>(); + auto data = obj["data"].get>(); + + int64_t expected = 1; + for (auto dim : shape) { expected *= dim; } + MLLM_RT_ASSERT_EQ(expected, static_cast(data.size())); + + Tensor out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + std::copy(data.begin(), data.end(), out.ptr()); + return out; +} + +inline Qwen2_5OmniSpeakerMap loadSpeakerMap(const std::string& path) { + std::ifstream in(path); + if (!in.is_open()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Failed to open spk_dict.json at {}", path); } + + nlohmann::ordered_json root; + in >> root; + + Qwen2_5OmniSpeakerMap map; + bool first = true; + for (auto it = root.begin(); it != root.end(); ++it) { + const auto& name = it.key(); + const auto& entry = it.value(); + Qwen2_5OmniSpeakerParams params; + params.bos_token = entry.value("bos_token", 0); + params.cond = tensorFromJson(entry["cond"]); + params.ref_mel = tensorFromJson(entry["ref_mel"]); + map.speakers.emplace(name, std::move(params)); + if (first) { + map.default_speaker = name; + first = false; + } + } + + if (map.speakers.empty()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Empty speaker map in {}", path); } + return map; +} + +class Qwen2_5OmniTalkerMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniTalkerMLP() = default; + Qwen2_5OmniTalkerMLP(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniTalkerAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::MultimodalRoPE q_rope_; + nn::MultimodalRoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen2_5OmniTalkerAttention() = default; + + Qwen2_5OmniTalkerAttention(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + head_dim_ = cfg.head_dim; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, true); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, true); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, true); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, false); + + q_rope_ = reg( + "q_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + k_rope_ = reg( + "k_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + auto [k, v] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = k; + value_states = v; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + auto output = nn::functional::matmul(attn, value_states); + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + return {output}; + } + + int layer_idx_ = 0; +}; + +class Qwen2_5OmniTalkerDecoder final : public nn::Module { + public: + Qwen2_5OmniTalkerAttention self_attn_; + Qwen2_5OmniTalkerMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2_5OmniTalkerDecoder() = default; + + Qwen2_5OmniTalkerDecoder(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen2_5OmniTalkerModel final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + + public: + Qwen2_5OmniTalkerModel() = default; + + Qwen2_5OmniTalkerModel(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.embedding_size); + + auto inv = makeTalkerRoPEInvFreq(cfg.head_dim, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + x = norm_(x); + + return {x}; + } + + nn::Embedding embedding_; +}; + +struct Qwen2_5OmniTalkerOutput { + Tensor logits = Tensor::nil(); + Tensor thinker_reply_part = Tensor::nil(); + Tensor position_ids = Tensor::nil(); +}; + +class Qwen2_5OmniTalker final : public nn::Module { + public: + Qwen2_5OmniTalker() = delete; + Qwen2_5OmniTalker(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name), cfg_(cfg) { + thinker_to_talker_proj_ = reg("thinker_to_talker_proj", cfg.embedding_size, cfg.hidden_size, true); + model_ = reg("model", cfg); + codec_head_ = reg("codec_head", cfg.hidden_size, cfg.vocab_size, false); + + kv_cache_ = nn::StaticCache(cfg.max_position_embeddings, cfg.num_hidden_layers, cfg.num_attention_heads, cfg.num_key_value_heads, + cfg.head_dim, kFloat32, kFloat32, kCPU, false); + + codec_bos_token_ = cfg.tts_codec_start_token_id; + codec_eos_token_ = cfg.tts_codec_end_token_id; + codec_pad_token_ = cfg.tts_codec_pad_token_id; + codec_mask_token_ = cfg.tts_codec_mask_token_id; + text_bos_token_ = cfg.tts_text_start_token_id; + text_eos_token_ = cfg.tts_text_end_token_id; + text_pad_token_ = cfg.tts_text_pad_token_id; + } + + void clearCache() { + kv_cache_.clearCache(); + rope_deltas_ = Tensor::nil(); + } + + Qwen2_5OmniTalkerOutput forward(const Tensor& input_ids, const Tensor& input_text_ids, Tensor thinker_reply_part, + Tensor inputs_embeds, const Tensor& attention_mask, const Tensor& image_grid_thw, + Tensor position_ids) { + Tensor ids_for_pos = input_text_ids.isNil() ? input_ids : input_text_ids; + position_ids = getPositionIds(ids_for_pos, image_grid_thw, position_ids); + + const bool prefill = kv_cache_.getCurrentSeqCnt(0) == 0; + if (!inputs_embeds.isNil() && prefill) { + const auto S = inputs_embeds.shape()[1]; + MLLM_RT_ASSERT(S >= 2); + + auto bos_token = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + bos_token.at({0, 0}) = codec_bos_token_; + auto bos_embed = model_.embedding_(bos_token); + + auto pad_token = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + pad_token.at({0, 0}) = codec_pad_token_; + auto pad_embed = model_.embedding_(pad_token); + + auto embed_dim = inputs_embeds.shape()[2]; + if (inputs_embeds.dtype() == kFloat32) { + auto* out_ptr = inputs_embeds.offsettedPtr({0, S - 1, 0}); + auto* pad_ptr = inputs_embeds.offsettedPtr({0, S - 2, 0}); + auto* bos_ptr = bos_embed.ptr(); + auto* pad_src_ptr = pad_embed.ptr(); + for (int d = 0; d < embed_dim; ++d) { + out_ptr[d] += bos_ptr[d]; + pad_ptr[d] += pad_src_ptr[d]; + } + } else if (inputs_embeds.dtype() == kFloat16) { + auto* out_ptr = inputs_embeds.offsettedPtr({0, S - 1, 0}); + auto* pad_ptr = inputs_embeds.offsettedPtr({0, S - 2, 0}); + auto* bos_ptr = bos_embed.ptr(); + auto* pad_src_ptr = pad_embed.ptr(); + for (int d = 0; d < embed_dim; ++d) { + out_ptr[d] = static_cast(static_cast(out_ptr[d]) + static_cast(bos_ptr[d])); + pad_ptr[d] = static_cast(static_cast(pad_ptr[d]) + static_cast(pad_src_ptr[d])); + } + } + } + + if (inputs_embeds.isNil()) { + auto codec_embeds = model_.embedding_(input_ids); + inputs_embeds = codec_embeds + thinker_reply_part[{kAll, {0, 1}, kAll}]; + if (thinker_reply_part.shape()[1] > 1) { + thinker_reply_part = thinker_reply_part[{kAll, {1, thinker_reply_part.shape()[1]}, kAll}]; + } + } + + auto [llm_embedding_sin, llm_embedding_cos] = + makeTalkerPositionEmbedding(position_ids, model_.getBuffer("inv_freq"), cfg_.mrope_section); + + auto talker_lm_input = thinker_to_talker_proj_(inputs_embeds); + auto hidden_states = model_(talker_lm_input, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + auto logits = codec_head_(hidden_states).to(kFloat32); + + return { + .logits = logits, + .thinker_reply_part = thinker_reply_part, + .position_ids = position_ids, + }; + } + + int64_t codec_bos_token() const { return codec_bos_token_; } + int64_t codec_eos_token() const { return codec_eos_token_; } + int64_t codec_pad_token() const { return codec_pad_token_; } + int64_t codec_mask_token() const { return codec_mask_token_; } + int64_t text_eos_token() const { return text_eos_token_; } + int64_t text_pad_token() const { return text_pad_token_; } + int64_t text_bos_token() const { return text_bos_token_; } + + Qwen2_5OmniTalkerModel model_; + + private: + Tensor getPositionIds(const Tensor& input_ids, const Tensor& image_grid_thw, const Tensor& position_ids) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + bool has_multimodal = false; + auto input_ids_ptr = input_ids.ptr(); + auto seq_len = input_ids.shape()[1]; + for (int s = 0; s < seq_len; ++s) { + if (input_ids_ptr[s] == cfg_.vision_start_token_id || input_ids_ptr[s] == cfg_.audio_start_token_id) { + has_multimodal = true; + break; + } + } + + if (has_multimodal) { return getPositionIdsPrefill(input_ids, image_grid_thw); } + + if (!position_ids.isNil()) { + auto last_pos = position_ids.constAt({0, 0, position_ids.shape()[2] - 1}); + auto ret_position_ids = Tensor::empty({3, 1, 1}, kInt64, kCPU).alloc(); + *ret_position_ids.offsettedPtr({0, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({1, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({2, 0, 0}) = last_pos + 1; + return ret_position_ids; + } + + auto B = input_ids.shape()[0]; + auto S = seq_len; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor out = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + for (int d = 0; d < 3; ++d) { + auto out_ptr = out.offsettedPtr({d, 0, 0}); + for (int64_t s = 0; s < S; ++s) { out_ptr[s] = s; } + } + return out; + } + + Tensor getPositionIdsPrefill(const Tensor& input_ids, const Tensor& image_grid_thw) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + auto B = input_ids.shape()[0]; + auto S = input_ids.shape()[1]; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor position_ids = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + auto input_ids_ptr = input_ids.ptr(); + + auto fill_text_positions = [&](int start_seq, int len, int64_t start_id) { + for (int d = 0; d < 3; ++d) { + auto out_ptr = position_ids.offsettedPtr({d, 0, 0}); + for (int i = 0; i < len; ++i) { out_ptr[start_seq + i] = start_id + i; } + } + }; + + int seq_idx = 0; + int image_idx = 0; + int64_t current_max_position_id = -1; + const int total_images = image_grid_thw.isNil() ? 0 : image_grid_thw.shape()[0]; + + while (seq_idx < S) { + int next_vision = -1; + int next_audio = -1; + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_start_token_id) { + next_vision = i; + break; + } + } + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_start_token_id) { + next_audio = i; + break; + } + } + + if (next_vision == -1 && next_audio == -1) { + const int text_len = S - seq_idx; + if (text_len > 0) { fill_text_positions(seq_idx, text_len, current_max_position_id + 1); } + break; + } + + const bool is_vision = (next_vision != -1) && (next_audio == -1 || next_vision < next_audio); + const int segment_start = is_vision ? next_vision : next_audio; + + const int text_len = segment_start - seq_idx; + if (text_len > 0) { + fill_text_positions(seq_idx, text_len, current_max_position_id + 1); + current_max_position_id += text_len; + } + + if (is_vision) { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int vision_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_end_token_id) { + vision_end = i; + break; + } + } + MLLM_RT_ASSERT(vision_end != -1); + + if (image_idx >= total_images) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "Image index out of range."); } + + auto grid_t = image_grid_thw.ptr()[image_idx * 3]; + auto grid_h = image_grid_thw.ptr()[image_idx * 3 + 1]; + auto grid_w = image_grid_thw.ptr()[image_idx * 3 + 2]; + int vision_len = grid_t * grid_h * grid_w; + vision_len /= (cfg_.spatial_merge_size * cfg_.spatial_merge_size); + + for (int i = 0; i < vision_len; ++i) { + const int pos = segment_start + 1 + i; + if (pos >= S) { break; } + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, pos}) = current_max_position_id + 1 + i; + } + } + current_max_position_id += vision_len; + + fill_text_positions(vision_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = vision_end + 1; + image_idx += 1; + } else { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int audio_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_end_token_id) { + audio_end = i; + break; + } + } + MLLM_RT_ASSERT(audio_end != -1); + + std::vector audio_positions; + for (int i = segment_start + 1; i < audio_end; ++i) { + if (input_ids_ptr[i] == cfg_.audio_token_id) { + audio_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside audio segment."); + } + } + const int audio_len = static_cast(audio_positions.size()); + if (audio_len == 0) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "Empty audio tokens inside audio segment."); } + const int64_t audio_start_id = current_max_position_id + 1; + for (int i = 0; i < audio_len; ++i) { + const int64_t pos_id = audio_start_id + i; + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, audio_positions[i]}) = pos_id; + } + } + current_max_position_id += audio_len; + fill_text_positions(audio_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + seq_idx = audio_end + 1; + } + } + + return position_ids; + } + + const Qwen2_5OmniTalkerConfig& cfg_; + nn::Linear thinker_to_talker_proj_; + nn::Linear codec_head_; + nn::StaticCache kv_cache_; + Tensor rope_deltas_ = Tensor::nil(); + + int64_t codec_bos_token_ = 0; + int64_t codec_eos_token_ = 0; + int64_t codec_pad_token_ = 0; + int64_t codec_mask_token_ = 0; + int64_t text_bos_token_ = 0; + int64_t text_eos_token_ = 0; + int64_t text_pad_token_ = 0; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp new file mode 100644 index 000000000..6e5939a44 --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp @@ -0,0 +1,1508 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/Parallel.hpp" +#include "mllm/core/SlicePrimitives.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" + +namespace mllm::models::qwen2_5omni { + +namespace token2wav { + +constexpr float kPi = 3.14159265358979323846f; + +inline Tensor pad1dReflect(const Tensor& x, int32_t pad_left, int32_t pad_right) { + if (pad_left == 0 && pad_right == 0) { return x; } + return nn::functional::pad(x, {pad_left, pad_right}, aops::PadMode::kReflect); +} + +inline Tensor pad1dReplicate(const Tensor& x, int32_t pad_left, int32_t pad_right) { + if (pad_left == 0 && pad_right == 0) { return x; } + return nn::functional::pad(x, {pad_left, pad_right}, aops::PadMode::kReplicate); +} + +inline Tensor clampTensor(const Tensor& x, float min_val, float max_val) { + MLLM_RT_ASSERT_EQ(x.device(), kCPU); + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + + auto out = Tensor::empty(x.shape(), x.dtype(), x.device()).alloc(); + const auto* src = x.ptr(); + auto* dst = out.ptr(); + const auto numel = x.numel(); + + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + float v = src[idx]; + v = std::min(std::max(v, min_val), max_val); + dst[idx] = v; + }); + return out; +} + +inline Tensor amplitudeToDb(const Tensor& amplitude, float min_db_level) { + MLLM_RT_ASSERT_EQ(amplitude.device(), kCPU); + MLLM_RT_ASSERT_EQ(amplitude.dtype(), kFloat32); + + const float min_level = std::exp(min_db_level / 20.0f * std::log(10.0f)); + const float log10_scale = 1.0f / std::log(10.0f); + + auto out = Tensor::empty(amplitude.shape(), amplitude.dtype(), amplitude.device()).alloc(); + const auto* src = amplitude.ptr(); + auto* dst = out.ptr(); + const auto numel = amplitude.numel(); + + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + float v = std::max(src[idx], min_level); + dst[idx] = 20.0f * std::log(v) * log10_scale; + }); + + return out; +} + +inline Tensor normalizeSpectrogram(const Tensor& spectrogram, float max_value, float min_db) { + MLLM_RT_ASSERT_EQ(spectrogram.device(), kCPU); + MLLM_RT_ASSERT_EQ(spectrogram.dtype(), kFloat32); + + auto out = Tensor::empty(spectrogram.shape(), spectrogram.dtype(), spectrogram.device()).alloc(); + const auto* src = spectrogram.ptr(); + auto* dst = out.ptr(); + const auto numel = spectrogram.numel(); + + const float scale = (2.0f * max_value) / (-min_db); + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + float v = scale * (src[idx] - min_db) - max_value; + v = std::min(std::max(v, -max_value), max_value); + dst[idx] = v; + }); + return out; +} + +inline float besselI0(float x) { + const float ax = std::abs(x); + if (ax < 3.75f) { + const float y = (ax / 3.75f); + const float y2 = y * y; + return 1.0f + y2 * (3.5156229f + + y2 * (3.0899424f + + y2 * (1.2067492f + + y2 * (0.2659732f + + y2 * (0.0360768f + + y2 * 0.0045813f))))); + } + + const float y = 3.75f / ax; + const float exp_ax = std::exp(ax); + return (exp_ax / std::sqrt(ax)) * + (0.39894228f + + y * (0.01328592f + + y * (0.00225319f + + y * (-0.00157565f + + y * (0.00916281f + + y * (-0.02057706f + + y * (0.02635537f + + y * (-0.01647633f + + y * 0.00392377f)))))))); +} + +inline Tensor kaiserSincFilter1d(float cutoff, float half_width, int32_t kernel_size) { + const bool is_even = (kernel_size % 2) == 0; + const int32_t half_size = kernel_size / 2; + + if (cutoff == 0.0f) { return Tensor::zeros({1, 1, kernel_size}, kFloat32, kCPU); } + + const float delta_f = 4.0f * half_width; + const float attenuation = 2.285f * static_cast(half_size - 1) * kPi * delta_f + 7.95f; + + float beta = 0.0f; + if (attenuation > 50.0f) { + beta = 0.1102f * (attenuation - 8.7f); + } else if (attenuation >= 21.0f) { + beta = 0.5842f * std::pow(attenuation - 21.0f, 0.4f) + 0.07886f * (attenuation - 21.0f); + } + + const float denom = besselI0(beta); + std::vector window(kernel_size, 1.0f); + for (int32_t n = 0; n < kernel_size; ++n) { + const float ratio = (2.0f * static_cast(n) / static_cast(kernel_size - 1)) - 1.0f; + const float val = std::sqrt(std::max(0.0f, 1.0f - ratio * ratio)); + window[n] = besselI0(beta * val) / denom; + } + + std::vector filter(kernel_size, 0.0f); + float sum = 0.0f; + for (int32_t n = 0; n < kernel_size; ++n) { + float t = static_cast(n) - static_cast(half_size); + if (is_even) { t += 0.5f; } + const float arg = 2.0f * cutoff * t; + const float sinc = (arg == 0.0f) ? 1.0f : std::sin(kPi * arg) / (kPi * arg); + const float v = 2.0f * cutoff * window[n] * sinc; + filter[n] = v; + sum += v; + } + + if (sum != 0.0f) { + for (auto& v : filter) { v /= sum; } + } + + auto out = Tensor::empty({1, 1, kernel_size}, kFloat32, kCPU).alloc(); + std::copy(filter.begin(), filter.end(), out.ptr()); + return out; +} + +inline Tensor convTranspose1dDepthwise(const Tensor& input, const Tensor& filter, int32_t stride) { + MLLM_RT_ASSERT_EQ(input.device(), kCPU); + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(filter.device(), kCPU); + MLLM_RT_ASSERT_EQ(filter.dtype(), kFloat32); + + const auto& in_shape = input.shape(); + const int32_t batch = in_shape[0]; + const int32_t channels = in_shape[1]; + const int32_t in_len = in_shape[2]; + const int32_t kernel = filter.shape()[2]; + + const int32_t out_len = (in_len - 1) * stride + kernel; + auto out = Tensor::zeros({batch, channels, out_len}, kFloat32, kCPU); + + const auto* in_ptr = input.ptr(); + const auto* filt_ptr = filter.ptr(); + auto* out_ptr = out.ptr(); + + const int32_t in_step = channels * in_len; + const int32_t out_step = channels * out_len; + + for (int32_t b = 0; b < batch; ++b) { + const float* in_b = in_ptr + b * in_step; + float* out_b = out_ptr + b * out_step; + for (int32_t c = 0; c < channels; ++c) { + const float* in_c = in_b + c * in_len; + float* out_c = out_b + c * out_len; + const float* f = filt_ptr; + for (int32_t i = 0; i < in_len; ++i) { + const float v = in_c[i]; + const int32_t base = i * stride; + for (int32_t k = 0; k < kernel; ++k) { out_c[base + k] += v * f[k]; } + } + } + } + + return out; +} + +inline Tensor conv1dDepthwise(const Tensor& input, const Tensor& filter, int32_t stride) { + MLLM_RT_ASSERT_EQ(input.device(), kCPU); + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(filter.device(), kCPU); + MLLM_RT_ASSERT_EQ(filter.dtype(), kFloat32); + + const auto& in_shape = input.shape(); + const int32_t batch = in_shape[0]; + const int32_t channels = in_shape[1]; + const int32_t in_len = in_shape[2]; + const int32_t kernel = filter.shape()[2]; + + const int32_t out_len = (in_len - kernel) / stride + 1; + auto out = Tensor::zeros({batch, channels, out_len}, kFloat32, kCPU); + + const auto* in_ptr = input.ptr(); + const auto* filt_ptr = filter.ptr(); + auto* out_ptr = out.ptr(); + + const int32_t in_step = channels * in_len; + const int32_t out_step = channels * out_len; + + for (int32_t b = 0; b < batch; ++b) { + const float* in_b = in_ptr + b * in_step; + float* out_b = out_ptr + b * out_step; + for (int32_t c = 0; c < channels; ++c) { + const float* in_c = in_b + c * in_len; + float* out_c = out_b + c * out_len; + const float* f = filt_ptr; + for (int32_t o = 0; o < out_len; ++o) { + float sum = 0.0f; + const int32_t base = o * stride; + for (int32_t k = 0; k < kernel; ++k) { sum += in_c[base + k] * f[k]; } + out_c[o] = sum; + } + } + } + + return out; +} + +inline Tensor randomNormal(const std::vector& shape, float mean = 0.0f, float std = 1.0f) { + auto out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + auto* ptr = out.ptr(); + const int64_t numel = out.numel(); + std::mt19937 gen(static_cast(mllm::Context::instance().getRandomState())); + std::normal_distribution dist(mean, std); + for (int64_t i = 0; i < numel; ++i) { ptr[i] = dist(gen); } + return out; +} + +inline Tensor linspace(float start, float end, int32_t steps) { + auto out = Tensor::empty({steps}, kFloat32, kCPU).alloc(); + auto* ptr = out.ptr(); + if (steps <= 1) { + if (steps == 1) { ptr[0] = start; } + return out; + } + const float step = (end - start) / static_cast(steps - 1); + for (int32_t i = 0; i < steps; ++i) { ptr[i] = start + step * static_cast(i); } + return out; +} + +inline Tensor repeatInterleave(const Tensor& input, int32_t repeats, int32_t dim) { + MLLM_RT_ASSERT_EQ(input.device(), kCPU); + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(dim, 1); + + if (repeats == 1) { return input; } + + const auto& shape = input.shape(); + const int32_t batch = shape[0]; + const int32_t seq_len = shape[1]; + const int32_t channels = shape[2]; + + auto out = Tensor::empty({batch, seq_len * repeats, channels}, kFloat32, kCPU).alloc(); + const auto* src = input.ptr(); + auto* dst = out.ptr(); + + const int64_t in_stride_b = static_cast(seq_len) * channels; + const int64_t out_stride_b = static_cast(seq_len) * repeats * channels; + + for (int32_t b = 0; b < batch; ++b) { + const float* src_b = src + b * in_stride_b; + float* dst_b = dst + b * out_stride_b; + for (int32_t s = 0; s < seq_len; ++s) { + const float* src_s = src_b + static_cast(s) * channels; + for (int32_t r = 0; r < repeats; ++r) { + float* dst_s = dst_b + (static_cast(s) * repeats + r) * channels; + std::memcpy(dst_s, src_s, sizeof(float) * channels); + } + } + } + + return out; +} + +class SnakeBeta final : public nn::Module { + nn::Param alpha_; + nn::Param beta_; + float no_div_by_zero_ = 1e-9f; + + public: + SnakeBeta() = default; + SnakeBeta(const std::string& name, int32_t in_features) : nn::Module(name) { + alpha_ = reg("alpha", getModuleName() + ".alpha", std::vector{in_features}); + beta_ = reg("beta", getModuleName() + ".beta", std::vector{in_features}); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + MLLM_RT_ASSERT_EQ(x.device(), kCPU); + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + if (!x.isContiguous()) { x = x.contiguous(); } + + const auto& shape = x.shape(); + const int32_t batch = shape[0]; + const int32_t channels = shape[1]; + const int32_t seq_len = shape[2]; + + auto y = Tensor::empty(shape, kFloat32, kCPU).alloc(); + const auto* x_ptr = x.ptr(); + auto* y_ptr = y.ptr(); + + auto alpha = alpha_.weight(); + auto beta = beta_.weight(); + const auto* alpha_ptr = alpha.ptr(); + const auto* beta_ptr = beta.ptr(); + + const int32_t stride_c = seq_len; + const int32_t stride_b = channels * seq_len; + + for (int32_t b = 0; b < batch; ++b) { + for (int32_t c = 0; c < channels; ++c) { + const float a = std::exp(alpha_ptr[c]); + const float bb = std::exp(beta_ptr[c]); + const float inv_b = 1.0f / (bb + no_div_by_zero_); + const int32_t base = b * stride_b + c * stride_c; + for (int32_t t = 0; t < seq_len; ++t) { + float v = x_ptr[base + t]; + const float s = std::sin(v * a); + v = v + inv_b * (s * s); + y_ptr[base + t] = v; + } + } + } + + return {y}; + } + +}; + +class TorchActivation1d final : public nn::Module { + public: + TorchActivation1d() = default; + TorchActivation1d(const std::string& name, int32_t channels, int32_t up_ratio = 2, int32_t down_ratio = 2, + int32_t up_kernel_size = 12, int32_t down_kernel_size = 12) + : nn::Module(name), + up_ratio_(up_ratio), + down_ratio_(down_ratio), + up_kernel_size_(up_kernel_size), + down_kernel_size_(down_kernel_size) { + act_ = reg("act", channels); + + up_kernel_size_ = (up_kernel_size_ <= 0) ? static_cast(int(6 * up_ratio_ / 2) * 2) : up_kernel_size_; + up_stride_ = up_ratio_; + up_pad_ = up_kernel_size_ / up_ratio_ - 1; + up_pad_left_ = up_pad_ * up_stride_ + (up_kernel_size_ - up_stride_) / 2; + up_pad_right_ = up_pad_ * up_stride_ + (up_kernel_size_ - up_stride_ + 1) / 2; + + down_kernel_size_ = (down_kernel_size_ <= 0) ? static_cast(int(6 * down_ratio_ / 2) * 2) : down_kernel_size_; + down_stride_ = down_ratio_; + down_even_ = (down_kernel_size_ % 2) == 0; + down_pad_left_ = down_kernel_size_ / 2 - (down_even_ ? 1 : 0); + down_pad_right_ = down_kernel_size_ / 2; + + up_filter_ = kaiserSincFilter1d(0.5f / static_cast(up_ratio_), 0.6f / static_cast(up_ratio_), up_kernel_size_); + down_filter_ = + kaiserSincFilter1d(0.5f / static_cast(down_ratio_), 0.6f / static_cast(down_ratio_), down_kernel_size_); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + x = upsample(x); + x = act_(x)[0]; + x = downsample(x); + return {x}; + } + + private: + Tensor upsample(const Tensor& input) const { + auto padded = pad1dReplicate(input, up_pad_, up_pad_); + auto out = convTranspose1dDepthwise(padded, up_filter_, up_stride_); + out = out * static_cast(up_ratio_); + if (up_pad_left_ > 0 || up_pad_right_ > 0) { + auto length = out.shape()[2]; + auto start = up_pad_left_; + auto end = length - up_pad_right_; + out = out[{kAll, kAll, {start, end}}]; + } + return out; + } + + Tensor downsample(const Tensor& input) const { + auto padded = pad1dReplicate(input, down_pad_left_, down_pad_right_); + auto out = conv1dDepthwise(padded, down_filter_, down_stride_); + return out; + } + + SnakeBeta act_; + int32_t up_ratio_ = 2; + int32_t down_ratio_ = 2; + int32_t up_kernel_size_ = 12; + int32_t down_kernel_size_ = 12; + int32_t up_stride_ = 2; + int32_t down_stride_ = 2; + int32_t up_pad_ = 0; + int32_t up_pad_left_ = 0; + int32_t up_pad_right_ = 0; + int32_t down_pad_left_ = 0; + int32_t down_pad_right_ = 0; + bool down_even_ = false; + Tensor up_filter_ = Tensor::nil(); + Tensor down_filter_ = Tensor::nil(); +}; + +class TimeDelayNetBlock final : public nn::Module { + public: + TimeDelayNetBlock() = default; + TimeDelayNetBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t kernel_size, int32_t dilation) + : nn::Module(name), kernel_size_(kernel_size), dilation_(dilation) { + conv_ = reg("conv", in_channels, out_channels, kernel_size_, 1, 0, dilation_, 1, true); + relu_ = reg("relu"); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + const int32_t pad_total = dilation_ * (kernel_size_ - 1); + const int32_t pad_left = pad_total / 2; + const int32_t pad_right = pad_total - pad_left; + if (pad_total > 0) { x = pad1dReflect(x, pad_left, pad_right); } + x = conv_(x); + x = relu_(x); + return {x}; + } + + private: + nn::Conv1D conv_; + nn::ReLU relu_; + int32_t kernel_size_ = 1; + int32_t dilation_ = 1; +}; + +class Res2NetBlock final : public nn::Module { + public: + Res2NetBlock() = default; + Res2NetBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t scale, int32_t kernel_size, int32_t dilation) + : nn::Module(name), scale_(scale) { + const int32_t in_channel = in_channels / scale; + const int32_t hidden_channel = out_channels / scale; + blocks_ = reg>("blocks", scale_ - 1, in_channel, hidden_channel, kernel_size, dilation); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + const int32_t channels = x.shape()[1]; + const int32_t split = channels / scale_; + + std::vector outputs; + outputs.reserve(scale_); + Tensor output_part = Tensor::nil(); + + for (int32_t i = 0; i < scale_; ++i) { + auto hidden_part = x[{kAll, {i * split, (i + 1) * split}, kAll}]; + if (i == 0) { + output_part = hidden_part; + } else if (i == 1) { + output_part = blocks_.list()[i - 1](hidden_part)[0]; + } else { + output_part = blocks_.list()[i - 1](hidden_part + output_part)[0]; + } + outputs.push_back(output_part); + } + + auto out = nn::functional::concat(outputs, 1); + return {out}; + } + + private: + int32_t scale_ = 1; + nn::ModuleList blocks_; +}; + +class SqueezeExcitationBlock final : public nn::Module { + public: + SqueezeExcitationBlock() = default; + SqueezeExcitationBlock(const std::string& name, int32_t in_channels, int32_t se_channels, int32_t out_channels) + : nn::Module(name) { + conv1_ = reg("conv1", in_channels, se_channels, 1, 1, 0, 1, 1, true); + conv2_ = reg("conv2", se_channels, out_channels, 1, 1, 0, 1, 1, true); + relu_ = reg("relu"); + sigmoid_ = reg("sigmoid"); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + auto hidden_mean = nn::functional::mean(hidden_states, 2, true); + hidden_mean = relu_(conv1_(hidden_mean)); + hidden_mean = sigmoid_(conv2_(hidden_mean)); + hidden_states = hidden_states * hidden_mean; + return {hidden_states}; + } + + private: + nn::Conv1D conv1_; + nn::Conv1D conv2_; + nn::ReLU relu_; + nn::Sigmoid sigmoid_; +}; + +class AttentiveStatisticsPooling final : public nn::Module { + public: + AttentiveStatisticsPooling() = default; + AttentiveStatisticsPooling(const std::string& name, int32_t channels, int32_t attention_channels) + : nn::Module(name), channels_(channels) { + tdnn_ = reg("tdnn", channels * 3, attention_channels, 1, 1); + tanh_ = reg("tanh"); + conv_ = reg("conv", attention_channels, channels, 1, 1, 0, 1, 1, true); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + MLLM_RT_ASSERT_EQ(hidden_states.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(hidden_states.device(), kCPU); + + const int32_t batch = hidden_states.shape()[0]; + const int32_t channels = hidden_states.shape()[1]; + const int32_t seq_len = hidden_states.shape()[2]; + + auto mean = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + auto std = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + + const auto* x_ptr = hidden_states.ptr(); + auto* mean_ptr = mean.ptr(); + auto* std_ptr = std.ptr(); + + const int32_t stride_c = seq_len; + const int32_t stride_b = channels * seq_len; + + for (int32_t b = 0; b < batch; ++b) { + for (int32_t c = 0; c < channels; ++c) { + const int32_t base = b * stride_b + c * stride_c; + float sum = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { sum += x_ptr[base + t]; } + float m = sum / static_cast(seq_len); + mean_ptr[b * channels + c] = m; + + float var = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { + float diff = x_ptr[base + t] - m; + var += diff * diff; + } + var /= static_cast(seq_len); + std_ptr[b * channels + c] = std::sqrt(std::max(var, 1e-12f)); + } + } + + auto mean_rep = mean.view({batch, channels, 1}).repeat(seq_len, 2); + auto std_rep = std.view({batch, channels, 1}).repeat(seq_len, 2); + + auto attention = nn::functional::concat({hidden_states, mean_rep, std_rep}, 1); + attention = tdnn_(attention)[0]; + attention = tanh_(attention); + attention = conv_(attention); + attention = nn::functional::softmax(attention, 2); + + auto out_mean = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + auto out_std = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + auto* out_mean_ptr = out_mean.ptr(); + auto* out_std_ptr = out_std.ptr(); + const auto* attn_ptr = attention.ptr(); + + for (int32_t b = 0; b < batch; ++b) { + for (int32_t c = 0; c < channels; ++c) { + const int32_t base = b * stride_b + c * stride_c; + float m = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { m += attn_ptr[base + t] * x_ptr[base + t]; } + out_mean_ptr[b * channels + c] = m; + + float var = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { + float diff = x_ptr[base + t] - m; + var += attn_ptr[base + t] * diff * diff; + } + out_std_ptr[b * channels + c] = std::sqrt(std::max(var, 1e-12f)); + } + } + + auto pooled = nn::functional::concat({out_mean, out_std}, 1).view({batch, channels * 2, 1}); + return {pooled}; + } + + private: + int32_t channels_ = 0; + TimeDelayNetBlock tdnn_; + nn::Tanh tanh_; + nn::Conv1D conv_; +}; + +class SqueezeExcitationRes2NetBlock final : public nn::Module { + public: + SqueezeExcitationRes2NetBlock() = default; + SqueezeExcitationRes2NetBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t res2net_scale, + int32_t se_channels, int32_t kernel_size, int32_t dilation) + : nn::Module(name), out_channels_(out_channels) { + tdnn1_ = reg("tdnn1", in_channels, out_channels, 1, 1); + res2net_block_ = reg("res2net_block", out_channels, out_channels, res2net_scale, kernel_size, dilation); + tdnn2_ = reg("tdnn2", out_channels, out_channels, 1, 1); + se_block_ = reg("se_block", out_channels, se_channels, out_channels); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_state = inputs[0]; + auto residual = hidden_state; + + hidden_state = tdnn1_(hidden_state)[0]; + hidden_state = res2net_block_(hidden_state)[0]; + hidden_state = tdnn2_(hidden_state)[0]; + hidden_state = se_block_(hidden_state)[0]; + hidden_state = hidden_state + residual; + return {hidden_state}; + } + + private: + int32_t out_channels_ = 0; + TimeDelayNetBlock tdnn1_; + Res2NetBlock res2net_block_; + TimeDelayNetBlock tdnn2_; + SqueezeExcitationBlock se_block_; +}; + +class ECAPA_TimeDelayNet final : public nn::Module { + public: + ECAPA_TimeDelayNet() = default; + explicit ECAPA_TimeDelayNet(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name) { + if (cfg.enc_channels.size() != cfg.enc_kernel_sizes.size() || cfg.enc_channels.size() != cfg.enc_dilations.size()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "enc_channels, enc_kernel_sizes and enc_dilations should have same length"); + } + + if (cfg.enc_channels.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "enc_channels should not be empty"); + } + + const int32_t num_blocks = static_cast(cfg.enc_channels.size()); + tdnn0_ = reg("blocks.0", cfg.mel_dim, cfg.enc_channels[0], cfg.enc_kernel_sizes[0], cfg.enc_dilations[0]); + + for (int32_t i = 1; i < num_blocks - 1; ++i) { + se_blocks_.emplace_back(reg( + "blocks." + std::to_string(i), + cfg.enc_channels[i - 1], + cfg.enc_channels[i], + cfg.enc_res2net_scale, + cfg.enc_se_channels, + cfg.enc_kernel_sizes[i], + cfg.enc_dilations[i])); + } + + mfa_ = reg("mfa", cfg.enc_channels.back(), cfg.enc_channels.back(), cfg.enc_kernel_sizes.back(), + cfg.enc_dilations.back()); + asp_ = reg("asp", cfg.enc_channels.back(), cfg.enc_attention_channels); + fc_ = reg("fc", cfg.enc_channels.back() * 2, cfg.enc_dim, 1, 1, 0, 1, 1, true); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + MLLM_RT_ASSERT_EQ(hidden_states.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(hidden_states.device(), kCPU); + + hidden_states = hidden_states.transpose(1, 2); + + std::vector hidden_states_list; + hidden_states = tdnn0_(hidden_states)[0]; + hidden_states_list.push_back(hidden_states); + + for (auto& block : se_blocks_) { + hidden_states = block(hidden_states)[0]; + hidden_states_list.push_back(hidden_states); + } + + if (hidden_states_list.size() <= 1) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "ECAPA_TimeDelayNet expects at least 2 blocks."); + } + + std::vector mfa_inputs; + for (size_t i = 1; i < hidden_states_list.size(); ++i) { mfa_inputs.push_back(hidden_states_list[i]); } + hidden_states = nn::functional::concat(mfa_inputs, 1); + hidden_states = mfa_(hidden_states)[0]; + hidden_states = asp_(hidden_states)[0]; + hidden_states = fc_(hidden_states); + hidden_states = hidden_states.squeeze(-1); + + return {hidden_states}; + } + + private: + TimeDelayNetBlock tdnn0_; + std::vector se_blocks_; + TimeDelayNetBlock mfa_; + AttentiveStatisticsPooling asp_; + nn::Conv1D fc_; +}; + +class DiTInputEmbedding final : public nn::Module { + public: + DiTInputEmbedding() = default; + explicit DiTInputEmbedding(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name) { + const int32_t in_dim = cfg.mel_dim + cfg.enc_dim + cfg.enc_emb_dim + cfg.emb_dim; + proj_ = reg("proj", in_dim, cfg.hidden_size, true); + spk_encoder_ = reg("spk_encoder", cfg); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& speaker_embedding, const Tensor& condition_vector, const Tensor& code_embed, + bool drop_audio_cond, const Tensor& code_embed_uncond, bool apply_cfg) { + auto x = hidden_states; + auto spk = speaker_embedding; + auto cond = condition_vector; + auto code = code_embed; + + if (apply_cfg) { + x = nn::functional::concat({x, x}, 0); + spk = nn::functional::concat({spk, Tensor::zeros(spk.shape(), spk.dtype(), spk.device())}, 0); + cond = nn::functional::concat({cond, Tensor::zeros(cond.shape(), cond.dtype(), cond.device())}, 0); + code = nn::functional::concat({code, code_embed_uncond}, 0); + } else if (drop_audio_cond) { + cond = Tensor::zeros(cond.shape(), cond.dtype(), cond.device()); + spk = Tensor::zeros(spk.shape(), spk.dtype(), spk.device()); + } + + auto cond_embed = spk_encoder_(cond)[0]; + const int32_t seq_len = x.shape()[1]; + cond_embed = cond_embed.view({cond_embed.shape()[0], 1, cond_embed.shape()[1]}).repeat(seq_len, 1); + + auto merged = nn::functional::concat({x, cond_embed, code, spk}, -1); + auto out = proj_(merged); + return out; + } + + private: + nn::Linear proj_; + ECAPA_TimeDelayNet spk_encoder_; +}; + +class DiTCodecEmbedding final : public nn::Module { + public: + DiTCodecEmbedding() = default; + DiTCodecEmbedding(const std::string& name, int32_t codec_num_embeds, int32_t codec_dim, int32_t repeats) + : nn::Module(name), repeats_(repeats) { + codec_embed_ = reg("codec_embed", codec_num_embeds + 1, codec_dim); + } + + Tensor forward(const Tensor& code, bool drop_code) { + Tensor code_ids = code; + if (drop_code) { code_ids = Tensor::zeros(code.shape(), code.dtype(), code.device()); } + auto code_embed = codec_embed_(code_ids); + return repeatInterleave(code_embed, repeats_, 1); + } + + private: + int32_t repeats_ = 1; + nn::Embedding codec_embed_; +}; + +class Qwen2_5_OmniAdaLayerNormZero final : public nn::Module { + public: + Qwen2_5_OmniAdaLayerNormZero() = default; + Qwen2_5_OmniAdaLayerNormZero(const std::string& name, int32_t dim) : nn::Module(name) { + silu_ = reg("silu"); + linear_ = reg("linear", dim, dim * 6, true); + norm_ = reg("norm", std::vector{dim}, false, false, 1e-6f); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + auto emb = inputs[1]; + emb = linear_(silu_(emb)); + + auto chunks = nn::functional::chunk<6>(emb, 1); + auto shift_msa = chunks[0]; + auto scale_msa = chunks[1]; + auto gate_msa = chunks[2]; + auto shift_mlp = chunks[3]; + auto scale_mlp = chunks[4]; + auto gate_mlp = chunks[5]; + + auto normed = norm_(hidden_states); + const int32_t seq_len = hidden_states.shape()[1]; + auto scale = scale_msa.view({scale_msa.shape()[0], 1, scale_msa.shape()[1]}).repeat(seq_len, 1); + auto shift = shift_msa.view({shift_msa.shape()[0], 1, shift_msa.shape()[1]}).repeat(seq_len, 1); + normed = normed * (scale + 1.0f) + shift; + + return {normed, gate_msa, shift_mlp, scale_mlp, gate_mlp}; + } + + private: + nn::SiLU silu_; + nn::Linear linear_; + nn::LayerNorm norm_; +}; + +class Qwen2_5_OmniAdaLayerNormZero_Final final : public nn::Module { + public: + Qwen2_5_OmniAdaLayerNormZero_Final() = default; + Qwen2_5_OmniAdaLayerNormZero_Final(const std::string& name, int32_t dim) : nn::Module(name) { + silu_ = reg("silu"); + linear_ = reg("linear", dim, dim * 2, true); + norm_ = reg("norm", std::vector{dim}, false, false, 1e-6f); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& emb) { + auto emb_out = linear_(silu_(emb)); + auto chunks = nn::functional::chunk<2>(emb_out, 1); + auto scale = chunks[0]; + auto shift = chunks[1]; + + auto normed = norm_(hidden_states); + const int32_t seq_len = hidden_states.shape()[1]; + scale = scale.view({scale.shape()[0], 1, scale.shape()[1]}).repeat(seq_len, 1); + shift = shift.view({shift.shape()[0], 1, shift.shape()[1]}).repeat(seq_len, 1); + normed = normed * (scale + 1.0f) + shift; + return normed; + } + + private: + nn::SiLU silu_; + nn::Linear linear_; + nn::LayerNorm norm_; +}; + +class DiTMLP final : public nn::Module { + public: + DiTMLP() = default; + DiTMLP(const std::string& name, int32_t dim, int32_t mult) : nn::Module(name) { + const int32_t inner_dim = dim * mult; + fc1_ = reg("ff.0", dim, inner_dim, true); + act_ = reg("ff.1"); + fc2_ = reg("ff.3", inner_dim, dim, true); + } + + Tensor forward(const Tensor& hidden_states) { + auto x = fc1_(hidden_states); + x = act_(x); + x = fc2_(x); + return x; + } + + private: + nn::Linear fc1_; + nn::GELU act_; + nn::Linear fc2_; +}; + +inline void applyRotaryPosEmbFirstHead(Tensor& q, Tensor& k, const Tensor& cos, const Tensor& sin) { + MLLM_RT_ASSERT_EQ(q.device(), kCPU); + MLLM_RT_ASSERT_EQ(k.device(), kCPU); + MLLM_RT_ASSERT_EQ(cos.device(), kCPU); + MLLM_RT_ASSERT_EQ(sin.device(), kCPU); + MLLM_RT_ASSERT_EQ(q.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(k.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(cos.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(sin.dtype(), kFloat32); + + const int32_t batch = q.shape()[0]; + const int32_t heads = q.shape()[1]; + const int32_t seq_len = q.shape()[2]; + const int32_t head_dim = q.shape()[3]; + MLLM_RT_ASSERT_EQ(head_dim % 2, 0); + MLLM_RT_ASSERT_EQ(cos.shape()[0], batch); + MLLM_RT_ASSERT_EQ(cos.shape()[1], seq_len); + MLLM_RT_ASSERT_EQ(cos.shape()[2], head_dim); + + const auto* cos_ptr = cos.ptr(); + const auto* sin_ptr = sin.ptr(); + auto* q_ptr = q.ptr(); + auto* k_ptr = k.ptr(); + + const int64_t stride_q_b = static_cast(heads) * seq_len * head_dim; + const int64_t stride_q_h = static_cast(seq_len) * head_dim; + const int64_t stride_q_s = head_dim; + + const int64_t stride_cos_b = static_cast(seq_len) * head_dim; + const int64_t stride_cos_s = head_dim; + + for (int32_t b = 0; b < batch; ++b) { + const int64_t q_base_b = static_cast(b) * stride_q_b; + const int64_t cos_base_b = static_cast(b) * stride_cos_b; + for (int32_t s = 0; s < seq_len; ++s) { + float* q_row = q_ptr + q_base_b + 0 * stride_q_h + static_cast(s) * stride_q_s; + float* k_row = k_ptr + q_base_b + 0 * stride_q_h + static_cast(s) * stride_q_s; + const float* cos_row = cos_ptr + cos_base_b + static_cast(s) * stride_cos_s; + const float* sin_row = sin_ptr + cos_base_b + static_cast(s) * stride_cos_s; + for (int32_t d = 0; d < head_dim; d += 2) { + const float c = cos_row[d]; + const float ss = sin_row[d]; + const float q1 = q_row[d]; + const float q2 = q_row[d + 1]; + const float k1 = k_row[d]; + const float k2 = k_row[d + 1]; + q_row[d] = q1 * c - q2 * ss; + q_row[d + 1] = q1 * ss + q2 * c; + k_row[d] = k1 * c - k2 * ss; + k_row[d + 1] = k1 * ss + k2 * c; + } + } + } +} + +inline Tensor makeBlockDiff(int32_t batch, int32_t heads, int32_t seq_len, int32_t block_size) { + (void)heads; + MLLM_RT_ASSERT(block_size > 0); + std::vector block_indices(seq_len, 0); + for (int32_t i = 0; i < seq_len; ++i) { block_indices[i] = i / block_size; } + + std::vector base(static_cast(seq_len) * seq_len, 0.0f); + for (int32_t i = 0; i < seq_len; ++i) { + for (int32_t j = 0; j < seq_len; ++j) { + base[static_cast(i) * seq_len + j] = static_cast(block_indices[j] - block_indices[i]); + } + } + + // Use a broadcast-friendly shape to avoid materializing head copies while keeping naive broadcast support. + auto out = Tensor::empty({batch, 1, seq_len, seq_len}, kFloat32, kCPU).alloc(); + const int64_t block_stride = static_cast(seq_len) * seq_len; + auto* out_ptr = out.ptr(); + for (int32_t b = 0; b < batch; ++b) { + float* dst = out_ptr + static_cast(b) * block_stride; + std::memcpy(dst, base.data(), sizeof(float) * base.size()); + } + return out; +} + +inline Tensor makeBlockMask(const Tensor& block_diff, int32_t look_backward_block, int32_t look_ahead_block) { + MLLM_RT_ASSERT_EQ(block_diff.device(), kCPU); + MLLM_RT_ASSERT_EQ(block_diff.dtype(), kFloat32); + + auto mask = Tensor::empty(block_diff.shape(), kFloat32, kCPU).alloc(); + const auto* src = block_diff.ptr(); + auto* dst = mask.ptr(); + const int64_t numel = block_diff.numel(); + const float lower = -static_cast(look_backward_block); + const float upper = static_cast(look_ahead_block); + + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + const float v = src[idx]; + dst[idx] = (v >= lower && v <= upper) ? 0.0f : -1e4f; + }); + return mask; +} + +class DiTAttention final : public nn::Module { + public: + DiTAttention() = default; + explicit DiTAttention(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name), cfg_(cfg) { + dim_ = cfg.hidden_size; + heads_ = cfg.num_attention_heads; + head_dim_ = cfg.head_dim; + inner_dim_ = head_dim_ * heads_; + + to_q_ = reg("to_q", dim_, inner_dim_, true); + to_k_ = reg("to_k", dim_, inner_dim_, true); + to_v_ = reg("to_v", dim_, inner_dim_, true); + to_out_ = reg("to_out.0", inner_dim_, dim_, true); + } + + Tensor forward(const Tensor& hidden_states, const std::pair& position_embeddings, const Tensor& attention_mask) { + auto query = to_q_(hidden_states); + auto key = to_k_(hidden_states); + auto value = to_v_(hidden_states); + + const int32_t batch = hidden_states.shape()[0]; + const int32_t seq_len = hidden_states.shape()[1]; + + query = query.view({batch, seq_len, heads_, head_dim_}).transpose(1, 2); + key = key.view({batch, seq_len, heads_, head_dim_}).transpose(1, 2); + value = value.view({batch, seq_len, heads_, head_dim_}).transpose(1, 2); + + if (!position_embeddings.first.isNil()) { + applyRotaryPosEmbFirstHead(query, key, position_embeddings.first, position_embeddings.second); + } + + auto attn_output = nn::functional::scaledDotProductAttention(query, key, value, attention_mask); + attn_output = attn_output.transpose(1, 2).view({batch, seq_len, inner_dim_}); + attn_output = to_out_(attn_output); + return attn_output; + } + + private: + Qwen2_5OmniDiTConfig cfg_; + int32_t dim_ = 0; + int32_t heads_ = 0; + int32_t head_dim_ = 0; + int32_t inner_dim_ = 0; + nn::Linear to_q_; + nn::Linear to_k_; + nn::Linear to_v_; + nn::Linear to_out_; +}; + +class SinusPositionEmbedding final : public nn::Module { + public: + SinusPositionEmbedding() = default; + explicit SinusPositionEmbedding(const std::string& name, int32_t dim) : nn::Module(name), dim_(dim) {} + + Tensor forward(const Tensor& hidden_states, float scale = 1000.0f) { + MLLM_RT_ASSERT_EQ(hidden_states.device(), kCPU); + MLLM_RT_ASSERT_EQ(hidden_states.dtype(), kFloat32); + + const int32_t batch = hidden_states.shape()[0]; + const int32_t half_dim = dim_ / 2; + auto out = Tensor::empty({batch, dim_}, kFloat32, kCPU).alloc(); + auto* out_ptr = out.ptr(); + const auto* hs_ptr = hidden_states.ptr(); + + const float emb = std::log(10000.0f) / static_cast(half_dim - 1); + std::vector freqs(half_dim); + for (int32_t i = 0; i < half_dim; ++i) { freqs[i] = std::exp(-emb * static_cast(i)); } + + for (int32_t b = 0; b < batch; ++b) { + const float t = hs_ptr[b] * scale; + float* row = out_ptr + static_cast(b) * dim_; + for (int32_t i = 0; i < half_dim; ++i) { + const float val = t * freqs[i]; + row[i] = std::sin(val); + row[i + half_dim] = std::cos(val); + } + } + + return out; + } + + private: + int32_t dim_ = 0; +}; + +class DiTTimestepEmbedding final : public nn::Module { + public: + DiTTimestepEmbedding() = default; + explicit DiTTimestepEmbedding(const std::string& name, int32_t dim, int32_t freq_embed_dim = 256) + : nn::Module(name), freq_embed_dim_(freq_embed_dim) { + time_embed_ = reg("time_embed", freq_embed_dim_); + fc1_ = reg("time_mlp.0", freq_embed_dim_, dim, true); + act_ = reg("time_mlp.1"); + fc2_ = reg("time_mlp.2", dim, dim, true); + } + + Tensor forward(const Tensor& timestep) { + auto time_hidden = time_embed_.forward(timestep); + time_hidden = fc1_(time_hidden); + time_hidden = act_(time_hidden); + time_hidden = fc2_(time_hidden); + return time_hidden; + } + + private: + int32_t freq_embed_dim_ = 256; + SinusPositionEmbedding time_embed_; + nn::Linear fc1_; + nn::SiLU act_; + nn::Linear fc2_; +}; + +class DiTDecoderLayer final : public nn::Module { + public: + DiTDecoderLayer() = default; + DiTDecoderLayer(const std::string& name, const Qwen2_5OmniDiTConfig& cfg, int32_t look_ahead_block, int32_t look_backward_block) + : nn::Module(name), look_ahead_block_(look_ahead_block), look_backward_block_(look_backward_block) { + attn_norm_ = reg("attn_norm", cfg.hidden_size); + attn_ = reg("attn", cfg); + ff_norm_ = reg("ff_norm", std::vector{cfg.hidden_size}, false, false, 1e-6f); + ff_ = reg("ff", cfg.hidden_size, cfg.ff_mult); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& timestep, const std::pair& position_embeddings, + const Tensor& block_diff) { + auto attn_norm_out = attn_norm_(hidden_states, timestep); + auto norm = attn_norm_out[0]; + auto gate_msa = attn_norm_out[1]; + auto shift_mlp = attn_norm_out[2]; + auto scale_mlp = attn_norm_out[3]; + auto gate_mlp = attn_norm_out[4]; + + Tensor attn_mask = Tensor::nil(); + if (!block_diff.isNil()) { attn_mask = makeBlockMask(block_diff, look_backward_block_, look_ahead_block_); } + auto attn_output = attn_.forward(norm, position_embeddings, attn_mask); + + auto gate_msa_rep = gate_msa.view({gate_msa.shape()[0], 1, gate_msa.shape()[1]}).repeat(hidden_states.shape()[1], 1); + auto x = Tensor(hidden_states); + x = x + gate_msa_rep * attn_output; + + auto norm_ff = ff_norm_(x); + auto scale_rep = scale_mlp.view({scale_mlp.shape()[0], 1, scale_mlp.shape()[1]}).repeat(x.shape()[1], 1); + auto shift_rep = shift_mlp.view({shift_mlp.shape()[0], 1, shift_mlp.shape()[1]}).repeat(x.shape()[1], 1); + norm_ff = norm_ff * (scale_rep + 1.0f) + shift_rep; + auto ff_output = ff_.forward(norm_ff); + auto gate_mlp_rep = gate_mlp.view({gate_mlp.shape()[0], 1, gate_mlp.shape()[1]}).repeat(x.shape()[1], 1); + x = x + gate_mlp_rep * ff_output; + return x; + } + + private: + Qwen2_5_OmniAdaLayerNormZero attn_norm_; + DiTAttention attn_; + nn::LayerNorm ff_norm_; + DiTMLP ff_; + int32_t look_ahead_block_ = 0; + int32_t look_backward_block_ = 0; +}; + +class Qwen2_5OmniDiTRotaryEmbedding final : public nn::Module { + public: + Qwen2_5OmniDiTRotaryEmbedding() = default; + explicit Qwen2_5OmniDiTRotaryEmbedding(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name), cfg_(cfg) { + const int32_t dim = cfg.head_dim; + inv_freq_ = reg("inv_freq", getModuleName() + ".inv_freq", std::vector{dim / 2}); + attention_scaling_ = 1.0f; + + auto inv = inv_freq_.weight(); + if (!inv.isNil() && inv.numel() == 0) { + inv = Tensor::empty({dim / 2}, kFloat32, kCPU).alloc(); + inv_freq_.weight().copy2(inv); + } + } + + std::pair forward(const Tensor& x, const Tensor& position_ids) { + MLLM_RT_ASSERT_EQ(x.device(), kCPU); + MLLM_RT_ASSERT_EQ(position_ids.device(), kCPU); + MLLM_RT_ASSERT_EQ(position_ids.dtype(), kInt64); + + const int32_t batch = position_ids.shape()[0]; + const int32_t seq_len = position_ids.shape()[1]; + auto inv_freq = inv_freq_.weight(); + if (inv_freq.isNil() || inv_freq.numel() == 0) { + const int32_t dim = cfg_.head_dim; + inv_freq = Tensor::empty({dim / 2}, kFloat32, kCPU).alloc(); + auto* ptr = inv_freq.ptr(); + for (int32_t i = 0; i < dim / 2; ++i) { + ptr[i] = 1.0f / std::pow(cfg_.rope_theta, 2.0f * i / static_cast(dim)); + } + } + + const int32_t half_dim = inv_freq.shape()[0]; + auto cos = Tensor::empty({batch, seq_len, half_dim * 2}, kFloat32, kCPU).alloc(); + auto sin = Tensor::empty({batch, seq_len, half_dim * 2}, kFloat32, kCPU).alloc(); + + const auto* inv_ptr = inv_freq.ptr(); + const auto* pos_ptr = position_ids.ptr(); + auto* cos_ptr = cos.ptr(); + auto* sin_ptr = sin.ptr(); + + const int64_t stride_pos_b = seq_len; + const int64_t stride_cos_b = static_cast(seq_len) * half_dim * 2; + const int64_t stride_cos_s = half_dim * 2; + + for (int32_t b = 0; b < batch; ++b) { + const int64_t pos_base = static_cast(b) * stride_pos_b; + const int64_t out_base = static_cast(b) * stride_cos_b; + for (int32_t s = 0; s < seq_len; ++s) { + const float position = static_cast(pos_ptr[pos_base + s]); + float* cos_row = cos_ptr + out_base + static_cast(s) * stride_cos_s; + float* sin_row = sin_ptr + out_base + static_cast(s) * stride_cos_s; + for (int32_t d = 0; d < half_dim; ++d) { + const float freq = inv_ptr[d] * position; + const float c = std::cos(freq) * attention_scaling_; + const float ss = std::sin(freq) * attention_scaling_; + cos_row[d] = c; + cos_row[d + half_dim] = c; + sin_row[d] = ss; + sin_row[d + half_dim] = ss; + } + } + } + + return {cos, sin}; + } + + private: + Qwen2_5OmniDiTConfig cfg_; + nn::Param inv_freq_; + float attention_scaling_ = 1.0f; +}; + +class RungeKutta4ODESolver { + public: + using Function = std::function; + + RungeKutta4ODESolver(Function function, Tensor initial_value) + : function_(std::move(function)), initial_value_(std::move(initial_value)) {} + + Tensor integrate(const std::vector& time_points) { + auto current_value = initial_value_; + if (time_points.size() < 2) { return current_value; } + + for (size_t i = 0; i + 1 < time_points.size(); ++i) { + const float time_start = time_points[i]; + const float time_end = time_points[i + 1]; + const float time_step = time_end - time_start; + + auto k1 = function_(time_start, current_value); + auto k2 = function_(time_start + time_step * one_third_, current_value + k1 * (time_step * one_third_)); + auto k3 = function_(time_start + time_step * two_thirds_, + current_value + (k2 - k1 * one_third_) * time_step); + auto k4 = function_(time_end, current_value + (k1 - k2 + k3) * time_step); + + auto delta = (k1 + (k2 + k3) * 3.0f + k4) * (time_step / 8.0f); + current_value = current_value + delta; + } + + return current_value; + } + + private: + Function function_; + Tensor initial_value_; + float one_third_ = 1.0f / 3.0f; + float two_thirds_ = 2.0f / 3.0f; +}; + +class Qwen2_5OmniToken2WavDiTModel final : public nn::Module { + public: + Qwen2_5OmniToken2WavDiTModel() = default; + explicit Qwen2_5OmniToken2WavDiTModel(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name), cfg_(cfg) { + mel_dim_ = cfg.mel_dim; + repeats_ = cfg.repeats; + block_size_ = cfg.block_size; + num_attention_heads_ = cfg.num_attention_heads; + + time_embed_ = reg("time_embed", cfg.hidden_size); + text_embed_ = reg("text_embed", cfg.num_embeds, cfg.emb_dim, cfg.repeats); + input_embed_ = reg("input_embed", cfg); + rotary_embed_ = reg("rotary_embed", cfg); + + for (int32_t i = 0; i < cfg.num_hidden_layers; ++i) { + const bool look_ahead = std::find(cfg.look_ahead_layers.begin(), cfg.look_ahead_layers.end(), i) != cfg.look_ahead_layers.end(); + const bool look_backward = + std::find(cfg.look_backward_layers.begin(), cfg.look_backward_layers.end(), i) != cfg.look_backward_layers.end(); + transformer_blocks_.emplace_back(reg("transformer_blocks." + std::to_string(i), cfg, look_ahead ? 1 : 0, + look_backward ? 1 : 0)); + } + + norm_out_ = reg("norm_out", cfg.hidden_size); + proj_out_ = reg("proj_out", cfg.hidden_size, cfg.mel_dim, true); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& condition_vector, const Tensor& speaker_embedding, const Tensor& quantized_code, + const Tensor& time_step, bool drop_audio_conditioning, bool drop_code, bool apply_cfg) { + Tensor timestep = time_step; + if (timestep.shape().empty()) { timestep = timestep.view({1}); } + if (timestep.shape().size() == 1 && timestep.shape()[0] == 1 && hidden_states.shape()[0] > 1) { + timestep = timestep.repeat(hidden_states.shape()[0], 0); + } + + auto time_embedding = time_embed_.forward(timestep); + auto text_embedding = text_embed_.forward(quantized_code, apply_cfg ? false : drop_code); + Tensor text_embedding_uncond = Tensor::nil(); + if (apply_cfg) { text_embedding_uncond = text_embed_.forward(quantized_code, true); } + + auto x = input_embed_.forward(hidden_states, speaker_embedding, condition_vector, text_embedding, drop_audio_conditioning, + text_embedding_uncond, apply_cfg); + + const int32_t seq_len = x.shape()[1]; + auto position_ids = Tensor::empty({x.shape()[0], seq_len}, kInt64, kCPU).alloc(); + auto* pos_ptr = position_ids.ptr(); + for (int32_t b = 0; b < position_ids.shape()[0]; ++b) { + for (int32_t s = 0; s < seq_len; ++s) { pos_ptr[b * seq_len + s] = s; } + } + + auto position_embeddings = rotary_embed_.forward(x, position_ids); + auto block_diff = makeBlockDiff(x.shape()[0], num_attention_heads_, seq_len, block_size_); + + for (auto& block : transformer_blocks_) { x = block.forward(x, time_embedding, position_embeddings, block_diff); } + + x = norm_out_.forward(x, time_embedding); + x = proj_out_(x); + return x; + } + + Tensor sample(const Tensor& conditioning_vector, const Tensor& reference_mel, const Tensor& quantized_code, int32_t num_steps, + float guidance_scale, float sway_coefficient) { + const int32_t max_duration = quantized_code.shape()[1] * repeats_; + auto initial_state = randomNormal({1, max_duration, mel_dim_}); + + const int32_t batch = reference_mel.shape()[0]; + if (batch != 1) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "Only batch size = 1 is supported for Qwen2.5-Omni token2wav."); } + + auto cond = Tensor(conditioning_vector); + cond = cond.view({batch, 1, conditioning_vector.shape()[1]}).repeat(max_duration, 1); + + auto ode_function = [&](float time_step, const Tensor& hidden) -> Tensor { + auto t = Tensor::empty({1}, kFloat32, kCPU).alloc(); + t.ptr()[0] = time_step; + + if (guidance_scale < 1e-5f) { + return forward(hidden, reference_mel, cond, quantized_code, t, false, false, false); + } + + auto model_output = forward(hidden, reference_mel, cond, quantized_code, t, false, false, true); + auto outputs = nn::functional::chunk<2>(model_output, 0); + return outputs[0] + (outputs[0] - outputs[1]) * guidance_scale; + }; + + auto time_points_tensor = linspace(0.0f, 1.0f, num_steps); + std::vector time_points(static_cast(num_steps)); + const auto* tp_ptr = time_points_tensor.ptr(); + for (int32_t i = 0; i < num_steps; ++i) { time_points[i] = tp_ptr[i]; } + + if (sway_coefficient != 0.0f) { + for (auto& t : time_points) { + t = t + sway_coefficient * (std::cos(kPi / 2.0f * t) - 1.0f + t); + } + } + + RungeKutta4ODESolver solver(ode_function, initial_state); + auto generated = solver.integrate(time_points); + auto mel = generated.permute({0, 2, 1}); + if (!mel.isContiguous()) { mel = mel.contiguous(); } + return mel; + } + + private: + Qwen2_5OmniDiTConfig cfg_; + int32_t mel_dim_ = 0; + int32_t repeats_ = 1; + int32_t block_size_ = 1; + int32_t num_attention_heads_ = 1; + + DiTTimestepEmbedding time_embed_; + DiTCodecEmbedding text_embed_; + DiTInputEmbedding input_embed_; + Qwen2_5OmniDiTRotaryEmbedding rotary_embed_; + std::vector transformer_blocks_; + Qwen2_5_OmniAdaLayerNormZero_Final norm_out_; + nn::Linear proj_out_; +}; + +class AMPBlock final : public nn::Module { + public: + AMPBlock() = default; + AMPBlock(const std::string& name, int32_t channels, int32_t kernel_size, const std::vector& dilations) + : nn::Module(name) { + if (dilations.size() != 3) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "AMPBlock expects 3 dilation values."); } + + convs1_.emplace_back(reg("convs1.0", channels, channels, kernel_size, 1, getPadding(kernel_size, dilations[0]), + dilations[0], 1, true)); + convs1_.emplace_back(reg("convs1.1", channels, channels, kernel_size, 1, getPadding(kernel_size, dilations[1]), + dilations[1], 1, true)); + convs1_.emplace_back(reg("convs1.2", channels, channels, kernel_size, 1, getPadding(kernel_size, dilations[2]), + dilations[2], 1, true)); + + convs2_.emplace_back(reg("convs2.0", channels, channels, kernel_size, 1, getPadding(kernel_size, 1), 1, 1, true)); + convs2_.emplace_back(reg("convs2.1", channels, channels, kernel_size, 1, getPadding(kernel_size, 1), 1, 1, true)); + convs2_.emplace_back(reg("convs2.2", channels, channels, kernel_size, 1, getPadding(kernel_size, 1), 1, 1, true)); + + const int32_t num_layers = static_cast(convs1_.size() + convs2_.size()); + for (int32_t i = 0; i < num_layers; ++i) { + activations_.emplace_back(reg("activations." + std::to_string(i), channels)); + } + } + + Tensor forward(const Tensor& hidden_states) { + auto out = hidden_states; + const int32_t num_blocks = static_cast(convs1_.size()); + for (int32_t i = 0; i < num_blocks; ++i) { + auto residual = out; + auto x = activations_[i * 2].forward({out}, {})[0]; + x = convs1_[i](x); + x = activations_[i * 2 + 1].forward({x}, {})[0]; + x = convs2_[i](x); + out = residual + x; + } + return out; + } + + private: + static int32_t getPadding(int32_t kernel_size, int32_t dilation) { + return static_cast((kernel_size * dilation - dilation) / 2); + } + + std::vector convs1_; + std::vector convs2_; + std::vector activations_; +}; + +class Qwen2_5OmniToken2WavBigVGANModel final : public nn::Module { + public: + Qwen2_5OmniToken2WavBigVGANModel() = default; + explicit Qwen2_5OmniToken2WavBigVGANModel(const std::string& name, const Qwen2_5OmniBigVGANConfig& cfg) : nn::Module(name), cfg_(cfg) { + num_residual_blocks_ = static_cast(cfg.resblock_kernel_sizes.size()); + num_upsample_layers_ = static_cast(cfg.upsample_rates.size()); + + conv_pre_ = reg("conv_pre", cfg.mel_dim, cfg.upsample_initial_channel, 7, 1, 3, 1, 1, true); + + for (int32_t layer_idx = 0; layer_idx < num_upsample_layers_; ++layer_idx) { + const int32_t stride = cfg.upsample_rates[layer_idx]; + const int32_t kernel = cfg.upsample_kernel_sizes[layer_idx]; + const int32_t in_ch = cfg.upsample_initial_channel / static_cast(std::pow(2, layer_idx)); + const int32_t out_ch = cfg.upsample_initial_channel / static_cast(std::pow(2, layer_idx + 1)); + const int32_t padding = (kernel - stride) / 2; + ups_.emplace_back(reg("ups." + std::to_string(layer_idx) + ".0", in_ch, out_ch, kernel, stride, + padding, 0, 1, 1, true)); + } + + for (int32_t layer_idx = 0; layer_idx < num_upsample_layers_; ++layer_idx) { + const int32_t channels = cfg.upsample_initial_channel / static_cast(std::pow(2, layer_idx + 1)); + for (size_t i = 0; i < cfg.resblock_kernel_sizes.size(); ++i) { + resblocks_.emplace_back(reg("resblocks." + std::to_string(resblocks_.size()), channels, + cfg.resblock_kernel_sizes[i], cfg.resblock_dilation_sizes[i])); + } + } + + activation_post_ = + reg("activation_post", cfg.upsample_initial_channel / static_cast(std::pow(2, num_upsample_layers_))); + conv_post_ = reg("conv_post", + cfg.upsample_initial_channel / static_cast(std::pow(2, num_upsample_layers_)), 1, 7, 1, 3, 1, 1, + false); + } + + Tensor forward(const Tensor& mel_spectrogram) { + auto mel = mel_spectrogram; + if (!mel.isContiguous()) { mel = mel.contiguous(); } + auto processed = processMelSpectrogram(mel); + return forwardProcessed(processed); + } + + private: + Tensor forwardProcessed(const Tensor& processed) { + auto hidden = conv_pre_(processed); + + for (int32_t layer_idx = 0; layer_idx < num_upsample_layers_; ++layer_idx) { + hidden = ups_[layer_idx](hidden); + Tensor residual_sum = Tensor::zeros(hidden.shape(), hidden.dtype(), hidden.device()); + for (int32_t block_idx = 0; block_idx < num_residual_blocks_; ++block_idx) { + residual_sum = residual_sum + resblocks_[layer_idx * num_residual_blocks_ + block_idx].forward(hidden); + } + hidden = residual_sum * (1.0f / static_cast(num_residual_blocks_)); + } + + hidden = activation_post_.forward({hidden}, {})[0]; + auto output = conv_post_(hidden); + output = clampTensor(output, -1.0f, 1.0f); + return output.squeeze(); + } + Tensor processMelSpectrogram(const Tensor& mel_spectrogram) const { + auto amplitude = nn::functional::exp(mel_spectrogram); + auto decibel = amplitudeToDb(amplitude, -115.0f) + (-20.0f); + return normalizeSpectrogram(decibel, 1.0f, -115.0f); + } + + Qwen2_5OmniBigVGANConfig cfg_; + int32_t num_residual_blocks_ = 0; + int32_t num_upsample_layers_ = 0; + nn::Conv1D conv_pre_; + std::vector ups_; + std::vector resblocks_; + TorchActivation1d activation_post_; + nn::Conv1D conv_post_; +}; + +class Qwen2_5OmniToken2WavModel final : public nn::Module { + public: + Qwen2_5OmniToken2WavModel() = default; + explicit Qwen2_5OmniToken2WavModel(const std::string& name, const Qwen2_5OmniToken2WavConfig& cfg) : nn::Module(name), cfg_(cfg) { + code2wav_dit_model_ = reg("code2wav_dit_model", cfg.dit_config); + code2wav_bigvgan_model_ = reg("code2wav_bigvgan_model", cfg.bigvgan_config); + } + + Tensor forward(const Tensor& code, const Tensor& conditioning, const Tensor& reference_mel, int32_t num_steps = 10, + float guidance_scale = 0.5f, float sway_coefficient = -1.0f) { + auto mel = code2wav_dit_model_.sample(conditioning, reference_mel, code, num_steps, guidance_scale, sway_coefficient); + if (!mel.isContiguous()) { mel = mel.contiguous(); } + return code2wav_bigvgan_model_.forward(mel); + } + + Tensor vocodeMel(const Tensor& mel) { + return code2wav_bigvgan_model_.forward(mel); + } + + private: + Qwen2_5OmniToken2WavConfig cfg_; + Qwen2_5OmniToken2WavDiTModel code2wav_dit_model_; + Qwen2_5OmniToken2WavBigVGANModel code2wav_bigvgan_model_; +}; + +} // namespace token2wav + +using token2wav::Qwen2_5OmniToken2WavBigVGANModel; +using token2wav::Qwen2_5OmniToken2WavDiTModel; +using token2wav::Qwen2_5OmniToken2WavModel; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp new file mode 100644 index 000000000..961b5c8f2 --- /dev/null +++ b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp @@ -0,0 +1,385 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen2vl/image_preprocessor_qwen2vl.hpp" +#include "mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::models::qwen2_5omni { + +// same regex as Qwen2/Qwen2-VL tokenizers +inline bool qwen2_5OmniTokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + { + size_t original_pos = pos; + bool has_prefix = false; + matched.clear(); + + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + has_prefix = true; + } + + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } else if (has_prefix) { + pos = original_pos; + matched.clear(); + } + } + + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + { + size_t original_pos = pos; + matched.clear(); + size_t start = pos; + + if (str[pos] == L' ') { ++pos; } + + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + + matched = str.substr(start, pos - start); + + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } else { + pos = original_pos; + } + } + + { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos; + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos >= str.size() || std::iswspace(str[pos])) { + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline bool qwen2_5OmniRegex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (qwen2_5OmniTokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } + return true; +} + +struct Qwen2_5OmniMessage { + std::string prompt; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +struct Qwen2_5OmniVisionMessage { + std::string prompt; + std::string img_file_path; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +struct Qwen2_5OmniAudioMessage { + std::string prompt; + std::string audio_file_path; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +class Qwen2_5OmniTokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit Qwen2_5OmniTokenizer(const std::string& file_path, + int32_t spatial_merge_size = 2, + int32_t min_pixels = 56 * 56, + int32_t max_pixels = 1280 * 1280, + int32_t audio_sample_rate = 16000, + int32_t audio_n_mels = 128, + int32_t audio_hop_length = 160, + int32_t audio_chunk_length = 300) + //interestingly, the answer went bad when setting max_pixels higher, eg. 3584*3584) + : image_preprocessor_(min_pixels, max_pixels), + audio_preprocessor_(audio_sample_rate, audio_n_mels, audio_hop_length, audio_chunk_length), + spatial_merge_size_(spatial_merge_size) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + bpe_.initFromSentencePieceJson(file_path); + special_tokens_trie_.add(L"<|endoftext|>"); + special_tokens_trie_.add(L"<|im_start|>"); + special_tokens_trie_.add(L"<|im_end|>"); + special_tokens_trie_.add(L"<|object_ref_start|>"); + special_tokens_trie_.add(L"<|object_ref_end|>"); + special_tokens_trie_.add(L"<|box_start|>"); + special_tokens_trie_.add(L"<|box_end|>"); + special_tokens_trie_.add(L"<|quad_start|>"); + special_tokens_trie_.add(L"<|quad_end|>"); + special_tokens_trie_.add(L"<|vision_bos|>"); + special_tokens_trie_.add(L"<|vision_eos|>"); + special_tokens_trie_.add(L"<|vision_pad|>"); + special_tokens_trie_.add(L"<|image_pad|>"); + special_tokens_trie_.add(L"<|video_pad|>"); + special_tokens_trie_.add(L"<|AUDIO|>"); + special_tokens_trie_.add(L"<|audio_bos|>"); + special_tokens_trie_.add(L"<|audio_eos|>"); + special_tokens_trie_.add(L"<|IMAGE|>"); + special_tokens_trie_.add(L"<|VIDEO|>"); + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + ::mllm::models::qwen2_5omni::qwen2_5OmniRegex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + + auto bpe_ts = bpe_._bpe(mapped_str); + + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + + return ret; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); } + return {mllm::preprocessor::utf8string2WideString(utf_8_str)}; + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return ret; + } + + ARGenerationOutputPast convertMessage(const Qwen2_5OmniMessage& message) { + auto applied_string = message.buildChatMessage(); + auto sequence_str = tokenize(applied_string); + + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return {{"sequence", sequence}}; + } + + ARGenerationOutputPast convertVisionMessage(const Qwen2_5OmniVisionMessage& message) { + auto applied_string = message.buildChatMessage(); + + auto [img, grid_thw] = image_preprocessor_(message.img_file_path); + + auto sequence_str = tokenize(applied_string); + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + auto grid_t = grid_thw.ptr()[0]; + auto grid_h = grid_thw.ptr()[1]; + auto grid_w = grid_thw.ptr()[2]; + int32_t img_token_nums = grid_t * grid_h * grid_w; + img_token_nums /= (spatial_merge_size_ * spatial_merge_size_); + + auto image_token_id = bpe_._lookup_vocab(L"<|IMAGE|>"); + { + auto it = std::find(ids.begin(), ids.end(), image_token_id); + if (it == ids.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing <|IMAGE|> token in Qwen2.5-Omni prompt template."); + } + ids.insert(it + 1, img_token_nums - 1, image_token_id); + } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + {"img", img}, + {"grid_thw", grid_thw}, + }; + } + + ARGenerationOutputPast convertAudioMessage(const Qwen2_5OmniAudioMessage& message) { + auto applied_string = message.buildChatMessage(); + auto sequence_str = tokenize(applied_string); + + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + auto audio_result = audio_preprocessor_.processAudioFile(message.audio_file_path); + if (audio_result.input_features.isNil() || audio_result.feature_length <= 0) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Failed to extract audio features for Qwen2.5-Omni."); + } + + int32_t audio_token_nums = audio_preprocessor_.calcAudioTokenLength(audio_result.feature_length); + if (audio_token_nums <= 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Invalid audio token length for Qwen2.5-Omni."); + } + + auto audio_token_id = bpe_._lookup_vocab(L"<|AUDIO|>"); + { + auto it = std::find(ids.begin(), ids.end(), audio_token_id); + if (it == ids.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing <|AUDIO|> token in Qwen2.5-Omni prompt template."); + } + ids.insert(it + 1, audio_token_nums - 1, audio_token_id); + } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + audio_result.input_features.setName("input_features"); + + return { + {"sequence", sequence}, + {"input_features", audio_result.input_features}, + }; + } + + private: + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; + mllm::models::qwen2vl::Qwen2VLImagePreprocessor image_preprocessor_; + Qwen2_5OmniAudioPreprocessor audio_preprocessor_; + int32_t spatial_merge_size_ = 2; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/nn/Nn.hpp b/mllm/nn/Nn.hpp index fdb0edc82..160e1cb43 100644 --- a/mllm/nn/Nn.hpp +++ b/mllm/nn/Nn.hpp @@ -11,6 +11,7 @@ #include "mllm/nn/layers/RMSNorm.hpp" // IWYU pragma: export #include "mllm/nn/layers/SiLU.hpp" // IWYU pragma: export #include "mllm/nn/layers/Sigmoid.hpp" // IWYU pragma: export +#include "mllm/nn/layers/Tanh.hpp" // IWYU pragma: export #include "mllm/nn/layers/Embedding.hpp" // IWYU pragma: export #include "mllm/nn/layers/GELU.hpp" // IWYU pragma: export #include "mllm/nn/layers/QuickGELU.hpp" // IWYU pragma: export @@ -26,6 +27,7 @@ #include "mllm/nn/layers/Param.hpp" // IWYU pragma: export #include "mllm/nn/layers/KVCache.hpp" // IWYU pragma: export #include "mllm/nn/layers/Conv1D.hpp" // IWYU pragma: export +#include "mllm/nn/layers/ConvTranspose1D.hpp" // IWYU pragma: export #include "mllm/nn/layers/AvgPool1d.hpp" // IWYU pragma: export #include "mllm/nn/layers/STFT.hpp" // IWYU pragma: export #include "mllm/nn/layers/PagedAttn.hpp" // IWYU pragma: export diff --git a/mllm/nn/layers/ConvTranspose1D.cpp b/mllm/nn/layers/ConvTranspose1D.cpp new file mode 100644 index 000000000..de2a7a5c7 --- /dev/null +++ b/mllm/nn/layers/ConvTranspose1D.cpp @@ -0,0 +1,32 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/nn/layers/ConvTranspose1D.hpp" + +namespace mllm::nn { + +ConvTranspose1D::ConvTranspose1D() : Layer(OpTypes::kConvTranspose1D, aops::ConvTranspose1DOpOptions{}) {} + +ConvTranspose1D::ConvTranspose1D(int32_t in_channels, int32_t out_channels, int32_t kernel_size, int32_t stride_size, + int32_t padding, int32_t output_padding, int32_t dilation, int32_t groups, bool bias) + : Layer(OpTypes::kConvTranspose1D, aops::ConvTranspose1DOpOptions{.in_channels = in_channels, + .out_channels = out_channels, + .kernel_size = kernel_size, + .stride = stride_size, + .padding = padding, + .output_padding = output_padding, + .dilation = dilation, + .groups = groups, + .bias = bias}) {} + +ConvTranspose1D::ConvTranspose1D(const aops::ConvTranspose1DOpOptions& options) : Layer(OpTypes::kConvTranspose1D, options) {} + +Tensor ConvTranspose1D::weight() const { + return std::static_pointer_cast(impl()->getInstancedOp())->weight(); +} + +Tensor ConvTranspose1D::bias() const { + return std::static_pointer_cast(impl()->getInstancedOp())->bias(); +} + +} // namespace mllm::nn diff --git a/mllm/nn/layers/ConvTranspose1D.hpp b/mllm/nn/layers/ConvTranspose1D.hpp new file mode 100644 index 000000000..6ddc2fac3 --- /dev/null +++ b/mllm/nn/layers/ConvTranspose1D.hpp @@ -0,0 +1,29 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/ConvTranspose1DOp.hpp" + +namespace mllm::nn { + +class ConvTranspose1D : public Layer { + public: + ConvTranspose1D(); + + ConvTranspose1D(int32_t in_channels, int32_t out_channels, int32_t kernel_size, int32_t stride_size = 1, + int32_t padding = 0, int32_t output_padding = 0, int32_t dilation = 1, int32_t groups = 1, + bool bias = true); + + explicit ConvTranspose1D(const aops::ConvTranspose1DOpOptions& options); + + [[nodiscard]] Tensor weight() const; + + [[nodiscard]] Tensor bias() const; + + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD +}; + +} // namespace mllm::nn diff --git a/mllm/nn/layers/Tanh.cpp b/mllm/nn/layers/Tanh.cpp new file mode 100644 index 000000000..dda95f7ae --- /dev/null +++ b/mllm/nn/layers/Tanh.cpp @@ -0,0 +1,12 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/nn/layers/Tanh.hpp" + +namespace mllm::nn { + +Tanh::Tanh() : Layer(OpTypes::kTanh, aops::TanhOpOptions{}) {} + +Tanh::Tanh(const aops::TanhOpOptions& options) : Layer(OpTypes::kTanh, options) {} + +} // namespace mllm::nn diff --git a/mllm/nn/layers/Tanh.hpp b/mllm/nn/layers/Tanh.hpp new file mode 100644 index 000000000..ab84e7eeb --- /dev/null +++ b/mllm/nn/layers/Tanh.hpp @@ -0,0 +1,21 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/TanhOp.hpp" + +namespace mllm::nn { + +class Tanh : public Layer { + public: + Tanh(); + + explicit Tanh(const aops::TanhOpOptions& options); + + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD + MLLM_LAYER_ENABLE_INPLACE_ATTRIBUTE(Tanh) +}; + +} // namespace mllm::nn diff --git a/tests/cpu/ConvTranspose1DKernelTest.hpp b/tests/cpu/ConvTranspose1DKernelTest.hpp new file mode 100644 index 000000000..d7657baf1 --- /dev/null +++ b/tests/cpu/ConvTranspose1DKernelTest.hpp @@ -0,0 +1,134 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "KernelTestHelper.hpp" +#include "mllm/core/ParameterFile.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" + +using namespace mllm; // NOLINT + +void naive_conv_transpose1d(const float* input_data, const float* weight_data, const float* bias_data, float* output_data, + int batch, int in_channels, int sequence, int out_channels, int kernel_size, int stride, + int padding, int dilation, int output_padding, int groups) { + const int out_sequence = (sequence - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1; + std::fill_n(output_data, batch * out_channels * out_sequence, 0.0f); + + const int in_channels_per_group = in_channels / groups; + const int out_channels_per_group = out_channels / groups; + + for (int b = 0; b < batch; ++b) { + for (int oc = 0; oc < out_channels; ++oc) { + const int group_idx = oc / out_channels_per_group; + const int oc_in_group = oc % out_channels_per_group; + for (int out_pos = 0; out_pos < out_sequence; ++out_pos) { + float sum = 0.0f; + for (int ic_in_group = 0; ic_in_group < in_channels_per_group; ++ic_in_group) { + const int ic = group_idx * in_channels_per_group + ic_in_group; + const int base_input_idx = b * (in_channels * sequence) + ic * sequence; + const int base_weight_idx = (ic * out_channels_per_group + oc_in_group) * kernel_size; + + for (int k = 0; k < kernel_size; ++k) { + int input_pos = out_pos + padding - k * dilation; + if (input_pos % stride != 0) { continue; } + input_pos /= stride; + if (input_pos < 0 || input_pos >= sequence) { continue; } + + const int input_idx = base_input_idx + input_pos; + const int weight_idx = base_weight_idx + k; + sum += input_data[input_idx] * weight_data[weight_idx]; + } + } + if (bias_data != nullptr) { sum += bias_data[oc]; } + const int output_idx = b * (out_channels * out_sequence) + oc * out_sequence + out_pos; + output_data[output_idx] = sum; + } + } + } +} + +class ConvTranspose1DModule : public nn::Module { + nn::ConvTranspose1D conv_; + + public: + ConvTranspose1DModule(int in_channel, int out_channel, int kernel_size, int stride, int padding, int output_padding, + int dilation, int groups, bool bias) { + conv_ = reg("conv", in_channel, out_channel, kernel_size, stride, padding, output_padding, dilation, + groups, bias); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {conv_(inputs[0])}; + } +}; + +class ConvTranspose1DKernelTest : public KernelTest { + public: + bool testConvTranspose1DOnce(const std::unordered_map& cfg) { + auto batch = cfg.at("batch"); + auto in_channel = cfg.at("in_channel"); + auto out_channel = cfg.at("out_channel"); + auto sequence = cfg.at("sequence"); + auto kernel_size = cfg.at("kernel_size"); + auto stride = cfg.at("stride"); + auto padding = cfg.at("padding"); + auto output_padding = cfg.at("output_padding"); + auto dilation = cfg.at("dilation"); + auto groups = cfg.at("groups"); + auto bias = cfg.at("bias"); + + auto module = ConvTranspose1DModule(in_channel, out_channel, kernel_size, stride, padding, output_padding, dilation, + groups, bias); + + auto weight_param = + Tensor::random({in_channel, out_channel / groups, kernel_size}, -1, 1, kFloat32, kCPU); + auto bias_param = Tensor::random({out_channel}, -1, 1, kFloat32, kCPU); + weight_param.setName("conv.weight"); + bias_param.setName("conv.bias"); + + auto param = ParameterFile::create(); + param->push("conv.weight", weight_param); + if (bias) { param->push("conv.bias", bias_param); } + module.load(param); + + auto input = Tensor::random({batch, in_channel, sequence}, -1, 1, kFloat32, kCPU); + auto predict = module(input)[0]; + + auto expected = Tensor::zeros(predict.shape(), kFloat32, kCPU); + naive_conv_transpose1d(input.ptr(), weight_param.ptr(), bias ? bias_param.ptr() : nullptr, + expected.ptr(), batch, in_channel, sequence, out_channel, kernel_size, stride, padding, + dilation, output_padding, groups); + + auto result = test::allClose(expected, predict, 1e-4f, 1e-4f); + if (!result) { + print(result); + return false; + } + return true; + } + + bool testConvTranspose1D(const std::vector>& cfgs) { + for (auto& cfg : cfgs) { + if (!testConvTranspose1DOnce(cfg)) { + auto batch = cfg.at("batch"); + auto in_channel = cfg.at("in_channel"); + auto out_channel = cfg.at("out_channel"); + auto sequence = cfg.at("sequence"); + auto kernel_size = cfg.at("kernel_size"); + auto stride = cfg.at("stride"); + auto padding = cfg.at("padding"); + auto output_padding = cfg.at("output_padding"); + auto dilation = cfg.at("dilation"); + auto groups = cfg.at("groups"); + auto bias = cfg.at("bias"); + print(batch, in_channel, out_channel, sequence, kernel_size, stride, padding, output_padding, dilation, groups, bias); + return false; + } + } + return true; + } +}; diff --git a/tests/cpu/KernelTest.cpp b/tests/cpu/KernelTest.cpp index 9f8d613ee..575360703 100644 --- a/tests/cpu/KernelTest.cpp +++ b/tests/cpu/KernelTest.cpp @@ -857,6 +857,48 @@ TEST_F(FlashAttn2KernelTest, fwd_bshd) { } #endif +//===----------------------------------------------------------------------===// +// Tanh +//===----------------------------------------------------------------------===// +#include "TanhKernelTest.hpp" +TEST_F(TanhKernelTest, TanhFloat32) { EXPECT_EQ(testTanh({{8}, {2, 3, 4}}), true); } + +//===----------------------------------------------------------------------===// +// ConvTranspose1D +//===----------------------------------------------------------------------===// +#include "ConvTranspose1DKernelTest.hpp" +TEST_F(ConvTranspose1DKernelTest, Basic) { + EXPECT_EQ(testConvTranspose1D({ + { + {"batch", 1}, + {"in_channel", 2}, + {"out_channel", 3}, + {"sequence", 4}, + {"kernel_size", 3}, + {"stride", 2}, + {"padding", 1}, + {"output_padding", 0}, + {"dilation", 1}, + {"groups", 1}, + {"bias", 1}, + }, + { + {"batch", 2}, + {"in_channel", 1}, + {"out_channel", 2}, + {"sequence", 5}, + {"kernel_size", 2}, + {"stride", 1}, + {"padding", 0}, + {"output_padding", 0}, + {"dilation", 1}, + {"groups", 1}, + {"bias", 0}, + }, + }), + true); +} + //===----------------------------------------------------------------------===// // Conv2D Test // diff --git a/tests/cpu/TanhKernelTest.hpp b/tests/cpu/TanhKernelTest.hpp new file mode 100644 index 000000000..ff6762170 --- /dev/null +++ b/tests/cpu/TanhKernelTest.hpp @@ -0,0 +1,49 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include + +#include "KernelTestHelper.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" + +class TanhModule : public mllm::nn::Module { + mllm::nn::Tanh tanh_; + + public: + TanhModule() { tanh_ = reg("tanh"); } + + std::vector forward(const std::vector& inputs, + const std::vector& args) override { + return {tanh_(inputs[0])}; + } +}; + +class TanhKernelTest : public KernelTest { + public: + bool testTanh(const std::vector& shapes) { + using mllm::Tensor; + using mllm::kCPU; + using mllm::kFloat32; + TanhModule module; + + for (auto& s : shapes) { + auto input = Tensor::random(s, -3, 3, kFloat32, kCPU); + auto output = module(input)[0]; + auto expected = Tensor::empty(s, kFloat32, kCPU).alloc(); + + const auto* in_ptr = input.ptr(); + auto* out_ptr = expected.ptr(); + const auto numel = input.numel(); + for (size_t i = 0; i < numel; ++i) { out_ptr[i] = std::tanh(in_ptr[i]); } + + auto result = mllm::test::allClose(expected, output, 1e-5f, 1e-5f); + if (!result) { + mllm::print(result); + return false; + } + } + return true; + } +};