diff --git a/contrib/models/Qwen2.5-Omni-7B/README.md b/contrib/models/Qwen2.5-Omni-7B/README.md index c5830ceb..5b21028e 100644 --- a/contrib/models/Qwen2.5-Omni-7B/README.md +++ b/contrib/models/Qwen2.5-Omni-7B/README.md @@ -1,130 +1,292 @@ -# Contrib Model: Qwen2.5 Omni 7B +# Contrib Model: Qwen2.5-Omni-7B -NeuronX Distributed Inference implementation of Qwen2.5 Omni 7B. - -> **Note:** This implementation has been validated using the **text backbone only**. Vision/audio modalities are implemented but not yet verified. +NeuronX Distributed Inference implementation of [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) with full multimodal support: text generation, image understanding, audio understanding, and text-to-speech. ## Model Information - **HuggingFace ID:** `Qwen/Qwen2.5-Omni-7B` -- **Model Type:** Decoder-only transformer +- **Model Type:** Multimodal encoder-decoder (Thinker + Vision + Audio + Talker + Token2Wav) +- **Architecture:** Qwen2-based text backbone with vision/audio encoders and speech synthesis - **License:** Check HuggingFace model card ## Architecture Details -- **Layers:** Check model config -- **Hidden Size:** Check model config -- **Attention Heads:** Check model config -- **Vocabulary:** Check model config -- **Max Position Embeddings:** Check model config +| Component | Runtime | TP | Parameters | +|-----------|---------|-----|------------| +| Thinker (text) | Neuron | 4 | hidden=3584, heads=28, kv_heads=4, layers=28 | +| Vision encoder | Neuron | 4 | embed=1280, heads=16, depth=32, SwiGLU MLP | +| Audio encoder | CPU+Neuron | 4 | d_model=1280, heads=20, layers=32, chunked attention | +| Talker | Neuron | 4 | hidden=896, heads=12, kv_heads=4, head_dim=128, layers=24, vocab=8448, fused embed (8448→896) | +| Token2Wav | CPU+Neuron (fp32) | N/A | DiT: dim=1024, 22 blocks (Neuron); BigVGAN: 6 upsample stages (CPU) | + +**Total state dict keys:** 2448 (Text: 339, Vision: 518, Audio: 489, Talker: 293, Token2Wav: 809) + +Key features: +- **Thinker**: Architecturally identical to Qwen2.5-7B; reuses `NeuronQwen2ForCausalLM` with state-dict prefix remapping (28 heads / 4 TP = 7 per rank, 4 kv_heads / 4 TP = 1 per rank) +- **Vision encoder**: SwiGLU MLP, RMSNorm, separate QKV projections, PatchMerger (16 heads / 4 TP = 4 per rank) +- **Audio encoder**: Whisper-style with chunked attention. Hybrid CPU+Neuron: Conv1d frontend + chunking on CPU, 32 transformer layers on Neuron (20 heads / 4 TP = 5 per rank), AvgPool + LayerNorm + projection on CPU +- **Talker**: Neuron-compiled with fused embedding (embed_tokens 8448→3584 + thinker_to_talker_proj 3584→896 collapsed into 8448→896), explicit head_dim=128, 3D mRoPE, per-step thinker state injection via vision_embeddings (12 heads / 4 TP = 3 per rank, 4 kv_heads / 4 TP = 1 per rank). Auto-pads vision_embeddings to max_context_length for compiled bucket compatibility. +- **Token2Wav**: DiT transformer core (22 blocks) on Neuron + BigVGAN vocoder on CPU, ODE sampling (Runge-Kutta 4, 10 steps), float32. Split architecture: CPU preprocessing (ECAPA-TDNN, codec embed, input embed, rotary) + Neuron transformer core + CPU ODE solver + CPU BigVGAN. Automatic CPU fallback when mel_len exceeds compiled max. + +## Prerequisites + +- **Instance**: trn2.48xlarge or trn2.xlarge (4+ NeuronCores sufficient) +- **Weights**: Download from [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) — the example scripts auto-download via `huggingface_hub.snapshot_download` on first run. +- **Python dependencies** (on top of the NxDI venv): + ```bash + pip install soundfile # writes WAV output in generate_qwen25_omni_speech.py + pip install qwen-omni-utils[decord] # process_mm_info() in generate_qwen25_omni.py for image/audio/video inputs + ``` +- **Pin all three Neuron models to the same core group**. Set `NEURON_RT_VISIBLE_CORES=0-3` before launching the speech pipeline so the Thinker (TP=4), Talker (TP=4), and the single-device Token2Wav DiT all live on the same four NeuronCores. Without this the Neuron runtime places the DiT NEFF on a different core group and every DiT forward pays a cross-group scheduling penalty (~30% slower). `examples/generate_qwen25_omni_speech.py` already sets this via `os.environ.setdefault` before any Neuron module is imported; if you embed the pipeline in your own entrypoint, do the same. + ```bash + export NEURON_RT_VISIBLE_CORES=0-3 + python examples/generate_qwen25_omni_speech.py + ``` -## Validation Results +## Usage -**Validated:** 2026-01-29 -**Configuration:** TP=2, batch_size=1, seq_len=128, bfloat16 +### Text-only (Thinker) -### Test Results +```python +import sys +from pathlib import Path + +# Make this contrib package's src/ importable (flat, per upstream contrib convention). +sys.path.insert(0, str(Path("contrib/models/Qwen2.5-Omni-7B/src").resolve())) +import _upstream_compat # noqa: F401 (applies hf_adapter bug fix) + +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter +from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, + Qwen25OmniInferenceConfig, +) -| Test | Status | Result | -|------|--------|--------| -| Smoke Test | ✅ PASS | Model loads successfully | -| Token Matching | ⚠️ N/A | **0.0% match** | -| TTFT (P50) | ✅ PASS | 50.15ms (threshold: 100ms) | -| Throughput | ✅ PASS | 19.82 tok/s (threshold: 10 tok/s) | +model_path = "/path/to/Qwen2.5-Omni-7B/" +compiled_path = "/path/to/compiled/" -### Performance Metrics +neuron_config = NeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=4096, + max_context_length=4096, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.6, top_k=20, top_p=0.95 + ), +) -| Metric | Value | -|--------|-------| -| TTFT (P50) | 50.15ms | -| Throughput | 19.82 tokens/s | +config = Qwen25OmniInferenceConfig( + neuron_config, load_config=load_pretrained_config(model_path) +) +model = NeuronQwen25OmniForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) -**Status:** ✅ VALIDATED +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +adapter = HuggingFaceGenerationAdapter(model, tokenizer) +output = adapter.generate("What is quantum computing?", max_new_tokens=256) +``` -### Device Profiling Metrics +### Multimodal (Vision + Audio + Speech) -**Configuration:** TP=2, batch_size=1, seq_len=128, bfloat16 -**Instance:** trn1.32xlarge | **Profiled:** 2026-03-18 +```python +from modeling_qwen25_omni import ( + NeuronQwen25OmniMultimodalForCausalLM, + Qwen25OmniMultimodalInferenceConfig, +) -| Metric | Context Encoding | Token Generation | -|--------|-----------------|------------------| -| MFU (%) | 0.19 | 0.00 | -| MBU (%) | 0.36 | 0.42 | -| HFU (%) | 0.19 | 0.00 | -| Execution Time (us) | 0.05 | 0.04 | -| HBM Read | 7.19 GB | 7.08 GB | -| HBM Write | 88.46 MB | 2.78 MB | +# After loading text model, enable multimodal components: +# model.enable_audio_encoder(audio_state_dict) +# model.compile_audio_encoder("/path/to/compiled_audio/") # compile Neuron transformer +# model.load_audio_encoder("/path/to/compiled_audio/") # load compiled model +# model.enable_talker(talker_state_dict) +# model.enable_token2wav(token2wav_state_dict, speaker_dict_path="spk_dict.pt") +# +# Full multimodal pipeline: +# Thinker generates text -> hidden states passed to Talker +# -> Talker generates codec tokens -> Token2Wav generates waveform +``` -**Throughput:** 19.81 tok/s | **Compile Time:** 332.09s +## vLLM Integration -> Metrics from `neuron-profile capture` on compiled NEFFs. MFU = Model FLOPs Utilization, -> MBU = Memory Bandwidth Utilization, HFU = Hardware FLOPs Utilization. +Qwen2.5-Omni can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron) for text-only inference. A patch is required for the nested config structure. -## Usage +### Setup -```python -from transformers import AutoTokenizer, GenerationConfig -from neuronx_distributed_inference.models.config import NeuronConfig -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +```bash +# 1. Install vllm-neuron +pip install vllm-neuron -# Import model classes from src -from src.modeling_qwen2_5_omni_7b import NeuronQwen25Omni7BForCausalLM, Qwen25Omni7BInferenceConfig +# 2. Apply the Qwen2.5-Omni patch +python perf_test/apply_vllm_neuron_patch_qwen25omni.py +``` -model_path = "/path/to/Qwen2.5-Omni-7B/" -compiled_model_path = "/path/to/compiled/" +### Serving -# Configure -neuron_config = NeuronConfig( - tp_degree=2, - batch_size=1, - seq_len=512, - torch_dtype=torch.bfloat16, -) +```bash +python3 -m vllm.entrypoints.openai.api_server \ + --model /path/to/Qwen2.5-Omni-7B \ + --tensor-parallel-size 4 \ + --max-model-len 4096 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 4, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": false, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 4096, + "seq_len": 4096, + "is_continuous_batching": true, + "enable_bucketing": true, + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' +``` -config = Qwen25Omni7BInferenceConfig( - neuron_config, - load_config=load_pretrained_config(model_path), -) +### Key vLLM Patch Changes -# Compile and load -model = NeuronQwen25Omni7BForCausalLM(model_path, config) -model.compile(compiled_model_path) -model.load(compiled_model_path) +The patch (`perf_test/apply_vllm_neuron_patch_qwen25omni.py`) modifies vllm-neuron to: +- Extract text config from nested `thinker_config.text_config` +- Map `Qwen2_5OmniModel` architecture to `qwen2_5_omni` model type +- Handle layer count extraction for nested config -# Generate -tokenizer = AutoTokenizer.from_pretrained(model_path) -# ... (see integration test for full example) -``` +See `perf_test/3_bench_qwen25_omni_7b.sh` for full benchmark configurations. + +## Performance + +Text-only benchmark (trn2, BF16, TP=4): + +| Config | TPOT (ms) | Output tok/s | Notes | +|--------|-----------|--------------|-------| +| BS=1, non-CB, greedy | ~11-13 | ~77-90 | Tested with chat template | +| BS=4, CB, c=4 | TBD | TBD | vLLM serving | + +Model load time: ~15s (from compiled artifacts on NVMe). + +Audio encoder performance (CPU frontend + CPU postprocessor, no Neuron transformer): + +| Audio Length | Mel Frames | Frontend | Postprocessor | +|-------------|-----------|----------|---------------| +| 1s | ~100 | ~20ms | included | +| 3s | ~300 | ~22ms | included | +| 10s | ~1000 | ~33ms | included | +| 30s | ~3000 | ~34ms | included | + +### End-to-End Multimodal (CPU inference, trn2.48xlarge) + +| Test | Input | Output | Time | +|------|-------|--------|------| +| Text → Text | "What is the capital of France?" | Correct answer (Paris) | 15.1s | +| Image + Text → Text | Synthetic image (shapes) + description prompt | Correctly identified red square, blue circle, yellow circle, green triangle | 59.5s | +| Audio + Text → Text | 440Hz sine wave + "What do you hear?" | Text response generated | 12.1s | +| Text → Speech | "Say hello and tell me the weather" | Text + audio waveform (14.2s audio) | 197.2s | + +### Speech Pipeline Profiling (CPU inference, trn2.48xlarge) + +Per-component measured breakdown for text-to-speech (14.1s audio output): + +| Component | Time | % of Total | RTF | Notes | +|-----------|------|------------|-----|-------| +| Thinker (7B) | 31.0s | 12% | — | 59 text tokens, ~1.9 tok/s on CPU | +| Talker (690M) | 103.3s | 41% | 7.3x | Autoregressive codec token generation, 24 layers | +| Token2Wav (DiT+BigVGAN) | 117.9s | 47% | 8.4x | 22 DiT blocks × 10 ODE steps × 2 (CFG) = 440 forward passes | +| **Total** | **252.1s** | **100%** | **17.9x** | Generating 14.1s audio takes 252.1s on CPU | + +### Full Neuron Speech Pipeline (trn2.48xlarge, TP=4, BF16) + +End-to-end speech synthesis with all components on Neuron (9.1s audio output): + +| Component | Time | Notes | +|-----------|------|-------| +| Thinker (7B, Neuron TP=4) | 0.3s | 24 text tokens | +| CPU hidden state extraction | ~3s | HF model forward for thinker states | +| Talker (690M, Neuron TP=4) | 2.1s | 454 codec tokens, per-step thinker injection | +| Token2Wav (Neuron DiT + CPU BigVGAN) | 9.9s | 22 DiT blocks × 10 ODE steps | +| **Total** | **~15s** | **9.1s audio, RTF ~1.7x** | + +### Neuron vs CPU Speedup (trn2.48xlarge, TP=4, BF16) + +| Component | CPU Time | Neuron Time | Speedup | Notes | +|-----------|----------|-------------|---------|-------| +| Thinker (7B) | 30.4s | 0.47s | **64.7x** | TPOT 10.2ms | +| Talker (690M) | 98.1s | 2.0s (500 tokens) | **49.1x** | TPOT 4.0ms | +| Token2Wav DiT (85M) | 24.1s | 3.8s | **6.3x** | 22 blocks × 10 ODE steps, batch=2 (CFG) | +| Token2Wav BigVGAN | 2.8s | 2.8s (CPU) | 1x | Stays on CPU | +| **Total** | **267.9s** | **~15s** | **~18x** | All Neuron components active | + +Token2Wav component breakdown (300 codec tokens / 6.0s audio): + +| Config | CPU | Neuron DiT | Speedup | +|--------|-----|-----------|---------| +| DiT only (22 blocks, 10 ODE steps) | 24.1s | 3.8s | 6.3x | +| Token2Wav end-to-end | 13.7s | 5.2s | 2.7x | +| DiT core single forward (batch=2, mel_len=1024) | 592ms | 60ms | 9.8x | + +Key observations: +- **Full Neuron speech pipeline** verified end-to-end: Thinker → Talker → Token2Wav all on Neuron, producing real human speech +- Thinker and Talker achieve **49-65x speedup** on Neuron +- Token2Wav DiT achieves **6.3x speedup** (9.8x for isolated transformer core) +- BigVGAN vocoder (CPU) is now the remaining bottleneck +- **Per-step thinker state injection**: Talker v2 adds thinker_reply_part[step] embedding at each autoregressive step, matching HF behavior +- **Vision embeddings auto-padding**: Compiled Neuron models require fixed bucket shapes; vision_embeddings are auto-padded to max_context_length +- Split architecture for Token2Wav: CPU preprocessing (ECAPA-TDNN, codec/input embed, rotary, block_diff) + Neuron transformer core (22 blocks + norm + proj) +- Overcame XLA tracing limitations: in-place slice assignment in DiTAttention (→ torch.cat), SDPA dispatch (→ explicit matmul), ECAPA-TDNN/codec embed issues (→ kept on CPU) +- Automatic CPU fallback when mel_len exceeds compiled DiT max ## Compatibility Matrix -| Instance/Version | 2.20+ | 2.19 and earlier | -|------------------|-------|------------------| -| Trn1 | ✅ Working | Not tested | -| Inf2 | Not tested | Not tested | +| Instance/Version | 2.23+ (PyTorch 2.9) | 2.22 and earlier | +|------------------|---------------------|------------------| +| Trn2 (trn2.48xlarge) | Tested (TP=4) | Not tested | +| Trn2 (trn2.xlarge) | Supported (TP=4) | Not tested | +| Trn1 (trn1.32xlarge) | Should work (TP=4, 4 NeuronCores) | Not tested | +| Inf2 (inf2.48xlarge) | Should work (TP=4) | Not tested | ## Testing -Run integration tests: +Verified on trn2.48xlarge with real Qwen2.5-Omni-7B weights: + +- **Imports**: All model classes import successfully +- **Config**: TP=4 head divisibility verified (Thinker 7/1, Audio 5, Vision 4 per rank) +- **State dict**: All 2448 keys converted correctly (text=339, audio=489, vision=518, talker=293, token2wav=809) +- **Audio CPU**: Frontend+postprocessor 1s=20ms, 30s=34ms +- **Talker CPU**: 1351M params loaded in ~10s, codec tokens verified +- **Text generation (TP=4)**: Compile + load + generate working, TPOT ~11-13ms, correct outputs verified ```bash -pytest nxdi_contrib_models/models/Qwen2.5-Omni-7B/test/integration/test_model.py --capture=tee-sys +# End-to-end test (compile + load + generate) +python /tmp/test_qwen25_omni_tp4.py ``` -Or run manually: +## Key Implementation Notes -```bash -cd nxdi_contrib_models/models/Qwen2.5-Omni-7B -python3 test/integration/test_model.py -``` +1. **TP=4 for all Neuron components**: Thinker (28 heads/4=7 per rank), Vision (16 heads/4=4), Audio (20 heads/4=5). All heads divisible by 4. +2. **Audio encoder hybrid architecture**: Conv1d frontend + chunking on CPU, 32 transformer layers on Neuron with TP=4, AvgPool + LayerNorm + projection on CPU. Asymmetric attention bias (q/v have bias, k has none) handled via ColumnParallelLinear. +3. **Talker on Neuron**: Non-standard head_dim (128 != 896/12), 3D mRoPE with per-step thinker-state injection, ~690M params. Uses ImageToTextModelWrapper with 24 positional args. Fused embedding (embed_tokens 8448→3584 + proj 3584→896 collapsed into 8448→896). Per-step thinker reply states injected via vision_embeddings during token generation. Vision embeddings auto-padded to max_context_length for compiled bucket compatibility. TPOT 4.0ms. +4. **Token2Wav split architecture**: DiT transformer core (22 blocks) on Neuron via torch_neuronx.trace(). CPU preprocessing: ECAPA-TDNN speaker encoder, codec embedding (repeat_interleave), input embedding, rotary embedding, block_diff mask. CPU postprocessing: ODE solver (RK4, 10 steps), BigVGAN vocoder. Float32 for ODE precision. XLA fixes: DiTAttention in-place slice assignment → torch.cat, SDPA dispatch → explicit matmul attention, float additive attention mask. Automatic CPU fallback when mel_len exceeds compiled max. +5. **Speaker support**: `spk_dict.pt` contains per-speaker conditioning (Ethan, Chelsie) +6. **State dict prefix remapping**: `thinker.model.*` -> `model.*`, `thinker.lm_head.*` -> `lm_head.*`, `thinker.visual.*` -> `visual.*`, `thinker.audio_tower.*` -> `frontend.*`/`transformer.*`/`postprocessor.*` ## Example Checkpoints -* Qwen/Qwen2.5-Omni-7B +* [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) ## Maintainer -Annapurna Labs +Henan Wan (whn09) -**Last Updated:** 2026-01-29 +**Last Updated:** 2026-04-15 diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py new file mode 100644 index 00000000..ad8d9858 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +""" +End-to-end multimodal inference for Qwen2.5-Omni-7B on NeuronX (TP=4). + +Supports: + 1. Text-only: Thinker generates text + 2. Image + text: Vision encoder + Thinker + 3. Audio + text: Audio encoder + Thinker + 4. Image + audio + text: All encoders + Thinker + 5. Speech output: Thinker → Talker → Token2Wav (optional) + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + pip install qwen-omni-utils[decord] + + # First run compiles (~30 min), subsequent runs load from cache: + python3 examples/generate_qwen25_omni.py + + # Text-only (fastest, uses simpler text-only model): + python3 examples/generate_qwen25_omni.py --mode text + + # Image understanding: + python3 examples/generate_qwen25_omni.py --mode image + + # Audio understanding: + python3 examples/generate_qwen25_omni.py --mode audio + + # Full multimodal (image + audio + text → text + speech): + python3 examples/generate_qwen25_omni.py --mode full +""" + +# --- Qwen2.5-Omni contrib bootstrap --- +import sys as _sys +from pathlib import Path as _Path +_SRC = _Path(__file__).resolve().parents[1] / "src" +if str(_SRC) not in _sys.path: + _sys.path.insert(0, str(_SRC)) +import _upstream_compat # noqa: F401 (applies hf_adapter shim) +# --- end bootstrap --- + +import argparse +import gc +import os +import sys +import time + +import torch + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +from _model_path import resolve_model_path +MODEL_PATH = resolve_model_path() +COMPILED_PATH = os.environ.get( + "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_compiled" +) +TP_DEGREE = int(os.environ.get("QWEN25_OMNI_TP_DEGREE", "4")) + +# Test media from Qwen official examples +IMAGE_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" +AUDIO_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/cough.wav" + +# Sequence lengths +TEXT_SEQ_LENGTH = 4096 +TEXT_BUCKETS = [256, 512, 1024, 2048, 4096] +VISION_SEQ_LENGTH = 1012 # single image +VISION_BUCKETS = [1] # 1 image + + +class Timer: + def __init__(self, label): + self.label = label + self.elapsed = 0 + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.elapsed = time.time() - self.start + print(f" [{self.label}] {self.elapsed:.2f}s") + + +# --------------------------------------------------------------------------- +# Mode 1: Text-only (uses the simpler NeuronQwen25OmniForCausalLM) +# --------------------------------------------------------------------------- +def run_text_only(model_path=MODEL_PATH, compiled_path=COMPILED_PATH): + """Text-only inference using the Thinker model.""" + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, + HuggingFaceGenerationAdapter, + ) + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, + Qwen25OmniInferenceConfig, + ) + from transformers import AutoTokenizer + + print("\n" + "=" * 60) + print("Mode: Text-only (Thinker on Neuron, TP=4)") + print("=" * 60) + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=2048, + max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(do_sample=False, top_k=1), + ) + + hf_config = load_pretrained_config(model_path) + config = Qwen25OmniInferenceConfig(neuron_config, load_config=hf_config) + + compiled_dir = os.path.join(compiled_path, "thinker_tp4") + + with Timer("Create model"): + model = NeuronQwen25OmniForCausalLM(model_path, config) + + if not os.path.exists(os.path.join(compiled_dir, "neuron_config.json")): + with Timer("Compile"): + model.compile(compiled_dir) + else: + print(" Compiled artifacts found") + + with Timer("Load"): + model.load(compiled_dir) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + adapter = HuggingFaceGenerationAdapter(model) + + prompts = [ + "What is 2+3? Answer with just the number.", + "Write a haiku about the ocean.", + "Explain quantum computing in one sentence.", + ] + + print("\n--- Generation Results ---") + for prompt in prompts: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + encoded = tokenizer(text, return_tensors="pt") + + with Timer(f"Generate"): + output_ids = adapter.generate( + input_ids=encoded["input_ids"], + attention_mask=encoded["attention_mask"], + max_new_tokens=128, + eos_token_id=[tokenizer.eos_token_id, 151645], + ) + + new_tokens = output_ids[0, encoded["input_ids"].shape[1] :] + output_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + print(f" Q: {prompt}") + print(f" A: {output_text.strip()[:300]}") + print() + + del model, adapter + gc.collect() + + +# --------------------------------------------------------------------------- +# Mode 2: Multimodal (image/audio + text → text, optional speech) +# --------------------------------------------------------------------------- +def run_multimodal(mode="full", model_path=MODEL_PATH, compiled_path=COMPILED_PATH): + """Multimodal inference: image + audio + text → text (+ optional speech). + + Args: + mode: "image" (image+text), "audio" (audio+text), "full" (image+audio+text) + model_path: Path to HF model weights + compiled_path: Path for compiled artifacts + """ + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, + HuggingFaceGenerationAdapter, + ) + from modeling_qwen25_omni import ( + NeuronQwen25OmniMultimodalForCausalLM, + Qwen25OmniMultimodalInferenceConfig, + ) + from transformers import AutoProcessor, GenerationConfig + + print("\n" + "=" * 60) + print(f"Mode: Multimodal ({mode}) on Neuron, TP={TP_DEGREE}") + print("=" * 60) + + # --- Step 1: Create configs --- + text_neuron_config = NeuronConfig( + batch_size=1, + seq_len=TEXT_SEQ_LENGTH, + max_context_length=TEXT_SEQ_LENGTH, + ctx_batch_size=1, + tp_degree=TP_DEGREE, + torch_dtype=torch.bfloat16, + fused_qkv=False, + sequence_parallel_enabled=False, + flash_decoding_enabled=False, + qkv_kernel_enabled=False, + qkv_nki_kernel_enabled=False, + attn_kernel_enabled=False, + enable_bucketing=True, + context_encoding_buckets=TEXT_BUCKETS, + token_generation_buckets=TEXT_BUCKETS, + on_device_sampling_config=OnDeviceSamplingConfig(do_sample=False, top_k=1), + ) + + vision_neuron_config = NeuronConfig( + batch_size=1, + seq_len=VISION_SEQ_LENGTH, + tp_degree=TP_DEGREE, + torch_dtype=torch.bfloat16, + enable_bucketing=True, + buckets=VISION_BUCKETS, + fused_qkv=False, # Qwen2.5-Omni vision uses separate Q/K/V + qkv_kernel_enabled=False, + attn_kernel_enabled=False, + ) + + config = Qwen25OmniMultimodalInferenceConfig( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + load_config=load_pretrained_config(model_path), + ) + + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + # --- Step 2: Create and compile/load the model --- + compiled_dir = os.path.join(compiled_path, "multimodal_tp4") + + with Timer("Create multimodal model"): + model = NeuronQwen25OmniMultimodalForCausalLM( + model_path=model_path, config=config + ) + + if not os.path.exists(os.path.join(compiled_dir, "neuron_config.json")): + with Timer("Compile text + vision (this takes 20-40 minutes)"): + model.compile(compiled_dir) + processor.save_pretrained(compiled_dir) + else: + print(" Compiled artifacts found") + + with Timer("Load compiled model"): + model.load(compiled_dir) + processor = AutoProcessor.from_pretrained(compiled_dir) + + # --- Step 3: Enable audio encoder (if needed) --- + if mode in ("audio", "full"): + print("\n Enabling audio encoder...") + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniForConditionalGeneration, + ) + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + + with Timer("Load HF model for audio weights"): + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + full_sd = hf_model.state_dict() + + audio_sd = NeuronQwen25OmniAudioEncoder.convert_hf_to_neuron_state_dict( + full_sd, dtype=torch.bfloat16 + ) + del hf_model, full_sd + gc.collect() + + model.enable_audio_encoder(audio_sd) + + compiled_audio_dir = os.path.join(compiled_path, "audio_encoder_tp4") + if not os.path.exists( + os.path.join(compiled_audio_dir, "neuron_config.json") + ): + with Timer("Compile audio encoder transformer"): + model.compile_audio_encoder(compiled_audio_dir) + else: + print(" Audio encoder compiled artifacts found") + + with Timer("Load audio encoder"): + model.load_audio_encoder(compiled_audio_dir) + + del audio_sd + gc.collect() + + # --- Step 4: Run inference --- + adapter = HuggingFaceGenerationAdapter(model) + generation_config = GenerationConfig( + do_sample=False, + eos_token_id=[151645], + pad_token_id=151645, + ) + + # Prepare inputs based on mode + from qwen_omni_utils import process_mm_info + + if mode == "image": + print("\n--- Image Understanding ---") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "image", "image": IMAGE_URL}, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + elif mode == "audio": + print("\n--- Audio Understanding ---") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "audio", "audio": AUDIO_URL}, + {"type": "text", "text": "What sound is this? Describe what you hear."}, + ], + }, + ] + else: # full + print("\n--- Full Multimodal (Image + Audio + Text) ---") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "image", "image": IMAGE_URL}, + {"type": "audio", "audio": AUDIO_URL}, + { + "type": "text", + "text": "Describe the image and the audio you received.", + }, + ], + }, + ] + + # Process multimodal inputs + with Timer("Process multimodal inputs"): + audios, images, videos = process_mm_info(messages, use_audio_in_video=False) + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + inputs = processor( + text=[text], + images=images if images else None, + audios=audios if audios else None, + return_tensors="pt", + padding=True, + ) + + print(f" input_ids shape: {inputs.input_ids.shape}") + if hasattr(inputs, "pixel_values") and inputs.pixel_values is not None: + print(f" pixel_values shape: {inputs.pixel_values.shape}") + if hasattr(inputs, "image_grid_thw") and inputs.image_grid_thw is not None: + print(f" image_grid_thw: {inputs.image_grid_thw}") + if hasattr(inputs, "input_features") and inputs.input_features is not None: + print(f" input_features shape: {inputs.input_features.shape}") + + # Generate + generate_kwargs = { + "input_ids": inputs.input_ids, + "attention_mask": inputs.attention_mask, + "generation_config": generation_config, + "max_new_tokens": 256, + } + if hasattr(inputs, "pixel_values") and inputs.pixel_values is not None: + generate_kwargs["pixel_values"] = inputs.pixel_values + if hasattr(inputs, "image_grid_thw") and inputs.image_grid_thw is not None: + generate_kwargs["image_grid_thw"] = inputs.image_grid_thw + if hasattr(inputs, "input_features") and inputs.input_features is not None: + generate_kwargs["input_features"] = inputs.input_features + if ( + hasattr(inputs, "feature_attention_mask") + and inputs.feature_attention_mask is not None + ): + generate_kwargs["feature_attention_mask"] = inputs.feature_attention_mask + + with Timer("Generate"): + output_ids = adapter.generate(**generate_kwargs) + + output_text = processor.batch_decode( + output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + print(f"\n Output: {output_text[0].strip()[:500]}") + + del model, adapter + gc.collect() + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser(description="Qwen2.5-Omni-7B inference on Neuron") + parser.add_argument( + "--mode", + choices=["text", "image", "audio", "full"], + default="text", + help="Inference mode (default: text)", + ) + parser.add_argument( + "--model-path", + default=MODEL_PATH, + help=f"Model path (default: {MODEL_PATH})", + ) + parser.add_argument( + "--compiled-path", + default=COMPILED_PATH, + help=f"Compiled model path (default: {COMPILED_PATH})", + ) + args = parser.parse_args() + + model_path = args.model_path + compiled_path = args.compiled_path + + print(f"Model: {model_path}") + print(f"Compiled: {compiled_path}") + print(f"TP: {TP_DEGREE}") + + if args.mode == "text": + run_text_only(model_path, compiled_path) + else: + run_multimodal(args.mode, model_path, compiled_path) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py new file mode 100644 index 00000000..733614d6 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/examples/generate_qwen25_omni_speech.py @@ -0,0 +1,763 @@ +#!/usr/bin/env python3 +""" +End-to-end speech synthesis for Qwen2.5-Omni-7B on NeuronX (TP=4). + +Full pipeline: Thinker (text) -> Talker (codec tokens) -> Token2Wav (audio). + +All three Neuron-compiled components (Thinker, Talker, Token2Wav DiT) are +loaded *once* into the same Python process on the same NeuronCores (TP=4, +core 0-3) and reused across runs. The first inference still pays the +full model-load cost, but subsequent runs are pure inference. + +Two-step workflow: + Step 1: Compile all Neuron components (one-time, ~30 min) + Step 2: Run inference + +Prerequisites: + - Trn2 instance (trn2.48xlarge or trn2.xlarge, 4+ NeuronCores) + - Neuron SDK 2.23+ with PyTorch 2.9 + - Model weights downloaded from Qwen/Qwen2.5-Omni-7B (auto-fetched on first run) + - pip install soundfile + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + cd neuronx-distributed-inference + + # Step 1: Compile (one-time) + python examples/generate_qwen25_omni_speech.py --compile + + # Step 2: Run inference + python examples/generate_qwen25_omni_speech.py + python examples/generate_qwen25_omni_speech.py --prompt "Tell me about the weather" + python examples/generate_qwen25_omni_speech.py --speaker Chelsie --output hello.wav + + # Benchmark: load each model once, run N inferences, report avg latency + python examples/generate_qwen25_omni_speech.py --num-runs 5 +""" + +import os as _os + +# Pin all three Neuron-compiled models (Thinker, Talker, DiT) to the same +# four NeuronCores. Without this, the runtime places the single-device DiT +# NEFF on a different core group than the TP=4 Thinker/Talker, and the +# resulting cross-group scheduling makes every DiT forward ~30% slower. +# Set before any Neuron module is imported so the runtime picks it up. +_os.environ.setdefault("NEURON_RT_VISIBLE_CORES", "0-3") + +# --- Qwen2.5-Omni contrib bootstrap --- +import sys as _sys +from pathlib import Path as _Path +_SRC = _Path(__file__).resolve().parents[1] / "src" +if str(_SRC) not in _sys.path: + _sys.path.insert(0, str(_SRC)) +import _upstream_compat # noqa: F401 (applies hf_adapter shim) +# --- end bootstrap --- + +import argparse +import gc +import os +import sys +import time + +import torch + +try: + import soundfile as sf +except ImportError: + sys.exit( + "soundfile is required for WAV output. Install with: pip install soundfile" + ) + +from _model_path import resolve_model_path + +MODEL_PATH = resolve_model_path() +COMPILED_PATH = os.environ.get( + "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_compiled" +) +TP_DEGREE = int(os.environ.get("QWEN25_OMNI_TP_DEGREE", "4")) + +DEFAULT_PROMPT = "Say hello and briefly introduce yourself in two sentences." +DEFAULT_SYSTEM = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, " + "capable of perceiving auditory and visual inputs, as well as generating " + "text and speech." +) +DEFAULT_SPEAKER = "Ethan" + +_ORIG_EMBEDDING_FORWARD = torch.nn.Embedding.forward + + +def _restore_embedding(): + """Restore original Embedding.forward if Neuron loading changed it.""" + if torch.nn.Embedding.forward is not _ORIG_EMBEDDING_FORWARD: + torch.nn.Embedding.forward = _ORIG_EMBEDDING_FORWARD + + +class Timer: + def __init__(self, label): + self.label = label + self.elapsed = 0 + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.elapsed = time.time() - self.start + print(f" [{self.label}] {self.elapsed:.2f}s") + + +# ========================================================================== +# Compilation (--compile) +# ========================================================================== + +def _compile_thinker(model_path, out_path): + from neuronx_distributed_inference.models.config import ( + NeuronConfig, OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, Qwen25OmniInferenceConfig, + ) + + nc = NeuronConfig( + tp_degree=TP_DEGREE, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.7, top_k=20, top_p=0.8, + ), + ) + cfg = Qwen25OmniInferenceConfig(nc, load_config=load_pretrained_config(model_path)) + model = NeuronQwen25OmniForCausalLM(model_path, cfg) + model.compile(out_path) + + +def _compile_talker(model_path, out_path): + from transformers import AutoConfig + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, TalkerInferenceConfig, TalkerNeuronConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + + hf = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + tc = hf.talker_config + + tnc = TalkerNeuronConfig( + tp_degree=TP_DEGREE, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + ) + tic = TalkerInferenceConfig( + neuron_config=tnc, load_config=load_pretrained_config(hf_config=tc), + ) + talker = NeuronQwen25OmniTalkerForCausalLM(model_path, config=tic) + talker.compile(out_path) + + +def _compile_dit(model_path, out_path): + from transformers import AutoConfig + from safetensors.torch import load_file + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(hf_config.token2wav_config) + + state_dict = {} + for fn in sorted(os.listdir(model_path)): + if fn.endswith(".safetensors"): + sd = load_file(os.path.join(model_path, fn)) + for k, v in sd.items(): + if k.startswith("token2wav."): + state_dict[k[len("token2wav."):]] = v + t2w.load_state_dict(state_dict, strict=False) + t2w.compile_dit(out_path, max_mel_len=2048, batch_size=2) + + +def compile_all(model_path, compiled_path): + """Compile all three Neuron components: Thinker, Talker, DiT. + + Each component is compiled sequentially in the current process. Compilation + holds the Neuron compiler (not the runtime) so there's no core-conflict + issue even when all three share TP=4 / core 0-3. + """ + print("=" * 60) + print("Compiling Qwen2.5-Omni Speech Components") + print("=" * 60) + print(f" Model: {model_path}") + print(f" Output: {compiled_path}") + print(f" TP: {TP_DEGREE}") + t_total = time.time() + + stages = [ + ("Thinker", "thinker_tp4", "neuron_config.json", _compile_thinker), + ("Talker", "talker_tp4", "neuron_config.json", _compile_talker), + ("DiT", "dit_core", "dit_core_neuron.pt", _compile_dit), + ] + for idx, (label, subdir, marker, fn) in enumerate(stages, 1): + print(f"\n--- [{idx}/{len(stages)}] Compiling {label} ---") + out_path = os.path.join(compiled_path, subdir) + if os.path.exists(os.path.join(out_path, marker)): + print(" Already compiled, skipping.") + continue + t0 = time.time() + fn(model_path, out_path) + print(f" {label} compiled in {time.time() - t0:.1f}s") + + print(f"\nAll components compiled in {time.time() - t_total:.0f}s") + print(f"Artifacts saved to: {compiled_path}/") + return True + + +# ========================================================================== +# Inference: model loading (once per process) +# ========================================================================== + +def _check_compiled(compiled_path): + checks = [ + (os.path.join(compiled_path, "thinker_tp4", "neuron_config.json"), "Thinker"), + (os.path.join(compiled_path, "talker_tp4", "neuron_config.json"), "Talker"), + (os.path.join(compiled_path, "dit_core", "dit_core_neuron.pt"), "DiT"), + ] + missing = [name for path, name in checks if not os.path.exists(path)] + if missing: + print(f"ERROR: Missing compiled artifacts for: {', '.join(missing)}") + print(f"Run with --compile first:") + print(f" python {sys.argv[0]} --compile") + return False + return True + + +def load_thinker(model_path, compiled_path): + """Load the Thinker (Qwen2.5-Omni text model) onto Neuron, return (adapter, tokenizer).""" + from neuronx_distributed_inference.models.config import ( + NeuronConfig, OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, HuggingFaceGenerationAdapter, + ) + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, Qwen25OmniInferenceConfig, + ) + from transformers import AutoTokenizer + + nc = NeuronConfig( + tp_degree=TP_DEGREE, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.7, top_k=20, top_p=0.8, + ), + ) + cfg = Qwen25OmniInferenceConfig(nc, load_config=load_pretrained_config(model_path)) + model = NeuronQwen25OmniForCausalLM(model_path, cfg) + + t0 = time.time() + model.load(os.path.join(compiled_path, "thinker_tp4")) + load_time = time.time() - t0 + print(f" [Thinker] loaded in {load_time:.1f}s") + + tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + adapter = HuggingFaceGenerationAdapter(model) + + # Warmup the NEFF so the first real inference isn't artificially slow. + enc = tok( + tok.apply_chat_template( + [{"role": "user", "content": "Hi"}], + tokenize=False, add_generation_prompt=True, + ), + return_tensors="pt", + ) + _ = adapter.generate( + input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], + max_new_tokens=5, eos_token_id=[tok.eos_token_id, 151645], + ) + print(" [Thinker] warmup done") + return adapter, tok, load_time + + +def load_talker(model_path, compiled_path): + """Load the Talker model onto Neuron and return (talker, adapter, talker_config).""" + from transformers import AutoConfig + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, TalkerInferenceConfig, TalkerNeuronConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, HuggingFaceGenerationAdapter, + ) + + hf = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + tc = hf.talker_config + + tnc = TalkerNeuronConfig( + tp_degree=TP_DEGREE, batch_size=1, seq_len=2048, max_context_length=2048, + torch_dtype=torch.bfloat16, + ) + tic = TalkerInferenceConfig( + neuron_config=tnc, load_config=load_pretrained_config(hf_config=tc), + ) + talker = NeuronQwen25OmniTalkerForCausalLM(model_path, config=tic) + + t0 = time.time() + talker.load(os.path.join(compiled_path, "talker_tp4")) + load_time = time.time() - t0 + print(f" [Talker] loaded in {load_time:.1f}s") + + adapter = HuggingFaceGenerationAdapter(talker) + return talker, adapter, tc, load_time + + +def load_token2wav(model_path, compiled_path): + """Load the Token2Wav model (DiT on Neuron + BigVGAN on CPU).""" + from transformers import AutoConfig + from safetensors.torch import load_file + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + t2w_cfg = hf_config.token2wav_config + + t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(t2w_cfg) + + state_dict = {} + for fn in sorted(os.listdir(model_path)): + if fn.endswith(".safetensors"): + sd = load_file(os.path.join(model_path, fn)) + for k, v in sd.items(): + if k.startswith("token2wav."): + state_dict[k[len("token2wav."):]] = v + t2w.load_state_dict(state_dict, strict=False) + + t0 = time.time() + t2w.load_dit(os.path.join(compiled_path, "dit_core")) + load_time = time.time() - t0 + _restore_embedding() + print(f" [DiT] loaded in {load_time:.1f}s") + return t2w, t2w_cfg, load_time + + +def load_hf_cpu(model_path): + """Load the HF Qwen2.5-Omni model on CPU in bfloat16 (for hidden-state extraction).""" + from transformers import Qwen2_5OmniForConditionalGeneration + + t0 = time.time() + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + hf_model.eval() + _restore_embedding() + load_time = time.time() - t0 + print(f" [HF CPU] loaded in {load_time:.1f}s") + return hf_model, load_time + + +# ========================================================================== +# Inference: per-run phases +# ========================================================================== + +def run_thinker(thinker_adapter, tokenizer, prompt, system_prompt): + """Phase 1: Thinker generates text.""" + chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ] + enc = tokenizer( + tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True), + return_tensors="pt", + ) + + t0 = time.time() + out = thinker_adapter.generate( + input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], + max_new_tokens=200, eos_token_id=[tokenizer.eos_token_id, 151645], + ) + elapsed = time.time() - t0 + + prompt_len = enc["input_ids"].shape[1] + all_ids = out[0].tolist() + gen_ids = all_ids[prompt_len:] + text = tokenizer.decode(gen_ids, skip_special_tokens=True) + return { + "all_ids": all_ids, + "prompt_len": prompt_len, + "gen_text": text, + "n_tokens": len(gen_ids), + "gen_time": elapsed, + } + + +def extract_hidden_states(hf_model, thinker_result): + """Phase 2: Run HF Thinker on CPU to capture hidden states. + + The compiled Neuron Thinker uses on-device sampling and only emits tokens, + not hidden states. The Talker needs the per-token last-layer hidden states + to condition on, so we re-run the prompt+reply through the HF CPU model in + bfloat16 — all downstream consumers already round back to bf16 so float32 + here would be pure overhead. + """ + full_ids = torch.tensor([thinker_result["all_ids"]], dtype=torch.long) + prompt_len = thinker_result["prompt_len"] + + t0 = time.time() + with torch.no_grad(): + outputs = hf_model.thinker( + input_ids=full_ids, output_hidden_states=True, return_dict=True, + ) + elapsed = time.time() - t0 + return outputs, full_ids, prompt_len, elapsed + + +def prepare_talker_input(model_path, hf_model, outputs, full_ids, prompt_len, speaker): + """Phase 3: Build projected thinker states for the Talker.""" + from transformers import AutoConfig + from safetensors.torch import load_file + from modeling_qwen25_omni_talker import ThinkerToTalkerProjection + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + talker_cfg = hf_config.talker_config + + spk_dict = torch.load(os.path.join(model_path, "spk_dict.pt"), weights_only=True) + sp = spk_dict[speaker] + conditioning = sp["cond"].unsqueeze(0).float() if sp["cond"].dim() == 1 else sp["cond"].float() + if conditioning.dim() == 1: + conditioning = conditioning.unsqueeze(0) + reference_mel = sp["ref_mel"].unsqueeze(0).float() if sp["ref_mel"].dim() == 2 else sp["ref_mel"].float() + if reference_mel.dim() == 2: + reference_mel = reference_mel.unsqueeze(0) + bos_token = sp["bos_token"] + if isinstance(bos_token, torch.Tensor): + bos_token = bos_token.item() + + embedding_output = outputs.hidden_states[0] + last_hidden = outputs.hidden_states[-1] + total_len = full_ids.shape[1] + + context_embed = embedding_output[:, :prompt_len, :] + context_hidden = last_hidden[:, :prompt_len, :] + reply_embeds = [embedding_output[:, i:i+1, :] for i in range(prompt_len, total_len)] + reply_hiddens = [last_hidden[:, i:i+1, :] for i in range(prompt_len, total_len)] + + thinker_token_embeds = [context_embed] + reply_embeds + thinker_hidden_states_list = [context_hidden] + reply_hiddens + + thinker_reply_part = ( + torch.cat(thinker_hidden_states_list[1:], dim=1) + + torch.cat(thinker_token_embeds[1:], dim=1) + ) + talker_inputs_embeds = thinker_hidden_states_list[0] + thinker_token_embeds[0] + + thinker_embed_tokens = hf_model.thinker.get_input_embeddings() + bos_embed = thinker_embed_tokens(torch.tensor([[bos_token]], dtype=torch.long)) + talker_inputs_embeds = torch.cat([ + talker_inputs_embeds, bos_embed, thinker_reply_part[:, :1, :], + ], dim=1) + + talker_embed_weight = None + for fn in sorted(os.listdir(model_path)): + if fn.endswith(".safetensors"): + sd = load_file(os.path.join(model_path, fn)) + if "talker.model.embed_tokens.weight" in sd: + talker_embed_weight = sd["talker.model.embed_tokens.weight"] + break + if talker_embed_weight is not None: + talker_embed_layer = torch.nn.Embedding( + talker_embed_weight.shape[0], talker_embed_weight.shape[1], + ) + talker_embed_layer.weight.data = talker_embed_weight.float() + codec_bos_embed = talker_embed_layer( + torch.tensor([talker_cfg.tts_codec_start_token_id]), + ) + codec_pad_embed = talker_embed_layer( + torch.tensor([talker_cfg.tts_codec_pad_token_id]), + ) + talker_inputs_embeds[:, -1, :] += codec_bos_embed + talker_inputs_embeds[:, -2, :] += codec_pad_embed + + eos_embed = thinker_embed_tokens( + torch.tensor([[talker_cfg.tts_text_end_token_id]], dtype=torch.long), + ) + pad_embed = thinker_embed_tokens( + torch.tensor([[talker_cfg.tts_text_pad_token_id]], dtype=torch.long), + ) + thinker_reply_part = torch.cat([thinker_reply_part[:, 1:, :], eos_embed, pad_embed], dim=1) + + context_len = talker_inputs_embeds.shape[1] + n_reply = thinker_reply_part.shape[1] + + proj_weight = proj_bias = None + for k, v in hf_model.state_dict().items(): + if "thinker_to_talker_proj.weight" in k: + proj_weight = v + if "thinker_to_talker_proj.bias" in k: + proj_bias = v + + proj = ThinkerToTalkerProjection(proj_weight.shape[1], proj_weight.shape[0]) + proj.proj.weight.data = proj_weight + if proj_bias is not None: + proj.proj.bias.data = proj_bias + proj.to(proj_weight.dtype) + + projected_context = proj(talker_inputs_embeds) + projected_reply = proj(thinker_reply_part) + + return { + "projected_context": projected_context, + "projected_reply": projected_reply, + "context_len": context_len, + "n_reply": n_reply, + "conditioning": conditioning, + "reference_mel": reference_mel, + } + + +def run_talker(talker_model, talker_adapter, talker_cfg, talker_input): + """Phase 4: Talker generates codec tokens.""" + projected_context = talker_input["projected_context"] + projected_reply = talker_input["projected_reply"] + context_len = talker_input["context_len"] + + codec_bos = talker_cfg.tts_codec_start_token_id + codec_eos = talker_cfg.tts_codec_end_token_id + codec_pad = talker_cfg.tts_codec_pad_token_id + codec_mask = talker_cfg.tts_codec_mask_token_id + + talker_input_ids = torch.cat([ + torch.full((1, context_len - 2), codec_mask, dtype=torch.long), + torch.tensor([[codec_pad]], dtype=torch.long), + torch.tensor([[codec_bos]], dtype=torch.long), + ], dim=1) + talker_attention_mask = torch.ones_like(talker_input_ids, dtype=torch.long) + + max_gen = min(600, 2048 - context_len - 10) + + # Re-set vision embeddings before each run (context encoding consumes them). + ve = projected_context.to(torch.bfloat16) + vm = torch.ones(1, context_len, 1, dtype=torch.int32) + reply = projected_reply.to(torch.bfloat16) + talker_model.set_vision_embeddings(ve, vm, thinker_reply_embeds=reply) + + t0 = time.time() + out = talker_adapter.generate( + input_ids=talker_input_ids, + attention_mask=talker_attention_mask, + max_new_tokens=max_gen, + eos_token_id=[codec_eos, codec_pad], + suppress_tokens=[codec_bos], + do_sample=True, temperature=0.9, top_k=40, top_p=0.8, + repetition_penalty=1.05, + ) + elapsed = time.time() - t0 + + gen_tokens = out[0, context_len:].tolist() + while gen_tokens and gen_tokens[-1] == codec_eos: + gen_tokens.pop() + return gen_tokens, elapsed + + +def run_token2wav(t2w, t2w_cfg, codec_codes, conditioning, reference_mel): + """Phase 5: Token2Wav DiT + BigVGAN synthesize a waveform.""" + code_tensor = torch.tensor([codec_codes], dtype=torch.long) + num_embeds = getattr(t2w_cfg.dit_config, "num_embeds", 8193) + if code_tensor.max() >= num_embeds: + code_tensor = code_tensor.clamp(0, num_embeds) + + t0 = time.time() + wav = t2w( + code=code_tensor, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=10, + guidance_scale=0.5, + ) + elapsed = time.time() - t0 + return wav, elapsed + + +# ========================================================================== +# Main +# ========================================================================== + +def main(): + parser = argparse.ArgumentParser( + description="Qwen2.5-Omni-7B speech synthesis on Neuron" + ) + parser.add_argument( + "--compile", action="store_true", + help="Compile all Neuron components (one-time, ~30 min)", + ) + parser.add_argument( + "--num-runs", type=int, default=1, + help="Number of inference runs per component for benchmarking (default: 1)", + ) + parser.add_argument( + "--prompt", default=DEFAULT_PROMPT, + help="Text prompt for speech generation", + ) + parser.add_argument( + "--system-prompt", default=DEFAULT_SYSTEM, + help="System prompt", + ) + parser.add_argument( + "--speaker", default=DEFAULT_SPEAKER, choices=["Ethan", "Chelsie"], + help="Speaker voice (default: Ethan)", + ) + parser.add_argument( + "--model-path", default=MODEL_PATH, + help=f"Model path (default: {MODEL_PATH})", + ) + parser.add_argument( + "--compiled-path", default=COMPILED_PATH, + help=f"Compiled artifacts path (default: {COMPILED_PATH})", + ) + parser.add_argument( + "--output", default="speech_output.wav", + help="Output WAV file path (default: speech_output.wav)", + ) + args = parser.parse_args() + + model_path = args.model_path + compiled_path = args.compiled_path + num_runs = args.num_runs + + if args.compile: + ok = compile_all(model_path, compiled_path) + sys.exit(0 if ok else 1) + + if not _check_compiled(compiled_path): + sys.exit(1) + + print("=" * 60) + print("Qwen2.5-Omni Speech Pipeline (Neuron, TP=4, single process)") + print("=" * 60) + print(f" Model: {model_path}") + print(f" Compiled: {compiled_path}") + print(f" Speaker: {args.speaker}") + print(f" Prompt: {args.prompt}") + print(f" Output: {args.output}") + print(f" Runs: {num_runs}") + t_total = time.time() + + # ----- Load everything once ----- + print("\n--- Loading models (one-time cost) ---") + t_load_total = time.time() + thinker_adapter, tokenizer, thinker_load = load_thinker(model_path, compiled_path) + hf_model, hf_load = load_hf_cpu(model_path) + talker_model, talker_adapter, talker_cfg, talker_load = load_talker(model_path, compiled_path) + t2w, t2w_cfg, dit_load = load_token2wav(model_path, compiled_path) + total_load = time.time() - t_load_total + print(f" Total model load time: {total_load:.1f}s") + + # ----- Run the pipeline num_runs times ----- + thinker_times, talker_times, t2w_times = [], [], [] + hidden_times, prep_times = [], [] + first_text = first_codes = first_wav = None + first_audio_duration = 0.0 + + for i in range(num_runs): + print(f"\n--- Run {i+1}/{num_runs} ---") + + thinker_result = run_thinker( + thinker_adapter, tokenizer, args.prompt, args.system_prompt, + ) + thinker_times.append(thinker_result["gen_time"]) + print( + f" [Thinker] {thinker_result['n_tokens']} tokens in " + f"{thinker_result['gen_time']:.3f}s - {thinker_result['gen_text'][:80]}" + ) + + outputs, full_ids, prompt_len, hidden_time = extract_hidden_states( + hf_model, thinker_result, + ) + hidden_times.append(hidden_time) + print(f" [Hidden] forward pass in {hidden_time:.2f}s") + + t0 = time.time() + talker_input = prepare_talker_input( + model_path, hf_model, outputs, full_ids, prompt_len, args.speaker, + ) + prep_time = time.time() - t0 + prep_times.append(prep_time) + print( + f" [Prep] context={talker_input['context_len']} tokens, " + f"reply={talker_input['n_reply']} tokens ({prep_time:.2f}s)" + ) + + codec_codes, talker_time = run_talker( + talker_model, talker_adapter, talker_cfg, talker_input, + ) + talker_times.append(talker_time) + print(f" [Talker] {len(codec_codes)} codec tokens in {talker_time:.3f}s") + if not codec_codes: + print(" Talker produced no tokens, aborting run.") + continue + + wav, t2w_time = run_token2wav( + t2w, t2w_cfg, codec_codes, + talker_input["conditioning"], talker_input["reference_mel"], + ) + t2w_times.append(t2w_time) + print(f" [Token2Wav] synthesized in {t2w_time:.2f}s") + + if first_text is None: + first_text = thinker_result["gen_text"] + first_codes = codec_codes + first_wav = wav + + # Free the per-run temporaries so the heap doesn't grow across runs. + del outputs, full_ids, talker_input + gc.collect() + + # ----- Write first run's audio ----- + if first_wav is not None and isinstance(first_wav, torch.Tensor) and first_wav.numel() > 0: + wav_np = first_wav.detach().cpu().float().numpy().flatten() + sf.write(args.output, wav_np, 24000) + first_audio_duration = len(wav_np) / 24000 + print(f"\n Audio: {first_audio_duration:.1f}s saved to {args.output}") + + total_time = time.time() - t_total + + def _avg(xs): + return sum(xs) / len(xs) if xs else 0.0 + + print("\n" + "=" * 60) + print("RESULTS") + print("=" * 60) + if first_text: + print(f" Text: {first_text[:200]}") + print("\n Model load time (one-time cost, excluded from pipeline avg):") + print(f" Thinker: {thinker_load:.1f}s") + print(f" HF CPU: {hf_load:.1f}s") + print(f" Talker: {talker_load:.1f}s") + print(f" DiT: {dit_load:.1f}s") + print(f" Total: {total_load:.1f}s") + print(f"\n Per-run latency (avg of {num_runs} runs):") + print(f" Thinker: {_avg(thinker_times):.3f}s") + print(f" Hidden: {_avg(hidden_times):.3f}s (HF CPU forward)") + print(f" Prep: {_avg(prep_times):.3f}s") + print(f" Talker: {_avg(talker_times):.3f}s") + print(f" Token2Wav: {_avg(t2w_times):.2f}s") + pipeline_avg = ( + _avg(thinker_times) + _avg(hidden_times) + _avg(prep_times) + + _avg(talker_times) + _avg(t2w_times) + ) + print(f" Pipeline: {pipeline_avg:.2f}s total") + if first_audio_duration > 0: + print(f"\n Audio: {first_audio_duration:.1f}s") + print(f" RTF: {pipeline_avg/first_audio_duration:.2f}x") + if num_runs > 1: + print(f"\n Per-run breakdown ({num_runs} runs):") + print(f" Thinker: {['%.3f' % t for t in thinker_times]}") + print(f" Hidden: {['%.3f' % t for t in hidden_times]}") + print(f" Talker: {['%.3f' % t for t in talker_times]}") + print(f" Token2Wav: {['%.2f' % t for t in t2w_times]}") + print(f"\n Wall time: {total_time:.1f}s (load + {num_runs} run(s))") + print(f" Output: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen2.5-Omni-7B/perf_test/3_bench_qwen25_omni_7b.sh b/contrib/models/Qwen2.5-Omni-7B/perf_test/3_bench_qwen25_omni_7b.sh new file mode 100644 index 00000000..e398db20 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/perf_test/3_bench_qwen25_omni_7b.sh @@ -0,0 +1,171 @@ +#!/bin/bash +set -e + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="/opt/dlami/nvme/models/Qwen2.5-Omni-7B" +PORT=8000 +RESULTS_DIR="/var/tmp/bench_results/qwen25_omni_7b" +mkdir -p "$RESULTS_DIR" + +# Helper: wait for vLLM server to be ready +wait_for_server() { + echo " Waiting for vLLM server to be ready..." + for i in $(seq 1 360); do + if curl -s http://localhost:$PORT/health > /dev/null 2>&1; then + echo " Server ready! (${i}s * 5 = $((i*5))s)" + return 0 + fi + sleep 5 + done + echo " ERROR: Server did not start within 1800s" + return 1 +} + +# Helper: run benchmark +run_bench() { + local config_name=$1 + local concurrency=$2 + local num_prompts=$3 + + echo " Benchmark: concurrency=$concurrency, prompts=$num_prompts" + vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$num_prompts" \ + --random-input-len 900 \ + --random-output-len 90 \ + --random-range-ratio 0.03 \ + --max-concurrency "$concurrency" \ + 2>&1 | tee "$RESULTS_DIR/${config_name}_c${concurrency}.txt" + echo "" +} + +# Helper: stop server +stop_server() { + echo " Stopping vLLM server..." + pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 5 +} + +# Helper: quick sanity check +sanity_check() { + echo " Running sanity check..." + curl -s http://localhost:$PORT/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role": "user", "content": "What is 1+1? Answer briefly."}], + "model": "'"$MODEL_PATH"'", + "max_tokens": 64, + "temperature": 0.0, + "stream": false + }' | python3 -c "import sys,json; r=json.load(sys.stdin); print(' Sanity:', r['choices'][0]['message']['content'][:100])" 2>/dev/null || echo " Sanity check: could not parse response" +} + +echo "==========================================" +echo "Qwen2.5-Omni-7B Performance Benchmark" +echo "==========================================" +echo "Model: $MODEL_PATH" +echo "Results: $RESULTS_DIR" +echo "" + +############################################################################### +# Config 1: BS=1, TP=4, non-CB (baseline latency) +# Qwen2.5-Omni-7B is a dense 7B model, TP=4 is sufficient +############################################################################### +CONFIG_NAME="bs1_tp4" +echo "--- Config 1: BS=1, TP=4, non-CB (baseline) ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 4 \ + --max-model-len 4096 \ + --max-num-seqs 1 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 4, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": false, + "qkv_kernel_enabled": false, + "qkv_nki_kernel_enabled": false, + "attn_kernel_enabled": false, + "batch_size": 1, + "ctx_batch_size": 1, + "tkg_batch_size": 1, + "max_context_length": 4096, + "seq_len": 4096, + "is_continuous_batching": false, + "enable_bucketing": false, + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +stop_server + +############################################################################### +# Config 2: BS=4, TP=4, CB (throughput) +############################################################################### +CONFIG_NAME="bs4_tp4_cb" +echo "--- Config 2: BS=4, TP=4, CB ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 4 \ + --max-model-len 4096 \ + --max-num-seqs 4 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 4, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": false, + "qkv_kernel_enabled": false, + "qkv_nki_kernel_enabled": false, + "attn_kernel_enabled": false, + "batch_size": 4, + "ctx_batch_size": 1, + "tkg_batch_size": 4, + "max_context_length": 4096, + "seq_len": 4096, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [4096], + "token_generation_buckets": [4096], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 4 64 +stop_server + +echo "==========================================" +echo "Qwen2.5-Omni-7B benchmarks complete!" +echo "Results saved to: $RESULTS_DIR" +echo "==========================================" +ls -la "$RESULTS_DIR" diff --git a/contrib/models/Qwen2.5-Omni-7B/perf_test/apply_vllm_neuron_patch_qwen25omni.py b/contrib/models/Qwen2.5-Omni-7B/perf_test/apply_vllm_neuron_patch_qwen25omni.py new file mode 100644 index 00000000..68131ed0 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/perf_test/apply_vllm_neuron_patch_qwen25omni.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +"""Add Qwen2.5-Omni model support to vllm-neuron. + +This patch should be applied AFTER the MiMo/MiniMax patch (apply_vllm_neuron_patch.py). +It handles: + 1. Config extraction: Qwen2.5-Omni nests text config under thinker_config.text_config + 2. Architecture mapping: "Qwen2_5OmniModel" -> "qwen2_5_omni" model type + 3. Layer count extraction: get_num_layers_from_hf_config for nested config +""" + +import os + +# Patch 1 & 2: neuronx_distributed_model_loader.py +LOADER_FILE = "/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/vllm_neuron/worker/neuronx_distributed_model_loader.py" + +with open(LOADER_FILE) as f: + content = f.read() + +# 1. In _get_model_configs: handle Qwen2.5-Omni nested config +content = content.replace( + ' if architecture in NEURON_MULTI_MODAL_MODELS:\n' + ' config = getattr(config, "text_config", None)\n' + ' num_key_value_heads = getattr(config, "num_key_value_heads", None)', + ' if architecture in NEURON_MULTI_MODAL_MODELS:\n' + ' config = getattr(config, "text_config", None)\n' + ' # Qwen2.5-Omni: text config is nested under thinker_config.text_config\n' + ' if architecture == "Qwen2_5OmniModel":\n' + ' thinker_config = getattr(config, "thinker_config", None)\n' + ' if thinker_config is not None:\n' + ' config = getattr(thinker_config, "text_config", config)\n' + ' num_key_value_heads = getattr(config, "num_key_value_heads", None)', +) + +# 2. In _get_neuron_model_cls: handle Qwen2_5OmniModel architecture +content = content.replace( + ' try:\n' + ' if "For" in architecture:', + ' # Qwen2.5-Omni: architecture is "Qwen2_5OmniModel" (no "For" in name)\n' + ' if architecture == "Qwen2_5OmniModel":\n' + ' return MODEL_TYPES["qwen2_5_omni"]["causal-lm"]\n' + '\n' + ' try:\n' + ' if "For" in architecture:', +) + +with open(LOADER_FILE, "w") as f: + f.write(content) + +print("Patch 1/2: neuronx_distributed_model_loader.py updated") + +# Patch 3: utils.py - handle Qwen2.5-Omni nested config for layer count +UTILS_FILE = "/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/vllm_neuron/worker/utils.py" + +with open(UTILS_FILE) as f: + content = f.read() + +content = content.replace( + ' # Sum nested configs (multimodal models)\n' + ' total = 0\n' + ' for attr in ["text_config", "vision_config"]:', + ' # Qwen2.5-Omni: check thinker_config.text_config\n' + ' thinker_config = getattr(hf_config, "thinker_config", None)\n' + ' if thinker_config is not None:\n' + ' text_config = getattr(thinker_config, "text_config", None)\n' + ' if text_config is not None:\n' + ' layers = getattr(text_config, "num_hidden_layers", None)\n' + ' if layers is not None:\n' + ' return layers\n' + '\n' + ' # Sum nested configs (multimodal models)\n' + ' total = 0\n' + ' for attr in ["text_config", "vision_config"]:', +) + +with open(UTILS_FILE, "w") as f: + f.write(content) + +print("Patch 2/2: utils.py updated") +print() +print("Qwen2.5-Omni vllm-neuron patch applied successfully!") +print(" 1. Added thinker_config.text_config extraction in _get_model_configs") +print(" 2. Added Qwen2_5OmniModel -> qwen2_5_omni mapping in _get_neuron_model_cls") +print(" 3. Added thinker_config.text_config layer count extraction in utils.py") diff --git a/contrib/models/Qwen2.5-Omni-7B/src/__init__.py b/contrib/models/Qwen2.5-Omni-7B/src/__init__.py index 3d7d8c14..62f2a61c 100644 --- a/contrib/models/Qwen2.5-Omni-7B/src/__init__.py +++ b/contrib/models/Qwen2.5-Omni-7B/src/__init__.py @@ -1,3 +1,10 @@ -from .modeling_qwen2_5_omni import NeuronQwen2_5OmniForCausalLM, Qwen2_5OmniInferenceConfig +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Importing this package applies an upstream bug fix for +# HuggingFaceGenerationAdapter.prepare_inputs_for_generation so that +# adapter.generate() does not raise NameError when forwarding +# tensor_capture_hook downstream. The fix is idempotent and only activates +# if the upstream file still contains the bug. -__all__ = ["NeuronQwen2_5OmniForCausalLM", "Qwen2_5OmniInferenceConfig"] +from . import _upstream_compat # noqa: F401 (side-effect import) diff --git a/contrib/models/Qwen2.5-Omni-7B/src/_model_path.py b/contrib/models/Qwen2.5-Omni-7B/src/_model_path.py new file mode 100644 index 00000000..da9c1c69 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/_model_path.py @@ -0,0 +1,22 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Helper for resolving the Qwen2.5-Omni-7B weight path. +# +# Honors ``$QWEN25_OMNI_MODEL_PATH`` if it points at a directory with a +# ``config.json``. Otherwise delegates to ``huggingface_hub.snapshot_download`` +# which is a no-op if the model is already cached and returns the real snapshot +# directory (including the commit hash) in either case. + +import os + + +HF_REPO_ID = "Qwen/Qwen2.5-Omni-7B" + + +def resolve_model_path() -> str: + env = os.environ.get("QWEN25_OMNI_MODEL_PATH") + if env and os.path.isfile(os.path.join(env, "config.json")): + return env + from huggingface_hub import snapshot_download + return snapshot_download(HF_REPO_ID) diff --git a/contrib/models/Qwen2.5-Omni-7B/src/_upstream_compat.py b/contrib/models/Qwen2.5-Omni-7B/src/_upstream_compat.py new file mode 100644 index 00000000..36bc015c --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/_upstream_compat.py @@ -0,0 +1,139 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Upstream compatibility shims for Qwen2.5-Omni. +# +# These patches sit here (not in src/neuronx_distributed_inference/) so that +# this contrib package has zero direct invasion on the upstream source tree. +# Each patch is idempotent: if upstream has already fixed the issue the +# original object stays unchanged. + +import logging +import inspect + +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + +logger = logging.getLogger(__name__) + + +def _patch_prepare_inputs_for_generation(): + """Fix upstream NameError in HuggingFaceGenerationAdapter. + + Upstream's prepare_inputs_for_generation builds model_inputs with + "tensor_capture_hook": tensor_capture_hook + but never defines tensor_capture_hook as a parameter or extracts it from + **kwargs, so adapter.generate() raises NameError. Qwen2.5-Omni's Talker + drives generation via adapter.generate(), so this breaks it out-of-the-box. + + This wrapper re-dispatches to the upstream method with tensor_capture_hook + materialized as a local so the original body sees a defined name. + """ + src = inspect.getsource(HuggingFaceGenerationAdapter.prepare_inputs_for_generation) + references_hook = '"tensor_capture_hook": tensor_capture_hook' in src + already_extracted = 'tensor_capture_hook = kwargs.get("tensor_capture_hook"' in src + if already_extracted or not references_hook: + return # upstream already consistent + + original = HuggingFaceGenerationAdapter.prepare_inputs_for_generation + + def patched(self, input_ids, *args, **kwargs): + # Inject the missing local via the frame's globals is fragile; the + # cleanest fix is to guarantee the name exists in the caller's scope. + # Since Python resolves bare names via function locals/globals, and + # the original body has no local binding, we surface it through the + # **kwargs extraction idiom already used for input_capture_hook. + # The patched body below mirrors upstream but extracts the hook. + import torch # noqa: F401 + self.prev_kv_cache_populated = self.neuron_model.kv_cache_populated + if self.neuron_model.kv_cache_populated: + input_ids = input_ids[:, -1:] + + past_key_values = kwargs.pop("past_key_values", None) + attention_mask = kwargs.pop("attention_mask", None) + inputs_embeds = kwargs.pop("inputs_embeds", None) + sampling_params = kwargs.pop("sampling_params", None) + adapter_ids = kwargs.pop("adapter_ids", None) + divergence_idx = kwargs.pop("divergence_idx", None) + + accepted_indices = kwargs.get("accepted_indices", None) + current_length = kwargs.get("current_length", None) + medusa_mask = kwargs.get("medusa_mask", None) + scatter_index = kwargs.get("scatter_index", None) + position_ids = kwargs.get("position_ids", None) + input_capture_hook = kwargs.get("input_capture_hook", None) + tensor_capture_hook = kwargs.get("tensor_capture_hook", None) + + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + if self.input_start_offsets: + if len(self.input_start_offsets) > 1: + import torch as _torch + position_ids += _torch.tensor( + self.input_start_offsets, + dtype=position_ids.dtype, + device=position_ids.device, + )[:, None] + else: + position_ids += self.input_start_offsets[0] + import torch as _torch + for i, offset in enumerate(self.input_start_offsets): + position_ids[i, 0:offset] = _torch.arange(offset) + else: + position_ids.masked_fill_(attention_mask == 0, 1) + + if self.neuron_model.kv_cache_populated: + import torch as _torch + position_ids = _torch.amax(position_ids, 1, keepdim=True) + position_ids = position_ids + 1 + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", False), + "attention_mask": attention_mask, + "medusa_args": (accepted_indices, current_length, medusa_mask, scatter_index), + "sampling_params": sampling_params, + "input_capture_hook": input_capture_hook, + "tensor_capture_hook": tensor_capture_hook, + "adapter_ids": adapter_ids, + } + ) + + tf_args = [] + if self.neuron_config.tensor_replacement_config: + from neuronx_distributed_inference.utils.tensor_replacement.registry import ( + TensorReplacementRegister, + ) + reg = TensorReplacementRegister.get_instance() + tf, masks = reg.step_args( + self.generation_step, + divergence_idx=True if divergence_idx else False, + ) + tf_args = tf + masks + + if tf_args: + model_inputs["tf_args"] = tf_args + + additional_kwargs = self.neuron_model.get_required_kwargs() + for arg in additional_kwargs: + model_inputs.update({arg: kwargs.get(arg, None)}) + + return model_inputs + + HuggingFaceGenerationAdapter.prepare_inputs_for_generation = patched + HuggingFaceGenerationAdapter.prepare_inputs_for_generation.__doc__ = ( + "[Qwen2.5-Omni contrib patched] " + (original.__doc__ or "") + ) + logger.info( + "Qwen2.5-Omni contrib: patched HuggingFaceGenerationAdapter." + "prepare_inputs_for_generation to extract tensor_capture_hook from kwargs." + ) + + +_patch_prepare_inputs_for_generation() diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni.py new file mode 100644 index 00000000..4e889aa0 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni.py @@ -0,0 +1,1106 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni support for NXD inference. +# +# Provides two modes: +# 1. Text-only (Thinker): NeuronQwen25OmniForCausalLM +# - Reuses Qwen2 decoder with thinker.model.* prefix remapping +# 2. Multimodal (Vision + Text): NeuronQwen25OmniMultimodalForCausalLM +# - Vision encoder: Qwen2.5-Omni ViT (SwiGLU, RMSNorm, separate QKV) +# - Text decoder: Qwen2-VL text model (multimodal RoPE) +# +# Reference: https://huggingface.co/Qwen/Qwen2.5-Omni-7B + +"""Qwen2.5-Omni model for NXD inference.""" + +import copy +import gc +import logging +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import torch +from transformers.modeling_outputs import CausalLMOutputWithPast + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.image_to_text_model_base import ( + ImageToTextInferenceConfig, + NeuronBaseForImageToText, +) +from neuronx_distributed_inference.models.model_wrapper import VISION_ENCODER_MODEL_TAG +from neuronx_distributed_inference.models.qwen2.modeling_qwen2 import ( + NeuronQwen2ForCausalLM, + NeuronQwen2Model, + convert_state_dict_to_fused_qkv, +) + +logger = logging.getLogger("Neuron") + + +# Attributes to extract from the thinker's text_config to the top-level config. +_TEXT_CONFIG_ATTRS = [ + "hidden_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "vocab_size", + "intermediate_size", + "max_position_embeddings", + "rope_theta", + "rms_norm_eps", + "hidden_act", + "tie_word_embeddings", + "max_window_layers", + "use_sliding_window", + "sliding_window", +] + + +class Qwen25OmniInferenceConfig(InferenceConfig): + """Inference config for Qwen2.5-Omni (Thinker text component). + + Handles the nested config structure: the HF config has attributes under + thinker_config.text_config that we need at the top level for NxDI. + """ + + def add_derived_config(self): + self.num_cores_per_group = 1 + # Qwen2.5-Omni text model has QKV bias but no output projection bias + self.qkv_bias = True + self.o_bias = False + + # Extract text config attributes from nested thinker_config + if hasattr(self, "thinker_config"): + thinker_cfg = self.thinker_config + # When loaded from saved JSON, thinker_config is a plain dict + if isinstance(thinker_cfg, dict): + thinker_cfg = SimpleNamespace(**thinker_cfg) + self.thinker_config = thinker_cfg + + text_cfg = thinker_cfg.text_config + if isinstance(text_cfg, dict): + text_cfg = SimpleNamespace(**text_cfg) + thinker_cfg.text_config = text_cfg + + # Text config attributes always take precedence over top-level + # defaults from PretrainedConfig (e.g. tie_word_embeddings defaults + # to True at the top level but is False in text_config). + for attr in _TEXT_CONFIG_ATTRS: + if hasattr(text_cfg, attr): + setattr(self, attr, getattr(text_cfg, attr)) + + # Set pad_token_id from thinker_config + if hasattr(thinker_cfg, "pad_token_id"): + if not hasattr(self, "pad_token_id") or self.pad_token_id is None: + self.pad_token_id = thinker_cfg.pad_token_id + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "pad_token_id", + "vocab_size", + "max_position_embeddings", + "rope_theta", + "rms_norm_eps", + "hidden_act", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + +class NeuronQwen25OmniForCausalLM(NeuronQwen2ForCausalLM): + """Qwen2.5-Omni Thinker text model for Causal LM on Neuron. + + Reuses the Qwen2 model architecture since the Thinker's text backbone + is architecturally identical to Qwen2.5. The main differences are: + - Weight keys are prefixed with 'thinker.model.' / 'thinker.lm_head.' + - Non-text weights (talker, token2wav, audio_tower, visual) are discarded + """ + + _model_cls = NeuronQwen2Model + _STATE_DICT_MODEL_PREFIX = "thinker.model." + + @staticmethod + def load_hf_model(model_path: str, **kwargs): + """Load the full Qwen2.5-Omni model from HuggingFace. + + Note: We load the full model and filter to thinker text weights + in convert_hf_to_neuron_state_dict. + """ + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + **kwargs, + ) + + @classmethod + def get_config_cls(cls) -> Type[Qwen25OmniInferenceConfig]: + return Qwen25OmniInferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: Dict[str, Any], + config: Qwen25OmniInferenceConfig, + ) -> Dict[str, Any]: + """Convert Qwen2.5-Omni state dict to NxDI format. + + 1. Keep only thinker text model weights (discard talker, token2wav, audio, visual) + 2. Map 'thinker.lm_head.' -> 'lm_head.' + 3. Apply standard Qwen2 conversions (fused QKV, rank utils, etc.) + """ + neuron_config = config.neuron_config + + # Filter: keep only thinker text weights and lm_head + # After base-class prefix stripping, 'thinker.model.*' becomes '*' + # but 'thinker.lm_head.*' is NOT stripped, so handle it here. + keys_to_remove = [] + keys_to_rename = {} + for key in state_dict: + if key.startswith("thinker.lm_head."): + # Map thinker.lm_head.weight -> lm_head.weight + new_key = key.replace("thinker.lm_head.", "lm_head.", 1) + keys_to_rename[key] = new_key + elif not key.startswith(("layers.", "embed_tokens.", "norm.", "lm_head.")): + # After base-class prefix stripping, valid text keys start with + # layers.*, embed_tokens.*, norm.*, or lm_head.* + # Everything else (talker, token2wav, audio_tower, visual) should be removed + keys_to_remove.append(key) + + # Apply renames + for old_key, new_key in keys_to_rename.items(): + state_dict[new_key] = state_dict.pop(old_key) + + # Remove non-text weights + for key in keys_to_remove: + del state_dict[key] + + gc.collect() + logger.info( + "Filtered state dict to %d thinker text model weights", len(state_dict) + ) + + # Add rank utilities (same as Qwen2) + if neuron_config.vocab_parallel: + state_dict["embed_tokens.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + num_layers = config.num_hidden_layers + tp_degree = neuron_config.tp_degree + for i in range(num_layers): + state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + if neuron_config.fused_qkv: + state_dict = convert_state_dict_to_fused_qkv(state_dict, config) + + state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + return state_dict + + def get_compiler_args(self): + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + "--auto-cast=none " + "--model-type transformer " + "-O1" + ) + compiler_args += ( + " --tensorizer-options='" + "--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 " + "--vectorize-strided-dma'" + ) + compiler_args += " --internal-hlo2tensorizer-options='--verify-hlo=true'" + return compiler_args + + +# --------------------------------------------------------------------------- +# Multimodal (Vision + Text) support +# --------------------------------------------------------------------------- + +# Keys from thinker_config.text_config to copy to the top-level multimodal config +_MULTIMODAL_TEXT_CONFIG_KEYS = [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "pad_token_id", + "vocab_size", + "intermediate_size", + "max_position_embeddings", + "rms_norm_eps", + "rope_theta", + "rope_scaling", + "hidden_act", + "bos_token_id", + "eos_token_id", + "qkv_bias", + "o_bias", + "image_token_id", + "vision_token_id", + "video_token_id", + "vision_start_token_id", + "vision_end_token_id", +] + + +class Qwen25OmniMultimodalInferenceConfig(ImageToTextInferenceConfig): + """Inference config for Qwen2.5-Omni multimodal (vision + text). + + Handles the nested config structure where text_config and vision_config + are under thinker_config. Extracts them to the top level as required + by ImageToTextInferenceConfig. + """ + + def __init__( + self, + text_neuron_config, + vision_neuron_config, + fused_spec_config=None, + load_config=None, + metadata: Optional[Dict] = None, + **kwargs, + ): + # Extract text_config and vision_config from thinker_config + # The HF config nests them: thinker_config.text_config, thinker_config.vision_config + thinker = kwargs.get("thinker_config", None) + if thinker is not None: + if hasattr(thinker, "__dict__") and not isinstance(thinker, dict): + thinker = vars(thinker) + if isinstance(thinker, dict): + if "text_config" not in kwargs and "text_config" in thinker: + tc = thinker["text_config"] + kwargs["text_config"] = ( + vars(tc) if hasattr(tc, "__dict__") and not isinstance(tc, dict) else tc + ) + if "vision_config" not in kwargs and "vision_config" in thinker: + vc = thinker["vision_config"] + kwargs["vision_config"] = ( + vars(vc) if hasattr(vc, "__dict__") and not isinstance(vc, dict) else vc + ) + # Extract audio_config from thinker_config + if "audio_config" not in kwargs and "audio_config" in thinker: + ac = thinker["audio_config"] + kwargs["audio_config"] = ( + vars(ac) if hasattr(ac, "__dict__") and not isinstance(ac, dict) else ac + ) + # Extract special token IDs from thinker_config + for token_key in [ + "image_token_index", "audio_token_index", "video_token_index", + "audio_start_token_id", "audio_end_token_id", + "vision_start_token_id", "vision_end_token_id", + "vision_token_id", "pad_token_id", + ]: + if token_key in thinker and token_key not in kwargs: + kwargs[token_key] = thinker[token_key] + + super().__init__( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + fused_spec_config=fused_spec_config, + load_config=load_config, + metadata=metadata, + **kwargs, + ) + + self.add_special_config() + self.validate_model_supported_configs() + + def add_special_config(self): + self.num_cores_per_group = 1 + self.qkv_bias = True + self.o_bias = False + + # Vision config derived attributes + self.vision_config.head_dim = ( + self.vision_config.embed_dim // self.vision_config.num_heads + ) + self.vision_config.num_cores_per_group = 1 + + # Vision encoder MUST use separate Q/K/V (not fused) + if getattr(self.vision_config.neuron_config, "fused_qkv", True): + self.vision_config.neuron_config.fused_qkv = False + logger.info( + "Qwen2.5-Omni vision encoder: set fused_qkv=False " + "(separate Q/K/V projections)" + ) + + # Copy text config keys to top-level (for compatibility) + for key in _MULTIMODAL_TEXT_CONFIG_KEYS: + if hasattr(self.text_config, key): + setattr(self, key, getattr(self.text_config, key)) + + # Map Qwen2.5-Omni token IDs to Qwen2-VL compatible names + if hasattr(self, "image_token_index"): + self.image_token_id = self.image_token_index + self.text_config.image_token_id = self.image_token_index + if hasattr(self, "video_token_index"): + self.video_token_id = self.video_token_index + self.text_config.video_token_id = self.video_token_index + if hasattr(self, "audio_token_index"): + self.audio_token_id = self.audio_token_index + self.text_config.audio_token_id = self.audio_token_index + + # Set pad_token_id + if hasattr(self, "pad_token_id"): + self.text_config.pad_token_id = self.pad_token_id + + # Store audio_config as SimpleNamespace for attribute access + if hasattr(self, "audio_config") and isinstance(self.audio_config, dict): + self.audio_config = SimpleNamespace(**self.audio_config) + + def validate_model_supported_configs(self): + # Disable unsupported features for text model + unsupported_text = [ + "is_prefix_caching", + "is_chunked_prefill", + "is_medusa", + "enable_fused_speculation", + ] + for cfg_name in unsupported_text: + if getattr(self.text_config.neuron_config, cfg_name, False): + setattr(self.text_config.neuron_config, cfg_name, False) + logger.warning( + f"Qwen2.5-Omni text model does not support " + f"'{cfg_name}'. Disabled." + ) + + # Disable unsupported features for vision model + unsupported_vision = [ + "sequence_parallel_enabled", + "flash_decoding_enabled", + "qkv_kernel_enabled", + ] + for cfg_name in unsupported_vision: + if getattr(self.vision_config.neuron_config, cfg_name, False): + setattr(self.vision_config.neuron_config, cfg_name, False) + logger.warning( + f"Qwen2.5-Omni vision model does not support " + f"'{cfg_name}'. Disabled." + ) + + def get_required_attributes(self) -> List[str]: + return [ + "text_config", + "vision_config", + "text_config.hidden_size", + "text_config.num_attention_heads", + "text_config.num_hidden_layers", + "text_config.num_key_value_heads", + "text_config.pad_token_id", + "text_config.vocab_size", + "text_config.max_position_embeddings", + "text_config.rope_theta", + "text_config.rms_norm_eps", + "text_config.hidden_act", + "vision_config.depth", + "vision_config.embed_dim", + "vision_config.num_heads", + "vision_config.in_channels", + "vision_config.patch_size", + "vision_config.spatial_merge_size", + "vision_config.temporal_patch_size", + "vision_config.out_hidden_size", + "vision_config.intermediate_size", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + +class NeuronQwen25OmniMultimodalForCausalLM(NeuronBaseForImageToText): + """Qwen2.5-Omni multimodal model (vision encoder + text decoder) on Neuron. + + Reuses Qwen2-VL text model components (same mRoPE architecture) and + the Qwen2.5-Omni vision encoder (SwiGLU, RMSNorm, separate QKV). + """ + + # Import lazily to avoid circular imports + @staticmethod + def _get_text_model_cls(): + from neuronx_distributed_inference.models.qwen2_vl.modeling_qwen2_vl_text import ( + NeuronQwen2VLTextModel, + ) + return NeuronQwen2VLTextModel + + @staticmethod + def _get_text_model_wrapper(): + from neuronx_distributed_inference.models.qwen2_vl.modeling_qwen2_vl_text import ( + Qwen2VLTextModelWrapper, + ) + return Qwen2VLTextModelWrapper + + @staticmethod + def _get_vision_model_cls(): + from modeling_qwen25_omni_vision import ( + NeuronQwen25OmniVisionModel, + ) + return NeuronQwen25OmniVisionModel + + @staticmethod + def _get_vision_model_wrapper(): + from modeling_qwen25_omni_vision import ( + Qwen25OmniVisionModelWrapper, + ) + return Qwen25OmniVisionModelWrapper + + @staticmethod + def _get_audio_encoder_cls(): + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + return NeuronQwen25OmniAudioEncoder + + @staticmethod + def _get_talker_cls(): + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + ) + return NeuronQwen25OmniTalker + + @staticmethod + def _get_neuron_talker_cls(): + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, + ) + return NeuronQwen25OmniTalkerForCausalLM + + @staticmethod + def _get_talker_config_cls(): + from modeling_qwen25_omni_talker import ( + TalkerInferenceConfig, + ) + return TalkerInferenceConfig + + @staticmethod + def _get_thinker_projection_cls(): + from modeling_qwen25_omni_talker import ( + ThinkerToTalkerProjection, + ) + return ThinkerToTalkerProjection + + @staticmethod + def _get_token2wav_cls(): + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2Wav, + ) + return NeuronQwen25OmniToken2Wav + + @staticmethod + def _get_neuron_token2wav_cls(): + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + return NeuronQwen25OmniToken2WavWithNeuronDiT + + def __init__(self, *args, **kwargs): + super().__init__( + self._get_text_model_cls(), + self._get_vision_model_cls(), + self._get_text_model_wrapper(), + self._get_vision_model_wrapper(), + *args, + **kwargs, + ) + self.audio_encoder = None + self.talker = None + self.token2wav = None + self.speaker_map = {} + + def get_vision_compiler_args(self) -> str: + return ( + "--auto-cast=none --model-type=transformer " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 ' -O1 " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + def get_compiler_args(self) -> str: + return ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + "--auto-cast=none --model-type=transformer -O1 " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 " + "--vectorize-strided-dma' " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + def get_required_kwargs(self) -> List[str]: + return ["pixel_values", "vision_mask", "image_grid_thw"] + + def enable_vision_encoder( + self, enable_wlt_optimization: bool = True, **model_init_kwargs + ): + new_config = copy.deepcopy(self.config) + self.vision_encoder_model = self._get_vision_model_wrapper()( + config=new_config, + model_cls=self._get_vision_model_cls(), + tag=VISION_ENCODER_MODEL_TAG, + compiler_args=self.get_vision_compiler_args(), + model_init_kwargs=model_init_kwargs, + priority_model_idx=(0 if enable_wlt_optimization else None), + pipeline_execution=True, + return_ranked_to_cpu=False, + ) + self.vision_models.append(self.vision_encoder_model) + + def enable_audio_encoder(self, state_dict=None): + """Initialize the audio encoder (CPU frontend/postprocessor + Neuron transformer). + + The audio encoder is a hybrid CPU+Neuron module: + - CPU: Conv1d frontend, positional embeddings, chunking + - Neuron (TP=4): 32 transformer layers with block-diagonal attention + - CPU: AvgPool, LayerNorm, Linear projection + + The CPU components are loaded from the converted state dict. + The Neuron transformer must be compiled/loaded separately via + compile_audio_encoder() and load_audio_encoder(). + + Args: + state_dict: Converted state dict containing audio_tower.* keys + (already split into frontend.*, transformer.*, postprocessor.*). + If None, the encoder is created with random weights. + """ + audio_config = getattr(self.config, "audio_config", None) + if audio_config is None: + logger.warning( + "No audio_config found in model config. " + "Audio encoder will not be initialized." + ) + return + + AudioEncoderCls = self._get_audio_encoder_cls() + dtype = torch.bfloat16 + if hasattr(self.config, "neuron_config"): + dtype = getattr(self.config.neuron_config, "torch_dtype", dtype) + + if state_dict is not None: + self.audio_encoder = AudioEncoderCls.from_pretrained_state_dict( + audio_config, state_dict, dtype=dtype + ) + else: + self.audio_encoder = AudioEncoderCls(audio_config, dtype=dtype) + + self.audio_encoder.eval() + logger.info( + "Audio encoder initialized (CPU frontend/postprocessor, " + "Neuron transformer pending compile/load)" + ) + + def compile_audio_encoder(self, compiled_model_path, audio_neuron_config=None): + """Compile the audio encoder transformer layers on Neuron. + + Args: + compiled_model_path: Path to save compiled model artifacts. + audio_neuron_config: Optional NeuronConfig for audio transformer. + If None, creates a default config with TP matching the text model. + """ + if self.audio_encoder is None: + raise RuntimeError("Call enable_audio_encoder() first") + + from modeling_qwen25_omni_audio import ( + AudioEncoderInferenceConfig, + NeuronQwen25OmniForAudioEncoding, + ) + + audio_config = getattr(self.config, "audio_config", None) + if isinstance(audio_config, dict): + from types import SimpleNamespace + audio_config = SimpleNamespace(**audio_config) + + if audio_neuron_config is None: + # Default: match text model TP degree, reasonable seq_len buckets + tp_degree = self.neuron_config.tp_degree + audio_neuron_config = NeuronConfig( + tp_degree=tp_degree, + torch_dtype=self.neuron_config.torch_dtype, + batch_size=1, + # Audio seq_len buckets: typical values after conv + buckets=[256, 512, 1024, 1500], + ) + + audio_inf_config = AudioEncoderInferenceConfig( + neuron_config=audio_neuron_config, + audio_config=vars(audio_config) if hasattr(audio_config, '__dict__') else audio_config, + ) + + audio_app = NeuronQwen25OmniForAudioEncoding(audio_inf_config) + audio_app.compile(compiled_model_path) + + logger.info("Audio encoder transformer compiled to %s", compiled_model_path) + return audio_app + + def load_audio_encoder(self, compiled_model_path, audio_app=None): + """Load compiled audio encoder transformer layers. + + Args: + compiled_model_path: Path to compiled model artifacts. + audio_app: Optional pre-compiled NeuronQwen25OmniForAudioEncoding. + If None, creates and loads from compiled_model_path. + """ + if self.audio_encoder is None: + raise RuntimeError("Call enable_audio_encoder() first") + + if audio_app is None: + from modeling_qwen25_omni_audio import ( + AudioEncoderInferenceConfig, + NeuronQwen25OmniForAudioEncoding, + ) + # Load from compiled artifacts + audio_app = NeuronQwen25OmniForAudioEncoding.load(compiled_model_path) + + self.audio_encoder.transformer = audio_app.model + logger.info("Audio encoder transformer loaded from %s", compiled_model_path) + + def enable_talker(self, state_dict=None, use_neuron=False): + """Initialize the Talker model. + + The Talker converts Thinker hidden states into codec tokens for + speech synthesis. + + Args: + state_dict: Converted state dict containing talker keys. + If None, the Talker is created with random weights. + use_neuron: If True, use the Neuron-compiled Talker + (NeuronQwen25OmniTalkerForCausalLM). Requires separate + compilation with compile_talker() / load_talker(). + If False (default), use the CPU-based HF wrapper. + """ + talker_config = getattr(self.config, "talker_config", None) + if talker_config is None: + logger.warning( + "No talker_config found in model config. " + "Talker will not be initialized." + ) + return + + if use_neuron: + # Neuron Talker — just mark for later compile/load + logger.info( + "Neuron Talker mode enabled. Call compile_talker() or " + "load_talker() to activate." + ) + self._talker_use_neuron = True + self._talker_state_dict = state_dict + # Initialize CPU projection for thinker states + if state_dict is not None: + ProjCls = self._get_thinker_projection_cls() + self.thinker_to_talker_proj = ProjCls.from_state_dict(state_dict) + logger.info("Thinker→Talker CPU projection initialized") + else: + # CPU Talker (default) + TalkerCls = self._get_talker_cls() + dtype = torch.bfloat16 + + if state_dict is not None: + self.talker = TalkerCls.from_pretrained_state_dict( + talker_config, state_dict, dtype=dtype + ) + else: + self.talker = TalkerCls(talker_config, dtype=dtype) + + logger.info("Talker initialized on CPU") + + def compile_talker(self, compiled_model_path, talker_neuron_config=None): + """Compile the Talker on Neuron. + + Creates a NeuronQwen25OmniTalkerForCausalLM, converts the state + dict (fusing embedding + projection), and compiles. + + Recommended neuron_config: + - tp_degree=4 (12 Q heads / 4 = 3 per rank) + - batch_size=1 + - seq_len=4096 (max codec tokens) + + Args: + compiled_model_path: Path to save compiled model + talker_neuron_config: NeuronConfig for Talker compilation. + If None, creates a default config with TP=4. + """ + talker_config = getattr(self.config, "talker_config", None) + if talker_config is None: + raise RuntimeError("No talker_config in model config") + + NeuronTalkerCls = self._get_neuron_talker_cls() + TalkerConfigCls = self._get_talker_config_cls() + + if talker_neuron_config is None: + from modeling_qwen25_omni_talker import ( + TalkerNeuronConfig, + ) + talker_neuron_config = TalkerNeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=4096, + torch_dtype=torch.bfloat16, + ) + + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + inference_config = TalkerConfigCls( + neuron_config=talker_neuron_config, + load_config=load_pretrained_config(hf_config=talker_config), + ) + + app = NeuronTalkerCls(self.model_path, config=inference_config) + app.compile(compiled_model_path) + logger.info("Talker compiled to %s", compiled_model_path) + + def load_talker(self, compiled_model_path, talker_app=None): + """Load a compiled Talker from disk. + + Args: + compiled_model_path: Path to compiled model artifacts. + talker_app: Optional pre-compiled NeuronQwen25OmniTalkerForCausalLM. + """ + if talker_app is None: + NeuronTalkerCls = self._get_neuron_talker_cls() + talker_app = NeuronTalkerCls.load(compiled_model_path) + + self.talker = talker_app + logger.info("Neuron Talker loaded from %s", compiled_model_path) + + def enable_token2wav( + self, state_dict=None, speaker_dict_path=None, use_neuron_dit=False + ): + """Initialize the Token2Wav vocoder. + + Token2Wav converts codec tokens from the Talker into audio + waveforms using DiT + BigVGAN. + + Args: + state_dict: Converted state dict containing token2wav keys. + If None, Token2Wav is created with random weights. + speaker_dict_path: Path to spk_dict.pt for speaker conditioning. + If provided, loads the speaker map for audio generation. + use_neuron_dit: If True, use the Neuron-accelerated version + (DiT on Neuron, ODE loop + BigVGAN on CPU). + Call compile_token2wav_dit() / load_token2wav_dit() to compile. + """ + token2wav_config = getattr(self.config, "token2wav_config", None) + if token2wav_config is None: + logger.warning( + "No token2wav_config found in model config. " + "Token2Wav will not be initialized." + ) + return + + if use_neuron_dit: + Token2WavCls = self._get_neuron_token2wav_cls() + else: + Token2WavCls = self._get_token2wav_cls() + + if state_dict is not None: + self.token2wav = Token2WavCls.from_pretrained_state_dict( + token2wav_config, state_dict + ) + else: + self.token2wav = Token2WavCls(token2wav_config) + + if speaker_dict_path is not None: + self.speaker_map = Token2WavCls.load_speaker_dict(speaker_dict_path) + logger.info( + "Loaded %d speakers: %s", + len(self.speaker_map), + list(self.speaker_map.keys()), + ) + + mode = "CPU (Neuron DiT capable)" if use_neuron_dit else "CPU" + logger.info("Token2Wav initialized on %s (float32)", mode) + + def compile_token2wav_dit(self, compiled_path, **kwargs): + """Compile the Token2Wav DiT on Neuron. + + Requires enable_token2wav(use_neuron_dit=True) to be called first. + + Args: + compiled_path: Directory to save compiled DiT + **kwargs: Additional args passed to compile_dit() + (max_codec_len, max_mel_len, batch_size) + """ + if self.token2wav is None: + raise RuntimeError("Call enable_token2wav(use_neuron_dit=True) first") + if not hasattr(self.token2wav, "compile_dit"): + raise RuntimeError( + "Token2Wav is not Neuron DiT capable. " + "Call enable_token2wav(use_neuron_dit=True)" + ) + self.token2wav.compile_dit(compiled_path, **kwargs) + logger.info("Token2Wav DiT compiled to %s", compiled_path) + + def load_token2wav_dit(self, compiled_path): + """Load a compiled Token2Wav DiT. + + Args: + compiled_path: Directory containing compiled DiT + """ + if self.token2wav is None: + raise RuntimeError("Call enable_token2wav(use_neuron_dit=True) first") + if not hasattr(self.token2wav, "load_dit"): + raise RuntimeError( + "Token2Wav is not Neuron DiT capable. " + "Call enable_token2wav(use_neuron_dit=True)" + ) + self.token2wav.load_dit(compiled_path) + logger.info("Token2Wav DiT loaded from %s", compiled_path) + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, + inference_config: Qwen25OmniMultimodalInferenceConfig, + include_talker: bool = False, + include_token2wav: bool = False, + ) -> dict: + """Convert Qwen2.5-Omni full state dict to NxDI format. + + 1. Remap thinker.* prefixes to Qwen2-VL compatible format + 2. Apply vision encoder conversion (separate Q/K/V, SwiGLU MLP) + 3. Apply audio encoder conversion (strip prefix, cast dtype) + 4. Apply text model conversion (fused QKV, attention key renames) + 5. Optionally strip talker.* and token2wav.* prefixes + + Args: + state_dict: Full HF state dict + inference_config: Multimodal inference config + include_talker: If True, include talker keys (stripped prefix) + include_token2wav: If True, include token2wav keys (stripped prefix) + """ + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + from modeling_qwen25_omni_vision import ( + NeuronQwen25OmniForImageEncoding, + ) + from neuronx_distributed_inference.models.qwen2_vl.modeling_qwen2_vl_text import ( + NeuronQwen2VLTextForCausalLM, + ) + + # Step 1: Remap thinker.* prefixes, optionally keep talker/token2wav + remapped = {} + talker_state = {} + token2wav_state = {} + for key, value in state_dict.items(): + if key.startswith("thinker.model."): + remapped["model." + key[len("thinker.model."):]] = value + elif key.startswith("thinker.lm_head."): + remapped[key[len("thinker."):]] = value + elif key.startswith("thinker.visual."): + remapped["visual." + key[len("thinker.visual."):]] = value + elif key.startswith("thinker.audio_tower."): + remapped["audio_tower." + key[len("thinker.audio_tower."):]] = value + elif key.startswith("talker.") and include_talker: + talker_state[key[len("talker."):]] = value + elif key.startswith("token2wav.") and include_token2wav: + token2wav_state[key[len("token2wav."):]] = value + + del state_dict + gc.collect() + + logger.info( + "Remapped %d thinker keys%s%s", + len(remapped), + " + %d talker keys" % len(talker_state) if talker_state else "", + " + %d token2wav keys" % len(token2wav_state) if token2wav_state else "", + ) + + # Step 2: Vision encoder conversion + remapped = NeuronQwen25OmniForImageEncoding.convert_hf_to_neuron_state_dict( + remapped, inference_config + ) + + # Step 3: Audio encoder conversion (strip prefix, cast dtype) + audio_dtype = getattr( + inference_config, "torch_dtype", torch.bfloat16 + ) + if hasattr(inference_config, "neuron_config"): + audio_dtype = getattr( + inference_config.neuron_config, "torch_dtype", audio_dtype + ) + remapped = NeuronQwen25OmniAudioEncoder.convert_hf_to_neuron_state_dict( + remapped, dtype=audio_dtype + ) + + # Step 4: Text model conversion + remapped = NeuronQwen2VLTextForCausalLM.convert_hf_to_neuron_state_dict( + remapped, inference_config.text_config + ) + + # Step 5: Merge talker and token2wav state dicts + if talker_state: + for k, v in talker_state.items(): + remapped["talker." + k] = v + logger.info("Included %d talker keys in output", len(talker_state)) + if token2wav_state: + for k, v in token2wav_state.items(): + remapped["token2wav." + k] = v + logger.info("Included %d token2wav keys in output", len(token2wav_state)) + + return remapped + + def get_padding_length(self, input_ids): + """Get the context encoding bucket size for given input_ids.""" + buckets = self.context_encoding_model.config.neuron_config.buckets + for val in buckets: + if val >= input_ids.shape[1]: + return val + raise Exception("No bucket found for provided input_ids!") + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + seq_ids: Optional[torch.LongTensor] = None, + sampling_params: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + vision_mask: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + feature_attention_mask: Optional[torch.LongTensor] = None, + adapter_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + medusa_args=None, + input_capture_hook: Optional[Callable] = None, + tensor_capture_hook: Optional[Callable] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + from neuronx_distributed_inference.models.llama4.utils.encoder_utils import ( + generate_positions_from_mask, + pad_positions, + ) + + pad_limit = self.get_padding_length(input_ids) + is_context_encoding = input_ids.shape[-1] > 1 + + # --- Audio encoding (CPU frontend + Neuron transformer + CPU postprocessor) --- + audio_embeddings = None + audio_positions = None + if ( + input_features is not None + and self.audio_encoder is not None + and is_context_encoding + ): + audio_token_id = getattr( + self.config, "audio_token_id", None + ) or getattr(self.config, "audio_token_index", 151646) + + with torch.no_grad(): + # Prepare audio features (same as HF get_audio_features) + if feature_attention_mask is not None: + audio_feature_lengths = feature_attention_mask.sum(-1) + # (batch, mel_len, n_mels) -> mask -> (n_mels, valid_len) + input_features_flat = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + else: + input_features_flat = input_features.squeeze(0).permute(1, 0) + audio_feature_lengths = torch.tensor( + [input_features_flat.shape[1]], dtype=torch.long + ) + + aftercnn_lens, audio_output_lens = ( + self.audio_encoder._get_feat_extract_output_lengths( + audio_feature_lengths + ) + ) + audio_embeddings = self.audio_encoder( + input_features_flat, + feature_lens=audio_feature_lengths, + aftercnn_lens=aftercnn_lens, + ) + + # Find audio token positions for scattering + audio_mask_bool = (input_ids == audio_token_id) + if audio_mask_bool.any() and audio_embeddings is not None: + audio_positions = generate_positions_from_mask( + audio_mask_bool.squeeze() + ) + + # --- Vision encoding (Neuron) --- + vision_embeddings = None + vision_positions = None + if ( + (pixel_values is not None) + and is_context_encoding + and pixel_values.sum() != 0 + ): + image_token_id = getattr(self.config, "image_token_id", None) or getattr( + self.config, "image_token_index", 151655 + ) + vision_mask_bool = (input_ids == image_token_id) + if vision_mask_bool.any(): + vision_positions = generate_positions_from_mask( + vision_mask_bool.squeeze() + ) + vision_embeddings = self.vision_encoder_model( + pixel_values.to(self.vision_config.neuron_config.torch_dtype), + image_grid_thw, + ) + + # --- Combine multimodal embeddings for scattering --- + # The text model's encode_vision_to_input scatters embeddings at + # specified positions. We combine audio + vision into one tensor. + if audio_embeddings is not None and vision_embeddings is not None: + # Both audio and vision present + # audio_embeddings is on CPU, vision_embeddings may be on XLA + # The model wrapper handles device transfer, so keep on CPU + all_embeddings = torch.cat([ + vision_embeddings.cpu() if vision_embeddings.is_cuda else vision_embeddings, + audio_embeddings, + ], dim=0) + all_positions = torch.cat([vision_positions, audio_positions]) + vision_embeddings = all_embeddings + vision_mask = pad_positions(all_positions, pad_limit, (pad_limit - 1)) + elif audio_embeddings is not None and audio_positions is not None: + # Audio only, no vision + vision_embeddings = audio_embeddings + vision_mask = pad_positions(audio_positions, pad_limit, (pad_limit - 1)) + elif vision_embeddings is not None and vision_positions is not None: + # Vision only, no audio + vision_mask = pad_positions(vision_positions, pad_limit, (pad_limit - 1)) + else: + # No multimodal input - use dummy embeddings + vision_embeddings, vision_mask = self._get_text_model_wrapper().get_dummy_vision_inputs( + config=self.text_config, + input_ids=input_ids, + n_active_tokens=pad_limit, + fill_value=(pad_limit - 1), + ) + + output_token = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + input_capture_hook=input_capture_hook, + tensor_capture_hook=tensor_capture_hook, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + return output_token + + @classmethod + def get_config_cls(cls): + return Qwen25OmniMultimodalInferenceConfig diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_audio.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_audio.py new file mode 100644 index 00000000..a58f1f50 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_audio.py @@ -0,0 +1,700 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni Audio Encoder for NXD inference. +# +# Whisper-style audio encoder with TP=4 support: +# - Conv1d frontend + sinusoidal positional embeddings (CPU preprocessing) +# - 32 transformer layers with TP-parallel attention and MLP (Neuron) +# - AvgPool1d + LayerNorm + Linear projection (CPU postprocessing) +# +# Architecture: +# d_model=1280, heads=20 (5 per rank at TP=4), head_dim=64 +# ffn=5120 (1280 per rank at TP=4), output_dim=3584 +# Asymmetric attention bias: q/v have bias, k has NO bias +# +# The transformer layers are compiled on Neuron via ModelWrapper. +# Conv1d frontend and postprocessing run on CPU since they involve +# variable-length processing (chunking, per-audio AvgPool). + +"""Qwen2.5-Omni Audio Encoder for NXD inference.""" + +import logging +import math +from types import SimpleNamespace +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed_inference.models.application_base import NeuronApplicationBase +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.model_wrapper import ( + EncoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.modules.padding import pad_tensor + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# CPU components (not compiled on Neuron) +# --------------------------------------------------------------------------- + +class SinusoidsPositionEmbedding(nn.Module): + """Sinusoidal positional embeddings (same as Whisper/HF Qwen2.5-Omni).""" + + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp( + -log_timescale_increment * torch.arange(channels // 2).float() + ) + scaled_time = ( + torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + ) + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +class AudioCPUFrontend(nn.Module): + """Conv1d frontend + positional embeddings + chunking (CPU). + + Processes raw mel spectrograms into token sequences ready for the + Neuron transformer. Also handles audio chunking for long audio. + """ + + def __init__(self, audio_config, dtype=torch.bfloat16): + super().__init__() + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + d_model = audio_config.d_model # 1280 + num_mel_bins = audio_config.num_mel_bins # 128 + max_source_positions = audio_config.max_source_positions # 1500 + self.n_window = audio_config.n_window # 100 + self.d_model = d_model + + self.conv1 = nn.Conv1d( + num_mel_bins, d_model, kernel_size=3, padding=1, dtype=dtype + ) + self.conv2 = nn.Conv1d( + d_model, d_model, kernel_size=3, stride=2, padding=1, dtype=dtype + ) + self.positional_embedding = SinusoidsPositionEmbedding( + max_source_positions, d_model + ) + + def _padded_and_mask_function(self, chunk_list, chunk_lengths): + """Pad chunks to same length and create masks.""" + max_len = chunk_lengths.max().item() + dim = chunk_list[0].shape[0] + + padded_tensor = torch.zeros( + len(chunk_list), dim, max_len, + dtype=chunk_list[0].dtype, device=chunk_list[0].device, + ) + batch_mask = torch.zeros( + len(chunk_lengths), max_len, + dtype=torch.long, device=chunk_list[0].device, + ) + for i, (chunk, length) in enumerate(zip(chunk_list, chunk_lengths)): + length = length.item() + batch_mask[i, :length] = 1 + padded_tensor[i, :, :chunk.shape[1]] = chunk + + feature_lens_after_cnn = (chunk_lengths - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max().item() + batch_mask_after_cnn = torch.zeros( + len(chunk_lengths), max_len_after_cnn, + dtype=torch.bool, device=chunk_list[0].device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = True + + return padded_tensor, batch_mask.unsqueeze(1), batch_mask_after_cnn + + def forward(self, input_features, feature_lens): + """Process mel spectrogram through conv frontend. + + Args: + input_features: (n_mels, total_mel_len) mel spectrogram + feature_lens: (num_audios,) mel length for each audio + + Returns: + hidden_states: (total_valid_tokens, d_model) token embeddings + aftercnn_lens: (num_audios,) valid token count per audio + cu_seqlens: cumulative sequence lengths for attention masking + """ + aftercnn_lens = (feature_lens - 1) // 2 + 1 + + # Split into chunks of n_window * 2 mel frames + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum().item(), + dtype=torch.long, device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths = torch.where( + chunk_lengths == 0, self.n_window * 2, chunk_lengths + ) + + chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) + padded_feature, padded_mask, padded_mask_after_cnn = ( + self._padded_and_mask_function(chunk_list, chunk_lengths) + ) + + # Conv frontend + padded_embed = F.gelu(self.conv1(padded_feature)) * padded_mask + padded_embed = F.gelu(self.conv2(padded_embed)).transpose(1, 2) + + # Add positional embeddings + padded_embed = padded_embed + self.positional_embedding( + padded_embed.shape[1] + ).unsqueeze(0).to(padded_embed.dtype) + + # Flatten valid tokens + hidden_states = padded_embed[padded_mask_after_cnn] + + # Compute cu_seqlens for attention mask + cu_seqlens = torch.cat([ + torch.zeros(1, device=feature_lens.device, dtype=torch.int32), + padded_mask_after_cnn.sum(1).cumsum(0).to(torch.int32), + ]) + + return hidden_states, aftercnn_lens, cu_seqlens + + +class AudioCPUPostprocessor(nn.Module): + """AvgPool + LayerNorm + projection (CPU). + + Post-processes transformer output into final audio embeddings. + """ + + def __init__(self, audio_config, dtype=torch.bfloat16): + super().__init__() + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + d_model = audio_config.d_model # 1280 + output_dim = audio_config.output_dim # 3584 + + self.ln_post = nn.LayerNorm(d_model) # stays float32 + self.avg_pooler = nn.AvgPool1d(2, stride=2) + self.proj = nn.Linear(d_model, output_dim, dtype=dtype) + self.audio_bos_eos_token = nn.Embedding(2, output_dim) + + def forward(self, hidden_states, aftercnn_lens): + """Post-process transformer output. + + Args: + hidden_states: (total_tokens, d_model) transformer output + aftercnn_lens: (num_audios,) token count per audio + + Returns: + audio_embeddings: (total_output_tokens, output_dim) + """ + hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) + token_audio_list = [] + for each_audio_states in hidden_states_list: + each_audio_states = self.avg_pooler( + each_audio_states.transpose(0, 1) + ).transpose_(0, 1) + each_audio_states = self.ln_post(each_audio_states.float()).to( + each_audio_states.dtype + ) + each_audio_states = self.proj(each_audio_states) + token_audio_list.append(each_audio_states) + + return torch.cat(token_audio_list, dim=0) + + +# --------------------------------------------------------------------------- +# Neuron-compiled transformer components (TP=4) +# --------------------------------------------------------------------------- + +class NeuronAudioAttention(nn.Module): + """TP-parallel self-attention for audio encoder. + + Asymmetric bias: q_proj and v_proj have bias, k_proj has NO bias. + Uses ColumnParallelLinear for Q/K/V and RowParallelLinear for output. + """ + + def __init__(self, d_model, num_heads, tp_degree, dtype=torch.bfloat16): + super().__init__() + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + self.num_heads_per_rank = num_heads // tp_degree + self.scaling = self.head_dim ** -0.5 + + self.q_proj = ColumnParallelLinear( + d_model, d_model, bias=True, gather_output=False, dtype=dtype, + ) + self.k_proj = ColumnParallelLinear( + d_model, d_model, bias=False, gather_output=False, dtype=dtype, + ) + self.v_proj = ColumnParallelLinear( + d_model, d_model, bias=True, gather_output=False, dtype=dtype, + ) + self.out_proj = RowParallelLinear( + d_model, d_model, bias=True, input_is_parallel=True, dtype=dtype, + ) + + def forward(self, hidden_states, attention_mask=None): + """ + Args: + hidden_states: (batch, seq_len, d_model) + attention_mask: (batch, 1, seq_len, seq_len) with 0 for valid, -inf for masked + """ + bsz, seq_len, _ = hidden_states.shape + + # Project Q/K/V (ColumnParallelLinear outputs d_model/tp per rank) + q = self.q_proj(hidden_states) # (bsz, seq, d_model/tp) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to (bsz, num_heads_per_rank, seq, head_dim) + q = q.view(bsz, seq_len, self.num_heads_per_rank, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads_per_rank, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads_per_rank, self.head_dim).transpose(1, 2) + + # Attention scores + scores = torch.matmul(q, k.transpose(-2, -1)) * self.scaling + if attention_mask is not None: + scores = scores + attention_mask + + attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + # Reshape and project output (RowParallelLinear allreduces across ranks) + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + attn_output = self.out_proj(attn_output) + return attn_output + + +class NeuronAudioEncoderLayer(nn.Module): + """Audio encoder transformer layer (TP-parallel). + + Pre-norm: LayerNorm -> attention -> residual -> LayerNorm -> MLP -> residual. + Uses GELU activation (not SwiGLU). LayerNorm operates in float32. + """ + + def __init__(self, d_model, num_heads, ffn_dim, tp_degree, dtype=torch.bfloat16): + super().__init__() + self.self_attn = NeuronAudioAttention(d_model, num_heads, tp_degree, dtype) + self.self_attn_layer_norm = nn.LayerNorm(d_model) + self.fc1 = ColumnParallelLinear( + d_model, ffn_dim, bias=True, gather_output=False, dtype=dtype, + ) + self.fc2 = RowParallelLinear( + ffn_dim, d_model, bias=True, input_is_parallel=True, dtype=dtype, + ) + self.final_layer_norm = nn.LayerNorm(d_model) + + def forward(self, hidden_states, attention_mask=None): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states.float()).to(residual.dtype) + hidden_states = self.self_attn(hidden_states, attention_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states.float()).to(residual.dtype) + hidden_states = F.gelu(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NeuronAudioTransformerModel(nn.Module): + """Audio encoder transformer (32 layers, compiled on Neuron). + + Takes pre-processed hidden states (after conv + positional embedding) + and returns transformer output. The attention mask handles block-diagonal + masking for chunked audio. + + Input: (1, padded_seq_len, d_model) + (1, 1, padded_seq_len, padded_seq_len) + Output: (1, padded_seq_len, d_model) + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + audio_config = config.audio_config + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + tp_degree = config.neuron_config.tp_degree + dtype = config.neuron_config.torch_dtype + + d_model = audio_config.d_model + num_heads = audio_config.encoder_attention_heads + ffn_dim = audio_config.encoder_ffn_dim + num_layers = audio_config.encoder_layers + + self.layers = nn.ModuleList([ + NeuronAudioEncoderLayer(d_model, num_heads, ffn_dim, tp_degree, dtype) + for _ in range(num_layers) + ]) + + def forward(self, hidden_states, attention_mask): + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + return hidden_states + + +# --------------------------------------------------------------------------- +# ModelWrapper and Application classes +# --------------------------------------------------------------------------- + +class AudioTransformerModelWrapper(ModelWrapper): + """Handles bucketing on sequence length, padding, and compilation.""" + + def __init__(self, config, model_cls, tag="", compiler_args=None, + priority_model_idx=None, pipeline_execution=True, + return_ranked_to_cpu=False, model_init_kwargs={}): + super().__init__( + config, model_cls, tag, compiler_args, priority_model_idx, + pipeline_execution, return_ranked_to_cpu, model_init_kwargs, + ) + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + """Generate sample inputs for each sequence length bucket.""" + inputs = [] + dtype = self.config.neuron_config.torch_dtype + d_model = self.config.audio_config.d_model + if isinstance(d_model, dict): + d_model = d_model.get("d_model", 1280) + + for bucket in self.config.neuron_config.buckets: + hidden_states = torch.ones([1, bucket, d_model], dtype=dtype) + attention_mask = torch.zeros( + [1, 1, bucket, bucket], dtype=dtype, + ) + inputs.append((hidden_states, attention_mask)) + return inputs + + def get_model_instance(self): + return EncoderModelInstance(model_cls=self.model_cls, config=self.config) + + def get_target_bucket(self, seq_len): + """Find the smallest bucket that fits the sequence length.""" + for bucket in self.config.neuron_config.buckets: + if bucket >= seq_len: + return bucket + raise ValueError( + f"No bucket found for seq_len={seq_len}. " + f"Buckets: {self.config.neuron_config.buckets}" + ) + + def forward(self, hidden_states, attention_mask): + """Pad to bucket size and run compiled model.""" + if self.model is None: + raise RuntimeError("Forward called before load.") + + seq_len = hidden_states.shape[1] + bucket = self.get_target_bucket(seq_len) + + # Pad sequence to bucket size + if seq_len < bucket: + pad_len = bucket - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len)) + # Extend attention mask: new positions are masked out + dtype = attention_mask.dtype + mask_pad = torch.full( + (1, 1, bucket, bucket), + torch.finfo(dtype).min, + dtype=dtype, + ) + mask_pad[:, :, :seq_len, :seq_len] = attention_mask + attention_mask = mask_pad + + output = self._forward(hidden_states, attention_mask) + # Trim back to original length + return output[:, :seq_len, :] + + +class NeuronQwen25OmniForAudioEncoding(NeuronApplicationBase): + """Neuron application for audio encoder transformer layers. + + Handles compilation, loading, and inference of the 32 transformer layers. + The conv frontend and postprocessing are handled separately on CPU. + """ + + _model_cls = NeuronAudioTransformerModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = AudioTransformerModelWrapper( + config=self.config, + model_cls=self._model_cls, + tag=self._model_cls.__name__, + compiler_args=self.get_compiler_args(), + priority_model_idx=0, + ) + self.models.append(self.model) + + def forward(self, hidden_states, attention_mask): + return self.models[0](hidden_states, attention_mask) + + def get_compiler_args(self): + return ( + "--auto-cast=none --model-type=transformer " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 ' -O1 " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, inference_config): + """Convert HF state dict to Neuron format for audio transformer layers. + + Extracts audio_tower.layers.* keys and strips prefix. + Only returns transformer layer keys (not conv/postprocessing). + """ + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("thinker.audio_tower.layers."): + new_key = key[len("thinker.audio_tower."):] + elif key.startswith("audio_tower.layers."): + new_key = key[len("audio_tower."):] + else: + new_state_dict[key] = value + continue + new_state_dict[new_key] = value + return new_state_dict + + @classmethod + def get_config_cls(cls): + return AudioEncoderInferenceConfig + + +class AudioEncoderInferenceConfig(InferenceConfig): + """Config for audio encoder transformer compilation.""" + + def __init__(self, neuron_config, audio_config, **kwargs): + # Store audio_config before calling super + self.audio_config = audio_config + if isinstance(audio_config, dict): + self.audio_config = SimpleNamespace(**audio_config) + super().__init__(neuron_config=neuron_config, **kwargs) + + def add_derived_config(self): + pass + + def get_required_attributes(self): + return [] + + +# --------------------------------------------------------------------------- +# Full Audio Encoder (combines CPU frontend + Neuron transformer + CPU postprocessing) +# --------------------------------------------------------------------------- + +class NeuronQwen25OmniAudioEncoder(nn.Module): + """Qwen2.5-Omni Audio Encoder with Neuron acceleration. + + Architecture: + CPU: mel → Conv1d frontend → positional embeddings → chunking + Neuron (TP=4): 32 transformer layers with block-diagonal attention + CPU: AvgPool → LayerNorm → Linear projection → audio embeddings + + Usage: + 1. Create with audio_config + 2. Call compile_transformer() to compile on Neuron + 3. Call load_transformer() to load compiled model + 4. Call forward() to encode audio + """ + + def __init__(self, audio_config, neuron_config=None, dtype=torch.bfloat16): + super().__init__() + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + self.audio_config = audio_config + self.dtype = dtype + self.n_window = audio_config.n_window + + # CPU components + self.frontend = AudioCPUFrontend(audio_config, dtype=dtype) + self.postprocessor = AudioCPUPostprocessor(audio_config, dtype=dtype) + + # Neuron transformer (initialized via compile/load cycle) + self.transformer = None + self._neuron_config = neuron_config + + def _get_feat_extract_output_lengths(self, input_lengths): + """Compute output lengths after conv and avg_pool.""" + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + def _prepare_attention_mask(self, seq_length, cu_seqlens, dtype): + """Create block-diagonal attention mask from cu_seqlens.""" + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(dtype).min, + dtype=dtype, + ) + for i in range(1, len(cu_seqlens)): + s, e = cu_seqlens[i - 1], cu_seqlens[i] + attention_mask[..., s:e, s:e] = 0 + return attention_mask + + def forward(self, input_features, feature_lens, aftercnn_lens=None): + """Process mel spectrogram through audio encoder. + + Args: + input_features: (n_mels, total_mel_len) mel spectrogram + feature_lens: (num_audios,) mel length for each audio + aftercnn_lens: optional pre-computed lengths after conv + + Returns: + audio_embeddings: (total_audio_tokens, output_dim) tensor + """ + if aftercnn_lens is None: + aftercnn_lens, _ = self._get_feat_extract_output_lengths(feature_lens) + + # CPU: Conv frontend + chunking + hidden_states, aftercnn_lens_actual, cu_seqlens = self.frontend( + input_features, feature_lens + ) + + if self.transformer is not None: + # Neuron: transformer layers + seq_len = hidden_states.shape[0] + attention_mask = self._prepare_attention_mask( + seq_len, cu_seqlens, self.dtype + ) + # Add batch dimension for Neuron model + hidden_states = hidden_states.unsqueeze(0) + hidden_states = self.transformer(hidden_states, attention_mask) + hidden_states = hidden_states.squeeze(0) + else: + # Fallback: CPU transformer (for testing without Neuron) + logger.warning( + "Audio transformer not compiled. Running transformer on CPU " + "(this is slow and should only be used for testing)." + ) + attention_mask = self._prepare_attention_mask( + hidden_states.shape[0], cu_seqlens, self.dtype + ) + # CPU fallback would require loading transformer weights separately + raise RuntimeError( + "Audio transformer must be compiled and loaded before inference. " + "Call compile_transformer() and load_transformer() first." + ) + + # CPU: Postprocessing + return self.postprocessor(hidden_states, aftercnn_lens_actual) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, dtype=torch.bfloat16): + """Convert HF state dict to audio encoder format. + + Splits keys into three groups: + - frontend.*: conv1, conv2, positional_embedding (CPU) + - layers.*: transformer layers (Neuron) + - postprocessor.*: ln_post, avg_pooler, proj, audio_bos_eos_token (CPU) + + Returns dict with prefixed keys for the split architecture. + """ + new_state_dict = {} + + # LayerNorm keys that should remain float32 + ln_suffixes = ( + "self_attn_layer_norm.weight", "self_attn_layer_norm.bias", + "final_layer_norm.weight", "final_layer_norm.bias", + "ln_post.weight", "ln_post.bias", + ) + + # CPU frontend keys + frontend_prefixes = ("conv1.", "conv2.", "positional_embedding.") + # CPU postprocessor keys + postprocessor_prefixes = ("ln_post.", "proj.", "avg_pooler.", "audio_bos_eos_token.") + + for key, value in state_dict.items(): + # Strip audio_tower prefix + if key.startswith("thinker.audio_tower."): + clean_key = key[len("thinker.audio_tower."):] + elif key.startswith("audio_tower."): + clean_key = key[len("audio_tower."):] + else: + new_state_dict[key] = value + continue + + # Determine target dtype + if any(clean_key.endswith(s) for s in ln_suffixes): + target_dtype = torch.float32 + else: + target_dtype = dtype + + # Route to correct sub-module + if any(clean_key.startswith(p) for p in frontend_prefixes): + new_state_dict["frontend." + clean_key] = ( + value.clone().detach().contiguous().to(target_dtype) + ) + elif any(clean_key.startswith(p) for p in postprocessor_prefixes): + new_state_dict["postprocessor." + clean_key] = ( + value.clone().detach().contiguous().to(target_dtype) + ) + elif clean_key.startswith("layers."): + # Transformer layers (will be loaded into Neuron model) + new_state_dict["transformer." + clean_key] = ( + value.clone().detach().contiguous().to(target_dtype) + ) + else: + logger.warning("Unknown audio key: %s", clean_key) + + return new_state_dict + + @staticmethod + def from_pretrained_state_dict(audio_config, state_dict, dtype=torch.bfloat16): + """Create audio encoder and load CPU weights from converted state dict. + + Note: Transformer weights need to be loaded separately via + compile_transformer() + load_transformer() for Neuron execution. + """ + encoder = NeuronQwen25OmniAudioEncoder(audio_config, dtype=dtype) + + # Load only frontend and postprocessor weights + cpu_keys = {} + for key, value in state_dict.items(): + if key.startswith("frontend.") or key.startswith("postprocessor."): + cpu_keys[key] = value + + if cpu_keys: + missing, unexpected = encoder.load_state_dict(cpu_keys, strict=False) + # Filter out transformer keys from missing (expected) + missing = [k for k in missing if not k.startswith("transformer.")] + if missing: + logger.warning("Audio encoder CPU missing keys: %s", missing[:10]) + logger.info("Loaded %d CPU weights into audio encoder", len(cpu_keys)) + + return encoder diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_talker.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_talker.py new file mode 100644 index 00000000..32e42b8a --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_talker.py @@ -0,0 +1,1057 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni Talker for NXD inference. +# +# This file contains TWO implementations: +# +# 1. NeuronQwen25OmniTalker (CPU wrapper) +# - Wraps HF's Qwen2_5OmniTalkerForConditionalGeneration +# - Runs on CPU — suitable for quick testing or when Neuron resources +# are reserved for the 7B Thinker +# +# 2. NeuronQwen25OmniTalkerForCausalLM (Neuron-compiled) +# - Compiles the 24-layer transformer on Neuron with KV cache +# - Uses fused embedding (8448→3584→896 collapsed to 8448→896) +# - Supports mRoPE (3D position_ids) and explicit head_dim=128 +# - Recommended TP=4 (3 Q heads/rank, 1 KV head/rank) +# - Uses ImageToTextModelWrapper for thinker state injection during +# context encoding (projected thinker states passed as vision_embeddings) +# +# Talker Architecture: +# - embed_tokens: Embedding(8448, 3584) — codec vocab in Thinker's dim space +# - thinker_to_talker_proj: Linear(3584 -> 896) +# - 24 Qwen2 decoder layers (GQA: 12 heads, 4 kv_heads, head_dim=128) +# - MLP: SiLU gate_proj/up_proj(896->18944), down_proj(18944->896) +# - RMSNorm(896) +# - codec_head: Linear(896 -> 8448, no bias) +# +# Total Parameters: ~690M + +"""Qwen2.5-Omni Talker model for NXD inference.""" + +import gc +import logging +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Part 1: CPU-based Talker (HF wrapper) +# ============================================================================= + + +class NeuronQwen25OmniTalker: + """Wrapper around HF's Qwen2_5OmniTalkerForConditionalGeneration. + + The Talker takes Thinker hidden states + codec embeddings as input, + projects them from embedding_size (3584) to hidden_size (896), runs + through 24 Qwen2 decoder layers, and outputs codec tokens via a + codec_head linear layer. + + This wrapper: + 1. Instantiates the HF Talker from config + 2. Loads weights from converted state dict + 3. Exposes generation API for codec token synthesis + """ + + def __init__(self, talker_config, dtype=torch.bfloat16): + """Initialize the Talker. + + Args: + talker_config: Talker config (dict or HF config object). + Must contain: vocab_size, embedding_size, hidden_size, + num_hidden_layers, num_attention_heads, num_key_value_heads, + intermediate_size, etc. + dtype: Model dtype (default bfloat16). + """ + from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniTalkerConfig, + ) + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniTalkerForConditionalGeneration, + ) + + if isinstance(talker_config, dict): + talker_config = Qwen2_5OmniTalkerConfig(**talker_config) + + self.model = Qwen2_5OmniTalkerForConditionalGeneration(talker_config) + self.model.to(dtype) + self.model.eval() + self.config = talker_config + self.dtype = dtype + + # Expose codec token IDs for orchestration + self.codec_bos_token = talker_config.tts_codec_start_token_id + self.codec_eos_token = talker_config.tts_codec_end_token_id + self.codec_pad_token = talker_config.tts_codec_pad_token_id + self.codec_mask_token = talker_config.tts_codec_mask_token_id + self.text_bos_token = talker_config.tts_text_start_token_id + self.text_eos_token = talker_config.tts_text_end_token_id + self.text_pad_token = talker_config.tts_text_pad_token_id + + def load_state_dict(self, state_dict, strict=True): + """Load converted state dict into the HF Talker model.""" + return self.model.load_state_dict(state_dict, strict=strict) + + @torch.no_grad() + def generate( + self, + input_ids, + input_text_ids, + thinker_reply_part, + inputs_embeds, + attention_mask=None, + max_new_tokens=4096, + do_sample=True, + top_k=40, + top_p=0.8, + temperature=0.9, + eos_token_id=None, + repetition_penalty=1.05, + suppress_tokens=None, + **kwargs, + ): + """Generate codec tokens from Thinker hidden states. + + Args: + input_ids: (batch, seq_len) codec input IDs + input_text_ids: (batch, seq_len) text input IDs (for position calc) + thinker_reply_part: (batch, reply_len, 3584) Thinker hidden states + inputs_embeds: (batch, seq_len, 3584) input embeddings + attention_mask: (batch, seq_len) optional attention mask + max_new_tokens: Maximum codec tokens to generate + do_sample: Whether to sample (vs greedy) + top_k: Top-k sampling + top_p: Nucleus sampling probability + temperature: Sampling temperature + eos_token_id: EOS token(s) for stopping + repetition_penalty: Repetition penalty + suppress_tokens: Token IDs to suppress during generation + **kwargs: Additional generation kwargs + + Returns: + (batch, total_len) generated codec token IDs + """ + if eos_token_id is None: + eos_token_id = [self.codec_eos_token, self.codec_pad_token] + if suppress_tokens is None: + suppress_tokens = [self.codec_bos_token] + + return self.model.generate( + input_ids=input_ids, + input_text_ids=input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + suppress_tokens=suppress_tokens, + **kwargs, + ) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: dict) -> dict: + """Convert HF state dict to Talker format. + + Strips 'talker.' prefix from keys. Non-talker keys are passed through. + + Args: + state_dict: Full or partial state dict with talker.* keys. + + Returns: + State dict with talker prefix stripped. + """ + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("talker."): + new_state_dict[key[len("talker."):]] = value + else: + # Pass through non-talker keys + new_state_dict[key] = value + return new_state_dict + + @classmethod + def from_pretrained_state_dict(cls, talker_config, state_dict, dtype=torch.bfloat16): + """Create Talker and load weights from converted state dict. + + Args: + talker_config: Talker config (dict or HF config object) + state_dict: Already-converted state dict (talker keys only) + dtype: Target dtype + + Returns: + Initialized NeuronQwen25OmniTalker + """ + talker = cls(talker_config, dtype=dtype) + + # Filter to only talker keys (skip non-talker prefixes) + talker_keys = {} + for key, value in state_dict.items(): + if any( + key.startswith(p) + for p in [ + "lm_head.", "visual.", "audio_tower.", + "thinker.", "token2wav.", "talker.", + ] + ): + continue + talker_keys[key] = value + + missing, unexpected = talker.load_state_dict(talker_keys, strict=False) + if missing: + logger.warning("Talker missing keys: %s", missing[:10]) + if unexpected: + logger.warning("Talker unexpected keys: %s", unexpected[:10]) + logger.info("Loaded %d weights into Talker", len(talker_keys)) + + return talker + + +# ============================================================================= +# Part 2: Neuron-compiled Talker +# ============================================================================= + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, +) +from neuronx_distributed.utils import cpu_mode +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.llama.modeling_llama import NeuronLlamaMLP +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase +from neuronx_distributed_inference.modules.attention.utils import _rotate_half +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + + +def _get_rmsnorm_cls(): + return LlamaRMSNorm if cpu_mode() else CustomRMSNorm + + +def _apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Apply multimodal RoPE (reused from Qwen2-VL).""" + mrope_section = mrope_section * 2 + split_indices = [sum(mrope_section[:i + 1]) for i in range(len(mrope_section) - 1)] + cos = torch.cat( + [m[i % 3] for i, m in enumerate(torch.tensor_split(cos, split_indices, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(torch.tensor_split(sin, split_indices, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +def _apply_standard_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Apply standard 1D RoPE.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +class TalkerNeuronConfig(NeuronConfig): + """NeuronConfig subclass for Talker. + + Sets the default attention class to NeuronTalkerAttention. + Recommended TP=4 for the Talker (12 Q heads / 4 = 3 per rank). + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.attn_cls = NeuronTalkerAttention + + +class TalkerInferenceConfig(InferenceConfig): + """InferenceConfig for the Neuron-compiled Talker. + + Talker-specific attributes: + - head_dim = 128 (explicit, NOT hidden_size // num_attention_heads) + - qkv_bias = True, o_bias = False (Qwen2 pattern) + - rope_scaling with mrope_section for 3D mRoPE + - thinker_hidden_size = 3584 (for projection during context encoding) + """ + + def add_derived_config(self): + self.num_cores_per_group = 1 + self.qkv_bias = True + self.o_bias = False + # Head dim is EXPLICIT for the Talker + # hidden_size=896, num_attention_heads=12 → 896/12=74.67 (fractional!) + # Actual head_dim=128, so attention internal dim = 12×128 = 1536 + if not hasattr(self, "head_dim") or self.head_dim is None: + self.head_dim = 128 + # mRoPE config (default matching Qwen2.5-Omni) + if not hasattr(self, "rope_scaling") or self.rope_scaling is None: + self.rope_scaling = {"type": "mrope", "mrope_section": [16, 24, 24]} + # Store thinker hidden size for projection + if not hasattr(self, "thinker_hidden_size"): + self.thinker_hidden_size = 3584 + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", # 896 + "num_attention_heads", # 12 + "num_hidden_layers", # 24 + "num_key_value_heads", # 4 + "vocab_size", # 8448 + "rope_theta", + "rms_norm_eps", + "hidden_act", # silu + "intermediate_size", # 18944 + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[TalkerNeuronConfig]: + return TalkerNeuronConfig + + +# --------------------------------------------------------------------------- +# RoPE +# --------------------------------------------------------------------------- + +class TalkerRotaryEmbedding(nn.Module): + """Rotary position embedding for the Talker. + + Uses head_dim (128) as the RoPE dimension, NOT hidden_size // num_heads. + Supports both standard 1D and multimodal 3D position_ids. + """ + + def __init__(self, config: TalkerInferenceConfig, device=None): + super().__init__() + self.dim = config.head_dim # 128 + self.base = getattr(config, "rope_theta", 1000000.0) + self.attention_scaling = 1.0 + self.register_buffer("inv_freq", None, persistent=False) + self.inv_freq = self._compute_inv_freq(device) + + def _compute_inv_freq(self, device=None): + freq_indices = torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) + return 1.0 / (self.base ** (freq_indices / self.dim)) + + def forward(self, x, position_ids): + if position_ids.ndim == 2: + # Expand 2D (batch, seq) → 3D (3, batch, seq) for mRoPE + # Same approach as Qwen3-VL: replicate across temporal/height/width + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # 3D mRoPE: position_ids shape (3, batch, seq) + inv_freq_expanded = self.inv_freq[None, None, :, None].expand( + 3, position_ids.shape[1], -1, 1 + ) + position_ids_expanded = position_ids[:, :, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(-2, -1) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + +class NeuronTalkerAttention(NeuronAttentionBase): + """Talker self-attention with explicit head_dim=128 and mRoPE. + + Key difference from standard Qwen2 attention: + - head_dim=128 ≠ hidden_size(896) / num_attention_heads(12) + - Internal attention dimension = 12 × 128 = 1536 ≠ hidden_size + - Q projection: (896, 1536), K/V projection: (896, 512) + - Output projection: (1536, 896) + """ + + def __init__(self, config: TalkerInferenceConfig): + rotary_emb = TalkerRotaryEmbedding(config) + super().__init__( + config=config, + hidden_size=config.hidden_size, # 896 + num_attention_heads=config.num_attention_heads, # 12 + num_key_value_heads=config.num_key_value_heads, # 4 + head_dim=config.head_dim, # 128 + qkv_bias=config.qkv_bias, + o_bias=config.o_bias, + rotary_emb=rotary_emb, + ) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.use_mrope = ( + self.rope_scaling is not None + and "mrope_section" in self.rope_scaling + ) + if self.use_mrope: + self.mrope_section = self.rope_scaling["mrope_section"] + + def apply_rotary_embedding(self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope): + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + if self.use_mrope: + Q, K = _apply_multimodal_rotary_pos_emb( + Q, K, cos_cache, sin_cache, self.mrope_section + ) + else: + Q, K = _apply_standard_rotary_pos_emb(Q, K, cos_cache, sin_cache) + return Q, K, cos_cache, sin_cache + + +# --------------------------------------------------------------------------- +# Decoder Layer +# --------------------------------------------------------------------------- + +class NeuronTalkerDecoderLayer(nn.Module): + """Talker decoder layer: pre-norm attention + SwiGLU MLP.""" + + def __init__(self, config: TalkerInferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = NeuronTalkerAttention(config) + self.mlp = NeuronLlamaMLP(config) + self.input_layernorm = _get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = _get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + adapter_ids=None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + adapter_ids=adapter_ids, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, adapter_ids=adapter_ids)[0] + hidden_states = residual + hidden_states + + return (hidden_states, present_key_value, cos_cache, sin_cache, None) + + +# --------------------------------------------------------------------------- +# Model (NeuronBaseModel) +# --------------------------------------------------------------------------- + +class NeuronTalkerModel(NeuronBaseModel): + """Neuron-compiled Talker transformer. + + Uses fused embedding: original embed_tokens(8448, 3584) + proj(3584, 896) + are collapsed into embed_tokens(8448, 896) during state dict conversion. + + For context encoding with thinker hidden states, the projected states + (batch, reply_len, 896) are passed as vision_embeddings and substituted + via encode_vision_to_input(). Unlike the upstream NeuronBaseModel which + only calls encode_vision_to_input during context encoding, this subclass + overrides get_model_output to also inject during token generation (for + per-step thinker state injection), matching HF's Qwen2.5-Omni behavior. + """ + + def get_model_output( + self, + input_ids=None, + inputs_embeds=None, + vision_embeddings=None, + vision_mask=None, + is_for_context_encoding: bool = False, + adapter_ids=None, + **kwargs, + ): + """Override to inject thinker states during token generation. + + Upstream gates encode_vision_to_input behind is_for_context_encoding, + which is correct for Qwen2-VL/Qwen3-VL but wrong for Talker because + Talker needs per-step thinker state injection during autoregressive + decode. We pre-inject here, then pass inputs_embeds to super() with + vision_embeddings=None/vision_mask=None so the parent doesn't try + to inject again. + """ + if ( + not is_for_context_encoding + and vision_embeddings is not None + and vision_mask is not None + ): + # Token-generation phase: compute inputs_embeds ourselves, inject, + # then let super() skip its vision block. + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to( + self.config.neuron_config.torch_dtype + ) + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + vision_embeddings = None + vision_mask = None + + return super().get_model_output( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + is_for_context_encoding=is_for_context_encoding, + adapter_ids=adapter_ids, + **kwargs, + ) + + def setup_attr_for_model(self, config: TalkerInferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: TalkerInferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Fused embedding: 8448 → 896 (original 8448→3584→896 collapsed) + self.embed_tokens = ParallelEmbedding( + config.vocab_size, # 8448 + config.hidden_size, # 896 + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + pad=True, + ) + self.layers = nn.ModuleList( + [NeuronTalkerDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = _get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + # codec_head: 896 → 8448 + self.lm_head = ColumnParallelLinear( + config.hidden_size, # 896 + config.vocab_size, # 8448 + bias=False, + pad=True, + gather_output=not self.on_device_sampling, + ) + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Inject projected thinker states into token embeddings. + + Operates in two modes: + - Context encoding (seq > 1): REPLACE placeholder embeddings with + projected thinker states. All positions get thinker states. + - Token generation (seq == 1): ADD per-step thinker reply state to + the codec token embedding. This provides text guidance at each + autoregressive step, matching HF's per-step injection behavior. + + Args: + inputs_embeds: (batch, seq, 896) from embed_tokens + vision_embeddings: (batch, seq, 896) projected thinker states + vision_mask: (batch, seq, 1) int32 mask + + Returns: + (batch, seq, 896) with thinker states injected + """ + if inputs_embeds.shape[1] > 1: + # Context encoding: REPLACE embeddings with thinker states + vision_mask_bool = vision_mask.bool() + if vision_mask_bool.dim() == 3: + vision_mask_bool = vision_mask_bool.squeeze(-1) + mask_expanded = vision_mask_bool.unsqueeze(-1).expand_as(inputs_embeds) + return torch.where(mask_expanded, vision_embeddings, inputs_embeds) + else: + # Token generation: ADD thinker state to codec token embedding + # This matches HF's behavior where thinker_reply_part[step] is + # added to embed_tokens(codec_token) at each generation step. + return inputs_embeds + vision_embeddings + + +# --------------------------------------------------------------------------- +# Model Wrapper (tracing with per-step vision_embeddings) +# --------------------------------------------------------------------------- + + +class TalkerModelWrapper: + """Mixin that overrides get_dummy_vision_inputs for per-step injection. + + Unlike the default ImageToTextModelWrapper which uses empty + vision_embeddings during token generation tracing, this provides + (batch, 1, hidden_size) tensors so the compiled NEFF includes the + ADD operation for thinker state injection at each generation step. + """ + + @staticmethod + def get_dummy_vision_inputs(config, input_ids, n_active_tokens, fill_value): + input_batch_size, input_sequence_len = input_ids.shape[0], input_ids.shape[-1] + if input_sequence_len > 1: + # Context encoding: full-sequence vision embeddings + vision_embeddings = torch.zeros( + input_batch_size, + n_active_tokens, + config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + size=(input_batch_size, n_active_tokens, 1), + fill_value=fill_value, + dtype=torch.int32, + ) + else: + # Token generation: single-step vision embeddings for per-step + # thinker state injection (ADD to codec token embedding) + vision_embeddings = torch.zeros( + input_batch_size, + 1, + config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + size=(input_batch_size, 1, 1), + fill_value=fill_value, + dtype=torch.int32, + ) + return vision_embeddings, vision_mask + + +# --------------------------------------------------------------------------- +# Application (NeuronBaseForCausalLM) +# --------------------------------------------------------------------------- + +class NeuronQwen25OmniTalkerForCausalLM(NeuronBaseForCausalLM): + """Neuron-compiled Talker for autoregressive codec token generation. + + Compilation: + - Recommended TP=4 (12 Q heads / 4 = 3 per rank) + - Uses ImageToTextModelWrapper for vision_embeddings support + - Context encoding: thinker states injected as vision_embeddings + - Token generation: standard autoregressive with fused embedding + + State dict conversion: + - Fuses embed_tokens(8448, 3584) + thinker_to_talker_proj(3584, 896) + into embed_tokens(8448, 896) + - Maps codec_head → lm_head + - Supports fused QKV + + Usage: + # 1. Create config + talker_config = TalkerInferenceConfig(neuron_config, load_config=hf_config) + + # 2. Create application + app = NeuronQwen25OmniTalkerForCausalLM(model_path, config=talker_config) + + # 3. Compile + app.compile(compiled_model_path) + + # 4. Load + app.load(compiled_model_path) + + # 5. Generate (token generation on Neuron) + logits = app.forward(input_ids, attention_mask, position_ids, seq_ids, ...) + """ + + _model_cls = NeuronTalkerModel + + def get_model_wrapper_cls(self): + from neuronx_distributed_inference.models.image_to_text_model_wrapper import ( + ImageToTextModelWrapper, + ) + + # Dynamically create a wrapper with per-step vision_embeddings support. + # staticmethod() preserves the descriptor when assigning to class attr. + class _TalkerImageToTextModelWrapper(ImageToTextModelWrapper): + get_dummy_vision_inputs = staticmethod( + TalkerModelWrapper.get_dummy_vision_inputs + ) + + return _TalkerImageToTextModelWrapper + + @classmethod + def get_config_cls(cls): + return TalkerInferenceConfig + + def set_vision_embeddings(self, vision_embeddings, vision_mask, + thinker_reply_embeds=None): + """Store vision embeddings for the next generate() call. + + During context encoding, projected thinker states (896-dim) are + injected as vision_embeddings. During token generation, per-step + thinker reply states are ADDED to codec token embeddings. + + Vision embeddings are padded to max_context_length to match the + compiled bucket shapes (the compiled model expects fixed-size + vision_embeddings matching the bucket, while input_ids and + attention_mask are padded by preprocess_inputs). + + Args: + vision_embeddings: (batch, seq, 896) projected thinker states + for context encoding + vision_mask: (batch, seq, 1) int32 mask (all positions active) + thinker_reply_embeds: (batch, n_reply, 896) optional per-step + thinker reply states for token generation. If provided, + reply_embeds[:, step, :] is added to the codec token + embedding at each generation step. + """ + # Pad vision_embeddings and vision_mask to max_context_length so they + # match the compiled NEFF bucket shapes. + max_ctx = self.neuron_config.max_context_length + batch, seq, dim = vision_embeddings.shape + if seq < max_ctx: + pad_ve = torch.zeros( + batch, max_ctx - seq, dim, dtype=vision_embeddings.dtype + ) + vision_embeddings = torch.cat([vision_embeddings, pad_ve], dim=1) + pad_vm = torch.zeros( + batch, max_ctx - seq, 1, dtype=vision_mask.dtype + ) + vision_mask = torch.cat([vision_mask, pad_vm], dim=1) + + self._vision_embeddings = vision_embeddings + self._vision_mask = vision_mask + self._thinker_reply_embeds = thinker_reply_embeds + self._vision_dtype = vision_embeddings.dtype + self._tkg_step = 0 + + def _get_model_outputs( + self, input_ids, attention_mask, position_ids, seq_ids, + sampling_params, prev_hidden, adapter_ids, + medusa_args=None, llava_args=None, **kwargs + ): + """Override to pass vision_embeddings to ImageToTextModelWrapper. + + Context encoding: passes full projected thinker states as + vision_embeddings (REPLACE mode in encode_vision_to_input). + + Token generation: passes per-step thinker reply state as + vision_embeddings (ADD mode in encode_vision_to_input). + + ImageToTextModelWrapper traces with 24 positional args: + 0-4: input_ids, attention_mask, position_ids, seq_ids, sampling_params + 5-20: empty placeholders (prev_hidden, adapter_ids, medusa/block args) + 21: rotary_position_ids + 22: vision_embeddings + 23: vision_mask + """ + vision_embeddings = getattr(self, '_vision_embeddings', torch.empty(0)) + vision_mask = getattr(self, '_vision_mask', torch.empty(0)) + + if self._is_prefill(position_ids): + outputs = self.context_encoding_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + torch.empty(0), # prev_hidden + torch.empty(0), # adapter_ids + torch.empty(0), # accepted_indices + torch.empty(0), # current_length + torch.empty(0), # medusa_mask + torch.empty(0), # scatter_index + torch.empty(0), # slot_mapping + torch.empty(0), # active_block_table + torch.empty(0), # num_queries + torch.empty(0), # computed_context_lens + torch.empty(0), # tile_q_indices + torch.empty(0), # tile_block_tables + torch.empty(0), # tile_masks + torch.empty(0), # inputs_embeds + torch.empty(0), # kv_cache + torch.empty(0), # active_mask + torch.empty(0), # rotary_position_ids + vision_embeddings, + vision_mask, + ) + self.kv_cache_populated = True + # Clear context vision (no longer needed), keep reply embeds + self._vision_embeddings = torch.empty(0) + self._vision_mask = torch.empty(0) + self._tkg_step = 0 + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + # Get per-step thinker reply state for this generation step + reply_embeds = getattr(self, '_thinker_reply_embeds', None) + dtype = getattr(self, '_vision_dtype', torch.bfloat16) + batch_size = input_ids.shape[0] + hidden_size = self.config.hidden_size + + if reply_embeds is not None and self._tkg_step < reply_embeds.shape[1]: + step_ve = reply_embeds[:, self._tkg_step:self._tkg_step + 1, :] + step_vm = torch.ones(batch_size, 1, 1, dtype=torch.int32) + self._tkg_step += 1 + elif reply_embeds is not None and reply_embeds.shape[1] > 0: + # Repeat the last reply state (matches HF behavior where + # thinker_reply_part stays at the last element when exhausted) + step_ve = reply_embeds[:, -1:, :] + step_vm = torch.ones(batch_size, 1, 1, dtype=torch.int32) + else: + step_ve = torch.zeros( + batch_size, 1, hidden_size, dtype=dtype + ) + step_vm = torch.ones(batch_size, 1, 1, dtype=torch.int32) + + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + torch.empty(0), # prev_hidden + torch.empty(0), # adapter_ids + torch.empty(0), # accepted_indices + torch.empty(0), # current_length + torch.empty(0), # medusa_mask + torch.empty(0), # scatter_index + torch.empty(0), # slot_mapping + torch.empty(0), # active_block_table + torch.empty(0), # num_queries + torch.empty(0), # computed_context_lens + torch.empty(0), # tile_q_indices + torch.empty(0), # tile_block_tables + torch.empty(0), # tile_masks + torch.empty(0), # inputs_embeds + torch.empty(0), # kv_cache + torch.empty(0), # active_mask + torch.empty(0), # rotary_position_ids + step_ve, # vision_embeddings (per-step thinker state) + step_vm, # vision_mask + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: TalkerInferenceConfig + ) -> dict: + """Convert HF Talker state dict to Neuron format. + + Key transformations: + 1. Strip 'talker.' and 'model.' prefixes + 2. Fuse embed_tokens(8448, 3584) + thinker_to_talker_proj(3584, 896) + → embed_tokens(8448, 896) + 3. Map codec_head → lm_head + 4. Add rank utilities for distributed inference + 5. Optionally fuse QKV projections + + The original thinker_to_talker_proj weights are stored as + '_thinker_proj_weight' and '_thinker_proj_bias' in the returned + dict for use during context encoding (CPU-side projection). + """ + neuron_config = config.neuron_config + new_state_dict = {} + proj_weight = None + proj_bias = None + + for key, value in state_dict.items(): + # Strip talker. prefix + if key.startswith("talker."): + key = key[len("talker."):] + + # Route keys + if key.startswith("model."): + new_key = key[len("model."):] + elif key == "codec_head.weight": + new_key = "lm_head.weight" + elif key == "thinker_to_talker_proj.weight": + proj_weight = value + continue + elif key == "thinker_to_talker_proj.bias": + proj_bias = value + continue + else: + # Skip keys from other components + if any(key.startswith(p) for p in [ + "lm_head.", "visual.", "audio_tower.", + "thinker.", "token2wav.", + ]): + continue + new_key = key + new_state_dict[new_key] = value + + # Fuse embedding: embed(8448, 3584) @ proj.T(3584, 896) → (8448, 896) + # NOTE: projection bias is NOT included in the fused embedding. + # During token generation, the bias is already included once in the + # projected thinker reply states (proj(reply) = W @ reply + bias). + # Including it here would cause double-bias: fused(token) + proj(reply) + # = (W @ E + bias) + (W @ reply + bias) = W @ (E+reply) + 2*bias, + # whereas HF computes proj(E + reply) = W @ (E+reply) + bias. + if "embed_tokens.weight" in new_state_dict and proj_weight is not None: + embed_weight = new_state_dict["embed_tokens.weight"] # (8448, 3584) + fused_embed = embed_weight.float() @ proj_weight.float().T # (8448, 896) + new_state_dict["embed_tokens.weight"] = fused_embed.to( + neuron_config.torch_dtype + ) + logger.info( + "Fused embed_tokens (%s) + proj (%s) → (%s) (bias excluded)", + list(embed_weight.shape), list(proj_weight.shape), + list(new_state_dict["embed_tokens.weight"].shape), + ) + + # Save projection weights for CPU-side context encoding + if proj_weight is not None: + new_state_dict["_thinker_proj_weight"] = proj_weight + if proj_bias is not None: + new_state_dict["_thinker_proj_bias"] = proj_bias + + # Add rank utilities + if neuron_config.vocab_parallel: + new_state_dict["embed_tokens.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + tp_degree = neuron_config.tp_degree + for i in range(config.num_hidden_layers): + new_state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + # Fuse QKV if enabled + if neuron_config.fused_qkv: + new_state_dict = _fuse_talker_qkv(new_state_dict, config) + + new_state_dict["rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + gc.collect() + return new_state_dict + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + # Talker does not tie embed_tokens and lm_head weights + pass + + def get_compiler_args(self): + return ( + "--enable-saturate-infinity --enable-mixed-precision-accumulation " + "--auto-cast=none --model-type transformer -O1 " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 --vectorize-strided-dma' " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + +def _fuse_talker_qkv(state_dict: dict, config: InferenceConfig) -> dict: + """Fuse Q/K/V weight and bias tensors into Wqkv for the Talker.""" + for layer_idx in range(config.num_hidden_layers): + for attr in ["weight", "bias"]: + q_key = f"layers.{layer_idx}.self_attn.q_proj.{attr}" + k_key = f"layers.{layer_idx}.self_attn.k_proj.{attr}" + v_key = f"layers.{layer_idx}.self_attn.v_proj.{attr}" + if all(k in state_dict for k in [q_key, k_key, v_key]): + state_dict[f"layers.{layer_idx}.self_attn.Wqkv.{attr}"] = torch.cat([ + state_dict.pop(q_key), + state_dict.pop(k_key), + state_dict.pop(v_key), + ]) + gc.collect() + return state_dict + + +# --------------------------------------------------------------------------- +# Helper: CPU-side thinker state projection +# --------------------------------------------------------------------------- + +class ThinkerToTalkerProjection(nn.Module): + """CPU-side projection from Thinker hidden space to Talker hidden space. + + During context encoding, thinker hidden states (3584-d) need to be + projected to 896-d before being injected into the Neuron model as + vision_embeddings. + + This module is loaded from the original thinker_to_talker_proj weights + that are extracted during state dict conversion. + """ + + def __init__(self, thinker_hidden_size=3584, talker_hidden_size=896): + super().__init__() + self.proj = nn.Linear(thinker_hidden_size, talker_hidden_size) + + @torch.no_grad() + def forward(self, thinker_hidden_states): + """Project thinker states for Neuron context encoding. + + Args: + thinker_hidden_states: (batch, seq_len, 3584) + + Returns: + (batch, seq_len, 896) projected states ready for vision_embeddings + """ + return self.proj(thinker_hidden_states) + + @classmethod + def from_state_dict(cls, state_dict, dtype=torch.bfloat16): + """Create projection from extracted state dict. + + Args: + state_dict: Must contain '_thinker_proj_weight' and + optionally '_thinker_proj_bias'. + + Returns: + ThinkerToTalkerProjection on CPU in specified dtype + """ + proj_weight = state_dict.get("_thinker_proj_weight") + proj_bias = state_dict.get("_thinker_proj_bias") + if proj_weight is None: + raise ValueError( + "State dict missing '_thinker_proj_weight'. " + "Run convert_hf_to_neuron_state_dict first." + ) + in_features = proj_weight.shape[1] + out_features = proj_weight.shape[0] + module = cls(in_features, out_features) + module.proj.weight.data = proj_weight.to(dtype) + if proj_bias is not None: + module.proj.bias.data = proj_bias.to(dtype) + else: + module.proj.bias = None + module.to(dtype) + module.eval() + return module diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py new file mode 100644 index 00000000..9e37ae6c --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_token2wav.py @@ -0,0 +1,765 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni Token2Wav for NXD inference. +# +# This file contains TWO implementations: +# +# 1. NeuronQwen25OmniToken2Wav (CPU wrapper) +# - Wraps HF's Qwen2_5OmniToken2WavModel entirely on CPU in float32 +# - Suitable for quick testing or when Neuron resources are limited +# +# 2. NeuronQwen25OmniToken2WavWithNeuronDiT (Neuron-accelerated) +# - Compiles the DiT (22 transformer blocks) on Neuron via torch_neuronx.trace() +# - ODE solver loop stays on CPU (inherently sequential, 10-50 steps) +# - BigVGAN vocoder stays on CPU (convolutional, ~10-20M params) +# - Speaker encoder (ECAPA-TDNN) stays on CPU (~small) +# - DiT is the compute bottleneck: 22 blocks × 10 ODE steps = 220 forward passes +# +# Architecture: +# - DiT (Diffusion Transformer): 22 blocks, dim=1024, 16 heads +# - ECAPA-TDNN speaker encoder for speaker conditioning +# - Codec embedding + RoPE + AdaLayerNorm +# - ODE sampling (Runge-Kutta 4) for mel spectrogram generation +# - BigVGAN vocoder: mel spectrogram -> waveform +# - conv_pre(80->1536) + 6 upsample stages + AMPBlock residuals +# - Snake activation, conv_post(24->1) +# +# Runs on CPU in float32 (required for ODE solver precision). +# Token2Wav has ~809 state dict keys total. + +"""Qwen2.5-Omni Token2Wav model for NXD inference.""" + +import json +import logging +import os +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Part 1: CPU-based Token2Wav (HF wrapper) +# ============================================================================= + + +class NeuronQwen25OmniToken2Wav: + """Wrapper around HF's Qwen2_5OmniToken2WavModel. + + Token2Wav converts codec tokens into audio waveforms through: + 1. DiT model: codec tokens + speaker embedding -> mel spectrogram + (via ODE sampling with classifier-free guidance) + 2. BigVGAN vocoder: mel spectrogram -> waveform + + Speaker conditioning requires a speaker dict (spk_dict.pt) containing + per-speaker 'cond' (conditioning) and 'ref_mel' (reference mel) tensors, + plus 'bos_token' for the Talker. + + This wrapper: + 1. Instantiates the HF Token2Wav from config + 2. Loads weights from converted state dict + 3. Exposes waveform generation API + """ + + def __init__(self, token2wav_config): + """Initialize Token2Wav. + + Args: + token2wav_config: Token2Wav config (dict or HF config object). + Must contain dit_config and bigvgan_config sub-configs. + """ + from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniToken2WavConfig, + ) + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniToken2WavModel, + ) + + if isinstance(token2wav_config, dict): + token2wav_config = Qwen2_5OmniToken2WavConfig(**token2wav_config) + + self.model = Qwen2_5OmniToken2WavModel(token2wav_config) + # Token2Wav must run in float32 for ODE solver precision + self.model.float() + self.model.eval() + self.config = token2wav_config + + @property + def dtype(self): + """Return dtype of the underlying HF model (for HF generate compatibility).""" + return next(self.model.parameters()).dtype + + def float(self): + """Cast underlying model to float32 (for HF generate compatibility).""" + self.model.float() + return self + + def load_state_dict(self, state_dict, strict=True): + """Load converted state dict into the HF Token2Wav model.""" + return self.model.load_state_dict(state_dict, strict=strict) + + @torch.no_grad() + def __call__( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generate waveform from codec tokens. + + Args: + code: (batch, seq_len) codec token IDs from the Talker + conditioning: (batch, mel_len, enc_dim) speaker conditioning + (from spk_dict.pt 'cond' key) + reference_mel: (batch, mel_len, mel_dim) reference mel spectrogram + (from spk_dict.pt 'ref_mel' key) + num_steps: Number of ODE solver steps (default 10) + guidance_scale: Classifier-free guidance scale (default 0.5) + sway_coefficient: Time schedule sway (default -1.0) + **kwargs: Additional kwargs passed to Token2Wav + + Returns: + waveform: (samples,) audio waveform tensor on CPU + """ + return self.model( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + **kwargs, + ) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: dict) -> dict: + """Convert HF state dict to Token2Wav format. + + Strips 'token2wav.' prefix from keys. Non-token2wav keys are passed through. + + Args: + state_dict: Full or partial state dict with token2wav.* keys. + + Returns: + State dict with token2wav prefix stripped. + """ + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("token2wav."): + new_state_dict[key[len("token2wav."):]] = value + else: + new_state_dict[key] = value + return new_state_dict + + @staticmethod + def load_speaker_dict(speaker_dict_path): + """Load speaker dictionary from spk_dict.pt. + + Args: + speaker_dict_path: Path to spk_dict.pt + + Returns: + dict: Speaker name -> {cond, ref_mel, bos_token} + """ + return torch.load(speaker_dict_path, weights_only=True) + + @classmethod + def from_pretrained_state_dict(cls, token2wav_config, state_dict): + """Create Token2Wav and load weights from converted state dict. + + Args: + token2wav_config: Token2Wav config (dict or HF config object) + state_dict: Already-converted state dict (token2wav keys only) + + Returns: + Initialized NeuronQwen25OmniToken2Wav + """ + token2wav = cls(token2wav_config) + + # Filter to only token2wav keys (skip non-token2wav prefixes) + t2w_keys = {} + for key, value in state_dict.items(): + if any( + key.startswith(p) + for p in [ + "lm_head.", "visual.", "audio_tower.", + "thinker.", "talker.", "token2wav.", + ] + ): + continue + t2w_keys[key] = value + + missing, unexpected = token2wav.load_state_dict(t2w_keys, strict=False) + if missing: + logger.warning("Token2Wav missing keys: %s", missing[:10]) + if unexpected: + logger.warning("Token2Wav unexpected keys: %s", unexpected[:10]) + logger.info("Loaded %d weights into Token2Wav", len(t2w_keys)) + + return token2wav + + +# ============================================================================= +# Part 2: Neuron-accelerated Token2Wav (DiT on Neuron) +# ============================================================================= + + +def _monkeypatch_dit_attention_for_neuron(dit_module): + """Replace DiTAttention.forward with XLA-traceable version. + + Fixes two XLA-incompatible operations in DiTAttention: + 1. In-place slice assignment: query[:, :1], key[:, :1] = ... + → replaced with torch.cat-based reassembly + 2. ALL_ATTENTION_FUNCTIONS dispatch → explicit matmul attention + + Args: + dit_module: Qwen2_5OmniToken2WavDiTModel instance + """ + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + apply_rotary_pos_emb, + ) + + for block in dit_module.transformer_blocks: + attn = block.attn + + def _make_patched(a): + def forward(hidden_states, position_embeddings=None, attention_mask=None): + batch_size = hidden_states.shape[0] + query = a.to_q(hidden_states) + key = a.to_k(hidden_states) + value = a.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // a.heads + query = query.view(batch_size, -1, a.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, a.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, a.heads, head_dim).transpose(1, 2) + + # Apply RoPE to first head only (matches HF behavior) + # FIX: use torch.cat instead of in-place slice assignment + cos, sin = position_embeddings + q_rope, k_rope = apply_rotary_pos_emb( + query[:, :1], key[:, :1], cos, sin + ) + query = torch.cat([q_rope, query[:, 1:]], dim=1) + key = torch.cat([k_rope, key[:, 1:]], dim=1) + + # Explicit matmul attention (XLA-safe, no SDPA dispatch) + scale = head_dim ** -0.5 + attn_weights = torch.matmul( + query, key.transpose(-2, -1) + ) * scale + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1 + ) + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape( + batch_size, -1, a.heads * head_dim + ) + attn_output = attn_output.to(query.dtype) + + attn_output = a.to_out[0](attn_output) + attn_output = a.to_out[1](attn_output) + + return attn_output + + return forward + + attn.forward = _make_patched(attn) + + +class _NeuronDiTCore(torch.nn.Module): + """Traced wrapper for DiT transformer blocks + norm_out + proj_out. + + Only the compute-heavy transformer core is compiled on Neuron. + All preprocessing (time_embed, text_embed, input_embed with ECAPA-TDNN, + rotary_embed, block_diff) stays on CPU to avoid XLA tracing issues with: + - ECAPA-TDNN Conv1d(padding="same", padding_mode="reflect") + - AttentiveStatisticsPooling (dynamic masks, masked_fill(-inf)) + - DiTCodecEmbedding (torch.repeat_interleave) + - DiTInputEmbedding (2D/3D tensor cat mismatch) + - Rotary embedding (torch.autocast context manager) + + Each block may have a different attention pattern (look_backward/look_ahead). + Three per-block masks are pre-computed on CPU: + - mask_local: look_backward=0, look_ahead=0 (most blocks) + - mask_backward: look_backward=1, look_ahead=0 (blocks 0, 20) + - mask_ahead: look_backward=0, look_ahead=1 (block 10) + """ + + def __init__(self, dit_module): + super().__init__() + self.transformer_blocks = dit_module.transformer_blocks + self.norm_out = dit_module.norm_out + self.proj_out = dit_module.proj_out + # Build per-block mask selection (Python list for static trace) + # 0 = mask_local (0,0), 1 = mask_backward (1,0), 2 = mask_ahead (0,1) + self._block_mask_idx = [] + for block in dit_module.transformer_blocks: + lb = block.look_backward_block + la = block.look_ahead_block + if lb == 0 and la == 0: + self._block_mask_idx.append(0) + elif lb >= 1 and la == 0: + self._block_mask_idx.append(1) + else: # la >= 1 + self._block_mask_idx.append(2) + + def forward(self, hidden_states, time_embedding, cos, sin, + mask_local, mask_backward, mask_ahead): + """ + Args: + hidden_states: (batch, seq_len, dim) from input_embed + time_embedding: (time_batch, dim) from time_embed (broadcasts) + cos: (batch, seq_len, head_dim) from rotary_embed + sin: (batch, seq_len, head_dim) from rotary_embed + mask_local: (batch, 1, seq_len, seq_len) float mask (0,0) + mask_backward: (batch, 1, seq_len, seq_len) float mask (1,0) + mask_ahead: (batch, 1, seq_len, seq_len) float mask (0,1) + """ + position_embeddings = (cos, sin) + masks = [mask_local, mask_backward, mask_ahead] + + for i, block in enumerate(self.transformer_blocks): + # Select per-block mask (static Python int index → traced as constant) + attention_mask = masks[self._block_mask_idx[i]] + + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.attn_norm( + hidden_states, emb=time_embedding + ) + attn_output = block.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + norm = ( + block.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + + shift_mlp[:, None] + ) + ff_output = block.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + return output + + +class NeuronQwen25OmniToken2WavWithNeuronDiT(NeuronQwen25OmniToken2Wav): + """Token2Wav with DiT transformer core compiled on Neuron. + + The DiT is the compute bottleneck in Token2Wav: + - 22 transformer blocks × 10 ODE steps = 220 forward passes per generation + - Each block: self-attention (dim=1024, 16 heads) + FFN + - Total ~85M params + + Split architecture (CPU preprocessing + Neuron transformer core): + CPU: time_embed(time_step) → time embedding + CPU: text_embed(quantized_code) → codec features + CPU: input_embed(hidden_states, speaker_embedding, condition_vector, code_embed) + → includes ECAPA-TDNN speaker encoder, CFG batch doubling + CPU: rotary_embed(hidden_states) → cos, sin + CPU: _create_block_diff(hidden_states) → attention mask + Neuron: 22 transformer blocks + norm_out + proj_out (the compute hotspot) + CPU: ODE solver loop (Runge-Kutta 4, 10 steps) + CPU: BigVGAN vocoder → waveform + + Usage: + t2w = NeuronQwen25OmniToken2WavWithNeuronDiT(token2wav_config) + t2w.load_state_dict(state_dict) + t2w.compile_dit("compiled_dit/", max_mel_len=2048) + # Subsequent runs: + t2w.load_dit("compiled_dit/") + waveform = t2w(code, conditioning, reference_mel) + """ + + def __init__(self, token2wav_config): + super().__init__(token2wav_config) + self._neuron_dit_core = None + self._dit_compiled_path = None + self._dit_max_mel_len = None + self._dit_batch_size = None + + def _get_dit_module(self): + """Extract the DiT sub-module from the HF Token2Wav model.""" + for attr_name in [ + "code2wav_dit_model", "dit", "flow_model", "transformer", + ]: + if hasattr(self.model, attr_name): + return getattr(self.model, attr_name) + for name, module in self.model.named_children(): + if "dit" in name.lower() or "flow" in name.lower(): + return module + return None + + def compile_dit( + self, + compiled_path, + max_mel_len=2048, + batch_size=2, + ): + """Compile the DiT transformer core on Neuron. + + Only the 22 transformer blocks + norm_out + proj_out are compiled. + Preprocessing (ECAPA-TDNN, codec embedding, input embedding, rotary) + stays on CPU to avoid XLA tracing issues. + + Args: + compiled_path: Directory to save compiled model + max_mel_len: Maximum mel spectrogram length (covers ~24s audio). + Shorter inputs are padded; longer inputs fall back to CPU. + batch_size: Batch size for compilation. Use 2 for standard + inference with classifier-free guidance (CFG doubles batch). + """ + try: + import torch_neuronx + except ImportError: + raise ImportError( + "torch_neuronx required for DiT compilation. " + "Run on a Neuron instance (trn1/trn2/inf2)." + ) + + os.makedirs(compiled_path, exist_ok=True) + + dit = self._get_dit_module() + if dit is None: + raise RuntimeError("Could not extract DiT module from Token2Wav.") + + # Get model dimensions + dit_cfg = getattr(dit, "config", None) + dim = getattr(dit_cfg, "dim", 1024) + num_heads = getattr(dit_cfg, "num_attention_heads", 16) + head_dim = dim // num_heads + + logger.info( + "Compiling DiT core: batch=%d, mel_len=%d, dim=%d, heads=%d", + batch_size, max_mel_len, dim, num_heads, + ) + + # Monkeypatch DiTAttention to fix in-place slice assignment + _monkeypatch_dit_attention_for_neuron(dit) + + # Create wrapper for transformer core only + core = _NeuronDiTCore(dit) + core.float() + core.eval() + + # Create example inputs + # time_embedding uses batch=1 (broadcasts to hidden_states batch) + hidden_states = torch.randn( + batch_size, max_mel_len, dim, dtype=torch.float32 + ) + time_embedding = torch.randn(1, dim, dtype=torch.float32) + cos = torch.randn( + batch_size, max_mel_len, head_dim, dtype=torch.float32 + ) + sin = torch.randn( + batch_size, max_mel_len, head_dim, dtype=torch.float32 + ) + # Three per-block attention masks (local, backward, ahead) + mask_local = torch.zeros( + batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 + ) + mask_backward = torch.zeros( + batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 + ) + mask_ahead = torch.zeros( + batch_size, 1, max_mel_len, max_mel_len, dtype=torch.float32 + ) + + compiled = torch_neuronx.trace( + core, + (hidden_states, time_embedding, cos, sin, + mask_local, mask_backward, mask_ahead), + compiler_args=[ + "--auto-cast=none", + "--model-type=transformer", + "-O1", + ], + ) + + save_path = os.path.join(compiled_path, "dit_core_neuron.pt") + torch.jit.save(compiled, save_path) + + # Save metadata for load + meta = { + "max_mel_len": max_mel_len, + "batch_size": batch_size, + "dim": dim, + "num_heads": num_heads, + "head_dim": head_dim, + } + with open(os.path.join(compiled_path, "dit_core_meta.json"), "w") as f: + json.dump(meta, f) + + logger.info("Compiled DiT core saved to %s", save_path) + + self._neuron_dit_core = compiled + self._dit_compiled_path = compiled_path + self._dit_max_mel_len = max_mel_len + self._dit_batch_size = batch_size + + def load_dit(self, compiled_path): + """Load a previously compiled DiT core model. + + Args: + compiled_path: Directory containing compiled model + """ + save_path = os.path.join(compiled_path, "dit_core_neuron.pt") + meta_path = os.path.join(compiled_path, "dit_core_meta.json") + + if not os.path.exists(save_path): + raise FileNotFoundError( + f"Compiled DiT core not found at {save_path}" + ) + + self._neuron_dit_core = torch.jit.load(save_path) + self._dit_compiled_path = compiled_path + + if os.path.exists(meta_path): + with open(meta_path) as f: + meta = json.load(f) + self._dit_max_mel_len = meta["max_mel_len"] + self._dit_batch_size = meta["batch_size"] + + logger.info("Loaded compiled DiT core from %s", save_path) + + def _build_attention_masks(self, block_diff, actual_mel_len, max_mel_len): + """Build three per-block float additive attention masks with padding. + + DiT blocks have three distinct attention patterns based on their + look_backward_block/look_ahead_block attributes: + - mask_local (0,0): most blocks — attend only within same block + - mask_backward (1,0): blocks 0,20 — attend current + previous block + - mask_ahead (0,1): block 10 — attend current + next block + + Args: + block_diff: (batch, heads, actual_mel_len, actual_mel_len) + from _create_block_diff. Values are block index differences. + actual_mel_len: Actual sequence length before padding + max_mel_len: Padded sequence length (compiled size) + + Returns: + Tuple of 3 masks, each (batch, 1, max_mel_len, max_mel_len) float32. + 0.0 for attend, -1e4 for don't attend. + """ + mask_batch = self._dit_batch_size or 1 + bd = block_diff[0, 0] # (actual_mel_len, actual_mel_len) + + # Three patterns: (look_backward, look_ahead) + patterns = [(0, 0), (1, 0), (0, 1)] + masks = [] + for lb, la in patterns: + bool_mask = (bd >= -float(lb)) & (bd <= float(la)) + valid = torch.where( + bool_mask, + torch.tensor(0.0, dtype=torch.float32), + torch.tensor(-1e4, dtype=torch.float32), + ) + mask = torch.full( + (mask_batch, 1, max_mel_len, max_mel_len), + -1e4, dtype=torch.float32, + ) + for b in range(mask_batch): + mask[b, 0, :actual_mel_len, :actual_mel_len] = valid + masks.append(mask) + + return masks[0], masks[1], masks[2] + + @torch.no_grad() + def __call__( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generate waveform. DiT core runs on Neuron if compiled, else CPU. + + Monkeypatches dit.forward to split execution: + - CPU: preprocessing (time/text/input embed, rotary, block_diff) + - Neuron: 22 transformer blocks + norm + proj + - CPU: ODE solver, BigVGAN vocoder + """ + if self._neuron_dit_core is not None: + dit = self._get_dit_module() + original_forward = dit.forward + neuron_core = self._neuron_dit_core + max_mel_len = self._dit_max_mel_len + expected_batch = self._dit_batch_size + build_masks = self._build_attention_masks + + def neuron_dit_forward( + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + """DiT forward with Neuron-accelerated transformer core.""" + batch_size = hidden_states.shape[0] + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Estimate mel_len early. input_embed doubles batch for + # CFG, so we must fall back BEFORE calling it — otherwise + # original_forward would receive already-modified tensors. + est_mel_len = hidden_states.shape[1] + if est_mel_len > max_mel_len: + logger.warning( + "mel_len %d > max %d, falling back to CPU", + est_mel_len, + max_mel_len, + ) + return original_forward( + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=drop_audio_conditioning, + drop_code=drop_code, + apply_cfg=apply_cfg, + ) + + # CPU: compute embeddings (same as HF original) + time_embedding = dit.time_embed(time_step) + text_embedding = dit.text_embed( + quantized_code, + drop_code=False if apply_cfg else drop_code, + ) + text_embedding_uncond = ( + dit.text_embed(quantized_code, drop_code=True) + if apply_cfg + else None + ) + + # CPU: input embedding (ECAPA-TDNN, CFG batch doubling) + hidden_states = dit.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_uncond, + apply_cfg=apply_cfg, + ) + + # CPU: positional encodings + cos, sin = dit.rotary_embed(hidden_states) + block_diff = dit._create_block_diff(hidden_states) + + actual_mel_len = hidden_states.shape[1] + actual_batch = hidden_states.shape[0] + + # Build three per-block attention masks + mask_local, mask_backward, mask_ahead = build_masks( + block_diff, actual_mel_len, max_mel_len + ) + + # Pad to compiled shapes + pad_mel = max_mel_len - actual_mel_len + if pad_mel > 0: + hidden_states = torch.nn.functional.pad( + hidden_states, (0, 0, 0, pad_mel) + ) + cos = torch.nn.functional.pad(cos, (0, 0, 0, pad_mel)) + sin = torch.nn.functional.pad(sin, (0, 0, 0, pad_mel)) + + pad_batch = expected_batch - actual_batch + if pad_batch > 0: + hidden_states = torch.nn.functional.pad( + hidden_states, (0, 0, 0, 0, 0, pad_batch) + ) + cos = torch.nn.functional.pad( + cos, (0, 0, 0, 0, 0, pad_batch) + ) + sin = torch.nn.functional.pad( + sin, (0, 0, 0, 0, 0, pad_batch) + ) + + # Run transformer core on Neuron (3 per-block masks) + output = neuron_core( + hidden_states.float(), + time_embedding.float(), + cos.float(), + sin.float(), + mask_local.float(), + mask_backward.float(), + mask_ahead.float(), + ) + + # Unpad to actual sizes + output = output[:actual_batch, :actual_mel_len] + return output + + dit.forward = neuron_dit_forward + try: + result = self.model( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + **kwargs, + ) + finally: + dit.forward = original_forward + return result + else: + return self.model( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + **kwargs, + ) + + @classmethod + def from_pretrained_state_dict(cls, token2wav_config, state_dict): + """Create Token2Wav with Neuron DiT support.""" + token2wav = cls(token2wav_config) + + t2w_keys = {} + for key, value in state_dict.items(): + if any( + key.startswith(p) + for p in [ + "lm_head.", "visual.", "audio_tower.", + "thinker.", "talker.", "token2wav.", + ] + ): + continue + t2w_keys[key] = value + + missing, unexpected = token2wav.load_state_dict(t2w_keys, strict=False) + if missing: + logger.warning("Token2Wav missing keys: %s", missing[:10]) + if unexpected: + logger.warning("Token2Wav unexpected keys: %s", unexpected[:10]) + logger.info( + "Loaded %d weights into Token2Wav (Neuron DiT capable)", + len(t2w_keys), + ) + + return token2wav diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_vision.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_vision.py new file mode 100644 index 00000000..1aebd529 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen25_omni_vision.py @@ -0,0 +1,579 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Qwen2.5-Omni Vision Encoder for NXD inference. +# +# Differences from Qwen2-VL vision encoder: +# - SwiGLU MLP (gate_proj, up_proj, down_proj) instead of simple FC1/FC2 +# - RMSNorm instead of LayerNorm +# - Separate Q/K/V projections instead of fused QKV +# - intermediate_size=3420 (not TP-divisible by 32, use nn.Linear for MLP) + +"""Qwen2.5-Omni Vision Encoder for NXD inference.""" + +import logging +import os +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + PatchEmbed, + VisionRotaryEmbedding, +) + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed_inference.models.application_base import NeuronApplicationBase +from neuronx_distributed_inference.models.config import InferenceConfig +from neuronx_distributed_inference.models.model_wrapper import ( + EncoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.models.qwen2_vl.modeling_qwen2_vl_vision import ( + Qwen2VLVisionRotaryEmbedding, +) +from neuronx_distributed_inference.models.qwen2_vl.utils.vision_utils import ( + calculate_max_grid_size, + get_image_dimensions, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import apply_rotary_pos_emb +from neuronx_distributed_inference.modules.padding import ( + pad_tensor, + pad_with_first_batchline, +) + +logger = logging.getLogger(__name__) + + +class VisionRMSNorm(nn.Module): + """RMSNorm for vision encoder (replaces LayerNorm used in Qwen2-VL).""" + + def __init__(self, hidden_size, eps=1e-6, dtype=torch.bfloat16): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class VisionSwiGLUMLP(nn.Module): + """SwiGLU MLP for Qwen2.5-Omni vision encoder. + + Uses regular nn.Linear instead of ColumnParallelLinear/RowParallelLinear + because intermediate_size=3420 is not divisible by common TP degrees (16, 32). + The vision model is small enough that MLP weight replication is acceptable. + """ + + def __init__(self, dim, hidden_dim, dtype=torch.bfloat16): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=True) + self.up_proj = nn.Linear(dim, hidden_dim, bias=True) + self.down_proj = nn.Linear(hidden_dim, dim, bias=True) + # Cast to target dtype + self.gate_proj = self.gate_proj.to(dtype) + self.up_proj = self.up_proj.to(dtype) + self.down_proj = self.down_proj.to(dtype) + + def forward(self, x): + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronQwen25OmniVisionAttention(NeuronAttentionBase): + """Vision attention with separate Q/K/V projections (not fused). + + Requires vision_config.neuron_config.fused_qkv = False. + """ + + def __init__(self, config): + super().__init__( + config=config, + hidden_size=config.embed_dim, + num_attention_heads=config.num_heads, + num_key_value_heads=config.num_heads, + head_dim=config.embed_dim // config.num_heads, + num_cores_per_group=config.num_cores_per_group, + sequence_parallel_enabled=False, + rotary_emb=Qwen2VLVisionRotaryEmbedding(), + qkv_bias=True, + o_bias=True, + ) + + def forward(self, hidden_states, position_embeddings=None, **kwargs): + self._position_embeddings = position_embeddings + try: + return super().forward(hidden_states, **kwargs) + finally: + self._position_embeddings = None + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, self._position_embeddings) + Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache) + return Q, K, cos_cache, sin_cache + + +class Qwen25OmniVisionBlock(nn.Module): + """Vision transformer block with RMSNorm and SwiGLU MLP.""" + + def __init__(self, vision_config): + super().__init__() + dtype = vision_config.neuron_config.torch_dtype + self.norm1 = VisionRMSNorm( + vision_config.embed_dim, eps=1e-6, dtype=dtype + ) + self.norm2 = VisionRMSNorm( + vision_config.embed_dim, eps=1e-6, dtype=dtype + ) + self.attn = NeuronQwen25OmniVisionAttention(vision_config) + self.mlp = VisionSwiGLUMLP( + dim=vision_config.embed_dim, + hidden_dim=vision_config.intermediate_size, + dtype=dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + attn_output = self.attn( + self.norm1(hidden_states), + position_embeddings=position_embeddings, + )[0] + hidden_states = hidden_states + attn_output + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen25OmniPatchMerger(nn.Module): + """Patch merger with RMSNorm (Qwen2-VL uses LayerNorm). + + Merges spatial_merge_size^2 patches into one, projecting from + embed_dim to out_hidden_size (text model hidden size). + """ + + def __init__(self, dim, context_dim, spatial_merge_size=2, dtype=torch.bfloat16): + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size ** 2) + self.ln_q = VisionRMSNorm(context_dim, eps=1e-6, dtype=dtype) + self.mlp = nn.Sequential( + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + gather_output=False, + dtype=dtype, + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + dim, + input_is_parallel=True, + dtype=dtype, + reduce_dtype=dtype, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class NeuronQwen25OmniVisionModel(nn.Module): + """Qwen2.5-Omni Vision Encoder on Neuron. + + Architecture is based on Qwen2-VL ViT but uses RMSNorm, SwiGLU MLP, + and separate Q/K/V projections. Reuses the same PatchEmbed and + VisionRotaryEmbedding from Qwen2-VL (identical parameters). + """ + + def __init__(self, config: InferenceConfig) -> None: + super().__init__() + self.config = config + self.vision_config = config.vision_config + + self.spatial_merge_size = self.vision_config.spatial_merge_size + + # Reuse Qwen2-VL PatchEmbed (same Conv3D architecture) + self.patch_embed = PatchEmbed( + patch_size=self.vision_config.patch_size, + temporal_patch_size=self.vision_config.temporal_patch_size, + in_channels=self.vision_config.in_channels, + embed_dim=self.vision_config.embed_dim, + ).to(self.vision_config.neuron_config.torch_dtype) + + head_dim = self.vision_config.embed_dim // self.vision_config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen25OmniVisionBlock(self.vision_config) + for _ in range(self.vision_config.depth) + ] + ) + + self.merger = Qwen25OmniPatchMerger( + dim=self.vision_config.out_hidden_size, # 3584 (text hidden size) + context_dim=self.vision_config.embed_dim, # 1280 + spatial_merge_size=self.vision_config.spatial_merge_size, + dtype=self.vision_config.neuron_config.torch_dtype, + ) + + # Calculate dynamic MAX_GRID_SIZE based on configured image dimensions + image_width, image_height = get_image_dimensions( + self.vision_config.neuron_config + ) + self.max_grid_size = calculate_max_grid_size( + image_width, + image_height, + patch_size=self.vision_config.patch_size, + ) + logger.info( + f"Calculated max_grid_size={self.max_grid_size} for " + f"image dimensions {image_width}x{image_height}" + ) + + self.precomputed_rotary_pos_emb = self.rotary_pos_emb(self.max_grid_size) + self.register_buffer( + "rotary_pos_emb_cache", + self.precomputed_rotary_pos_emb, + persistent=False, + ) + + def rot_pos_ids(self, grid_thw): + """Compute rotary position IDs for patches (same algorithm as Qwen2-VL).""" + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) + ) + pos_ids = torch.cat(pos_ids, dim=0) + return pos_ids + + def pad_to_text_seq_len(self, hidden_states): + """Pad vision outputs to text model sequence length.""" + padded_length = self.config.neuron_config.seq_len + hidden_states = hidden_states.to( + self.config.text_config.neuron_config.torch_dtype + ) + + hidden_size = hidden_states.shape[-1] + hidden_states, _ = pad_tensor( + hidden_states, (padded_length, hidden_size), pad_value=0 + ) + + # Flatten vision outputs: (seq_len, hidden_size) -> (1, seq_len, hidden_size) + hidden_states = hidden_states.view(-1, hidden_size).unsqueeze(0) + return hidden_states + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + + assert grid_thw[:, 1:].max() < self.max_grid_size, ( + f"Grid size {grid_thw[:, 1:].max()} exceeds max_grid_size " + f"{self.max_grid_size}. Increase default_image_width/height " + f"in vision_neuron_config." + ) + pos_ids = self.rot_pos_ids(grid_thw) + rotary_pos_emb = self.rotary_pos_emb_cache[pos_ids].flatten(1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos_emb = emb.cos() + sin_emb = emb.sin() + cos_emb = cos_emb.reshape(grid_thw.shape[0], -1, cos_emb.shape[-1]) + sin_emb = sin_emb.reshape(grid_thw.shape[0], -1, sin_emb.shape[-1]) + position_embeddings = (cos_emb, sin_emb) + + hidden_states = hidden_states.reshape( + grid_thw.shape[0], -1, hidden_states.shape[-1] + ) + for blk in self.blocks: + hidden_states = blk(hidden_states, position_embeddings) + hidden_states_merger = self.merger(hidden_states) + return self.pad_to_text_seq_len(hidden_states_merger) + + +class Qwen25OmniVisionModelWrapper(ModelWrapper): + """Model wrapper for Qwen2.5-Omni vision encoder. + + Handles bucketing on number of images, padding, and input generation. + Same pattern as Qwen2VLVisionModelWrapper. + """ + + def __init__( + self, + config: InferenceConfig, + model_cls, + tag="", + compiler_args: str = None, + priority_model_idx: int = None, + pipeline_execution: bool = True, + return_ranked_to_cpu: bool = False, + model_init_kwargs={}, + ) -> None: + super().__init__( + config, + model_cls, + tag, + compiler_args, + priority_model_idx, + pipeline_execution, + return_ranked_to_cpu, + model_init_kwargs, + ) + from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + smart_resize, + ) + + image_width, image_height = get_image_dimensions( + self.config.vision_config.neuron_config + ) + resized_height, resized_width = smart_resize( + width=image_width, height=image_height + ) + self.pixels_per_image = ( + resized_height // self.config.vision_config.patch_size + ) * (resized_width // self.config.vision_config.patch_size) + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + smart_resize, + ) + + inputs = [] + image_width, image_height = get_image_dimensions( + self.config.vision_config.neuron_config + ) + resized_height, resized_width = smart_resize( + width=image_width, height=image_height + ) + vc = self.config.vision_config + for bucket in vc.neuron_config.buckets: + pixel_values = torch.ones( + [ + bucket * self.pixels_per_image, + vc.in_channels + * vc.patch_size + * vc.patch_size + * vc.temporal_patch_size, + ], + dtype=vc.neuron_config.torch_dtype, + ) + grid_thw = torch.tensor( + [ + [ + 1, + resized_height // vc.patch_size, + resized_width // vc.patch_size, + ] + ] + ).repeat(bucket, 1) + inputs.append((pixel_values, grid_thw)) + return inputs + + def get_model_instance(self): + return EncoderModelInstance(model_cls=self.model_cls, config=self.config) + + def get_padded_num_image(self, pixel_values): + """Get the bucket size (number of images) for given pixel_values.""" + buckets = self.config.vision_config.neuron_config.buckets + for val in buckets: + if val * self.pixels_per_image >= pixel_values.shape[0]: + return val + raise Exception( + f"No bucket found for pixel_values shape {pixel_values.shape[0]}. " + f"pixels_per_image={self.pixels_per_image}, buckets={buckets}" + ) + + def forward(self, pixel_values, grid_thw): + """Override ModelWrapper.forward() with padding to bucket size.""" + if self.model is None: + raise RuntimeError( + "Forward called before load. Run load() or load_state_dict() " + "before calling forward" + ) + padded_num_image = self.get_padded_num_image(pixel_values) + padded_pixel_values = pad_with_first_batchline( + pixel_values, + (padded_num_image * self.pixels_per_image, pixel_values.shape[1]), + ) + padded_grid_thw = pad_with_first_batchline( + grid_thw, (padded_num_image, 3) + ) + output = self._forward(padded_pixel_values, padded_grid_thw) + return output + + +class NeuronQwen25OmniForImageEncoding(NeuronApplicationBase): + """Standalone Neuron Application for Qwen2.5-Omni image encoding. + + Wraps NeuronQwen25OmniVisionModel with compile/load functionality. + Can be used independently or as part of the full multimodal model. + """ + + _model_cls = NeuronQwen25OmniVisionModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_wrapper = self.get_model_wrapper_cls() + + self.model = self.model_wrapper( + config=self.config, + model_cls=self._model_cls, + tag=self._model_cls.__name__, + compiler_args=self.get_compiler_args(), + priority_model_idx=0, + ) + self.models.append(self.model) + + def get_model_wrapper_cls(self): + return Qwen25OmniVisionModelWrapper + + def forward(self, pixel_values, grid_thw): + return self.models[0](pixel_values, grid_thw) + + def get_compiler_args(self): + compiler_args = ( + "--auto-cast=none --model-type=transformer " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 ' -O1 " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + logger.info( + f"Compiling {self._model_cls.__name__} vision model " + f"with args: {compiler_args}" + ) + return compiler_args + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load the full Qwen2.5-Omni model (vision weights will be filtered + in convert_hf_to_neuron_state_dict).""" + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, inference_config: InferenceConfig + ) -> dict: + """Convert HF state dict to NxDI format for vision encoder. + + Handles: + 1. Filter to vision keys only (thinker.visual.* or visual.*) + 2. Strip prefix + 3. Rename separate Q/K/V attention keys to NeuronAttentionBase format + 4. Cast to target dtype + """ + new_state_dict = {} + dtype = inference_config.vision_config.neuron_config.torch_dtype + + for key, value in state_dict.items(): + # Accept both "thinker.visual." and "visual." prefixes + if key.startswith("thinker.visual."): + new_key = key[len("thinker.visual."):] + elif key.startswith("visual."): + new_key = key[len("visual."):] + else: + # Pass through non-vision keys unchanged + new_state_dict[key] = value + continue + + # Rename attention keys: separate Q/K/V -> NeuronAttentionBase format + if ".attn.proj." in new_key: + new_key = new_key.replace(".attn.proj.", ".attn.o_proj.") + elif ".attn.q." in new_key: + new_key = new_key.replace( + ".attn.q.", ".attn.qkv_proj.q_proj." + ) + elif ".attn.k." in new_key: + new_key = new_key.replace( + ".attn.k.", ".attn.qkv_proj.k_proj." + ) + elif ".attn.v." in new_key: + new_key = new_key.replace( + ".attn.v.", ".attn.qkv_proj.v_proj." + ) + + new_state_dict[new_key] = ( + value.clone().detach().contiguous().to(dtype) + ) + + del state_dict + return new_state_dict + + @classmethod + def get_config_cls(cls): + from modeling_qwen25_omni import ( + Qwen25OmniMultimodalInferenceConfig, + ) + + return Qwen25OmniMultimodalInferenceConfig + + @classmethod + def prepare_input_args(cls, prompts, images, processor, role="user", config=None): + """Prepare input arguments for Qwen2.5-Omni vision model.""" + from neuronx_distributed_inference.models.qwen2_vl.utils.input_processor import ( + prepare_generation_inputs_hf, + ) + + if len(prompts) > 1: + raise NotImplementedError( + "Qwen2.5-Omni currently only supports batch size 1" + ) + if isinstance(prompts, list): + prompts = prompts[0] + if images and isinstance(images, list) and isinstance(images[0], list): + images = images[0] + inputs = prepare_generation_inputs_hf( + prompts, images, processor, role, config + ) + vision_inputs = None + if hasattr(inputs, "pixel_values") and hasattr(inputs, "image_grid_thw"): + vision_inputs = { + "pixel_values": inputs.pixel_values, + "image_grid_thw": inputs.image_grid_thw, + } + return inputs.input_ids, inputs.attention_mask, vision_inputs diff --git a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen2_5_omni.py b/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen2_5_omni.py deleted file mode 100644 index 89e929cd..00000000 --- a/contrib/models/Qwen2.5-Omni-7B/src/modeling_qwen2_5_omni.py +++ /dev/null @@ -1,620 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -PyTorch Qwen2.5-Omni model for NXD inference (Text-only version) - -This implementation ports the text model (Thinker) from Qwen2.5-Omni to NeuronX Distributed Inference. -It focuses on text-only inference, ignoring multimodal (audio/vision) components. - -Based on: -- Reference: NeuronxDistributedInference/src/neuronx_distributed_inference/models/qwen2/modeling_qwen2.py -""" -import json -import os -from typing import List, Optional, Tuple, Type - -import torch -from neuronx_distributed.parallel_layers.layers import ( - ColumnParallelLinear, - ParallelEmbedding, - RowParallelLinear, -) -from neuronx_distributed.utils import cpu_mode -from torch import nn -from transformers.models.llama.modeling_llama import LlamaRMSNorm - -from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig -from neuronx_distributed_inference.models.model_base import ( - NeuronBaseForCausalLM, - NeuronBaseModel, -) -from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase -from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding -from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm - - -def get_rmsnorm_cls(): - """ - Initialize to the appropriate implementation of RMSNorm - If infer on NXD -> CustomRMSNorm - If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) - """ - return LlamaRMSNorm if cpu_mode() else CustomRMSNorm - - -class Qwen2_5OmniNeuronConfig(NeuronConfig): - """NeuronConfig for Qwen2.5-Omni model""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.attn_cls = NeuronQwen2_5OmniAttention - - -class Qwen2_5OmniInferenceConfig(InferenceConfig): - """ - Configuration class for Qwen2.5-Omni inference on NeuronX. - - This config handles the text model (Thinker) from Qwen2.5-Omni. - The thinker_config.text_config contains the core text model parameters. - """ - - def add_derived_config(self): - """Add derived configuration parameters""" - self.num_cores_per_group = 1 - self.qkv_bias = True # Qwen2.5-Omni has bias in Q/K/V projections - self.o_bias = False # No bias in output projection - - # Handle layer types for sliding window attention - # Default to all full attention if not specified - if not hasattr(self, 'layer_types') or self.layer_types is None: - self.layer_types = ['full_attention'] * self.num_hidden_layers - - # Multimodal RoPE section for 3D position embeddings - # [temporal, height, width] sections - for text-only, all positions are same - if not hasattr(self, 'mrope_section'): - self.mrope_section = [16, 24, 24] # Default from config - - # Add standard HuggingFace config attributes required by NeuronBaseModel - if not hasattr(self, 'output_attentions'): - self.output_attentions = False - if not hasattr(self, 'output_hidden_states'): - self.output_hidden_states = False - if not hasattr(self, 'use_return_dict'): - self.use_return_dict = True - if not hasattr(self, 'use_cache'): - self.use_cache = True - - def get_required_attributes(self) -> List[str]: - """List of required attributes for the configuration""" - return [ - "hidden_size", - "num_attention_heads", - "num_hidden_layers", - "num_key_value_heads", - "pad_token_id", - "vocab_size", - "max_position_embeddings", - "rope_theta", - "rms_norm_eps", - "hidden_act", - "intermediate_size", - ] - - @classmethod - def get_neuron_config_cls(cls) -> Type[Qwen2_5OmniNeuronConfig]: - """Return the NeuronConfig class to use""" - return Qwen2_5OmniNeuronConfig - - @classmethod - def from_pretrained(cls, model_path: str, **kwargs) -> "Qwen2_5OmniInferenceConfig": - """ - Load configuration from a pretrained Qwen2.5-Omni model directory. - - The Qwen2.5-Omni config has a nested structure: - config.json -> thinker_config -> text_config (the actual text model config) - - Args: - model_path: Path to the model directory containing config.json - **kwargs: Additional arguments to override configuration - - Returns: - Qwen2_5OmniInferenceConfig: Configuration object - """ - # Extract neuron_config from kwargs if it exists - neuron_config = kwargs.pop("neuron_config", None) - - # Try loading saved neuron config if not provided - # Try multiple possible locations - if neuron_config is None: - possible_paths = [ - os.path.join(model_path, "neuron_config.json"), - "neuron_config.json", # Current directory - ] - - for neuron_config_path in possible_paths: - if os.path.exists(neuron_config_path): - print(f"Loading neuron_config from: {neuron_config_path}") - with open(neuron_config_path, "r") as f: - neuron_config_data = json.load(f) - # The saved config has the neuron_config nested - if "neuron_config" in neuron_config_data: - neuron_config_dict = neuron_config_data["neuron_config"] - else: - neuron_config_dict = neuron_config_data - neuron_config = cls.get_neuron_config_cls()(**neuron_config_dict) - break - - # Read the full config.json - config_path = os.path.join(model_path, "config.json") - with open(config_path, "r") as f: - full_config = json.load(f) - - # Navigate to the text model config - # Path: config.json -> thinker_config -> text_config - thinker_config = full_config.get("thinker_config", {}) - text_config = thinker_config.get("text_config", {}) - - if not text_config: - raise ValueError( - f"Could not find text_config in {config_path}. " - "Expected structure: config.json -> thinker_config -> text_config" - ) - - # Extract configuration parameters from text_config - config_dict = { - "hidden_size": text_config.get("hidden_size"), - "num_attention_heads": text_config.get("num_attention_heads"), - "num_hidden_layers": text_config.get("num_hidden_layers"), - "num_key_value_heads": text_config.get("num_key_value_heads"), - "vocab_size": text_config.get("vocab_size"), - "max_position_embeddings": text_config.get("max_position_embeddings"), - "intermediate_size": text_config.get("intermediate_size"), - "rms_norm_eps": text_config.get("rms_norm_eps"), - "rope_theta": text_config.get("rope_theta"), - "hidden_act": text_config.get("hidden_act"), - "sliding_window": text_config.get("sliding_window"), - "use_sliding_window": text_config.get("use_sliding_window", False), - } - - # Extract pad_token_id from thinker_config (not in text_config) - config_dict["pad_token_id"] = thinker_config.get("pad_token_id") - - # Extract rope_scaling if present - if "rope_scaling" in text_config and text_config["rope_scaling"]: - rope_scaling = text_config["rope_scaling"] - config_dict["rope_scaling"] = rope_scaling - # Extract mrope_section for multimodal RoPE - config_dict["mrope_section"] = rope_scaling.get("mrope_section", [16, 24, 24]) - - # Handle layer_types for sliding window attention - # Qwen2.5-Omni alternates between full and sliding attention - num_layers = config_dict["num_hidden_layers"] - if config_dict.get("use_sliding_window"): - # Alternate between full and sliding attention - config_dict["layer_types"] = ["sliding_attention" if i % 2 else "full_attention" - for i in range(num_layers)] - else: - # All layers use full attention - config_dict["layer_types"] = ["full_attention"] * num_layers - - # Override with kwargs - config_dict.update(kwargs) - - # Create and return config - config = cls(neuron_config=neuron_config, **config_dict) - return config - - -class NeuronQwen2_5OmniAttention(NeuronAttentionBase): - """ - Qwen2.5-Omni attention mechanism for NeuronX. - - Based on NeuronQwen2Attention but with multimodal RoPE support. - The multimodal RoPE is handled at the model level, so this class - uses standard NeuronAttentionBase with bias configurations. - - Reference: - - HF: Qwen2_5OmniAttention in modeling_qwen2_5_omni.py - - NXD: NeuronQwen2Attention in modeling_qwen2.py - """ - - def __init__(self, config: Qwen2_5OmniInferenceConfig, layer_idx: int = 0): - """ - Initialize Qwen2.5-Omni attention. - - Args: - config: Model configuration - layer_idx: Layer index (used for sliding window) - """ - self.layer_idx = layer_idx - - # Create rotary embedding - rotary_emb = RotaryEmbedding( - config.hidden_size // config.num_attention_heads, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) - - # Determine if this layer uses sliding window - sliding_window = None - if hasattr(config, 'layer_types') and config.layer_types: - if config.layer_types[layer_idx] == "sliding_attention": - sliding_window = getattr(config, 'sliding_window', None) - - # Initialize base attention - super().__init__( - config=config, - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - head_dim=config.hidden_size // config.num_attention_heads, - qkv_bias=config.qkv_bias, # Qwen2.5-Omni has bias in QKV - o_bias=config.o_bias, # No bias in output - rotary_emb=rotary_emb, - sliding_window=sliding_window, - ) - - -class NeuronQwen2_5OmniMLP(nn.Module): - """ - Qwen2.5-Omni MLP layer (same as Qwen2/Llama - SwiGLU activation). - - Architecture: - - gate_proj: Linear(hidden_size, intermediate_size) - - up_proj: Linear(hidden_size, intermediate_size) - - down_proj: Linear(intermediate_size, hidden_size) - - activation: SwiGLU = silu(gate_proj(x)) * up_proj(x) - - Reference: - - HF: Qwen2MLP in modeling_qwen2_5_omni.py - - NXD: NeuronLlamaMLP in modeling_llama.py - """ - - def __init__(self, config: Qwen2_5OmniInferenceConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - - # Gate projection (for SwiGLU) - self.gate_proj = ColumnParallelLinear( - config.hidden_size, - config.intermediate_size, - bias=False, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - ) - - # Up projection (for SwiGLU) - self.up_proj = ColumnParallelLinear( - config.hidden_size, - config.intermediate_size, - bias=False, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - ) - - # Down projection - self.down_proj = RowParallelLinear( - config.intermediate_size, - config.hidden_size, - bias=False, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, - ) - - # Activation function (SiLU for SwiGLU) - self.act_fn = nn.SiLU() - - def forward(self, x): - """ - Forward pass with SwiGLU activation. - - Args: - x: Input tensor [batch, seq_len, hidden_size] - - Returns: - Tuple of (output, None) for compatibility with NeuronBaseModel - """ - # SwiGLU: silu(gate_proj(x)) * up_proj(x) - gate_output = self.act_fn(self.gate_proj(x)) - up_output = self.up_proj(x) - intermediate_output = gate_output * up_output - - # Apply down projection - output = self.down_proj(intermediate_output) - - return output, None # Return None as second output for compatibility - - -class NeuronQwen2_5OmniDecoderLayer(nn.Module): - """ - Qwen2.5-Omni decoder layer for NeuronX. - - Architecture (pre-norm): - 1. hidden = input + self_attn(norm(input)) - 2. output = hidden + mlp(norm(hidden)) - - Reference: - - HF: Qwen2_5OmniDecoderLayer in modeling_qwen2_5_omni.py - - NXD: NeuronQwen2DecoderLayer in modeling_qwen2.py - """ - - def __init__(self, config: Qwen2_5OmniInferenceConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - - # Self-attention with layer-specific sliding window - self.self_attn = NeuronQwen2_5OmniAttention(config, layer_idx=layer_idx) - - # MLP (SwiGLU) - self.mlp = NeuronQwen2_5OmniMLP(config) - - # Layer normalization (RMSNorm) - self.input_layernorm = get_rmsnorm_cls()( - config.hidden_size, - eps=config.rms_norm_eps, - ) - self.post_attention_layernorm = get_rmsnorm_cls()( - config.hidden_size, - eps=config.rms_norm_eps, - ) - - # Store attention type for this layer - self.attention_type = config.layer_types[layer_idx] if hasattr(config, 'layer_types') else 'full_attention' - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Forward pass for decoder layer. - - Args: - hidden_states: Input tensor [batch, seq_len, hidden_size] - attention_mask: Attention mask - position_ids: Position indices - past_key_value: Cached key-value pairs - - Returns: - Tuple of (hidden_states, present_key_value, cos_cache, sin_cache, None) - """ - # Pre-norm: normalize before attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self-attention - hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - **kwargs, - ) - - # Residual connection - hidden_states = residual + hidden_states - - # Pre-norm: normalize before MLP - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states)[0] # Take first element of tuple - - # Residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) - - return outputs - - -class NeuronQwen2_5OmniModel(NeuronBaseModel): - """ - Qwen2.5-Omni text model for NeuronX inference. - - This implements the core text model (Thinker) from Qwen2.5-Omni, - focusing on text-only inference without multimodal components. - - Architecture: - - Token embeddings - - Stack of decoder layers with GQA and SwiGLU - - RMSNorm - - LM head for token generation - - Reference: - - HF: Qwen2_5OmniThinkerTextModel in modeling_qwen2_5_omni.py - - NXD: NeuronQwen2Model in modeling_qwen2.py - """ - - def setup_attr_for_model(self, config: Qwen2_5OmniInferenceConfig): - """Setup attributes for model initialization""" - self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None - self.tp_degree = config.neuron_config.tp_degree - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.max_batch_size = config.neuron_config.max_batch_size - self.buckets = config.neuron_config.buckets - - def init_model(self, config: Qwen2_5OmniInferenceConfig): - """Initialize the model components""" - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - # Token embeddings - self.embed_tokens = ParallelEmbedding( - config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=config.neuron_config.torch_dtype, - shard_across_embedding=True, - pad=True, - ) - - # Decoder layers - self.layers = nn.ModuleList( - [NeuronQwen2_5OmniDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers)] - ) - - # Final normalization - self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) - - # LM head for token generation - self.lm_head = ColumnParallelLinear( - config.hidden_size, - config.vocab_size, - bias=False, - pad=True, - gather_output=not self.on_device_sampling, - dtype=config.neuron_config.torch_dtype, - ) - - -class NeuronQwen2_5OmniForCausalLM(NeuronBaseForCausalLM): - """ - Qwen2.5-Omni for Causal Language Modeling on NeuronX. - - This is the main entry point for using Qwen2.5-Omni on NeuronX. - It provides a HuggingFace-compatible interface for text generation. - - Usage: - config = Qwen2_5OmniInferenceConfig.from_pretrained(model_path, neuron_config=neuron_config) - model = NeuronQwen2_5OmniForCausalLM(config) - model.load_state_dict(state_dict) - model.compile_model() - model.save_pretrained(output_path) - - Reference: - - HF: Qwen2_5OmniThinkerForConditionalGeneration in modeling_qwen2_5_omni.py - - NXD: NeuronQwen2ForCausalLM in modeling_qwen2.py - """ - - _model_cls = NeuronQwen2_5OmniModel - - @staticmethod - def convert_hf_to_neuron_state_dict(state_dict, config): - """ - Convert HuggingFace Qwen2.5-Omni state dict to NeuronX format. - - The Qwen2.5-Omni checkpoint has a nested structure with the text model under - 'model.thinker.model' prefix. This function extracts and renames the weights. - - NeuronAttentionBase expects QKV weights in the format: - - layers.{i}.self_attn.qkv_proj.q_proj.weight (for separate Q/K/V) - - layers.{i}.self_attn.qkv_proj.q_proj.bias - - Weight mappings: - - thinker.model.embed_tokens.weight -> embed_tokens.weight - - thinker.model.layers.{i}.self_attn.q_proj.* -> layers.{i}.self_attn.qkv_proj.q_proj.* - - thinker.model.layers.{i}.self_attn.k_proj.* -> layers.{i}.self_attn.qkv_proj.k_proj.* - - thinker.model.layers.{i}.self_attn.v_proj.* -> layers.{i}.self_attn.qkv_proj.v_proj.* - - thinker.model.layers.{i}.self_attn.o_proj.* -> layers.{i}.self_attn.o_proj.* - - thinker.model.layers.{i}.mlp.* -> layers.{i}.mlp.* - - thinker.model.norm.weight -> norm.weight - - thinker.lm_head.weight -> lm_head.weight - - Args: - state_dict: HuggingFace state dictionary - config: Model configuration - - Returns: - Dictionary with NeuronX-formatted weights - """ - import torch - - neuron_state_dict = {} - - # Remove prefixes: either "model.thinker.model." or just "thinker.model." - # The actual prefix might vary - possible_prefixes = [ - "model.thinker.model.", - "thinker.model.", - "model.thinker.", - "thinker.", - "" - ] - - # Detect which prefix is used - actual_prefix = "" - for prefix in possible_prefixes: - test_key = f"{prefix}embed_tokens.weight" - if test_key in state_dict: - actual_prefix = prefix - break - - print(f"Detected HF weight prefix: '{actual_prefix}'") - - for name, param in state_dict.items(): - # Skip weights not belonging to the text model (thinker) - if not name.startswith(actual_prefix) and actual_prefix != "": - # Check if this is the lm_head - lm_head_patterns = ["model.thinker.lm_head.", "thinker.lm_head.", "lm_head."] - is_lm_head = any(name.startswith(p) for p in lm_head_patterns) - if not is_lm_head: - continue - - # Remove the prefix - if actual_prefix and name.startswith(actual_prefix): - new_name = name[len(actual_prefix):] - else: - # Handle lm_head separately - for lm_prefix in ["model.thinker.lm_head.", "thinker.lm_head.", "lm_head."]: - if name.startswith(lm_prefix): - new_name = name[len(lm_prefix):] - new_name = "lm_head." + new_name - break - else: - new_name = name - - # Map attention weights to qkv_proj structure - if ".self_attn.q_proj." in new_name: - new_name = new_name.replace(".self_attn.q_proj.", ".self_attn.qkv_proj.q_proj.") - elif ".self_attn.k_proj." in new_name: - new_name = new_name.replace(".self_attn.k_proj.", ".self_attn.qkv_proj.k_proj.") - elif ".self_attn.v_proj." in new_name: - new_name = new_name.replace(".self_attn.v_proj.", ".self_attn.qkv_proj.v_proj.") - - # Clone and store the parameter - neuron_state_dict[new_name] = param.clone() - - print(f"Converted {len(state_dict)} HF weights to {len(neuron_state_dict)} Neuron weights") - - # Verify key weights exist - required_keys = ["embed_tokens.weight", "norm.weight", "lm_head.weight"] - for key in required_keys: - if key not in neuron_state_dict: - print(f"⚠️ Warning: Required key '{key}' not found in converted state dict") - - # Verify layer 0 attention weights - layer0_attn_keys = [ - "layers.0.self_attn.qkv_proj.q_proj.weight", - "layers.0.self_attn.qkv_proj.k_proj.weight", - "layers.0.self_attn.qkv_proj.v_proj.weight", - "layers.0.self_attn.o_proj.weight" - ] - for key in layer0_attn_keys: - if key not in neuron_state_dict: - print(f"⚠️ Warning: Layer 0 attention key '{key}' not found") - - return neuron_state_dict diff --git a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_e2e_qwen25_omni.py b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_e2e_qwen25_omni.py new file mode 100644 index 00000000..835f8313 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_e2e_qwen25_omni.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +"""End-to-end test for Qwen2.5-Omni on Trn2 (CPU inference). + +Tests multimodal input (text, image, audio) → text and audio output +using HF Qwen2_5OmniForConditionalGeneration directly. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + python3 test_e2e_qwen25_omni.py [--model-path MODEL_PATH] [--output-dir OUTPUT_DIR] +""" + +import argparse +import gc +import json +import logging +import os +import sys +import time +import traceback + +import numpy as np +import torch + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = "Qwen/Qwen2.5-Omni-7B" +DEFAULT_OUTPUT = "/home/ubuntu/e2e_test_results" + +# Qwen2.5-Omni requires this specific system prompt for audio output +SYSTEM_PROMPT = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, " + "capable of perceiving auditory and visual inputs, as well as generating text and speech." +) + + +def download_test_assets(output_dir): + """Create sample image and test audio.""" + from PIL import Image, ImageDraw + import wave + + assets_dir = os.path.join(output_dir, "assets") + os.makedirs(assets_dir, exist_ok=True) + + image_path = os.path.join(assets_dir, "test_image.jpg") + if not os.path.exists(image_path): + logger.info("Creating test image...") + img = Image.new("RGB", (320, 240), "white") + draw = ImageDraw.Draw(img) + draw.rectangle([20, 20, 100, 100], fill="red", outline="black") + draw.ellipse([130, 30, 230, 130], fill="blue", outline="black") + draw.polygon([(260, 120), (310, 20), (210, 20)], fill="green", outline="black") + draw.rectangle([0, 160, 320, 240], fill="lightgreen") + draw.ellipse([220, 60, 300, 140], fill="yellow") + img.save(image_path) + logger.info("Saved %s", image_path) + + audio_path = os.path.join(assets_dir, "test_audio.wav") + if not os.path.exists(audio_path): + logger.info("Creating test audio...") + sr = 16000 + t = np.linspace(0, 2.0, sr * 2, endpoint=False) + audio = (0.5 * np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + with wave.open(audio_path, "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sr) + wf.writeframes(audio.tobytes()) + logger.info("Saved %s", audio_path) + + return image_path, audio_path + + +def load_model(model_path): + """Load Qwen2.5-Omni model and processor.""" + from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor + + logger.info("Loading processor from %s ...", model_path) + processor = Qwen2_5OmniProcessor.from_pretrained(model_path) + + logger.info("Loading model from %s (this may take a few minutes)...", model_path) + start = time.time() + model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + model_path, + dtype=torch.bfloat16, + device_map="cpu", + ) + model.eval() + elapsed = time.time() - start + logger.info("Model loaded in %.1fs", elapsed) + + return model, processor + + +# ============================================================================ +# Test 1: Text → Text +# ============================================================================ + +def test_text_only(model, processor, output_dir): + logger.info("=" * 60) + logger.info("TEST 1: Text → Text") + logger.info("=" * 60) + + prompt = "What is the capital of France? Answer in one sentence." + + try: + messages = [ + {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=text, return_tensors="pt") + + start = time.time() + with torch.no_grad(): + output_ids = model.generate( + **inputs, + return_audio=False, + thinker_max_new_tokens=100, + ) + elapsed = time.time() - start + + response = processor.batch_decode( + output_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + + path = os.path.join(output_dir, "test1_text_response.txt") + with open(path, "w") as f: + f.write(f"Prompt: {prompt}\n\nResponse: {response}\n\nTime: {elapsed:.1f}s\n") + + logger.info("Response: %s", response) + logger.info("Time: %.1fs", elapsed) + return True, response, elapsed + + except Exception as e: + logger.error("Test 1 FAILED: %s", e) + traceback.print_exc() + return False, str(e), 0 + + +# ============================================================================ +# Test 2: Image + Text → Text +# ============================================================================ + +def test_image_text(model, processor, output_dir, image_path): + logger.info("=" * 60) + logger.info("TEST 2: Image + Text → Text") + logger.info("=" * 60) + + prompt = "Describe this image in detail. What shapes and colors do you see?" + + try: + from PIL import Image + image = Image.open(image_path).convert("RGB") + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt}, + ], + }, + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=text, images=[image], return_tensors="pt") + + start = time.time() + with torch.no_grad(): + output_ids = model.generate( + **inputs, + return_audio=False, + thinker_max_new_tokens=200, + ) + elapsed = time.time() - start + + response = processor.batch_decode( + output_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + + path = os.path.join(output_dir, "test2_image_response.txt") + with open(path, "w") as f: + f.write(f"Prompt: {prompt}\nImage: {image_path}\n\nResponse: {response}\n\nTime: {elapsed:.1f}s\n") + + logger.info("Response: %s", response[:300]) + logger.info("Time: %.1fs", elapsed) + return True, response, elapsed + + except Exception as e: + logger.error("Test 2 FAILED: %s", e) + traceback.print_exc() + return False, str(e), 0 + + +# ============================================================================ +# Test 3: Audio + Text → Text +# ============================================================================ + +def test_audio_text(model, processor, output_dir, audio_path): + logger.info("=" * 60) + logger.info("TEST 3: Audio + Text → Text") + logger.info("=" * 60) + + prompt = "What do you hear in this audio? Describe it." + + try: + import wave + + with wave.open(audio_path, "r") as wf: + sr = wf.getframerate() + frames = wf.readframes(wf.getnframes()) + audio_data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0 + + messages = [ + {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, + { + "role": "user", + "content": [ + {"type": "audio", "audio": audio_data, "sampling_rate": sr}, + {"type": "text", "text": prompt}, + ], + }, + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor( + text=text, + audios=[audio_data], + sampling_rate=sr, + return_tensors="pt", + ) + + start = time.time() + with torch.no_grad(): + output_ids = model.generate( + **inputs, + return_audio=False, + thinker_max_new_tokens=200, + ) + elapsed = time.time() - start + + response = processor.batch_decode( + output_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + + path = os.path.join(output_dir, "test3_audio_response.txt") + with open(path, "w") as f: + f.write(f"Prompt: {prompt}\nAudio: {audio_path}\n\nResponse: {response}\n\nTime: {elapsed:.1f}s\n") + + logger.info("Response: %s", response[:300]) + logger.info("Time: %.1fs", elapsed) + return True, response, elapsed + + except Exception as e: + logger.error("Test 3 FAILED: %s", e) + traceback.print_exc() + return False, str(e), 0 + + +# ============================================================================ +# Test 4: Text → Speech +# ============================================================================ + +def test_speech_output(model, processor, output_dir): + logger.info("=" * 60) + logger.info("TEST 4: Text → Speech") + logger.info("=" * 60) + + prompt = "Say hello and tell me what the weather is like today." + + try: + messages = [ + {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=text, return_tensors="pt") + + start = time.time() + with torch.no_grad(): + result = model.generate( + **inputs, + return_audio=True, + thinker_max_new_tokens=200, + talker_max_new_tokens=2000, + speaker="Chelsie", + ) + elapsed = time.time() - start + + # result is (text_ids, audio_waveform) when return_audio=True + text_response = "" + audio_waveform = None + + if isinstance(result, tuple) and len(result) >= 2: + text_ids, audio_waveform = result[0], result[1] + text_response = processor.batch_decode( + text_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + else: + text_ids = result if not isinstance(result, tuple) else result[0] + text_response = processor.batch_decode( + text_ids[:, inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + )[0] + + # Save text + text_path = os.path.join(output_dir, "test4_speech_text.txt") + with open(text_path, "w") as f: + f.write(f"Prompt: {prompt}\n\nResponse: {text_response}\n\nTime: {elapsed:.1f}s\n") + if audio_waveform is not None: + f.write(f"Audio: generated ({type(audio_waveform)})\n") + else: + f.write("Audio: none returned\n") + + # Save audio + wav_path = os.path.join(output_dir, "test4_speech_response.wav") + if audio_waveform is not None: + import wave as wave_mod + + if isinstance(audio_waveform, torch.Tensor): + audio_np = audio_waveform.cpu().float().numpy() + else: + audio_np = np.array(audio_waveform, dtype=np.float32) + + if audio_np.ndim > 1: + audio_np = audio_np.squeeze() + + max_val = max(abs(audio_np.max()), abs(audio_np.min()), 1e-8) + audio_int16 = (audio_np / max_val * 32767).astype(np.int16) + + with wave_mod.open(wav_path, "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio_int16.tobytes()) + + logger.info("Speech saved: %s (%d samples, %.1fs audio)", + wav_path, len(audio_int16), len(audio_int16) / 24000) + else: + logger.warning("No audio waveform returned.") + with open(os.path.join(output_dir, "test4_no_audio.txt"), "w") as f: + f.write("model.generate(return_audio=True) did not return audio waveform.\n" + "This may require spk_dict.pt or the Talker/Token2Wav to be initialized.\n") + + logger.info("Text: %s", text_response[:300]) + logger.info("Time: %.1fs", elapsed) + return True, text_response, elapsed + + except Exception as e: + logger.error("Test 4 FAILED: %s", e) + traceback.print_exc() + return False, str(e), 0 + + +# ============================================================================ +# Main +# ============================================================================ + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", default=DEFAULT_MODEL) + parser.add_argument("--output-dir", default=DEFAULT_OUTPUT) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + logger.info("=" * 60) + logger.info("Qwen2.5-Omni End-to-End Test") + logger.info("Model: %s", args.model_path) + logger.info("Output: %s", args.output_dir) + logger.info("=" * 60) + + image_path, audio_path = download_test_assets(args.output_dir) + model, processor = load_model(args.model_path) + + results = {} + + for name, fn, extra_args in [ + ("test1_text_to_text", test_text_only, []), + ("test2_image_text", test_image_text, [image_path]), + ("test3_audio_text", test_audio_text, [audio_path]), + ("test4_text_to_speech", test_speech_output, []), + ]: + ok, resp, t = fn(model, processor, args.output_dir, *extra_args) + results[name] = {"passed": ok, "response": resp[:500], "time": t} + gc.collect() + + passed = sum(1 for r in results.values() if r["passed"]) + total = len(results) + + summary_path = os.path.join(args.output_dir, "summary.txt") + with open(summary_path, "w") as f: + f.write("=" * 60 + "\n") + f.write("Qwen2.5-Omni End-to-End Test Results\n") + f.write(f"Model: {args.model_path}\n") + f.write("=" * 60 + "\n\n") + f.write(f"Results: {passed}/{total} passed\n\n") + for name, r in results.items(): + status = "PASS" if r["passed"] else "FAIL" + f.write(f"[{status}] {name} ({r['time']:.1f}s)\n") + f.write(f" {r['response'][:300]}\n\n") + + with open(os.path.join(args.output_dir, "results.json"), "w") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + logger.info("=" * 60) + logger.info("FINAL: %d/%d tests passed", passed, total) + logger.info("Results: %s", args.output_dir) + logger.info("=" * 60) + + if passed < total: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py index 633d8387..b4cbc728 100755 --- a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py +++ b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py @@ -1,236 +1,809 @@ #!/usr/bin/env python3 """ -Integration tests for Qwen2.5-Omni-7B NeuronX implementation. -""" +Integration tests for Qwen2.5-Omni-7B on NeuronX (TP=4). -import pytest -import torch -import json -from pathlib import Path -from transformers import AutoTokenizer, GenerationConfig +Tests: + 1. Import validation + 2. Config creation and TP=4 head divisibility + 3. State dict conversion (all 2448 keys) + 4. Audio encoder CPU components (frontend + postprocessor) + 5. Talker CPU model (weight loading + codec tokens) + 6. Text-only Thinker compile + load + generate + 7. Image understanding (requires vision encoder compiled) + 8. Audio understanding (requires audio encoder compiled) + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + + # Run all tests (skip tests requiring compilation with --quick): + python3 test_model.py + + # Run only text generation (model must be pre-compiled): + python3 test_model.py --test text_gen -from neuronx_distributed_inference.models.config import NeuronConfig -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + # Run with pytest: + pytest test_model.py -v +""" -# Import from src directory +# --- Qwen2.5-Omni contrib bootstrap --- +import sys as _sys +from pathlib import Path as _Path +_SRC = _Path(__file__).resolve().parents[2] / "src" +if str(_SRC) not in _sys.path: + _sys.path.insert(0, str(_SRC)) +import _upstream_compat # noqa: F401 (applies hf_adapter shim) +# --- end bootstrap --- + +import argparse +import gc +import os import sys -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) -from modeling_qwen2_5_omni import * - - -# Test configuration -MODEL_PATH = "/home/ubuntu/models/Qwen2.5-Omni-7B/" -COMPILED_MODEL_PATH = "/home/ubuntu/neuron_models/Qwen2.5-Omni-7B/" - - -def load_neuron_config_from_compiled(compiled_path: str): - """Load neuron configuration from compiled model's neuron_config.json.""" - config_path = Path(compiled_path) / "neuron_config.json" - - if not config_path.exists(): - raise FileNotFoundError(f"neuron_config.json not found: {config_path}") - - with open(config_path) as f: - config_data = json.load(f) - - if "neuron_config" in config_data: - return config_data["neuron_config"] - else: - return config_data +import time +import traceback +import torch -def create_model_for_inference(compiled_path: str, model_path: str): - """Create model for inference using compiled neuron_config.""" - neuron_config_dict = load_neuron_config_from_compiled(compiled_path) - - dtype_str = neuron_config_dict.get('torch_dtype', 'torch.bfloat16') - if isinstance(dtype_str, str): - dtype = getattr(torch, dtype_str.split('.')[1]) if dtype_str.startswith('torch.') else torch.bfloat16 +# Default paths - override with environment variables +from _model_path import resolve_model_path +MODEL_PATH = resolve_model_path() +COMPILED_PATH = os.environ.get( + "QWEN25_OMNI_COMPILED_PATH", "/tmp/qwen25_omni_tp4_compiled" +) +TP_DEGREE = int(os.environ.get("QWEN25_OMNI_TP_DEGREE", "4")) + +# Test media URLs from Qwen official examples +IMAGE_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" +AUDIO_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/cough.wav" + + +class Timer: + """Context manager to time a block.""" + + def __init__(self, label): + self.label = label + self.elapsed = 0 + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.elapsed = time.time() - self.start + print(f" [{self.label}] {self.elapsed:.2f}s") + + +# --------------------------------------------------------------------------- +# Test 1: Imports +# --------------------------------------------------------------------------- +def test_imports(): + """Verify all Qwen2.5-Omni modules import correctly.""" + print("=" * 60) + print("Test 1: Import validation") + print("=" * 60) + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + + print(" NeuronConfig, OnDeviceSamplingConfig imported OK") + + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, + Qwen25OmniInferenceConfig, + NeuronQwen25OmniMultimodalForCausalLM, + Qwen25OmniMultimodalInferenceConfig, + ) + + print(" Qwen25Omni model classes imported OK") + + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + NeuronQwen25OmniForAudioEncoding, + AudioEncoderInferenceConfig, + AudioCPUFrontend, + AudioCPUPostprocessor, + NeuronAudioTransformerModel, + AudioTransformerModelWrapper, + ) + + print(" Audio encoder classes imported OK") + + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + ) + + print(" Talker imported OK") + + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2Wav, + ) + + print(" Token2Wav imported OK") + + from modeling_qwen25_omni_vision import ( + NeuronQwen25OmniForImageEncoding, + ) + + print(" Vision encoder imported OK") + + print(" PASS: All imports successful\n") + return True + + +# --------------------------------------------------------------------------- +# Test 2: Config +# --------------------------------------------------------------------------- +def test_config(): + """Create configs and verify TP=4 head divisibility.""" + print("=" * 60) + print("Test 2: Config creation and TP=4 validation") + print("=" * 60) + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + from modeling_qwen25_omni import ( + Qwen25OmniInferenceConfig, + ) + from transformers import AutoConfig + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=2048, + max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.6, top_k=20, top_p=0.95 + ), + ) + + hf_config = load_pretrained_config(MODEL_PATH) + config = Qwen25OmniInferenceConfig(neuron_config, load_config=hf_config) + + # Validate text config + print(f" hidden_size={config.hidden_size}") + print(f" num_attention_heads={config.num_attention_heads}") + print(f" num_key_value_heads={config.num_key_value_heads}") + print(f" num_hidden_layers={config.num_hidden_layers}") + print(f" vocab_size={config.vocab_size}") + + # TP divisibility check for Thinker + assert ( + config.num_attention_heads % TP_DEGREE == 0 + ), f"num_attention_heads={config.num_attention_heads} not divisible by TP={TP_DEGREE}" + assert ( + config.num_key_value_heads % TP_DEGREE == 0 + ), f"num_key_value_heads={config.num_key_value_heads} not divisible by TP={TP_DEGREE}" + print( + f" Thinker heads: {config.num_attention_heads}/{TP_DEGREE}=" + f"{config.num_attention_heads // TP_DEGREE} per rank, " + f"kv_heads: {config.num_key_value_heads}/{TP_DEGREE}=" + f"{config.num_key_value_heads // TP_DEGREE} per rank" + ) + + # Validate audio config (via AutoConfig for full nested access) + full_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + if hasattr(full_config, "thinker_config"): + tc = full_config.thinker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + if "audio_config" in tc: + ac = tc["audio_config"] + if hasattr(ac, "__dict__"): + ac = vars(ac) + audio_heads = ac.get("encoder_attention_heads", 20) + print( + f" Audio: d_model={ac.get('d_model')}, heads={audio_heads}, " + f"layers={ac.get('encoder_layers')}" + ) + assert ( + audio_heads % TP_DEGREE == 0 + ), f"audio heads={audio_heads} not divisible by TP={TP_DEGREE}" + print( + f" Audio heads: {audio_heads}/{TP_DEGREE}={audio_heads // TP_DEGREE} per rank" + ) + + # Validate talker config + if hasattr(full_config, "talker_config"): + tc = full_config.talker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + talker_heads = tc.get("num_attention_heads", 12) + talker_hidden = tc.get("hidden_size", 896) + print( + f" Talker: hidden={talker_hidden}, heads={talker_heads}, " + f"head_dim={tc.get('head_dim', 128)}" + ) + print( + f" Talker stays on CPU (head_dim={tc.get('head_dim', 128)} " + f"!= {talker_hidden}//{talker_heads})" + ) + + print(" PASS: Config creation and TP=4 validation\n") + return True + + +# --------------------------------------------------------------------------- +# Test 3: State dict conversion +# --------------------------------------------------------------------------- +def test_state_dict(): + """Load full HF model and convert state dict for all components.""" + print("=" * 60) + print("Test 3: State dict conversion (all components)") + print("=" * 60) + + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniForConditionalGeneration, + ) + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + ) + + # Load HF model state dict + with Timer("Load HF model"): + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + full_sd = hf_model.state_dict() + print(f" Full state dict: {len(full_sd)} keys") + + # Count keys by prefix + prefix_counts = {} + for k in full_sd: + p = k.split(".")[0] + prefix_counts[p] = prefix_counts.get(p, 0) + 1 + for p, c in sorted(prefix_counts.items()): + print(f" {p}: {c} keys") + + # Test audio encoder state dict conversion + with Timer("Audio state dict conversion"): + audio_sd = NeuronQwen25OmniAudioEncoder.convert_hf_to_neuron_state_dict( + {k: v for k, v in full_sd.items() if "audio_tower" in k}, + dtype=torch.bfloat16, + ) + frontend_keys = [k for k in audio_sd if k.startswith("frontend.")] + transformer_keys = [k for k in audio_sd if k.startswith("transformer.")] + postprocessor_keys = [k for k in audio_sd if k.startswith("postprocessor.")] + print( + f" Audio keys: frontend={len(frontend_keys)}, " + f"transformer={len(transformer_keys)}, " + f"postprocessor={len(postprocessor_keys)}" + ) + + # Test talker state dict conversion + with Timer("Talker state dict conversion"): + talker_sd = NeuronQwen25OmniTalker.convert_hf_to_neuron_state_dict( + {k: v for k, v in full_sd.items() if k.startswith("talker.")} + ) + print(f" Talker keys: {len(talker_sd)}") + + del hf_model, full_sd + gc.collect() + + print(" PASS: State dict conversion\n") + return True + + +# --------------------------------------------------------------------------- +# Test 4: Audio encoder CPU components +# --------------------------------------------------------------------------- +def test_audio_encoder(): + """Test audio encoder CPU frontend and postprocessor with synthetic input.""" + print("=" * 60) + print("Test 4: Audio encoder CPU components") + print("=" * 60) + + from transformers import AutoConfig + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniForConditionalGeneration, + ) + from modeling_qwen25_omni_audio import ( + NeuronQwen25OmniAudioEncoder, + ) + + # Load config + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + if not hasattr(hf_config, "thinker_config"): + print(" SKIP: No thinker_config found") + return True + + tc = hf_config.thinker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + audio_config = tc.get("audio_config", None) + if audio_config is not None and hasattr(audio_config, "__dict__"): + audio_config = vars(audio_config) + if audio_config is None: + print(" SKIP: No audio_config found") + return True + + print( + f" Audio config: d_model={audio_config.get('d_model')}, " + f"heads={audio_config.get('encoder_attention_heads')}, " + f"layers={audio_config.get('encoder_layers')}" + ) + + # Load HF model for weights + with Timer("Load HF model for audio weights"): + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + full_sd = hf_model.state_dict() + + with Timer("Convert audio state dict"): + converted_sd = NeuronQwen25OmniAudioEncoder.convert_hf_to_neuron_state_dict( + full_sd, dtype=torch.bfloat16 + ) + + del hf_model, full_sd + gc.collect() + + with Timer("Create audio encoder + load CPU weights"): + encoder = NeuronQwen25OmniAudioEncoder.from_pretrained_state_dict( + audio_config, converted_sd, dtype=torch.bfloat16 + ) + encoder.eval() + + del converted_sd + gc.collect() + + # Test with synthetic mel spectrograms of various lengths + n_mels = audio_config.get("num_mel_bins", 128) + test_cases = [(100, "1s"), (300, "3s"), (1000, "10s"), (3000, "30s")] + + for mel_len, label in test_cases: + mel_input = torch.randn(n_mels, mel_len, dtype=torch.bfloat16) + feature_lens = torch.tensor([mel_len], dtype=torch.long) + + t0 = time.time() + hidden, aftercnn_lens, cu_seqlens = encoder.frontend(mel_input, feature_lens) + audio_embeds = encoder.postprocessor(hidden, aftercnn_lens) + elapsed = time.time() - t0 + + print( + f" {label} ({mel_len} frames): frontend→{hidden.shape[0]} tokens, " + f"postprocessor→{audio_embeds.shape}, time={elapsed*1000:.1f}ms" + ) + + del encoder + gc.collect() + + print(" PASS: Audio encoder CPU components\n") + return True + + +# --------------------------------------------------------------------------- +# Test 5: Talker CPU model +# --------------------------------------------------------------------------- +def test_talker(): + """Test Talker CPU model weight loading and codec token IDs.""" + print("=" * 60) + print("Test 5: Talker CPU model") + print("=" * 60) + + from transformers import AutoConfig + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniForConditionalGeneration, + ) + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + ) + + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + talker_config = getattr(hf_config, "talker_config", None) + if talker_config is None: + print(" SKIP: No talker_config found") + return True + + if hasattr(talker_config, "__dict__") and not isinstance(talker_config, dict): + tc = vars(talker_config) else: - dtype = dtype_str - - neuron_config_kwargs = { - 'tp_degree': neuron_config_dict.get('tp_degree', 2), - 'batch_size': neuron_config_dict.get('batch_size', 1), - 'seq_len': neuron_config_dict.get('seq_len', 128), - 'torch_dtype': dtype, - } - - neuron_config = NeuronConfig(**neuron_config_kwargs) - - # This will use the imported model and config classes - # The actual class names will be determined at runtime - return None, neuron_config - - -def generate_with_neuron_model(model, input_ids, max_new_tokens: int): - """Generate tokens using manual forward pass loop.""" - generated_ids = input_ids.clone() - - for _ in range(max_new_tokens): - seq_len = generated_ids.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(generated_ids.shape[0], -1) - - with torch.no_grad(): - outputs = model(generated_ids, position_ids=position_ids) - - if hasattr(outputs, 'logits'): - logits = outputs.logits - elif isinstance(outputs, tuple): - logits = outputs[0] - else: - logits = outputs - - next_token_logits = logits[:, -1, :] - next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) - generated_ids = torch.cat([generated_ids, next_token], dim=-1) - - return generated_ids - - -@pytest.fixture(scope="module") -def compiled_model(): - """Load pre-compiled model.""" - # Note: Actual implementation would load the specific model class - # This is a template that should be customized per model - return None - - -@pytest.fixture(scope="module") -def tokenizer(): - """Load tokenizer.""" - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right", trust_remote_code=True) + tc = talker_config + + print( + f" Talker config: hidden={tc.get('hidden_size')}, " + f"heads={tc.get('num_attention_heads')}, " + f"kv_heads={tc.get('num_key_value_heads')}, " + f"layers={tc.get('num_hidden_layers')}, " + f"vocab={tc.get('vocab_size')}, " + f"embedding_size={tc.get('embedding_size')}" + ) + + with Timer("Load HF model for talker weights"): + hf_model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + full_sd = hf_model.state_dict() + + talker_sd = NeuronQwen25OmniTalker.convert_hf_to_neuron_state_dict( + {k: v for k, v in full_sd.items() if k.startswith("talker.")} + ) + print(f" Talker state dict: {len(talker_sd)} keys") + + del hf_model, full_sd + gc.collect() + + with Timer("Create Talker + load weights"): + talker = NeuronQwen25OmniTalker.from_pretrained_state_dict( + talker_config, talker_sd, dtype=torch.bfloat16 + ) + + print( + f" Talker codec tokens: bos={talker.codec_bos_token}, " + f"eos={talker.codec_eos_token}, pad={talker.codec_pad_token}" + ) + + del talker, talker_sd + gc.collect() + + print(" PASS: Talker CPU model\n") + return True + + +# --------------------------------------------------------------------------- +# Test 6: Text-only Thinker compile + load + generate +# --------------------------------------------------------------------------- +def test_text_gen(): + """Compile (if needed), load, and generate with the Thinker at TP=4.""" + print("=" * 60) + print("Test 6: Text-only Thinker compile + load + generate (TP=4)") + print("=" * 60) + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, + HuggingFaceGenerationAdapter, + ) + from modeling_qwen25_omni import ( + NeuronQwen25OmniForCausalLM, + Qwen25OmniInferenceConfig, + ) + from transformers import AutoTokenizer + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=2048, + max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=False, + top_k=1, + ), + ) + + hf_config = load_pretrained_config(MODEL_PATH) + config = Qwen25OmniInferenceConfig(neuron_config, load_config=hf_config) + + compiled_dir = os.path.join(COMPILED_PATH, "thinker_tp4") + + with Timer("Create model"): + model = NeuronQwen25OmniForCausalLM(MODEL_PATH, config) + + if not os.path.exists(os.path.join(compiled_dir, "neuron_config.json")): + with Timer("Compile (this takes several minutes)"): + model.compile(compiled_dir) + else: + print(" Compiled artifacts found, skipping compilation") + + with Timer("Load compiled model"): + model.load(compiled_dir) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - -def test_model_loads(compiled_model): - """Test that model loads successfully (smoke test).""" - assert compiled_model is not None - assert hasattr(compiled_model, 'config') - print("✓ Smoke test passed - Model loaded successfully") - - -def test_model_generates(compiled_model, tokenizer): - """Test that model can generate text.""" - prompt = "The capital of France is" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - - generated_ids = generate_with_neuron_model(compiled_model, inputs.input_ids, max_new_tokens=20) - output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - - assert len(output_text) > len(prompt), "Output should be longer than prompt" - print(f"✓ Generation test passed") - print(f" Output: {output_text}") - - -def test_output_coherence(compiled_model, tokenizer): - """Test that output is coherent (not gibberish).""" - prompt = "Hello, how are you?" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - - generated_ids = generate_with_neuron_model(compiled_model, inputs.input_ids, max_new_tokens=30) - output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - - # Coherence checks - assert len(output_text.split()) > 3, "Output should have multiple words" - assert not _is_repetitive(output_text), "Output should not be repetitive" - - print(f"✓ Coherence test passed") - print(f" Output: {output_text[:100]}...") - - - -def _is_repetitive(text: str, max_repeat: int = 5) -> bool: - """Check if text has excessive repetition.""" - words = text.split() - if len(words) < 10: - return False - - # Check for repeated words - for i in range(len(words) - max_repeat): - word = words[i] - if all(words[i+j] == word for j in range(max_repeat)): - return True - - # Check for repeated characters - new_text = text[-100:] if len(text) > 100 else text - if len(new_text) > 20: - char_counts = {} - for c in new_text: - char_counts[c] = char_counts.get(c, 0) + 1 - max_char_ratio = max(char_counts.values()) / len(new_text) - if max_char_ratio > 0.5: - return True - - return False - - -def test_performance_ttft(compiled_model, tokenizer): - """Test Time To First Token (TTFT) performance.""" - import time - - prompt = "Hello, how are you?" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - input_ids = inputs.input_ids - - # Warmup - for _ in range(3): - seq_len = input_ids.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) - with torch.no_grad(): - _ = compiled_model(input_ids, position_ids=position_ids) - - # Measure TTFT - times = [] - for _ in range(10): - seq_len = input_ids.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) - - start = time.perf_counter() - with torch.no_grad(): - _ = compiled_model(input_ids, position_ids=position_ids) - end = time.perf_counter() - - times.append((end - start) * 1000) # ms - - avg_ttft = sum(times) / len(times) - print(f"✓ TTFT: {avg_ttft:.2f}ms") - - - -def test_performance_throughput(compiled_model, tokenizer): - """Test token generation throughput.""" - import time - - prompt = "Hello" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - input_ids = inputs.input_ids - num_tokens = 50 - - # Warmup - _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=5) - - # Measure throughput - start = time.perf_counter() - _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=num_tokens) - end = time.perf_counter() - - total_time = end - start - throughput = num_tokens / total_time - print(f"✓ Throughput: {throughput:.2f} tok/s") + # HuggingFaceGenerationAdapter does NOT take tokenizer as argument. + adapter = HuggingFaceGenerationAdapter(model) + + def make_chat_input(prompt): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + encoded = tokenizer(text, return_tensors="pt") + return encoded["input_ids"], encoded["attention_mask"] + + prompts = [ + "What is 2+3? Answer with just the number.", + "Write a haiku about the ocean.", + "Explain quantum computing in one sentence.", + ] + + for prompt in prompts: + input_ids, attention_mask = make_chat_input(prompt) + prompt_len = input_ids.shape[1] + + with Timer(f"Generate '{prompt[:40]}...'"): + output_ids = adapter.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=128, + eos_token_id=[tokenizer.eos_token_id, 151645], + ) + + new_tokens = output_ids[0, prompt_len:] + output_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + n_new = sum( + 1 + for tok in new_tokens + if tok.item() + not in [tokenizer.eos_token_id, tokenizer.pad_token_id, 151645] + ) + print(f" Input: {prompt}") + print(f" Output: {output_text.strip()[:200]}") + print(f" Tokens: {n_new}") + print() + + del model, adapter + gc.collect() + + print(" PASS: Text-only Thinker compile + load + generate\n") + return True + + +# --------------------------------------------------------------------------- +# Test 7: Image understanding (requires multimodal model) +# --------------------------------------------------------------------------- +def test_image_understanding(): + """Test image understanding with the Qwen2.5-Omni vision encoder. + + This test requires the multimodal model (vision encoder + text decoder) + to be compiled. It downloads a test image and asks the model to describe it. + + NOTE: This is a placeholder for the full multimodal pipeline. The vision + encoder must be compiled on Neuron before this test can run end-to-end. + Currently tests the preprocessing pipeline only. + """ + print("=" * 60) + print("Test 7: Image understanding (preprocessing)") + print("=" * 60) + + try: + from qwen_omni_utils import process_mm_info + except ImportError: + print(" SKIP: qwen-omni-utils not installed") + print(" Install with: pip install qwen-omni-utils[decord]") + return True + + from transformers import AutoConfig + + # Build message with image + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant.", + } + ], + }, + { + "role": "user", + "content": [ + {"type": "image", "image": IMAGE_URL}, + {"type": "text", "text": "Describe this image briefly."}, + ], + }, + ] + + print(f" Image URL: {IMAGE_URL}") + + with Timer("Process multimodal info (download + preprocess)"): + audios, images, videos = process_mm_info( + messages, use_audio_in_video=False + ) + + if images: + print(f" Images processed: {len(images)}") + for i, img in enumerate(images): + if hasattr(img, "shape"): + print(f" Image {i}: shape={img.shape}") + elif hasattr(img, "size"): + print(f" Image {i}: size={img.size}") + else: + print(" No images found in processed output") + + # Verify config has vision token IDs + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + if hasattr(hf_config, "thinker_config"): + tc = hf_config.thinker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + image_token_id = tc.get("image_token_index") + vision_start_id = tc.get("vision_start_token_id") + vision_end_id = tc.get("vision_end_token_id") + print( + f" Vision tokens: image={image_token_id}, " + f"start={vision_start_id}, end={vision_end_id}" + ) + + print( + " NOTE: Full end-to-end image understanding requires " + "compiled vision encoder on Neuron." + ) + print(" PASS: Image preprocessing\n") + return True + + +# --------------------------------------------------------------------------- +# Test 8: Audio understanding (requires audio encoder on Neuron) +# --------------------------------------------------------------------------- +def test_audio_understanding(): + """Test audio understanding preprocessing pipeline. + + Downloads a test audio file and preprocesses it through the Qwen2.5-Omni + audio pipeline. Full end-to-end inference requires the audio encoder's + Neuron transformer to be compiled. + """ + print("=" * 60) + print("Test 8: Audio understanding (preprocessing)") + print("=" * 60) + + try: + from qwen_omni_utils import process_mm_info + except ImportError: + print(" SKIP: qwen-omni-utils not installed") + print(" Install with: pip install qwen-omni-utils[decord]") + return True + + from transformers import AutoConfig + + # Build message with audio + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant.", + } + ], + }, + { + "role": "user", + "content": [ + {"type": "audio", "audio": AUDIO_URL}, + {"type": "text", "text": "What sound is this?"}, + ], + }, + ] + + print(f" Audio URL: {AUDIO_URL}") + + with Timer("Process multimodal info (download + preprocess)"): + audios, images, videos = process_mm_info( + messages, use_audio_in_video=False + ) + + if audios: + print(f" Audios processed: {len(audios)}") + for i, audio in enumerate(audios): + if hasattr(audio, "shape"): + print(f" Audio {i}: shape={audio.shape}") + else: + print(f" Audio {i}: type={type(audio)}") + else: + print(" No audios found in processed output") + + # Verify config has audio token IDs + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + if hasattr(hf_config, "thinker_config"): + tc = hf_config.thinker_config + if hasattr(tc, "__dict__") and not isinstance(tc, dict): + tc = vars(tc) + audio_token_id = tc.get("audio_token_index") + audio_start_id = tc.get("audio_start_token_id") + audio_end_id = tc.get("audio_end_token_id") + print( + f" Audio tokens: audio={audio_token_id}, " + f"start={audio_start_id}, end={audio_end_id}" + ) + + print( + " NOTE: Full end-to-end audio understanding requires " + "compiled audio encoder on Neuron." + ) + print(" PASS: Audio preprocessing\n") + return True + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +ALL_TESTS = [ + ("imports", test_imports), + ("config", test_config), + ("state_dict", test_state_dict), + ("audio_encoder", test_audio_encoder), + ("talker", test_talker), + ("text_gen", test_text_gen), + ("image", test_image_understanding), + ("audio", test_audio_understanding), +] + + +def main(): + parser = argparse.ArgumentParser( + description="Qwen2.5-Omni-7B integration tests" + ) + parser.add_argument( + "--test", + choices=[name for name, _ in ALL_TESTS], + nargs="+", + help="Run specific test(s). Default: all.", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Skip heavyweight tests (state_dict, audio_encoder, talker).", + ) + args = parser.parse_args() + + if args.test: + tests = [(n, fn) for n, fn in ALL_TESTS if n in args.test] + elif args.quick: + skip = {"state_dict", "audio_encoder", "talker"} + tests = [(n, fn) for n, fn in ALL_TESTS if n not in skip] + else: + tests = ALL_TESTS + + print("\n" + "=" * 60) + print(f"Qwen2.5-Omni-7B TP={TP_DEGREE} Integration Tests") + print(f"Model: {MODEL_PATH}") + print(f"Compiled: {COMPILED_PATH}") + print("=" * 60 + "\n") + + results = {} + total_start = time.time() + + for name, test_fn in tests: + try: + passed = test_fn() + results[name] = "PASS" if passed else "FAIL" + except Exception as e: + print(f"\n FAIL: {e}") + traceback.print_exc() + results[name] = f"FAIL: {e}" + + total_time = time.time() - total_start + + print("\n" + "=" * 60) + print("RESULTS SUMMARY") + print("=" * 60) + for name, result in results.items(): + status = "PASS" if result == "PASS" else "FAIL" + print(f" [{status}] {name}: {result}") + print(f"\n Total time: {total_time:.1f}s") + print("=" * 60) + + if any(r != "PASS" for r in results.values()): + sys.exit(1) if __name__ == "__main__": - print("="*80) - print("Qwen2.5-Omni-7B Integration Tests") - print("="*80) - - print("\nNote: This is a template test file.") - print("For actual model testing, customize the model loading logic.") - - print("\n" + "="*80) - print("✓ Template structure verified!") - print("="*80) + main() diff --git a/contrib/models/Qwen2.5-Omni-7B/test/integration/test_talker_neuron.py b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_talker_neuron.py new file mode 100644 index 00000000..90a83463 --- /dev/null +++ b/contrib/models/Qwen2.5-Omni-7B/test/integration/test_talker_neuron.py @@ -0,0 +1,610 @@ +#!/usr/bin/env python3 +"""Tests for the Neuron-compiled Talker and Token2Wav implementations. + +Tests CPU-level logic: config, state dict conversion, fused embedding, +class structure. Does NOT require Neuron hardware. + +Mocks neuronx_distributed and torch_neuronx at sys.modules level so tests +can run on any machine (Mac, Linux, etc.) without Neuron SDK. +""" + +# --- Qwen2.5-Omni contrib bootstrap --- +import sys as _sys +from pathlib import Path as _Path +_SRC = _Path(__file__).resolve().parents[2] / "src" +if str(_SRC) not in _sys.path: + _sys.path.insert(0, str(_SRC)) +import _upstream_compat # noqa: F401 (applies hf_adapter shim) +# --- end bootstrap --- + +import sys +import types +import torch +from pathlib import Path +from unittest.mock import MagicMock + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + + +# ============================================================================ +# Mock setup: neuronx_distributed and torch_neuronx +# ============================================================================ + +class _MockModuleType(types.ModuleType): + """Module type that returns MagicMock for any missing attribute.""" + def __getattr__(self, name): + return MagicMock(name=f"{self.__name__}.{name}") + + +class _AutoMockFinder: + """Meta path finder that auto-mocks any package that can't be imported normally. + + Only intercepts packages listed in _MOCK_PREFIXES to avoid breaking stdlib. + """ + _MOCK_PREFIXES = ( + "neuronx_distributed", "torch_neuronx", "torch_xla", + "nki", "nkilib", "neuronxcc", "transformers", "huggingface_hub", + "safetensors", "accelerate", "sentencepiece", "tokenizers", + ) + + def find_module(self, fullname, path=None): + if any(fullname == p or fullname.startswith(p + ".") for p in self._MOCK_PREFIXES): + return self + return None + + def find_spec(self, fullname, path, target=None): + if any(fullname == p or fullname.startswith(p + ".") for p in self._MOCK_PREFIXES): + from importlib.machinery import ModuleSpec + return ModuleSpec(fullname, self, is_package=True) + return None + + def create_module(self, spec): + mod = _MockModuleType(spec.name) + mod.__path__ = [] + mod.__package__ = spec.name + mod.__loader__ = self + mod.__spec__ = spec + return mod + + def exec_module(self, module): + pass + + def load_module(self, fullname): + if fullname in sys.modules: + return sys.modules[fullname] + mod = _MockModuleType(fullname) + mod.__path__ = [] + mod.__package__ = fullname + mod.__loader__ = self + sys.modules[fullname] = mod + return mod + + +def _setup_neuron_mocks(): + """Install auto-mock import hook and set up key attributes. + + Must be called before importing any neuronx_distributed_inference modules. + """ + # Install the auto-mock finder + sys.meta_path.insert(0, _AutoMockFinder()) + + # Force-import mocked modules so they exist in sys.modules + import neuronx_distributed + import neuronx_distributed.utils + + # Set specific attributes that need real values + neuronx_distributed.utils.cpu_mode = MagicMock(return_value=True) + + import neuronx_distributed.utils.utils + mock_hardware = MagicMock(name="hardware") + mock_hardware.TRN1 = "trn1" + mock_hardware.return_value = "trn2" + neuronx_distributed.utils.utils.hardware = mock_hardware + + import torch_neuronx.utils + torch_neuronx.utils.get_platform_target = MagicMock(return_value="trn2") + + # ColumnParallelLinear / RowParallelLinear / ParallelEmbedding need to be + # real classes so NxDI code can subclass them (e.g. lora_layer.py) + # ColumnParallelLinear / RowParallelLinear / ParallelEmbedding need to be + # real classes so NxDI code can subclass them (e.g. lora_layer.py) + class _MockParallelLinear(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + class _MockParallelEmbedding(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + class _MockSPMDRank(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + # Set on both .parallel_layers and .parallel_layers.layers + # (code imports from both paths) + import neuronx_distributed.parallel_layers as _pl + import neuronx_distributed.parallel_layers.layers as _pl_layers + for mod in (_pl, _pl_layers): + mod.ColumnParallelLinear = _MockParallelLinear + mod.RowParallelLinear = _MockParallelLinear + mod.ParallelEmbedding = _MockParallelEmbedding + mod.SPMDRank = _MockSPMDRank + + # LlamaRMSNorm needs to be a real nn.Module subclass for RMSNorm usage + class MockLlamaRMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + import transformers.models.llama.modeling_llama + transformers.models.llama.modeling_llama.LlamaRMSNorm = MockLlamaRMSNorm + + # ACT2FN for MLP activation lookup + import transformers.activations + transformers.activations.ACT2FN = { + "silu": torch.nn.SiLU(), + "gelu": torch.nn.GELU(), + "relu": torch.nn.ReLU(), + } + + +# Install mocks before any NxDI imports +_MOCK_MODULES = _setup_neuron_mocks() + + +# ============================================================================ +# Helper: create a load_config callable for InferenceConfig +# ============================================================================ + +def _make_load_config(**attrs): + """Create a load_config callable that sets attributes on an InferenceConfig.""" + def load_config(self): + for key, value in attrs.items(): + setattr(self, key, value) + return load_config + + +# ============================================================================ +# Test 1: Config classes +# ============================================================================ + +def test_talker_inference_config(): + """Test TalkerInferenceConfig creation and derived attributes.""" + from modeling_qwen25_omni_talker import ( + TalkerInferenceConfig, + TalkerNeuronConfig, + ) + + neuron_config = TalkerNeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=512, + torch_dtype=torch.bfloat16, + on_cpu=True, + ) + + config = TalkerInferenceConfig( + neuron_config=neuron_config, + load_config=_make_load_config( + hidden_size=896, + num_attention_heads=12, + num_hidden_layers=24, + num_key_value_heads=4, + vocab_size=8448, + rope_theta=1000000.0, + rms_norm_eps=1e-6, + hidden_act="silu", + intermediate_size=18944, + pad_token_id=0, + max_position_embeddings=32768, + ), + ) + + # Check derived config + assert config.head_dim == 128, f"Expected head_dim=128, got {config.head_dim}" + assert config.qkv_bias == True + assert config.o_bias == False + assert config.num_cores_per_group == 1 + assert config.rope_scaling is not None + assert config.rope_scaling["mrope_section"] == [16, 24, 24] + assert config.thinker_hidden_size == 3584 + + # Check neuron config cls + assert TalkerInferenceConfig.get_neuron_config_cls() == TalkerNeuronConfig + + # Check required attributes + required = config.get_required_attributes() + assert "hidden_size" in required + assert "num_attention_heads" in required + assert "head_dim" not in required # head_dim is derived, not required from HF + + print("PASS: TalkerInferenceConfig") + + +# ============================================================================ +# Test 2: Fused embedding in state dict conversion +# ============================================================================ + +def test_fused_embedding_conversion(): + """Test that embed_tokens + thinker_to_talker_proj are correctly fused.""" + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalkerForCausalLM, + TalkerInferenceConfig, + TalkerNeuronConfig, + ) + + # Create fake state dict with talker weights + vocab_size, embed_dim, hidden_size = 8448, 3584, 896 + embed_weight = torch.randn(vocab_size, embed_dim) + proj_weight = torch.randn(hidden_size, embed_dim) + proj_bias = torch.randn(hidden_size) + + state_dict = { + "talker.model.embed_tokens.weight": embed_weight, + "talker.thinker_to_talker_proj.weight": proj_weight, + "talker.thinker_to_talker_proj.bias": proj_bias, + "talker.codec_head.weight": torch.randn(vocab_size, hidden_size), + "talker.model.layers.0.self_attn.q_proj.weight": torch.randn(12 * 128, hidden_size), + "talker.model.layers.0.self_attn.k_proj.weight": torch.randn(4 * 128, hidden_size), + "talker.model.layers.0.self_attn.v_proj.weight": torch.randn(4 * 128, hidden_size), + "talker.model.layers.0.self_attn.q_proj.bias": torch.randn(12 * 128), + "talker.model.layers.0.self_attn.k_proj.bias": torch.randn(4 * 128), + "talker.model.layers.0.self_attn.v_proj.bias": torch.randn(4 * 128), + "talker.model.norm.weight": torch.randn(hidden_size), + } + + # Create config + neuron_config = TalkerNeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=512, + torch_dtype=torch.bfloat16, + on_cpu=True, + fused_qkv=True, + ) + + config = TalkerInferenceConfig( + neuron_config=neuron_config, + load_config=_make_load_config( + hidden_size=896, + num_attention_heads=12, + num_hidden_layers=1, # Just 1 layer for testing + num_key_value_heads=4, + vocab_size=8448, + rope_theta=1000000.0, + rms_norm_eps=1e-6, + hidden_act="silu", + intermediate_size=18944, + pad_token_id=0, + max_position_embeddings=32768, + ), + ) + + # Convert + converted = NeuronQwen25OmniTalkerForCausalLM.convert_hf_to_neuron_state_dict( + state_dict, config + ) + + # Check fused embedding + assert "embed_tokens.weight" in converted + fused_embed = converted["embed_tokens.weight"] + assert fused_embed.shape == (vocab_size, hidden_size), \ + f"Expected ({vocab_size}, {hidden_size}), got {fused_embed.shape}" + + # Verify fused embedding is correct: embed @ proj.T + bias + expected = embed_weight.float() @ proj_weight.float().T + proj_bias.float().unsqueeze(0) + expected = expected.to(torch.bfloat16) + assert torch.allclose(fused_embed, expected, atol=1e-2), \ + "Fused embedding values don't match expected computation" + + # Check codec_head → lm_head mapping + assert "lm_head.weight" in converted + assert "codec_head.weight" not in converted + + # Check prefix stripping + assert "layers.0.self_attn.q_proj.weight" not in converted # Should be fused + assert "layers.0.self_attn.Wqkv.weight" in converted # Fused QKV + + # Check fused QKV shape: q(1536) + k(512) + v(512) = 2560 + qkv = converted["layers.0.self_attn.Wqkv.weight"] + assert qkv.shape[0] == 12 * 128 + 4 * 128 + 4 * 128, \ + f"Fused QKV wrong shape: {qkv.shape}" + + # Check projection weights saved for CPU context encoding + assert "_thinker_proj_weight" in converted + assert "_thinker_proj_bias" in converted + assert converted["_thinker_proj_weight"].shape == (hidden_size, embed_dim) + + # Check rank utilities + assert "rank_util.rank" in converted + assert "layers.0.self_attn.rank_util.rank" in converted + + print("PASS: Fused embedding conversion") + + +# ============================================================================ +# Test 3: ThinkerToTalkerProjection +# ============================================================================ + +def test_thinker_to_talker_projection(): + """Test CPU-side thinker state projection.""" + from modeling_qwen25_omni_talker import ( + ThinkerToTalkerProjection, + ) + + thinker_dim, talker_dim = 3584, 896 + + # Create from state dict + proj_weight = torch.randn(talker_dim, thinker_dim) + proj_bias = torch.randn(talker_dim) + state_dict = { + "_thinker_proj_weight": proj_weight, + "_thinker_proj_bias": proj_bias, + } + + proj = ThinkerToTalkerProjection.from_state_dict(state_dict, dtype=torch.float32) + + # Test forward + batch, seq = 2, 10 + thinker_states = torch.randn(batch, seq, thinker_dim) + output = proj(thinker_states) + + assert output.shape == (batch, seq, talker_dim), \ + f"Expected ({batch}, {seq}, {talker_dim}), got {output.shape}" + + # Verify output matches manual computation (relax atol for large dim=3584) + expected = thinker_states @ proj_weight.T + proj_bias + assert torch.allclose(output, expected, atol=1e-3), \ + "Projection output doesn't match expected" + + print("PASS: ThinkerToTalkerProjection") + + +# ============================================================================ +# Test 4: Talker RoPE +# ============================================================================ + +def test_talker_rotary_embedding(): + """Test TalkerRotaryEmbedding with both 1D and 3D positions.""" + from modeling_qwen25_omni_talker import ( + TalkerRotaryEmbedding, + ) + + class MockConfig: + head_dim = 128 + rope_theta = 1000000.0 + + emb = TalkerRotaryEmbedding(MockConfig()) + + # Test with 2D position_ids (standard RoPE) + batch, seq = 2, 16 + x = torch.randn(batch, seq, 128) + pos_2d = torch.arange(seq).unsqueeze(0).expand(batch, -1) # (batch, seq) + cos, sin = emb(x, pos_2d) + assert cos.shape == (batch, seq, 128), f"2D RoPE cos shape: {cos.shape}" + assert sin.shape == (batch, seq, 128), f"2D RoPE sin shape: {sin.shape}" + + # Test with 3D position_ids (mRoPE) + pos_3d = torch.arange(seq).unsqueeze(0).unsqueeze(0).expand(3, batch, -1) # (3, batch, seq) + cos, sin = emb(x, pos_3d) + assert cos.shape == (3, batch, seq, 128), f"3D mRoPE cos shape: {cos.shape}" + assert sin.shape == (3, batch, seq, 128), f"3D mRoPE sin shape: {sin.shape}" + + print("PASS: TalkerRotaryEmbedding") + + +# ============================================================================ +# Test 5: mRoPE application +# ============================================================================ + +def test_apply_multimodal_rotary_pos_emb(): + """Test mRoPE application function.""" + from modeling_qwen25_omni_talker import ( + _apply_multimodal_rotary_pos_emb, + ) + + batch, n_heads, seq, head_dim = 2, 12, 16, 128 + q = torch.randn(batch, n_heads, seq, head_dim) + k = torch.randn(batch, n_heads, seq, head_dim) + + # cos/sin from mRoPE: (3, batch, seq, head_dim) + cos = torch.randn(3, batch, seq, head_dim) + sin = torch.randn(3, batch, seq, head_dim) + mrope_section = [16, 24, 24] # sum=64, *2=128=head_dim + + q_out, k_out = _apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + + assert q_out.shape == q.shape, f"Q shape mismatch: {q_out.shape} vs {q.shape}" + assert k_out.shape == k.shape, f"K shape mismatch: {k_out.shape} vs {k.shape}" + + # Verify it's not identity (something changed) + assert not torch.allclose(q_out, q), "Q unchanged after RoPE" + assert not torch.allclose(k_out, k), "K unchanged after RoPE" + + print("PASS: apply_multimodal_rotary_pos_emb") + + +# ============================================================================ +# Test 6: Token2Wav Neuron DiT class +# ============================================================================ + +def test_token2wav_neuron_dit_class(): + """Test NeuronQwen25OmniToken2WavWithNeuronDiT class structure.""" + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2Wav, + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + + # Check inheritance + assert issubclass(NeuronQwen25OmniToken2WavWithNeuronDiT, NeuronQwen25OmniToken2Wav) + + # Check that Neuron DiT class has the expected methods + assert hasattr(NeuronQwen25OmniToken2WavWithNeuronDiT, "compile_dit") + assert hasattr(NeuronQwen25OmniToken2WavWithNeuronDiT, "load_dit") + assert hasattr(NeuronQwen25OmniToken2WavWithNeuronDiT, "_get_dit_module") + + print("PASS: Token2Wav Neuron DiT class structure") + + +# ============================================================================ +# Test 7: Orchestration methods +# ============================================================================ + +def test_orchestration_methods(): + """Test that orchestration class has the new methods.""" + import ast + + orchestration_path = ( + Path(__file__).resolve().parents[2] + / "src" / "modeling_qwen25_omni.py" + ) + with open(orchestration_path) as f: + tree = ast.parse(f.read()) + + # Find NeuronQwen25OmniMultimodalForCausalLM + target_cls = None + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == "NeuronQwen25OmniMultimodalForCausalLM": + target_cls = node + break + + assert target_cls is not None, "NeuronQwen25OmniMultimodalForCausalLM not found" + + methods = { + n.name for n in target_cls.body + if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + + # Check new methods exist + expected_methods = { + "enable_talker", + "compile_talker", + "load_talker", + "enable_token2wav", + "compile_token2wav_dit", + "load_token2wav_dit", + "_get_talker_cls", + "_get_neuron_talker_cls", + "_get_talker_config_cls", + "_get_thinker_projection_cls", + "_get_token2wav_cls", + "_get_neuron_token2wav_cls", + } + + missing = expected_methods - methods + assert not missing, f"Missing methods in orchestration: {missing}" + + print("PASS: Orchestration methods") + + +# ============================================================================ +# Test 8: Encode vision to input (thinker state injection) +# ============================================================================ + +def test_encode_vision_to_input(): + """Test NeuronTalkerModel.encode_vision_to_input for thinker state injection.""" + from modeling_qwen25_omni_talker import ( + NeuronTalkerModel, + ) + + batch, seq, hidden = 2, 16, 896 + + # Placeholder embeddings from embed_tokens + inputs_embeds = torch.zeros(batch, seq, hidden) + # Projected thinker states + vision_embeddings = torch.randn(batch, seq, hidden) + # Full mask (all positions are thinker states) + vision_mask = torch.ones(batch, seq, 1, dtype=torch.int32) + + # Call static-like method (doesn't need model instance) + result = NeuronTalkerModel.encode_vision_to_input(None, inputs_embeds, vision_embeddings, vision_mask) + + assert result.shape == (batch, seq, hidden) + assert torch.allclose(result, vision_embeddings), \ + "Full mask should replace all positions with thinker states" + + # Test partial mask + vision_mask_partial = torch.zeros(batch, seq, 1, dtype=torch.int32) + vision_mask_partial[:, :8, :] = 1 # First 8 positions are thinker states + + result_partial = NeuronTalkerModel.encode_vision_to_input( + None, inputs_embeds, vision_embeddings, vision_mask_partial + ) + assert torch.allclose(result_partial[:, :8, :], vision_embeddings[:, :8, :]), \ + "Masked positions should have thinker states" + assert torch.allclose(result_partial[:, 8:, :], inputs_embeds[:, 8:, :]), \ + "Unmasked positions should keep original embeddings" + + print("PASS: encode_vision_to_input") + + +# ============================================================================ +# Test 9: Class imports resolve correctly +# ============================================================================ + +def test_imports(): + """Test that all new classes can be imported.""" + from modeling_qwen25_omni_talker import ( + NeuronQwen25OmniTalker, + TalkerNeuronConfig, + TalkerInferenceConfig, + TalkerRotaryEmbedding, + NeuronTalkerAttention, + NeuronTalkerDecoderLayer, + NeuronTalkerModel, + NeuronQwen25OmniTalkerForCausalLM, + ThinkerToTalkerProjection, + ) + + from modeling_qwen25_omni_token2wav import ( + NeuronQwen25OmniToken2Wav, + NeuronQwen25OmniToken2WavWithNeuronDiT, + ) + + print("PASS: All imports successful") + + +# ============================================================================ +# Main +# ============================================================================ + +if __name__ == "__main__": + tests = [ + test_imports, + test_talker_inference_config, + test_fused_embedding_conversion, + test_thinker_to_talker_projection, + test_talker_rotary_embedding, + test_apply_multimodal_rotary_pos_emb, + test_token2wav_neuron_dit_class, + test_orchestration_methods, + test_encode_vision_to_input, + ] + + passed = 0 + failed = 0 + for test in tests: + try: + test() + passed += 1 + except Exception as e: + print(f"FAIL: {test.__name__}: {e}") + import traceback + traceback.print_exc() + failed += 1 + + print(f"\n{'='*50}") + print(f"Results: {passed} passed, {failed} failed out of {len(tests)} tests") + if failed == 0: + print("All tests passed!") + else: + sys.exit(1)