From efd848e77323cc5ad36126edc0858a2a3beea2a7 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 11:14:05 +0800 Subject: [PATCH 01/14] [contrib] Refactor Qwen2.5-Omni-7B to zero-invasion contrib layout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All Qwen2.5-Omni code now lives under contrib/models/Qwen2.5-Omni-7B/: src/ - modeling files (text, vision, audio, talker, token2wav) examples/ - end-to-end generation + speech pipeline test/ - integration tests perf_test/ - vLLM benchmarks Key refactors vs feature/qwen25-omni-support: - Removed src/neuronx_distributed_inference/models/qwen25_omni/ entirely; contrib files import each other flat via sys.path.insert + _upstream_compat bootstrap, mirroring upstream's contrib/Qwen2-Audio-7B convention. - NeuronTalkerModel previously required a one-line src/model_base.py patch (apply_vision_during_token_gen gate). Replaced by a get_model_output override inside the contrib Talker class — no src/ change needed. - Worked around an upstream hf_adapter.py bug (prepare_inputs_for_generation references an undefined tensor_capture_hook local) via a local _upstream_compat shim applied at contrib import time. Also leaves room for a separate upstream PR with the 1-line fix. - Removed the qwen2_5_omni entries from utils/constants.py and inference_demo.py that were unnecessary for contrib usage. Result: git diff upstream/main..HEAD -- src/ is empty; all changes are additive under contrib/models/Qwen2.5-Omni-7B/. Co-Authored-By: Claude Opus 4.7 --- contrib/models/Qwen2.5-Omni-7B/README.md | 314 +++-- .../examples/generate_qwen25_omni.py | 457 +++++++ .../examples/generate_qwen25_omni_speech.py | 851 +++++++++++++ .../perf_test/3_bench_qwen25_omni_7b.sh | 171 +++ .../apply_vllm_neuron_patch_qwen25omni.py | 83 ++ .../models/Qwen2.5-Omni-7B/src/__init__.py | 11 +- .../Qwen2.5-Omni-7B/src/_upstream_compat.py | 139 +++ .../src/modeling_qwen25_omni.py | 1106 +++++++++++++++++ .../src/modeling_qwen25_omni_audio.py | 700 +++++++++++ .../src/modeling_qwen25_omni_talker.py | 1057 ++++++++++++++++ .../src/modeling_qwen25_omni_token2wav.py | 765 ++++++++++++ .../src/modeling_qwen25_omni_vision.py | 579 +++++++++ .../src/modeling_qwen2_5_omni.py | 620 --------- .../test/integration/test_e2e_qwen25_omni.py | 424 +++++++ .../test/integration/test_model.py | 1016 +++++++++++---- .../test/integration/test_talker_neuron.py | 610 +++++++++ 16 files changed, 7979 insertions(+), 924 deletions(-) create mode 100644 contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/perf_test/3_bench_qwen25_omni_7b.sh create mode 100644 contrib/models/Qwen2.5-Omni-7B/perf_test/apply_vllm_neuron_patch_qwen25omni.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/src/_upstream_compat.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_audio.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_talker.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_vision.py delete mode 100644 contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen2_5_omni.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/test/integration/test_e2e_qwen25_omni.py create mode 100644 contrib/models/Qwen2.5-Omni-7B/test/integration/test_talker_neuron.py diff --git a/contrib/models/Qwen2.5-Omni-7B/README.md b/contrib/models/Qwen2.5-Omni-7B/README.md index c5830ceb..2d56ac4c 100644 --- a/contrib/models/Qwen2.5-Omni-7B/README.md +++ b/contrib/models/Qwen2.5-Omni-7B/README.md @@ -1,130 +1,282 @@ -# Contrib Model: Qwen2.5 Omni 7B +# Contrib Model: Qwen2.5-Omni-7B -NeuronX Distributed Inference implementation of Qwen2.5 Omni 7B. - -> **Note:** This implementation has been validated using the **text backbone only**. Vision/audio modalities are implemented but not yet verified. +NeuronX Distributed Inference implementation of [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) with full multimodal support: text generation, image understanding, audio understanding, and text-to-speech. ## Model Information - **HuggingFace ID:** `Qwen/Qwen2.5-Omni-7B` -- **Model Type:** Decoder-only transformer +- **Model Type:** Multimodal encoder-decoder (Thinker + Vision + Audio + Talker + Token2Wav) +- **Architecture:** Qwen2-based text backbone with vision/audio encoders and speech synthesis - **License:** Check HuggingFace model card ## Architecture Details -- **Layers:** Check model config -- **Hidden Size:** Check model config -- **Attention Heads:** Check model config -- **Vocabulary:** Check model config -- **Max Position Embeddings:** Check model config - -## Validation Results +| Component | Runtime | TP | Parameters | +|-----------|---------|-----|------------| +| Thinker (text) | Neuron | 4 | hidden=3584, heads=28, kv_heads=4, layers=28 | +| Vision encoder | Neuron | 4 | embed=1280, heads=16, depth=32, SwiGLU MLP | +| Audio encoder | CPU+Neuron | 4 | d_model=1280, heads=20, layers=32, chunked attention | +| Talker | Neuron | 4 | hidden=896, heads=12, kv_heads=4, head_dim=128, layers=24, vocab=8448, fused embed (8448→896) | +| Token2Wav | CPU+Neuron (fp32) | N/A | DiT: dim=1024, 22 blocks (Neuron); BigVGAN: 6 upsample stages (CPU) | -**Validated:** 2026-01-29 -**Configuration:** TP=2, batch_size=1, seq_len=128, bfloat16 +**Total state dict keys:** 2448 (Text: 339, Vision: 518, Audio: 489, Talker: 293, Token2Wav: 809) -### Test Results +Key features: +- **Thinker**: Architecturally identical to Qwen2.5-7B; reuses `NeuronQwen2ForCausalLM` with state-dict prefix remapping (28 heads / 4 TP = 7 per rank, 4 kv_heads / 4 TP = 1 per rank) +- **Vision encoder**: SwiGLU MLP, RMSNorm, separate QKV projections, PatchMerger (16 heads / 4 TP = 4 per rank) +- **Audio encoder**: Whisper-style with chunked attention. Hybrid CPU+Neuron: Conv1d frontend + chunking on CPU, 32 transformer layers on Neuron (20 heads / 4 TP = 5 per rank), AvgPool + LayerNorm + projection on CPU +- **Talker**: Neuron-compiled with fused embedding (embed_tokens 8448→3584 + thinker_to_talker_proj 3584→896 collapsed into 8448→896), explicit head_dim=128, 3D mRoPE, per-step thinker state injection via vision_embeddings (12 heads / 4 TP = 3 per rank, 4 kv_heads / 4 TP = 1 per rank). Auto-pads vision_embeddings to max_context_length for compiled bucket compatibility. +- **Token2Wav**: DiT transformer core (22 blocks) on Neuron + BigVGAN vocoder on CPU, ODE sampling (Runge-Kutta 4, 10 steps), float32. Split architecture: CPU preprocessing (ECAPA-TDNN, codec embed, input embed, rotary) + Neuron transformer core + CPU ODE solver + CPU BigVGAN. Automatic CPU fallback when mel_len exceeds compiled max. -| Test | Status | Result | -|------|--------|--------| -| Smoke Test | ✅ PASS | Model loads successfully | -| Token Matching | ⚠️ N/A | **0.0% match** | -| TTFT (P50) | ✅ PASS | 50.15ms (threshold: 100ms) | -| Throughput | ✅ PASS | 19.82 tok/s (threshold: 10 tok/s) | +## Prerequisites -### Performance Metrics +- **Instance**: trn2.48xlarge or trn2.xlarge (4+ NeuronCores sufficient) +- **Weights**: Download from [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) -| Metric | Value | -|--------|-------| -| TTFT (P50) | 50.15ms | -| Throughput | 19.82 tokens/s | +## Usage +### Text-only (Thinker) -**Status:** ✅ VALIDATED +```python +import sys +from pathlib import Path + +# Make this contrib package's src/ importable (flat, per upstream contrib convention). +sys.path.insert(0, str(Path("contrib/models/Qwen2.5-Omni-7B/src").resolve())) +import _upstream_compat # noqa: F401 (applies hf_adapter bug fix) + +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter +from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, + Qwen25OmniInferenceConfig, +) -### Device Profiling Metrics +model_path = "/path/to/Qwen2.5-Omni-7B/" +compiled_path = "/path/to/compiled/" -**Configuration:** TP=2, batch_size=1, seq_len=128, bfloat16 -**Instance:** trn1.32xlarge | **Profiled:** 2026-03-18 +neuron_config = NeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=4096, + max_context_length=4096, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.6, top_k=20, top_p=0.95 + ), +) -| Metric | Context Encoding | Token Generation | -|--------|-----------------|------------------| -| MFU (%) | 0.19 | 0.00 | -| MBU (%) | 0.36 | 0.42 | -| HFU (%) | 0.19 | 0.00 | -| Execution Time (us) | 0.05 | 0.04 | -| HBM Read | 7.19 GB | 7.08 GB | -| HBM Write | 88.46 MB | 2.78 MB | +config = Qwen25OmniInferenceConfig( + neuron_config, load_config=load_pretrained_config(model_path) +) -**Throughput:** 19.81 tok/s | **Compile Time:** 332.09s +model = NeuronQwen25OmniForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) -> Metrics from `neuron-profile capture` on compiled NEFFs. MFU = Model FLOPs Utilization, -> MBU = Memory Bandwidth Utilization, HFU = Hardware FLOPs Utilization. +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +adapter = HuggingFaceGenerationAdapter(model, tokenizer) +output = adapter.generate("What is quantum computing?", max_new_tokens=256) +``` -## Usage +### Multimodal (Vision + Audio + Speech) ```python -from transformers import AutoTokenizer, GenerationConfig -from neuronx_distributed_inference.models.config import NeuronConfig -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +from modeling_qwen25_omni import ( + NeuronQwen25OmniMultimodalForCausalLM, + Qwen25OmniMultimodalInferenceConfig, +) -# Import model classes from src -from src.modeling_qwen2_5_omni_7b import NeuronQwen25Omni7BForCausalLM, Qwen25Omni7BInferenceConfig +# After loading text model, enable multimodal components: +# model.enable_audio_encoder(audio_state_dict) +# model.compile_audio_encoder("/path/to/compiled_audio/") # compile Neuron transformer +# model.load_audio_encoder("/path/to/compiled_audio/") # load compiled model +# model.enable_talker(talker_state_dict) +# model.enable_token2wav(token2wav_state_dict, speaker_dict_path="spk_dict.pt") +# +# Full multimodal pipeline: +# Thinker generates text -> hidden states passed to Talker +# -> Talker generates codec tokens -> Token2Wav generates waveform +``` -model_path = "/path/to/Qwen2.5-Omni-7B/" -compiled_model_path = "/path/to/compiled/" +## vLLM Integration -# Configure -neuron_config = NeuronConfig( - tp_degree=2, - batch_size=1, - seq_len=512, - torch_dtype=torch.bfloat16, -) +Qwen2.5-Omni can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron) for text-only inference. A patch is required for the nested config structure. -config = Qwen25Omni7BInferenceConfig( - neuron_config, - load_config=load_pretrained_config(model_path), -) +### Setup + +```bash +# 1. Install vllm-neuron +pip install vllm-neuron -# Compile and load -model = NeuronQwen25Omni7BForCausalLM(model_path, config) -model.compile(compiled_model_path) -model.load(compiled_model_path) +# 2. Apply the Qwen2.5-Omni patch +python perf_test/apply_vllm_neuron_patch_qwen25omni.py +``` + +### Serving -# Generate -tokenizer = AutoTokenizer.from_pretrained(model_path) -# ... (see integration test for full example) +```bash +python3 -m vllm.entrypoints.openai.api_server \ + --model /path/to/Qwen2.5-Omni-7B \ + --tensor-parallel-size 4 \ + --max-model-len 4096 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 4, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": false, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 4096, + "seq_len": 4096, + "is_continuous_batching": true, + "enable_bucketing": true, + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' ``` +### Key vLLM Patch Changes + +The patch (`perf_test/apply_vllm_neuron_patch_qwen25omni.py`) modifies vllm-neuron to: +- Extract text config from nested `thinker_config.text_config` +- Map `Qwen2_5OmniModel` architecture to `qwen2_5_omni` model type +- Handle layer count extraction for nested config + +See `perf_test/3_bench_qwen25_omni_7b.sh` for full benchmark configurations. + +## Performance + +Text-only benchmark (trn2, BF16, TP=4): + +| Config | TPOT (ms) | Output tok/s | Notes | +|--------|-----------|--------------|-------| +| BS=1, non-CB, greedy | ~11-13 | ~77-90 | Tested with chat template | +| BS=4, CB, c=4 | TBD | TBD | vLLM serving | + +Model load time: ~15s (from compiled artifacts on NVMe). + +Audio encoder performance (CPU frontend + CPU postprocessor, no Neuron transformer): + +| Audio Length | Mel Frames | Frontend | Postprocessor | +|-------------|-----------|----------|---------------| +| 1s | ~100 | ~20ms | included | +| 3s | ~300 | ~22ms | included | +| 10s | ~1000 | ~33ms | included | +| 30s | ~3000 | ~34ms | included | + +### End-to-End Multimodal (CPU inference, trn2.48xlarge) + +| Test | Input | Output | Time | +|------|-------|--------|------| +| Text → Text | "What is the capital of France?" | Correct answer (Paris) | 15.1s | +| Image + Text → Text | Synthetic image (shapes) + description prompt | Correctly identified red square, blue circle, yellow circle, green triangle | 59.5s | +| Audio + Text → Text | 440Hz sine wave + "What do you hear?" | Text response generated | 12.1s | +| Text → Speech | "Say hello and tell me the weather" | Text + audio waveform (14.2s audio) | 197.2s | + +### Speech Pipeline Profiling (CPU inference, trn2.48xlarge) + +Per-component measured breakdown for text-to-speech (14.1s audio output): + +| Component | Time | % of Total | RTF | Notes | +|-----------|------|------------|-----|-------| +| Thinker (7B) | 31.0s | 12% | — | 59 text tokens, ~1.9 tok/s on CPU | +| Talker (690M) | 103.3s | 41% | 7.3x | Autoregressive codec token generation, 24 layers | +| Token2Wav (DiT+BigVGAN) | 117.9s | 47% | 8.4x | 22 DiT blocks × 10 ODE steps × 2 (CFG) = 440 forward passes | +| **Total** | **252.1s** | **100%** | **17.9x** | Generating 14.1s audio takes 252.1s on CPU | + +### Full Neuron Speech Pipeline (trn2.48xlarge, TP=4, BF16) + +End-to-end speech synthesis with all components on Neuron (9.1s audio output): + +| Component | Time | Notes | +|-----------|------|-------| +| Thinker (7B, Neuron TP=4) | 0.3s | 24 text tokens | +| CPU hidden state extraction | ~3s | HF model forward for thinker states | +| Talker (690M, Neuron TP=4) | 2.1s | 454 codec tokens, per-step thinker injection | +| Token2Wav (Neuron DiT + CPU BigVGAN) | 9.9s | 22 DiT blocks × 10 ODE steps | +| **Total** | **~15s** | **9.1s audio, RTF ~1.7x** | + +### Neuron vs CPU Speedup (trn2.48xlarge, TP=4, BF16) + +| Component | CPU Time | Neuron Time | Speedup | Notes | +|-----------|----------|-------------|---------|-------| +| Thinker (7B) | 30.4s | 0.47s | **64.7x** | TPOT 10.2ms | +| Talker (690M) | 98.1s | 2.0s (500 tokens) | **49.1x** | TPOT 4.0ms | +| Token2Wav DiT (85M) | 24.1s | 3.8s | **6.3x** | 22 blocks × 10 ODE steps, batch=2 (CFG) | +| Token2Wav BigVGAN | 2.8s | 2.8s (CPU) | 1x | Stays on CPU | +| **Total** | **267.9s** | **~15s** | **~18x** | All Neuron components active | + +Token2Wav component breakdown (300 codec tokens / 6.0s audio): + +| Config | CPU | Neuron DiT | Speedup | +|--------|-----|-----------|---------| +| DiT only (22 blocks, 10 ODE steps) | 24.1s | 3.8s | 6.3x | +| Token2Wav end-to-end | 13.7s | 5.2s | 2.7x | +| DiT core single forward (batch=2, mel_len=1024) | 592ms | 60ms | 9.8x | + +Key observations: +- **Full Neuron speech pipeline** verified end-to-end: Thinker → Talker → Token2Wav all on Neuron, producing real human speech +- Thinker and Talker achieve **49-65x speedup** on Neuron +- Token2Wav DiT achieves **6.3x speedup** (9.8x for isolated transformer core) +- BigVGAN vocoder (CPU) is now the remaining bottleneck +- **Per-step thinker state injection**: Talker v2 adds thinker_reply_part[step] embedding at each autoregressive step, matching HF behavior +- **Vision embeddings auto-padding**: Compiled Neuron models require fixed bucket shapes; vision_embeddings are auto-padded to max_context_length +- Split architecture for Token2Wav: CPU preprocessing (ECAPA-TDNN, codec/input embed, rotary, block_diff) + Neuron transformer core (22 blocks + norm + proj) +- Overcame XLA tracing limitations: in-place slice assignment in DiTAttention (→ torch.cat), SDPA dispatch (→ explicit matmul), ECAPA-TDNN/codec embed issues (→ kept on CPU) +- Automatic CPU fallback when mel_len exceeds compiled DiT max + ## Compatibility Matrix -| Instance/Version | 2.20+ | 2.19 and earlier | -|------------------|-------|------------------| -| Trn1 | ✅ Working | Not tested | -| Inf2 | Not tested | Not tested | +| Instance/Version | 2.23+ (PyTorch 2.9) | 2.22 and earlier | +|------------------|---------------------|------------------| +| Trn2 (trn2.48xlarge) | Tested (TP=4) | Not tested | +| Trn2 (trn2.xlarge) | Supported (TP=4) | Not tested | +| Trn1 (trn1.32xlarge) | Should work (TP=4, 4 NeuronCores) | Not tested | +| Inf2 (inf2.48xlarge) | Should work (TP=4) | Not tested | ## Testing -Run integration tests: +Verified on trn2.48xlarge with real Qwen2.5-Omni-7B weights: + +- **Imports**: All model classes import successfully +- **Config**: TP=4 head divisibility verified (Thinker 7/1, Audio 5, Vision 4 per rank) +- **State dict**: All 2448 keys converted correctly (text=339, audio=489, vision=518, talker=293, token2wav=809) +- **Audio CPU**: Frontend+postprocessor 1s=20ms, 30s=34ms +- **Talker CPU**: 1351M params loaded in ~10s, codec tokens verified +- **Text generation (TP=4)**: Compile + load + generate working, TPOT ~11-13ms, correct outputs verified ```bash -pytest nxdi_contrib_models/models/Qwen2.5-Omni-7B/test/integration/test_model.py --capture=tee-sys +# End-to-end test (compile + load + generate) +python /tmp/test_qwen25_omni_tp4.py ``` -Or run manually: +## Key Implementation Notes -```bash -cd nxdi_contrib_models/models/Qwen2.5-Omni-7B -python3 test/integration/test_model.py -``` +1. **TP=4 for all Neuron components**: Thinker (28 heads/4=7 per rank), Vision (16 heads/4=4), Audio (20 heads/4=5). All heads divisible by 4. +2. **Audio encoder hybrid architecture**: Conv1d frontend + chunking on CPU, 32 transformer layers on Neuron with TP=4, AvgPool + LayerNorm + projection on CPU. Asymmetric attention bias (q/v have bias, k has none) handled via ColumnParallelLinear. +3. **Talker on Neuron**: Non-standard head_dim (128 != 896/12), 3D mRoPE with per-step thinker-state injection, ~690M params. Uses ImageToTextModelWrapper with 24 positional args. Fused embedding (embed_tokens 8448→3584 + proj 3584→896 collapsed into 8448→896). Per-step thinker reply states injected via vision_embeddings during token generation. Vision embeddings auto-padded to max_context_length for compiled bucket compatibility. TPOT 4.0ms. +4. **Token2Wav split architecture**: DiT transformer core (22 blocks) on Neuron via torch_neuronx.trace(). CPU preprocessing: ECAPA-TDNN speaker encoder, codec embedding (repeat_interleave), input embedding, rotary embedding, block_diff mask. CPU postprocessing: ODE solver (RK4, 10 steps), BigVGAN vocoder. Float32 for ODE precision. XLA fixes: DiTAttention in-place slice assignment → torch.cat, SDPA dispatch → explicit matmul attention, float additive attention mask. Automatic CPU fallback when mel_len exceeds compiled max. +5. **Speaker support**: `spk_dict.pt` contains per-speaker conditioning (Ethan, Chelsie) +6. **State dict prefix remapping**: `thinker.model.*` -> `model.*`, `thinker.lm_head.*` -> `lm_head.*`, `thinker.visual.*` -> `visual.*`, `thinker.audio_tower.*` -> `frontend.*`/`transformer.*`/`postprocessor.*` ## Example Checkpoints -* Qwen/Qwen2.5-Omni-7B +* [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) ## Maintainer -Annapurna Labs +Henan Wan (whn09) -**Last Updated:** 2026-01-29 +**Last Updated:** 2026-04-15 diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py new file mode 100644 index 00000000..9b4edfce --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py @@ -0,0 +1,457 @@ +#!/usr/bin/env python3 +""" +End-to-end multimodal inference for Qwen2.5-Omni-7B on NeuronX (TP=4). + +Supports: + 1. Text-only: Thinker generates text + 2. Image + text: Vision encoder + Thinker + 3. Audio + text: Audio encoder + Thinker + 4. Image + audio + text: All encoders + Thinker + 5. Speech output: Thinker → Talker → Token2Wav (optional) + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + pip install qwen-omni-utils[decord] + + # First run compiles (~30 min), subsequent runs load from cache: + python3 examples/generate_qwen25_omni.py + + # Text-only (fastest, uses simpler text-only model): + python3 examples/generate_qwen25_omni.py --mode text + + # Image understanding: + python3 examples/generate_qwen25_omni.py --mode image + + # Audio understanding: + python3 examples/generate_qwen25_omni.py --mode audio + + # Full multimodal (image + audio + text → text + speech): + python3 examples/generate_qwen25_omni.py --mode full +""" + +# --- Qwen2.5-Omni contrib bootstrap --- +import sys as _sys +from pathlib import Path as _Path +_SRC = _Path(__file__).resolve().parents[1] / "src" +if str(_SRC) not in _sys.path: + _sys.path.insert(0, str(_SRC)) +import _upstream_compat # noqa: F401 (applies hf_adapter shim) +# --- end bootstrap --- + +import argparse +import gc +import os +import sys +import time + +import torch + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +MODEL_PATH = os.environ.get( + "QWEN25_OMNI_MODEL_PATH", "/opt/dlami/nvme/models/Qwen2.5-Omni-7B" +) +COMPILED_PATH = os.environ.get( + "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_compiled" +) +TP_DEGREE = int(os.environ.get("QWEN25_OMNI_TP_DEGREE", "4")) + +# Test media from Qwen official examples +IMAGE_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" +AUDIO_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/cough.wav" + +# Sequence lengths +TEXT_SEQ_LENGTH = 4096 +TEXT_BUCKETS = [256, 512, 1024, 2048, 4096] +VISION_SEQ_LENGTH = 1012 # single image +VISION_BUCKETS = [1] # 1 image + + +class Timer: + def __init__(self, label): + self.label = label + self.elapsed = 0 + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.elapsed = time.time() - self.start + print(f" [{self.label}] {self.elapsed:.2f}s") + + +# --------------------------------------------------------------------------- +# Mode 1: Text-only (uses the simpler NeuronQwen25OmniForCausalLM) +# --------------------------------------------------------------------------- +def run_text_only(model_path=MODEL_PATH, compiled_path=COMPILED_PATH): + """Text-only inference using the Thinker model.""" + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, + HuggingFaceGenerationAdapter, + ) + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, + Qwen25OmniInferenceConfig, + ) + from transformers import AutoTokenizer + + print("\n" + "=" * 60) + print("Mode: Text-only (Thinker on Neuron, TP=4)") + print("=" * 60) + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=2048, + max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(do_sample=False, top_k=1), + ) + + hf_config = load_pretrained_config(model_path) + config = Qwen25OmniInferenceConfig(neuron_config, load_config=hf_config) + + compiled_dir = os.path.join(compiled_path, "thinker_tp4") + + with Timer("Create model"): + model = NeuronQwen25OmniForCausalLM(model_path, config) + + if not os.path.exists(os.path.join(compiled_dir, "neuron_config.json")): + with Timer("Compile"): + model.compile(compiled_dir) + else: + print(" Compiled artifacts found") + + with Timer("Load"): + model.load(compiled_dir) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + adapter = HuggingFaceGenerationAdapter(model) + + prompts = [ + "What is 2+3? Answer with just the number.", + "Write a haiku about the ocean.", + "Explain quantum computing in one sentence.", + ] + + print("\n--- Generation Results ---") + for prompt in prompts: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + encoded = tokenizer(text, return_tensors="pt") + + with Timer(f"Generate"): + output_ids = adapter.generate( + input_ids=encoded["input_ids"], + attention_mask=encoded["attention_mask"], + max_new_tokens=128, + eos_token_id=[tokenizer.eos_token_id, 151645], + ) + + new_tokens = output_ids[0, encoded["input_ids"].shape[1] :] + output_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + print(f" Q: {prompt}") + print(f" A: {output_text.strip()[:300]}") + print() + + del model, adapter + gc.collect() + + +# --------------------------------------------------------------------------- +# Mode 2: Multimodal (image/audio + text → text, optional speech) +# --------------------------------------------------------------------------- +def run_multimodal(mode="full", model_path=MODEL_PATH, compiled_path=COMPILED_PATH): + """Multimodal inference: image + audio + text → text (+ optional speech). + + Args: + mode: "image" (image+text), "audio" (audio+text), "full" (image+audio+text) + model_path: Path to HF model weights + compiled_path: Path for compiled artifacts + """ + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, + HuggingFaceGenerationAdapter, + ) + from modeling_qwen25_omni import ( + NeuronQwen25OmniMultimodalForCausalLM, + Qwen25OmniMultimodalInferenceConfig, + ) + from transformers import AutoProcessor, GenerationConfig + + print("\n" + "=" * 60) + print(f"Mode: Multimodal ({mode}) on Neuron, TP={TP_DEGREE}") + print("=" * 60) + + # --- Step 1: Create configs --- + text_neuron_config = NeuronConfig( + batch_size=1, + seq_len=TEXT_SEQ_LENGTH, + max_context_length=TEXT_SEQ_LENGTH, + ctx_batch_size=1, + tp_degree=TP_DEGREE, + torch_dtype=torch.bfloat16, + fused_qkv=False, + sequence_parallel_enabled=False, + flash_decoding_enabled=False, + qkv_kernel_enabled=False, + qkv_nki_kernel_enabled=False, + attn_kernel_enabled=False, + enable_bucketing=True, + context_encoding_buckets=TEXT_BUCKETS, + token_generation_buckets=TEXT_BUCKETS, + on_device_sampling_config=OnDeviceSamplingConfig(do_sample=False, top_k=1), + ) + + vision_neuron_config = NeuronConfig( + batch_size=1, + seq_len=VISION_SEQ_LENGTH, + tp_degree=TP_DEGREE, + torch_dtype=torch.bfloat16, + enable_bucketing=True, + buckets=VISION_BUCKETS, + fused_qkv=False, # Qwen2.5-Omni vision uses separate Q/K/V + qkv_kernel_enabled=False, + attn_kernel_enabled=False, + ) + + config = Qwen25OmniMultimodalInferenceConfig( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + load_config=load_pretrained_config(model_path), + ) + + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + # --- Step 2: Create and compile/load the model --- + compiled_dir = os.path.join(compiled_path, "multimodal_tp4") + + with Timer("Create multimodal model"): + model = NeuronQwen25OmniMultimodalForCausalLM( + model_path=model_path, config=config + ) + + if not os.path.exists(os.path.join(compiled_dir, "neuron_config.json")): + with Timer("Compile text + vision (this takes 20-40 minutes)"): + model.compile(compiled_dir) + processor.save_pretrained(compiled_dir) + else: + print(" Compiled artifacts found") + + with Timer("Load compiled model"): + model.load(compiled_dir) + processor = AutoProcessor.from_pretrained(compiled_dir) + + # --- Step 3: Enable audio encoder (if needed) --- + if mode in ("audio", "full"): + print("\n Enabling audio encoder...") + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniForConditionalGeneration, + ) + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + + with Timer("Load HF model for audio weights"): + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + full_sd = hf_model.state_dict() + + audio_sd = NeuronQwen25OmniAudioEncoder.convert_hf_to_neuron_state_dict( + full_sd, dtype=torch.bfloat16 + ) + del hf_model, full_sd + gc.collect() + + model.enable_audio_encoder(audio_sd) + + compiled_audio_dir = os.path.join(compiled_path, "audio_encoder_tp4") + if not os.path.exists( + os.path.join(compiled_audio_dir, "neuron_config.json") + ): + with Timer("Compile audio encoder transformer"): + model.compile_audio_encoder(compiled_audio_dir) + else: + print(" Audio encoder compiled artifacts found") + + with Timer("Load audio encoder"): + model.load_audio_encoder(compiled_audio_dir) + + del audio_sd + gc.collect() + + # --- Step 4: Run inference --- + adapter = HuggingFaceGenerationAdapter(model) + generation_config = GenerationConfig( + do_sample=False, + eos_token_id=[151645], + pad_token_id=151645, + ) + + # Prepare inputs based on mode + from qwen_omni_utils import process_mm_info + + if mode == "image": + print("\n--- Image Understanding ---") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "image", "image": IMAGE_URL}, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + elif mode == "audio": + print("\n--- Audio Understanding ---") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "audio", "audio": AUDIO_URL}, + {"type": "text", "text": "What sound is this? Describe what you hear."}, + ], + }, + ] + else: # full + print("\n--- Full Multimodal (Image + Audio + Text) ---") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "image", "image": IMAGE_URL}, + {"type": "audio", "audio": AUDIO_URL}, + { + "type": "text", + "text": "Describe the image and the audio you received.", + }, + ], + }, + ] + + # Process multimodal inputs + with Timer("Process multimodal inputs"): + audios, images, videos = process_mm_info(messages, use_audio_in_video=False) + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + inputs = processor( + text=[text], + images=images if images else None, + audios=audios if audios else None, + return_tensors="pt", + padding=True, + ) + + print(f" input_ids shape: {inputs.input_ids.shape}") + if hasattr(inputs, "pixel_values") and inputs.pixel_values is not None: + print(f" pixel_values shape: {inputs.pixel_values.shape}") + if hasattr(inputs, "image_grid_thw") and inputs.image_grid_thw is not None: + print(f" image_grid_thw: {inputs.image_grid_thw}") + if hasattr(inputs, "input_features") and inputs.input_features is not None: + print(f" input_features shape: {inputs.input_features.shape}") + + # Generate + generate_kwargs = { + "input_ids": inputs.input_ids, + "attention_mask": inputs.attention_mask, + "generation_config": generation_config, + "max_new_tokens": 256, + } + if hasattr(inputs, "pixel_values") and inputs.pixel_values is not None: + generate_kwargs["pixel_values"] = inputs.pixel_values + if hasattr(inputs, "image_grid_thw") and inputs.image_grid_thw is not None: + generate_kwargs["image_grid_thw"] = inputs.image_grid_thw + if hasattr(inputs, "input_features") and inputs.input_features is not None: + generate_kwargs["input_features"] = inputs.input_features + if ( + hasattr(inputs, "feature_attention_mask") + and inputs.feature_attention_mask is not None + ): + generate_kwargs["feature_attention_mask"] = inputs.feature_attention_mask + + with Timer("Generate"): + output_ids = adapter.generate(**generate_kwargs) + + output_text = processor.batch_decode( + output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + print(f"\n Output: {output_text[0].strip()[:500]}") + + del model, adapter + gc.collect() + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser(description="Qwen2.5-Omni-7B inference on Neuron") + parser.add_argument( + "--mode", + choices=["text", "image", "audio", "full"], + default="text", + help="Inference mode (default: text)", + ) + parser.add_argument( + "--model-path", + default=MODEL_PATH, + help=f"Model path (default: {MODEL_PATH})", + ) + parser.add_argument( + "--compiled-path", + default=COMPILED_PATH, + help=f"Compiled model path (default: {COMPILED_PATH})", + ) + args = parser.parse_args() + + model_path = args.model_path + compiled_path = args.compiled_path + + print(f"Model: {model_path}") + print(f"Compiled: {compiled_path}") + print(f"TP: {TP_DEGREE}") + + if args.mode == "text": + run_text_only(model_path, compiled_path) + else: + run_multimodal(args.mode, model_path, compiled_path) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py new file mode 100644 index 00000000..2dc7b4e2 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -0,0 +1,851 @@ +#!/usr/bin/env python3 +""" +End-to-end speech synthesis for Qwen2.5-Omni-7B on NeuronX (TP=4). + +Full pipeline: Thinker (text) -> Talker (codec tokens) -> Token2Wav (audio) + +Two-step workflow: + Step 1: Compile all Neuron components (one-time, ~30 min) + Step 2: Run inference (loads compiled artifacts, ~15s per utterance) + +Architecture note: + Thinker (TP=4) and Talker (TP=4) each require exclusive Neuron access, + so they run in separate subprocesses. Within each subprocess, the model + is loaded ONCE and reused for all --num-runs iterations. + Token2Wav DiT runs in the main process. + +Prerequisites: + - Trn2 instance (trn2.48xlarge or trn2.xlarge, 4+ NeuronCores) + - Neuron SDK 2.23+ with PyTorch 2.9 + - Model weights: huggingface-cli download Qwen/Qwen2.5-Omni-7B + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + cd neuronx-distributed-inference + + # Step 1: Compile (one-time, ~30 min) + python examples/generate_qwen25_omni_speech.py --compile + + # Step 2: Run inference + python examples/generate_qwen25_omni_speech.py + python examples/generate_qwen25_omni_speech.py --prompt "Tell me about the weather" + python examples/generate_qwen25_omni_speech.py --speaker Chelsie --output hello.wav + + # Benchmark: load each model once, run N inferences, report avg latency + python examples/generate_qwen25_omni_speech.py --num-runs 5 + +Pipeline timing (trn2.48xlarge, TP=4, model already loaded): + Thinker: ~0.3s (text generation) + Talker: ~2-3s (codec token generation) + Token2Wav: ~10s (mel spectrogram + vocoder) + Total: ~15s for ~10s of audio (RTF ~1.5x) +""" + +# --- Qwen2.5-Omni contrib bootstrap --- +import sys as _sys +from pathlib import Path as _Path +_SRC = _Path(__file__).resolve().parents[1] / "src" +if str(_SRC) not in _sys.path: + _sys.path.insert(0, str(_SRC)) +import _upstream_compat # noqa: F401 (applies hf_adapter shim) +# --- end bootstrap --- + +import argparse +import gc +import json +import os +import subprocess +import sys +import time + +import torch + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +MODEL_PATH = os.environ.get( + "QWEN25_OMNI_MODEL_PATH", + os.path.expanduser("~/.cache/huggingface/hub/models--Qwen--Qwen2.5-Omni-7B/snapshots/"), +) +COMPILED_PATH = os.environ.get( + "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_compiled" +) +TP_DEGREE = int(os.environ.get("QWEN25_OMNI_TP_DEGREE", "4")) + +DEFAULT_PROMPT = "Say hello and briefly introduce yourself in two sentences." +DEFAULT_SYSTEM = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, " + "capable of perceiving auditory and visual inputs, as well as generating " + "text and speech." +) +DEFAULT_SPEAKER = "Ethan" + +_ORIG_EMBEDDING_FORWARD = torch.nn.Embedding.forward + + +def _resolve_model_path(path): + """Resolve model path, handling HF cache snapshot directories.""" + if os.path.isdir(path) and not os.path.exists(os.path.join(path, "config.json")): + snaps = [d for d in os.listdir(path) if not d.startswith(".")] + if snaps: + path = os.path.join(path, snaps[0]) + return path + + +def _restore_embedding(): + """Restore original Embedding.forward if Neuron loading changed it.""" + if torch.nn.Embedding.forward is not _ORIG_EMBEDDING_FORWARD: + torch.nn.Embedding.forward = _ORIG_EMBEDDING_FORWARD + + +class Timer: + def __init__(self, label): + self.label = label + self.elapsed = 0 + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.elapsed = time.time() - self.start + print(f" [{self.label}] {self.elapsed:.2f}s") + + +_SUBPROCESS_BOOTSTRAP = ( + "import sys, os\n" + f"sys.path.insert(0, {str(_SRC)!r})\n" + "import _upstream_compat # noqa: F401\n" +) + + +def _run_subprocess(script_code, label, temp_dir): + """Run Python code as a subprocess (required for Neuron process isolation).""" + script_path = os.path.join(temp_dir, f"{label}.py") + with open(script_path, "w") as f: + f.write(_SUBPROCESS_BOOTSTRAP + script_code) + env = dict(os.environ) + existing = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = f"{_SRC}{os.pathsep}{existing}" if existing else str(_SRC) + t0 = time.time() + result = subprocess.run( + [sys.executable, script_path], + capture_output=True, text=True, timeout=600, env=env, + ) + elapsed = time.time() - t0 + if result.returncode != 0: + print(f" [{label}] FAILED ({elapsed:.1f}s)") + for line in result.stderr.strip().split("\n")[-15:]: + if line.strip() and not any( + x in line for x in ["WARN", "TDRV", "NMGR", "NRT", "nccl", "blockwise"] + ): + print(f" {line}") + for line in result.stdout.strip().split("\n")[-5:]: + if line.strip(): + print(f" {line}") + return False + print(f" [{label}] subprocess finished ({elapsed:.1f}s)") + for line in result.stdout.strip().split("\n"): + if line.strip(): + print(f" {line}") + return True + + +# ========================================================================== +# Compilation (--compile) +# ========================================================================== + +def compile_all(model_path, compiled_path): + """Compile all three Neuron components: Thinker, Talker, DiT.""" + print("=" * 60) + print("Compiling Qwen2.5-Omni Speech Components") + print("=" * 60) + print(f" Model: {model_path}") + print(f" Output: {compiled_path}") + print(f" TP: {TP_DEGREE}") + t_total = time.time() + + temp_dir = os.path.join(compiled_path, "_tmp") + os.makedirs(temp_dir, exist_ok=True) + + # --- 1. Compile Thinker --- + print("\n--- [1/3] Compiling Thinker ---") + thinker_compiled = os.path.join(compiled_path, "thinker_tp4") + if os.path.exists(os.path.join(thinker_compiled, "neuron_config.json")): + print(" Already compiled, skipping.") + else: + script = f''' +import torch, os +MODEL_PATH = "{model_path}" +COMPILED = "{thinker_compiled}" + +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, Qwen25OmniInferenceConfig, +) + +nc = NeuronConfig( + tp_degree={TP_DEGREE}, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.7, top_k=20, top_p=0.8 + ), +) +cfg = Qwen25OmniInferenceConfig(nc, load_config=load_pretrained_config(MODEL_PATH)) +model = NeuronQwen25OmniForCausalLM(MODEL_PATH, cfg) +model.compile(COMPILED) +print("Thinker compiled successfully") +''' + ok = _run_subprocess(script, "compile_thinker", temp_dir) + if not ok: + print("FATAL: Thinker compilation failed.") + return False + + # --- 2. Compile Talker --- + print("\n--- [2/3] Compiling Talker ---") + talker_compiled = os.path.join(compiled_path, "talker_tp4") + if os.path.exists(os.path.join(talker_compiled, "neuron_config.json")): + print(" Already compiled, skipping.") + else: + script = f''' +import torch, os +MODEL_PATH = "{model_path}" +COMPILED = "{talker_compiled}" + +from transformers import AutoConfig +from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, TalkerInferenceConfig, TalkerNeuronConfig, +) +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + +hf = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) +tc = hf.talker_config + +tnc = TalkerNeuronConfig( + tp_degree={TP_DEGREE}, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, +) +tic = TalkerInferenceConfig(neuron_config=tnc, load_config=load_pretrained_config(hf_config=tc)) +talker = NeuronQwen25OmniTalkerForCausalLM(MODEL_PATH, config=tic) +talker.compile(COMPILED) +print("Talker compiled successfully") +''' + ok = _run_subprocess(script, "compile_talker", temp_dir) + if not ok: + print("FATAL: Talker compilation failed.") + return False + + # --- 3. Compile DiT --- + print("\n--- [3/3] Compiling Token2Wav DiT ---") + dit_compiled = os.path.join(compiled_path, "dit_core") + if os.path.exists(os.path.join(dit_compiled, "dit_core_neuron.pt")): + print(" Already compiled, skipping.") + else: + script = f''' +import torch, os +MODEL_PATH = "{model_path}" +COMPILED = "{dit_compiled}" + +from transformers import AutoConfig +from safetensors.torch import load_file +from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2WavWithNeuronDiT, +) + +hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) +t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(hf_config.token2wav_config) + +state_dict = {{}} +for fn in sorted(os.listdir(MODEL_PATH)): + if fn.endswith(".safetensors"): + sd = load_file(os.path.join(MODEL_PATH, fn)) + for k, v in sd.items(): + if k.startswith("token2wav."): + state_dict[k[len("token2wav."):]] = v +t2w.load_state_dict(state_dict, strict=False) + +t2w.compile_dit(COMPILED, max_mel_len=2048, batch_size=2) +print("DiT compiled successfully") +''' + ok = _run_subprocess(script, "compile_dit", temp_dir) + if not ok: + print("FATAL: DiT compilation failed.") + return False + + total = time.time() - t_total + print(f"\nAll components compiled in {total:.0f}s") + print(f"Artifacts saved to: {compiled_path}/") + print(f" thinker_tp4/ - Thinker (7B text model)") + print(f" talker_tp4/ - Talker (690M codec model)") + print(f" dit_core/ - Token2Wav DiT (85M transformer)") + return True + + +# ========================================================================== +# Inference +# ========================================================================== + +def _check_compiled(compiled_path): + """Verify all compiled artifacts exist.""" + checks = [ + (os.path.join(compiled_path, "thinker_tp4", "neuron_config.json"), "Thinker"), + (os.path.join(compiled_path, "talker_tp4", "neuron_config.json"), "Talker"), + (os.path.join(compiled_path, "dit_core", "dit_core_neuron.pt"), "DiT"), + ] + missing = [name for path, name in checks if not os.path.exists(path)] + if missing: + print(f"ERROR: Missing compiled artifacts for: {', '.join(missing)}") + print(f"Run with --compile first:") + print(f" python {sys.argv[0]} --compile") + return False + return True + + +def run_thinker(model_path, compiled_path, prompt, system_prompt, num_runs, temp_dir): + """Phase 1: Load Thinker once, generate num_runs times, return first result + avg time.""" + print(f"\n--- Phase 1: Thinker (text generation, {num_runs} runs) ---") + + thinker_compiled = os.path.join(compiled_path, "thinker_tp4") + output_file = os.path.join(temp_dir, "thinker_output.json") + + script = f''' +import torch, os, json, time +MODEL_PATH = "{model_path}" +COMPILED = "{thinker_compiled}" +OUTPUT = "{output_file}" +NUM_RUNS = {num_runs} + +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter +from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, Qwen25OmniInferenceConfig, +) +from transformers import AutoTokenizer + +nc = NeuronConfig( + tp_degree={TP_DEGREE}, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.7, top_k=20, top_p=0.8 + ), +) +cfg = Qwen25OmniInferenceConfig(nc, load_config=load_pretrained_config(MODEL_PATH)) +model = NeuronQwen25OmniForCausalLM(MODEL_PATH, cfg) + +t_load = time.time() +model.load(COMPILED) +t_load = time.time() - t_load +print(f"Model loaded in {{t_load:.1f}}s") + +tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) +if tok.pad_token is None: + tok.pad_token = tok.eos_token +adp = HuggingFaceGenerationAdapter(model) + +# Warmup (first inference is slower due to Neuron warm-up) +enc = tok(tok.apply_chat_template( + [{{"role":"user","content":"Hi"}}], tokenize=False, add_generation_prompt=True +), return_tensors="pt") +_ = adp.generate( + input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], + max_new_tokens=5, eos_token_id=[tok.eos_token_id, 151645], +) +print("Warmup done") + +chat = [ + {{"role":"system","content":"{system_prompt}"}}, + {{"role":"user","content":"{prompt}"}}, +] +enc = tok(tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True), return_tensors="pt") + +print(f"Running {{NUM_RUNS}} inferences (model stays loaded)...") +first_result = None +times = [] +for i in range(NUM_RUNS): + t0 = time.time() + out = adp.generate( + input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], + max_new_tokens=200, eos_token_id=[tok.eos_token_id, 151645], + ) + elapsed = time.time() - t0 + times.append(elapsed) + + prompt_len = enc["input_ids"].shape[1] + all_ids = out[0].tolist() + gen_ids = all_ids[prompt_len:] + text = tok.decode(gen_ids, skip_special_tokens=True) + n_tokens = len(gen_ids) + print(f" Run {{i+1}}/{{NUM_RUNS}}: {{n_tokens}} tokens in {{elapsed:.3f}}s - {{text[:80]}}") + + if first_result is None: + first_result = {{"all_ids": all_ids, "prompt_len": prompt_len, + "gen_text": text, "n_tokens": n_tokens}} + +avg_time = sum(times) / len(times) +first_result["gen_time"] = avg_time +first_result["all_times"] = times +first_result["load_time"] = t_load +print(f"Avg inference: {{avg_time:.3f}}s (load: {{t_load:.1f}}s, not included in avg)") + +with open(OUTPUT, "w") as f: + json.dump(first_result, f) +''' + ok = _run_subprocess(script, "thinker", temp_dir) + if not ok: + return None + + with open(output_file) as f: + return json.load(f) + + +def extract_hidden_states(model_path, thinker_result): + """Phase 2: Extract thinker hidden states via CPU forward pass.""" + print("\n--- Phase 2: Hidden state extraction (CPU) ---") + from transformers import Qwen2_5OmniForConditionalGeneration + + with Timer("Load HF model"): + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + model_path, torch_dtype=torch.float32, trust_remote_code=True, + ) + hf_model.eval() + _restore_embedding() + + full_ids = torch.tensor([thinker_result["all_ids"]], dtype=torch.long) + prompt_len = thinker_result["prompt_len"] + + with Timer("Forward pass"): + with torch.no_grad(): + outputs = hf_model.thinker( + input_ids=full_ids, output_hidden_states=True, return_dict=True, + ) + + return hf_model, outputs, full_ids, prompt_len + + +def prepare_talker_input(model_path, hf_model, outputs, full_ids, prompt_len, speaker, temp_dir): + """Phase 3: Build projected thinker states for the Talker.""" + print("\n--- Phase 3: Talker input preparation ---") + from transformers import AutoConfig + from safetensors.torch import load_file + from modeling_qwen25_omni_talker import ( + ThinkerToTalkerProjection, + ) + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + talker_cfg = hf_config.talker_config + + spk_dict = torch.load(os.path.join(model_path, "spk_dict.pt"), weights_only=True) + sp = spk_dict[speaker] + conditioning = sp["cond"].unsqueeze(0).float() if sp["cond"].dim() == 1 else sp["cond"].float() + if conditioning.dim() == 1: + conditioning = conditioning.unsqueeze(0) + reference_mel = sp["ref_mel"].unsqueeze(0).float() if sp["ref_mel"].dim() == 2 else sp["ref_mel"].float() + if reference_mel.dim() == 2: + reference_mel = reference_mel.unsqueeze(0) + bos_token = sp["bos_token"] + if isinstance(bos_token, torch.Tensor): + bos_token = bos_token.item() + + embedding_output = outputs.hidden_states[0] + last_hidden = outputs.hidden_states[-1] + total_len = full_ids.shape[1] + + context_embed = embedding_output[:, :prompt_len, :] + context_hidden = last_hidden[:, :prompt_len, :] + reply_embeds = [embedding_output[:, i:i+1, :] for i in range(prompt_len, total_len)] + reply_hiddens = [last_hidden[:, i:i+1, :] for i in range(prompt_len, total_len)] + + thinker_token_embeds = [context_embed] + reply_embeds + thinker_hidden_states_list = [context_hidden] + reply_hiddens + + thinker_reply_part = ( + torch.cat(thinker_hidden_states_list[1:], dim=1) + + torch.cat(thinker_token_embeds[1:], dim=1) + ) + talker_inputs_embeds = thinker_hidden_states_list[0] + thinker_token_embeds[0] + + thinker_embed_tokens = hf_model.thinker.get_input_embeddings() + bos_embed = thinker_embed_tokens(torch.tensor([[bos_token]], dtype=torch.long)) + talker_inputs_embeds = torch.cat([ + talker_inputs_embeds, bos_embed, thinker_reply_part[:, :1, :], + ], dim=1) + + talker_embed_weight = None + for fn in sorted(os.listdir(model_path)): + if fn.endswith(".safetensors"): + sd = load_file(os.path.join(model_path, fn)) + if "talker.model.embed_tokens.weight" in sd: + talker_embed_weight = sd["talker.model.embed_tokens.weight"] + break + if talker_embed_weight is not None: + talker_embed_layer = torch.nn.Embedding( + talker_embed_weight.shape[0], talker_embed_weight.shape[1] + ) + talker_embed_layer.weight.data = talker_embed_weight.float() + codec_bos_embed = talker_embed_layer( + torch.tensor([talker_cfg.tts_codec_start_token_id]) + ) + codec_pad_embed = talker_embed_layer( + torch.tensor([talker_cfg.tts_codec_pad_token_id]) + ) + talker_inputs_embeds[:, -1, :] += codec_bos_embed + talker_inputs_embeds[:, -2, :] += codec_pad_embed + + eos_embed = thinker_embed_tokens( + torch.tensor([[talker_cfg.tts_text_end_token_id]], dtype=torch.long) + ) + pad_embed = thinker_embed_tokens( + torch.tensor([[talker_cfg.tts_text_pad_token_id]], dtype=torch.long) + ) + thinker_reply_part = torch.cat([thinker_reply_part[:, 1:, :], eos_embed, pad_embed], dim=1) + + context_len = talker_inputs_embeds.shape[1] + n_reply = thinker_reply_part.shape[1] + print(f" Context: {context_len} tokens, Reply: {n_reply} tokens") + + proj_weight = proj_bias = None + for k, v in hf_model.state_dict().items(): + if "thinker_to_talker_proj.weight" in k: + proj_weight = v + if "thinker_to_talker_proj.bias" in k: + proj_bias = v + + proj = ThinkerToTalkerProjection(proj_weight.shape[1], proj_weight.shape[0]) + proj.proj.weight.data = proj_weight + if proj_bias is not None: + proj.proj.bias.data = proj_bias + + projected_context = proj(talker_inputs_embeds) + projected_reply = proj(thinker_reply_part) + + talker_input = { + "projected_context": projected_context, + "projected_reply": projected_reply, + "context_len": context_len, + "n_reply": n_reply, + "prompt_len": prompt_len, + "conditioning": conditioning, + "reference_mel": reference_mel, + } + torch.save(talker_input, os.path.join(temp_dir, "talker_input.pt")) + + del hf_model, outputs + gc.collect() + return context_len + + +def run_talker(model_path, compiled_path, context_len, num_runs, temp_dir): + """Phase 4: Load Talker once, generate num_runs times, return first result + avg time.""" + print(f"\n--- Phase 4: Talker (codec token generation, {num_runs} runs) ---") + + talker_compiled = os.path.join(compiled_path, "talker_tp4") + output_file = os.path.join(temp_dir, "talker_output.json") + + script = f''' +import torch, os, json, time +MODEL_PATH = "{model_path}" +COMPILED = "{talker_compiled}" +TEMP_DIR = "{temp_dir}" +NUM_RUNS = {num_runs} + +from transformers import AutoConfig +from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, TalkerInferenceConfig, TalkerNeuronConfig, +) +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter + +hf = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) +tc = hf.talker_config + +tnc = TalkerNeuronConfig( + tp_degree={TP_DEGREE}, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, +) +tic = TalkerInferenceConfig(neuron_config=tnc, load_config=load_pretrained_config(hf_config=tc)) +talker = NeuronQwen25OmniTalkerForCausalLM(MODEL_PATH, config=tic) + +t_load = time.time() +talker.load(COMPILED) +t_load = time.time() - t_load +print(f"Model loaded in {{t_load:.1f}}s") + +adp = HuggingFaceGenerationAdapter(talker) + +inp = torch.load(os.path.join(TEMP_DIR, "talker_input.pt"), weights_only=False) +projected_context = inp["projected_context"] +projected_reply = inp["projected_reply"] +context_len = inp["context_len"] + +codec_bos = tc.tts_codec_start_token_id +codec_eos = tc.tts_codec_end_token_id +codec_pad = tc.tts_codec_pad_token_id +codec_mask = tc.tts_codec_mask_token_id + +talker_input_ids = torch.cat([ + torch.full((1, context_len - 2), codec_mask, dtype=torch.long), + torch.tensor([[codec_pad]], dtype=torch.long), + torch.tensor([[codec_bos]], dtype=torch.long), +], dim=1) +talker_attention_mask = torch.ones_like(talker_input_ids, dtype=torch.long) + +max_gen = min(600, 2048 - context_len - 10) + +print(f"Running {{NUM_RUNS}} inferences (model stays loaded)...") +first_codes = None +times = [] +for i in range(NUM_RUNS): + # Re-set vision embeddings before each run (context encoding consumes them) + ve = projected_context.to(torch.bfloat16) + vm = torch.ones(1, context_len, 1, dtype=torch.int32) + reply = projected_reply.to(torch.bfloat16) + talker.set_vision_embeddings(ve, vm, thinker_reply_embeds=reply) + + t0 = time.time() + out = adp.generate( + input_ids=talker_input_ids, + attention_mask=talker_attention_mask, + max_new_tokens=max_gen, + eos_token_id=[codec_eos, codec_pad], + suppress_tokens=[codec_bos], + do_sample=True, temperature=0.9, top_k=40, top_p=0.8, + repetition_penalty=1.05, + ) + elapsed = time.time() - t0 + times.append(elapsed) + + gen_tokens = out[0, context_len:].tolist() + while gen_tokens and gen_tokens[-1] == codec_eos: + gen_tokens.pop() + print(f" Run {{i+1}}/{{NUM_RUNS}}: {{len(gen_tokens)}} codec tokens in {{elapsed:.3f}}s") + + if first_codes is None: + first_codes = gen_tokens + +avg_time = sum(times) / len(times) +print(f"Avg inference: {{avg_time:.3f}}s (load: {{t_load:.1f}}s, not included in avg)") + +result = {{"codes": first_codes, "gen_time": avg_time, "all_times": times, "load_time": t_load}} +with open(os.path.join(TEMP_DIR, "talker_output.json"), "w") as f: + json.dump(result, f) +''' + ok = _run_subprocess(script, "talker", temp_dir) + if not ok: + return None + + with open(output_file) as f: + return json.load(f) + + +def run_token2wav(model_path, compiled_path, codec_codes, num_runs, temp_dir, output_wav): + """Phase 5: Load DiT once, run num_runs times, save first result, report avg time.""" + print(f"\n--- Phase 5: Token2Wav (waveform synthesis, {num_runs} runs) ---") + + from transformers import AutoConfig + from safetensors.torch import load_file + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + t2w_cfg = hf_config.token2wav_config + + t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(t2w_cfg) + + state_dict = {} + for fn in sorted(os.listdir(model_path)): + if fn.endswith(".safetensors"): + sd = load_file(os.path.join(model_path, fn)) + for k, v in sd.items(): + if k.startswith("token2wav."): + state_dict[k[len("token2wav."):]] = v + t2w.load_state_dict(state_dict, strict=False) + + dit_compiled = os.path.join(compiled_path, "dit_core") + with Timer("Load compiled DiT"): + t2w.load_dit(dit_compiled) + _restore_embedding() + + code_tensor = torch.tensor([codec_codes], dtype=torch.long) + num_embeds = getattr(t2w_cfg.dit_config, "num_embeds", 8193) + if code_tensor.max() >= num_embeds: + code_tensor = code_tensor.clamp(0, num_embeds) + + inp = torch.load(os.path.join(temp_dir, "talker_input.pt"), weights_only=False) + conditioning = inp["conditioning"] + reference_mel = inp["reference_mel"] + + print(f" Running {num_runs} inferences (DiT stays loaded)...") + first_wav = None + times = [] + for i in range(num_runs): + t0 = time.time() + wav = t2w( + code=code_tensor, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=10, + guidance_scale=0.5, + ) + elapsed = time.time() - t0 + times.append(elapsed) + print(f" Run {i+1}/{num_runs}: {elapsed:.2f}s") + + if first_wav is None: + first_wav = wav + + avg_time = sum(times) / len(times) + print(f" Avg inference: {avg_time:.2f}s") + + audio_duration = 0 + if first_wav is not None and isinstance(first_wav, torch.Tensor) and first_wav.numel() > 0: + import soundfile as sf + wav_np = first_wav.detach().cpu().float().numpy().flatten() + sf.write(output_wav, wav_np, 24000) + audio_duration = len(wav_np) / 24000 + print(f" Audio: {audio_duration:.1f}s saved to {output_wav}") + + return audio_duration, avg_time, times + + +# ========================================================================== +# Main +# ========================================================================== + +def main(): + parser = argparse.ArgumentParser( + description="Qwen2.5-Omni-7B speech synthesis on Neuron" + ) + parser.add_argument( + "--compile", action="store_true", + help="Compile all Neuron components (one-time, ~30 min)", + ) + parser.add_argument( + "--num-runs", type=int, default=1, + help="Number of inference runs per component for benchmarking (default: 1)", + ) + parser.add_argument( + "--prompt", default=DEFAULT_PROMPT, + help="Text prompt for speech generation", + ) + parser.add_argument( + "--system-prompt", default=DEFAULT_SYSTEM, + help="System prompt", + ) + parser.add_argument( + "--speaker", default=DEFAULT_SPEAKER, choices=["Ethan", "Chelsie"], + help="Speaker voice (default: Ethan)", + ) + parser.add_argument( + "--model-path", default=MODEL_PATH, + help=f"Model path (default: {MODEL_PATH})", + ) + parser.add_argument( + "--compiled-path", default=COMPILED_PATH, + help=f"Compiled artifacts path (default: {COMPILED_PATH})", + ) + parser.add_argument( + "--output", default="speech_output.wav", + help="Output WAV file path (default: speech_output.wav)", + ) + args = parser.parse_args() + + model_path = _resolve_model_path(args.model_path) + compiled_path = args.compiled_path + num_runs = args.num_runs + + # --- Compile mode --- + if args.compile: + ok = compile_all(model_path, compiled_path) + sys.exit(0 if ok else 1) + + # --- Inference mode --- + if not _check_compiled(compiled_path): + sys.exit(1) + + temp_dir = os.path.join(compiled_path, "_tmp") + os.makedirs(temp_dir, exist_ok=True) + + print("=" * 60) + print("Qwen2.5-Omni Speech Pipeline (Neuron, TP=4)") + print("=" * 60) + print(f" Model: {model_path}") + print(f" Compiled: {compiled_path}") + print(f" Speaker: {args.speaker}") + print(f" Prompt: {args.prompt}") + print(f" Output: {args.output}") + print(f" Runs: {num_runs}") + t_total = time.time() + + # Phase 1: Thinker (load once, run N times in subprocess) + thinker_result = run_thinker( + model_path, compiled_path, args.prompt, args.system_prompt, + num_runs, temp_dir, + ) + if not thinker_result: + print("Thinker failed, aborting.") + return + print(f" Text: {thinker_result['gen_text'][:200]}") + + # Phase 2: Hidden states (CPU, run once) + hf_model, outputs, full_ids, prompt_len = extract_hidden_states( + model_path, thinker_result + ) + + # Phase 3: Talker input prep (CPU, run once) + context_len = prepare_talker_input( + model_path, hf_model, outputs, full_ids, prompt_len, + args.speaker, temp_dir, + ) + + # Phase 4: Talker (load once, run N times in subprocess) + talker_result = run_talker( + model_path, compiled_path, context_len, num_runs, temp_dir, + ) + if not talker_result or not talker_result["codes"]: + print("Talker failed or produced no tokens, aborting.") + return + print(f" {len(talker_result['codes'])} codec tokens (avg {talker_result['gen_time']:.3f}s)") + + # Phase 5: Token2Wav (load DiT once, run N times in main process) + audio_duration, t2w_avg, t2w_times = run_token2wav( + model_path, compiled_path, talker_result["codes"], + num_runs, temp_dir, args.output, + ) + + # Summary + total_time = time.time() - t_total + thinker_avg = thinker_result["gen_time"] + talker_avg = talker_result["gen_time"] + thinker_load = thinker_result.get("load_time", 0) + talker_load = talker_result.get("load_time", 0) + + print("\n" + "=" * 60) + print("RESULTS") + print("=" * 60) + print(f" Text: {thinker_result['gen_text'][:200]}") + print(f"\n Model load time (one-time cost, excluded from pipeline avg):") + print(f" Thinker: {thinker_load:.1f}s") + print(f" Talker: {talker_load:.1f}s") + print(f"\n Inference latency (avg of {num_runs} runs, model already loaded):") + print(f" Thinker: {thinker_avg:.3f}s ({thinker_result['n_tokens']} tokens)") + print(f" Talker: {talker_avg:.3f}s ({len(talker_result['codes'])} codec tokens)") + print(f" Token2Wav: {t2w_avg:.2f}s") + pipeline_avg = thinker_avg + talker_avg + t2w_avg + print(f" Pipeline: {pipeline_avg:.2f}s total") + if audio_duration > 0: + print(f"\n Audio: {audio_duration:.1f}s") + print(f" RTF: {pipeline_avg/audio_duration:.2f}x (pipeline_avg / audio_duration)") + if num_runs > 1: + thinker_times = thinker_result.get("all_times", []) + talker_times = talker_result.get("all_times", []) + print(f"\n Per-run breakdown ({num_runs} runs):") + print(f" Thinker: {['%.3f' % t for t in thinker_times]}") + print(f" Talker: {['%.3f' % t for t in talker_times]}") + print(f" Token2Wav: {['%.2f' % t for t in t2w_times]}") + print(f"\n Wall time: {total_time:.1f}s (includes model loading + CPU phases)") + print(f" Output: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen2.5-Omni-7B/perf_test/3_bench_qwen25_omni_7b.sh b/contrib/models/Qwen2.5-Omni-7B/perf_test/3_bench_qwen25_omni_7b.sh new file mode 100644 index 00000000..e398db20 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/perf_test/3_bench_qwen25_omni_7b.sh @@ -0,0 +1,171 @@ +#!/bin/bash +set -e + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="/opt/dlami/nvme/models/Qwen2.5-Omni-7B" +PORT=8000 +RESULTS_DIR="/var/tmp/bench_results/qwen25_omni_7b" +mkdir -p "$RESULTS_DIR" + +# Helper: wait for vLLM server to be ready +wait_for_server() { + echo " Waiting for vLLM server to be ready..." + for i in $(seq 1 360); do + if curl -s http://localhost:$PORT/health > /dev/null 2>&1; then + echo " Server ready! (${i}s * 5 = $((i*5))s)" + return 0 + fi + sleep 5 + done + echo " ERROR: Server did not start within 1800s" + return 1 +} + +# Helper: run benchmark +run_bench() { + local config_name=$1 + local concurrency=$2 + local num_prompts=$3 + + echo " Benchmark: concurrency=$concurrency, prompts=$num_prompts" + vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$num_prompts" \ + --random-input-len 900 \ + --random-output-len 90 \ + --random-range-ratio 0.03 \ + --max-concurrency "$concurrency" \ + 2>&1 | tee "$RESULTS_DIR/${config_name}_c${concurrency}.txt" + echo "" +} + +# Helper: stop server +stop_server() { + echo " Stopping vLLM server..." + pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 5 +} + +# Helper: quick sanity check +sanity_check() { + echo " Running sanity check..." + curl -s http://localhost:$PORT/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role": "user", "content": "What is 1+1? Answer briefly."}], + "model": "'"$MODEL_PATH"'", + "max_tokens": 64, + "temperature": 0.0, + "stream": false + }' | python3 -c "import sys,json; r=json.load(sys.stdin); print(' Sanity:', r['choices'][0]['message']['content'][:100])" 2>/dev/null || echo " Sanity check: could not parse response" +} + +echo "==========================================" +echo "Qwen2.5-Omni-7B Performance Benchmark" +echo "==========================================" +echo "Model: $MODEL_PATH" +echo "Results: $RESULTS_DIR" +echo "" + +############################################################################### +# Config 1: BS=1, TP=4, non-CB (baseline latency) +# Qwen2.5-Omni-7B is a dense 7B model, TP=4 is sufficient +############################################################################### +CONFIG_NAME="bs1_tp4" +echo "--- Config 1: BS=1, TP=4, non-CB (baseline) ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 4 \ + --max-model-len 4096 \ + --max-num-seqs 1 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 4, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": false, + "qkv_kernel_enabled": false, + "qkv_nki_kernel_enabled": false, + "attn_kernel_enabled": false, + "batch_size": 1, + "ctx_batch_size": 1, + "tkg_batch_size": 1, + "max_context_length": 4096, + "seq_len": 4096, + "is_continuous_batching": false, + "enable_bucketing": false, + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +stop_server + +############################################################################### +# Config 2: BS=4, TP=4, CB (throughput) +############################################################################### +CONFIG_NAME="bs4_tp4_cb" +echo "--- Config 2: BS=4, TP=4, CB ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 4 \ + --max-model-len 4096 \ + --max-num-seqs 4 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 4, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": false, + "qkv_kernel_enabled": false, + "qkv_nki_kernel_enabled": false, + "attn_kernel_enabled": false, + "batch_size": 4, + "ctx_batch_size": 1, + "tkg_batch_size": 4, + "max_context_length": 4096, + "seq_len": 4096, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [4096], + "token_generation_buckets": [4096], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 4 64 +stop_server + +echo "==========================================" +echo "Qwen2.5-Omni-7B benchmarks complete!" +echo "Results saved to: $RESULTS_DIR" +echo "==========================================" +ls -la "$RESULTS_DIR" diff --git a/contrib/models/Qwen2.5-Omni-7B/perf_test/apply_vllm_neuron_patch_qwen25omni.py b/contrib/models/Qwen2.5-Omni-7B/perf_test/apply_vllm_neuron_patch_qwen25omni.py new file mode 100644 index 00000000..68131ed0 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/perf_test/apply_vllm_neuron_patch_qwen25omni.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +"""Add Qwen2.5-Omni model support to vllm-neuron. + +This patch should be applied AFTER the MiMo/MiniMax patch (apply_vllm_neuron_patch.py). +It handles: + 1. Config extraction: Qwen2.5-Omni nests text config under thinker_config.text_config + 2. Architecture mapping: "Qwen2_5OmniModel" -> "qwen2_5_omni" model type + 3. Layer count extraction: get_num_layers_from_hf_config for nested config +""" + +import os + +# Patch 1 & 2: neuronx_distributed_model_loader.py +LOADER_FILE = "/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/vllm_neuron/worker/neuronx_distributed_model_loader.py" + +with open(LOADER_FILE) as f: + content = f.read() + +# 1. In _get_model_configs: handle Qwen2.5-Omni nested config +content = content.replace( + ' if architecture in NEURON_MULTI_MODAL_MODELS:\n' + ' config = getattr(config, "text_config", None)\n' + ' num_key_value_heads = getattr(config, "num_key_value_heads", None)', + ' if architecture in NEURON_MULTI_MODAL_MODELS:\n' + ' config = getattr(config, "text_config", None)\n' + ' # Qwen2.5-Omni: text config is nested under thinker_config.text_config\n' + ' if architecture == "Qwen2_5OmniModel":\n' + ' thinker_config = getattr(config, "thinker_config", None)\n' + ' if thinker_config is not None:\n' + ' config = getattr(thinker_config, "text_config", config)\n' + ' num_key_value_heads = getattr(config, "num_key_value_heads", None)', +) + +# 2. In _get_neuron_model_cls: handle Qwen2_5OmniModel architecture +content = content.replace( + ' try:\n' + ' if "For" in architecture:', + ' # Qwen2.5-Omni: architecture is "Qwen2_5OmniModel" (no "For" in name)\n' + ' if architecture == "Qwen2_5OmniModel":\n' + ' return MODEL_TYPES["qwen2_5_omni"]["causal-lm"]\n' + '\n' + ' try:\n' + ' if "For" in architecture:', +) + +with open(LOADER_FILE, "w") as f: + f.write(content) + +print("Patch 1/2: neuronx_distributed_model_loader.py updated") + +# Patch 3: utils.py - handle Qwen2.5-Omni nested config for layer count +UTILS_FILE = "/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/vllm_neuron/worker/utils.py" + +with open(UTILS_FILE) as f: + content = f.read() + +content = content.replace( + ' # Sum nested configs (multimodal models)\n' + ' total = 0\n' + ' for attr in ["text_config", "vision_config"]:', + ' # Qwen2.5-Omni: check thinker_config.text_config\n' + ' thinker_config = getattr(hf_config, "thinker_config", None)\n' + ' if thinker_config is not None:\n' + ' text_config = getattr(thinker_config, "text_config", None)\n' + ' if text_config is not None:\n' + ' layers = getattr(text_config, "num_hidden_layers", None)\n' + ' if layers is not None:\n' + ' return layers\n' + '\n' + ' # Sum nested configs (multimodal models)\n' + ' total = 0\n' + ' for attr in ["text_config", "vision_config"]:', +) + +with open(UTILS_FILE, "w") as f: + f.write(content) + +print("Patch 2/2: utils.py updated") +print() +print("Qwen2.5-Omni vllm-neuron patch applied successfully!") +print(" 1. Added thinker_config.text_config extraction in _get_model_configs") +print(" 2. Added Qwen2_5OmniModel -> qwen2_5_omni mapping in _get_neuron_model_cls") +print(" 3. Added thinker_config.text_config layer count extraction in utils.py") diff --git a/contrib/models/Qwen2.5-Omni-7B/src/__init__.py b/contrib/models/Qwen2.5-Omni-7B/src/__init__.py index 3d7d8c14..62f2a61c 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/__init__.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/__init__.py @@ -1,3 +1,10 @@ -from .modeling_qwen2_5_omni import NeuronQwen2_5OmniForCausalLM, Qwen2_5OmniInferenceConfig +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Importing this package applies an upstream bug fix for +# HuggingFaceGenerationAdapter.prepare_inputs_for_generation so that +# adapter.generate() does not raise NameError when forwarding +# tensor_capture_hook downstream. The fix is idempotent and only activates +# if the upstream file still contains the bug. -__all__ = ["NeuronQwen2_5OmniForCausalLM", "Qwen2_5OmniInferenceConfig"] +from . import _upstream_compat # noqa: F401 (side-effect import) diff --git a/contrib/models/Qwen2.5-Omni-7B/src/_upstream_compat.py b/contrib/models/Qwen2.5-Omni-7B/src/_upstream_compat.py new file mode 100644 index 00000000..36bc015c --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/_upstream_compat.py @@ -0,0 +1,139 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Upstream compatibility shims for Qwen2.5-Omni. +# +# These patches sit here (not in src/neuronx_distributed_inference/) so that +# this contrib package has zero direct invasion on the upstream source tree. +# Each patch is idempotent: if upstream has already fixed the issue the +# original object stays unchanged. + +import logging +import inspect + +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + +logger = logging.getLogger(__name__) + + +def _patch_prepare_inputs_for_generation(): + """Fix upstream NameError in HuggingFaceGenerationAdapter. + + Upstream's prepare_inputs_for_generation builds model_inputs with + "tensor_capture_hook": tensor_capture_hook + but never defines tensor_capture_hook as a parameter or extracts it from + **kwargs, so adapter.generate() raises NameError. Qwen2.5-Omni's Talker + drives generation via adapter.generate(), so this breaks it out-of-the-box. + + This wrapper re-dispatches to the upstream method with tensor_capture_hook + materialized as a local so the original body sees a defined name. + """ + src = inspect.getsource(HuggingFaceGenerationAdapter.prepare_inputs_for_generation) + references_hook = '"tensor_capture_hook": tensor_capture_hook' in src + already_extracted = 'tensor_capture_hook = kwargs.get("tensor_capture_hook"' in src + if already_extracted or not references_hook: + return # upstream already consistent + + original = HuggingFaceGenerationAdapter.prepare_inputs_for_generation + + def patched(self, input_ids, *args, **kwargs): + # Inject the missing local via the frame's globals is fragile; the + # cleanest fix is to guarantee the name exists in the caller's scope. + # Since Python resolves bare names via function locals/globals, and + # the original body has no local binding, we surface it through the + # **kwargs extraction idiom already used for input_capture_hook. + # The patched body below mirrors upstream but extracts the hook. + import torch # noqa: F401 + self.prev_kv_cache_populated = self.neuron_model.kv_cache_populated + if self.neuron_model.kv_cache_populated: + input_ids = input_ids[:, -1:] + + past_key_values = kwargs.pop("past_key_values", None) + attention_mask = kwargs.pop("attention_mask", None) + inputs_embeds = kwargs.pop("inputs_embeds", None) + sampling_params = kwargs.pop("sampling_params", None) + adapter_ids = kwargs.pop("adapter_ids", None) + divergence_idx = kwargs.pop("divergence_idx", None) + + accepted_indices = kwargs.get("accepted_indices", None) + current_length = kwargs.get("current_length", None) + medusa_mask = kwargs.get("medusa_mask", None) + scatter_index = kwargs.get("scatter_index", None) + position_ids = kwargs.get("position_ids", None) + input_capture_hook = kwargs.get("input_capture_hook", None) + tensor_capture_hook = kwargs.get("tensor_capture_hook", None) + + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + if self.input_start_offsets: + if len(self.input_start_offsets) > 1: + import torch as _torch + position_ids += _torch.tensor( + self.input_start_offsets, + dtype=position_ids.dtype, + device=position_ids.device, + )[:, None] + else: + position_ids += self.input_start_offsets[0] + import torch as _torch + for i, offset in enumerate(self.input_start_offsets): + position_ids[i, 0:offset] = _torch.arange(offset) + else: + position_ids.masked_fill_(attention_mask == 0, 1) + + if self.neuron_model.kv_cache_populated: + import torch as _torch + position_ids = _torch.amax(position_ids, 1, keepdim=True) + position_ids = position_ids + 1 + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", False), + "attention_mask": attention_mask, + "medusa_args": (accepted_indices, current_length, medusa_mask, scatter_index), + "sampling_params": sampling_params, + "input_capture_hook": input_capture_hook, + "tensor_capture_hook": tensor_capture_hook, + "adapter_ids": adapter_ids, + } + ) + + tf_args = [] + if self.neuron_config.tensor_replacement_config: + from neuronx_distributed_inference.utils.tensor_replacement.registry import ( + TensorReplacementRegister, + ) + reg = TensorReplacementRegister.get_instance() + tf, masks = reg.step_args( + self.generation_step, + divergence_idx=True if divergence_idx else False, + ) + tf_args = tf + masks + + if tf_args: + model_inputs["tf_args"] = tf_args + + additional_kwargs = self.neuron_model.get_required_kwargs() + for arg in additional_kwargs: + model_inputs.update({arg: kwargs.get(arg, None)}) + + return model_inputs + + HuggingFaceGenerationAdapter.prepare_inputs_for_generation = patched + HuggingFaceGenerationAdapter.prepare_inputs_for_generation.__doc__ = ( + "[Qwen2.5-Omni contrib patched] " + (original.__doc__ or "") + ) + logger.info( + "Qwen2.5-Omni contrib: patched HuggingFaceGenerationAdapter." + "prepare_inputs_for_generation to extract tensor_capture_hook from kwargs." + ) + + +_patch_prepare_inputs_for_generation() diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni.py new file mode 100644 index 00000000..4e889aa0 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni.py @@ -0,0 +1,1106 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni support for NXD inference. +# +# Provides two modes: +# 1. Text-only (Thinker): NeuronQwen25OmniForCausalLM +# - Reuses Qwen2 decoder with thinker.model.* prefix remapping +# 2. Multimodal (Vision + Text): NeuronQwen25OmniMultimodalForCausalLM +# - Vision encoder: Qwen2.5-Omni ViT (SwiGLU, RMSNorm, separate QKV) +# - Text decoder: Qwen2-VL text model (multimodal RoPE) +# +# Reference: https://huggingface.co/Qwen/Qwen2.5-Omni-7B + +"""Qwen2.5-Omni model for NXD inference.""" + +import copy +import gc +import logging +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import torch +from transformers.modeling_outputs import CausalLMOutputWithPast + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.image_to_text_model_base import ( + ImageToTextInferenceConfig, + NeuronBaseForImageToText, +) +from neuronx_distributed_inference.models.model_wrapper import VISION_ENCODER_MODEL_TAG +from neuronx_distributed_inference.models.qwen2.modeling_qwen2 import ( + NeuronQwen2ForCausalLM, + NeuronQwen2Model, + convert_state_dict_to_fused_qkv, +) + +logger = logging.getLogger("Neuron") + + +# Attributes to extract from the thinker's text_config to the top-level config. +_TEXT_CONFIG_ATTRS = [ + "hidden_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "vocab_size", + "intermediate_size", + "max_position_embeddings", + "rope_theta", + "rms_norm_eps", + "hidden_act", + "tie_word_embeddings", + "max_window_layers", + "use_sliding_window", + "sliding_window", +] + + +class Qwen25OmniInferenceConfig(InferenceConfig): + """Inference config for Qwen2.5-Omni (Thinker text component). + + Handles the nested config structure: the HF config has attributes under + thinker_config.text_config that we need at the top level for NxDI. + """ + + def add_derived_config(self): + self.num_cores_per_group = 1 + # Qwen2.5-Omni text model has QKV bias but no output projection bias + self.qkv_bias = True + self.o_bias = False + + # Extract text config attributes from nested thinker_config + if hasattr(self, "thinker_config"): + thinker_cfg = self.thinker_config + # When loaded from saved JSON, thinker_config is a plain dict + if isinstance(thinker_cfg, dict): + thinker_cfg = SimpleNamespace(**thinker_cfg) + self.thinker_config = thinker_cfg + + text_cfg = thinker_cfg.text_config + if isinstance(text_cfg, dict): + text_cfg = SimpleNamespace(**text_cfg) + thinker_cfg.text_config = text_cfg + + # Text config attributes always take precedence over top-level + # defaults from PretrainedConfig (e.g. tie_word_embeddings defaults + # to True at the top level but is False in text_config). + for attr in _TEXT_CONFIG_ATTRS: + if hasattr(text_cfg, attr): + setattr(self, attr, getattr(text_cfg, attr)) + + # Set pad_token_id from thinker_config + if hasattr(thinker_cfg, "pad_token_id"): + if not hasattr(self, "pad_token_id") or self.pad_token_id is None: + self.pad_token_id = thinker_cfg.pad_token_id + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "pad_token_id", + "vocab_size", + "max_position_embeddings", + "rope_theta", + "rms_norm_eps", + "hidden_act", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + +class NeuronQwen25OmniForCausalLM(NeuronQwen2ForCausalLM): + """Qwen2.5-Omni Thinker text model for Causal LM on Neuron. + + Reuses the Qwen2 model architecture since the Thinker's text backbone + is architecturally identical to Qwen2.5. The main differences are: + - Weight keys are prefixed with 'thinker.model.' / 'thinker.lm_head.' + - Non-text weights (talker, token2wav, audio_tower, visual) are discarded + """ + + _model_cls = NeuronQwen2Model + _STATE_DICT_MODEL_PREFIX = "thinker.model." + + @staticmethod + def load_hf_model(model_path: str, **kwargs): + """Load the full Qwen2.5-Omni model from HuggingFace. + + Note: We load the full model and filter to thinker text weights + in convert_hf_to_neuron_state_dict. + """ + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + **kwargs, + ) + + @classmethod + def get_config_cls(cls) -> Type[Qwen25OmniInferenceConfig]: + return Qwen25OmniInferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: Dict[str, Any], + config: Qwen25OmniInferenceConfig, + ) -> Dict[str, Any]: + """Convert Qwen2.5-Omni state dict to NxDI format. + + 1. Keep only thinker text model weights (discard talker, token2wav, audio, visual) + 2. Map 'thinker.lm_head.' -> 'lm_head.' + 3. Apply standard Qwen2 conversions (fused QKV, rank utils, etc.) + """ + neuron_config = config.neuron_config + + # Filter: keep only thinker text weights and lm_head + # After base-class prefix stripping, 'thinker.model.*' becomes '*' + # but 'thinker.lm_head.*' is NOT stripped, so handle it here. + keys_to_remove = [] + keys_to_rename = {} + for key in state_dict: + if key.startswith("thinker.lm_head."): + # Map thinker.lm_head.weight -> lm_head.weight + new_key = key.replace("thinker.lm_head.", "lm_head.", 1) + keys_to_rename[key] = new_key + elif not key.startswith(("layers.", "embed_tokens.", "norm.", "lm_head.")): + # After base-class prefix stripping, valid text keys start with + # layers.*, embed_tokens.*, norm.*, or lm_head.* + # Everything else (talker, token2wav, audio_tower, visual) should be removed + keys_to_remove.append(key) + + # Apply renames + for old_key, new_key in keys_to_rename.items(): + state_dict[new_key] = state_dict.pop(old_key) + + # Remove non-text weights + for key in keys_to_remove: + del state_dict[key] + + gc.collect() + logger.info( + "Filtered state dict to %d thinker text model weights", len(state_dict) + ) + + # Add rank utilities (same as Qwen2) + if neuron_config.vocab_parallel: + state_dict["embed_tokens.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + num_layers = config.num_hidden_layers + tp_degree = neuron_config.tp_degree + for i in range(num_layers): + state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + if neuron_config.fused_qkv: + state_dict = convert_state_dict_to_fused_qkv(state_dict, config) + + state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + return state_dict + + def get_compiler_args(self): + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + "--auto-cast=none " + "--model-type transformer " + "-O1" + ) + compiler_args += ( + " --tensorizer-options='" + "--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 " + "--vectorize-strided-dma'" + ) + compiler_args += " --internal-hlo2tensorizer-options='--verify-hlo=true'" + return compiler_args + + +# --------------------------------------------------------------------------- +# Multimodal (Vision + Text) support +# --------------------------------------------------------------------------- + +# Keys from thinker_config.text_config to copy to the top-level multimodal config +_MULTIMODAL_TEXT_CONFIG_KEYS = [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "pad_token_id", + "vocab_size", + "intermediate_size", + "max_position_embeddings", + "rms_norm_eps", + "rope_theta", + "rope_scaling", + "hidden_act", + "bos_token_id", + "eos_token_id", + "qkv_bias", + "o_bias", + "image_token_id", + "vision_token_id", + "video_token_id", + "vision_start_token_id", + "vision_end_token_id", +] + + +class Qwen25OmniMultimodalInferenceConfig(ImageToTextInferenceConfig): + """Inference config for Qwen2.5-Omni multimodal (vision + text). + + Handles the nested config structure where text_config and vision_config + are under thinker_config. Extracts them to the top level as required + by ImageToTextInferenceConfig. + """ + + def __init__( + self, + text_neuron_config, + vision_neuron_config, + fused_spec_config=None, + load_config=None, + metadata: Optional[Dict] = None, + **kwargs, + ): + # Extract text_config and vision_config from thinker_config + # The HF config nests them: thinker_config.text_config, thinker_config.vision_config + thinker = kwargs.get("thinker_config", None) + if thinker is not None: + if hasattr(thinker, "__dict__") and not isinstance(thinker, dict): + thinker = vars(thinker) + if isinstance(thinker, dict): + if "text_config" not in kwargs and "text_config" in thinker: + tc = thinker["text_config"] + kwargs["text_config"] = ( + vars(tc) if hasattr(tc, "__dict__") and not isinstance(tc, dict) else tc + ) + if "vision_config" not in kwargs and "vision_config" in thinker: + vc = thinker["vision_config"] + kwargs["vision_config"] = ( + vars(vc) if hasattr(vc, "__dict__") and not isinstance(vc, dict) else vc + ) + # Extract audio_config from thinker_config + if "audio_config" not in kwargs and "audio_config" in thinker: + ac = thinker["audio_config"] + kwargs["audio_config"] = ( + vars(ac) if hasattr(ac, "__dict__") and not isinstance(ac, dict) else ac + ) + # Extract special token IDs from thinker_config + for token_key in [ + "image_token_index", "audio_token_index", "video_token_index", + "audio_start_token_id", "audio_end_token_id", + "vision_start_token_id", "vision_end_token_id", + "vision_token_id", "pad_token_id", + ]: + if token_key in thinker and token_key not in kwargs: + kwargs[token_key] = thinker[token_key] + + super().__init__( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + fused_spec_config=fused_spec_config, + load_config=load_config, + metadata=metadata, + **kwargs, + ) + + self.add_special_config() + self.validate_model_supported_configs() + + def add_special_config(self): + self.num_cores_per_group = 1 + self.qkv_bias = True + self.o_bias = False + + # Vision config derived attributes + self.vision_config.head_dim = ( + self.vision_config.embed_dim // self.vision_config.num_heads + ) + self.vision_config.num_cores_per_group = 1 + + # Vision encoder MUST use separate Q/K/V (not fused) + if getattr(self.vision_config.neuron_config, "fused_qkv", True): + self.vision_config.neuron_config.fused_qkv = False + logger.info( + "Qwen2.5-Omni vision encoder: set fused_qkv=False " + "(separate Q/K/V projections)" + ) + + # Copy text config keys to top-level (for compatibility) + for key in _MULTIMODAL_TEXT_CONFIG_KEYS: + if hasattr(self.text_config, key): + setattr(self, key, getattr(self.text_config, key)) + + # Map Qwen2.5-Omni token IDs to Qwen2-VL compatible names + if hasattr(self, "image_token_index"): + self.image_token_id = self.image_token_index + self.text_config.image_token_id = self.image_token_index + if hasattr(self, "video_token_index"): + self.video_token_id = self.video_token_index + self.text_config.video_token_id = self.video_token_index + if hasattr(self, "audio_token_index"): + self.audio_token_id = self.audio_token_index + self.text_config.audio_token_id = self.audio_token_index + + # Set pad_token_id + if hasattr(self, "pad_token_id"): + self.text_config.pad_token_id = self.pad_token_id + + # Store audio_config as SimpleNamespace for attribute access + if hasattr(self, "audio_config") and isinstance(self.audio_config, dict): + self.audio_config = SimpleNamespace(**self.audio_config) + + def validate_model_supported_configs(self): + # Disable unsupported features for text model + unsupported_text = [ + "is_prefix_caching", + "is_chunked_prefill", + "is_medusa", + "enable_fused_speculation", + ] + for cfg_name in unsupported_text: + if getattr(self.text_config.neuron_config, cfg_name, False): + setattr(self.text_config.neuron_config, cfg_name, False) + logger.warning( + f"Qwen2.5-Omni text model does not support " + f"'{cfg_name}'. Disabled." + ) + + # Disable unsupported features for vision model + unsupported_vision = [ + "sequence_parallel_enabled", + "flash_decoding_enabled", + "qkv_kernel_enabled", + ] + for cfg_name in unsupported_vision: + if getattr(self.vision_config.neuron_config, cfg_name, False): + setattr(self.vision_config.neuron_config, cfg_name, False) + logger.warning( + f"Qwen2.5-Omni vision model does not support " + f"'{cfg_name}'. Disabled." + ) + + def get_required_attributes(self) -> List[str]: + return [ + "text_config", + "vision_config", + "text_config.hidden_size", + "text_config.num_attention_heads", + "text_config.num_hidden_layers", + "text_config.num_key_value_heads", + "text_config.pad_token_id", + "text_config.vocab_size", + "text_config.max_position_embeddings", + "text_config.rope_theta", + "text_config.rms_norm_eps", + "text_config.hidden_act", + "vision_config.depth", + "vision_config.embed_dim", + "vision_config.num_heads", + "vision_config.in_channels", + "vision_config.patch_size", + "vision_config.spatial_merge_size", + "vision_config.temporal_patch_size", + "vision_config.out_hidden_size", + "vision_config.intermediate_size", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + +class NeuronQwen25OmniMultimodalForCausalLM(NeuronBaseForImageToText): + """Qwen2.5-Omni multimodal model (vision encoder + text decoder) on Neuron. + + Reuses Qwen2-VL text model components (same mRoPE architecture) and + the Qwen2.5-Omni vision encoder (SwiGLU, RMSNorm, separate QKV). + """ + + # Import lazily to avoid circular imports + @staticmethod + def _get_text_model_cls(): + from neuronx_distributed_inference.models.qwen2_vl.modeling_qwen2_vl_text import ( + NeuronQwen2VLTextModel, + ) + return NeuronQwen2VLTextModel + + @staticmethod + def _get_text_model_wrapper(): + from neuronx_distributed_inference.models.qwen2_vl.modeling_qwen2_vl_text import ( + Qwen2VLTextModelWrapper, + ) + return Qwen2VLTextModelWrapper + + @staticmethod + def _get_vision_model_cls(): + from modeling_qwen25_omni_vision import ( + NeuronQwen25OmniVisionModel, + ) + return NeuronQwen25OmniVisionModel + + @staticmethod + def _get_vision_model_wrapper(): + from modeling_qwen25_omni_vision import ( + Qwen25OmniVisionModelWrapper, + ) + return Qwen25OmniVisionModelWrapper + + @staticmethod + def _get_audio_encoder_cls(): + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + return NeuronQwen25OmniAudioEncoder + + @staticmethod + def _get_talker_cls(): + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + ) + return NeuronQwen25OmniTalker + + @staticmethod + def _get_neuron_talker_cls(): + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, + ) + return NeuronQwen25OmniTalkerForCausalLM + + @staticmethod + def _get_talker_config_cls(): + from modeling_qwen25_omni_talker import ( + TalkerInferenceConfig, + ) + return TalkerInferenceConfig + + @staticmethod + def _get_thinker_projection_cls(): + from modeling_qwen25_omni_talker import ( + ThinkerToTalkerProjection, + ) + return ThinkerToTalkerProjection + + @staticmethod + def _get_token2wav_cls(): + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2Wav, + ) + return NeuronQwen25OmniToken2Wav + + @staticmethod + def _get_neuron_token2wav_cls(): + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + return NeuronQwen25OmniToken2WavWithNeuronDiT + + def __init__(self, *args, **kwargs): + super().__init__( + self._get_text_model_cls(), + self._get_vision_model_cls(), + self._get_text_model_wrapper(), + self._get_vision_model_wrapper(), + *args, + **kwargs, + ) + self.audio_encoder = None + self.talker = None + self.token2wav = None + self.speaker_map = {} + + def get_vision_compiler_args(self) -> str: + return ( + "--auto-cast=none --model-type=transformer " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 ' -O1 " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + def get_compiler_args(self) -> str: + return ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + "--auto-cast=none --model-type=transformer -O1 " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 " + "--vectorize-strided-dma' " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + def get_required_kwargs(self) -> List[str]: + return ["pixel_values", "vision_mask", "image_grid_thw"] + + def enable_vision_encoder( + self, enable_wlt_optimization: bool = True, **model_init_kwargs + ): + new_config = copy.deepcopy(self.config) + self.vision_encoder_model = self._get_vision_model_wrapper()( + config=new_config, + model_cls=self._get_vision_model_cls(), + tag=VISION_ENCODER_MODEL_TAG, + compiler_args=self.get_vision_compiler_args(), + model_init_kwargs=model_init_kwargs, + priority_model_idx=(0 if enable_wlt_optimization else None), + pipeline_execution=True, + return_ranked_to_cpu=False, + ) + self.vision_models.append(self.vision_encoder_model) + + def enable_audio_encoder(self, state_dict=None): + """Initialize the audio encoder (CPU frontend/postprocessor + Neuron transformer). + + The audio encoder is a hybrid CPU+Neuron module: + - CPU: Conv1d frontend, positional embeddings, chunking + - Neuron (TP=4): 32 transformer layers with block-diagonal attention + - CPU: AvgPool, LayerNorm, Linear projection + + The CPU components are loaded from the converted state dict. + The Neuron transformer must be compiled/loaded separately via + compile_audio_encoder() and load_audio_encoder(). + + Args: + state_dict: Converted state dict containing audio_tower.* keys + (already split into frontend.*, transformer.*, postprocessor.*). + If None, the encoder is created with random weights. + """ + audio_config = getattr(self.config, "audio_config", None) + if audio_config is None: + logger.warning( + "No audio_config found in model config. " + "Audio encoder will not be initialized." + ) + return + + AudioEncoderCls = self._get_audio_encoder_cls() + dtype = torch.bfloat16 + if hasattr(self.config, "neuron_config"): + dtype = getattr(self.config.neuron_config, "torch_dtype", dtype) + + if state_dict is not None: + self.audio_encoder = AudioEncoderCls.from_pretrained_state_dict( + audio_config, state_dict, dtype=dtype + ) + else: + self.audio_encoder = AudioEncoderCls(audio_config, dtype=dtype) + + self.audio_encoder.eval() + logger.info( + "Audio encoder initialized (CPU frontend/postprocessor, " + "Neuron transformer pending compile/load)" + ) + + def compile_audio_encoder(self, compiled_model_path, audio_neuron_config=None): + """Compile the audio encoder transformer layers on Neuron. + + Args: + compiled_model_path: Path to save compiled model artifacts. + audio_neuron_config: Optional NeuronConfig for audio transformer. + If None, creates a default config with TP matching the text model. + """ + if self.audio_encoder is None: + raise RuntimeError("Call enable_audio_encoder() first") + + from modeling_qwen25_omni_audio import ( + AudioEncoderInferenceConfig, + NeuronQwen25OmniForAudioEncoding, + ) + + audio_config = getattr(self.config, "audio_config", None) + if isinstance(audio_config, dict): + from types import SimpleNamespace + audio_config = SimpleNamespace(**audio_config) + + if audio_neuron_config is None: + # Default: match text model TP degree, reasonable seq_len buckets + tp_degree = self.neuron_config.tp_degree + audio_neuron_config = NeuronConfig( + tp_degree=tp_degree, + torch_dtype=self.neuron_config.torch_dtype, + batch_size=1, + # Audio seq_len buckets: typical values after conv + buckets=[256, 512, 1024, 1500], + ) + + audio_inf_config = AudioEncoderInferenceConfig( + neuron_config=audio_neuron_config, + audio_config=vars(audio_config) if hasattr(audio_config, '__dict__') else audio_config, + ) + + audio_app = NeuronQwen25OmniForAudioEncoding(audio_inf_config) + audio_app.compile(compiled_model_path) + + logger.info("Audio encoder transformer compiled to %s", compiled_model_path) + return audio_app + + def load_audio_encoder(self, compiled_model_path, audio_app=None): + """Load compiled audio encoder transformer layers. + + Args: + compiled_model_path: Path to compiled model artifacts. + audio_app: Optional pre-compiled NeuronQwen25OmniForAudioEncoding. + If None, creates and loads from compiled_model_path. + """ + if self.audio_encoder is None: + raise RuntimeError("Call enable_audio_encoder() first") + + if audio_app is None: + from modeling_qwen25_omni_audio import ( + AudioEncoderInferenceConfig, + NeuronQwen25OmniForAudioEncoding, + ) + # Load from compiled artifacts + audio_app = NeuronQwen25OmniForAudioEncoding.load(compiled_model_path) + + self.audio_encoder.transformer = audio_app.model + logger.info("Audio encoder transformer loaded from %s", compiled_model_path) + + def enable_talker(self, state_dict=None, use_neuron=False): + """Initialize the Talker model. + + The Talker converts Thinker hidden states into codec tokens for + speech synthesis. + + Args: + state_dict: Converted state dict containing talker keys. + If None, the Talker is created with random weights. + use_neuron: If True, use the Neuron-compiled Talker + (NeuronQwen25OmniTalkerForCausalLM). Requires separate + compilation with compile_talker() / load_talker(). + If False (default), use the CPU-based HF wrapper. + """ + talker_config = getattr(self.config, "talker_config", None) + if talker_config is None: + logger.warning( + "No talker_config found in model config. " + "Talker will not be initialized." + ) + return + + if use_neuron: + # Neuron Talker — just mark for later compile/load + logger.info( + "Neuron Talker mode enabled. Call compile_talker() or " + "load_talker() to activate." + ) + self._talker_use_neuron = True + self._talker_state_dict = state_dict + # Initialize CPU projection for thinker states + if state_dict is not None: + ProjCls = self._get_thinker_projection_cls() + self.thinker_to_talker_proj = ProjCls.from_state_dict(state_dict) + logger.info("Thinker→Talker CPU projection initialized") + else: + # CPU Talker (default) + TalkerCls = self._get_talker_cls() + dtype = torch.bfloat16 + + if state_dict is not None: + self.talker = TalkerCls.from_pretrained_state_dict( + talker_config, state_dict, dtype=dtype + ) + else: + self.talker = TalkerCls(talker_config, dtype=dtype) + + logger.info("Talker initialized on CPU") + + def compile_talker(self, compiled_model_path, talker_neuron_config=None): + """Compile the Talker on Neuron. + + Creates a NeuronQwen25OmniTalkerForCausalLM, converts the state + dict (fusing embedding + projection), and compiles. + + Recommended neuron_config: + - tp_degree=4 (12 Q heads / 4 = 3 per rank) + - batch_size=1 + - seq_len=4096 (max codec tokens) + + Args: + compiled_model_path: Path to save compiled model + talker_neuron_config: NeuronConfig for Talker compilation. + If None, creates a default config with TP=4. + """ + talker_config = getattr(self.config, "talker_config", None) + if talker_config is None: + raise RuntimeError("No talker_config in model config") + + NeuronTalkerCls = self._get_neuron_talker_cls() + TalkerConfigCls = self._get_talker_config_cls() + + if talker_neuron_config is None: + from modeling_qwen25_omni_talker import ( + TalkerNeuronConfig, + ) + talker_neuron_config = TalkerNeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=4096, + torch_dtype=torch.bfloat16, + ) + + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + inference_config = TalkerConfigCls( + neuron_config=talker_neuron_config, + load_config=load_pretrained_config(hf_config=talker_config), + ) + + app = NeuronTalkerCls(self.model_path, config=inference_config) + app.compile(compiled_model_path) + logger.info("Talker compiled to %s", compiled_model_path) + + def load_talker(self, compiled_model_path, talker_app=None): + """Load a compiled Talker from disk. + + Args: + compiled_model_path: Path to compiled model artifacts. + talker_app: Optional pre-compiled NeuronQwen25OmniTalkerForCausalLM. + """ + if talker_app is None: + NeuronTalkerCls = self._get_neuron_talker_cls() + talker_app = NeuronTalkerCls.load(compiled_model_path) + + self.talker = talker_app + logger.info("Neuron Talker loaded from %s", compiled_model_path) + + def enable_token2wav( + self, state_dict=None, speaker_dict_path=None, use_neuron_dit=False + ): + """Initialize the Token2Wav vocoder. + + Token2Wav converts codec tokens from the Talker into audio + waveforms using DiT + BigVGAN. + + Args: + state_dict: Converted state dict containing token2wav keys. + If None, Token2Wav is created with random weights. + speaker_dict_path: Path to spk_dict.pt for speaker conditioning. + If provided, loads the speaker map for audio generation. + use_neuron_dit: If True, use the Neuron-accelerated version + (DiT on Neuron, ODE loop + BigVGAN on CPU). + Call compile_token2wav_dit() / load_token2wav_dit() to compile. + """ + token2wav_config = getattr(self.config, "token2wav_config", None) + if token2wav_config is None: + logger.warning( + "No token2wav_config found in model config. " + "Token2Wav will not be initialized." + ) + return + + if use_neuron_dit: + Token2WavCls = self._get_neuron_token2wav_cls() + else: + Token2WavCls = self._get_token2wav_cls() + + if state_dict is not None: + self.token2wav = Token2WavCls.from_pretrained_state_dict( + token2wav_config, state_dict + ) + else: + self.token2wav = Token2WavCls(token2wav_config) + + if speaker_dict_path is not None: + self.speaker_map = Token2WavCls.load_speaker_dict(speaker_dict_path) + logger.info( + "Loaded %d speakers: %s", + len(self.speaker_map), + list(self.speaker_map.keys()), + ) + + mode = "CPU (Neuron DiT capable)" if use_neuron_dit else "CPU" + logger.info("Token2Wav initialized on %s (float32)", mode) + + def compile_token2wav_dit(self, compiled_path, **kwargs): + """Compile the Token2Wav DiT on Neuron. + + Requires enable_token2wav(use_neuron_dit=True) to be called first. + + Args: + compiled_path: Directory to save compiled DiT + **kwargs: Additional args passed to compile_dit() + (max_codec_len, max_mel_len, batch_size) + """ + if self.token2wav is None: + raise RuntimeError("Call enable_token2wav(use_neuron_dit=True) first") + if not hasattr(self.token2wav, "compile_dit"): + raise RuntimeError( + "Token2Wav is not Neuron DiT capable. " + "Call enable_token2wav(use_neuron_dit=True)" + ) + self.token2wav.compile_dit(compiled_path, **kwargs) + logger.info("Token2Wav DiT compiled to %s", compiled_path) + + def load_token2wav_dit(self, compiled_path): + """Load a compiled Token2Wav DiT. + + Args: + compiled_path: Directory containing compiled DiT + """ + if self.token2wav is None: + raise RuntimeError("Call enable_token2wav(use_neuron_dit=True) first") + if not hasattr(self.token2wav, "load_dit"): + raise RuntimeError( + "Token2Wav is not Neuron DiT capable. " + "Call enable_token2wav(use_neuron_dit=True)" + ) + self.token2wav.load_dit(compiled_path) + logger.info("Token2Wav DiT loaded from %s", compiled_path) + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, + inference_config: Qwen25OmniMultimodalInferenceConfig, + include_talker: bool = False, + include_token2wav: bool = False, + ) -> dict: + """Convert Qwen2.5-Omni full state dict to NxDI format. + + 1. Remap thinker.* prefixes to Qwen2-VL compatible format + 2. Apply vision encoder conversion (separate Q/K/V, SwiGLU MLP) + 3. Apply audio encoder conversion (strip prefix, cast dtype) + 4. Apply text model conversion (fused QKV, attention key renames) + 5. Optionally strip talker.* and token2wav.* prefixes + + Args: + state_dict: Full HF state dict + inference_config: Multimodal inference config + include_talker: If True, include talker keys (stripped prefix) + include_token2wav: If True, include token2wav keys (stripped prefix) + """ + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + from modeling_qwen25_omni_vision import ( + NeuronQwen25OmniForImageEncoding, + ) + from neuronx_distributed_inference.models.qwen2_vl.modeling_qwen2_vl_text import ( + NeuronQwen2VLTextForCausalLM, + ) + + # Step 1: Remap thinker.* prefixes, optionally keep talker/token2wav + remapped = {} + talker_state = {} + token2wav_state = {} + for key, value in state_dict.items(): + if key.startswith("thinker.model."): + remapped["model." + key[len("thinker.model."):]] = value + elif key.startswith("thinker.lm_head."): + remapped[key[len("thinker."):]] = value + elif key.startswith("thinker.visual."): + remapped["visual." + key[len("thinker.visual."):]] = value + elif key.startswith("thinker.audio_tower."): + remapped["audio_tower." + key[len("thinker.audio_tower."):]] = value + elif key.startswith("talker.") and include_talker: + talker_state[key[len("talker."):]] = value + elif key.startswith("token2wav.") and include_token2wav: + token2wav_state[key[len("token2wav."):]] = value + + del state_dict + gc.collect() + + logger.info( + "Remapped %d thinker keys%s%s", + len(remapped), + " + %d talker keys" % len(talker_state) if talker_state else "", + " + %d token2wav keys" % len(token2wav_state) if token2wav_state else "", + ) + + # Step 2: Vision encoder conversion + remapped = NeuronQwen25OmniForImageEncoding.convert_hf_to_neuron_state_dict( + remapped, inference_config + ) + + # Step 3: Audio encoder conversion (strip prefix, cast dtype) + audio_dtype = getattr( + inference_config, "torch_dtype", torch.bfloat16 + ) + if hasattr(inference_config, "neuron_config"): + audio_dtype = getattr( + inference_config.neuron_config, "torch_dtype", audio_dtype + ) + remapped = NeuronQwen25OmniAudioEncoder.convert_hf_to_neuron_state_dict( + remapped, dtype=audio_dtype + ) + + # Step 4: Text model conversion + remapped = NeuronQwen2VLTextForCausalLM.convert_hf_to_neuron_state_dict( + remapped, inference_config.text_config + ) + + # Step 5: Merge talker and token2wav state dicts + if talker_state: + for k, v in talker_state.items(): + remapped["talker." + k] = v + logger.info("Included %d talker keys in output", len(talker_state)) + if token2wav_state: + for k, v in token2wav_state.items(): + remapped["token2wav." + k] = v + logger.info("Included %d token2wav keys in output", len(token2wav_state)) + + return remapped + + def get_padding_length(self, input_ids): + """Get the context encoding bucket size for given input_ids.""" + buckets = self.context_encoding_model.config.neuron_config.buckets + for val in buckets: + if val >= input_ids.shape[1]: + return val + raise Exception("No bucket found for provided input_ids!") + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + seq_ids: Optional[torch.LongTensor] = None, + sampling_params: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + vision_mask: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + feature_attention_mask: Optional[torch.LongTensor] = None, + adapter_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + medusa_args=None, + input_capture_hook: Optional[Callable] = None, + tensor_capture_hook: Optional[Callable] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + from neuronx_distributed_inference.models.llama4.utils.encoder_utils import ( + generate_positions_from_mask, + pad_positions, + ) + + pad_limit = self.get_padding_length(input_ids) + is_context_encoding = input_ids.shape[-1] > 1 + + # --- Audio encoding (CPU frontend + Neuron transformer + CPU postprocessor) --- + audio_embeddings = None + audio_positions = None + if ( + input_features is not None + and self.audio_encoder is not None + and is_context_encoding + ): + audio_token_id = getattr( + self.config, "audio_token_id", None + ) or getattr(self.config, "audio_token_index", 151646) + + with torch.no_grad(): + # Prepare audio features (same as HF get_audio_features) + if feature_attention_mask is not None: + audio_feature_lengths = feature_attention_mask.sum(-1) + # (batch, mel_len, n_mels) -> mask -> (n_mels, valid_len) + input_features_flat = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + else: + input_features_flat = input_features.squeeze(0).permute(1, 0) + audio_feature_lengths = torch.tensor( + [input_features_flat.shape[1]], dtype=torch.long + ) + + aftercnn_lens, audio_output_lens = ( + self.audio_encoder._get_feat_extract_output_lengths( + audio_feature_lengths + ) + ) + audio_embeddings = self.audio_encoder( + input_features_flat, + feature_lens=audio_feature_lengths, + aftercnn_lens=aftercnn_lens, + ) + + # Find audio token positions for scattering + audio_mask_bool = (input_ids == audio_token_id) + if audio_mask_bool.any() and audio_embeddings is not None: + audio_positions = generate_positions_from_mask( + audio_mask_bool.squeeze() + ) + + # --- Vision encoding (Neuron) --- + vision_embeddings = None + vision_positions = None + if ( + (pixel_values is not None) + and is_context_encoding + and pixel_values.sum() != 0 + ): + image_token_id = getattr(self.config, "image_token_id", None) or getattr( + self.config, "image_token_index", 151655 + ) + vision_mask_bool = (input_ids == image_token_id) + if vision_mask_bool.any(): + vision_positions = generate_positions_from_mask( + vision_mask_bool.squeeze() + ) + vision_embeddings = self.vision_encoder_model( + pixel_values.to(self.vision_config.neuron_config.torch_dtype), + image_grid_thw, + ) + + # --- Combine multimodal embeddings for scattering --- + # The text model's encode_vision_to_input scatters embeddings at + # specified positions. We combine audio + vision into one tensor. + if audio_embeddings is not None and vision_embeddings is not None: + # Both audio and vision present + # audio_embeddings is on CPU, vision_embeddings may be on XLA + # The model wrapper handles device transfer, so keep on CPU + all_embeddings = torch.cat([ + vision_embeddings.cpu() if vision_embeddings.is_cuda else vision_embeddings, + audio_embeddings, + ], dim=0) + all_positions = torch.cat([vision_positions, audio_positions]) + vision_embeddings = all_embeddings + vision_mask = pad_positions(all_positions, pad_limit, (pad_limit - 1)) + elif audio_embeddings is not None and audio_positions is not None: + # Audio only, no vision + vision_embeddings = audio_embeddings + vision_mask = pad_positions(audio_positions, pad_limit, (pad_limit - 1)) + elif vision_embeddings is not None and vision_positions is not None: + # Vision only, no audio + vision_mask = pad_positions(vision_positions, pad_limit, (pad_limit - 1)) + else: + # No multimodal input - use dummy embeddings + vision_embeddings, vision_mask = self._get_text_model_wrapper().get_dummy_vision_inputs( + config=self.text_config, + input_ids=input_ids, + n_active_tokens=pad_limit, + fill_value=(pad_limit - 1), + ) + + output_token = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + input_capture_hook=input_capture_hook, + tensor_capture_hook=tensor_capture_hook, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + return output_token + + @classmethod + def get_config_cls(cls): + return Qwen25OmniMultimodalInferenceConfig diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_audio.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_audio.py new file mode 100644 index 00000000..a58f1f50 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_audio.py @@ -0,0 +1,700 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni Audio Encoder for NXD inference. +# +# Whisper-style audio encoder with TP=4 support: +# - Conv1d frontend + sinusoidal positional embeddings (CPU preprocessing) +# - 32 transformer layers with TP-parallel attention and MLP (Neuron) +# - AvgPool1d + LayerNorm + Linear projection (CPU postprocessing) +# +# Architecture: +# d_model=1280, heads=20 (5 per rank at TP=4), head_dim=64 +# ffn=5120 (1280 per rank at TP=4), output_dim=3584 +# Asymmetric attention bias: q/v have bias, k has NO bias +# +# The transformer layers are compiled on Neuron via ModelWrapper. +# Conv1d frontend and postprocessing run on CPU since they involve +# variable-length processing (chunking, per-audio AvgPool). + +"""Qwen2.5-Omni Audio Encoder for NXD inference.""" + +import logging +import math +from types import SimpleNamespace +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed_inference.models.application_base import NeuronApplicationBase +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.model_wrapper import ( + EncoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.modules.padding import pad_tensor + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# CPU components (not compiled on Neuron) +# --------------------------------------------------------------------------- + +class SinusoidsPositionEmbedding(nn.Module): + """Sinusoidal positional embeddings (same as Whisper/HF Qwen2.5-Omni).""" + + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp( + -log_timescale_increment * torch.arange(channels // 2).float() + ) + scaled_time = ( + torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + ) + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +class AudioCPUFrontend(nn.Module): + """Conv1d frontend + positional embeddings + chunking (CPU). + + Processes raw mel spectrograms into token sequences ready for the + Neuron transformer. Also handles audio chunking for long audio. + """ + + def __init__(self, audio_config, dtype=torch.bfloat16): + super().__init__() + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + d_model = audio_config.d_model # 1280 + num_mel_bins = audio_config.num_mel_bins # 128 + max_source_positions = audio_config.max_source_positions # 1500 + self.n_window = audio_config.n_window # 100 + self.d_model = d_model + + self.conv1 = nn.Conv1d( + num_mel_bins, d_model, kernel_size=3, padding=1, dtype=dtype + ) + self.conv2 = nn.Conv1d( + d_model, d_model, kernel_size=3, stride=2, padding=1, dtype=dtype + ) + self.positional_embedding = SinusoidsPositionEmbedding( + max_source_positions, d_model + ) + + def _padded_and_mask_function(self, chunk_list, chunk_lengths): + """Pad chunks to same length and create masks.""" + max_len = chunk_lengths.max().item() + dim = chunk_list[0].shape[0] + + padded_tensor = torch.zeros( + len(chunk_list), dim, max_len, + dtype=chunk_list[0].dtype, device=chunk_list[0].device, + ) + batch_mask = torch.zeros( + len(chunk_lengths), max_len, + dtype=torch.long, device=chunk_list[0].device, + ) + for i, (chunk, length) in enumerate(zip(chunk_list, chunk_lengths)): + length = length.item() + batch_mask[i, :length] = 1 + padded_tensor[i, :, :chunk.shape[1]] = chunk + + feature_lens_after_cnn = (chunk_lengths - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max().item() + batch_mask_after_cnn = torch.zeros( + len(chunk_lengths), max_len_after_cnn, + dtype=torch.bool, device=chunk_list[0].device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = True + + return padded_tensor, batch_mask.unsqueeze(1), batch_mask_after_cnn + + def forward(self, input_features, feature_lens): + """Process mel spectrogram through conv frontend. + + Args: + input_features: (n_mels, total_mel_len) mel spectrogram + feature_lens: (num_audios,) mel length for each audio + + Returns: + hidden_states: (total_valid_tokens, d_model) token embeddings + aftercnn_lens: (num_audios,) valid token count per audio + cu_seqlens: cumulative sequence lengths for attention masking + """ + aftercnn_lens = (feature_lens - 1) // 2 + 1 + + # Split into chunks of n_window * 2 mel frames + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum().item(), + dtype=torch.long, device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths = torch.where( + chunk_lengths == 0, self.n_window * 2, chunk_lengths + ) + + chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) + padded_feature, padded_mask, padded_mask_after_cnn = ( + self._padded_and_mask_function(chunk_list, chunk_lengths) + ) + + # Conv frontend + padded_embed = F.gelu(self.conv1(padded_feature)) * padded_mask + padded_embed = F.gelu(self.conv2(padded_embed)).transpose(1, 2) + + # Add positional embeddings + padded_embed = padded_embed + self.positional_embedding( + padded_embed.shape[1] + ).unsqueeze(0).to(padded_embed.dtype) + + # Flatten valid tokens + hidden_states = padded_embed[padded_mask_after_cnn] + + # Compute cu_seqlens for attention mask + cu_seqlens = torch.cat([ + torch.zeros(1, device=feature_lens.device, dtype=torch.int32), + padded_mask_after_cnn.sum(1).cumsum(0).to(torch.int32), + ]) + + return hidden_states, aftercnn_lens, cu_seqlens + + +class AudioCPUPostprocessor(nn.Module): + """AvgPool + LayerNorm + projection (CPU). + + Post-processes transformer output into final audio embeddings. + """ + + def __init__(self, audio_config, dtype=torch.bfloat16): + super().__init__() + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + d_model = audio_config.d_model # 1280 + output_dim = audio_config.output_dim # 3584 + + self.ln_post = nn.LayerNorm(d_model) # stays float32 + self.avg_pooler = nn.AvgPool1d(2, stride=2) + self.proj = nn.Linear(d_model, output_dim, dtype=dtype) + self.audio_bos_eos_token = nn.Embedding(2, output_dim) + + def forward(self, hidden_states, aftercnn_lens): + """Post-process transformer output. + + Args: + hidden_states: (total_tokens, d_model) transformer output + aftercnn_lens: (num_audios,) token count per audio + + Returns: + audio_embeddings: (total_output_tokens, output_dim) + """ + hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) + token_audio_list = [] + for each_audio_states in hidden_states_list: + each_audio_states = self.avg_pooler( + each_audio_states.transpose(0, 1) + ).transpose_(0, 1) + each_audio_states = self.ln_post(each_audio_states.float()).to( + each_audio_states.dtype + ) + each_audio_states = self.proj(each_audio_states) + token_audio_list.append(each_audio_states) + + return torch.cat(token_audio_list, dim=0) + + +# --------------------------------------------------------------------------- +# Neuron-compiled transformer components (TP=4) +# --------------------------------------------------------------------------- + +class NeuronAudioAttention(nn.Module): + """TP-parallel self-attention for audio encoder. + + Asymmetric bias: q_proj and v_proj have bias, k_proj has NO bias. + Uses ColumnParallelLinear for Q/K/V and RowParallelLinear for output. + """ + + def __init__(self, d_model, num_heads, tp_degree, dtype=torch.bfloat16): + super().__init__() + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + self.num_heads_per_rank = num_heads // tp_degree + self.scaling = self.head_dim ** -0.5 + + self.q_proj = ColumnParallelLinear( + d_model, d_model, bias=True, gather_output=False, dtype=dtype, + ) + self.k_proj = ColumnParallelLinear( + d_model, d_model, bias=False, gather_output=False, dtype=dtype, + ) + self.v_proj = ColumnParallelLinear( + d_model, d_model, bias=True, gather_output=False, dtype=dtype, + ) + self.out_proj = RowParallelLinear( + d_model, d_model, bias=True, input_is_parallel=True, dtype=dtype, + ) + + def forward(self, hidden_states, attention_mask=None): + """ + Args: + hidden_states: (batch, seq_len, d_model) + attention_mask: (batch, 1, seq_len, seq_len) with 0 for valid, -inf for masked + """ + bsz, seq_len, _ = hidden_states.shape + + # Project Q/K/V (ColumnParallelLinear outputs d_model/tp per rank) + q = self.q_proj(hidden_states) # (bsz, seq, d_model/tp) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to (bsz, num_heads_per_rank, seq, head_dim) + q = q.view(bsz, seq_len, self.num_heads_per_rank, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads_per_rank, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads_per_rank, self.head_dim).transpose(1, 2) + + # Attention scores + scores = torch.matmul(q, k.transpose(-2, -1)) * self.scaling + if attention_mask is not None: + scores = scores + attention_mask + + attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + # Reshape and project output (RowParallelLinear allreduces across ranks) + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + attn_output = self.out_proj(attn_output) + return attn_output + + +class NeuronAudioEncoderLayer(nn.Module): + """Audio encoder transformer layer (TP-parallel). + + Pre-norm: LayerNorm -> attention -> residual -> LayerNorm -> MLP -> residual. + Uses GELU activation (not SwiGLU). LayerNorm operates in float32. + """ + + def __init__(self, d_model, num_heads, ffn_dim, tp_degree, dtype=torch.bfloat16): + super().__init__() + self.self_attn = NeuronAudioAttention(d_model, num_heads, tp_degree, dtype) + self.self_attn_layer_norm = nn.LayerNorm(d_model) + self.fc1 = ColumnParallelLinear( + d_model, ffn_dim, bias=True, gather_output=False, dtype=dtype, + ) + self.fc2 = RowParallelLinear( + ffn_dim, d_model, bias=True, input_is_parallel=True, dtype=dtype, + ) + self.final_layer_norm = nn.LayerNorm(d_model) + + def forward(self, hidden_states, attention_mask=None): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states.float()).to(residual.dtype) + hidden_states = self.self_attn(hidden_states, attention_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states.float()).to(residual.dtype) + hidden_states = F.gelu(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NeuronAudioTransformerModel(nn.Module): + """Audio encoder transformer (32 layers, compiled on Neuron). + + Takes pre-processed hidden states (after conv + positional embedding) + and returns transformer output. The attention mask handles block-diagonal + masking for chunked audio. + + Input: (1, padded_seq_len, d_model) + (1, 1, padded_seq_len, padded_seq_len) + Output: (1, padded_seq_len, d_model) + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + audio_config = config.audio_config + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + tp_degree = config.neuron_config.tp_degree + dtype = config.neuron_config.torch_dtype + + d_model = audio_config.d_model + num_heads = audio_config.encoder_attention_heads + ffn_dim = audio_config.encoder_ffn_dim + num_layers = audio_config.encoder_layers + + self.layers = nn.ModuleList([ + NeuronAudioEncoderLayer(d_model, num_heads, ffn_dim, tp_degree, dtype) + for _ in range(num_layers) + ]) + + def forward(self, hidden_states, attention_mask): + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + return hidden_states + + +# --------------------------------------------------------------------------- +# ModelWrapper and Application classes +# --------------------------------------------------------------------------- + +class AudioTransformerModelWrapper(ModelWrapper): + """Handles bucketing on sequence length, padding, and compilation.""" + + def __init__(self, config, model_cls, tag="", compiler_args=None, + priority_model_idx=None, pipeline_execution=True, + return_ranked_to_cpu=False, model_init_kwargs={}): + super().__init__( + config, model_cls, tag, compiler_args, priority_model_idx, + pipeline_execution, return_ranked_to_cpu, model_init_kwargs, + ) + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + """Generate sample inputs for each sequence length bucket.""" + inputs = [] + dtype = self.config.neuron_config.torch_dtype + d_model = self.config.audio_config.d_model + if isinstance(d_model, dict): + d_model = d_model.get("d_model", 1280) + + for bucket in self.config.neuron_config.buckets: + hidden_states = torch.ones([1, bucket, d_model], dtype=dtype) + attention_mask = torch.zeros( + [1, 1, bucket, bucket], dtype=dtype, + ) + inputs.append((hidden_states, attention_mask)) + return inputs + + def get_model_instance(self): + return EncoderModelInstance(model_cls=self.model_cls, config=self.config) + + def get_target_bucket(self, seq_len): + """Find the smallest bucket that fits the sequence length.""" + for bucket in self.config.neuron_config.buckets: + if bucket >= seq_len: + return bucket + raise ValueError( + f"No bucket found for seq_len={seq_len}. " + f"Buckets: {self.config.neuron_config.buckets}" + ) + + def forward(self, hidden_states, attention_mask): + """Pad to bucket size and run compiled model.""" + if self.model is None: + raise RuntimeError("Forward called before load.") + + seq_len = hidden_states.shape[1] + bucket = self.get_target_bucket(seq_len) + + # Pad sequence to bucket size + if seq_len < bucket: + pad_len = bucket - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len)) + # Extend attention mask: new positions are masked out + dtype = attention_mask.dtype + mask_pad = torch.full( + (1, 1, bucket, bucket), + torch.finfo(dtype).min, + dtype=dtype, + ) + mask_pad[:, :, :seq_len, :seq_len] = attention_mask + attention_mask = mask_pad + + output = self._forward(hidden_states, attention_mask) + # Trim back to original length + return output[:, :seq_len, :] + + +class NeuronQwen25OmniForAudioEncoding(NeuronApplicationBase): + """Neuron application for audio encoder transformer layers. + + Handles compilation, loading, and inference of the 32 transformer layers. + The conv frontend and postprocessing are handled separately on CPU. + """ + + _model_cls = NeuronAudioTransformerModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = AudioTransformerModelWrapper( + config=self.config, + model_cls=self._model_cls, + tag=self._model_cls.__name__, + compiler_args=self.get_compiler_args(), + priority_model_idx=0, + ) + self.models.append(self.model) + + def forward(self, hidden_states, attention_mask): + return self.models[0](hidden_states, attention_mask) + + def get_compiler_args(self): + return ( + "--auto-cast=none --model-type=transformer " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 ' -O1 " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, inference_config): + """Convert HF state dict to Neuron format for audio transformer layers. + + Extracts audio_tower.layers.* keys and strips prefix. + Only returns transformer layer keys (not conv/postprocessing). + """ + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("thinker.audio_tower.layers."): + new_key = key[len("thinker.audio_tower."):] + elif key.startswith("audio_tower.layers."): + new_key = key[len("audio_tower."):] + else: + new_state_dict[key] = value + continue + new_state_dict[new_key] = value + return new_state_dict + + @classmethod + def get_config_cls(cls): + return AudioEncoderInferenceConfig + + +class AudioEncoderInferenceConfig(InferenceConfig): + """Config for audio encoder transformer compilation.""" + + def __init__(self, neuron_config, audio_config, **kwargs): + # Store audio_config before calling super + self.audio_config = audio_config + if isinstance(audio_config, dict): + self.audio_config = SimpleNamespace(**audio_config) + super().__init__(neuron_config=neuron_config, **kwargs) + + def add_derived_config(self): + pass + + def get_required_attributes(self): + return [] + + +# --------------------------------------------------------------------------- +# Full Audio Encoder (combines CPU frontend + Neuron transformer + CPU postprocessing) +# --------------------------------------------------------------------------- + +class NeuronQwen25OmniAudioEncoder(nn.Module): + """Qwen2.5-Omni Audio Encoder with Neuron acceleration. + + Architecture: + CPU: mel → Conv1d frontend → positional embeddings → chunking + Neuron (TP=4): 32 transformer layers with block-diagonal attention + CPU: AvgPool → LayerNorm → Linear projection → audio embeddings + + Usage: + 1. Create with audio_config + 2. Call compile_transformer() to compile on Neuron + 3. Call load_transformer() to load compiled model + 4. Call forward() to encode audio + """ + + def __init__(self, audio_config, neuron_config=None, dtype=torch.bfloat16): + super().__init__() + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + self.audio_config = audio_config + self.dtype = dtype + self.n_window = audio_config.n_window + + # CPU components + self.frontend = AudioCPUFrontend(audio_config, dtype=dtype) + self.postprocessor = AudioCPUPostprocessor(audio_config, dtype=dtype) + + # Neuron transformer (initialized via compile/load cycle) + self.transformer = None + self._neuron_config = neuron_config + + def _get_feat_extract_output_lengths(self, input_lengths): + """Compute output lengths after conv and avg_pool.""" + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + def _prepare_attention_mask(self, seq_length, cu_seqlens, dtype): + """Create block-diagonal attention mask from cu_seqlens.""" + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(dtype).min, + dtype=dtype, + ) + for i in range(1, len(cu_seqlens)): + s, e = cu_seqlens[i - 1], cu_seqlens[i] + attention_mask[..., s:e, s:e] = 0 + return attention_mask + + def forward(self, input_features, feature_lens, aftercnn_lens=None): + """Process mel spectrogram through audio encoder. + + Args: + input_features: (n_mels, total_mel_len) mel spectrogram + feature_lens: (num_audios,) mel length for each audio + aftercnn_lens: optional pre-computed lengths after conv + + Returns: + audio_embeddings: (total_audio_tokens, output_dim) tensor + """ + if aftercnn_lens is None: + aftercnn_lens, _ = self._get_feat_extract_output_lengths(feature_lens) + + # CPU: Conv frontend + chunking + hidden_states, aftercnn_lens_actual, cu_seqlens = self.frontend( + input_features, feature_lens + ) + + if self.transformer is not None: + # Neuron: transformer layers + seq_len = hidden_states.shape[0] + attention_mask = self._prepare_attention_mask( + seq_len, cu_seqlens, self.dtype + ) + # Add batch dimension for Neuron model + hidden_states = hidden_states.unsqueeze(0) + hidden_states = self.transformer(hidden_states, attention_mask) + hidden_states = hidden_states.squeeze(0) + else: + # Fallback: CPU transformer (for testing without Neuron) + logger.warning( + "Audio transformer not compiled. Running transformer on CPU " + "(this is slow and should only be used for testing)." + ) + attention_mask = self._prepare_attention_mask( + hidden_states.shape[0], cu_seqlens, self.dtype + ) + # CPU fallback would require loading transformer weights separately + raise RuntimeError( + "Audio transformer must be compiled and loaded before inference. " + "Call compile_transformer() and load_transformer() first." + ) + + # CPU: Postprocessing + return self.postprocessor(hidden_states, aftercnn_lens_actual) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, dtype=torch.bfloat16): + """Convert HF state dict to audio encoder format. + + Splits keys into three groups: + - frontend.*: conv1, conv2, positional_embedding (CPU) + - layers.*: transformer layers (Neuron) + - postprocessor.*: ln_post, avg_pooler, proj, audio_bos_eos_token (CPU) + + Returns dict with prefixed keys for the split architecture. + """ + new_state_dict = {} + + # LayerNorm keys that should remain float32 + ln_suffixes = ( + "self_attn_layer_norm.weight", "self_attn_layer_norm.bias", + "final_layer_norm.weight", "final_layer_norm.bias", + "ln_post.weight", "ln_post.bias", + ) + + # CPU frontend keys + frontend_prefixes = ("conv1.", "conv2.", "positional_embedding.") + # CPU postprocessor keys + postprocessor_prefixes = ("ln_post.", "proj.", "avg_pooler.", "audio_bos_eos_token.") + + for key, value in state_dict.items(): + # Strip audio_tower prefix + if key.startswith("thinker.audio_tower."): + clean_key = key[len("thinker.audio_tower."):] + elif key.startswith("audio_tower."): + clean_key = key[len("audio_tower."):] + else: + new_state_dict[key] = value + continue + + # Determine target dtype + if any(clean_key.endswith(s) for s in ln_suffixes): + target_dtype = torch.float32 + else: + target_dtype = dtype + + # Route to correct sub-module + if any(clean_key.startswith(p) for p in frontend_prefixes): + new_state_dict["frontend." + clean_key] = ( + value.clone().detach().contiguous().to(target_dtype) + ) + elif any(clean_key.startswith(p) for p in postprocessor_prefixes): + new_state_dict["postprocessor." + clean_key] = ( + value.clone().detach().contiguous().to(target_dtype) + ) + elif clean_key.startswith("layers."): + # Transformer layers (will be loaded into Neuron model) + new_state_dict["transformer." + clean_key] = ( + value.clone().detach().contiguous().to(target_dtype) + ) + else: + logger.warning("Unknown audio key: %s", clean_key) + + return new_state_dict + + @staticmethod + def from_pretrained_state_dict(audio_config, state_dict, dtype=torch.bfloat16): + """Create audio encoder and load CPU weights from converted state dict. + + Note: Transformer weights need to be loaded separately via + compile_transformer() + load_transformer() for Neuron execution. + """ + encoder = NeuronQwen25OmniAudioEncoder(audio_config, dtype=dtype) + + # Load only frontend and postprocessor weights + cpu_keys = {} + for key, value in state_dict.items(): + if key.startswith("frontend.") or key.startswith("postprocessor."): + cpu_keys[key] = value + + if cpu_keys: + missing, unexpected = encoder.load_state_dict(cpu_keys, strict=False) + # Filter out transformer keys from missing (expected) + missing = [k for k in missing if not k.startswith("transformer.")] + if missing: + logger.warning("Audio encoder CPU missing keys: %s", missing[:10]) + logger.info("Loaded %d CPU weights into audio encoder", len(cpu_keys)) + + return encoder diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_talker.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_talker.py new file mode 100644 index 00000000..32e42b8a --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_talker.py @@ -0,0 +1,1057 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni Talker for NXD inference. +# +# This file contains TWO implementations: +# +# 1. NeuronQwen25OmniTalker (CPU wrapper) +# - Wraps HF's Qwen2_5OmniTalkerForConditionalGeneration +# - Runs on CPU — suitable for quick testing or when Neuron resources +# are reserved for the 7B Thinker +# +# 2. NeuronQwen25OmniTalkerForCausalLM (Neuron-compiled) +# - Compiles the 24-layer transformer on Neuron with KV cache +# - Uses fused embedding (8448→3584→896 collapsed to 8448→896) +# - Supports mRoPE (3D position_ids) and explicit head_dim=128 +# - Recommended TP=4 (3 Q heads/rank, 1 KV head/rank) +# - Uses ImageToTextModelWrapper for thinker state injection during +# context encoding (projected thinker states passed as vision_embeddings) +# +# Talker Architecture: +# - embed_tokens: Embedding(8448, 3584) — codec vocab in Thinker's dim space +# - thinker_to_talker_proj: Linear(3584 -> 896) +# - 24 Qwen2 decoder layers (GQA: 12 heads, 4 kv_heads, head_dim=128) +# - MLP: SiLU gate_proj/up_proj(896->18944), down_proj(18944->896) +# - RMSNorm(896) +# - codec_head: Linear(896 -> 8448, no bias) +# +# Total Parameters: ~690M + +"""Qwen2.5-Omni Talker model for NXD inference.""" + +import gc +import logging +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Part 1: CPU-based Talker (HF wrapper) +# ============================================================================= + + +class NeuronQwen25OmniTalker: + """Wrapper around HF's Qwen2_5OmniTalkerForConditionalGeneration. + + The Talker takes Thinker hidden states + codec embeddings as input, + projects them from embedding_size (3584) to hidden_size (896), runs + through 24 Qwen2 decoder layers, and outputs codec tokens via a + codec_head linear layer. + + This wrapper: + 1. Instantiates the HF Talker from config + 2. Loads weights from converted state dict + 3. Exposes generation API for codec token synthesis + """ + + def __init__(self, talker_config, dtype=torch.bfloat16): + """Initialize the Talker. + + Args: + talker_config: Talker config (dict or HF config object). + Must contain: vocab_size, embedding_size, hidden_size, + num_hidden_layers, num_attention_heads, num_key_value_heads, + intermediate_size, etc. + dtype: Model dtype (default bfloat16). + """ + from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniTalkerConfig, + ) + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniTalkerForConditionalGeneration, + ) + + if isinstance(talker_config, dict): + talker_config = Qwen2_5OmniTalkerConfig(**talker_config) + + self.model = Qwen2_5OmniTalkerForConditionalGeneration(talker_config) + self.model.to(dtype) + self.model.eval() + self.config = talker_config + self.dtype = dtype + + # Expose codec token IDs for orchestration + self.codec_bos_token = talker_config.tts_codec_start_token_id + self.codec_eos_token = talker_config.tts_codec_end_token_id + self.codec_pad_token = talker_config.tts_codec_pad_token_id + self.codec_mask_token = talker_config.tts_codec_mask_token_id + self.text_bos_token = talker_config.tts_text_start_token_id + self.text_eos_token = talker_config.tts_text_end_token_id + self.text_pad_token = talker_config.tts_text_pad_token_id + + def load_state_dict(self, state_dict, strict=True): + """Load converted state dict into the HF Talker model.""" + return self.model.load_state_dict(state_dict, strict=strict) + + @torch.no_grad() + def generate( + self, + input_ids, + input_text_ids, + thinker_reply_part, + inputs_embeds, + attention_mask=None, + max_new_tokens=4096, + do_sample=True, + top_k=40, + top_p=0.8, + temperature=0.9, + eos_token_id=None, + repetition_penalty=1.05, + suppress_tokens=None, + **kwargs, + ): + """Generate codec tokens from Thinker hidden states. + + Args: + input_ids: (batch, seq_len) codec input IDs + input_text_ids: (batch, seq_len) text input IDs (for position calc) + thinker_reply_part: (batch, reply_len, 3584) Thinker hidden states + inputs_embeds: (batch, seq_len, 3584) input embeddings + attention_mask: (batch, seq_len) optional attention mask + max_new_tokens: Maximum codec tokens to generate + do_sample: Whether to sample (vs greedy) + top_k: Top-k sampling + top_p: Nucleus sampling probability + temperature: Sampling temperature + eos_token_id: EOS token(s) for stopping + repetition_penalty: Repetition penalty + suppress_tokens: Token IDs to suppress during generation + **kwargs: Additional generation kwargs + + Returns: + (batch, total_len) generated codec token IDs + """ + if eos_token_id is None: + eos_token_id = [self.codec_eos_token, self.codec_pad_token] + if suppress_tokens is None: + suppress_tokens = [self.codec_bos_token] + + return self.model.generate( + input_ids=input_ids, + input_text_ids=input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + suppress_tokens=suppress_tokens, + **kwargs, + ) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: dict) -> dict: + """Convert HF state dict to Talker format. + + Strips 'talker.' prefix from keys. Non-talker keys are passed through. + + Args: + state_dict: Full or partial state dict with talker.* keys. + + Returns: + State dict with talker prefix stripped. + """ + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("talker."): + new_state_dict[key[len("talker."):]] = value + else: + # Pass through non-talker keys + new_state_dict[key] = value + return new_state_dict + + @classmethod + def from_pretrained_state_dict(cls, talker_config, state_dict, dtype=torch.bfloat16): + """Create Talker and load weights from converted state dict. + + Args: + talker_config: Talker config (dict or HF config object) + state_dict: Already-converted state dict (talker keys only) + dtype: Target dtype + + Returns: + Initialized NeuronQwen25OmniTalker + """ + talker = cls(talker_config, dtype=dtype) + + # Filter to only talker keys (skip non-talker prefixes) + talker_keys = {} + for key, value in state_dict.items(): + if any( + key.startswith(p) + for p in [ + "lm_head.", "visual.", "audio_tower.", + "thinker.", "token2wav.", "talker.", + ] + ): + continue + talker_keys[key] = value + + missing, unexpected = talker.load_state_dict(talker_keys, strict=False) + if missing: + logger.warning("Talker missing keys: %s", missing[:10]) + if unexpected: + logger.warning("Talker unexpected keys: %s", unexpected[:10]) + logger.info("Loaded %d weights into Talker", len(talker_keys)) + + return talker + + +# ============================================================================= +# Part 2: Neuron-compiled Talker +# ============================================================================= + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, +) +from neuronx_distributed.utils import cpu_mode +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.llama.modeling_llama import NeuronLlamaMLP +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase +from neuronx_distributed_inference.modules.attention.utils import _rotate_half +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + + +def _get_rmsnorm_cls(): + return LlamaRMSNorm if cpu_mode() else CustomRMSNorm + + +def _apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Apply multimodal RoPE (reused from Qwen2-VL).""" + mrope_section = mrope_section * 2 + split_indices = [sum(mrope_section[:i + 1]) for i in range(len(mrope_section) - 1)] + cos = torch.cat( + [m[i % 3] for i, m in enumerate(torch.tensor_split(cos, split_indices, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(torch.tensor_split(sin, split_indices, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +def _apply_standard_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Apply standard 1D RoPE.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +class TalkerNeuronConfig(NeuronConfig): + """NeuronConfig subclass for Talker. + + Sets the default attention class to NeuronTalkerAttention. + Recommended TP=4 for the Talker (12 Q heads / 4 = 3 per rank). + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.attn_cls = NeuronTalkerAttention + + +class TalkerInferenceConfig(InferenceConfig): + """InferenceConfig for the Neuron-compiled Talker. + + Talker-specific attributes: + - head_dim = 128 (explicit, NOT hidden_size // num_attention_heads) + - qkv_bias = True, o_bias = False (Qwen2 pattern) + - rope_scaling with mrope_section for 3D mRoPE + - thinker_hidden_size = 3584 (for projection during context encoding) + """ + + def add_derived_config(self): + self.num_cores_per_group = 1 + self.qkv_bias = True + self.o_bias = False + # Head dim is EXPLICIT for the Talker + # hidden_size=896, num_attention_heads=12 → 896/12=74.67 (fractional!) + # Actual head_dim=128, so attention internal dim = 12×128 = 1536 + if not hasattr(self, "head_dim") or self.head_dim is None: + self.head_dim = 128 + # mRoPE config (default matching Qwen2.5-Omni) + if not hasattr(self, "rope_scaling") or self.rope_scaling is None: + self.rope_scaling = {"type": "mrope", "mrope_section": [16, 24, 24]} + # Store thinker hidden size for projection + if not hasattr(self, "thinker_hidden_size"): + self.thinker_hidden_size = 3584 + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", # 896 + "num_attention_heads", # 12 + "num_hidden_layers", # 24 + "num_key_value_heads", # 4 + "vocab_size", # 8448 + "rope_theta", + "rms_norm_eps", + "hidden_act", # silu + "intermediate_size", # 18944 + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[TalkerNeuronConfig]: + return TalkerNeuronConfig + + +# --------------------------------------------------------------------------- +# RoPE +# --------------------------------------------------------------------------- + +class TalkerRotaryEmbedding(nn.Module): + """Rotary position embedding for the Talker. + + Uses head_dim (128) as the RoPE dimension, NOT hidden_size // num_heads. + Supports both standard 1D and multimodal 3D position_ids. + """ + + def __init__(self, config: TalkerInferenceConfig, device=None): + super().__init__() + self.dim = config.head_dim # 128 + self.base = getattr(config, "rope_theta", 1000000.0) + self.attention_scaling = 1.0 + self.register_buffer("inv_freq", None, persistent=False) + self.inv_freq = self._compute_inv_freq(device) + + def _compute_inv_freq(self, device=None): + freq_indices = torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) + return 1.0 / (self.base ** (freq_indices / self.dim)) + + def forward(self, x, position_ids): + if position_ids.ndim == 2: + # Expand 2D (batch, seq) → 3D (3, batch, seq) for mRoPE + # Same approach as Qwen3-VL: replicate across temporal/height/width + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # 3D mRoPE: position_ids shape (3, batch, seq) + inv_freq_expanded = self.inv_freq[None, None, :, None].expand( + 3, position_ids.shape[1], -1, 1 + ) + position_ids_expanded = position_ids[:, :, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(-2, -1) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + +class NeuronTalkerAttention(NeuronAttentionBase): + """Talker self-attention with explicit head_dim=128 and mRoPE. + + Key difference from standard Qwen2 attention: + - head_dim=128 ≠ hidden_size(896) / num_attention_heads(12) + - Internal attention dimension = 12 × 128 = 1536 ≠ hidden_size + - Q projection: (896, 1536), K/V projection: (896, 512) + - Output projection: (1536, 896) + """ + + def __init__(self, config: TalkerInferenceConfig): + rotary_emb = TalkerRotaryEmbedding(config) + super().__init__( + config=config, + hidden_size=config.hidden_size, # 896 + num_attention_heads=config.num_attention_heads, # 12 + num_key_value_heads=config.num_key_value_heads, # 4 + head_dim=config.head_dim, # 128 + qkv_bias=config.qkv_bias, + o_bias=config.o_bias, + rotary_emb=rotary_emb, + ) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.use_mrope = ( + self.rope_scaling is not None + and "mrope_section" in self.rope_scaling + ) + if self.use_mrope: + self.mrope_section = self.rope_scaling["mrope_section"] + + def apply_rotary_embedding(self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope): + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + if self.use_mrope: + Q, K = _apply_multimodal_rotary_pos_emb( + Q, K, cos_cache, sin_cache, self.mrope_section + ) + else: + Q, K = _apply_standard_rotary_pos_emb(Q, K, cos_cache, sin_cache) + return Q, K, cos_cache, sin_cache + + +# --------------------------------------------------------------------------- +# Decoder Layer +# --------------------------------------------------------------------------- + +class NeuronTalkerDecoderLayer(nn.Module): + """Talker decoder layer: pre-norm attention + SwiGLU MLP.""" + + def __init__(self, config: TalkerInferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = NeuronTalkerAttention(config) + self.mlp = NeuronLlamaMLP(config) + self.input_layernorm = _get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = _get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + adapter_ids=None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + adapter_ids=adapter_ids, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, adapter_ids=adapter_ids)[0] + hidden_states = residual + hidden_states + + return (hidden_states, present_key_value, cos_cache, sin_cache, None) + + +# --------------------------------------------------------------------------- +# Model (NeuronBaseModel) +# --------------------------------------------------------------------------- + +class NeuronTalkerModel(NeuronBaseModel): + """Neuron-compiled Talker transformer. + + Uses fused embedding: original embed_tokens(8448, 3584) + proj(3584, 896) + are collapsed into embed_tokens(8448, 896) during state dict conversion. + + For context encoding with thinker hidden states, the projected states + (batch, reply_len, 896) are passed as vision_embeddings and substituted + via encode_vision_to_input(). Unlike the upstream NeuronBaseModel which + only calls encode_vision_to_input during context encoding, this subclass + overrides get_model_output to also inject during token generation (for + per-step thinker state injection), matching HF's Qwen2.5-Omni behavior. + """ + + def get_model_output( + self, + input_ids=None, + inputs_embeds=None, + vision_embeddings=None, + vision_mask=None, + is_for_context_encoding: bool = False, + adapter_ids=None, + **kwargs, + ): + """Override to inject thinker states during token generation. + + Upstream gates encode_vision_to_input behind is_for_context_encoding, + which is correct for Qwen2-VL/Qwen3-VL but wrong for Talker because + Talker needs per-step thinker state injection during autoregressive + decode. We pre-inject here, then pass inputs_embeds to super() with + vision_embeddings=None/vision_mask=None so the parent doesn't try + to inject again. + """ + if ( + not is_for_context_encoding + and vision_embeddings is not None + and vision_mask is not None + ): + # Token-generation phase: compute inputs_embeds ourselves, inject, + # then let super() skip its vision block. + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to( + self.config.neuron_config.torch_dtype + ) + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + vision_embeddings = None + vision_mask = None + + return super().get_model_output( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + is_for_context_encoding=is_for_context_encoding, + adapter_ids=adapter_ids, + **kwargs, + ) + + def setup_attr_for_model(self, config: TalkerInferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: TalkerInferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Fused embedding: 8448 → 896 (original 8448→3584→896 collapsed) + self.embed_tokens = ParallelEmbedding( + config.vocab_size, # 8448 + config.hidden_size, # 896 + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + pad=True, + ) + self.layers = nn.ModuleList( + [NeuronTalkerDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = _get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + # codec_head: 896 → 8448 + self.lm_head = ColumnParallelLinear( + config.hidden_size, # 896 + config.vocab_size, # 8448 + bias=False, + pad=True, + gather_output=not self.on_device_sampling, + ) + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Inject projected thinker states into token embeddings. + + Operates in two modes: + - Context encoding (seq > 1): REPLACE placeholder embeddings with + projected thinker states. All positions get thinker states. + - Token generation (seq == 1): ADD per-step thinker reply state to + the codec token embedding. This provides text guidance at each + autoregressive step, matching HF's per-step injection behavior. + + Args: + inputs_embeds: (batch, seq, 896) from embed_tokens + vision_embeddings: (batch, seq, 896) projected thinker states + vision_mask: (batch, seq, 1) int32 mask + + Returns: + (batch, seq, 896) with thinker states injected + """ + if inputs_embeds.shape[1] > 1: + # Context encoding: REPLACE embeddings with thinker states + vision_mask_bool = vision_mask.bool() + if vision_mask_bool.dim() == 3: + vision_mask_bool = vision_mask_bool.squeeze(-1) + mask_expanded = vision_mask_bool.unsqueeze(-1).expand_as(inputs_embeds) + return torch.where(mask_expanded, vision_embeddings, inputs_embeds) + else: + # Token generation: ADD thinker state to codec token embedding + # This matches HF's behavior where thinker_reply_part[step] is + # added to embed_tokens(codec_token) at each generation step. + return inputs_embeds + vision_embeddings + + +# --------------------------------------------------------------------------- +# Model Wrapper (tracing with per-step vision_embeddings) +# --------------------------------------------------------------------------- + + +class TalkerModelWrapper: + """Mixin that overrides get_dummy_vision_inputs for per-step injection. + + Unlike the default ImageToTextModelWrapper which uses empty + vision_embeddings during token generation tracing, this provides + (batch, 1, hidden_size) tensors so the compiled NEFF includes the + ADD operation for thinker state injection at each generation step. + """ + + @staticmethod + def get_dummy_vision_inputs(config, input_ids, n_active_tokens, fill_value): + input_batch_size, input_sequence_len = input_ids.shape[0], input_ids.shape[-1] + if input_sequence_len > 1: + # Context encoding: full-sequence vision embeddings + vision_embeddings = torch.zeros( + input_batch_size, + n_active_tokens, + config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + size=(input_batch_size, n_active_tokens, 1), + fill_value=fill_value, + dtype=torch.int32, + ) + else: + # Token generation: single-step vision embeddings for per-step + # thinker state injection (ADD to codec token embedding) + vision_embeddings = torch.zeros( + input_batch_size, + 1, + config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + size=(input_batch_size, 1, 1), + fill_value=fill_value, + dtype=torch.int32, + ) + return vision_embeddings, vision_mask + + +# --------------------------------------------------------------------------- +# Application (NeuronBaseForCausalLM) +# --------------------------------------------------------------------------- + +class NeuronQwen25OmniTalkerForCausalLM(NeuronBaseForCausalLM): + """Neuron-compiled Talker for autoregressive codec token generation. + + Compilation: + - Recommended TP=4 (12 Q heads / 4 = 3 per rank) + - Uses ImageToTextModelWrapper for vision_embeddings support + - Context encoding: thinker states injected as vision_embeddings + - Token generation: standard autoregressive with fused embedding + + State dict conversion: + - Fuses embed_tokens(8448, 3584) + thinker_to_talker_proj(3584, 896) + into embed_tokens(8448, 896) + - Maps codec_head → lm_head + - Supports fused QKV + + Usage: + # 1. Create config + talker_config = TalkerInferenceConfig(neuron_config, load_config=hf_config) + + # 2. Create application + app = NeuronQwen25OmniTalkerForCausalLM(model_path, config=talker_config) + + # 3. Compile + app.compile(compiled_model_path) + + # 4. Load + app.load(compiled_model_path) + + # 5. Generate (token generation on Neuron) + logits = app.forward(input_ids, attention_mask, position_ids, seq_ids, ...) + """ + + _model_cls = NeuronTalkerModel + + def get_model_wrapper_cls(self): + from neuronx_distributed_inference.models.image_to_text_model_wrapper import ( + ImageToTextModelWrapper, + ) + + # Dynamically create a wrapper with per-step vision_embeddings support. + # staticmethod() preserves the descriptor when assigning to class attr. + class _TalkerImageToTextModelWrapper(ImageToTextModelWrapper): + get_dummy_vision_inputs = staticmethod( + TalkerModelWrapper.get_dummy_vision_inputs + ) + + return _TalkerImageToTextModelWrapper + + @classmethod + def get_config_cls(cls): + return TalkerInferenceConfig + + def set_vision_embeddings(self, vision_embeddings, vision_mask, + thinker_reply_embeds=None): + """Store vision embeddings for the next generate() call. + + During context encoding, projected thinker states (896-dim) are + injected as vision_embeddings. During token generation, per-step + thinker reply states are ADDED to codec token embeddings. + + Vision embeddings are padded to max_context_length to match the + compiled bucket shapes (the compiled model expects fixed-size + vision_embeddings matching the bucket, while input_ids and + attention_mask are padded by preprocess_inputs). + + Args: + vision_embeddings: (batch, seq, 896) projected thinker states + for context encoding + vision_mask: (batch, seq, 1) int32 mask (all positions active) + thinker_reply_embeds: (batch, n_reply, 896) optional per-step + thinker reply states for token generation. If provided, + reply_embeds[:, step, :] is added to the codec token + embedding at each generation step. + """ + # Pad vision_embeddings and vision_mask to max_context_length so they + # match the compiled NEFF bucket shapes. + max_ctx = self.neuron_config.max_context_length + batch, seq, dim = vision_embeddings.shape + if seq < max_ctx: + pad_ve = torch.zeros( + batch, max_ctx - seq, dim, dtype=vision_embeddings.dtype + ) + vision_embeddings = torch.cat([vision_embeddings, pad_ve], dim=1) + pad_vm = torch.zeros( + batch, max_ctx - seq, 1, dtype=vision_mask.dtype + ) + vision_mask = torch.cat([vision_mask, pad_vm], dim=1) + + self._vision_embeddings = vision_embeddings + self._vision_mask = vision_mask + self._thinker_reply_embeds = thinker_reply_embeds + self._vision_dtype = vision_embeddings.dtype + self._tkg_step = 0 + + def _get_model_outputs( + self, input_ids, attention_mask, position_ids, seq_ids, + sampling_params, prev_hidden, adapter_ids, + medusa_args=None, llava_args=None, **kwargs + ): + """Override to pass vision_embeddings to ImageToTextModelWrapper. + + Context encoding: passes full projected thinker states as + vision_embeddings (REPLACE mode in encode_vision_to_input). + + Token generation: passes per-step thinker reply state as + vision_embeddings (ADD mode in encode_vision_to_input). + + ImageToTextModelWrapper traces with 24 positional args: + 0-4: input_ids, attention_mask, position_ids, seq_ids, sampling_params + 5-20: empty placeholders (prev_hidden, adapter_ids, medusa/block args) + 21: rotary_position_ids + 22: vision_embeddings + 23: vision_mask + """ + vision_embeddings = getattr(self, '_vision_embeddings', torch.empty(0)) + vision_mask = getattr(self, '_vision_mask', torch.empty(0)) + + if self._is_prefill(position_ids): + outputs = self.context_encoding_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + torch.empty(0), # prev_hidden + torch.empty(0), # adapter_ids + torch.empty(0), # accepted_indices + torch.empty(0), # current_length + torch.empty(0), # medusa_mask + torch.empty(0), # scatter_index + torch.empty(0), # slot_mapping + torch.empty(0), # active_block_table + torch.empty(0), # num_queries + torch.empty(0), # computed_context_lens + torch.empty(0), # tile_q_indices + torch.empty(0), # tile_block_tables + torch.empty(0), # tile_masks + torch.empty(0), # inputs_embeds + torch.empty(0), # kv_cache + torch.empty(0), # active_mask + torch.empty(0), # rotary_position_ids + vision_embeddings, + vision_mask, + ) + self.kv_cache_populated = True + # Clear context vision (no longer needed), keep reply embeds + self._vision_embeddings = torch.empty(0) + self._vision_mask = torch.empty(0) + self._tkg_step = 0 + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + # Get per-step thinker reply state for this generation step + reply_embeds = getattr(self, '_thinker_reply_embeds', None) + dtype = getattr(self, '_vision_dtype', torch.bfloat16) + batch_size = input_ids.shape[0] + hidden_size = self.config.hidden_size + + if reply_embeds is not None and self._tkg_step < reply_embeds.shape[1]: + step_ve = reply_embeds[:, self._tkg_step:self._tkg_step + 1, :] + step_vm = torch.ones(batch_size, 1, 1, dtype=torch.int32) + self._tkg_step += 1 + elif reply_embeds is not None and reply_embeds.shape[1] > 0: + # Repeat the last reply state (matches HF behavior where + # thinker_reply_part stays at the last element when exhausted) + step_ve = reply_embeds[:, -1:, :] + step_vm = torch.ones(batch_size, 1, 1, dtype=torch.int32) + else: + step_ve = torch.zeros( + batch_size, 1, hidden_size, dtype=dtype + ) + step_vm = torch.ones(batch_size, 1, 1, dtype=torch.int32) + + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + torch.empty(0), # prev_hidden + torch.empty(0), # adapter_ids + torch.empty(0), # accepted_indices + torch.empty(0), # current_length + torch.empty(0), # medusa_mask + torch.empty(0), # scatter_index + torch.empty(0), # slot_mapping + torch.empty(0), # active_block_table + torch.empty(0), # num_queries + torch.empty(0), # computed_context_lens + torch.empty(0), # tile_q_indices + torch.empty(0), # tile_block_tables + torch.empty(0), # tile_masks + torch.empty(0), # inputs_embeds + torch.empty(0), # kv_cache + torch.empty(0), # active_mask + torch.empty(0), # rotary_position_ids + step_ve, # vision_embeddings (per-step thinker state) + step_vm, # vision_mask + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: TalkerInferenceConfig + ) -> dict: + """Convert HF Talker state dict to Neuron format. + + Key transformations: + 1. Strip 'talker.' and 'model.' prefixes + 2. Fuse embed_tokens(8448, 3584) + thinker_to_talker_proj(3584, 896) + → embed_tokens(8448, 896) + 3. Map codec_head → lm_head + 4. Add rank utilities for distributed inference + 5. Optionally fuse QKV projections + + The original thinker_to_talker_proj weights are stored as + '_thinker_proj_weight' and '_thinker_proj_bias' in the returned + dict for use during context encoding (CPU-side projection). + """ + neuron_config = config.neuron_config + new_state_dict = {} + proj_weight = None + proj_bias = None + + for key, value in state_dict.items(): + # Strip talker. prefix + if key.startswith("talker."): + key = key[len("talker."):] + + # Route keys + if key.startswith("model."): + new_key = key[len("model."):] + elif key == "codec_head.weight": + new_key = "lm_head.weight" + elif key == "thinker_to_talker_proj.weight": + proj_weight = value + continue + elif key == "thinker_to_talker_proj.bias": + proj_bias = value + continue + else: + # Skip keys from other components + if any(key.startswith(p) for p in [ + "lm_head.", "visual.", "audio_tower.", + "thinker.", "token2wav.", + ]): + continue + new_key = key + new_state_dict[new_key] = value + + # Fuse embedding: embed(8448, 3584) @ proj.T(3584, 896) → (8448, 896) + # NOTE: projection bias is NOT included in the fused embedding. + # During token generation, the bias is already included once in the + # projected thinker reply states (proj(reply) = W @ reply + bias). + # Including it here would cause double-bias: fused(token) + proj(reply) + # = (W @ E + bias) + (W @ reply + bias) = W @ (E+reply) + 2*bias, + # whereas HF computes proj(E + reply) = W @ (E+reply) + bias. + if "embed_tokens.weight" in new_state_dict and proj_weight is not None: + embed_weight = new_state_dict["embed_tokens.weight"] # (8448, 3584) + fused_embed = embed_weight.float() @ proj_weight.float().T # (8448, 896) + new_state_dict["embed_tokens.weight"] = fused_embed.to( + neuron_config.torch_dtype + ) + logger.info( + "Fused embed_tokens (%s) + proj (%s) → (%s) (bias excluded)", + list(embed_weight.shape), list(proj_weight.shape), + list(new_state_dict["embed_tokens.weight"].shape), + ) + + # Save projection weights for CPU-side context encoding + if proj_weight is not None: + new_state_dict["_thinker_proj_weight"] = proj_weight + if proj_bias is not None: + new_state_dict["_thinker_proj_bias"] = proj_bias + + # Add rank utilities + if neuron_config.vocab_parallel: + new_state_dict["embed_tokens.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + tp_degree = neuron_config.tp_degree + for i in range(config.num_hidden_layers): + new_state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + # Fuse QKV if enabled + if neuron_config.fused_qkv: + new_state_dict = _fuse_talker_qkv(new_state_dict, config) + + new_state_dict["rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + gc.collect() + return new_state_dict + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + # Talker does not tie embed_tokens and lm_head weights + pass + + def get_compiler_args(self): + return ( + "--enable-saturate-infinity --enable-mixed-precision-accumulation " + "--auto-cast=none --model-type transformer -O1 " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 --vectorize-strided-dma' " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + +def _fuse_talker_qkv(state_dict: dict, config: InferenceConfig) -> dict: + """Fuse Q/K/V weight and bias tensors into Wqkv for the Talker.""" + for layer_idx in range(config.num_hidden_layers): + for attr in ["weight", "bias"]: + q_key = f"layers.{layer_idx}.self_attn.q_proj.{attr}" + k_key = f"layers.{layer_idx}.self_attn.k_proj.{attr}" + v_key = f"layers.{layer_idx}.self_attn.v_proj.{attr}" + if all(k in state_dict for k in [q_key, k_key, v_key]): + state_dict[f"layers.{layer_idx}.self_attn.Wqkv.{attr}"] = torch.cat([ + state_dict.pop(q_key), + state_dict.pop(k_key), + state_dict.pop(v_key), + ]) + gc.collect() + return state_dict + + +# --------------------------------------------------------------------------- +# Helper: CPU-side thinker state projection +# --------------------------------------------------------------------------- + +class ThinkerToTalkerProjection(nn.Module): + """CPU-side projection from Thinker hidden space to Talker hidden space. + + During context encoding, thinker hidden states (3584-d) need to be + projected to 896-d before being injected into the Neuron model as + vision_embeddings. + + This module is loaded from the original thinker_to_talker_proj weights + that are extracted during state dict conversion. + """ + + def __init__(self, thinker_hidden_size=3584, talker_hidden_size=896): + super().__init__() + self.proj = nn.Linear(thinker_hidden_size, talker_hidden_size) + + @torch.no_grad() + def forward(self, thinker_hidden_states): + """Project thinker states for Neuron context encoding. + + Args: + thinker_hidden_states: (batch, seq_len, 3584) + + Returns: + (batch, seq_len, 896) projected states ready for vision_embeddings + """ + return self.proj(thinker_hidden_states) + + @classmethod + def from_state_dict(cls, state_dict, dtype=torch.bfloat16): + """Create projection from extracted state dict. + + Args: + state_dict: Must contain '_thinker_proj_weight' and + optionally '_thinker_proj_bias'. + + Returns: + ThinkerToTalkerProjection on CPU in specified dtype + """ + proj_weight = state_dict.get("_thinker_proj_weight") + proj_bias = state_dict.get("_thinker_proj_bias") + if proj_weight is None: + raise ValueError( + "State dict missing '_thinker_proj_weight'. " + "Run convert_hf_to_neuron_state_dict first." + ) + in_features = proj_weight.shape[1] + out_features = proj_weight.shape[0] + module = cls(in_features, out_features) + module.proj.weight.data = proj_weight.to(dtype) + if proj_bias is not None: + module.proj.bias.data = proj_bias.to(dtype) + else: + module.proj.bias = None + module.to(dtype) + module.eval() + return module diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py new file mode 100644 index 00000000..9e37ae6c --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -0,0 +1,765 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni Token2Wav for NXD inference. +# +# This file contains TWO implementations: +# +# 1. NeuronQwen25OmniToken2Wav (CPU wrapper) +# - Wraps HF's Qwen2_5OmniToken2WavModel entirely on CPU in float32 +# - Suitable for quick testing or when Neuron resources are limited +# +# 2. NeuronQwen25OmniToken2WavWithNeuronDiT (Neuron-accelerated) +# - Compiles the DiT (22 transformer blocks) on Neuron via torch_neuronx.trace() +# - ODE solver loop stays on CPU (inherently sequential, 10-50 steps) +# - BigVGAN vocoder stays on CPU (convolutional, ~10-20M params) +# - Speaker encoder (ECAPA-TDNN) stays on CPU (~small) +# - DiT is the compute bottleneck: 22 blocks × 10 ODE steps = 220 forward passes +# +# Architecture: +# - DiT (Diffusion Transformer): 22 blocks, dim=1024, 16 heads +# - ECAPA-TDNN speaker encoder for speaker conditioning +# - Codec embedding + RoPE + AdaLayerNorm +# - ODE sampling (Runge-Kutta 4) for mel spectrogram generation +# - BigVGAN vocoder: mel spectrogram -> waveform +# - conv_pre(80->1536) + 6 upsample stages + AMPBlock residuals +# - Snake activation, conv_post(24->1) +# +# Runs on CPU in float32 (required for ODE solver precision). +# Token2Wav has ~809 state dict keys total. + +"""Qwen2.5-Omni Token2Wav model for NXD inference.""" + +import json +import logging +import os +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Part 1: CPU-based Token2Wav (HF wrapper) +# ============================================================================= + + +class NeuronQwen25OmniToken2Wav: + """Wrapper around HF's Qwen2_5OmniToken2WavModel. + + Token2Wav converts codec tokens into audio waveforms through: + 1. DiT model: codec tokens + speaker embedding -> mel spectrogram + (via ODE sampling with classifier-free guidance) + 2. BigVGAN vocoder: mel spectrogram -> waveform + + Speaker conditioning requires a speaker dict (spk_dict.pt) containing + per-speaker 'cond' (conditioning) and 'ref_mel' (reference mel) tensors, + plus 'bos_token' for the Talker. + + This wrapper: + 1. Instantiates the HF Token2Wav from config + 2. Loads weights from converted state dict + 3. Exposes waveform generation API + """ + + def __init__(self, token2wav_config): + """Initialize Token2Wav. + + Args: + token2wav_config: Token2Wav config (dict or HF config object). + Must contain dit_config and bigvgan_config sub-configs. + """ + from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniToken2WavConfig, + ) + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniToken2WavModel, + ) + + if isinstance(token2wav_config, dict): + token2wav_config = Qwen2_5OmniToken2WavConfig(**token2wav_config) + + self.model = Qwen2_5OmniToken2WavModel(token2wav_config) + # Token2Wav must run in float32 for ODE solver precision + self.model.float() + self.model.eval() + self.config = token2wav_config + + @property + def dtype(self): + """Return dtype of the underlying HF model (for HF generate compatibility).""" + return next(self.model.parameters()).dtype + + def float(self): + """Cast underlying model to float32 (for HF generate compatibility).""" + self.model.float() + return self + + def load_state_dict(self, state_dict, strict=True): + """Load converted state dict into the HF Token2Wav model.""" + return self.model.load_state_dict(state_dict, strict=strict) + + @torch.no_grad() + def __call__( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generate waveform from codec tokens. + + Args: + code: (batch, seq_len) codec token IDs from the Talker + conditioning: (batch, mel_len, enc_dim) speaker conditioning + (from spk_dict.pt 'cond' key) + reference_mel: (batch, mel_len, mel_dim) reference mel spectrogram + (from spk_dict.pt 'ref_mel' key) + num_steps: Number of ODE solver steps (default 10) + guidance_scale: Classifier-free guidance scale (default 0.5) + sway_coefficient: Time schedule sway (default -1.0) + **kwargs: Additional kwargs passed to Token2Wav + + Returns: + waveform: (samples,) audio waveform tensor on CPU + """ + return self.model( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + **kwargs, + ) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: dict) -> dict: + """Convert HF state dict to Token2Wav format. + + Strips 'token2wav.' prefix from keys. Non-token2wav keys are passed through. + + Args: + state_dict: Full or partial state dict with token2wav.* keys. + + Returns: + State dict with token2wav prefix stripped. + """ + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("token2wav."): + new_state_dict[key[len("token2wav."):]] = value + else: + new_state_dict[key] = value + return new_state_dict + + @staticmethod + def load_speaker_dict(speaker_dict_path): + """Load speaker dictionary from spk_dict.pt. + + Args: + speaker_dict_path: Path to spk_dict.pt + + Returns: + dict: Speaker name -> {cond, ref_mel, bos_token} + """ + return torch.load(speaker_dict_path, weights_only=True) + + @classmethod + def from_pretrained_state_dict(cls, token2wav_config, state_dict): + """Create Token2Wav and load weights from converted state dict. + + Args: + token2wav_config: Token2Wav config (dict or HF config object) + state_dict: Already-converted state dict (token2wav keys only) + + Returns: + Initialized NeuronQwen25OmniToken2Wav + """ + token2wav = cls(token2wav_config) + + # Filter to only token2wav keys (skip non-token2wav prefixes) + t2w_keys = {} + for key, value in state_dict.items(): + if any( + key.startswith(p) + for p in [ + "lm_head.", "visual.", "audio_tower.", + "thinker.", "talker.", "token2wav.", + ] + ): + continue + t2w_keys[key] = value + + missing, unexpected = token2wav.load_state_dict(t2w_keys, strict=False) + if missing: + logger.warning("Token2Wav missing keys: %s", missing[:10]) + if unexpected: + logger.warning("Token2Wav unexpected keys: %s", unexpected[:10]) + logger.info("Loaded %d weights into Token2Wav", len(t2w_keys)) + + return token2wav + + +# ============================================================================= +# Part 2: Neuron-accelerated Token2Wav (DiT on Neuron) +# ============================================================================= + + +def _monkeypatch_dit_attention_for_neuron(dit_module): + """Replace DiTAttention.forward with XLA-traceable version. + + Fixes two XLA-incompatible operations in DiTAttention: + 1. In-place slice assignment: query[:, :1], key[:, :1] = ... + → replaced with torch.cat-based reassembly + 2. ALL_ATTENTION_FUNCTIONS dispatch → explicit matmul attention + + Args: + dit_module: Qwen2_5OmniToken2WavDiTModel instance + """ + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + apply_rotary_pos_emb, + ) + + for block in dit_module.transformer_blocks: + attn = block.attn + + def _make_patched(a): + def forward(hidden_states, position_embeddings=None, attention_mask=None): + batch_size = hidden_states.shape[0] + query = a.to_q(hidden_states) + key = a.to_k(hidden_states) + value = a.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // a.heads + query = query.view(batch_size, -1, a.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, a.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, a.heads, head_dim).transpose(1, 2) + + # Apply RoPE to first head only (matches HF behavior) + # FIX: use torch.cat instead of in-place slice assignment + cos, sin = position_embeddings + q_rope, k_rope = apply_rotary_pos_emb( + query[:, :1], key[:, :1], cos, sin + ) + query = torch.cat([q_rope, query[:, 1:]], dim=1) + key = torch.cat([k_rope, key[:, 1:]], dim=1) + + # Explicit matmul attention (XLA-safe, no SDPA dispatch) + scale = head_dim ** -0.5 + attn_weights = torch.matmul( + query, key.transpose(-2, -1) + ) * scale + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1 + ) + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape( + batch_size, -1, a.heads * head_dim + ) + attn_output = attn_output.to(query.dtype) + + attn_output = a.to_out[0](attn_output) + attn_output = a.to_out[1](attn_output) + + return attn_output + + return forward + + attn.forward = _make_patched(attn) + + +class _NeuronDiTCore(torch.nn.Module): + """Traced wrapper for DiT transformer blocks + norm_out + proj_out. + + Only the compute-heavy transformer core is compiled on Neuron. + All preprocessing (time_embed, text_embed, input_embed with ECAPA-TDNN, + rotary_embed, block_diff) stays on CPU to avoid XLA tracing issues with: + - ECAPA-TDNN Conv1d(padding="same", padding_mode="reflect") + - AttentiveStatisticsPooling (dynamic masks, masked_fill(-inf)) + - DiTCodecEmbedding (torch.repeat_interleave) + - DiTInputEmbedding (2D/3D tensor cat mismatch) + - Rotary embedding (torch.autocast context manager) + + Each block may have a different attention pattern (look_backward/look_ahead). + Three per-block masks are pre-computed on CPU: + - mask_local: look_backward=0, look_ahead=0 (most blocks) + - mask_backward: look_backward=1, look_ahead=0 (blocks 0, 20) + - mask_ahead: look_backward=0, look_ahead=1 (block 10) + """ + + def __init__(self, dit_module): + super().__init__() + self.transformer_blocks = dit_module.transformer_blocks + self.norm_out = dit_module.norm_out + self.proj_out = dit_module.proj_out + # Build per-block mask selection (Python list for static trace) + # 0 = mask_local (0,0), 1 = mask_backward (1,0), 2 = mask_ahead (0,1) + self._block_mask_idx = [] + for block in dit_module.transformer_blocks: + lb = block.look_backward_block + la = block.look_ahead_block + if lb == 0 and la == 0: + self._block_mask_idx.append(0) + elif lb >= 1 and la == 0: + self._block_mask_idx.append(1) + else: # la >= 1 + self._block_mask_idx.append(2) + + def forward(self, hidden_states, time_embedding, cos, sin, + mask_local, mask_backward, mask_ahead): + """ + Args: + hidden_states: (batch, seq_len, dim) from input_embed + time_embedding: (time_batch, dim) from time_embed (broadcasts) + cos: (batch, seq_len, head_dim) from rotary_embed + sin: (batch, seq_len, head_dim) from rotary_embed + mask_local: (batch, 1, seq_len, seq_len) float mask (0,0) + mask_backward: (batch, 1, seq_len, seq_len) float mask (1,0) + mask_ahead: (batch, 1, seq_len, seq_len) float mask (0,1) + """ + position_embeddings = (cos, sin) + masks = [mask_local, mask_backward, mask_ahead] + + for i, block in enumerate(self.transformer_blocks): + # Select per-block mask (static Python int index → traced as constant) + attention_mask = masks[self._block_mask_idx[i]] + + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.attn_norm( + hidden_states, emb=time_embedding + ) + attn_output = block.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + norm = ( + block.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + + shift_mlp[:, None] + ) + ff_output = block.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + return output + + +class NeuronQwen25OmniToken2WavWithNeuronDiT(NeuronQwen25OmniToken2Wav): + """Token2Wav with DiT transformer core compiled on Neuron. + + The DiT is the compute bottleneck in Token2Wav: + - 22 transformer blocks × 10 ODE steps = 220 forward passes per generation + - Each block: self-attention (dim=1024, 16 heads) + FFN + - Total ~85M params + + Split architecture (CPU preprocessing + Neuron transformer core): + CPU: time_embed(time_step) → time embedding + CPU: text_embed(quantized_code) → codec features + CPU: input_embed(hidden_states, speaker_embedding, condition_vector, code_embed) + → includes ECAPA-TDNN speaker encoder, CFG batch doubling + CPU: rotary_embed(hidden_states) → cos, sin + CPU: _create_block_diff(hidden_states) → attention mask + Neuron: 22 transformer blocks + norm_out + proj_out (the compute hotspot) + CPU: ODE solver loop (Runge-Kutta 4, 10 steps) + CPU: BigVGAN vocoder → waveform + + Usage: + t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(token2wav_config) + t2w.load_state_dict(state_dict) + t2w.compile_dit("compiled_dit/", max_mel_len=2048) + # Subsequent runs: + t2w.load_dit("compiled_dit/") + waveform = t2w(code, conditioning, reference_mel) + """ + + def __init__(self, token2wav_config): + super().__init__(token2wav_config) + self._neuron_dit_core = None + self._dit_compiled_path = None + self._dit_max_mel_len = None + self._dit_batch_size = None + + def _get_dit_module(self): + """Extract the DiT sub-module from the HF Token2Wav model.""" + for attr_name in [ + "code2wav_dit_model", "dit", "flow_model", "transformer", + ]: + if hasattr(self.model, attr_name): + return getattr(self.model, attr_name) + for name, module in self.model.named_children(): + if "dit" in name.lower() or "flow" in name.lower(): + return module + return None + + def compile_dit( + self, + compiled_path, + max_mel_len=2048, + batch_size=2, + ): + """Compile the DiT transformer core on Neuron. + + Only the 22 transformer blocks + norm_out + proj_out are compiled. + Preprocessing (ECAPA-TDNN, codec embedding, input embedding, rotary) + stays on CPU to avoid XLA tracing issues. + + Args: + compiled_path: Directory to save compiled model + max_mel_len: Maximum mel spectrogram length (covers ~24s audio). + Shorter inputs are padded; longer inputs fall back to CPU. + batch_size: Batch size for compilation. Use 2 for standard + inference with classifier-free guidance (CFG doubles batch). + """ + try: + import torch_neuronx + except ImportError: + raise ImportError( + "torch_neuronx required for DiT compilation. " + "Run on a Neuron instance (trn1/trn2/inf2)." + ) + + os.makedirs(compiled_path, exist_ok=True) + + dit = self._get_dit_module() + if dit is None: + raise RuntimeError("Could not extract DiT module from Token2Wav.") + + # Get model dimensions + dit_cfg = getattr(dit, "config", None) + dim = getattr(dit_cfg, "dim", 1024) + num_heads = getattr(dit_cfg, "num_attention_heads", 16) + head_dim = dim // num_heads + + logger.info( + "Compiling DiT core: batch=%d, mel_len=%d, dim=%d, heads=%d", + batch_size, max_mel_len, dim, num_heads, + ) + + # Monkeypatch DiTAttention to fix in-place slice assignment + _monkeypatch_dit_attention_for_neuron(dit) + + # Create wrapper for transformer core only + core = _NeuronDiTCore(dit) + core.float() + core.eval() + + # Create example inputs + # time_embedding uses batch=1 (broadcasts to hidden_states batch) + hidden_states = torch.randn( + batch_size, max_mel_len, dim, dtype=torch.float32 + ) + time_embedding = torch.randn(1, dim, dtype=torch.float32) + cos = torch.randn( + batch_size, max_mel_len, head_dim, dtype=torch.float32 + ) + sin = torch.randn( + batch_size, max_mel_len, head_dim, dtype=torch.float32 + ) + # Three per-block attention masks (local, backward, ahead) + mask_local = torch.zeros( + batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 + ) + mask_backward = torch.zeros( + batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 + ) + mask_ahead = torch.zeros( + batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 + ) + + compiled = torch_neuronx.trace( + core, + (hidden_states, time_embedding, cos, sin, + mask_local, mask_backward, mask_ahead), + compiler_args=[ + "--auto-cast=none", + "--model-type=transformer", + "-O1", + ], + ) + + save_path = os.path.join(compiled_path, "dit_core_neuron.pt") + torch.jit.save(compiled, save_path) + + # Save metadata for load + meta = { + "max_mel_len": max_mel_len, + "batch_size": batch_size, + "dim": dim, + "num_heads": num_heads, + "head_dim": head_dim, + } + with open(os.path.join(compiled_path, "dit_core_meta.json"), "w") as f: + json.dump(meta, f) + + logger.info("Compiled DiT core saved to %s", save_path) + + self._neuron_dit_core = compiled + self._dit_compiled_path = compiled_path + self._dit_max_mel_len = max_mel_len + self._dit_batch_size = batch_size + + def load_dit(self, compiled_path): + """Load a previously compiled DiT core model. + + Args: + compiled_path: Directory containing compiled model + """ + save_path = os.path.join(compiled_path, "dit_core_neuron.pt") + meta_path = os.path.join(compiled_path, "dit_core_meta.json") + + if not os.path.exists(save_path): + raise FileNotFoundError( + f"Compiled DiT core not found at {save_path}" + ) + + self._neuron_dit_core = torch.jit.load(save_path) + self._dit_compiled_path = compiled_path + + if os.path.exists(meta_path): + with open(meta_path) as f: + meta = json.load(f) + self._dit_max_mel_len = meta["max_mel_len"] + self._dit_batch_size = meta["batch_size"] + + logger.info("Loaded compiled DiT core from %s", save_path) + + def _build_attention_masks(self, block_diff, actual_mel_len, max_mel_len): + """Build three per-block float additive attention masks with padding. + + DiT blocks have three distinct attention patterns based on their + look_backward_block/look_ahead_block attributes: + - mask_local (0,0): most blocks — attend only within same block + - mask_backward (1,0): blocks 0,20 — attend current + previous block + - mask_ahead (0,1): block 10 — attend current + next block + + Args: + block_diff: (batch, heads, actual_mel_len, actual_mel_len) + from _create_block_diff. Values are block index differences. + actual_mel_len: Actual sequence length before padding + max_mel_len: Padded sequence length (compiled size) + + Returns: + Tuple of 3 masks, each (batch, 1, max_mel_len, max_mel_len) float32. + 0.0 for attend, -1e4 for don't attend. + """ + mask_batch = self._dit_batch_size or 1 + bd = block_diff[0, 0] # (actual_mel_len, actual_mel_len) + + # Three patterns: (look_backward, look_ahead) + patterns = [(0, 0), (1, 0), (0, 1)] + masks = [] + for lb, la in patterns: + bool_mask = (bd >= -float(lb)) & (bd <= float(la)) + valid = torch.where( + bool_mask, + torch.tensor(0.0, dtype=torch.float32), + torch.tensor(-1e4, dtype=torch.float32), + ) + mask = torch.full( + (mask_batch, 1, max_mel_len, max_mel_len), + -1e4, dtype=torch.float32, + ) + for b in range(mask_batch): + mask[b, 0, :actual_mel_len, :actual_mel_len] = valid + masks.append(mask) + + return masks[0], masks[1], masks[2] + + @torch.no_grad() + def __call__( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generate waveform. DiT core runs on Neuron if compiled, else CPU. + + Monkeypatches dit.forward to split execution: + - CPU: preprocessing (time/text/input embed, rotary, block_diff) + - Neuron: 22 transformer blocks + norm + proj + - CPU: ODE solver, BigVGAN vocoder + """ + if self._neuron_dit_core is not None: + dit = self._get_dit_module() + original_forward = dit.forward + neuron_core = self._neuron_dit_core + max_mel_len = self._dit_max_mel_len + expected_batch = self._dit_batch_size + build_masks = self._build_attention_masks + + def neuron_dit_forward( + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + """DiT forward with Neuron-accelerated transformer core.""" + batch_size = hidden_states.shape[0] + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Estimate mel_len early. input_embed doubles batch for + # CFG, so we must fall back BEFORE calling it — otherwise + # original_forward would receive already-modified tensors. + est_mel_len = hidden_states.shape[1] + if est_mel_len > max_mel_len: + logger.warning( + "mel_len %d > max %d, falling back to CPU", + est_mel_len, + max_mel_len, + ) + return original_forward( + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=drop_audio_conditioning, + drop_code=drop_code, + apply_cfg=apply_cfg, + ) + + # CPU: compute embeddings (same as HF original) + time_embedding = dit.time_embed(time_step) + text_embedding = dit.text_embed( + quantized_code, + drop_code=False if apply_cfg else drop_code, + ) + text_embedding_uncond = ( + dit.text_embed(quantized_code, drop_code=True) + if apply_cfg + else None + ) + + # CPU: input embedding (ECAPA-TDNN, CFG batch doubling) + hidden_states = dit.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_uncond, + apply_cfg=apply_cfg, + ) + + # CPU: positional encodings + cos, sin = dit.rotary_embed(hidden_states) + block_diff = dit._create_block_diff(hidden_states) + + actual_mel_len = hidden_states.shape[1] + actual_batch = hidden_states.shape[0] + + # Build three per-block attention masks + mask_local, mask_backward, mask_ahead = build_masks( + block_diff, actual_mel_len, max_mel_len + ) + + # Pad to compiled shapes + pad_mel = max_mel_len - actual_mel_len + if pad_mel > 0: + hidden_states = torch.nn.functional.pad( + hidden_states, (0, 0, 0, pad_mel) + ) + cos = torch.nn.functional.pad(cos, (0, 0, 0, pad_mel)) + sin = torch.nn.functional.pad(sin, (0, 0, 0, pad_mel)) + + pad_batch = expected_batch - actual_batch + if pad_batch > 0: + hidden_states = torch.nn.functional.pad( + hidden_states, (0, 0, 0, 0, 0, pad_batch) + ) + cos = torch.nn.functional.pad( + cos, (0, 0, 0, 0, 0, pad_batch) + ) + sin = torch.nn.functional.pad( + sin, (0, 0, 0, 0, 0, pad_batch) + ) + + # Run transformer core on Neuron (3 per-block masks) + output = neuron_core( + hidden_states.float(), + time_embedding.float(), + cos.float(), + sin.float(), + mask_local.float(), + mask_backward.float(), + mask_ahead.float(), + ) + + # Unpad to actual sizes + output = output[:actual_batch, :actual_mel_len] + return output + + dit.forward = neuron_dit_forward + try: + result = self.model( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + **kwargs, + ) + finally: + dit.forward = original_forward + return result + else: + return self.model( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + **kwargs, + ) + + @classmethod + def from_pretrained_state_dict(cls, token2wav_config, state_dict): + """Create Token2Wav with Neuron DiT support.""" + token2wav = cls(token2wav_config) + + t2w_keys = {} + for key, value in state_dict.items(): + if any( + key.startswith(p) + for p in [ + "lm_head.", "visual.", "audio_tower.", + "thinker.", "talker.", "token2wav.", + ] + ): + continue + t2w_keys[key] = value + + missing, unexpected = token2wav.load_state_dict(t2w_keys, strict=False) + if missing: + logger.warning("Token2Wav missing keys: %s", missing[:10]) + if unexpected: + logger.warning("Token2Wav unexpected keys: %s", unexpected[:10]) + logger.info( + "Loaded %d weights into Token2Wav (Neuron DiT capable)", + len(t2w_keys), + ) + + return token2wav diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_vision.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_vision.py new file mode 100644 index 00000000..1aebd529 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_vision.py @@ -0,0 +1,579 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni Vision Encoder for NXD inference. +# +# Differences from Qwen2-VL vision encoder: +# - SwiGLU MLP (gate_proj, up_proj, down_proj) instead of simple FC1/FC2 +# - RMSNorm instead of LayerNorm +# - Separate Q/K/V projections instead of fused QKV +# - intermediate_size=3420 (not TP-divisible by 32, use nn.Linear for MLP) + +"""Qwen2.5-Omni Vision Encoder for NXD inference.""" + +import logging +import os +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + PatchEmbed, + VisionRotaryEmbedding, +) + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed_inference.models.application_base import NeuronApplicationBase +from neuronx_distributed_inference.models.config import InferenceConfig +from neuronx_distributed_inference.models.model_wrapper import ( + EncoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.models.qwen2_vl.modeling_qwen2_vl_vision import ( + Qwen2VLVisionRotaryEmbedding, +) +from neuronx_distributed_inference.models.qwen2_vl.utils.vision_utils import ( + calculate_max_grid_size, + get_image_dimensions, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import apply_rotary_pos_emb +from neuronx_distributed_inference.modules.padding import ( + pad_tensor, + pad_with_first_batchline, +) + +logger = logging.getLogger(__name__) + + +class VisionRMSNorm(nn.Module): + """RMSNorm for vision encoder (replaces LayerNorm used in Qwen2-VL).""" + + def __init__(self, hidden_size, eps=1e-6, dtype=torch.bfloat16): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class VisionSwiGLUMLP(nn.Module): + """SwiGLU MLP for Qwen2.5-Omni vision encoder. + + Uses regular nn.Linear instead of ColumnParallelLinear/RowParallelLinear + because intermediate_size=3420 is not divisible by common TP degrees (16, 32). + The vision model is small enough that MLP weight replication is acceptable. + """ + + def __init__(self, dim, hidden_dim, dtype=torch.bfloat16): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=True) + self.up_proj = nn.Linear(dim, hidden_dim, bias=True) + self.down_proj = nn.Linear(hidden_dim, dim, bias=True) + # Cast to target dtype + self.gate_proj = self.gate_proj.to(dtype) + self.up_proj = self.up_proj.to(dtype) + self.down_proj = self.down_proj.to(dtype) + + def forward(self, x): + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronQwen25OmniVisionAttention(NeuronAttentionBase): + """Vision attention with separate Q/K/V projections (not fused). + + Requires vision_config.neuron_config.fused_qkv = False. + """ + + def __init__(self, config): + super().__init__( + config=config, + hidden_size=config.embed_dim, + num_attention_heads=config.num_heads, + num_key_value_heads=config.num_heads, + head_dim=config.embed_dim // config.num_heads, + num_cores_per_group=config.num_cores_per_group, + sequence_parallel_enabled=False, + rotary_emb=Qwen2VLVisionRotaryEmbedding(), + qkv_bias=True, + o_bias=True, + ) + + def forward(self, hidden_states, position_embeddings=None, **kwargs): + self._position_embeddings = position_embeddings + try: + return super().forward(hidden_states, **kwargs) + finally: + self._position_embeddings = None + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, self._position_embeddings) + Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache) + return Q, K, cos_cache, sin_cache + + +class Qwen25OmniVisionBlock(nn.Module): + """Vision transformer block with RMSNorm and SwiGLU MLP.""" + + def __init__(self, vision_config): + super().__init__() + dtype = vision_config.neuron_config.torch_dtype + self.norm1 = VisionRMSNorm( + vision_config.embed_dim, eps=1e-6, dtype=dtype + ) + self.norm2 = VisionRMSNorm( + vision_config.embed_dim, eps=1e-6, dtype=dtype + ) + self.attn = NeuronQwen25OmniVisionAttention(vision_config) + self.mlp = VisionSwiGLUMLP( + dim=vision_config.embed_dim, + hidden_dim=vision_config.intermediate_size, + dtype=dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + attn_output = self.attn( + self.norm1(hidden_states), + position_embeddings=position_embeddings, + )[0] + hidden_states = hidden_states + attn_output + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen25OmniPatchMerger(nn.Module): + """Patch merger with RMSNorm (Qwen2-VL uses LayerNorm). + + Merges spatial_merge_size^2 patches into one, projecting from + embed_dim to out_hidden_size (text model hidden size). + """ + + def __init__(self, dim, context_dim, spatial_merge_size=2, dtype=torch.bfloat16): + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size ** 2) + self.ln_q = VisionRMSNorm(context_dim, eps=1e-6, dtype=dtype) + self.mlp = nn.Sequential( + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + gather_output=False, + dtype=dtype, + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + dim, + input_is_parallel=True, + dtype=dtype, + reduce_dtype=dtype, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class NeuronQwen25OmniVisionModel(nn.Module): + """Qwen2.5-Omni Vision Encoder on Neuron. + + Architecture is based on Qwen2-VL ViT but uses RMSNorm, SwiGLU MLP, + and separate Q/K/V projections. Reuses the same PatchEmbed and + VisionRotaryEmbedding from Qwen2-VL (identical parameters). + """ + + def __init__(self, config: InferenceConfig) -> None: + super().__init__() + self.config = config + self.vision_config = config.vision_config + + self.spatial_merge_size = self.vision_config.spatial_merge_size + + # Reuse Qwen2-VL PatchEmbed (same Conv3D architecture) + self.patch_embed = PatchEmbed( + patch_size=self.vision_config.patch_size, + temporal_patch_size=self.vision_config.temporal_patch_size, + in_channels=self.vision_config.in_channels, + embed_dim=self.vision_config.embed_dim, + ).to(self.vision_config.neuron_config.torch_dtype) + + head_dim = self.vision_config.embed_dim // self.vision_config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen25OmniVisionBlock(self.vision_config) + for _ in range(self.vision_config.depth) + ] + ) + + self.merger = Qwen25OmniPatchMerger( + dim=self.vision_config.out_hidden_size, # 3584 (text hidden size) + context_dim=self.vision_config.embed_dim, # 1280 + spatial_merge_size=self.vision_config.spatial_merge_size, + dtype=self.vision_config.neuron_config.torch_dtype, + ) + + # Calculate dynamic MAX_GRID_SIZE based on configured image dimensions + image_width, image_height = get_image_dimensions( + self.vision_config.neuron_config + ) + self.max_grid_size = calculate_max_grid_size( + image_width, + image_height, + patch_size=self.vision_config.patch_size, + ) + logger.info( + f"Calculated max_grid_size={self.max_grid_size} for " + f"image dimensions {image_width}x{image_height}" + ) + + self.precomputed_rotary_pos_emb = self.rotary_pos_emb(self.max_grid_size) + self.register_buffer( + "rotary_pos_emb_cache", + self.precomputed_rotary_pos_emb, + persistent=False, + ) + + def rot_pos_ids(self, grid_thw): + """Compute rotary position IDs for patches (same algorithm as Qwen2-VL).""" + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) + ) + pos_ids = torch.cat(pos_ids, dim=0) + return pos_ids + + def pad_to_text_seq_len(self, hidden_states): + """Pad vision outputs to text model sequence length.""" + padded_length = self.config.neuron_config.seq_len + hidden_states = hidden_states.to( + self.config.text_config.neuron_config.torch_dtype + ) + + hidden_size = hidden_states.shape[-1] + hidden_states, _ = pad_tensor( + hidden_states, (padded_length, hidden_size), pad_value=0 + ) + + # Flatten vision outputs: (seq_len, hidden_size) -> (1, seq_len, hidden_size) + hidden_states = hidden_states.view(-1, hidden_size).unsqueeze(0) + return hidden_states + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + + assert grid_thw[:, 1:].max() < self.max_grid_size, ( + f"Grid size {grid_thw[:, 1:].max()} exceeds max_grid_size " + f"{self.max_grid_size}. Increase default_image_width/height " + f"in vision_neuron_config." + ) + pos_ids = self.rot_pos_ids(grid_thw) + rotary_pos_emb = self.rotary_pos_emb_cache[pos_ids].flatten(1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos_emb = emb.cos() + sin_emb = emb.sin() + cos_emb = cos_emb.reshape(grid_thw.shape[0], -1, cos_emb.shape[-1]) + sin_emb = sin_emb.reshape(grid_thw.shape[0], -1, sin_emb.shape[-1]) + position_embeddings = (cos_emb, sin_emb) + + hidden_states = hidden_states.reshape( + grid_thw.shape[0], -1, hidden_states.shape[-1] + ) + for blk in self.blocks: + hidden_states = blk(hidden_states, position_embeddings) + hidden_states_merger = self.merger(hidden_states) + return self.pad_to_text_seq_len(hidden_states_merger) + + +class Qwen25OmniVisionModelWrapper(ModelWrapper): + """Model wrapper for Qwen2.5-Omni vision encoder. + + Handles bucketing on number of images, padding, and input generation. + Same pattern as Qwen2VLVisionModelWrapper. + """ + + def __init__( + self, + config: InferenceConfig, + model_cls, + tag="", + compiler_args: str = None, + priority_model_idx: int = None, + pipeline_execution: bool = True, + return_ranked_to_cpu: bool = False, + model_init_kwargs={}, + ) -> None: + super().__init__( + config, + model_cls, + tag, + compiler_args, + priority_model_idx, + pipeline_execution, + return_ranked_to_cpu, + model_init_kwargs, + ) + from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + smart_resize, + ) + + image_width, image_height = get_image_dimensions( + self.config.vision_config.neuron_config + ) + resized_height, resized_width = smart_resize( + width=image_width, height=image_height + ) + self.pixels_per_image = ( + resized_height // self.config.vision_config.patch_size + ) * (resized_width // self.config.vision_config.patch_size) + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + smart_resize, + ) + + inputs = [] + image_width, image_height = get_image_dimensions( + self.config.vision_config.neuron_config + ) + resized_height, resized_width = smart_resize( + width=image_width, height=image_height + ) + vc = self.config.vision_config + for bucket in vc.neuron_config.buckets: + pixel_values = torch.ones( + [ + bucket * self.pixels_per_image, + vc.in_channels + * vc.patch_size + * vc.patch_size + * vc.temporal_patch_size, + ], + dtype=vc.neuron_config.torch_dtype, + ) + grid_thw = torch.tensor( + [ + [ + 1, + resized_height // vc.patch_size, + resized_width // vc.patch_size, + ] + ] + ).repeat(bucket, 1) + inputs.append((pixel_values, grid_thw)) + return inputs + + def get_model_instance(self): + return EncoderModelInstance(model_cls=self.model_cls, config=self.config) + + def get_padded_num_image(self, pixel_values): + """Get the bucket size (number of images) for given pixel_values.""" + buckets = self.config.vision_config.neuron_config.buckets + for val in buckets: + if val * self.pixels_per_image >= pixel_values.shape[0]: + return val + raise Exception( + f"No bucket found for pixel_values shape {pixel_values.shape[0]}. " + f"pixels_per_image={self.pixels_per_image}, buckets={buckets}" + ) + + def forward(self, pixel_values, grid_thw): + """Override ModelWrapper.forward() with padding to bucket size.""" + if self.model is None: + raise RuntimeError( + "Forward called before load. Run load() or load_state_dict() " + "before calling forward" + ) + padded_num_image = self.get_padded_num_image(pixel_values) + padded_pixel_values = pad_with_first_batchline( + pixel_values, + (padded_num_image * self.pixels_per_image, pixel_values.shape[1]), + ) + padded_grid_thw = pad_with_first_batchline( + grid_thw, (padded_num_image, 3) + ) + output = self._forward(padded_pixel_values, padded_grid_thw) + return output + + +class NeuronQwen25OmniForImageEncoding(NeuronApplicationBase): + """Standalone Neuron Application for Qwen2.5-Omni image encoding. + + Wraps NeuronQwen25OmniVisionModel with compile/load functionality. + Can be used independently or as part of the full multimodal model. + """ + + _model_cls = NeuronQwen25OmniVisionModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_wrapper = self.get_model_wrapper_cls() + + self.model = self.model_wrapper( + config=self.config, + model_cls=self._model_cls, + tag=self._model_cls.__name__, + compiler_args=self.get_compiler_args(), + priority_model_idx=0, + ) + self.models.append(self.model) + + def get_model_wrapper_cls(self): + return Qwen25OmniVisionModelWrapper + + def forward(self, pixel_values, grid_thw): + return self.models[0](pixel_values, grid_thw) + + def get_compiler_args(self): + compiler_args = ( + "--auto-cast=none --model-type=transformer " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 ' -O1 " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + logger.info( + f"Compiling {self._model_cls.__name__} vision model " + f"with args: {compiler_args}" + ) + return compiler_args + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load the full Qwen2.5-Omni model (vision weights will be filtered + in convert_hf_to_neuron_state_dict).""" + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, inference_config: InferenceConfig + ) -> dict: + """Convert HF state dict to NxDI format for vision encoder. + + Handles: + 1. Filter to vision keys only (thinker.visual.* or visual.*) + 2. Strip prefix + 3. Rename separate Q/K/V attention keys to NeuronAttentionBase format + 4. Cast to target dtype + """ + new_state_dict = {} + dtype = inference_config.vision_config.neuron_config.torch_dtype + + for key, value in state_dict.items(): + # Accept both "thinker.visual." and "visual." prefixes + if key.startswith("thinker.visual."): + new_key = key[len("thinker.visual."):] + elif key.startswith("visual."): + new_key = key[len("visual."):] + else: + # Pass through non-vision keys unchanged + new_state_dict[key] = value + continue + + # Rename attention keys: separate Q/K/V -> NeuronAttentionBase format + if ".attn.proj." in new_key: + new_key = new_key.replace(".attn.proj.", ".attn.o_proj.") + elif ".attn.q." in new_key: + new_key = new_key.replace( + ".attn.q.", ".attn.qkv_proj.q_proj." + ) + elif ".attn.k." in new_key: + new_key = new_key.replace( + ".attn.k.", ".attn.qkv_proj.k_proj." + ) + elif ".attn.v." in new_key: + new_key = new_key.replace( + ".attn.v.", ".attn.qkv_proj.v_proj." + ) + + new_state_dict[new_key] = ( + value.clone().detach().contiguous().to(dtype) + ) + + del state_dict + return new_state_dict + + @classmethod + def get_config_cls(cls): + from modeling_qwen25_omni import ( + Qwen25OmniMultimodalInferenceConfig, + ) + + return Qwen25OmniMultimodalInferenceConfig + + @classmethod + def prepare_input_args(cls, prompts, images, processor, role="user", config=None): + """Prepare input arguments for Qwen2.5-Omni vision model.""" + from neuronx_distributed_inference.models.qwen2_vl.utils.input_processor import ( + prepare_generation_inputs_hf, + ) + + if len(prompts) > 1: + raise NotImplementedError( + "Qwen2.5-Omni currently only supports batch size 1" + ) + if isinstance(prompts, list): + prompts = prompts[0] + if images and isinstance(images, list) and isinstance(images[0], list): + images = images[0] + inputs = prepare_generation_inputs_hf( + prompts, images, processor, role, config + ) + vision_inputs = None + if hasattr(inputs, "pixel_values") and hasattr(inputs, "image_grid_thw"): + vision_inputs = { + "pixel_values": inputs.pixel_values, + "image_grid_thw": inputs.image_grid_thw, + } + return inputs.input_ids, inputs.attention_mask, vision_inputs diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen2_5_omni.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen2_5_omni.py deleted file mode 100644 index 89e929cd..00000000 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen2_5_omni.py +++ /dev/null @@ -1,620 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -PyTorch Qwen2.5-Omni model for NXD inference (Text-only version) - -This implementation ports the text model (Thinker) from Qwen2.5-Omni to NeuronX Distributed Inference. -It focuses on text-only inference, ignoring multimodal (audio/vision) components. - -Based on: -- Reference: NeuronxDistributedInference/src/neuronx_distributed_inference/models/qwen2/modeling_qwen2.py -""" -import json -import os -from typing import List, Optional, Tuple, Type - -import torch -from neuronx_distributed.parallel_layers.layers import ( - ColumnParallelLinear, - ParallelEmbedding, - RowParallelLinear, -) -from neuronx_distributed.utils import cpu_mode -from torch import nn -from transformers.models.llama.modeling_llama import LlamaRMSNorm - -from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig -from neuronx_distributed_inference.models.model_base import ( - NeuronBaseForCausalLM, - NeuronBaseModel, -) -from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase -from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding -from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm - - -def get_rmsnorm_cls(): - """ - Initialize to the appropriate implementation of RMSNorm - If infer on NXD -> CustomRMSNorm - If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) - """ - return LlamaRMSNorm if cpu_mode() else CustomRMSNorm - - -class Qwen2_5OmniNeuronConfig(NeuronConfig): - """NeuronConfig for Qwen2.5-Omni model""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.attn_cls = NeuronQwen2_5OmniAttention - - -class Qwen2_5OmniInferenceConfig(InferenceConfig): - """ - Configuration class for Qwen2.5-Omni inference on NeuronX. - - This config handles the text model (Thinker) from Qwen2.5-Omni. - The thinker_config.text_config contains the core text model parameters. - """ - - def add_derived_config(self): - """Add derived configuration parameters""" - self.num_cores_per_group = 1 - self.qkv_bias = True # Qwen2.5-Omni has bias in Q/K/V projections - self.o_bias = False # No bias in output projection - - # Handle layer types for sliding window attention - # Default to all full attention if not specified - if not hasattr(self, 'layer_types') or self.layer_types is None: - self.layer_types = ['full_attention'] * self.num_hidden_layers - - # Multimodal RoPE section for 3D position embeddings - # [temporal, height, width] sections - for text-only, all positions are same - if not hasattr(self, 'mrope_section'): - self.mrope_section = [16, 24, 24] # Default from config - - # Add standard HuggingFace config attributes required by NeuronBaseModel - if not hasattr(self, 'output_attentions'): - self.output_attentions = False - if not hasattr(self, 'output_hidden_states'): - self.output_hidden_states = False - if not hasattr(self, 'use_return_dict'): - self.use_return_dict = True - if not hasattr(self, 'use_cache'): - self.use_cache = True - - def get_required_attributes(self) -> List[str]: - """List of required attributes for the configuration""" - return [ - "hidden_size", - "num_attention_heads", - "num_hidden_layers", - "num_key_value_heads", - "pad_token_id", - "vocab_size", - "max_position_embeddings", - "rope_theta", - "rms_norm_eps", - "hidden_act", - "intermediate_size", - ] - - @classmethod - def get_neuron_config_cls(cls) -> Type[Qwen2_5OmniNeuronConfig]: - """Return the NeuronConfig class to use""" - return Qwen2_5OmniNeuronConfig - - @classmethod - def from_pretrained(cls, model_path: str, **kwargs) -> "Qwen2_5OmniInferenceConfig": - """ - Load configuration from a pretrained Qwen2.5-Omni model directory. - - The Qwen2.5-Omni config has a nested structure: - config.json -> thinker_config -> text_config (the actual text model config) - - Args: - model_path: Path to the model directory containing config.json - **kwargs: Additional arguments to override configuration - - Returns: - Qwen2_5OmniInferenceConfig: Configuration object - """ - # Extract neuron_config from kwargs if it exists - neuron_config = kwargs.pop("neuron_config", None) - - # Try loading saved neuron config if not provided - # Try multiple possible locations - if neuron_config is None: - possible_paths = [ - os.path.join(model_path, "neuron_config.json"), - "neuron_config.json", # Current directory - ] - - for neuron_config_path in possible_paths: - if os.path.exists(neuron_config_path): - print(f"Loading neuron_config from: {neuron_config_path}") - with open(neuron_config_path, "r") as f: - neuron_config_data = json.load(f) - # The saved config has the neuron_config nested - if "neuron_config" in neuron_config_data: - neuron_config_dict = neuron_config_data["neuron_config"] - else: - neuron_config_dict = neuron_config_data - neuron_config = cls.get_neuron_config_cls()(**neuron_config_dict) - break - - # Read the full config.json - config_path = os.path.join(model_path, "config.json") - with open(config_path, "r") as f: - full_config = json.load(f) - - # Navigate to the text model config - # Path: config.json -> thinker_config -> text_config - thinker_config = full_config.get("thinker_config", {}) - text_config = thinker_config.get("text_config", {}) - - if not text_config: - raise ValueError( - f"Could not find text_config in {config_path}. " - "Expected structure: config.json -> thinker_config -> text_config" - ) - - # Extract configuration parameters from text_config - config_dict = { - "hidden_size": text_config.get("hidden_size"), - "num_attention_heads": text_config.get("num_attention_heads"), - "num_hidden_layers": text_config.get("num_hidden_layers"), - "num_key_value_heads": text_config.get("num_key_value_heads"), - "vocab_size": text_config.get("vocab_size"), - "max_position_embeddings": text_config.get("max_position_embeddings"), - "intermediate_size": text_config.get("intermediate_size"), - "rms_norm_eps": text_config.get("rms_norm_eps"), - "rope_theta": text_config.get("rope_theta"), - "hidden_act": text_config.get("hidden_act"), - "sliding_window": text_config.get("sliding_window"), - "use_sliding_window": text_config.get("use_sliding_window", False), - } - - # Extract pad_token_id from thinker_config (not in text_config) - config_dict["pad_token_id"] = thinker_config.get("pad_token_id") - - # Extract rope_scaling if present - if "rope_scaling" in text_config and text_config["rope_scaling"]: - rope_scaling = text_config["rope_scaling"] - config_dict["rope_scaling"] = rope_scaling - # Extract mrope_section for multimodal RoPE - config_dict["mrope_section"] = rope_scaling.get("mrope_section", [16, 24, 24]) - - # Handle layer_types for sliding window attention - # Qwen2.5-Omni alternates between full and sliding attention - num_layers = config_dict["num_hidden_layers"] - if config_dict.get("use_sliding_window"): - # Alternate between full and sliding attention - config_dict["layer_types"] = ["sliding_attention" if i % 2 else "full_attention" - for i in range(num_layers)] - else: - # All layers use full attention - config_dict["layer_types"] = ["full_attention"] * num_layers - - # Override with kwargs - config_dict.update(kwargs) - - # Create and return config - config = cls(neuron_config=neuron_config, **config_dict) - return config - - -class NeuronQwen2_5OmniAttention(NeuronAttentionBase): - """ - Qwen2.5-Omni attention mechanism for NeuronX. - - Based on NeuronQwen2Attention but with multimodal RoPE support. - The multimodal RoPE is handled at the model level, so this class - uses standard NeuronAttentionBase with bias configurations. - - Reference: - - HF: Qwen2_5OmniAttention in modeling_qwen2_5_omni.py - - NXD: NeuronQwen2Attention in modeling_qwen2.py - """ - - def __init__(self, config: Qwen2_5OmniInferenceConfig, layer_idx: int = 0): - """ - Initialize Qwen2.5-Omni attention. - - Args: - config: Model configuration - layer_idx: Layer index (used for sliding window) - """ - self.layer_idx = layer_idx - - # Create rotary embedding - rotary_emb = RotaryEmbedding( - config.hidden_size // config.num_attention_heads, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) - - # Determine if this layer uses sliding window - sliding_window = None - if hasattr(config, 'layer_types') and config.layer_types: - if config.layer_types[layer_idx] == "sliding_attention": - sliding_window = getattr(config, 'sliding_window', None) - - # Initialize base attention - super().__init__( - config=config, - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - head_dim=config.hidden_size // config.num_attention_heads, - qkv_bias=config.qkv_bias, # Qwen2.5-Omni has bias in QKV - o_bias=config.o_bias, # No bias in output - rotary_emb=rotary_emb, - sliding_window=sliding_window, - ) - - -class NeuronQwen2_5OmniMLP(nn.Module): - """ - Qwen2.5-Omni MLP layer (same as Qwen2/Llama - SwiGLU activation). - - Architecture: - - gate_proj: Linear(hidden_size, intermediate_size) - - up_proj: Linear(hidden_size, intermediate_size) - - down_proj: Linear(intermediate_size, hidden_size) - - activation: SwiGLU = silu(gate_proj(x)) * up_proj(x) - - Reference: - - HF: Qwen2MLP in modeling_qwen2_5_omni.py - - NXD: NeuronLlamaMLP in modeling_llama.py - """ - - def __init__(self, config: Qwen2_5OmniInferenceConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - - # Gate projection (for SwiGLU) - self.gate_proj = ColumnParallelLinear( - config.hidden_size, - config.intermediate_size, - bias=False, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - ) - - # Up projection (for SwiGLU) - self.up_proj = ColumnParallelLinear( - config.hidden_size, - config.intermediate_size, - bias=False, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - ) - - # Down projection - self.down_proj = RowParallelLinear( - config.intermediate_size, - config.hidden_size, - bias=False, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, - ) - - # Activation function (SiLU for SwiGLU) - self.act_fn = nn.SiLU() - - def forward(self, x): - """ - Forward pass with SwiGLU activation. - - Args: - x: Input tensor [batch, seq_len, hidden_size] - - Returns: - Tuple of (output, None) for compatibility with NeuronBaseModel - """ - # SwiGLU: silu(gate_proj(x)) * up_proj(x) - gate_output = self.act_fn(self.gate_proj(x)) - up_output = self.up_proj(x) - intermediate_output = gate_output * up_output - - # Apply down projection - output = self.down_proj(intermediate_output) - - return output, None # Return None as second output for compatibility - - -class NeuronQwen2_5OmniDecoderLayer(nn.Module): - """ - Qwen2.5-Omni decoder layer for NeuronX. - - Architecture (pre-norm): - 1. hidden = input + self_attn(norm(input)) - 2. output = hidden + mlp(norm(hidden)) - - Reference: - - HF: Qwen2_5OmniDecoderLayer in modeling_qwen2_5_omni.py - - NXD: NeuronQwen2DecoderLayer in modeling_qwen2.py - """ - - def __init__(self, config: Qwen2_5OmniInferenceConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - - # Self-attention with layer-specific sliding window - self.self_attn = NeuronQwen2_5OmniAttention(config, layer_idx=layer_idx) - - # MLP (SwiGLU) - self.mlp = NeuronQwen2_5OmniMLP(config) - - # Layer normalization (RMSNorm) - self.input_layernorm = get_rmsnorm_cls()( - config.hidden_size, - eps=config.rms_norm_eps, - ) - self.post_attention_layernorm = get_rmsnorm_cls()( - config.hidden_size, - eps=config.rms_norm_eps, - ) - - # Store attention type for this layer - self.attention_type = config.layer_types[layer_idx] if hasattr(config, 'layer_types') else 'full_attention' - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Forward pass for decoder layer. - - Args: - hidden_states: Input tensor [batch, seq_len, hidden_size] - attention_mask: Attention mask - position_ids: Position indices - past_key_value: Cached key-value pairs - - Returns: - Tuple of (hidden_states, present_key_value, cos_cache, sin_cache, None) - """ - # Pre-norm: normalize before attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self-attention - hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - **kwargs, - ) - - # Residual connection - hidden_states = residual + hidden_states - - # Pre-norm: normalize before MLP - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states)[0] # Take first element of tuple - - # Residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) - - return outputs - - -class NeuronQwen2_5OmniModel(NeuronBaseModel): - """ - Qwen2.5-Omni text model for NeuronX inference. - - This implements the core text model (Thinker) from Qwen2.5-Omni, - focusing on text-only inference without multimodal components. - - Architecture: - - Token embeddings - - Stack of decoder layers with GQA and SwiGLU - - RMSNorm - - LM head for token generation - - Reference: - - HF: Qwen2_5OmniThinkerTextModel in modeling_qwen2_5_omni.py - - NXD: NeuronQwen2Model in modeling_qwen2.py - """ - - def setup_attr_for_model(self, config: Qwen2_5OmniInferenceConfig): - """Setup attributes for model initialization""" - self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None - self.tp_degree = config.neuron_config.tp_degree - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.max_batch_size = config.neuron_config.max_batch_size - self.buckets = config.neuron_config.buckets - - def init_model(self, config: Qwen2_5OmniInferenceConfig): - """Initialize the model components""" - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - # Token embeddings - self.embed_tokens = ParallelEmbedding( - config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=config.neuron_config.torch_dtype, - shard_across_embedding=True, - pad=True, - ) - - # Decoder layers - self.layers = nn.ModuleList( - [NeuronQwen2_5OmniDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers)] - ) - - # Final normalization - self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) - - # LM head for token generation - self.lm_head = ColumnParallelLinear( - config.hidden_size, - config.vocab_size, - bias=False, - pad=True, - gather_output=not self.on_device_sampling, - dtype=config.neuron_config.torch_dtype, - ) - - -class NeuronQwen2_5OmniForCausalLM(NeuronBaseForCausalLM): - """ - Qwen2.5-Omni for Causal Language Modeling on NeuronX. - - This is the main entry point for using Qwen2.5-Omni on NeuronX. - It provides a HuggingFace-compatible interface for text generation. - - Usage: - config = Qwen2_5OmniInferenceConfig.from_pretrained(model_path, neuron_config=neuron_config) - model = NeuronQwen2_5OmniForCausalLM(config) - model.load_state_dict(state_dict) - model.compile_model() - model.save_pretrained(output_path) - - Reference: - - HF: Qwen2_5OmniThinkerForConditionalGeneration in modeling_qwen2_5_omni.py - - NXD: NeuronQwen2ForCausalLM in modeling_qwen2.py - """ - - _model_cls = NeuronQwen2_5OmniModel - - @staticmethod - def convert_hf_to_neuron_state_dict(state_dict, config): - """ - Convert HuggingFace Qwen2.5-Omni state dict to NeuronX format. - - The Qwen2.5-Omni checkpoint has a nested structure with the text model under - 'model.thinker.model' prefix. This function extracts and renames the weights. - - NeuronAttentionBase expects QKV weights in the format: - - layers.{i}.self_attn.qkv_proj.q_proj.weight (for separate Q/K/V) - - layers.{i}.self_attn.qkv_proj.q_proj.bias - - Weight mappings: - - thinker.model.embed_tokens.weight -> embed_tokens.weight - - thinker.model.layers.{i}.self_attn.q_proj.* -> layers.{i}.self_attn.qkv_proj.q_proj.* - - thinker.model.layers.{i}.self_attn.k_proj.* -> layers.{i}.self_attn.qkv_proj.k_proj.* - - thinker.model.layers.{i}.self_attn.v_proj.* -> layers.{i}.self_attn.qkv_proj.v_proj.* - - thinker.model.layers.{i}.self_attn.o_proj.* -> layers.{i}.self_attn.o_proj.* - - thinker.model.layers.{i}.mlp.* -> layers.{i}.mlp.* - - thinker.model.norm.weight -> norm.weight - - thinker.lm_head.weight -> lm_head.weight - - Args: - state_dict: HuggingFace state dictionary - config: Model configuration - - Returns: - Dictionary with NeuronX-formatted weights - """ - import torch - - neuron_state_dict = {} - - # Remove prefixes: either "model.thinker.model." or just "thinker.model." - # The actual prefix might vary - possible_prefixes = [ - "model.thinker.model.", - "thinker.model.", - "model.thinker.", - "thinker.", - "" - ] - - # Detect which prefix is used - actual_prefix = "" - for prefix in possible_prefixes: - test_key = f"{prefix}embed_tokens.weight" - if test_key in state_dict: - actual_prefix = prefix - break - - print(f"Detected HF weight prefix: '{actual_prefix}'") - - for name, param in state_dict.items(): - # Skip weights not belonging to the text model (thinker) - if not name.startswith(actual_prefix) and actual_prefix != "": - # Check if this is the lm_head - lm_head_patterns = ["model.thinker.lm_head.", "thinker.lm_head.", "lm_head."] - is_lm_head = any(name.startswith(p) for p in lm_head_patterns) - if not is_lm_head: - continue - - # Remove the prefix - if actual_prefix and name.startswith(actual_prefix): - new_name = name[len(actual_prefix):] - else: - # Handle lm_head separately - for lm_prefix in ["model.thinker.lm_head.", "thinker.lm_head.", "lm_head."]: - if name.startswith(lm_prefix): - new_name = name[len(lm_prefix):] - new_name = "lm_head." + new_name - break - else: - new_name = name - - # Map attention weights to qkv_proj structure - if ".self_attn.q_proj." in new_name: - new_name = new_name.replace(".self_attn.q_proj.", ".self_attn.qkv_proj.q_proj.") - elif ".self_attn.k_proj." in new_name: - new_name = new_name.replace(".self_attn.k_proj.", ".self_attn.qkv_proj.k_proj.") - elif ".self_attn.v_proj." in new_name: - new_name = new_name.replace(".self_attn.v_proj.", ".self_attn.qkv_proj.v_proj.") - - # Clone and store the parameter - neuron_state_dict[new_name] = param.clone() - - print(f"Converted {len(state_dict)} HF weights to {len(neuron_state_dict)} Neuron weights") - - # Verify key weights exist - required_keys = ["embed_tokens.weight", "norm.weight", "lm_head.weight"] - for key in required_keys: - if key not in neuron_state_dict: - print(f"⚠️ Warning: Required key '{key}' not found in converted state dict") - - # Verify layer 0 attention weights - layer0_attn_keys = [ - "layers.0.self_attn.qkv_proj.q_proj.weight", - "layers.0.self_attn.qkv_proj.k_proj.weight", - "layers.0.self_attn.qkv_proj.v_proj.weight", - "layers.0.self_attn.o_proj.weight" - ] - for key in layer0_attn_keys: - if key not in neuron_state_dict: - print(f"⚠️ Warning: Layer 0 attention key '{key}' not found") - - return neuron_state_dict diff --git a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_e2e_qwen25_omni.py b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_e2e_qwen25_omni.py new file mode 100644 index 00000000..835f8313 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_e2e_qwen25_omni.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +"""End-to-end test for Qwen2.5-Omni on Trn2 (CPU inference). + +Tests multimodal input (text, image, audio) → text and audio output +using HF Qwen2_5OmniForConditionalGeneration directly. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + python3 test_e2e_qwen25_omni.py [--model-path MODEL_PATH] [--output-dir OUTPUT_DIR] +""" + +import argparse +import gc +import json +import logging +import os +import sys +import time +import traceback + +import numpy as np +import torch + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = "Qwen/Qwen2.5-Omni-7B" +DEFAULT_OUTPUT = "/home/ubuntu/e2e_test_results" + +# Qwen2.5-Omni requires this specific system prompt for audio output +SYSTEM_PROMPT = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, " + "capable of perceiving auditory and visual inputs, as well as generating text and speech." +) + + +def download_test_assets(output_dir): + """Create sample image and test audio.""" + from PIL import Image, ImageDraw + import wave + + assets_dir = os.path.join(output_dir, "assets") + os.makedirs(assets_dir, exist_ok=True) + + image_path = os.path.join(assets_dir, "test_image.jpg") + if not os.path.exists(image_path): + logger.info("Creating test image...") + img = Image.new("RGB", (320, 240), "white") + draw = ImageDraw.Draw(img) + draw.rectangle([20, 20, 100, 100], fill="red", outline="black") + draw.ellipse([130, 30, 230, 130], fill="blue", outline="black") + draw.polygon([(260, 120), (310, 20), (210, 20)], fill="green", outline="black") + draw.rectangle([0, 160, 320, 240], fill="lightgreen") + draw.ellipse([220, 60, 300, 140], fill="yellow") + img.save(image_path) + logger.info("Saved %s", image_path) + + audio_path = os.path.join(assets_dir, "test_audio.wav") + if not os.path.exists(audio_path): + logger.info("Creating test audio...") + sr = 16000 + t = np.linspace(0, 2.0, sr * 2, endpoint=False) + audio = (0.5 * np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + with wave.open(audio_path, "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sr) + wf.writeframes(audio.tobytes()) + logger.info("Saved %s", audio_path) + + return image_path, audio_path + + +def load_model(model_path): + """Load Qwen2.5-Omni model and processor.""" + from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor + + logger.info("Loading processor from %s ...", model_path) + processor = Qwen2_5OmniProcessor.from_pretrained(model_path) + + logger.info("Loading model from %s (this may take a few minutes)...", model_path) + start = time.time() + model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + model_path, + dtype=torch.bfloat16, + device_map="cpu", + ) + model.eval() + elapsed = time.time() - start + logger.info("Model loaded in %.1fs", elapsed) + + return model, processor + + +# ============================================================================ +# Test 1: Text → Text +# ============================================================================ + +def test_text_only(model, processor, output_dir): + logger.info("=" * 60) + logger.info("TEST 1: Text → Text") + logger.info("=" * 60) + + prompt = "What is the capital of France? Answer in one sentence." + + try: + messages = [ + {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=text, return_tensors="pt") + + start = time.time() + with torch.no_grad(): + output_ids = model.generate( + **inputs, + return_audio=False, + thinker_max_new_tokens=100, + ) + elapsed = time.time() - start + + response = processor.batch_decode( + output_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + + path = os.path.join(output_dir, "test1_text_response.txt") + with open(path, "w") as f: + f.write(f"Prompt: {prompt}\n\nResponse: {response}\n\nTime: {elapsed:.1f}s\n") + + logger.info("Response: %s", response) + logger.info("Time: %.1fs", elapsed) + return True, response, elapsed + + except Exception as e: + logger.error("Test 1 FAILED: %s", e) + traceback.print_exc() + return False, str(e), 0 + + +# ============================================================================ +# Test 2: Image + Text → Text +# ============================================================================ + +def test_image_text(model, processor, output_dir, image_path): + logger.info("=" * 60) + logger.info("TEST 2: Image + Text → Text") + logger.info("=" * 60) + + prompt = "Describe this image in detail. What shapes and colors do you see?" + + try: + from PIL import Image + image = Image.open(image_path).convert("RGB") + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt}, + ], + }, + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=text, images=[image], return_tensors="pt") + + start = time.time() + with torch.no_grad(): + output_ids = model.generate( + **inputs, + return_audio=False, + thinker_max_new_tokens=200, + ) + elapsed = time.time() - start + + response = processor.batch_decode( + output_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + + path = os.path.join(output_dir, "test2_image_response.txt") + with open(path, "w") as f: + f.write(f"Prompt: {prompt}\nImage: {image_path}\n\nResponse: {response}\n\nTime: {elapsed:.1f}s\n") + + logger.info("Response: %s", response[:300]) + logger.info("Time: %.1fs", elapsed) + return True, response, elapsed + + except Exception as e: + logger.error("Test 2 FAILED: %s", e) + traceback.print_exc() + return False, str(e), 0 + + +# ============================================================================ +# Test 3: Audio + Text → Text +# ============================================================================ + +def test_audio_text(model, processor, output_dir, audio_path): + logger.info("=" * 60) + logger.info("TEST 3: Audio + Text → Text") + logger.info("=" * 60) + + prompt = "What do you hear in this audio? Describe it." + + try: + import wave + + with wave.open(audio_path, "r") as wf: + sr = wf.getframerate() + frames = wf.readframes(wf.getnframes()) + audio_data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0 + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "audio", "audio": audio_data, "sampling_rate": sr}, + {"type": "text", "text": prompt}, + ], + }, + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor( + text=text, + audios=[audio_data], + sampling_rate=sr, + return_tensors="pt", + ) + + start = time.time() + with torch.no_grad(): + output_ids = model.generate( + **inputs, + return_audio=False, + thinker_max_new_tokens=200, + ) + elapsed = time.time() - start + + response = processor.batch_decode( + output_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + + path = os.path.join(output_dir, "test3_audio_response.txt") + with open(path, "w") as f: + f.write(f"Prompt: {prompt}\nAudio: {audio_path}\n\nResponse: {response}\n\nTime: {elapsed:.1f}s\n") + + logger.info("Response: %s", response[:300]) + logger.info("Time: %.1fs", elapsed) + return True, response, elapsed + + except Exception as e: + logger.error("Test 3 FAILED: %s", e) + traceback.print_exc() + return False, str(e), 0 + + +# ============================================================================ +# Test 4: Text → Speech +# ============================================================================ + +def test_speech_output(model, processor, output_dir): + logger.info("=" * 60) + logger.info("TEST 4: Text → Speech") + logger.info("=" * 60) + + prompt = "Say hello and tell me what the weather is like today." + + try: + messages = [ + {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=text, return_tensors="pt") + + start = time.time() + with torch.no_grad(): + result = model.generate( + **inputs, + return_audio=True, + thinker_max_new_tokens=200, + talker_max_new_tokens=2000, + speaker="Chelsie", + ) + elapsed = time.time() - start + + # result is (text_ids, audio_waveform) when return_audio=True + text_response = "" + audio_waveform = None + + if isinstance(result, tuple) and len(result) >= 2: + text_ids, audio_waveform = result[0], result[1] + text_response = processor.batch_decode( + text_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + else: + text_ids = result if not isinstance(result, tuple) else result[0] + text_response = processor.batch_decode( + text_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + + # Save text + text_path = os.path.join(output_dir, "test4_speech_text.txt") + with open(text_path, "w") as f: + f.write(f"Prompt: {prompt}\n\nResponse: {text_response}\n\nTime: {elapsed:.1f}s\n") + if audio_waveform is not None: + f.write(f"Audio: generated ({type(audio_waveform)})\n") + else: + f.write("Audio: none returned\n") + + # Save audio + wav_path = os.path.join(output_dir, "test4_speech_response.wav") + if audio_waveform is not None: + import wave as wave_mod + + if isinstance(audio_waveform, torch.Tensor): + audio_np = audio_waveform.cpu().float().numpy() + else: + audio_np = np.array(audio_waveform, dtype=np.float32) + + if audio_np.ndim > 1: + audio_np = audio_np.squeeze() + + max_val = max(abs(audio_np.max()), abs(audio_np.min()), 1e-8) + audio_int16 = (audio_np / max_val * 32767).astype(np.int16) + + with wave_mod.open(wav_path, "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio_int16.tobytes()) + + logger.info("Speech saved: %s (%d samples, %.1fs audio)", + wav_path, len(audio_int16), len(audio_int16) / 24000) + else: + logger.warning("No audio waveform returned.") + with open(os.path.join(output_dir, "test4_no_audio.txt"), "w") as f: + f.write("model.generate(return_audio=True) did not return audio waveform.\n" + "This may require spk_dict.pt or the Talker/Token2Wav to be initialized.\n") + + logger.info("Text: %s", text_response[:300]) + logger.info("Time: %.1fs", elapsed) + return True, text_response, elapsed + + except Exception as e: + logger.error("Test 4 FAILED: %s", e) + traceback.print_exc() + return False, str(e), 0 + + +# ============================================================================ +# Main +# ============================================================================ + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", default=DEFAULT_MODEL) + parser.add_argument("--output-dir", default=DEFAULT_OUTPUT) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + logger.info("=" * 60) + logger.info("Qwen2.5-Omni End-to-End Test") + logger.info("Model: %s", args.model_path) + logger.info("Output: %s", args.output_dir) + logger.info("=" * 60) + + image_path, audio_path = download_test_assets(args.output_dir) + model, processor = load_model(args.model_path) + + results = {} + + for name, fn, extra_args in [ + ("test1_text_to_text", test_text_only, []), + ("test2_image_text", test_image_text, [image_path]), + ("test3_audio_text", test_audio_text, [audio_path]), + ("test4_text_to_speech", test_speech_output, []), + ]: + ok, resp, t = fn(model, processor, args.output_dir, *extra_args) + results[name] = {"passed": ok, "response": resp[:500], "time": t} + gc.collect() + + passed = sum(1 for r in results.values() if r["passed"]) + total = len(results) + + summary_path = os.path.join(args.output_dir, "summary.txt") + with open(summary_path, "w") as f: + f.write("=" * 60 + "\n") + f.write("Qwen2.5-Omni End-to-End Test Results\n") + f.write(f"Model: {args.model_path}\n") + f.write("=" * 60 + "\n\n") + f.write(f"Results: {passed}/{total} passed\n\n") + for name, r in results.items(): + status = "PASS" if r["passed"] else "FAIL" + f.write(f"[{status}] {name} ({r['time']:.1f}s)\n") + f.write(f" {r['response'][:300]}\n\n") + + with open(os.path.join(args.output_dir, "results.json"), "w") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + logger.info("=" * 60) + logger.info("FINAL: %d/%d tests passed", passed, total) + logger.info("Results: %s", args.output_dir) + logger.info("=" * 60) + + if passed < total: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py index 633d8387..2fde653c 100755 --- a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py +++ b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py @@ -1,236 +1,810 @@ #!/usr/bin/env python3 """ -Integration tests for Qwen2.5-Omni-7B NeuronX implementation. -""" +Integration tests for Qwen2.5-Omni-7B on NeuronX (TP=4). -import pytest -import torch -import json -from pathlib import Path -from transformers import AutoTokenizer, GenerationConfig +Tests: + 1. Import validation + 2. Config creation and TP=4 head divisibility + 3. State dict conversion (all 2448 keys) + 4. Audio encoder CPU components (frontend + postprocessor) + 5. Talker CPU model (weight loading + codec tokens) + 6. Text-only Thinker compile + load + generate + 7. Image understanding (requires vision encoder compiled) + 8. Audio understanding (requires audio encoder compiled) + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + + # Run all tests (skip tests requiring compilation with --quick): + python3 test_model.py + + # Run only text generation (model must be pre-compiled): + python3 test_model.py --test text_gen -from neuronx_distributed_inference.models.config import NeuronConfig -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + # Run with pytest: + pytest test_model.py -v +""" -# Import from src directory +# --- Qwen2.5-Omni contrib bootstrap --- +import sys as _sys +from pathlib import Path as _Path +_SRC = _Path(__file__).resolve().parents[2] / "src" +if str(_SRC) not in _sys.path: + _sys.path.insert(0, str(_SRC)) +import _upstream_compat # noqa: F401 (applies hf_adapter shim) +# --- end bootstrap --- + +import argparse +import gc +import os import sys -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) -from modeling_qwen2_5_omni import * - - -# Test configuration -MODEL_PATH = "/home/ubuntu/models/Qwen2.5-Omni-7B/" -COMPILED_MODEL_PATH = "/home/ubuntu/neuron_models/Qwen2.5-Omni-7B/" - - -def load_neuron_config_from_compiled(compiled_path: str): - """Load neuron configuration from compiled model's neuron_config.json.""" - config_path = Path(compiled_path) / "neuron_config.json" - - if not config_path.exists(): - raise FileNotFoundError(f"neuron_config.json not found: {config_path}") - - with open(config_path) as f: - config_data = json.load(f) - - if "neuron_config" in config_data: - return config_data["neuron_config"] - else: - return config_data +import time +import traceback +import torch -def create_model_for_inference(compiled_path: str, model_path: str): - """Create model for inference using compiled neuron_config.""" - neuron_config_dict = load_neuron_config_from_compiled(compiled_path) - - dtype_str = neuron_config_dict.get('torch_dtype', 'torch.bfloat16') - if isinstance(dtype_str, str): - dtype = getattr(torch, dtype_str.split('.')[1]) if dtype_str.startswith('torch.') else torch.bfloat16 +# Default paths - override with environment variables +MODEL_PATH = os.environ.get( + "QWEN25_OMNI_MODEL_PATH", "/opt/dlami/nvme/models/Qwen2.5-Omni-7B" +) +COMPILED_PATH = os.environ.get( + "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_tp4_compiled" +) +TP_DEGREE = int(os.environ.get("QWEN25_OMNI_TP_DEGREE", "4")) + +# Test media URLs from Qwen official examples +IMAGE_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" +AUDIO_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/cough.wav" + + +class Timer: + """Context manager to time a block.""" + + def __init__(self, label): + self.label = label + self.elapsed = 0 + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.elapsed = time.time() - self.start + print(f" [{self.label}] {self.elapsed:.2f}s") + + +# --------------------------------------------------------------------------- +# Test 1: Imports +# --------------------------------------------------------------------------- +def test_imports(): + """Verify all Qwen2.5-Omni modules import correctly.""" + print("=" * 60) + print("Test 1: Import validation") + print("=" * 60) + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + + print(" NeuronConfig, OnDeviceSamplingConfig imported OK") + + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, + Qwen25OmniInferenceConfig, + NeuronQwen25OmniMultimodalForCausalLM, + Qwen25OmniMultimodalInferenceConfig, + ) + + print(" Qwen25Omni model classes imported OK") + + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + NeuronQwen25OmniForAudioEncoding, + AudioEncoderInferenceConfig, + AudioCPUFrontend, + AudioCPUPostprocessor, + NeuronAudioTransformerModel, + AudioTransformerModelWrapper, + ) + + print(" Audio encoder classes imported OK") + + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + ) + + print(" Talker imported OK") + + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2Wav, + ) + + print(" Token2Wav imported OK") + + from modeling_qwen25_omni_vision import ( + NeuronQwen25OmniForImageEncoding, + ) + + print(" Vision encoder imported OK") + + print(" PASS: All imports successful\n") + return True + + +# --------------------------------------------------------------------------- +# Test 2: Config +# --------------------------------------------------------------------------- +def test_config(): + """Create configs and verify TP=4 head divisibility.""" + print("=" * 60) + print("Test 2: Config creation and TP=4 validation") + print("=" * 60) + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + from modeling_qwen25_omni import ( + Qwen25OmniInferenceConfig, + ) + from transformers import AutoConfig + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=2048, + max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.6, top_k=20, top_p=0.95 + ), + ) + + hf_config = load_pretrained_config(MODEL_PATH) + config = Qwen25OmniInferenceConfig(neuron_config, load_config=hf_config) + + # Validate text config + print(f" hidden_size={config.hidden_size}") + print(f" num_attention_heads={config.num_attention_heads}") + print(f" num_key_value_heads={config.num_key_value_heads}") + print(f" num_hidden_layers={config.num_hidden_layers}") + print(f" vocab_size={config.vocab_size}") + + # TP divisibility check for Thinker + assert ( + config.num_attention_heads % TP_DEGREE == 0 + ), f"num_attention_heads={config.num_attention_heads} not divisible by TP={TP_DEGREE}" + assert ( + config.num_key_value_heads % TP_DEGREE == 0 + ), f"num_key_value_heads={config.num_key_value_heads} not divisible by TP={TP_DEGREE}" + print( + f" Thinker heads: {config.num_attention_heads}/{TP_DEGREE}=" + f"{config.num_attention_heads // TP_DEGREE} per rank, " + f"kv_heads: {config.num_key_value_heads}/{TP_DEGREE}=" + f"{config.num_key_value_heads // TP_DEGREE} per rank" + ) + + # Validate audio config (via AutoConfig for full nested access) + full_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + if hasattr(full_config, "thinker_config"): + tc = full_config.thinker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + if "audio_config" in tc: + ac = tc["audio_config"] + if hasattr(ac, "__dict__"): + ac = vars(ac) + audio_heads = ac.get("encoder_attention_heads", 20) + print( + f" Audio: d_model={ac.get('d_model')}, heads={audio_heads}, " + f"layers={ac.get('encoder_layers')}" + ) + assert ( + audio_heads % TP_DEGREE == 0 + ), f"audio heads={audio_heads} not divisible by TP={TP_DEGREE}" + print( + f" Audio heads: {audio_heads}/{TP_DEGREE}={audio_heads // TP_DEGREE} per rank" + ) + + # Validate talker config + if hasattr(full_config, "talker_config"): + tc = full_config.talker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + talker_heads = tc.get("num_attention_heads", 12) + talker_hidden = tc.get("hidden_size", 896) + print( + f" Talker: hidden={talker_hidden}, heads={talker_heads}, " + f"head_dim={tc.get('head_dim', 128)}" + ) + print( + f" Talker stays on CPU (head_dim={tc.get('head_dim', 128)} " + f"!= {talker_hidden}//{talker_heads})" + ) + + print(" PASS: Config creation and TP=4 validation\n") + return True + + +# --------------------------------------------------------------------------- +# Test 3: State dict conversion +# --------------------------------------------------------------------------- +def test_state_dict(): + """Load full HF model and convert state dict for all components.""" + print("=" * 60) + print("Test 3: State dict conversion (all components)") + print("=" * 60) + + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniForConditionalGeneration, + ) + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + ) + + # Load HF model state dict + with Timer("Load HF model"): + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + full_sd = hf_model.state_dict() + print(f" Full state dict: {len(full_sd)} keys") + + # Count keys by prefix + prefix_counts = {} + for k in full_sd: + p = k.split(".")[0] + prefix_counts[p] = prefix_counts.get(p, 0) + 1 + for p, c in sorted(prefix_counts.items()): + print(f" {p}: {c} keys") + + # Test audio encoder state dict conversion + with Timer("Audio state dict conversion"): + audio_sd = NeuronQwen25OmniAudioEncoder.convert_hf_to_neuron_state_dict( + {k: v for k, v in full_sd.items() if "audio_tower" in k}, + dtype=torch.bfloat16, + ) + frontend_keys = [k for k in audio_sd if k.startswith("frontend.")] + transformer_keys = [k for k in audio_sd if k.startswith("transformer.")] + postprocessor_keys = [k for k in audio_sd if k.startswith("postprocessor.")] + print( + f" Audio keys: frontend={len(frontend_keys)}, " + f"transformer={len(transformer_keys)}, " + f"postprocessor={len(postprocessor_keys)}" + ) + + # Test talker state dict conversion + with Timer("Talker state dict conversion"): + talker_sd = NeuronQwen25OmniTalker.convert_hf_to_neuron_state_dict( + {k: v for k, v in full_sd.items() if k.startswith("talker.")} + ) + print(f" Talker keys: {len(talker_sd)}") + + del hf_model, full_sd + gc.collect() + + print(" PASS: State dict conversion\n") + return True + + +# --------------------------------------------------------------------------- +# Test 4: Audio encoder CPU components +# --------------------------------------------------------------------------- +def test_audio_encoder(): + """Test audio encoder CPU frontend and postprocessor with synthetic input.""" + print("=" * 60) + print("Test 4: Audio encoder CPU components") + print("=" * 60) + + from transformers import AutoConfig + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniForConditionalGeneration, + ) + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + + # Load config + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + if not hasattr(hf_config, "thinker_config"): + print(" SKIP: No thinker_config found") + return True + + tc = hf_config.thinker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + audio_config = tc.get("audio_config", None) + if audio_config is not None and hasattr(audio_config, "__dict__"): + audio_config = vars(audio_config) + if audio_config is None: + print(" SKIP: No audio_config found") + return True + + print( + f" Audio config: d_model={audio_config.get('d_model')}, " + f"heads={audio_config.get('encoder_attention_heads')}, " + f"layers={audio_config.get('encoder_layers')}" + ) + + # Load HF model for weights + with Timer("Load HF model for audio weights"): + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + full_sd = hf_model.state_dict() + + with Timer("Convert audio state dict"): + converted_sd = NeuronQwen25OmniAudioEncoder.convert_hf_to_neuron_state_dict( + full_sd, dtype=torch.bfloat16 + ) + + del hf_model, full_sd + gc.collect() + + with Timer("Create audio encoder + load CPU weights"): + encoder = NeuronQwen25OmniAudioEncoder.from_pretrained_state_dict( + audio_config, converted_sd, dtype=torch.bfloat16 + ) + encoder.eval() + + del converted_sd + gc.collect() + + # Test with synthetic mel spectrograms of various lengths + n_mels = audio_config.get("num_mel_bins", 128) + test_cases = [(100, "1s"), (300, "3s"), (1000, "10s"), (3000, "30s")] + + for mel_len, label in test_cases: + mel_input = torch.randn(n_mels, mel_len, dtype=torch.bfloat16) + feature_lens = torch.tensor([mel_len], dtype=torch.long) + + t0 = time.time() + hidden, aftercnn_lens, cu_seqlens = encoder.frontend(mel_input, feature_lens) + audio_embeds = encoder.postprocessor(hidden, aftercnn_lens) + elapsed = time.time() - t0 + + print( + f" {label} ({mel_len} frames): frontend→{hidden.shape[0]} tokens, " + f"postprocessor→{audio_embeds.shape}, time={elapsed*1000:.1f}ms" + ) + + del encoder + gc.collect() + + print(" PASS: Audio encoder CPU components\n") + return True + + +# --------------------------------------------------------------------------- +# Test 5: Talker CPU model +# --------------------------------------------------------------------------- +def test_talker(): + """Test Talker CPU model weight loading and codec token IDs.""" + print("=" * 60) + print("Test 5: Talker CPU model") + print("=" * 60) + + from transformers import AutoConfig + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniForConditionalGeneration, + ) + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + ) + + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + talker_config = getattr(hf_config, "talker_config", None) + if talker_config is None: + print(" SKIP: No talker_config found") + return True + + if hasattr(talker_config, "__dict__") and not isinstance(talker_config, dict): + tc = vars(talker_config) else: - dtype = dtype_str - - neuron_config_kwargs = { - 'tp_degree': neuron_config_dict.get('tp_degree', 2), - 'batch_size': neuron_config_dict.get('batch_size', 1), - 'seq_len': neuron_config_dict.get('seq_len', 128), - 'torch_dtype': dtype, - } - - neuron_config = NeuronConfig(**neuron_config_kwargs) - - # This will use the imported model and config classes - # The actual class names will be determined at runtime - return None, neuron_config - - -def generate_with_neuron_model(model, input_ids, max_new_tokens: int): - """Generate tokens using manual forward pass loop.""" - generated_ids = input_ids.clone() - - for _ in range(max_new_tokens): - seq_len = generated_ids.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(generated_ids.shape[0], -1) - - with torch.no_grad(): - outputs = model(generated_ids, position_ids=position_ids) - - if hasattr(outputs, 'logits'): - logits = outputs.logits - elif isinstance(outputs, tuple): - logits = outputs[0] - else: - logits = outputs - - next_token_logits = logits[:, -1, :] - next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) - generated_ids = torch.cat([generated_ids, next_token], dim=-1) - - return generated_ids - - -@pytest.fixture(scope="module") -def compiled_model(): - """Load pre-compiled model.""" - # Note: Actual implementation would load the specific model class - # This is a template that should be customized per model - return None - - -@pytest.fixture(scope="module") -def tokenizer(): - """Load tokenizer.""" - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right", trust_remote_code=True) + tc = talker_config + + print( + f" Talker config: hidden={tc.get('hidden_size')}, " + f"heads={tc.get('num_attention_heads')}, " + f"kv_heads={tc.get('num_key_value_heads')}, " + f"layers={tc.get('num_hidden_layers')}, " + f"vocab={tc.get('vocab_size')}, " + f"embedding_size={tc.get('embedding_size')}" + ) + + with Timer("Load HF model for talker weights"): + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + full_sd = hf_model.state_dict() + + talker_sd = NeuronQwen25OmniTalker.convert_hf_to_neuron_state_dict( + {k: v for k, v in full_sd.items() if k.startswith("talker.")} + ) + print(f" Talker state dict: {len(talker_sd)} keys") + + del hf_model, full_sd + gc.collect() + + with Timer("Create Talker + load weights"): + talker = NeuronQwen25OmniTalker.from_pretrained_state_dict( + talker_config, talker_sd, dtype=torch.bfloat16 + ) + + print( + f" Talker codec tokens: bos={talker.codec_bos_token}, " + f"eos={talker.codec_eos_token}, pad={talker.codec_pad_token}" + ) + + del talker, talker_sd + gc.collect() + + print(" PASS: Talker CPU model\n") + return True + + +# --------------------------------------------------------------------------- +# Test 6: Text-only Thinker compile + load + generate +# --------------------------------------------------------------------------- +def test_text_gen(): + """Compile (if needed), load, and generate with the Thinker at TP=4.""" + print("=" * 60) + print("Test 6: Text-only Thinker compile + load + generate (TP=4)") + print("=" * 60) + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, + HuggingFaceGenerationAdapter, + ) + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, + Qwen25OmniInferenceConfig, + ) + from transformers import AutoTokenizer + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=2048, + max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=False, + top_k=1, + ), + ) + + hf_config = load_pretrained_config(MODEL_PATH) + config = Qwen25OmniInferenceConfig(neuron_config, load_config=hf_config) + + compiled_dir = os.path.join(COMPILED_PATH, "thinker_tp4") + + with Timer("Create model"): + model = NeuronQwen25OmniForCausalLM(MODEL_PATH, config) + + if not os.path.exists(os.path.join(compiled_dir, "neuron_config.json")): + with Timer("Compile (this takes several minutes)"): + model.compile(compiled_dir) + else: + print(" Compiled artifacts found, skipping compilation") + + with Timer("Load compiled model"): + model.load(compiled_dir) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - -def test_model_loads(compiled_model): - """Test that model loads successfully (smoke test).""" - assert compiled_model is not None - assert hasattr(compiled_model, 'config') - print("✓ Smoke test passed - Model loaded successfully") - - -def test_model_generates(compiled_model, tokenizer): - """Test that model can generate text.""" - prompt = "The capital of France is" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - - generated_ids = generate_with_neuron_model(compiled_model, inputs.input_ids, max_new_tokens=20) - output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - - assert len(output_text) > len(prompt), "Output should be longer than prompt" - print(f"✓ Generation test passed") - print(f" Output: {output_text}") - - -def test_output_coherence(compiled_model, tokenizer): - """Test that output is coherent (not gibberish).""" - prompt = "Hello, how are you?" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - - generated_ids = generate_with_neuron_model(compiled_model, inputs.input_ids, max_new_tokens=30) - output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - - # Coherence checks - assert len(output_text.split()) > 3, "Output should have multiple words" - assert not _is_repetitive(output_text), "Output should not be repetitive" - - print(f"✓ Coherence test passed") - print(f" Output: {output_text[:100]}...") - - - -def _is_repetitive(text: str, max_repeat: int = 5) -> bool: - """Check if text has excessive repetition.""" - words = text.split() - if len(words) < 10: - return False - - # Check for repeated words - for i in range(len(words) - max_repeat): - word = words[i] - if all(words[i+j] == word for j in range(max_repeat)): - return True - - # Check for repeated characters - new_text = text[-100:] if len(text) > 100 else text - if len(new_text) > 20: - char_counts = {} - for c in new_text: - char_counts[c] = char_counts.get(c, 0) + 1 - max_char_ratio = max(char_counts.values()) / len(new_text) - if max_char_ratio > 0.5: - return True - - return False - - -def test_performance_ttft(compiled_model, tokenizer): - """Test Time To First Token (TTFT) performance.""" - import time - - prompt = "Hello, how are you?" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - input_ids = inputs.input_ids - - # Warmup - for _ in range(3): - seq_len = input_ids.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) - with torch.no_grad(): - _ = compiled_model(input_ids, position_ids=position_ids) - - # Measure TTFT - times = [] - for _ in range(10): - seq_len = input_ids.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) - - start = time.perf_counter() - with torch.no_grad(): - _ = compiled_model(input_ids, position_ids=position_ids) - end = time.perf_counter() - - times.append((end - start) * 1000) # ms - - avg_ttft = sum(times) / len(times) - print(f"✓ TTFT: {avg_ttft:.2f}ms") - - - -def test_performance_throughput(compiled_model, tokenizer): - """Test token generation throughput.""" - import time - - prompt = "Hello" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - input_ids = inputs.input_ids - num_tokens = 50 - - # Warmup - _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=5) - - # Measure throughput - start = time.perf_counter() - _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=num_tokens) - end = time.perf_counter() - - total_time = end - start - throughput = num_tokens / total_time - print(f"✓ Throughput: {throughput:.2f} tok/s") + # HuggingFaceGenerationAdapter does NOT take tokenizer as argument. + adapter = HuggingFaceGenerationAdapter(model) + + def make_chat_input(prompt): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + encoded = tokenizer(text, return_tensors="pt") + return encoded["input_ids"], encoded["attention_mask"] + + prompts = [ + "What is 2+3? Answer with just the number.", + "Write a haiku about the ocean.", + "Explain quantum computing in one sentence.", + ] + + for prompt in prompts: + input_ids, attention_mask = make_chat_input(prompt) + prompt_len = input_ids.shape[1] + + with Timer(f"Generate '{prompt[:40]}...'"): + output_ids = adapter.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=128, + eos_token_id=[tokenizer.eos_token_id, 151645], + ) + + new_tokens = output_ids[0, prompt_len:] + output_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + n_new = sum( + 1 + for tok in new_tokens + if tok.item() + not in [tokenizer.eos_token_id, tokenizer.pad_token_id, 151645] + ) + print(f" Input: {prompt}") + print(f" Output: {output_text.strip()[:200]}") + print(f" Tokens: {n_new}") + print() + + del model, adapter + gc.collect() + + print(" PASS: Text-only Thinker compile + load + generate\n") + return True + + +# --------------------------------------------------------------------------- +# Test 7: Image understanding (requires multimodal model) +# --------------------------------------------------------------------------- +def test_image_understanding(): + """Test image understanding with the Qwen2.5-Omni vision encoder. + + This test requires the multimodal model (vision encoder + text decoder) + to be compiled. It downloads a test image and asks the model to describe it. + + NOTE: This is a placeholder for the full multimodal pipeline. The vision + encoder must be compiled on Neuron before this test can run end-to-end. + Currently tests the preprocessing pipeline only. + """ + print("=" * 60) + print("Test 7: Image understanding (preprocessing)") + print("=" * 60) + + try: + from qwen_omni_utils import process_mm_info + except ImportError: + print(" SKIP: qwen-omni-utils not installed") + print(" Install with: pip install qwen-omni-utils[decord]") + return True + + from transformers import AutoConfig + + # Build message with image + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant.", + } + ], + }, + { + "role": "user", + "content": [ + {"type": "image", "image": IMAGE_URL}, + {"type": "text", "text": "Describe this image briefly."}, + ], + }, + ] + + print(f" Image URL: {IMAGE_URL}") + + with Timer("Process multimodal info (download + preprocess)"): + audios, images, videos = process_mm_info( + messages, use_audio_in_video=False + ) + + if images: + print(f" Images processed: {len(images)}") + for i, img in enumerate(images): + if hasattr(img, "shape"): + print(f" Image {i}: shape={img.shape}") + elif hasattr(img, "size"): + print(f" Image {i}: size={img.size}") + else: + print(" No images found in processed output") + + # Verify config has vision token IDs + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + if hasattr(hf_config, "thinker_config"): + tc = hf_config.thinker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + image_token_id = tc.get("image_token_index") + vision_start_id = tc.get("vision_start_token_id") + vision_end_id = tc.get("vision_end_token_id") + print( + f" Vision tokens: image={image_token_id}, " + f"start={vision_start_id}, end={vision_end_id}" + ) + + print( + " NOTE: Full end-to-end image understanding requires " + "compiled vision encoder on Neuron." + ) + print(" PASS: Image preprocessing\n") + return True + + +# --------------------------------------------------------------------------- +# Test 8: Audio understanding (requires audio encoder on Neuron) +# --------------------------------------------------------------------------- +def test_audio_understanding(): + """Test audio understanding preprocessing pipeline. + + Downloads a test audio file and preprocesses it through the Qwen2.5-Omni + audio pipeline. Full end-to-end inference requires the audio encoder's + Neuron transformer to be compiled. + """ + print("=" * 60) + print("Test 8: Audio understanding (preprocessing)") + print("=" * 60) + + try: + from qwen_omni_utils import process_mm_info + except ImportError: + print(" SKIP: qwen-omni-utils not installed") + print(" Install with: pip install qwen-omni-utils[decord]") + return True + + from transformers import AutoConfig + + # Build message with audio + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant.", + } + ], + }, + { + "role": "user", + "content": [ + {"type": "audio", "audio": AUDIO_URL}, + {"type": "text", "text": "What sound is this?"}, + ], + }, + ] + + print(f" Audio URL: {AUDIO_URL}") + + with Timer("Process multimodal info (download + preprocess)"): + audios, images, videos = process_mm_info( + messages, use_audio_in_video=False + ) + + if audios: + print(f" Audios processed: {len(audios)}") + for i, audio in enumerate(audios): + if hasattr(audio, "shape"): + print(f" Audio {i}: shape={audio.shape}") + else: + print(f" Audio {i}: type={type(audio)}") + else: + print(" No audios found in processed output") + + # Verify config has audio token IDs + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + if hasattr(hf_config, "thinker_config"): + tc = hf_config.thinker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + audio_token_id = tc.get("audio_token_index") + audio_start_id = tc.get("audio_start_token_id") + audio_end_id = tc.get("audio_end_token_id") + print( + f" Audio tokens: audio={audio_token_id}, " + f"start={audio_start_id}, end={audio_end_id}" + ) + + print( + " NOTE: Full end-to-end audio understanding requires " + "compiled audio encoder on Neuron." + ) + print(" PASS: Audio preprocessing\n") + return True + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +ALL_TESTS = [ + ("imports", test_imports), + ("config", test_config), + ("state_dict", test_state_dict), + ("audio_encoder", test_audio_encoder), + ("talker", test_talker), + ("text_gen", test_text_gen), + ("image", test_image_understanding), + ("audio", test_audio_understanding), +] + + +def main(): + parser = argparse.ArgumentParser( + description="Qwen2.5-Omni-7B integration tests" + ) + parser.add_argument( + "--test", + choices=[name for name, _ in ALL_TESTS], + nargs="+", + help="Run specific test(s). Default: all.", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Skip heavyweight tests (state_dict, audio_encoder, talker).", + ) + args = parser.parse_args() + + if args.test: + tests = [(n, fn) for n, fn in ALL_TESTS if n in args.test] + elif args.quick: + skip = {"state_dict", "audio_encoder", "talker"} + tests = [(n, fn) for n, fn in ALL_TESTS if n not in skip] + else: + tests = ALL_TESTS + + print("\n" + "=" * 60) + print(f"Qwen2.5-Omni-7B TP={TP_DEGREE} Integration Tests") + print(f"Model: {MODEL_PATH}") + print(f"Compiled: {COMPILED_PATH}") + print("=" * 60 + "\n") + + results = {} + total_start = time.time() + + for name, test_fn in tests: + try: + passed = test_fn() + results[name] = "PASS" if passed else "FAIL" + except Exception as e: + print(f"\n FAIL: {e}") + traceback.print_exc() + results[name] = f"FAIL: {e}" + + total_time = time.time() - total_start + + print("\n" + "=" * 60) + print("RESULTS SUMMARY") + print("=" * 60) + for name, result in results.items(): + status = "PASS" if result == "PASS" else "FAIL" + print(f" [{status}] {name}: {result}") + print(f"\n Total time: {total_time:.1f}s") + print("=" * 60) + + if any(r != "PASS" for r in results.values()): + sys.exit(1) if __name__ == "__main__": - print("="*80) - print("Qwen2.5-Omni-7B Integration Tests") - print("="*80) - - print("\nNote: This is a template test file.") - print("For actual model testing, customize the model loading logic.") - - print("\n" + "="*80) - print("✓ Template structure verified!") - print("="*80) + main() diff --git a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_talker_neuron.py b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_talker_neuron.py new file mode 100644 index 00000000..90a83463 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_talker_neuron.py @@ -0,0 +1,610 @@ +#!/usr/bin/env python3 +"""Tests for the Neuron-compiled Talker and Token2Wav implementations. + +Tests CPU-level logic: config, state dict conversion, fused embedding, +class structure. Does NOT require Neuron hardware. + +Mocks neuronx_distributed and torch_neuronx at sys.modules level so tests +can run on any machine (Mac, Linux, etc.) without Neuron SDK. +""" + +# --- Qwen2.5-Omni contrib bootstrap --- +import sys as _sys +from pathlib import Path as _Path +_SRC = _Path(__file__).resolve().parents[2] / "src" +if str(_SRC) not in _sys.path: + _sys.path.insert(0, str(_SRC)) +import _upstream_compat # noqa: F401 (applies hf_adapter shim) +# --- end bootstrap --- + +import sys +import types +import torch +from pathlib import Path +from unittest.mock import MagicMock + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + + +# ============================================================================ +# Mock setup: neuronx_distributed and torch_neuronx +# ============================================================================ + +class _MockModuleType(types.ModuleType): + """Module type that returns MagicMock for any missing attribute.""" + def __getattr__(self, name): + return MagicMock(name=f"{self.__name__}.{name}") + + +class _AutoMockFinder: + """Meta path finder that auto-mocks any package that can't be imported normally. + + Only intercepts packages listed in _MOCK_PREFIXES to avoid breaking stdlib. + """ + _MOCK_PREFIXES = ( + "neuronx_distributed", "torch_neuronx", "torch_xla", + "nki", "nkilib", "neuronxcc", "transformers", "huggingface_hub", + "safetensors", "accelerate", "sentencepiece", "tokenizers", + ) + + def find_module(self, fullname, path=None): + if any(fullname == p or fullname.startswith(p + ".") for p in self._MOCK_PREFIXES): + return self + return None + + def find_spec(self, fullname, path, target=None): + if any(fullname == p or fullname.startswith(p + ".") for p in self._MOCK_PREFIXES): + from importlib.machinery import ModuleSpec + return ModuleSpec(fullname, self, is_package=True) + return None + + def create_module(self, spec): + mod = _MockModuleType(spec.name) + mod.__path__ = [] + mod.__package__ = spec.name + mod.__loader__ = self + mod.__spec__ = spec + return mod + + def exec_module(self, module): + pass + + def load_module(self, fullname): + if fullname in sys.modules: + return sys.modules[fullname] + mod = _MockModuleType(fullname) + mod.__path__ = [] + mod.__package__ = fullname + mod.__loader__ = self + sys.modules[fullname] = mod + return mod + + +def _setup_neuron_mocks(): + """Install auto-mock import hook and set up key attributes. + + Must be called before importing any neuronx_distributed_inference modules. + """ + # Install the auto-mock finder + sys.meta_path.insert(0, _AutoMockFinder()) + + # Force-import mocked modules so they exist in sys.modules + import neuronx_distributed + import neuronx_distributed.utils + + # Set specific attributes that need real values + neuronx_distributed.utils.cpu_mode = MagicMock(return_value=True) + + import neuronx_distributed.utils.utils + mock_hardware = MagicMock(name="hardware") + mock_hardware.TRN1 = "trn1" + mock_hardware.return_value = "trn2" + neuronx_distributed.utils.utils.hardware = mock_hardware + + import torch_neuronx.utils + torch_neuronx.utils.get_platform_target = MagicMock(return_value="trn2") + + # ColumnParallelLinear / RowParallelLinear / ParallelEmbedding need to be + # real classes so NxDI code can subclass them (e.g. lora_layer.py) + # ColumnParallelLinear / RowParallelLinear / ParallelEmbedding need to be + # real classes so NxDI code can subclass them (e.g. lora_layer.py) + class _MockParallelLinear(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + class _MockParallelEmbedding(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + class _MockSPMDRank(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + # Set on both .parallel_layers and .parallel_layers.layers + # (code imports from both paths) + import neuronx_distributed.parallel_layers as _pl + import neuronx_distributed.parallel_layers.layers as _pl_layers + for mod in (_pl, _pl_layers): + mod.ColumnParallelLinear = _MockParallelLinear + mod.RowParallelLinear = _MockParallelLinear + mod.ParallelEmbedding = _MockParallelEmbedding + mod.SPMDRank = _MockSPMDRank + + # LlamaRMSNorm needs to be a real nn.Module subclass for RMSNorm usage + class MockLlamaRMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + import transformers.models.llama.modeling_llama + transformers.models.llama.modeling_llama.LlamaRMSNorm = MockLlamaRMSNorm + + # ACT2FN for MLP activation lookup + import transformers.activations + transformers.activations.ACT2FN = { + "silu": torch.nn.SiLU(), + "gelu": torch.nn.GELU(), + "relu": torch.nn.ReLU(), + } + + +# Install mocks before any NxDI imports +_MOCK_MODULES = _setup_neuron_mocks() + + +# ============================================================================ +# Helper: create a load_config callable for InferenceConfig +# ============================================================================ + +def _make_load_config(**attrs): + """Create a load_config callable that sets attributes on an InferenceConfig.""" + def load_config(self): + for key, value in attrs.items(): + setattr(self, key, value) + return load_config + + +# ============================================================================ +# Test 1: Config classes +# ============================================================================ + +def test_talker_inference_config(): + """Test TalkerInferenceConfig creation and derived attributes.""" + from modeling_qwen25_omni_talker import ( + TalkerInferenceConfig, + TalkerNeuronConfig, + ) + + neuron_config = TalkerNeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=512, + torch_dtype=torch.bfloat16, + on_cpu=True, + ) + + config = TalkerInferenceConfig( + neuron_config=neuron_config, + load_config=_make_load_config( + hidden_size=896, + num_attention_heads=12, + num_hidden_layers=24, + num_key_value_heads=4, + vocab_size=8448, + rope_theta=1000000.0, + rms_norm_eps=1e-6, + hidden_act="silu", + intermediate_size=18944, + pad_token_id=0, + max_position_embeddings=32768, + ), + ) + + # Check derived config + assert config.head_dim == 128, f"Expected head_dim=128, got {config.head_dim}" + assert config.qkv_bias == True + assert config.o_bias == False + assert config.num_cores_per_group == 1 + assert config.rope_scaling is not None + assert config.rope_scaling["mrope_section"] == [16, 24, 24] + assert config.thinker_hidden_size == 3584 + + # Check neuron config cls + assert TalkerInferenceConfig.get_neuron_config_cls() == TalkerNeuronConfig + + # Check required attributes + required = config.get_required_attributes() + assert "hidden_size" in required + assert "num_attention_heads" in required + assert "head_dim" not in required # head_dim is derived, not required from HF + + print("PASS: TalkerInferenceConfig") + + +# ============================================================================ +# Test 2: Fused embedding in state dict conversion +# ============================================================================ + +def test_fused_embedding_conversion(): + """Test that embed_tokens + thinker_to_talker_proj are correctly fused.""" + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, + TalkerInferenceConfig, + TalkerNeuronConfig, + ) + + # Create fake state dict with talker weights + vocab_size, embed_dim, hidden_size = 8448, 3584, 896 + embed_weight = torch.randn(vocab_size, embed_dim) + proj_weight = torch.randn(hidden_size, embed_dim) + proj_bias = torch.randn(hidden_size) + + state_dict = { + "talker.model.embed_tokens.weight": embed_weight, + "talker.thinker_to_talker_proj.weight": proj_weight, + "talker.thinker_to_talker_proj.bias": proj_bias, + "talker.codec_head.weight": torch.randn(vocab_size, hidden_size), + "talker.model.layers.0.self_attn.q_proj.weight": torch.randn(12 * 128, hidden_size), + "talker.model.layers.0.self_attn.k_proj.weight": torch.randn(4 * 128, hidden_size), + "talker.model.layers.0.self_attn.v_proj.weight": torch.randn(4 * 128, hidden_size), + "talker.model.layers.0.self_attn.q_proj.bias": torch.randn(12 * 128), + "talker.model.layers.0.self_attn.k_proj.bias": torch.randn(4 * 128), + "talker.model.layers.0.self_attn.v_proj.bias": torch.randn(4 * 128), + "talker.model.norm.weight": torch.randn(hidden_size), + } + + # Create config + neuron_config = TalkerNeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=512, + torch_dtype=torch.bfloat16, + on_cpu=True, + fused_qkv=True, + ) + + config = TalkerInferenceConfig( + neuron_config=neuron_config, + load_config=_make_load_config( + hidden_size=896, + num_attention_heads=12, + num_hidden_layers=1, # Just 1 layer for testing + num_key_value_heads=4, + vocab_size=8448, + rope_theta=1000000.0, + rms_norm_eps=1e-6, + hidden_act="silu", + intermediate_size=18944, + pad_token_id=0, + max_position_embeddings=32768, + ), + ) + + # Convert + converted = NeuronQwen25OmniTalkerForCausalLM.convert_hf_to_neuron_state_dict( + state_dict, config + ) + + # Check fused embedding + assert "embed_tokens.weight" in converted + fused_embed = converted["embed_tokens.weight"] + assert fused_embed.shape == (vocab_size, hidden_size), \ + f"Expected ({vocab_size}, {hidden_size}), got {fused_embed.shape}" + + # Verify fused embedding is correct: embed @ proj.T + bias + expected = embed_weight.float() @ proj_weight.float().T + proj_bias.float().unsqueeze(0) + expected = expected.to(torch.bfloat16) + assert torch.allclose(fused_embed, expected, atol=1e-2), \ + "Fused embedding values don't match expected computation" + + # Check codec_head → lm_head mapping + assert "lm_head.weight" in converted + assert "codec_head.weight" not in converted + + # Check prefix stripping + assert "layers.0.self_attn.q_proj.weight" not in converted # Should be fused + assert "layers.0.self_attn.Wqkv.weight" in converted # Fused QKV + + # Check fused QKV shape: q(1536) + k(512) + v(512) = 2560 + qkv = converted["layers.0.self_attn.Wqkv.weight"] + assert qkv.shape[0] == 12 * 128 + 4 * 128 + 4 * 128, \ + f"Fused QKV wrong shape: {qkv.shape}" + + # Check projection weights saved for CPU context encoding + assert "_thinker_proj_weight" in converted + assert "_thinker_proj_bias" in converted + assert converted["_thinker_proj_weight"].shape == (hidden_size, embed_dim) + + # Check rank utilities + assert "rank_util.rank" in converted + assert "layers.0.self_attn.rank_util.rank" in converted + + print("PASS: Fused embedding conversion") + + +# ============================================================================ +# Test 3: ThinkerToTalkerProjection +# ============================================================================ + +def test_thinker_to_talker_projection(): + """Test CPU-side thinker state projection.""" + from modeling_qwen25_omni_talker import ( + ThinkerToTalkerProjection, + ) + + thinker_dim, talker_dim = 3584, 896 + + # Create from state dict + proj_weight = torch.randn(talker_dim, thinker_dim) + proj_bias = torch.randn(talker_dim) + state_dict = { + "_thinker_proj_weight": proj_weight, + "_thinker_proj_bias": proj_bias, + } + + proj = ThinkerToTalkerProjection.from_state_dict(state_dict, dtype=torch.float32) + + # Test forward + batch, seq = 2, 10 + thinker_states = torch.randn(batch, seq, thinker_dim) + output = proj(thinker_states) + + assert output.shape == (batch, seq, talker_dim), \ + f"Expected ({batch}, {seq}, {talker_dim}), got {output.shape}" + + # Verify output matches manual computation (relax atol for large dim=3584) + expected = thinker_states @ proj_weight.T + proj_bias + assert torch.allclose(output, expected, atol=1e-3), \ + "Projection output doesn't match expected" + + print("PASS: ThinkerToTalkerProjection") + + +# ============================================================================ +# Test 4: Talker RoPE +# ============================================================================ + +def test_talker_rotary_embedding(): + """Test TalkerRotaryEmbedding with both 1D and 3D positions.""" + from modeling_qwen25_omni_talker import ( + TalkerRotaryEmbedding, + ) + + class MockConfig: + head_dim = 128 + rope_theta = 1000000.0 + + emb = TalkerRotaryEmbedding(MockConfig()) + + # Test with 2D position_ids (standard RoPE) + batch, seq = 2, 16 + x = torch.randn(batch, seq, 128) + pos_2d = torch.arange(seq).unsqueeze(0).expand(batch, -1) # (batch, seq) + cos, sin = emb(x, pos_2d) + assert cos.shape == (batch, seq, 128), f"2D RoPE cos shape: {cos.shape}" + assert sin.shape == (batch, seq, 128), f"2D RoPE sin shape: {sin.shape}" + + # Test with 3D position_ids (mRoPE) + pos_3d = torch.arange(seq).unsqueeze(0).unsqueeze(0).expand(3, batch, -1) # (3, batch, seq) + cos, sin = emb(x, pos_3d) + assert cos.shape == (3, batch, seq, 128), f"3D mRoPE cos shape: {cos.shape}" + assert sin.shape == (3, batch, seq, 128), f"3D mRoPE sin shape: {sin.shape}" + + print("PASS: TalkerRotaryEmbedding") + + +# ============================================================================ +# Test 5: mRoPE application +# ============================================================================ + +def test_apply_multimodal_rotary_pos_emb(): + """Test mRoPE application function.""" + from modeling_qwen25_omni_talker import ( + _apply_multimodal_rotary_pos_emb, + ) + + batch, n_heads, seq, head_dim = 2, 12, 16, 128 + q = torch.randn(batch, n_heads, seq, head_dim) + k = torch.randn(batch, n_heads, seq, head_dim) + + # cos/sin from mRoPE: (3, batch, seq, head_dim) + cos = torch.randn(3, batch, seq, head_dim) + sin = torch.randn(3, batch, seq, head_dim) + mrope_section = [16, 24, 24] # sum=64, *2=128=head_dim + + q_out, k_out = _apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + + assert q_out.shape == q.shape, f"Q shape mismatch: {q_out.shape} vs {q.shape}" + assert k_out.shape == k.shape, f"K shape mismatch: {k_out.shape} vs {k.shape}" + + # Verify it's not identity (something changed) + assert not torch.allclose(q_out, q), "Q unchanged after RoPE" + assert not torch.allclose(k_out, k), "K unchanged after RoPE" + + print("PASS: apply_multimodal_rotary_pos_emb") + + +# ============================================================================ +# Test 6: Token2Wav Neuron DiT class +# ============================================================================ + +def test_token2wav_neuron_dit_class(): + """Test NeuronQwen25OmniToken2WavWithNeuronDiT class structure.""" + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2Wav, + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + + # Check inheritance + assert issubclass(NeuronQwen25OmniToken2WavWithNeuronDiT, NeuronQwen25OmniToken2Wav) + + # Check that Neuron DiT class has the expected methods + assert hasattr(NeuronQwen25OmniToken2WavWithNeuronDiT, "compile_dit") + assert hasattr(NeuronQwen25OmniToken2WavWithNeuronDiT, "load_dit") + assert hasattr(NeuronQwen25OmniToken2WavWithNeuronDiT, "_get_dit_module") + + print("PASS: Token2Wav Neuron DiT class structure") + + +# ============================================================================ +# Test 7: Orchestration methods +# ============================================================================ + +def test_orchestration_methods(): + """Test that orchestration class has the new methods.""" + import ast + + orchestration_path = ( + Path(__file__).resolve().parents[2] + / "src" / "modeling_qwen25_omni.py" + ) + with open(orchestration_path) as f: + tree = ast.parse(f.read()) + + # Find NeuronQwen25OmniMultimodalForCausalLM + target_cls = None + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == "NeuronQwen25OmniMultimodalForCausalLM": + target_cls = node + break + + assert target_cls is not None, "NeuronQwen25OmniMultimodalForCausalLM not found" + + methods = { + n.name for n in target_cls.body + if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + + # Check new methods exist + expected_methods = { + "enable_talker", + "compile_talker", + "load_talker", + "enable_token2wav", + "compile_token2wav_dit", + "load_token2wav_dit", + "_get_talker_cls", + "_get_neuron_talker_cls", + "_get_talker_config_cls", + "_get_thinker_projection_cls", + "_get_token2wav_cls", + "_get_neuron_token2wav_cls", + } + + missing = expected_methods - methods + assert not missing, f"Missing methods in orchestration: {missing}" + + print("PASS: Orchestration methods") + + +# ============================================================================ +# Test 8: Encode vision to input (thinker state injection) +# ============================================================================ + +def test_encode_vision_to_input(): + """Test NeuronTalkerModel.encode_vision_to_input for thinker state injection.""" + from modeling_qwen25_omni_talker import ( + NeuronTalkerModel, + ) + + batch, seq, hidden = 2, 16, 896 + + # Placeholder embeddings from embed_tokens + inputs_embeds = torch.zeros(batch, seq, hidden) + # Projected thinker states + vision_embeddings = torch.randn(batch, seq, hidden) + # Full mask (all positions are thinker states) + vision_mask = torch.ones(batch, seq, 1, dtype=torch.int32) + + # Call static-like method (doesn't need model instance) + result = NeuronTalkerModel.encode_vision_to_input(None, inputs_embeds, vision_embeddings, vision_mask) + + assert result.shape == (batch, seq, hidden) + assert torch.allclose(result, vision_embeddings), \ + "Full mask should replace all positions with thinker states" + + # Test partial mask + vision_mask_partial = torch.zeros(batch, seq, 1, dtype=torch.int32) + vision_mask_partial[:, :8, :] = 1 # First 8 positions are thinker states + + result_partial = NeuronTalkerModel.encode_vision_to_input( + None, inputs_embeds, vision_embeddings, vision_mask_partial + ) + assert torch.allclose(result_partial[:, :8, :], vision_embeddings[:, :8, :]), \ + "Masked positions should have thinker states" + assert torch.allclose(result_partial[:, 8:, :], inputs_embeds[:, 8:, :]), \ + "Unmasked positions should keep original embeddings" + + print("PASS: encode_vision_to_input") + + +# ============================================================================ +# Test 9: Class imports resolve correctly +# ============================================================================ + +def test_imports(): + """Test that all new classes can be imported.""" + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + TalkerNeuronConfig, + TalkerInferenceConfig, + TalkerRotaryEmbedding, + NeuronTalkerAttention, + NeuronTalkerDecoderLayer, + NeuronTalkerModel, + NeuronQwen25OmniTalkerForCausalLM, + ThinkerToTalkerProjection, + ) + + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2Wav, + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + + print("PASS: All imports successful") + + +# ============================================================================ +# Main +# ============================================================================ + +if __name__ == "__main__": + tests = [ + test_imports, + test_talker_inference_config, + test_fused_embedding_conversion, + test_thinker_to_talker_projection, + test_talker_rotary_embedding, + test_apply_multimodal_rotary_pos_emb, + test_token2wav_neuron_dit_class, + test_orchestration_methods, + test_encode_vision_to_input, + ] + + passed = 0 + failed = 0 + for test in tests: + try: + test() + passed += 1 + except Exception as e: + print(f"FAIL: {test.__name__}: {e}") + import traceback + traceback.print_exc() + failed += 1 + + print(f"\n{'='*50}") + print(f"Results: {passed} passed, {failed} failed out of {len(tests)} tests") + if failed == 0: + print("All tests passed!") + else: + sys.exit(1) From a377af90583b13fd9a795f47973ff1e9e1c573d2 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 11:53:07 +0800 Subject: [PATCH 02/14] Use huggingface_hub.snapshot_download as the default model path The previous defaults were brittle: one pointed at /opt/dlami/nvme/models/ which only exists on specific DLAMI instances, and another pointed at the parent snapshots/ directory (missing the commit hash sub-folder, so transformers could never find config.json). Replace them with a _model_path.resolve_model_path() helper that: 1. Honors QWEN25_OMNI_MODEL_PATH if it points at a dir with config.json. 2. Otherwise calls huggingface_hub.snapshot_download(HF_REPO_ID), which is a no-op if the model is already cached and returns the real snapshot directory (including commit hash) in either case. This lets developers run the examples and tests with zero setup beyond the NxDI venv -- if weights are not cached, they auto-download. Also removes the now-redundant _resolve_model_path() helper in generate_qwen25_omni_speech.py that manually reached into the snapshots/ directory to guess the commit subdir. Co-Authored-By: Claude Opus 4.7 --- .../examples/generate_qwen25_omni.py | 5 ++--- .../examples/generate_qwen25_omni_speech.py | 17 +++----------- .../models/Qwen2.5-Omni-7B/src/_model_path.py | 22 +++++++++++++++++++ .../test/integration/test_model.py | 5 ++--- 4 files changed, 29 insertions(+), 20 deletions(-) create mode 100644 contrib/models/Qwen2.5-Omni-7B/src/_model_path.py diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py index 9b4edfce..ad8d9858 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py @@ -49,9 +49,8 @@ # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- -MODEL_PATH = os.environ.get( - "QWEN25_OMNI_MODEL_PATH", "/opt/dlami/nvme/models/Qwen2.5-Omni-7B" -) +from _model_path import resolve_model_path +MODEL_PATH = resolve_model_path() COMPILED_PATH = os.environ.get( "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_compiled" ) diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py index 2dc7b4e2..e0de7341 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -63,10 +63,8 @@ # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- -MODEL_PATH = os.environ.get( - "QWEN25_OMNI_MODEL_PATH", - os.path.expanduser("~/.cache/huggingface/hub/models--Qwen--Qwen2.5-Omni-7B/snapshots/"), -) +from _model_path import resolve_model_path +MODEL_PATH = resolve_model_path() COMPILED_PATH = os.environ.get( "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_compiled" ) @@ -83,15 +81,6 @@ _ORIG_EMBEDDING_FORWARD = torch.nn.Embedding.forward -def _resolve_model_path(path): - """Resolve model path, handling HF cache snapshot directories.""" - if os.path.isdir(path) and not os.path.exists(os.path.join(path, "config.json")): - snaps = [d for d in os.listdir(path) if not d.startswith(".")] - if snaps: - path = os.path.join(path, snaps[0]) - return path - - def _restore_embedding(): """Restore original Embedding.forward if Neuron loading changed it.""" if torch.nn.Embedding.forward is not _ORIG_EMBEDDING_FORWARD: @@ -750,7 +739,7 @@ def main(): ) args = parser.parse_args() - model_path = _resolve_model_path(args.model_path) + model_path = args.model_path compiled_path = args.compiled_path num_runs = args.num_runs diff --git a/contrib/models/Qwen2.5-Omni-7B/src/_model_path.py b/contrib/models/Qwen2.5-Omni-7B/src/_model_path.py new file mode 100644 index 00000000..da9c1c69 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/_model_path.py @@ -0,0 +1,22 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Helper for resolving the Qwen2.5-Omni-7B weight path. +# +# Honors ``$QWEN25_OMNI_MODEL_PATH`` if it points at a directory with a +# ``config.json``. Otherwise delegates to ``huggingface_hub.snapshot_download`` +# which is a no-op if the model is already cached and returns the real snapshot +# directory (including the commit hash) in either case. + +import os + + +HF_REPO_ID = "Qwen/Qwen2.5-Omni-7B" + + +def resolve_model_path() -> str: + env = os.environ.get("QWEN25_OMNI_MODEL_PATH") + if env and os.path.isfile(os.path.join(env, "config.json")): + return env + from huggingface_hub import snapshot_download + return snapshot_download(HF_REPO_ID) diff --git a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py index 2fde653c..b4cbc728 100755 --- a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py +++ b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py @@ -44,9 +44,8 @@ import torch # Default paths - override with environment variables -MODEL_PATH = os.environ.get( - "QWEN25_OMNI_MODEL_PATH", "/opt/dlami/nvme/models/Qwen2.5-Omni-7B" -) +from _model_path import resolve_model_path +MODEL_PATH = resolve_model_path() COMPILED_PATH = os.environ.get( "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_tp4_compiled" ) From 1654edb4800e03abe496bdeb419070a6ceb33e3f Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 12:16:24 +0800 Subject: [PATCH 03/14] Document soundfile and qwen-omni-utils dependencies Two Python packages outside the NxDI venv are needed by the examples: - soundfile: used in generate_qwen25_omni_speech.py Phase 5 to write the output WAV. Previously imported lazily so the error only surfaced after ~100 seconds of pipeline execution; now imported up front with a clear install hint. - qwen-omni-utils[decord]: used in generate_qwen25_omni.py for multimodal preprocessing. Already documented in the script's docstring, now also in the README. Added a "Python dependencies" subsection under Prerequisites listing both pip-install commands. Co-Authored-By: Claude Opus 4.7 --- contrib/models/Qwen2.5-Omni-7B/README.md | 7 ++++++- .../examples/generate_qwen25_omni_speech.py | 8 +++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/README.md b/contrib/models/Qwen2.5-Omni-7B/README.md index 2d56ac4c..23276499 100644 --- a/contrib/models/Qwen2.5-Omni-7B/README.md +++ b/contrib/models/Qwen2.5-Omni-7B/README.md @@ -31,7 +31,12 @@ Key features: ## Prerequisites - **Instance**: trn2.48xlarge or trn2.xlarge (4+ NeuronCores sufficient) -- **Weights**: Download from [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) +- **Weights**: Download from [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) — the example scripts auto-download via `huggingface_hub.snapshot_download` on first run. +- **Python dependencies** (on top of the NxDI venv): + ```bash + pip install soundfile # writes WAV output in generate_qwen25_omni_speech.py + pip install qwen-omni-utils[decord] # process_mm_info() in generate_qwen25_omni.py for image/audio/video inputs + ``` ## Usage diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py index e0de7341..b0158d3e 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -60,6 +60,13 @@ import torch +try: + import soundfile as sf # noqa: F401 (used in Phase 5 to write the output WAV) +except ImportError: + sys.exit( + "soundfile is required for WAV output. Install with: pip install soundfile" + ) + # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- @@ -688,7 +695,6 @@ def run_token2wav(model_path, compiled_path, codec_codes, num_runs, temp_dir, ou audio_duration = 0 if first_wav is not None and isinstance(first_wav, torch.Tensor) and first_wav.numel() > 0: - import soundfile as sf wav_np = first_wav.detach().cpu().float().numpy().flatten() sf.write(output_wav, wav_np, 24000) audio_duration = len(wav_np) / 24000 From e92995bed38086bd3dc86ab5e328953c46bec45c Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 12:30:07 +0800 Subject: [PATCH 04/14] Speed up Phase 2 hidden-state extraction by loading HF model in bf16 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2 was loading Qwen2.5-Omni in float32 even though the downstream Talker consumes the projected states as bfloat16 — the upcast was pure overhead (~28GB RAM, ~6s load + ~4s forward on Trn2). Load the HF model in bfloat16 (the checkpoint's native dtype) with low_cpu_mem_usage=True, and make sure ThinkerToTalkerProjection is cast to match the weight dtype so the Linear runs in bf16 end-to-end. This shaves roughly half the Phase 2 wall time with no accuracy change (everything downstream already rounds back to bf16). Co-Authored-By: Claude Opus 4.7 --- .../examples/generate_qwen25_omni_speech.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py index b0158d3e..5476a034 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -396,13 +396,23 @@ def run_thinker(model_path, compiled_path, prompt, system_prompt, num_runs, temp def extract_hidden_states(model_path, thinker_result): - """Phase 2: Extract thinker hidden states via CPU forward pass.""" + """Phase 2: Extract thinker hidden states via CPU forward pass. + + We load the HF model on CPU in bfloat16 (the checkpoint's native dtype) and + run one forward pass to capture ``hidden_states[0]`` (token embeddings) and + ``hidden_states[-1]`` (final layer output). These are consumed by Phase 3's + ``thinker_to_talker_proj`` and ultimately reach the Talker as bfloat16, + so float32 here would just burn 2x memory and time with no accuracy win. + """ print("\n--- Phase 2: Hidden state extraction (CPU) ---") from transformers import Qwen2_5OmniForConditionalGeneration with Timer("Load HF model"): hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( - model_path, torch_dtype=torch.float32, trust_remote_code=True, + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, ) hf_model.eval() _restore_embedding() @@ -511,6 +521,7 @@ def prepare_talker_input(model_path, hf_model, outputs, full_ids, prompt_len, sp proj.proj.weight.data = proj_weight if proj_bias is not None: proj.proj.bias.data = proj_bias + proj.to(proj_weight.dtype) projected_context = proj(talker_inputs_embeds) projected_reply = proj(thinker_reply_part) From 846a060a15758be25626e82c14cd9f3840916ae1 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 14:42:33 +0800 Subject: [PATCH 05/14] Drop subprocess isolation: load Thinker/Talker/DiT once in the same process MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous implementation spawned a separate Python subprocess for each Neuron-compiled component (Thinker, Talker, Token2Wav DiT) based on the assumption that each model needed exclusive NeuronCore access. That wasn't actually true: since all three share TP=4 (NeuronCores 0-3), they can all live in the same process and swap NEFFs on the fly, exactly the pattern used by upstream contrib/Qwen-Image-Edit (PR #117). Concretely: - Removed `_run_subprocess`, `_SUBPROCESS_BOOTSTRAP`, and all three embedded f-string scripts (one per component). - `run_thinker` / `run_talker` / `run_token2wav` are now regular Python functions that take in-memory arguments and return in-memory results. No more `torch.save` / `torch.load` round-trips through a temp dir. - `load_thinker` / `load_talker` / `load_token2wav` / `load_hf_cpu` are invoked once in `main()`; the `for i in range(num_runs)` loop now calls the per-run phase functions on the already-loaded models. - `compile_all` was also de-subprocessed; compilation holds the Neuron compiler (not the runtime) so there was never a core-conflict risk. Performance impact (per your ~14s warm pipeline on trn2.48xlarge): - First run still pays the full ~60s one-shot load. - All subsequent runs skip the ~60s load — previously they paid it every time because each subprocess re-loaded and then exited. Co-Authored-By: Claude Opus 4.7 --- .../examples/generate_qwen25_omni_speech.py | 909 ++++++++---------- 1 file changed, 403 insertions(+), 506 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py index 5476a034..c903e496 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -2,28 +2,28 @@ """ End-to-end speech synthesis for Qwen2.5-Omni-7B on NeuronX (TP=4). -Full pipeline: Thinker (text) -> Talker (codec tokens) -> Token2Wav (audio) +Full pipeline: Thinker (text) -> Talker (codec tokens) -> Token2Wav (audio). + +All three Neuron-compiled components (Thinker, Talker, Token2Wav DiT) are +loaded *once* into the same Python process on the same NeuronCores (TP=4, +core 0-3) and reused across runs. The first inference still pays the +full model-load cost, but subsequent runs are pure inference. Two-step workflow: Step 1: Compile all Neuron components (one-time, ~30 min) - Step 2: Run inference (loads compiled artifacts, ~15s per utterance) - -Architecture note: - Thinker (TP=4) and Talker (TP=4) each require exclusive Neuron access, - so they run in separate subprocesses. Within each subprocess, the model - is loaded ONCE and reused for all --num-runs iterations. - Token2Wav DiT runs in the main process. + Step 2: Run inference Prerequisites: - Trn2 instance (trn2.48xlarge or trn2.xlarge, 4+ NeuronCores) - Neuron SDK 2.23+ with PyTorch 2.9 - - Model weights: huggingface-cli download Qwen/Qwen2.5-Omni-7B + - Model weights downloaded from Qwen/Qwen2.5-Omni-7B (auto-fetched on first run) + - pip install soundfile Usage: source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate cd neuronx-distributed-inference - # Step 1: Compile (one-time, ~30 min) + # Step 1: Compile (one-time) python examples/generate_qwen25_omni_speech.py --compile # Step 2: Run inference @@ -33,12 +33,6 @@ # Benchmark: load each model once, run N inferences, report avg latency python examples/generate_qwen25_omni_speech.py --num-runs 5 - -Pipeline timing (trn2.48xlarge, TP=4, model already loaded): - Thinker: ~0.3s (text generation) - Talker: ~2-3s (codec token generation) - Token2Wav: ~10s (mel spectrogram + vocoder) - Total: ~15s for ~10s of audio (RTF ~1.5x) """ # --- Qwen2.5-Omni contrib bootstrap --- @@ -52,25 +46,21 @@ import argparse import gc -import json import os -import subprocess import sys import time import torch try: - import soundfile as sf # noqa: F401 (used in Phase 5 to write the output WAV) + import soundfile as sf except ImportError: sys.exit( "soundfile is required for WAV output. Install with: pip install soundfile" ) -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- from _model_path import resolve_model_path + MODEL_PATH = resolve_model_path() COMPILED_PATH = os.environ.get( "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_compiled" @@ -108,51 +98,80 @@ def __exit__(self, *args): print(f" [{self.label}] {self.elapsed:.2f}s") -_SUBPROCESS_BOOTSTRAP = ( - "import sys, os\n" - f"sys.path.insert(0, {str(_SRC)!r})\n" - "import _upstream_compat # noqa: F401\n" -) +# ========================================================================== +# Compilation (--compile) +# ========================================================================== +def _compile_thinker(model_path, out_path): + from neuronx_distributed_inference.models.config import ( + NeuronConfig, OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, Qwen25OmniInferenceConfig, + ) -def _run_subprocess(script_code, label, temp_dir): - """Run Python code as a subprocess (required for Neuron process isolation).""" - script_path = os.path.join(temp_dir, f"{label}.py") - with open(script_path, "w") as f: - f.write(_SUBPROCESS_BOOTSTRAP + script_code) - env = dict(os.environ) - existing = env.get("PYTHONPATH", "") - env["PYTHONPATH"] = f"{_SRC}{os.pathsep}{existing}" if existing else str(_SRC) - t0 = time.time() - result = subprocess.run( - [sys.executable, script_path], - capture_output=True, text=True, timeout=600, env=env, + nc = NeuronConfig( + tp_degree=TP_DEGREE, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.7, top_k=20, top_p=0.8, + ), ) - elapsed = time.time() - t0 - if result.returncode != 0: - print(f" [{label}] FAILED ({elapsed:.1f}s)") - for line in result.stderr.strip().split("\n")[-15:]: - if line.strip() and not any( - x in line for x in ["WARN", "TDRV", "NMGR", "NRT", "nccl", "blockwise"] - ): - print(f" {line}") - for line in result.stdout.strip().split("\n")[-5:]: - if line.strip(): - print(f" {line}") - return False - print(f" [{label}] subprocess finished ({elapsed:.1f}s)") - for line in result.stdout.strip().split("\n"): - if line.strip(): - print(f" {line}") - return True + cfg = Qwen25OmniInferenceConfig(nc, load_config=load_pretrained_config(model_path)) + model = NeuronQwen25OmniForCausalLM(model_path, cfg) + model.compile(out_path) -# ========================================================================== -# Compilation (--compile) -# ========================================================================== +def _compile_talker(model_path, out_path): + from transformers import AutoConfig + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, TalkerInferenceConfig, TalkerNeuronConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + + hf = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + tc = hf.talker_config + + tnc = TalkerNeuronConfig( + tp_degree=TP_DEGREE, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + ) + tic = TalkerInferenceConfig( + neuron_config=tnc, load_config=load_pretrained_config(hf_config=tc), + ) + talker = NeuronQwen25OmniTalkerForCausalLM(model_path, config=tic) + talker.compile(out_path) + + +def _compile_dit(model_path, out_path): + from transformers import AutoConfig + from safetensors.torch import load_file + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(hf_config.token2wav_config) + + state_dict = {} + for fn in sorted(os.listdir(model_path)): + if fn.endswith(".safetensors"): + sd = load_file(os.path.join(model_path, fn)) + for k, v in sd.items(): + if k.startswith("token2wav."): + state_dict[k[len("token2wav."):]] = v + t2w.load_state_dict(state_dict, strict=False) + t2w.compile_dit(out_path, max_mel_len=2048, batch_size=2) + def compile_all(model_path, compiled_path): - """Compile all three Neuron components: Thinker, Talker, DiT.""" + """Compile all three Neuron components: Thinker, Talker, DiT. + + Each component is compiled sequentially in the current process. Compilation + holds the Neuron compiler (not the runtime) so there's no core-conflict + issue even when all three share TP=4 / core 0-3. + """ print("=" * 60) print("Compiling Qwen2.5-Omni Speech Components") print("=" * 60) @@ -161,129 +180,31 @@ def compile_all(model_path, compiled_path): print(f" TP: {TP_DEGREE}") t_total = time.time() - temp_dir = os.path.join(compiled_path, "_tmp") - os.makedirs(temp_dir, exist_ok=True) - - # --- 1. Compile Thinker --- - print("\n--- [1/3] Compiling Thinker ---") - thinker_compiled = os.path.join(compiled_path, "thinker_tp4") - if os.path.exists(os.path.join(thinker_compiled, "neuron_config.json")): - print(" Already compiled, skipping.") - else: - script = f''' -import torch, os -MODEL_PATH = "{model_path}" -COMPILED = "{thinker_compiled}" - -from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config -from modeling_qwen25_omni import ( - NeuronQwen25OmniForCausalLM, Qwen25OmniInferenceConfig, -) - -nc = NeuronConfig( - tp_degree={TP_DEGREE}, batch_size=1, seq_len=2048, max_context_length=2048, - torch_dtype=torch.bfloat16, - on_device_sampling_config=OnDeviceSamplingConfig( - do_sample=True, temperature=0.7, top_k=20, top_p=0.8 - ), -) -cfg = Qwen25OmniInferenceConfig(nc, load_config=load_pretrained_config(MODEL_PATH)) -model = NeuronQwen25OmniForCausalLM(MODEL_PATH, cfg) -model.compile(COMPILED) -print("Thinker compiled successfully") -''' - ok = _run_subprocess(script, "compile_thinker", temp_dir) - if not ok: - print("FATAL: Thinker compilation failed.") - return False - - # --- 2. Compile Talker --- - print("\n--- [2/3] Compiling Talker ---") - talker_compiled = os.path.join(compiled_path, "talker_tp4") - if os.path.exists(os.path.join(talker_compiled, "neuron_config.json")): - print(" Already compiled, skipping.") - else: - script = f''' -import torch, os -MODEL_PATH = "{model_path}" -COMPILED = "{talker_compiled}" - -from transformers import AutoConfig -from modeling_qwen25_omni_talker import ( - NeuronQwen25OmniTalkerForCausalLM, TalkerInferenceConfig, TalkerNeuronConfig, -) -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config - -hf = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) -tc = hf.talker_config - -tnc = TalkerNeuronConfig( - tp_degree={TP_DEGREE}, batch_size=1, seq_len=2048, max_context_length=2048, - torch_dtype=torch.bfloat16, -) -tic = TalkerInferenceConfig(neuron_config=tnc, load_config=load_pretrained_config(hf_config=tc)) -talker = NeuronQwen25OmniTalkerForCausalLM(MODEL_PATH, config=tic) -talker.compile(COMPILED) -print("Talker compiled successfully") -''' - ok = _run_subprocess(script, "compile_talker", temp_dir) - if not ok: - print("FATAL: Talker compilation failed.") - return False - - # --- 3. Compile DiT --- - print("\n--- [3/3] Compiling Token2Wav DiT ---") - dit_compiled = os.path.join(compiled_path, "dit_core") - if os.path.exists(os.path.join(dit_compiled, "dit_core_neuron.pt")): - print(" Already compiled, skipping.") - else: - script = f''' -import torch, os -MODEL_PATH = "{model_path}" -COMPILED = "{dit_compiled}" - -from transformers import AutoConfig -from safetensors.torch import load_file -from modeling_qwen25_omni_token2wav import ( - NeuronQwen25OmniToken2WavWithNeuronDiT, -) + stages = [ + ("Thinker", "thinker_tp4", "neuron_config.json", _compile_thinker), + ("Talker", "talker_tp4", "neuron_config.json", _compile_talker), + ("DiT", "dit_core", "dit_core_neuron.pt", _compile_dit), + ] + for idx, (label, subdir, marker, fn) in enumerate(stages, 1): + print(f"\n--- [{idx}/{len(stages)}] Compiling {label} ---") + out_path = os.path.join(compiled_path, subdir) + if os.path.exists(os.path.join(out_path, marker)): + print(" Already compiled, skipping.") + continue + t0 = time.time() + fn(model_path, out_path) + print(f" {label} compiled in {time.time() - t0:.1f}s") -hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) -t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(hf_config.token2wav_config) - -state_dict = {{}} -for fn in sorted(os.listdir(MODEL_PATH)): - if fn.endswith(".safetensors"): - sd = load_file(os.path.join(MODEL_PATH, fn)) - for k, v in sd.items(): - if k.startswith("token2wav."): - state_dict[k[len("token2wav."):]] = v -t2w.load_state_dict(state_dict, strict=False) - -t2w.compile_dit(COMPILED, max_mel_len=2048, batch_size=2) -print("DiT compiled successfully") -''' - ok = _run_subprocess(script, "compile_dit", temp_dir) - if not ok: - print("FATAL: DiT compilation failed.") - return False - - total = time.time() - t_total - print(f"\nAll components compiled in {total:.0f}s") + print(f"\nAll components compiled in {time.time() - t_total:.0f}s") print(f"Artifacts saved to: {compiled_path}/") - print(f" thinker_tp4/ - Thinker (7B text model)") - print(f" talker_tp4/ - Talker (690M codec model)") - print(f" dit_core/ - Token2Wav DiT (85M transformer)") return True # ========================================================================== -# Inference +# Inference: model loading (once per process) # ========================================================================== def _check_compiled(compiled_path): - """Verify all compiled artifacts exist.""" checks = [ (os.path.join(compiled_path, "thinker_tp4", "neuron_config.json"), "Thinker"), (os.path.join(compiled_path, "talker_tp4", "neuron_config.json"), "Talker"), @@ -298,145 +219,195 @@ def _check_compiled(compiled_path): return True -def run_thinker(model_path, compiled_path, prompt, system_prompt, num_runs, temp_dir): - """Phase 1: Load Thinker once, generate num_runs times, return first result + avg time.""" - print(f"\n--- Phase 1: Thinker (text generation, {num_runs} runs) ---") +def load_thinker(model_path, compiled_path): + """Load the Thinker (Qwen2.5-Omni text model) onto Neuron, return (adapter, tokenizer).""" + from neuronx_distributed_inference.models.config import ( + NeuronConfig, OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, HuggingFaceGenerationAdapter, + ) + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, Qwen25OmniInferenceConfig, + ) + from transformers import AutoTokenizer + + nc = NeuronConfig( + tp_degree=TP_DEGREE, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.7, top_k=20, top_p=0.8, + ), + ) + cfg = Qwen25OmniInferenceConfig(nc, load_config=load_pretrained_config(model_path)) + model = NeuronQwen25OmniForCausalLM(model_path, cfg) - thinker_compiled = os.path.join(compiled_path, "thinker_tp4") - output_file = os.path.join(temp_dir, "thinker_output.json") + t0 = time.time() + model.load(os.path.join(compiled_path, "thinker_tp4")) + load_time = time.time() - t0 + print(f" [Thinker] loaded in {load_time:.1f}s") + + tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + adapter = HuggingFaceGenerationAdapter(model) + + # Warmup the NEFF so the first real inference isn't artificially slow. + enc = tok( + tok.apply_chat_template( + [{"role": "user", "content": "Hi"}], + tokenize=False, add_generation_prompt=True, + ), + return_tensors="pt", + ) + _ = adapter.generate( + input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], + max_new_tokens=5, eos_token_id=[tok.eos_token_id, 151645], + ) + print(" [Thinker] warmup done") + return adapter, tok, load_time - script = f''' -import torch, os, json, time -MODEL_PATH = "{model_path}" -COMPILED = "{thinker_compiled}" -OUTPUT = "{output_file}" -NUM_RUNS = {num_runs} -from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter -from modeling_qwen25_omni import ( - NeuronQwen25OmniForCausalLM, Qwen25OmniInferenceConfig, -) -from transformers import AutoTokenizer - -nc = NeuronConfig( - tp_degree={TP_DEGREE}, batch_size=1, seq_len=2048, max_context_length=2048, - torch_dtype=torch.bfloat16, - on_device_sampling_config=OnDeviceSamplingConfig( - do_sample=True, temperature=0.7, top_k=20, top_p=0.8 - ), -) -cfg = Qwen25OmniInferenceConfig(nc, load_config=load_pretrained_config(MODEL_PATH)) -model = NeuronQwen25OmniForCausalLM(MODEL_PATH, cfg) - -t_load = time.time() -model.load(COMPILED) -t_load = time.time() - t_load -print(f"Model loaded in {{t_load:.1f}}s") - -tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) -if tok.pad_token is None: - tok.pad_token = tok.eos_token -adp = HuggingFaceGenerationAdapter(model) - -# Warmup (first inference is slower due to Neuron warm-up) -enc = tok(tok.apply_chat_template( - [{{"role":"user","content":"Hi"}}], tokenize=False, add_generation_prompt=True -), return_tensors="pt") -_ = adp.generate( - input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], - max_new_tokens=5, eos_token_id=[tok.eos_token_id, 151645], -) -print("Warmup done") - -chat = [ - {{"role":"system","content":"{system_prompt}"}}, - {{"role":"user","content":"{prompt}"}}, -] -enc = tok(tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True), return_tensors="pt") - -print(f"Running {{NUM_RUNS}} inferences (model stays loaded)...") -first_result = None -times = [] -for i in range(NUM_RUNS): +def load_talker(model_path, compiled_path): + """Load the Talker model onto Neuron and return (talker, adapter, talker_config).""" + from transformers import AutoConfig + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, TalkerInferenceConfig, TalkerNeuronConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, HuggingFaceGenerationAdapter, + ) + + hf = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + tc = hf.talker_config + + tnc = TalkerNeuronConfig( + tp_degree=TP_DEGREE, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + ) + tic = TalkerInferenceConfig( + neuron_config=tnc, load_config=load_pretrained_config(hf_config=tc), + ) + talker = NeuronQwen25OmniTalkerForCausalLM(model_path, config=tic) + + t0 = time.time() + talker.load(os.path.join(compiled_path, "talker_tp4")) + load_time = time.time() - t0 + print(f" [Talker] loaded in {load_time:.1f}s") + + adapter = HuggingFaceGenerationAdapter(talker) + return talker, adapter, tc, load_time + + +def load_token2wav(model_path, compiled_path): + """Load the Token2Wav model (DiT on Neuron + BigVGAN on CPU).""" + from transformers import AutoConfig + from safetensors.torch import load_file + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + t2w_cfg = hf_config.token2wav_config + + t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(t2w_cfg) + + state_dict = {} + for fn in sorted(os.listdir(model_path)): + if fn.endswith(".safetensors"): + sd = load_file(os.path.join(model_path, fn)) + for k, v in sd.items(): + if k.startswith("token2wav."): + state_dict[k[len("token2wav."):]] = v + t2w.load_state_dict(state_dict, strict=False) + t0 = time.time() - out = adp.generate( + t2w.load_dit(os.path.join(compiled_path, "dit_core")) + load_time = time.time() - t0 + _restore_embedding() + print(f" [DiT] loaded in {load_time:.1f}s") + return t2w, t2w_cfg, load_time + + +def load_hf_cpu(model_path): + """Load the HF Qwen2.5-Omni model on CPU in bfloat16 (for hidden-state extraction).""" + from transformers import Qwen2_5OmniForConditionalGeneration + + t0 = time.time() + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + hf_model.eval() + _restore_embedding() + load_time = time.time() - t0 + print(f" [HF CPU] loaded in {load_time:.1f}s") + return hf_model, load_time + + +# ========================================================================== +# Inference: per-run phases +# ========================================================================== + +def run_thinker(thinker_adapter, tokenizer, prompt, system_prompt): + """Phase 1: Thinker generates text.""" + chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ] + enc = tokenizer( + tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True), + return_tensors="pt", + ) + + t0 = time.time() + out = thinker_adapter.generate( input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], - max_new_tokens=200, eos_token_id=[tok.eos_token_id, 151645], + max_new_tokens=200, eos_token_id=[tokenizer.eos_token_id, 151645], ) elapsed = time.time() - t0 - times.append(elapsed) prompt_len = enc["input_ids"].shape[1] all_ids = out[0].tolist() gen_ids = all_ids[prompt_len:] - text = tok.decode(gen_ids, skip_special_tokens=True) - n_tokens = len(gen_ids) - print(f" Run {{i+1}}/{{NUM_RUNS}}: {{n_tokens}} tokens in {{elapsed:.3f}}s - {{text[:80]}}") - - if first_result is None: - first_result = {{"all_ids": all_ids, "prompt_len": prompt_len, - "gen_text": text, "n_tokens": n_tokens}} - -avg_time = sum(times) / len(times) -first_result["gen_time"] = avg_time -first_result["all_times"] = times -first_result["load_time"] = t_load -print(f"Avg inference: {{avg_time:.3f}}s (load: {{t_load:.1f}}s, not included in avg)") - -with open(OUTPUT, "w") as f: - json.dump(first_result, f) -''' - ok = _run_subprocess(script, "thinker", temp_dir) - if not ok: - return None - - with open(output_file) as f: - return json.load(f) - - -def extract_hidden_states(model_path, thinker_result): - """Phase 2: Extract thinker hidden states via CPU forward pass. - - We load the HF model on CPU in bfloat16 (the checkpoint's native dtype) and - run one forward pass to capture ``hidden_states[0]`` (token embeddings) and - ``hidden_states[-1]`` (final layer output). These are consumed by Phase 3's - ``thinker_to_talker_proj`` and ultimately reach the Talker as bfloat16, - so float32 here would just burn 2x memory and time with no accuracy win. - """ - print("\n--- Phase 2: Hidden state extraction (CPU) ---") - from transformers import Qwen2_5OmniForConditionalGeneration + text = tokenizer.decode(gen_ids, skip_special_tokens=True) + return { + "all_ids": all_ids, + "prompt_len": prompt_len, + "gen_text": text, + "n_tokens": len(gen_ids), + "gen_time": elapsed, + } - with Timer("Load HF model"): - hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True, - ) - hf_model.eval() - _restore_embedding() +def extract_hidden_states(hf_model, thinker_result): + """Phase 2: Run HF Thinker on CPU to capture hidden states. + + The compiled Neuron Thinker uses on-device sampling and only emits tokens, + not hidden states. The Talker needs the per-token last-layer hidden states + to condition on, so we re-run the prompt+reply through the HF CPU model in + bfloat16 — all downstream consumers already round back to bf16 so float32 + here would be pure overhead. + """ full_ids = torch.tensor([thinker_result["all_ids"]], dtype=torch.long) prompt_len = thinker_result["prompt_len"] - with Timer("Forward pass"): - with torch.no_grad(): - outputs = hf_model.thinker( - input_ids=full_ids, output_hidden_states=True, return_dict=True, - ) - - return hf_model, outputs, full_ids, prompt_len + t0 = time.time() + with torch.no_grad(): + outputs = hf_model.thinker( + input_ids=full_ids, output_hidden_states=True, return_dict=True, + ) + elapsed = time.time() - t0 + return outputs, full_ids, prompt_len, elapsed -def prepare_talker_input(model_path, hf_model, outputs, full_ids, prompt_len, speaker, temp_dir): +def prepare_talker_input(model_path, hf_model, outputs, full_ids, prompt_len, speaker): """Phase 3: Build projected thinker states for the Talker.""" - print("\n--- Phase 3: Talker input preparation ---") from transformers import AutoConfig from safetensors.torch import load_file - from modeling_qwen25_omni_talker import ( - ThinkerToTalkerProjection, - ) + from modeling_qwen25_omni_talker import ThinkerToTalkerProjection hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) talker_cfg = hf_config.talker_config @@ -486,29 +457,28 @@ def prepare_talker_input(model_path, hf_model, outputs, full_ids, prompt_len, sp break if talker_embed_weight is not None: talker_embed_layer = torch.nn.Embedding( - talker_embed_weight.shape[0], talker_embed_weight.shape[1] + talker_embed_weight.shape[0], talker_embed_weight.shape[1], ) talker_embed_layer.weight.data = talker_embed_weight.float() codec_bos_embed = talker_embed_layer( - torch.tensor([talker_cfg.tts_codec_start_token_id]) + torch.tensor([talker_cfg.tts_codec_start_token_id]), ) codec_pad_embed = talker_embed_layer( - torch.tensor([talker_cfg.tts_codec_pad_token_id]) + torch.tensor([talker_cfg.tts_codec_pad_token_id]), ) talker_inputs_embeds[:, -1, :] += codec_bos_embed talker_inputs_embeds[:, -2, :] += codec_pad_embed eos_embed = thinker_embed_tokens( - torch.tensor([[talker_cfg.tts_text_end_token_id]], dtype=torch.long) + torch.tensor([[talker_cfg.tts_text_end_token_id]], dtype=torch.long), ) pad_embed = thinker_embed_tokens( - torch.tensor([[talker_cfg.tts_text_pad_token_id]], dtype=torch.long) + torch.tensor([[talker_cfg.tts_text_pad_token_id]], dtype=torch.long), ) thinker_reply_part = torch.cat([thinker_reply_part[:, 1:, :], eos_embed, pad_embed], dim=1) context_len = talker_inputs_embeds.shape[1] n_reply = thinker_reply_part.shape[1] - print(f" Context: {context_len} tokens, Reply: {n_reply} tokens") proj_weight = proj_bias = None for k, v in hf_model.state_dict().items(): @@ -526,90 +496,44 @@ def prepare_talker_input(model_path, hf_model, outputs, full_ids, prompt_len, sp projected_context = proj(talker_inputs_embeds) projected_reply = proj(thinker_reply_part) - talker_input = { + return { "projected_context": projected_context, "projected_reply": projected_reply, "context_len": context_len, "n_reply": n_reply, - "prompt_len": prompt_len, "conditioning": conditioning, "reference_mel": reference_mel, } - torch.save(talker_input, os.path.join(temp_dir, "talker_input.pt")) - - del hf_model, outputs - gc.collect() - return context_len - -def run_talker(model_path, compiled_path, context_len, num_runs, temp_dir): - """Phase 4: Load Talker once, generate num_runs times, return first result + avg time.""" - print(f"\n--- Phase 4: Talker (codec token generation, {num_runs} runs) ---") - talker_compiled = os.path.join(compiled_path, "talker_tp4") - output_file = os.path.join(temp_dir, "talker_output.json") +def run_talker(talker_model, talker_adapter, talker_cfg, talker_input): + """Phase 4: Talker generates codec tokens.""" + projected_context = talker_input["projected_context"] + projected_reply = talker_input["projected_reply"] + context_len = talker_input["context_len"] - script = f''' -import torch, os, json, time -MODEL_PATH = "{model_path}" -COMPILED = "{talker_compiled}" -TEMP_DIR = "{temp_dir}" -NUM_RUNS = {num_runs} + codec_bos = talker_cfg.tts_codec_start_token_id + codec_eos = talker_cfg.tts_codec_end_token_id + codec_pad = talker_cfg.tts_codec_pad_token_id + codec_mask = talker_cfg.tts_codec_mask_token_id -from transformers import AutoConfig -from modeling_qwen25_omni_talker import ( - NeuronQwen25OmniTalkerForCausalLM, TalkerInferenceConfig, TalkerNeuronConfig, -) -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter + talker_input_ids = torch.cat([ + torch.full((1, context_len - 2), codec_mask, dtype=torch.long), + torch.tensor([[codec_pad]], dtype=torch.long), + torch.tensor([[codec_bos]], dtype=torch.long), + ], dim=1) + talker_attention_mask = torch.ones_like(talker_input_ids, dtype=torch.long) -hf = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) -tc = hf.talker_config + max_gen = min(600, 2048 - context_len - 10) -tnc = TalkerNeuronConfig( - tp_degree={TP_DEGREE}, batch_size=1, seq_len=2048, max_context_length=2048, - torch_dtype=torch.bfloat16, -) -tic = TalkerInferenceConfig(neuron_config=tnc, load_config=load_pretrained_config(hf_config=tc)) -talker = NeuronQwen25OmniTalkerForCausalLM(MODEL_PATH, config=tic) - -t_load = time.time() -talker.load(COMPILED) -t_load = time.time() - t_load -print(f"Model loaded in {{t_load:.1f}}s") - -adp = HuggingFaceGenerationAdapter(talker) - -inp = torch.load(os.path.join(TEMP_DIR, "talker_input.pt"), weights_only=False) -projected_context = inp["projected_context"] -projected_reply = inp["projected_reply"] -context_len = inp["context_len"] - -codec_bos = tc.tts_codec_start_token_id -codec_eos = tc.tts_codec_end_token_id -codec_pad = tc.tts_codec_pad_token_id -codec_mask = tc.tts_codec_mask_token_id - -talker_input_ids = torch.cat([ - torch.full((1, context_len - 2), codec_mask, dtype=torch.long), - torch.tensor([[codec_pad]], dtype=torch.long), - torch.tensor([[codec_bos]], dtype=torch.long), -], dim=1) -talker_attention_mask = torch.ones_like(talker_input_ids, dtype=torch.long) - -max_gen = min(600, 2048 - context_len - 10) - -print(f"Running {{NUM_RUNS}} inferences (model stays loaded)...") -first_codes = None -times = [] -for i in range(NUM_RUNS): - # Re-set vision embeddings before each run (context encoding consumes them) + # Re-set vision embeddings before each run (context encoding consumes them). ve = projected_context.to(torch.bfloat16) vm = torch.ones(1, context_len, 1, dtype=torch.int32) reply = projected_reply.to(torch.bfloat16) - talker.set_vision_embeddings(ve, vm, thinker_reply_embeds=reply) + talker_model.set_vision_embeddings(ve, vm, thinker_reply_embeds=reply) t0 = time.time() - out = adp.generate( + out = talker_adapter.generate( input_ids=talker_input_ids, attention_mask=talker_attention_mask, max_new_tokens=max_gen, @@ -619,99 +543,30 @@ def run_talker(model_path, compiled_path, context_len, num_runs, temp_dir): repetition_penalty=1.05, ) elapsed = time.time() - t0 - times.append(elapsed) gen_tokens = out[0, context_len:].tolist() while gen_tokens and gen_tokens[-1] == codec_eos: gen_tokens.pop() - print(f" Run {{i+1}}/{{NUM_RUNS}}: {{len(gen_tokens)}} codec tokens in {{elapsed:.3f}}s") - - if first_codes is None: - first_codes = gen_tokens - -avg_time = sum(times) / len(times) -print(f"Avg inference: {{avg_time:.3f}}s (load: {{t_load:.1f}}s, not included in avg)") - -result = {{"codes": first_codes, "gen_time": avg_time, "all_times": times, "load_time": t_load}} -with open(os.path.join(TEMP_DIR, "talker_output.json"), "w") as f: - json.dump(result, f) -''' - ok = _run_subprocess(script, "talker", temp_dir) - if not ok: - return None - - with open(output_file) as f: - return json.load(f) - - -def run_token2wav(model_path, compiled_path, codec_codes, num_runs, temp_dir, output_wav): - """Phase 5: Load DiT once, run num_runs times, save first result, report avg time.""" - print(f"\n--- Phase 5: Token2Wav (waveform synthesis, {num_runs} runs) ---") - - from transformers import AutoConfig - from safetensors.torch import load_file - from modeling_qwen25_omni_token2wav import ( - NeuronQwen25OmniToken2WavWithNeuronDiT, - ) - - hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - t2w_cfg = hf_config.token2wav_config + return gen_tokens, elapsed - t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(t2w_cfg) - - state_dict = {} - for fn in sorted(os.listdir(model_path)): - if fn.endswith(".safetensors"): - sd = load_file(os.path.join(model_path, fn)) - for k, v in sd.items(): - if k.startswith("token2wav."): - state_dict[k[len("token2wav."):]] = v - t2w.load_state_dict(state_dict, strict=False) - - dit_compiled = os.path.join(compiled_path, "dit_core") - with Timer("Load compiled DiT"): - t2w.load_dit(dit_compiled) - _restore_embedding() +def run_token2wav(t2w, t2w_cfg, codec_codes, conditioning, reference_mel): + """Phase 5: Token2Wav DiT + BigVGAN synthesize a waveform.""" code_tensor = torch.tensor([codec_codes], dtype=torch.long) num_embeds = getattr(t2w_cfg.dit_config, "num_embeds", 8193) if code_tensor.max() >= num_embeds: code_tensor = code_tensor.clamp(0, num_embeds) - inp = torch.load(os.path.join(temp_dir, "talker_input.pt"), weights_only=False) - conditioning = inp["conditioning"] - reference_mel = inp["reference_mel"] - - print(f" Running {num_runs} inferences (DiT stays loaded)...") - first_wav = None - times = [] - for i in range(num_runs): - t0 = time.time() - wav = t2w( - code=code_tensor, - conditioning=conditioning, - reference_mel=reference_mel, - num_steps=10, - guidance_scale=0.5, - ) - elapsed = time.time() - t0 - times.append(elapsed) - print(f" Run {i+1}/{num_runs}: {elapsed:.2f}s") - - if first_wav is None: - first_wav = wav - - avg_time = sum(times) / len(times) - print(f" Avg inference: {avg_time:.2f}s") - - audio_duration = 0 - if first_wav is not None and isinstance(first_wav, torch.Tensor) and first_wav.numel() > 0: - wav_np = first_wav.detach().cpu().float().numpy().flatten() - sf.write(output_wav, wav_np, 24000) - audio_duration = len(wav_np) / 24000 - print(f" Audio: {audio_duration:.1f}s saved to {output_wav}") - - return audio_duration, avg_time, times + t0 = time.time() + wav = t2w( + code=code_tensor, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=10, + guidance_scale=0.5, + ) + elapsed = time.time() - t0 + return wav, elapsed # ========================================================================== @@ -760,20 +615,15 @@ def main(): compiled_path = args.compiled_path num_runs = args.num_runs - # --- Compile mode --- if args.compile: ok = compile_all(model_path, compiled_path) sys.exit(0 if ok else 1) - # --- Inference mode --- if not _check_compiled(compiled_path): sys.exit(1) - temp_dir = os.path.join(compiled_path, "_tmp") - os.makedirs(temp_dir, exist_ok=True) - print("=" * 60) - print("Qwen2.5-Omni Speech Pipeline (Neuron, TP=4)") + print("Qwen2.5-Omni Speech Pipeline (Neuron, TP=4, single process)") print("=" * 60) print(f" Model: {model_path}") print(f" Compiled: {compiled_path}") @@ -783,74 +633,121 @@ def main(): print(f" Runs: {num_runs}") t_total = time.time() - # Phase 1: Thinker (load once, run N times in subprocess) - thinker_result = run_thinker( - model_path, compiled_path, args.prompt, args.system_prompt, - num_runs, temp_dir, - ) - if not thinker_result: - print("Thinker failed, aborting.") - return - print(f" Text: {thinker_result['gen_text'][:200]}") + # ----- Load everything once ----- + print("\n--- Loading models (one-time cost) ---") + t_load_total = time.time() + thinker_adapter, tokenizer, thinker_load = load_thinker(model_path, compiled_path) + hf_model, hf_load = load_hf_cpu(model_path) + talker_model, talker_adapter, talker_cfg, talker_load = load_talker(model_path, compiled_path) + t2w, t2w_cfg, dit_load = load_token2wav(model_path, compiled_path) + total_load = time.time() - t_load_total + print(f" Total model load time: {total_load:.1f}s") + + # ----- Run the pipeline num_runs times ----- + thinker_times, talker_times, t2w_times = [], [], [] + hidden_times, prep_times = [], [] + first_text = first_codes = first_wav = None + first_audio_duration = 0.0 - # Phase 2: Hidden states (CPU, run once) - hf_model, outputs, full_ids, prompt_len = extract_hidden_states( - model_path, thinker_result - ) + for i in range(num_runs): + print(f"\n--- Run {i+1}/{num_runs} ---") - # Phase 3: Talker input prep (CPU, run once) - context_len = prepare_talker_input( - model_path, hf_model, outputs, full_ids, prompt_len, - args.speaker, temp_dir, - ) + thinker_result = run_thinker( + thinker_adapter, tokenizer, args.prompt, args.system_prompt, + ) + thinker_times.append(thinker_result["gen_time"]) + print( + f" [Thinker] {thinker_result['n_tokens']} tokens in " + f"{thinker_result['gen_time']:.3f}s - {thinker_result['gen_text'][:80]}" + ) - # Phase 4: Talker (load once, run N times in subprocess) - talker_result = run_talker( - model_path, compiled_path, context_len, num_runs, temp_dir, - ) - if not talker_result or not talker_result["codes"]: - print("Talker failed or produced no tokens, aborting.") - return - print(f" {len(talker_result['codes'])} codec tokens (avg {talker_result['gen_time']:.3f}s)") + outputs, full_ids, prompt_len, hidden_time = extract_hidden_states( + hf_model, thinker_result, + ) + hidden_times.append(hidden_time) + print(f" [Hidden] forward pass in {hidden_time:.2f}s") - # Phase 5: Token2Wav (load DiT once, run N times in main process) - audio_duration, t2w_avg, t2w_times = run_token2wav( - model_path, compiled_path, talker_result["codes"], - num_runs, temp_dir, args.output, - ) + t0 = time.time() + talker_input = prepare_talker_input( + model_path, hf_model, outputs, full_ids, prompt_len, args.speaker, + ) + prep_time = time.time() - t0 + prep_times.append(prep_time) + print( + f" [Prep] context={talker_input['context_len']} tokens, " + f"reply={talker_input['n_reply']} tokens ({prep_time:.2f}s)" + ) + + codec_codes, talker_time = run_talker( + talker_model, talker_adapter, talker_cfg, talker_input, + ) + talker_times.append(talker_time) + print(f" [Talker] {len(codec_codes)} codec tokens in {talker_time:.3f}s") + if not codec_codes: + print(" Talker produced no tokens, aborting run.") + continue + + wav, t2w_time = run_token2wav( + t2w, t2w_cfg, codec_codes, + talker_input["conditioning"], talker_input["reference_mel"], + ) + t2w_times.append(t2w_time) + print(f" [Token2Wav] synthesized in {t2w_time:.2f}s") + + if first_text is None: + first_text = thinker_result["gen_text"] + first_codes = codec_codes + first_wav = wav + + # Free the per-run temporaries so the heap doesn't grow across runs. + del outputs, full_ids, talker_input + gc.collect() + + # ----- Write first run's audio ----- + if first_wav is not None and isinstance(first_wav, torch.Tensor) and first_wav.numel() > 0: + wav_np = first_wav.detach().cpu().float().numpy().flatten() + sf.write(args.output, wav_np, 24000) + first_audio_duration = len(wav_np) / 24000 + print(f"\n Audio: {first_audio_duration:.1f}s saved to {args.output}") - # Summary total_time = time.time() - t_total - thinker_avg = thinker_result["gen_time"] - talker_avg = talker_result["gen_time"] - thinker_load = thinker_result.get("load_time", 0) - talker_load = talker_result.get("load_time", 0) + + def _avg(xs): + return sum(xs) / len(xs) if xs else 0.0 print("\n" + "=" * 60) print("RESULTS") print("=" * 60) - print(f" Text: {thinker_result['gen_text'][:200]}") - print(f"\n Model load time (one-time cost, excluded from pipeline avg):") + if first_text: + print(f" Text: {first_text[:200]}") + print("\n Model load time (one-time cost, excluded from pipeline avg):") print(f" Thinker: {thinker_load:.1f}s") + print(f" HF CPU: {hf_load:.1f}s") print(f" Talker: {talker_load:.1f}s") - print(f"\n Inference latency (avg of {num_runs} runs, model already loaded):") - print(f" Thinker: {thinker_avg:.3f}s ({thinker_result['n_tokens']} tokens)") - print(f" Talker: {talker_avg:.3f}s ({len(talker_result['codes'])} codec tokens)") - print(f" Token2Wav: {t2w_avg:.2f}s") - pipeline_avg = thinker_avg + talker_avg + t2w_avg - print(f" Pipeline: {pipeline_avg:.2f}s total") - if audio_duration > 0: - print(f"\n Audio: {audio_duration:.1f}s") - print(f" RTF: {pipeline_avg/audio_duration:.2f}x (pipeline_avg / audio_duration)") + print(f" DiT: {dit_load:.1f}s") + print(f" Total: {total_load:.1f}s") + print(f"\n Per-run latency (avg of {num_runs} runs):") + print(f" Thinker: {_avg(thinker_times):.3f}s") + print(f" Hidden: {_avg(hidden_times):.3f}s (HF CPU forward)") + print(f" Prep: {_avg(prep_times):.3f}s") + print(f" Talker: {_avg(talker_times):.3f}s") + print(f" Token2Wav: {_avg(t2w_times):.2f}s") + pipeline_avg = ( + _avg(thinker_times) + _avg(hidden_times) + _avg(prep_times) + + _avg(talker_times) + _avg(t2w_times) + ) + print(f" Pipeline: {pipeline_avg:.2f}s total") + if first_audio_duration > 0: + print(f"\n Audio: {first_audio_duration:.1f}s") + print(f" RTF: {pipeline_avg/first_audio_duration:.2f}x") if num_runs > 1: - thinker_times = thinker_result.get("all_times", []) - talker_times = talker_result.get("all_times", []) print(f"\n Per-run breakdown ({num_runs} runs):") - print(f" Thinker: {['%.3f' % t for t in thinker_times]}") - print(f" Talker: {['%.3f' % t for t in talker_times]}") - print(f" Token2Wav: {['%.2f' % t for t in t2w_times]}") - print(f"\n Wall time: {total_time:.1f}s (includes model loading + CPU phases)") - print(f" Output: {args.output}") + print(f" Thinker: {['%.3f' % t for t in thinker_times]}") + print(f" Hidden: {['%.3f' % t for t in hidden_times]}") + print(f" Talker: {['%.3f' % t for t in talker_times]}") + print(f" Token2Wav: {['%.2f' % t for t in t2w_times]}") + print(f"\n Wall time: {total_time:.1f}s (load + {num_runs} run(s))") + print(f" Output: {args.output}") if __name__ == "__main__": From 495df41f8c1ee375549ff9c918440cb04da690e9 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 15:00:31 +0800 Subject: [PATCH 06/14] Compile DiT with parallel_model_trace(TP=4) so it co-locates with Thinker/Talker Previously the Token2Wav DiT core was compiled with torch_neuronx.trace() (single-device). When the whole speech pipeline runs in one Python process (after the subprocess removal), the Neuron runtime places a single-device NEFF on a different core group than the TP=4 Thinker/Talker. The result is a cross-core-group scheduling penalty: DiT per-call jumps from ~10.7s (when DiT had its own subprocess and core group) to ~14-18s. Switch DiT compilation to neuronx_distributed.trace.parallel_model_trace with tp_degree=4 in *replicated* mode: - Linears inside _NeuronDiTCore are NOT sharded; DiT has only ~85M params, so there's no memory win from sharding and the code change would be much larger. - The win is pure co-location: all three Neuron models now live on the same core group (0..TP-1) and the runtime schedules their NEFFs as peers. - parallel_model_trace takes a no-arg builder callable; captured the state dict and _block_mask_idx list so the builder can rehydrate the same module on the XLA device for each rank. Artifact layout change: - New: /dit_core/dit_core_parallel/ (directory, via parallel_model_save; reloaded with parallel_model_load). - Legacy: /dit_core/dit_core_neuron.pt (single-file torch.jit). load_dit() still accepts this for backwards compat and prints a warning recommending recompile. - compile_all() and _check_compiled() both recognize either artifact. Co-Authored-By: Claude Opus 4.7 --- .../examples/generate_qwen25_omni_speech.py | 31 ++++-- .../src/modeling_qwen25_omni_token2wav.py | 96 ++++++++++++++----- 2 files changed, 93 insertions(+), 34 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py index c903e496..58f490e9 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -181,14 +181,16 @@ def compile_all(model_path, compiled_path): t_total = time.time() stages = [ - ("Thinker", "thinker_tp4", "neuron_config.json", _compile_thinker), - ("Talker", "talker_tp4", "neuron_config.json", _compile_talker), - ("DiT", "dit_core", "dit_core_neuron.pt", _compile_dit), + ("Thinker", "thinker_tp4", ["neuron_config.json"], _compile_thinker), + ("Talker", "talker_tp4", ["neuron_config.json"], _compile_talker), + # DiT has two possible artifacts: TP-replicated directory (current) + # or the legacy single-file .pt (pre-TP rewrite). + ("DiT", "dit_core", ["dit_core_parallel", "dit_core_neuron.pt"], _compile_dit), ] - for idx, (label, subdir, marker, fn) in enumerate(stages, 1): + for idx, (label, subdir, markers, fn) in enumerate(stages, 1): print(f"\n--- [{idx}/{len(stages)}] Compiling {label} ---") out_path = os.path.join(compiled_path, subdir) - if os.path.exists(os.path.join(out_path, marker)): + if any(os.path.exists(os.path.join(out_path, m)) for m in markers): print(" Already compiled, skipping.") continue t0 = time.time() @@ -205,12 +207,23 @@ def compile_all(model_path, compiled_path): # ========================================================================== def _check_compiled(compiled_path): + """Confirm that each component has been compiled. + + Accepts both the new TP-replicated DiT artifact (``dit_core_parallel/``) + and the legacy single-device one (``dit_core_neuron.pt``). + """ checks = [ - (os.path.join(compiled_path, "thinker_tp4", "neuron_config.json"), "Thinker"), - (os.path.join(compiled_path, "talker_tp4", "neuron_config.json"), "Talker"), - (os.path.join(compiled_path, "dit_core", "dit_core_neuron.pt"), "DiT"), + ([os.path.join(compiled_path, "thinker_tp4", "neuron_config.json")], "Thinker"), + ([os.path.join(compiled_path, "talker_tp4", "neuron_config.json")], "Talker"), + ( + [ + os.path.join(compiled_path, "dit_core", "dit_core_parallel"), + os.path.join(compiled_path, "dit_core", "dit_core_neuron.pt"), + ], + "DiT", + ), ] - missing = [name for path, name in checks if not os.path.exists(path)] + missing = [name for paths, name in checks if not any(os.path.exists(p) for p in paths)] if missing: print(f"ERROR: Missing compiled artifacts for: {', '.join(missing)}") print(f"Run with --compile first:") diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py index 9e37ae6c..71672ee2 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -408,26 +408,41 @@ def compile_dit( compiled_path, max_mel_len=2048, batch_size=2, + tp_degree=4, ): - """Compile the DiT transformer core on Neuron. + """Compile the DiT transformer core on Neuron using TP=4. Only the 22 transformer blocks + norm_out + proj_out are compiled. Preprocessing (ECAPA-TDNN, codec embedding, input embedding, rotary) stays on CPU to avoid XLA tracing issues. + Uses ``neuronx_distributed.trace.parallel_model_trace`` (replicated, + not sharded) so the DiT lives on the same NeuronCore group (0..tp_degree-1) + as the Thinker and Talker. This matters when all three models share one + Python process: a single-device ``torch_neuronx.trace`` NEFF gets + placed on a separate core group and pays a cross-group scheduling + penalty (~4s per DiT forward on trn2.48xlarge). Replicating onto + the same TP group makes the NeuronCore runtime treat all three as + peers on the same logical device set. + + DiT itself is small (~85M params) so there is no memory win from + sharding the linears; the win is purely co-location. + Args: - compiled_path: Directory to save compiled model + compiled_path: Directory to save compiled model. max_mel_len: Maximum mel spectrogram length (covers ~24s audio). Shorter inputs are padded; longer inputs fall back to CPU. batch_size: Batch size for compilation. Use 2 for standard inference with classifier-free guidance (CFG doubles batch). + tp_degree: Replication degree; should match the Thinker/Talker + ``tp_degree`` so all three live on the same NeuronCore group. """ try: - import torch_neuronx + from neuronx_distributed.trace import parallel_model_trace except ImportError: raise ImportError( - "torch_neuronx required for DiT compilation. " - "Run on a Neuron instance (trn1/trn2/inf2)." + "neuronx_distributed required for DiT compilation. " + "Run on a Neuron instance with the NxDI venv active." ) os.makedirs(compiled_path, exist_ok=True) @@ -443,17 +458,28 @@ def compile_dit( head_dim = dim // num_heads logger.info( - "Compiling DiT core: batch=%d, mel_len=%d, dim=%d, heads=%d", - batch_size, max_mel_len, dim, num_heads, + "Compiling DiT core: batch=%d, mel_len=%d, dim=%d, heads=%d, tp=%d", + batch_size, max_mel_len, dim, num_heads, tp_degree, ) # Monkeypatch DiTAttention to fix in-place slice assignment _monkeypatch_dit_attention_for_neuron(dit) - # Create wrapper for transformer core only - core = _NeuronDiTCore(dit) - core.float() - core.eval() + # Capture state dict so parallel_model_trace's builder closure can + # reconstruct _NeuronDiTCore on the XLA device with the right weights. + core_template = _NeuronDiTCore(dit) + core_template.float() + core_template.eval() + core_state_dict = core_template.state_dict() + block_mask_idx = core_template._block_mask_idx + + def build_core(): + fresh = _NeuronDiTCore(dit) + fresh.float() + fresh.eval() + fresh._block_mask_idx = block_mask_idx + fresh.load_state_dict(core_state_dict) + return fresh # Create example inputs # time_embedding uses batch=1 (broadcasts to hidden_states batch) @@ -478,10 +504,11 @@ def compile_dit( batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 ) - compiled = torch_neuronx.trace( - core, + compiled = parallel_model_trace( + build_core, (hidden_states, time_embedding, cos, sin, mask_local, mask_backward, mask_ahead), + tp_degree=tp_degree, compiler_args=[ "--auto-cast=none", "--model-type=transformer", @@ -489,8 +516,12 @@ def compile_dit( ], ) - save_path = os.path.join(compiled_path, "dit_core_neuron.pt") - torch.jit.save(compiled, save_path) + # parallel_model_trace produces a ParallelModel that serializes as a + # directory (multiple per-rank artifacts), not a single .pt file. + save_dir = os.path.join(compiled_path, "dit_core_parallel") + os.makedirs(save_dir, exist_ok=True) + from neuronx_distributed.trace import parallel_model_save + parallel_model_save(compiled, save_dir) # Save metadata for load meta = { @@ -499,11 +530,12 @@ def compile_dit( "dim": dim, "num_heads": num_heads, "head_dim": head_dim, + "tp_degree": tp_degree, } with open(os.path.join(compiled_path, "dit_core_meta.json"), "w") as f: json.dump(meta, f) - logger.info("Compiled DiT core saved to %s", save_path) + logger.info("Compiled DiT core (TP=%d) saved to %s", tp_degree, save_dir) self._neuron_dit_core = compiled self._dit_compiled_path = compiled_path @@ -513,18 +545,34 @@ def compile_dit( def load_dit(self, compiled_path): """Load a previously compiled DiT core model. - Args: - compiled_path: Directory containing compiled model + Supports both the old single-device ``torch.jit`` artifact (filename + ``dit_core_neuron.pt``) and the new TP-replicated ``parallel_model`` + artifact (directory ``dit_core_parallel/``). Loading a legacy + single-device artifact will work but pays the cross-core-group + scheduling penalty described in ``compile_dit``. """ - save_path = os.path.join(compiled_path, "dit_core_neuron.pt") meta_path = os.path.join(compiled_path, "dit_core_meta.json") - - if not os.path.exists(save_path): + parallel_dir = os.path.join(compiled_path, "dit_core_parallel") + legacy_path = os.path.join(compiled_path, "dit_core_neuron.pt") + + if os.path.isdir(parallel_dir): + from neuronx_distributed.trace import parallel_model_load + self._neuron_dit_core = parallel_model_load(parallel_dir) + logger.info("Loaded TP-replicated DiT core from %s", parallel_dir) + elif os.path.exists(legacy_path): + self._neuron_dit_core = torch.jit.load(legacy_path) + logger.warning( + "Loaded legacy single-device DiT core from %s; recompile with " + "compile_dit() to get the TP=4 replicated artifact and avoid " + "cross-core-group scheduling overhead when running alongside " + "the Thinker and Talker.", + legacy_path, + ) + else: raise FileNotFoundError( - f"Compiled DiT core not found at {save_path}" + f"Compiled DiT core not found at {parallel_dir} or {legacy_path}" ) - self._neuron_dit_core = torch.jit.load(save_path) self._dit_compiled_path = compiled_path if os.path.exists(meta_path): @@ -533,8 +581,6 @@ def load_dit(self, compiled_path): self._dit_max_mel_len = meta["max_mel_len"] self._dit_batch_size = meta["batch_size"] - logger.info("Loaded compiled DiT core from %s", save_path) - def _build_attention_masks(self, block_diff, actual_mel_len, max_mel_len): """Build three per-block float additive attention masks with padding. From bdd5a2a2780baea56c4060a315006b19f6c91b5b Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 15:07:13 +0800 Subject: [PATCH 07/14] Fix DiT compile: pass builder args to spawn'd child via a temp pickle The previous attempt defined the parallel_model_trace builder as a closure inside compile_dit, which fails under start_method='spawn' (the child can't pickle a local function). Switched to a module-level, importable builder _build_dit_core_for_trace. But spawn'd children don't inherit globals either, so the builder can't read the DiT module/state_dict from a module-level variable. Use an env var to point the child at a torch.save()'d temp file written right before parallel_model_trace. The builder loads it back on first call and then the file is cleaned up (and env var unset) in a finally block. The temp file is ~350MB (weights + dit reference). It lives next to the compiled artifacts and is removed after compilation. Co-Authored-By: Claude Opus 4.7 --- .../src/modeling_qwen25_omni_token2wav.py | 95 +++++++++++++++---- 1 file changed, 74 insertions(+), 21 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py index 71672ee2..b2c7a13c 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -279,6 +279,47 @@ def forward(hidden_states, position_embeddings=None, attention_mask=None): attn.forward = _make_patched(attn) +# File path used by the picklable builder `_build_dit_core_for_trace`. +# parallel_model_trace spawns child processes via torch.multiprocessing (start +# method = 'spawn'), so the builder must be importable by fully qualified name, +# and its only dependency — the DiT state dict + original module — has to be +# recoverable inside the child. We stash it on disk (torch.save) and let the +# child reload it: the Qwen2.5-Omni-7B repo weights are already on disk anyway, +# so an extra ~350MB pickle round-trip is acceptable. +_DIT_BUILDER_STASH_PATH: str = "" + + +def _build_dit_core_for_trace(): + """Module-level, picklable builder for parallel_model_trace. + + Reads the stashed pickle (see `_DIT_BUILDER_STASH_PATH`) that + `NeuronQwen25OmniToken2WavWithNeuronDiT.compile_dit` wrote before invoking + parallel_model_trace. The file is a plain ``torch.save`` of + ``{"dit_module": ..., "state_dict": ..., "block_mask_idx": ...}``. + """ + import os as _os + import torch as _torch + + stash_path = _os.environ.get("_QWEN25_OMNI_DIT_STASH", "") + if not stash_path or not _os.path.isfile(stash_path): + raise RuntimeError( + "Builder stash not found. Expected _QWEN25_OMNI_DIT_STASH env var " + "to point at a torch.save()'d dict written by compile_dit()." + ) + + payload = _torch.load(stash_path, weights_only=False, map_location="cpu") + dit = payload["dit_module"] + state_dict = payload["state_dict"] + block_mask_idx = payload["block_mask_idx"] + + fresh = _NeuronDiTCore(dit) + fresh.float() + fresh.eval() + fresh._block_mask_idx = block_mask_idx + fresh.load_state_dict(state_dict) + return fresh + + class _NeuronDiTCore(torch.nn.Module): """Traced wrapper for DiT transformer blocks + norm_out + proj_out. @@ -465,21 +506,26 @@ def compile_dit( # Monkeypatch DiTAttention to fix in-place slice assignment _monkeypatch_dit_attention_for_neuron(dit) - # Capture state dict so parallel_model_trace's builder closure can + # Capture state dict so parallel_model_trace's builder can # reconstruct _NeuronDiTCore on the XLA device with the right weights. + # The builder must be a module-level function (see _build_dit_core_for_trace + # above) because parallel_model_trace pickles the builder across spawn'd + # processes. Since spawn'd children don't inherit globals, we write the + # inputs to a temp file and point the child at it via an env var. core_template = _NeuronDiTCore(dit) core_template.float() core_template.eval() - core_state_dict = core_template.state_dict() - block_mask_idx = core_template._block_mask_idx - def build_core(): - fresh = _NeuronDiTCore(dit) - fresh.float() - fresh.eval() - fresh._block_mask_idx = block_mask_idx - fresh.load_state_dict(core_state_dict) - return fresh + stash_path = os.path.join(compiled_path, "_dit_builder_stash.pt") + torch.save( + { + "dit_module": dit, + "state_dict": core_template.state_dict(), + "block_mask_idx": core_template._block_mask_idx, + }, + stash_path, + ) + os.environ["_QWEN25_OMNI_DIT_STASH"] = stash_path # Create example inputs # time_embedding uses batch=1 (broadcasts to hidden_states batch) @@ -504,17 +550,24 @@ def build_core(): batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 ) - compiled = parallel_model_trace( - build_core, - (hidden_states, time_embedding, cos, sin, - mask_local, mask_backward, mask_ahead), - tp_degree=tp_degree, - compiler_args=[ - "--auto-cast=none", - "--model-type=transformer", - "-O1", - ], - ) + try: + compiled = parallel_model_trace( + _build_dit_core_for_trace, + (hidden_states, time_embedding, cos, sin, + mask_local, mask_backward, mask_ahead), + tp_degree=tp_degree, + compiler_args=[ + "--auto-cast=none", + "--model-type=transformer", + "-O1", + ], + ) + finally: + os.environ.pop("_QWEN25_OMNI_DIT_STASH", None) + try: + os.remove(stash_path) + except OSError: + pass # parallel_model_trace produces a ParallelModel that serializes as a # directory (multiple per-rank artifacts), not a single .pt file. From 462e605411cd855af7fb1c527fab5b3d95c3f13f Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 15:14:59 +0800 Subject: [PATCH 08/14] Fix DiT spawn'd builder: rebuild DiT from checkpoint, don't pickle the module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous attempt torch.save()'d the DiT nn.Module into a stash file for the spawn'd child to reload. That fails because the Neuron SDK runtime-patches torch classes like nn.Embedding, so the module tree contains class objects whose identity doesn't match torch.nn.modules.sparse.Embedding — the pickler raises: _pickle.PicklingError: Can't pickle : it's not the same object as torch.nn.modules.sparse.Embedding Rework the builder so each spawn'd child rebuilds the DiT itself from the HuggingFace checkpoint on disk. We stash only plain tensors / ints (the transformer-core state_dict + per-block mask indices) and pass the model path and stash path via env vars (which spawn children DO inherit). Additional env-var plumbing for spawn: - PYTHONPATH is extended with this module's directory so the child can import `modeling_qwen25_omni_token2wav` (and therefore find `_build_dit_core_for_trace`). The parent state is restored in the finally block. - _QWEN25_OMNI_DIT_MODEL_PATH points children at the HF checkpoint. - _QWEN25_OMNI_DIT_STASH points children at the tensor stash. compile_dit() now requires `model_path=...`; the caller in generate_qwen25_omni_speech.py already knows it. Co-Authored-By: Claude Opus 4.7 --- .../examples/generate_qwen25_omni_speech.py | 2 +- .../src/modeling_qwen25_omni_token2wav.py | 110 +++++++++++++----- 2 files changed, 83 insertions(+), 29 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py index 58f490e9..b071241c 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -162,7 +162,7 @@ def _compile_dit(model_path, out_path): if k.startswith("token2wav."): state_dict[k[len("token2wav."):]] = v t2w.load_state_dict(state_dict, strict=False) - t2w.compile_dit(out_path, max_mel_len=2048, batch_size=2) + t2w.compile_dit(out_path, max_mel_len=2048, batch_size=2, model_path=model_path) def compile_all(model_path, compiled_path): diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py index b2c7a13c..c18f50ba 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -279,44 +279,67 @@ def forward(hidden_states, position_embeddings=None, attention_mask=None): attn.forward = _make_patched(attn) -# File path used by the picklable builder `_build_dit_core_for_trace`. -# parallel_model_trace spawns child processes via torch.multiprocessing (start -# method = 'spawn'), so the builder must be importable by fully qualified name, -# and its only dependency — the DiT state dict + original module — has to be -# recoverable inside the child. We stash it on disk (torch.save) and let the -# child reload it: the Qwen2.5-Omni-7B repo weights are already on disk anyway, -# so an extra ~350MB pickle round-trip is acceptable. -_DIT_BUILDER_STASH_PATH: str = "" - - def _build_dit_core_for_trace(): """Module-level, picklable builder for parallel_model_trace. - Reads the stashed pickle (see `_DIT_BUILDER_STASH_PATH`) that - `NeuronQwen25OmniToken2WavWithNeuronDiT.compile_dit` wrote before invoking - parallel_model_trace. The file is a plain ``torch.save`` of - ``{"dit_module": ..., "state_dict": ..., "block_mask_idx": ...}``. + parallel_model_trace spawns child processes via torch.multiprocessing with + ``start_method='spawn'``, so: + + - The builder must be importable by fully qualified name (a nested closure + can't be pickled). + - Spawn'd children don't inherit parent globals, so we can't pass the DiT + module via a module-level variable — and ``torch.save``'ing the module + tree fails because Neuron SDK patches classes like ``nn.Embedding``, + producing "not the same object" pickler errors. + + Solution: the child builds the DiT from the HuggingFace checkpoint itself, + using a model path passed via an env var. The child loads only the + token2wav weights from the safetensors shards, plus a pre-saved copy of + the block-mask-index list and the DiT state dict (both are cheap to pickle + since they're just tensors / ints). """ import os as _os + import json as _json import torch as _torch + model_path = _os.environ.get("_QWEN25_OMNI_DIT_MODEL_PATH", "") stash_path = _os.environ.get("_QWEN25_OMNI_DIT_STASH", "") - if not stash_path or not _os.path.isfile(stash_path): + if not model_path or not stash_path or not _os.path.isfile(stash_path): raise RuntimeError( - "Builder stash not found. Expected _QWEN25_OMNI_DIT_STASH env var " - "to point at a torch.save()'d dict written by compile_dit()." + "DiT builder: expected _QWEN25_OMNI_DIT_MODEL_PATH and " + "_QWEN25_OMNI_DIT_STASH env vars to be set by compile_dit()." ) - payload = _torch.load(stash_path, weights_only=False, map_location="cpu") - dit = payload["dit_module"] - state_dict = payload["state_dict"] + from transformers import AutoConfig + from safetensors.torch import load_file + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(hf_config.token2wav_config) + + state_dict = {} + for fn in sorted(_os.listdir(model_path)): + if fn.endswith(".safetensors"): + sd = load_file(_os.path.join(model_path, fn)) + for k, v in sd.items(): + if k.startswith("token2wav."): + state_dict[k[len("token2wav."):]] = v + t2w.load_state_dict(state_dict, strict=False) + + dit = t2w._get_dit_module() + if dit is None: + raise RuntimeError("Could not extract DiT module from Token2Wav.") + + _monkeypatch_dit_attention_for_neuron(dit) + + payload = _torch.load(stash_path, weights_only=True, map_location="cpu") + core_state_dict = payload["state_dict"] block_mask_idx = payload["block_mask_idx"] fresh = _NeuronDiTCore(dit) fresh.float() fresh.eval() fresh._block_mask_idx = block_mask_idx - fresh.load_state_dict(state_dict) + fresh.load_state_dict(core_state_dict) return fresh @@ -450,6 +473,7 @@ def compile_dit( max_mel_len=2048, batch_size=2, tp_degree=4, + model_path=None, ): """Compile the DiT transformer core on Neuron using TP=4. @@ -477,6 +501,11 @@ def compile_dit( inference with classifier-free guidance (CFG doubles batch). tp_degree: Replication degree; should match the Thinker/Talker ``tp_degree`` so all three live on the same NeuronCore group. + model_path: Path to the HuggingFace Qwen2.5-Omni-7B checkpoint. + Required because ``parallel_model_trace`` spawns fresh child + processes that each rebuild the DiT from this checkpoint + (we can't pickle the in-memory DiT due to Neuron-patched + ``nn.Embedding`` class identity). """ try: from neuronx_distributed.trace import parallel_model_trace @@ -486,6 +515,13 @@ def compile_dit( "Run on a Neuron instance with the NxDI venv active." ) + if not model_path: + raise ValueError( + "compile_dit(model_path=...) is required — the spawn'd " + "compilation workers rebuild the DiT from this HuggingFace " + "checkpoint path." + ) + os.makedirs(compiled_path, exist_ok=True) dit = self._get_dit_module() @@ -506,12 +542,13 @@ def compile_dit( # Monkeypatch DiTAttention to fix in-place slice assignment _monkeypatch_dit_attention_for_neuron(dit) - # Capture state dict so parallel_model_trace's builder can - # reconstruct _NeuronDiTCore on the XLA device with the right weights. - # The builder must be a module-level function (see _build_dit_core_for_trace - # above) because parallel_model_trace pickles the builder across spawn'd - # processes. Since spawn'd children don't inherit globals, we write the - # inputs to a temp file and point the child at it via an env var. + # Serialize just what the spawn'd builder needs: the transformer + # core's state_dict (plain tensors) and the per-block mask indices + # (plain ints). We deliberately DO NOT pickle the DiT module itself — + # Neuron SDK patches torch classes like nn.Embedding, which makes + # torch.save trip on "not the same object" identity checks. Each + # child rebuilds the DiT from the on-disk HuggingFace checkpoint + # using the path we pass via _QWEN25_OMNI_DIT_MODEL_PATH. core_template = _NeuronDiTCore(dit) core_template.float() core_template.eval() @@ -519,13 +556,25 @@ def compile_dit( stash_path = os.path.join(compiled_path, "_dit_builder_stash.pt") torch.save( { - "dit_module": dit, "state_dict": core_template.state_dict(), "block_mask_idx": core_template._block_mask_idx, }, stash_path, ) os.environ["_QWEN25_OMNI_DIT_STASH"] = stash_path + os.environ["_QWEN25_OMNI_DIT_MODEL_PATH"] = str(model_path) + + # The builder lives in this module, which normally gets found via a + # sys.path.insert bootstrap in the caller. Spawn'd children won't have + # that, so push this module's directory onto PYTHONPATH for them. + _module_dir = os.path.dirname(os.path.abspath(__file__)) + existing_pythonpath = os.environ.get("PYTHONPATH", "") + prior_pythonpath = existing_pythonpath # restore in finally + if _module_dir not in existing_pythonpath.split(os.pathsep): + os.environ["PYTHONPATH"] = ( + f"{_module_dir}{os.pathsep}{existing_pythonpath}" + if existing_pythonpath else _module_dir + ) # Create example inputs # time_embedding uses batch=1 (broadcasts to hidden_states batch) @@ -564,6 +613,11 @@ def compile_dit( ) finally: os.environ.pop("_QWEN25_OMNI_DIT_STASH", None) + os.environ.pop("_QWEN25_OMNI_DIT_MODEL_PATH", None) + if prior_pythonpath: + os.environ["PYTHONPATH"] = prior_pythonpath + else: + os.environ.pop("PYTHONPATH", None) try: os.remove(stash_path) except OSError: From 9f8bcc19cfe4574a5c97fb86db997288141f1370 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 15:19:37 +0800 Subject: [PATCH 09/14] DiT builder: return (model, {}) tuple for parallel_model_trace parallel_model_trace's worker unpacks the builder result as ``model, input_output_alias = func(**func_kwargs)``. We were returning just the model, producing ``TypeError: cannot unpack non-iterable _NeuronDiTCore``. DiT has no input/output weight aliasing, so the alias dict is empty. Co-Authored-By: Claude Opus 4.7 --- .../Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py index c18f50ba..be421e5f 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -340,7 +340,9 @@ def _build_dit_core_for_trace(): fresh.eval() fresh._block_mask_idx = block_mask_idx fresh.load_state_dict(core_state_dict) - return fresh + # parallel_model_trace's worker expects (model, input_output_alias). + # DiT has no input/output weight aliasing, so the alias dict is empty. + return fresh, {} class _NeuronDiTCore(torch.nn.Module): From d8afc84cfb5ec708ec86f524c8f05eace462bc8a Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 15:55:00 +0800 Subject: [PATCH 10/14] Revert "DiT builder: return (model, {}) tuple for parallel_model_trace" This reverts commit 9f8bcc19cfe4574a5c97fb86db997288141f1370. --- .../Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py index be421e5f..c18f50ba 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -340,9 +340,7 @@ def _build_dit_core_for_trace(): fresh.eval() fresh._block_mask_idx = block_mask_idx fresh.load_state_dict(core_state_dict) - # parallel_model_trace's worker expects (model, input_output_alias). - # DiT has no input/output weight aliasing, so the alias dict is empty. - return fresh, {} + return fresh class _NeuronDiTCore(torch.nn.Module): From 2ec0a38d447ed6276378574f3ef8f490d28084f3 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 15:55:00 +0800 Subject: [PATCH 11/14] Revert "Fix DiT spawn'd builder: rebuild DiT from checkpoint, don't pickle the module" This reverts commit 462e605411cd855af7fb1c527fab5b3d95c3f13f. --- .../examples/generate_qwen25_omni_speech.py | 2 +- .../src/modeling_qwen25_omni_token2wav.py | 110 +++++------------- 2 files changed, 29 insertions(+), 83 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py index b071241c..58f490e9 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -162,7 +162,7 @@ def _compile_dit(model_path, out_path): if k.startswith("token2wav."): state_dict[k[len("token2wav."):]] = v t2w.load_state_dict(state_dict, strict=False) - t2w.compile_dit(out_path, max_mel_len=2048, batch_size=2, model_path=model_path) + t2w.compile_dit(out_path, max_mel_len=2048, batch_size=2) def compile_all(model_path, compiled_path): diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py index c18f50ba..b2c7a13c 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -279,67 +279,44 @@ def forward(hidden_states, position_embeddings=None, attention_mask=None): attn.forward = _make_patched(attn) +# File path used by the picklable builder `_build_dit_core_for_trace`. +# parallel_model_trace spawns child processes via torch.multiprocessing (start +# method = 'spawn'), so the builder must be importable by fully qualified name, +# and its only dependency — the DiT state dict + original module — has to be +# recoverable inside the child. We stash it on disk (torch.save) and let the +# child reload it: the Qwen2.5-Omni-7B repo weights are already on disk anyway, +# so an extra ~350MB pickle round-trip is acceptable. +_DIT_BUILDER_STASH_PATH: str = "" + + def _build_dit_core_for_trace(): """Module-level, picklable builder for parallel_model_trace. - parallel_model_trace spawns child processes via torch.multiprocessing with - ``start_method='spawn'``, so: - - - The builder must be importable by fully qualified name (a nested closure - can't be pickled). - - Spawn'd children don't inherit parent globals, so we can't pass the DiT - module via a module-level variable — and ``torch.save``'ing the module - tree fails because Neuron SDK patches classes like ``nn.Embedding``, - producing "not the same object" pickler errors. - - Solution: the child builds the DiT from the HuggingFace checkpoint itself, - using a model path passed via an env var. The child loads only the - token2wav weights from the safetensors shards, plus a pre-saved copy of - the block-mask-index list and the DiT state dict (both are cheap to pickle - since they're just tensors / ints). + Reads the stashed pickle (see `_DIT_BUILDER_STASH_PATH`) that + `NeuronQwen25OmniToken2WavWithNeuronDiT.compile_dit` wrote before invoking + parallel_model_trace. The file is a plain ``torch.save`` of + ``{"dit_module": ..., "state_dict": ..., "block_mask_idx": ...}``. """ import os as _os - import json as _json import torch as _torch - model_path = _os.environ.get("_QWEN25_OMNI_DIT_MODEL_PATH", "") stash_path = _os.environ.get("_QWEN25_OMNI_DIT_STASH", "") - if not model_path or not stash_path or not _os.path.isfile(stash_path): + if not stash_path or not _os.path.isfile(stash_path): raise RuntimeError( - "DiT builder: expected _QWEN25_OMNI_DIT_MODEL_PATH and " - "_QWEN25_OMNI_DIT_STASH env vars to be set by compile_dit()." + "Builder stash not found. Expected _QWEN25_OMNI_DIT_STASH env var " + "to point at a torch.save()'d dict written by compile_dit()." ) - from transformers import AutoConfig - from safetensors.torch import load_file - - hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(hf_config.token2wav_config) - - state_dict = {} - for fn in sorted(_os.listdir(model_path)): - if fn.endswith(".safetensors"): - sd = load_file(_os.path.join(model_path, fn)) - for k, v in sd.items(): - if k.startswith("token2wav."): - state_dict[k[len("token2wav."):]] = v - t2w.load_state_dict(state_dict, strict=False) - - dit = t2w._get_dit_module() - if dit is None: - raise RuntimeError("Could not extract DiT module from Token2Wav.") - - _monkeypatch_dit_attention_for_neuron(dit) - - payload = _torch.load(stash_path, weights_only=True, map_location="cpu") - core_state_dict = payload["state_dict"] + payload = _torch.load(stash_path, weights_only=False, map_location="cpu") + dit = payload["dit_module"] + state_dict = payload["state_dict"] block_mask_idx = payload["block_mask_idx"] fresh = _NeuronDiTCore(dit) fresh.float() fresh.eval() fresh._block_mask_idx = block_mask_idx - fresh.load_state_dict(core_state_dict) + fresh.load_state_dict(state_dict) return fresh @@ -473,7 +450,6 @@ def compile_dit( max_mel_len=2048, batch_size=2, tp_degree=4, - model_path=None, ): """Compile the DiT transformer core on Neuron using TP=4. @@ -501,11 +477,6 @@ def compile_dit( inference with classifier-free guidance (CFG doubles batch). tp_degree: Replication degree; should match the Thinker/Talker ``tp_degree`` so all three live on the same NeuronCore group. - model_path: Path to the HuggingFace Qwen2.5-Omni-7B checkpoint. - Required because ``parallel_model_trace`` spawns fresh child - processes that each rebuild the DiT from this checkpoint - (we can't pickle the in-memory DiT due to Neuron-patched - ``nn.Embedding`` class identity). """ try: from neuronx_distributed.trace import parallel_model_trace @@ -515,13 +486,6 @@ def compile_dit( "Run on a Neuron instance with the NxDI venv active." ) - if not model_path: - raise ValueError( - "compile_dit(model_path=...) is required — the spawn'd " - "compilation workers rebuild the DiT from this HuggingFace " - "checkpoint path." - ) - os.makedirs(compiled_path, exist_ok=True) dit = self._get_dit_module() @@ -542,13 +506,12 @@ def compile_dit( # Monkeypatch DiTAttention to fix in-place slice assignment _monkeypatch_dit_attention_for_neuron(dit) - # Serialize just what the spawn'd builder needs: the transformer - # core's state_dict (plain tensors) and the per-block mask indices - # (plain ints). We deliberately DO NOT pickle the DiT module itself — - # Neuron SDK patches torch classes like nn.Embedding, which makes - # torch.save trip on "not the same object" identity checks. Each - # child rebuilds the DiT from the on-disk HuggingFace checkpoint - # using the path we pass via _QWEN25_OMNI_DIT_MODEL_PATH. + # Capture state dict so parallel_model_trace's builder can + # reconstruct _NeuronDiTCore on the XLA device with the right weights. + # The builder must be a module-level function (see _build_dit_core_for_trace + # above) because parallel_model_trace pickles the builder across spawn'd + # processes. Since spawn'd children don't inherit globals, we write the + # inputs to a temp file and point the child at it via an env var. core_template = _NeuronDiTCore(dit) core_template.float() core_template.eval() @@ -556,25 +519,13 @@ def compile_dit( stash_path = os.path.join(compiled_path, "_dit_builder_stash.pt") torch.save( { + "dit_module": dit, "state_dict": core_template.state_dict(), "block_mask_idx": core_template._block_mask_idx, }, stash_path, ) os.environ["_QWEN25_OMNI_DIT_STASH"] = stash_path - os.environ["_QWEN25_OMNI_DIT_MODEL_PATH"] = str(model_path) - - # The builder lives in this module, which normally gets found via a - # sys.path.insert bootstrap in the caller. Spawn'd children won't have - # that, so push this module's directory onto PYTHONPATH for them. - _module_dir = os.path.dirname(os.path.abspath(__file__)) - existing_pythonpath = os.environ.get("PYTHONPATH", "") - prior_pythonpath = existing_pythonpath # restore in finally - if _module_dir not in existing_pythonpath.split(os.pathsep): - os.environ["PYTHONPATH"] = ( - f"{_module_dir}{os.pathsep}{existing_pythonpath}" - if existing_pythonpath else _module_dir - ) # Create example inputs # time_embedding uses batch=1 (broadcasts to hidden_states batch) @@ -613,11 +564,6 @@ def compile_dit( ) finally: os.environ.pop("_QWEN25_OMNI_DIT_STASH", None) - os.environ.pop("_QWEN25_OMNI_DIT_MODEL_PATH", None) - if prior_pythonpath: - os.environ["PYTHONPATH"] = prior_pythonpath - else: - os.environ.pop("PYTHONPATH", None) try: os.remove(stash_path) except OSError: From 8b187e3e7a37dae9df20e06183f15b38a36e27cc Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 15:55:00 +0800 Subject: [PATCH 12/14] Revert "Fix DiT compile: pass builder args to spawn'd child via a temp pickle" This reverts commit bdd5a2a2780baea56c4060a315006b19f6c91b5b. --- .../src/modeling_qwen25_omni_token2wav.py | 95 ++++--------------- 1 file changed, 21 insertions(+), 74 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py index b2c7a13c..71672ee2 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -279,47 +279,6 @@ def forward(hidden_states, position_embeddings=None, attention_mask=None): attn.forward = _make_patched(attn) -# File path used by the picklable builder `_build_dit_core_for_trace`. -# parallel_model_trace spawns child processes via torch.multiprocessing (start -# method = 'spawn'), so the builder must be importable by fully qualified name, -# and its only dependency — the DiT state dict + original module — has to be -# recoverable inside the child. We stash it on disk (torch.save) and let the -# child reload it: the Qwen2.5-Omni-7B repo weights are already on disk anyway, -# so an extra ~350MB pickle round-trip is acceptable. -_DIT_BUILDER_STASH_PATH: str = "" - - -def _build_dit_core_for_trace(): - """Module-level, picklable builder for parallel_model_trace. - - Reads the stashed pickle (see `_DIT_BUILDER_STASH_PATH`) that - `NeuronQwen25OmniToken2WavWithNeuronDiT.compile_dit` wrote before invoking - parallel_model_trace. The file is a plain ``torch.save`` of - ``{"dit_module": ..., "state_dict": ..., "block_mask_idx": ...}``. - """ - import os as _os - import torch as _torch - - stash_path = _os.environ.get("_QWEN25_OMNI_DIT_STASH", "") - if not stash_path or not _os.path.isfile(stash_path): - raise RuntimeError( - "Builder stash not found. Expected _QWEN25_OMNI_DIT_STASH env var " - "to point at a torch.save()'d dict written by compile_dit()." - ) - - payload = _torch.load(stash_path, weights_only=False, map_location="cpu") - dit = payload["dit_module"] - state_dict = payload["state_dict"] - block_mask_idx = payload["block_mask_idx"] - - fresh = _NeuronDiTCore(dit) - fresh.float() - fresh.eval() - fresh._block_mask_idx = block_mask_idx - fresh.load_state_dict(state_dict) - return fresh - - class _NeuronDiTCore(torch.nn.Module): """Traced wrapper for DiT transformer blocks + norm_out + proj_out. @@ -506,26 +465,21 @@ def compile_dit( # Monkeypatch DiTAttention to fix in-place slice assignment _monkeypatch_dit_attention_for_neuron(dit) - # Capture state dict so parallel_model_trace's builder can + # Capture state dict so parallel_model_trace's builder closure can # reconstruct _NeuronDiTCore on the XLA device with the right weights. - # The builder must be a module-level function (see _build_dit_core_for_trace - # above) because parallel_model_trace pickles the builder across spawn'd - # processes. Since spawn'd children don't inherit globals, we write the - # inputs to a temp file and point the child at it via an env var. core_template = _NeuronDiTCore(dit) core_template.float() core_template.eval() + core_state_dict = core_template.state_dict() + block_mask_idx = core_template._block_mask_idx - stash_path = os.path.join(compiled_path, "_dit_builder_stash.pt") - torch.save( - { - "dit_module": dit, - "state_dict": core_template.state_dict(), - "block_mask_idx": core_template._block_mask_idx, - }, - stash_path, - ) - os.environ["_QWEN25_OMNI_DIT_STASH"] = stash_path + def build_core(): + fresh = _NeuronDiTCore(dit) + fresh.float() + fresh.eval() + fresh._block_mask_idx = block_mask_idx + fresh.load_state_dict(core_state_dict) + return fresh # Create example inputs # time_embedding uses batch=1 (broadcasts to hidden_states batch) @@ -550,24 +504,17 @@ def compile_dit( batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 ) - try: - compiled = parallel_model_trace( - _build_dit_core_for_trace, - (hidden_states, time_embedding, cos, sin, - mask_local, mask_backward, mask_ahead), - tp_degree=tp_degree, - compiler_args=[ - "--auto-cast=none", - "--model-type=transformer", - "-O1", - ], - ) - finally: - os.environ.pop("_QWEN25_OMNI_DIT_STASH", None) - try: - os.remove(stash_path) - except OSError: - pass + compiled = parallel_model_trace( + build_core, + (hidden_states, time_embedding, cos, sin, + mask_local, mask_backward, mask_ahead), + tp_degree=tp_degree, + compiler_args=[ + "--auto-cast=none", + "--model-type=transformer", + "-O1", + ], + ) # parallel_model_trace produces a ParallelModel that serializes as a # directory (multiple per-rank artifacts), not a single .pt file. From d3dd3ef3a63c5ffb2c6923a4b44dfb8044cb304d Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 15:55:00 +0800 Subject: [PATCH 13/14] Revert "Compile DiT with parallel_model_trace(TP=4) so it co-locates with Thinker/Talker" This reverts commit 495df41f8c1ee375549ff9c918440cb04da690e9. --- .../examples/generate_qwen25_omni_speech.py | 31 ++---- .../src/modeling_qwen25_omni_token2wav.py | 96 +++++-------------- 2 files changed, 34 insertions(+), 93 deletions(-) diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py index 58f490e9..c903e496 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -181,16 +181,14 @@ def compile_all(model_path, compiled_path): t_total = time.time() stages = [ - ("Thinker", "thinker_tp4", ["neuron_config.json"], _compile_thinker), - ("Talker", "talker_tp4", ["neuron_config.json"], _compile_talker), - # DiT has two possible artifacts: TP-replicated directory (current) - # or the legacy single-file .pt (pre-TP rewrite). - ("DiT", "dit_core", ["dit_core_parallel", "dit_core_neuron.pt"], _compile_dit), + ("Thinker", "thinker_tp4", "neuron_config.json", _compile_thinker), + ("Talker", "talker_tp4", "neuron_config.json", _compile_talker), + ("DiT", "dit_core", "dit_core_neuron.pt", _compile_dit), ] - for idx, (label, subdir, markers, fn) in enumerate(stages, 1): + for idx, (label, subdir, marker, fn) in enumerate(stages, 1): print(f"\n--- [{idx}/{len(stages)}] Compiling {label} ---") out_path = os.path.join(compiled_path, subdir) - if any(os.path.exists(os.path.join(out_path, m)) for m in markers): + if os.path.exists(os.path.join(out_path, marker)): print(" Already compiled, skipping.") continue t0 = time.time() @@ -207,23 +205,12 @@ def compile_all(model_path, compiled_path): # ========================================================================== def _check_compiled(compiled_path): - """Confirm that each component has been compiled. - - Accepts both the new TP-replicated DiT artifact (``dit_core_parallel/``) - and the legacy single-device one (``dit_core_neuron.pt``). - """ checks = [ - ([os.path.join(compiled_path, "thinker_tp4", "neuron_config.json")], "Thinker"), - ([os.path.join(compiled_path, "talker_tp4", "neuron_config.json")], "Talker"), - ( - [ - os.path.join(compiled_path, "dit_core", "dit_core_parallel"), - os.path.join(compiled_path, "dit_core", "dit_core_neuron.pt"), - ], - "DiT", - ), + (os.path.join(compiled_path, "thinker_tp4", "neuron_config.json"), "Thinker"), + (os.path.join(compiled_path, "talker_tp4", "neuron_config.json"), "Talker"), + (os.path.join(compiled_path, "dit_core", "dit_core_neuron.pt"), "DiT"), ] - missing = [name for paths, name in checks if not any(os.path.exists(p) for p in paths)] + missing = [name for path, name in checks if not os.path.exists(path)] if missing: print(f"ERROR: Missing compiled artifacts for: {', '.join(missing)}") print(f"Run with --compile first:") diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py index 71672ee2..9e37ae6c 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -408,41 +408,26 @@ def compile_dit( compiled_path, max_mel_len=2048, batch_size=2, - tp_degree=4, ): - """Compile the DiT transformer core on Neuron using TP=4. + """Compile the DiT transformer core on Neuron. Only the 22 transformer blocks + norm_out + proj_out are compiled. Preprocessing (ECAPA-TDNN, codec embedding, input embedding, rotary) stays on CPU to avoid XLA tracing issues. - Uses ``neuronx_distributed.trace.parallel_model_trace`` (replicated, - not sharded) so the DiT lives on the same NeuronCore group (0..tp_degree-1) - as the Thinker and Talker. This matters when all three models share one - Python process: a single-device ``torch_neuronx.trace`` NEFF gets - placed on a separate core group and pays a cross-group scheduling - penalty (~4s per DiT forward on trn2.48xlarge). Replicating onto - the same TP group makes the NeuronCore runtime treat all three as - peers on the same logical device set. - - DiT itself is small (~85M params) so there is no memory win from - sharding the linears; the win is purely co-location. - Args: - compiled_path: Directory to save compiled model. + compiled_path: Directory to save compiled model max_mel_len: Maximum mel spectrogram length (covers ~24s audio). Shorter inputs are padded; longer inputs fall back to CPU. batch_size: Batch size for compilation. Use 2 for standard inference with classifier-free guidance (CFG doubles batch). - tp_degree: Replication degree; should match the Thinker/Talker - ``tp_degree`` so all three live on the same NeuronCore group. """ try: - from neuronx_distributed.trace import parallel_model_trace + import torch_neuronx except ImportError: raise ImportError( - "neuronx_distributed required for DiT compilation. " - "Run on a Neuron instance with the NxDI venv active." + "torch_neuronx required for DiT compilation. " + "Run on a Neuron instance (trn1/trn2/inf2)." ) os.makedirs(compiled_path, exist_ok=True) @@ -458,28 +443,17 @@ def compile_dit( head_dim = dim // num_heads logger.info( - "Compiling DiT core: batch=%d, mel_len=%d, dim=%d, heads=%d, tp=%d", - batch_size, max_mel_len, dim, num_heads, tp_degree, + "Compiling DiT core: batch=%d, mel_len=%d, dim=%d, heads=%d", + batch_size, max_mel_len, dim, num_heads, ) # Monkeypatch DiTAttention to fix in-place slice assignment _monkeypatch_dit_attention_for_neuron(dit) - # Capture state dict so parallel_model_trace's builder closure can - # reconstruct _NeuronDiTCore on the XLA device with the right weights. - core_template = _NeuronDiTCore(dit) - core_template.float() - core_template.eval() - core_state_dict = core_template.state_dict() - block_mask_idx = core_template._block_mask_idx - - def build_core(): - fresh = _NeuronDiTCore(dit) - fresh.float() - fresh.eval() - fresh._block_mask_idx = block_mask_idx - fresh.load_state_dict(core_state_dict) - return fresh + # Create wrapper for transformer core only + core = _NeuronDiTCore(dit) + core.float() + core.eval() # Create example inputs # time_embedding uses batch=1 (broadcasts to hidden_states batch) @@ -504,11 +478,10 @@ def build_core(): batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 ) - compiled = parallel_model_trace( - build_core, + compiled = torch_neuronx.trace( + core, (hidden_states, time_embedding, cos, sin, mask_local, mask_backward, mask_ahead), - tp_degree=tp_degree, compiler_args=[ "--auto-cast=none", "--model-type=transformer", @@ -516,12 +489,8 @@ def build_core(): ], ) - # parallel_model_trace produces a ParallelModel that serializes as a - # directory (multiple per-rank artifacts), not a single .pt file. - save_dir = os.path.join(compiled_path, "dit_core_parallel") - os.makedirs(save_dir, exist_ok=True) - from neuronx_distributed.trace import parallel_model_save - parallel_model_save(compiled, save_dir) + save_path = os.path.join(compiled_path, "dit_core_neuron.pt") + torch.jit.save(compiled, save_path) # Save metadata for load meta = { @@ -530,12 +499,11 @@ def build_core(): "dim": dim, "num_heads": num_heads, "head_dim": head_dim, - "tp_degree": tp_degree, } with open(os.path.join(compiled_path, "dit_core_meta.json"), "w") as f: json.dump(meta, f) - logger.info("Compiled DiT core (TP=%d) saved to %s", tp_degree, save_dir) + logger.info("Compiled DiT core saved to %s", save_path) self._neuron_dit_core = compiled self._dit_compiled_path = compiled_path @@ -545,34 +513,18 @@ def build_core(): def load_dit(self, compiled_path): """Load a previously compiled DiT core model. - Supports both the old single-device ``torch.jit`` artifact (filename - ``dit_core_neuron.pt``) and the new TP-replicated ``parallel_model`` - artifact (directory ``dit_core_parallel/``). Loading a legacy - single-device artifact will work but pays the cross-core-group - scheduling penalty described in ``compile_dit``. + Args: + compiled_path: Directory containing compiled model """ + save_path = os.path.join(compiled_path, "dit_core_neuron.pt") meta_path = os.path.join(compiled_path, "dit_core_meta.json") - parallel_dir = os.path.join(compiled_path, "dit_core_parallel") - legacy_path = os.path.join(compiled_path, "dit_core_neuron.pt") - - if os.path.isdir(parallel_dir): - from neuronx_distributed.trace import parallel_model_load - self._neuron_dit_core = parallel_model_load(parallel_dir) - logger.info("Loaded TP-replicated DiT core from %s", parallel_dir) - elif os.path.exists(legacy_path): - self._neuron_dit_core = torch.jit.load(legacy_path) - logger.warning( - "Loaded legacy single-device DiT core from %s; recompile with " - "compile_dit() to get the TP=4 replicated artifact and avoid " - "cross-core-group scheduling overhead when running alongside " - "the Thinker and Talker.", - legacy_path, - ) - else: + + if not os.path.exists(save_path): raise FileNotFoundError( - f"Compiled DiT core not found at {parallel_dir} or {legacy_path}" + f"Compiled DiT core not found at {save_path}" ) + self._neuron_dit_core = torch.jit.load(save_path) self._dit_compiled_path = compiled_path if os.path.exists(meta_path): @@ -581,6 +533,8 @@ def load_dit(self, compiled_path): self._dit_max_mel_len = meta["max_mel_len"] self._dit_batch_size = meta["batch_size"] + logger.info("Loaded compiled DiT core from %s", save_path) + def _build_attention_masks(self, block_diff, actual_mel_len, max_mel_len): """Build three per-block float additive attention masks with padding. From 7916fb6ef410a13e738cedfc1eb899e8a7f68131 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 15:57:14 +0800 Subject: [PATCH 14/14] Pin all three Neuron models to NeuronCores 0-3 via NEURON_RT_VISIBLE_CORES After reverting the DiT TP=4 experiment back to the single-device torch_neuronx.trace path, Token2Wav was still slower than the subprocess baseline (~14s vs ~10.7s) because the Neuron runtime was placing the single-device DiT NEFF on a different core group than the TP=4 Thinker/Talker, paying a cross-group scheduling penalty on every DiT forward. Fix: set NEURON_RT_VISIBLE_CORES=0-3 so all three NEFFs share the same four NeuronCores. Done in two places: - generate_qwen25_omni_speech.py: os.environ.setdefault before any Neuron module is imported. setdefault so users can still override if they need a different core range. - README "Prerequisites" section: explicit instruction + rationale for anyone embedding the pipeline in their own entrypoint. Co-Authored-By: Claude Opus 4.7 --- contrib/models/Qwen2.5-Omni-7B/README.md | 5 +++++ .../examples/generate_qwen25_omni_speech.py | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/contrib/models/Qwen2.5-Omni-7B/README.md b/contrib/models/Qwen2.5-Omni-7B/README.md index 23276499..5b21028e 100644 --- a/contrib/models/Qwen2.5-Omni-7B/README.md +++ b/contrib/models/Qwen2.5-Omni-7B/README.md @@ -37,6 +37,11 @@ Key features: pip install soundfile # writes WAV output in generate_qwen25_omni_speech.py pip install qwen-omni-utils[decord] # process_mm_info() in generate_qwen25_omni.py for image/audio/video inputs ``` +- **Pin all three Neuron models to the same core group**. Set `NEURON_RT_VISIBLE_CORES=0-3` before launching the speech pipeline so the Thinker (TP=4), Talker (TP=4), and the single-device Token2Wav DiT all live on the same four NeuronCores. Without this the Neuron runtime places the DiT NEFF on a different core group and every DiT forward pays a cross-group scheduling penalty (~30% slower). `examples/generate_qwen25_omni_speech.py` already sets this via `os.environ.setdefault` before any Neuron module is imported; if you embed the pipeline in your own entrypoint, do the same. + ```bash + export NEURON_RT_VISIBLE_CORES=0-3 + python examples/generate_qwen25_omni_speech.py + ``` ## Usage diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py index c903e496..733614d6 100644 --- a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -35,6 +35,15 @@ python examples/generate_qwen25_omni_speech.py --num-runs 5 """ +import os as _os + +# Pin all three Neuron-compiled models (Thinker, Talker, DiT) to the same +# four NeuronCores. Without this, the runtime places the single-device DiT +# NEFF on a different core group than the TP=4 Thinker/Talker, and the +# resulting cross-group scheduling makes every DiT forward ~30% slower. +# Set before any Neuron module is imported so the runtime picks it up. +_os.environ.setdefault("NEURON_RT_VISIBLE_CORES", "0-3") + # --- Qwen2.5-Omni contrib bootstrap --- import sys as _sys from pathlib import Path as _Path