diff --git a/contrib/models/InternVL3-8B-Instruct/README.md b/contrib/models/InternVL3-8B-Instruct/README.md new file mode 100644 index 00000000..d12f5698 --- /dev/null +++ b/contrib/models/InternVL3-8B-Instruct/README.md @@ -0,0 +1,212 @@ +# Contrib Model: InternVL3-8B-Instruct + +InternVL3-8B-Instruct is a vision-language model (VLM) running on AWS Trainium2 via NxD Inference. It supports both text-only and multimodal (text + image) inference using the NeuronBaseForImageToText framework. + +**Maintainer:** Jim Burtoft ([@jimburtoft](https://github.com/jimburtoft)) + +## Model Information + +- **HuggingFace ID:** [`OpenGVLab/InternVL3-8B-Instruct`](https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct) +- **Model Type:** Vision-language model (decoder-only transformer with vision encoder) +- **Parameters:** ~8B total (InternViT-300M vision encoder + Qwen2.5-7B text backbone) +- **Architecture:** GQA (28 heads / 4 KV heads), RoPE, RMSNorm (text); LayerNorm, GELU, absolute position embeddings (vision); pixel shuffle downsampling + 2-layer MLP projector +- **License:** MIT (Apache-2.0 for Qwen2.5 component) +- **Precision:** BF16 + +### Architecture Overview + +| Component | Details | +|-----------|---------| +| **Vision encoder** | InternViT-300M-448px-V2.5: 24 layers, hidden=1024, 16 heads, patch_size=14, image_size=448 | +| **Projector** | Pixel shuffle (downsample_ratio=0.5, 1024→256 tokens) + LayerNorm + Linear(4096, 3584) + GELU + Linear(3584, 3584) | +| **Text backbone** | Qwen2.5-7B: 28 layers, hidden=3584, intermediate=18944, GQA (28/4), vocab=151674, tie_word_embeddings=False | + +## Validation Results + +**Validated:** 2026-04-28 +**Instance:** trn2.3xlarge (LNC=2, TP=4) +**SDK:** Neuron SDK 2.29 (NxDI 0.9.17334, neuronx-cc 2.24.5133.0, PyTorch 2.9) + +### Benchmark Results + +#### Performance (TP=4, batch_size=1) + +| Sequence Length | TTFT (ms) | TKG Throughput (tok/s) | +|----------------|-----------|------------------------| +| 2048 | 138 | 75.1 | +| 4096 | 230 | 58.9 | +| 8192 | 482 | 40.0 | +| 16384 | 1019 | 23.6 | +| 32768 | 2438 | 11.4 | + +Vision encoder latency: 34.5 ms per 448x448 tile (batch=1). + +#### GPU Comparison (1x NVIDIA L40S, BF16, SDPA) + +| Metric | GPU (L40S) | Neuron (trn2.3xlarge TP=4) | Speedup | +|--------|------------|---------------------------|---------| +| TTFT (2048 input tokens) | 153.5 ms | 138 ms | 1.11x | +| Output tok/s (BS=1) | 40.5 | 75.1 | **1.85x** | + +### Accuracy Validation + +| Test | Status | Metrics | +|------|--------|---------| +| CTE logit comparison (vs CPU FP32) | PASS | cosine=0.9984, top-1 match, top-5 5/5, top-10 8/10 | +| TKG text generation | PASS | Correct, coherent output ("The capital of France is Paris.") | +| Multimodal generation | PASS | Vision encoder + text pipeline end-to-end working | + +## Usage + +### Prerequisites + +```bash +# Activate NxDI environment +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Download model +huggingface-cli download OpenGVLab/InternVL3-8B-Instruct --local-dir /mnt/models/InternVL3-8B-Instruct/ +``` + +### Compile and Run + +```python +import sys +import torch +from pathlib import Path +from transformers import AutoTokenizer + +# Add contrib src to path +sys.path.insert(0, str(Path("contrib/models/InternVL3-8B-Instruct/src"))) + +from modeling_internvl3 import NeuronInternVL3ForCausalLM, InternVL3InferenceConfig +from neuronx_distributed_inference.models.config import NeuronConfig + +MODEL_PATH = "/mnt/models/InternVL3-8B-Instruct/" +COMPILED_PATH = "/mnt/models/neuron_models/InternVL3-8B-Instruct/" + +# Configure +text_neuron_config = NeuronConfig( + tp_degree=4, + max_batch_size=1, + seq_len=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=None, + save_sharded_checkpoint=True, +) +vision_neuron_config = NeuronConfig( + tp_degree=4, + max_batch_size=1, + seq_len=256, + torch_dtype=torch.bfloat16, + on_device_sampling_config=None, + buckets=[1], + fused_qkv=True, + save_sharded_checkpoint=True, +) + +config = InternVL3InferenceConfig.from_pretrained( + MODEL_PATH, + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, +) + +# Compile (first time only) +model = NeuronInternVL3ForCausalLM(MODEL_PATH, config=config) +model.compile(COMPILED_PATH) + +# Load and generate +model.load(COMPILED_PATH) + +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) +prompt = "The capital of France is" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids + +outputs = model(input_ids=input_ids) +next_token = outputs.logits[0, -1].argmax().item() +print(f"{prompt} {tokenizer.decode([next_token])}") +# Output: The capital of France is Paris +``` + +### Multimodal Inference + +```python +# Build input with vision tokens +IMG_CONTEXT_ID = 151667 # +IMG_START_ID = 151665 # +IMG_END_ID = 151666 # + +text_ids = tokenizer("Describe this image:", return_tensors="pt").input_ids[0] +img_tokens = torch.full((256,), IMG_CONTEXT_ID, dtype=torch.long) + +input_ids = torch.cat([ + text_ids, + torch.tensor([IMG_START_ID]), + img_tokens, + torch.tensor([IMG_END_ID]), +]).unsqueeze(0) + +# Pixel values for a single 448x448 tile +pixel_values = preprocess_image(image) # [1, 3, 448, 448] + +outputs = model(input_ids=input_ids, pixel_values=pixel_values) +``` + +## Compatibility Matrix + +| Instance | SDK 2.29 | SDK 2.28 | +|----------|----------|----------| +| trn2.3xlarge (LNC=2, TP=4) | **VALIDATED** | Not tested | + +## Example Checkpoints + +* [OpenGVLab/InternVL3-8B-Instruct](https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct) + +## Testing Instructions + +```bash +# Ensure model is compiled first (see Usage above), then: +cd contrib/models/InternVL3-8B-Instruct/ +pytest test/integration/test_model.py -v --tb=short +``` + +## Implementation Notes + +### Three-File VLM Architecture + +The model uses the NxDI `NeuronBaseForImageToText` framework with three files: + +- `src/modeling_internvl3.py` — Top-level VLM orchestrating vision + text +- `src/modeling_internvl3_text.py` — Text model (Qwen2.5-7B) with vision embedding injection via `scatter_by_index_put()` +- `src/modeling_internvl3_vision.py` — Vision encoder (InternViT-300M) compiled via `torch_neuronx.trace()` with pixel shuffle and MLP projector + +### Weight Mapping + +| HuggingFace Key | NxDI Key | +|-----------------|----------| +| `language_model.model.layers.{i}.*` | `layers.{i}.*` | +| `language_model.model.embed_tokens.weight` | `embed_tokens.weight` | +| `language_model.model.norm.weight` | `norm.weight` | +| `language_model.lm_head.weight` | `lm_head.weight` | +| `vision_model.*` | Vision encoder (separate NEFF) | +| `mlp1.*` | Projector (part of vision NEFF) | + +### Special Tokens + +| Token | ID | Purpose | +|-------|-----|---------| +| `` | 151667 | Visual token placeholder in text sequence | +| `` | 151665 | Image region start marker | +| `` | 151666 | Image region end marker | +| `<|im_end|>` | 151645 | EOS token | + +## Known Issues + +- **V2PE not implemented**: Variable Visual Position Encoding (described in InternVL2.5/3 papers) is not implemented in the HuggingFace model code and is not included here. Standard position IDs are used. Accuracy validation passes without V2PE. +- **Batch size > 1**: Single-request batch inference (batch_size > 1) has a known issue with sampling_params shape. Use vLLM for multi-request concurrent serving. +- **trust_remote_code**: The HuggingFace tokenizer requires `trust_remote_code=True`. The NxDI model code reads config.json directly and does not require it. +- **NKI kernels**: Not applicable for this model. Qwen2.5-7B's `intermediate_size=18944` is incompatible with NxDI NKI `mlp_kernel` and `attn_block_tkg` kernels at tested TP degrees. + +## vLLM Integration + +This model can be served through vLLM-neuron with patches to the vllm-neuron worker. See the [NxD Inference vLLM User Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/vllm-user-guide.html) for general vLLM setup. InternVL3 requires modifications to vllm-neuron's model loader and runner to register the custom architecture. Contact the maintainer for patch details. diff --git a/contrib/models/InternVL3-8B-Instruct/compile_internvl3_vlm.py b/contrib/models/InternVL3-8B-Instruct/compile_internvl3_vlm.py new file mode 100644 index 00000000..4e26b846 --- /dev/null +++ b/contrib/models/InternVL3-8B-Instruct/compile_internvl3_vlm.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Compile InternVL3-8B-Instruct full VLM on Neuron. + +This compiles three NEFFs: +1. Vision encoder: InternViT-300M + pixel shuffle + MLP projector +2. Text CTE: Qwen2.5-7B context encoding (with vision embedding injection) +3. Text TKG: Qwen2.5-7B token generation + +Usage: + python compile_internvl3_vlm.py [--text-only] [--vision-only] + +Target: trn2.3xlarge LNC=2 TP=4 +""" + +import argparse +import sys +import time +from pathlib import Path + +import torch + +# Add contrib src to path +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from modeling_internvl3 import InternVL3InferenceConfig, NeuronInternVL3ForCausalLM +from neuronx_distributed_inference.models.config import NeuronConfig + +MODEL_PATH = "/mnt/models/InternVL3-8B-Instruct/" +COMPILED_PATH = "/mnt/models/neuron_models/InternVL3-8B-Instruct/" + + +def create_config(): + """Create InternVL3 VLM inference config with text + vision NeuronConfigs.""" + + # Text NeuronConfig: TP=4 on trn2.3xlarge LNC=2 + text_neuron_config = NeuronConfig( + tp_degree=4, + max_batch_size=1, + seq_len=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config=None, + save_sharded_checkpoint=True, + ) + + # Vision NeuronConfig: Must match text TP degree (NxDI requirement) + # Vision encoder is small; weights are replicated across TP ranks. + # Bucket = [1] (one image at a time) + vision_neuron_config = NeuronConfig( + tp_degree=4, + max_batch_size=1, + seq_len=256, # 256 vision tokens after pixel shuffle + torch_dtype=torch.bfloat16, + on_device_sampling_config=None, + buckets=[1], # number of images + fused_qkv=True, # vision encoder has fused QKV weights + save_sharded_checkpoint=True, + ) + + config = InternVL3InferenceConfig.from_pretrained( + MODEL_PATH, + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + ) + + return config + + +def compile_and_test(): + """Compile full VLM and run smoke tests.""" + print("=" * 60) + print("InternVL3-8B-Instruct: Full VLM Compilation") + print("=" * 60) + + config = create_config() + + print(f"\nModel path: {MODEL_PATH}") + print(f"Compiled path: {COMPILED_PATH}") + print(f"\n--- Text Config ---") + print(f" TP degree: {config.text_config.neuron_config.tp_degree}") + print(f" Seq len: {config.text_config.neuron_config.seq_len}") + print(f" Batch size: {config.text_config.neuron_config.max_batch_size}") + print(f" hidden_size: {config.text_config.hidden_size}") + print(f" num_hidden_layers: {config.text_config.num_hidden_layers}") + print(f" vocab_size: {config.text_config.vocab_size}") + print(f"\n--- Vision Config ---") + print(f" TP degree: {config.vision_config.neuron_config.tp_degree}") + print(f" Buckets: {config.vision_config.neuron_config.buckets}") + + # Create model + print("\n--- Creating model ---") + model = NeuronInternVL3ForCausalLM(MODEL_PATH, config=config) + + # Compile + print("\n--- Compiling (text + vision) ---") + start = time.time() + model.compile(COMPILED_PATH) + elapsed = time.time() - start + print(f"\nCompilation completed in {elapsed:.1f}s ({elapsed / 60:.1f} min)") + + # Load + print("\n--- Loading compiled model ---") + start = time.time() + model.load(COMPILED_PATH) + elapsed = time.time() - start + print(f"Load completed in {elapsed:.1f}s") + + # Smoke test 1: Text-only + print("\n--- Smoke test: text-only ---") + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + prompt = "The capital of France is" + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs.input_ids + seq_len = input_ids.shape[-1] + position_ids = torch.arange(seq_len, dtype=torch.int32).unsqueeze(0) + seq_ids = torch.zeros(1, dtype=torch.int32) + + print(f"Prompt: {prompt}") + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + position_ids=position_ids, + seq_ids=seq_ids, + ) + logits = outputs.logits + top5 = torch.topk(logits[0, -1].float(), 5) + print("Top-5 next tokens:") + for t, v in zip(top5.indices, top5.values): + print( + f" {tokenizer.decode([t.item()])!r} (id={t.item()}, logit={v.item():.4f})" + ) + + # Smoke test 2: Multimodal (synthetic image) + print("\n--- Smoke test: multimodal ---") + # Build input with 256 tokens + IMG_CONTEXT_ID = 151667 + IMG_START_ID = 151665 + IMG_END_ID = 151666 + + text_before = "Describe this image:" + text_before_ids = tokenizer(text_before, return_tensors="pt").input_ids[0] + + # Construct: *256 + img_tokens = torch.full((256,), IMG_CONTEXT_ID, dtype=torch.long) + full_ids = torch.cat( + [ + text_before_ids, + torch.tensor([IMG_START_ID]), + img_tokens, + torch.tensor([IMG_END_ID]), + ] + ).unsqueeze(0) + + full_seq_len = full_ids.shape[-1] + full_position_ids = torch.arange(full_seq_len, dtype=torch.int32).unsqueeze(0) + full_seq_ids = torch.zeros(1, dtype=torch.int32) + + # Synthetic pixel values (random noise as placeholder) + pixel_values = torch.randn(1, 3, 448, 448) + + print(f"Input IDs shape: {full_ids.shape}") + print(f"Pixel values shape: {pixel_values.shape}") + print( + f"Number of tokens: {(full_ids == IMG_CONTEXT_ID).sum().item()}" + ) + + with torch.no_grad(): + outputs = model( + input_ids=full_ids, + position_ids=full_position_ids, + seq_ids=full_seq_ids, + pixel_values=pixel_values, + ) + logits = outputs.logits + top5 = torch.topk(logits[0, -1].float(), 5) + print("Top-5 next tokens (multimodal):") + for t, v in zip(top5.indices, top5.values): + print( + f" {tokenizer.decode([t.item()])!r} (id={t.item()}, logit={v.item():.4f})" + ) + + print("\n" + "=" * 60) + print("SUCCESS: InternVL3 full VLM compiled and running on Neuron") + print("=" * 60) + + +if __name__ == "__main__": + compile_and_test() diff --git a/contrib/models/InternVL3-8B-Instruct/src/__init__.py b/contrib/models/InternVL3-8B-Instruct/src/__init__.py new file mode 100644 index 00000000..88cea97a --- /dev/null +++ b/contrib/models/InternVL3-8B-Instruct/src/__init__.py @@ -0,0 +1,28 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from modeling_internvl3 import ( + InternVL3InferenceConfig, + NeuronInternVL3ForCausalLM, +) +from modeling_internvl3_text import ( + InternVL3TextModelWrapper, + NeuronInternVL3TextForCausalLM, + NeuronInternVL3TextModel, +) +from modeling_internvl3_vision import ( + InternVL3VisionModelWrapper, + NeuronInternVL3VisionModel, + convert_vision_hf_to_neuron_state_dict, +) + +__all__ = [ + "InternVL3InferenceConfig", + "NeuronInternVL3ForCausalLM", + "InternVL3TextModelWrapper", + "NeuronInternVL3TextForCausalLM", + "NeuronInternVL3TextModel", + "InternVL3VisionModelWrapper", + "NeuronInternVL3VisionModel", + "convert_vision_hf_to_neuron_state_dict", +] diff --git a/contrib/models/InternVL3-8B-Instruct/src/modeling_internvl3.py b/contrib/models/InternVL3-8B-Instruct/src/modeling_internvl3.py new file mode 100644 index 00000000..763bcd33 --- /dev/null +++ b/contrib/models/InternVL3-8B-Instruct/src/modeling_internvl3.py @@ -0,0 +1,585 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +NxDI contrib model for InternVL3-8B-Instruct (OpenGVLab). + +Architecture: InternViT-300M vision encoder + pixel shuffle MLP projector + Qwen2.5-7B text backbone. + +This is the top-level VLM class that inherits from NeuronBaseForImageToText, +orchestrating both the vision encoder (separate NEFF) and text decoder (CTE + TKG NEFFs). + +Vision pipeline: + pixel_values [B,3,448,448] -> InternViT -> strip CLS -> pixel_shuffle -> MLP projector + -> [1, seq_len, 3584] padded vision embeddings + +Text pipeline: + input_ids -> embed_tokens -> scatter vision embeddings at positions + -> 28 Qwen2.5 decoder layers -> lm_head -> logits + +Special tokens: + = 151667 (image placeholder token in input_ids) + = 151665 (image start) + = 151666 (image end) +""" + +import copy +import json +import logging +import os +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import torch +from transformers.modeling_outputs import CausalLMOutputWithPast + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.image_to_text_model_base import ( + ImageToTextInferenceConfig, + NeuronBaseForImageToText, +) +from neuronx_distributed_inference.models.llama4.utils.encoder_utils import ( + generate_positions_from_mask, + pad_positions, +) +from neuronx_distributed_inference.models.model_wrapper import VISION_ENCODER_MODEL_TAG + +from modeling_internvl3_text import ( + InternVL3TextModelWrapper, + NeuronInternVL3TextForCausalLM, + NeuronInternVL3TextModel, +) +from modeling_internvl3_vision import ( + InternVL3VisionModelWrapper, + NeuronInternVL3VisionModel, + convert_vision_hf_to_neuron_state_dict, +) + +logger = logging.getLogger("Neuron") + +# InternVL3 special token ID for image context placeholder +IMG_CONTEXT_TOKEN_ID = 151667 + +# Keys from top-level config that must be copied to text_config +INTERNVL3_TEXT_CONFIG_KEYS = [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "pad_token_id", + "vocab_size", + "intermediate_size", + "max_position_embeddings", + "rms_norm_eps", + "rope_theta", + "hidden_act", + "bos_token_id", + "eos_token_id", + "tie_word_embeddings", +] + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +class InternVL3InferenceConfig(ImageToTextInferenceConfig): + """ + Inference configuration for InternVL3 on Neuron. + + Requires two NeuronConfig objects: + - text_neuron_config: for the Qwen2.5-7B text decoder (CTE + TKG) + - vision_neuron_config: for the InternViT-300M vision encoder + + The HF config.json has text params under "llm_config" and vision params + under "vision_config". This class handles the mapping. + """ + + def __init__( + self, + text_neuron_config, + vision_neuron_config, + fused_spec_config=None, + load_config=None, + metadata: Optional[Dict] = None, + **kwargs, + ): + # Wrap load_config to map InternVL's llm_config -> text_config + # The NxDI base class calls load_config(self) which sets attributes from + # config.to_dict(). InternVL HF config has "llm_config" not "text_config", + # so we intercept to create text_config from llm_config before validation. + if load_config is not None: + original_load_config = load_config + + def _patched_load_config(config_obj): + original_load_config(config_obj) + # Map llm_config -> text_config if needed + if ( + not hasattr(config_obj, "text_config") + or config_obj.text_config is None + ): + llm_cfg = getattr(config_obj, "llm_config", None) + if llm_cfg is not None: + # llm_config can be a dict, PretrainedConfig, or SimpleNamespace + if isinstance(llm_cfg, dict): + text_config_dict = dict(llm_cfg) + elif hasattr(llm_cfg, "to_dict"): + text_config_dict = llm_cfg.to_dict() + elif hasattr(llm_cfg, "__dict__"): + text_config_dict = dict(vars(llm_cfg)) + else: + text_config_dict = { + k: getattr(llm_cfg, k) + for k in dir(llm_cfg) + if not k.startswith("_") + } + # Set HF defaults needed by NxDI + text_config_dict.setdefault("output_attentions", False) + text_config_dict.setdefault("output_hidden_states", False) + config_obj.text_config = text_config_dict + # Also propagate text config keys to top level + for key in INTERNVL3_TEXT_CONFIG_KEYS: + if key in text_config_dict and not hasattr(config_obj, key): + setattr(config_obj, key, text_config_dict[key]) + + load_config = _patched_load_config + + super().__init__( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + fused_spec_config=fused_spec_config, + load_config=load_config, + metadata=metadata, + **kwargs, + ) + self.add_special_config() + self.validate_model_supported_configs() + + def add_special_config(self): + """Set InternVL3-specific config defaults.""" + self.num_cores_per_group = 1 + + # Qwen2.5 text backbone: QKV bias, no O bias + self.qkv_bias = True + self.o_bias = False + + # Image token ID for vision mask generation + self.image_token_id = IMG_CONTEXT_TOKEN_ID + + # Copy text keys from top-level config to text_config + for key in INTERNVL3_TEXT_CONFIG_KEYS: + if hasattr(self, key): + setattr(self.text_config, key, getattr(self, key)) + + self.pad_token_id = getattr(self.text_config, "pad_token_id", 0) + + def validate_model_supported_configs(self): + """Validate and disable unsupported NeuronConfig options.""" + # Validate text config keys match + for key in INTERNVL3_TEXT_CONFIG_KEYS: + if hasattr(self, key) and hasattr(self.text_config, key): + top_val = getattr(self, key) + text_val = getattr(self.text_config, key) + if top_val != text_val: + logger.warning( + f"Config mismatch: {key} top={top_val} vs text={text_val}, using top" + ) + setattr(self.text_config, key, top_val) + + # Disable unsupported text model features + TEXT_UNSUPPORTED = [ + "is_block_kv_layout", + "is_prefix_caching", + "is_chunked_prefill", + "is_medusa", + "enable_fused_speculation", + ] + for cfg_name in TEXT_UNSUPPORTED: + if getattr(self.text_config.neuron_config, cfg_name, False) is not False: + setattr(self.text_config.neuron_config, cfg_name, False) + logger.warning( + f"InternVL3 text model: '{cfg_name}' unsupported, disabled." + ) + + # Disable unsupported vision model features + VISION_UNSUPPORTED = [ + "sequence_parallel_enabled", + "flash_decoding_enabled", + "qkv_kernel_enabled", + "attn_block_tkg_nki_kernel_cache_update", + "attn_block_tkg_nki_kernel_enabled", + ] + for cfg_name in VISION_UNSUPPORTED: + if getattr(self.vision_config.neuron_config, cfg_name, False) is not False: + setattr(self.vision_config.neuron_config, cfg_name, False) + logger.warning( + f"InternVL3 vision model: '{cfg_name}' unsupported, disabled." + ) + + def get_required_attributes(self) -> List[str]: + return [ + "text_config", + "vision_config", + "text_config.hidden_size", + "text_config.num_attention_heads", + "text_config.num_hidden_layers", + "text_config.num_key_value_heads", + "text_config.pad_token_id", + "text_config.vocab_size", + "text_config.max_position_embeddings", + "text_config.rope_theta", + "text_config.rms_norm_eps", + "text_config.hidden_act", + "vision_config.hidden_size", + "vision_config.num_attention_heads", + "vision_config.num_hidden_layers", + "vision_config.image_size", + "vision_config.patch_size", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + @classmethod + def from_pretrained( + cls, + model_path: str, + text_neuron_config=None, + vision_neuron_config=None, + **kwargs, + ) -> "InternVL3InferenceConfig": + """ + Load configuration from a pretrained InternVL3 model directory. + + InternVL3 config.json structure: + - Top-level: model_type, downsample_ratio, force_image_size, etc. + - llm_config: Qwen2.5-7B text params (model_type=qwen2) + - vision_config: InternViT params (model_type=intern_vit_6b) + + The ImageToTextInferenceConfig parent expects "text_config" and + "vision_config" keys in kwargs. We map llm_config -> text_config. + """ + config_path = os.path.join(model_path, "config.json") + if not os.path.exists(config_path): + raise FileNotFoundError(f"config.json not found at {config_path}") + + with open(config_path, "r") as f: + config_dict = json.load(f) + + # Extract text config (InternVL uses "llm_config", not "text_config") + llm_config = config_dict.get("llm_config", {}) + vision_config = config_dict.get("vision_config", {}) + + # Build the config dict that ImageToTextInferenceConfig expects + # Must have "text_config" and "vision_config" at top level + inference_kwargs = {} + + # Copy top-level InternVL params + for key in [ + "downsample_ratio", + "force_image_size", + "select_layer", + "model_type", + "architectures", + "tie_word_embeddings", + ]: + if key in config_dict: + inference_kwargs[key] = config_dict[key] + + # Copy text params from llm_config -> text_config + text_config_dict = {} + for key in [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "vocab_size", + "max_position_embeddings", + "rope_theta", + "rms_norm_eps", + "hidden_act", + "intermediate_size", + "pad_token_id", + "bos_token_id", + "eos_token_id", + "tie_word_embeddings", + ]: + if key in llm_config: + text_config_dict[key] = llm_config[key] + + # HF PretrainedConfig defaults required by NxDI model_base._setup_func_config() + # These are normally set by HF's from_pretrained() but we build config manually + text_config_dict.setdefault("output_attentions", False) + text_config_dict.setdefault("output_hidden_states", False) + # Note: do NOT set use_return_dict here — it's a read-only computed attribute + # in PretrainedConfig and will raise AttributeError when HuggingFaceGenerationAdapter + # converts our config via to_pretrained_config() → PretrainedConfig(**text_config_dict) + + # Also set at top level (required by ImageToTextInferenceConfig) + for key, value in text_config_dict.items(): + inference_kwargs[key] = value + + inference_kwargs["text_config"] = text_config_dict + + # Copy vision config as-is + inference_kwargs["vision_config"] = vision_config + + # Set image_token_id + inference_kwargs["image_token_id"] = IMG_CONTEXT_TOKEN_ID + + # Set _name_or_path + inference_kwargs["_name_or_path"] = model_path + + # Merge user kwargs + inference_kwargs.update(kwargs) + + return cls( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + **inference_kwargs, + ) + + +# --------------------------------------------------------------------------- +# Top-level VLM class +# --------------------------------------------------------------------------- + + +class NeuronInternVL3ForCausalLM(NeuronBaseForImageToText): + """ + InternVL3 vision-language model for Neuron inference. + + Orchestrates: + - Vision encoder NEFF (InternViT-300M + pixel shuffle + projector) + - Text decoder CTE NEFF (Qwen2.5-7B context encoding with vision embedding injection) + - Text decoder TKG NEFF (Qwen2.5-7B token generation) + + Usage: + text_nc = NeuronConfig(tp_degree=4, max_batch_size=1, seq_len=4096, ...) + vision_nc = NeuronConfig(tp_degree=1, max_batch_size=1, buckets=[1], ...) + config = InternVL3InferenceConfig.from_pretrained( + model_path, text_neuron_config=text_nc, vision_neuron_config=vision_nc + ) + model = NeuronInternVL3ForCausalLM(config) + model.compile(compiled_path) + model.load(compiled_path) + output = model(input_ids, attention_mask, position_ids, seq_ids, + sampling_params, pixel_values=pixel_values) + """ + + text_model_cls = NeuronInternVL3TextModel + vision_model_cls = NeuronInternVL3VisionModel + text_model_wrapper = InternVL3TextModelWrapper + vision_model_wrapper = InternVL3VisionModelWrapper + + def __init__(self, *args, **kwargs): + super().__init__( + self.text_model_cls, + self.vision_model_cls, + self.text_model_wrapper, + self.vision_model_wrapper, + *args, + **kwargs, + ) + + def get_vision_compiler_args(self) -> str: + """Compiler args for vision encoder NEFF.""" + return "--auto-cast=matmult --model-type=transformer -O1" + + def get_compiler_args(self) -> str: + """Compiler args for text model NEFFs (CTE + TKG).""" + return "--auto-cast=matmult --model-type=transformer -O1" + + def get_required_kwargs(self) -> List[str]: + """Additional input args for HuggingFaceGenerationAdapter.""" + return ["pixel_values", "vision_mask"] + + def enable_vision_encoder( + self, enable_wlt_optimization: bool = True, **model_init_kwargs + ): + """Create the vision encoder model wrapper.""" + new_config = copy.deepcopy(self.config) + self.vision_encoder_model = self.vision_model_wrapper( + config=new_config, + model_cls=self.vision_model_cls, + tag=VISION_ENCODER_MODEL_TAG, + compiler_args=self.get_vision_compiler_args(), + model_init_kwargs=model_init_kwargs, + priority_model_idx=(0 if enable_wlt_optimization else None), + pipeline_execution=True, + return_ranked_to_cpu=False, + ) + self.vision_models.append(self.vision_encoder_model) + + def get_padding_length(self, input_ids): + """Get the CTE bucket size for the given input length.""" + buckets = self.context_encoding_model.config.neuron_config.buckets + for val in buckets: + if val >= input_ids.shape[1]: + return val + raise RuntimeError( + f"No bucket found for input_ids length {input_ids.shape[1]}. " + f"Available buckets: {buckets}" + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + seq_ids: Optional[torch.LongTensor] = None, + sampling_params: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + vision_mask: Optional[torch.FloatTensor] = None, + adapter_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + medusa_args=None, + input_capture_hook: Optional[Callable] = None, + tensor_capture_hook: Optional[Callable] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Forward pass orchestrating vision encoder and text decoder. + + For context encoding with images: + 1. Identify token positions in input_ids + 2. Run vision encoder NEFF -> padded vision embeddings + 3. Pass vision_embeddings + vision_mask to text decoder CTE + 4. Inside CTE NEFF: embed_tokens -> scatter vision -> decoder layers -> logits + + For token generation or text-only: + - Pass dummy (zero) vision tensors to text decoder TKG + """ + # Work around NxDI issue: NeuronBaseForImageToText.forward() doesn't + # capture preprocess_inputs() return values, so sampling_params=None + # flows through to the compiled NEFF which expects [batch, 3]. + # Provide default sampling_params here if not supplied by caller. + if sampling_params is None: + sampling_params = self.default_sampling_params + + pad_limit = self.get_padding_length(input_ids) + + if ( + pixel_values is not None + and input_ids.shape[-1] > 1 + and pixel_values.sum() != 0 + ): + # Context encoding with images + # Build vision_mask: find positions in input_ids + vision_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + vision_mask = vision_mask.to(torch.bool) + vision_mask = generate_positions_from_mask(vision_mask.squeeze()) + vision_mask = pad_positions(vision_mask, pad_limit, (pad_limit - 1)) + + # Run vision encoder NEFF + # The NEFF is compiled for batch=1 (single image). For multi-image + # requests, loop over images and concatenate vision embeddings. + num_images = pixel_values.shape[0] + pv_dtype = self.vision_config.neuron_config.torch_dtype + if num_images == 1: + vision_embeddings = self.vision_encoder_model(pixel_values.to(pv_dtype)) + else: + # Each NEFF call returns [1, text_seq_len, hidden] with the first + # tokens_per_image positions containing real embeddings (rest are + # zero-padded by pad_to_text_seq_len inside the NEFF). + # Extract real tokens from each call, concatenate, re-pad. + tokens_per_image = 256 # (448/14)^2 * downsample_ratio^2 + emb_list = [] + for i in range(num_images): + emb_i = self.vision_encoder_model( + pixel_values[i : i + 1].to(pv_dtype) + ) + # Extract real tokens only: [1, tokens_per_image, hidden] + emb_list.append(emb_i[:, :tokens_per_image, :]) + # Concatenate: [1, num_images * tokens_per_image, hidden] + vision_embeddings = torch.cat(emb_list, dim=1) + # Pad to CTE bucket size (pad_limit) + total_vis_tokens = vision_embeddings.shape[1] + if total_vis_tokens < pad_limit: + pad_zeros = torch.zeros( + 1, + pad_limit - total_vis_tokens, + vision_embeddings.shape[2], + dtype=vision_embeddings.dtype, + device=vision_embeddings.device, + ) + vision_embeddings = torch.cat([vision_embeddings, pad_zeros], dim=1) + elif total_vis_tokens > pad_limit: + vision_embeddings = vision_embeddings[:, :pad_limit, :] + else: + # Token generation or text-only: use dummy zeros + vision_embeddings, vision_mask = ( + self.text_model_wrapper.get_dummy_vision_inputs( + config=self.text_config, + input_ids=input_ids, + n_active_tokens=pad_limit, + fill_value=(pad_limit - 1), + ) + ) + + output_token = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + input_capture_hook=input_capture_hook, + tensor_capture_hook=tensor_capture_hook, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + return output_token + + @classmethod + def get_config_cls(cls): + return InternVL3InferenceConfig + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load original HF InternVL3 model for CPU reference inference. + + Uses FP32 to avoid -inf logits that occur with BF16, + which cause false failures in logit_validation. + """ + from transformers import AutoModel + + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, torch_dtype=torch.float32 + ).eval() + # HF generate() requires img_context_token_id to be set + model.img_context_token_id = IMG_CONTEXT_TOKEN_ID + return model + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, inference_config: InternVL3InferenceConfig + ) -> dict: + """ + Convert full InternVL3 HF state dict to Neuron format. + + Delegates to: + - convert_vision_hf_to_neuron_state_dict() for vision + projector weights + - NeuronInternVL3TextForCausalLM.convert_hf_to_neuron_state_dict() for text weights + """ + # Vision weights (encoder + projector) + vision_state_dict = convert_vision_hf_to_neuron_state_dict(state_dict) + + # Text weights + text_state_dict = ( + NeuronInternVL3TextForCausalLM.convert_hf_to_neuron_state_dict( + state_dict, inference_config.text_config + ) + ) + + # Merge (vision and text keys should not overlap) + merged = {} + merged.update(vision_state_dict) + merged.update(text_state_dict) + return merged + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + """Handle tied embeddings.""" + NeuronInternVL3TextForCausalLM.update_state_dict_for_tied_weights(state_dict) diff --git a/contrib/models/InternVL3-8B-Instruct/src/modeling_internvl3_text.py b/contrib/models/InternVL3-8B-Instruct/src/modeling_internvl3_text.py new file mode 100644 index 00000000..be8f2094 --- /dev/null +++ b/contrib/models/InternVL3-8B-Instruct/src/modeling_internvl3_text.py @@ -0,0 +1,445 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +InternVL3-8B-Instruct: Text backbone (Qwen2.5-7B) for NxDI. + +This module contains the text decoder model, attention, decoder layers, +text model wrapper (ImageToTextModelWrapper), and weight conversion. + +Text backbone: Qwen2.5-7B +- 28 layers, hidden_size=3584, 28 Q heads, 4 KV heads (GQA 7:1) +- Standard RoPE (rope_theta=1e6), RMSNorm, SiLU gated MLP +- QKV bias=True, O bias=False +- vocab_size=151674, tie_word_embeddings=False + +Weight key mapping (HF -> NxDI): + language_model.model.layers.{i}.* -> layers.{i}.* + language_model.model.embed_tokens.weight -> embed_tokens.weight + language_model.model.norm.weight -> norm.weight + language_model.lm_head.weight -> lm_head.weight +""" + +import torch +from torch import nn + +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, +) +from neuronx_distributed.utils import cpu_mode + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.image_to_text_model_wrapper import ( + ImageToTextModelWrapper, +) +from neuronx_distributed_inference.models.llama.modeling_llama import NeuronLlamaMLP +from neuronx_distributed_inference.models.llama4.utils.encoder_utils import ( + scatter_by_index_put, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + + +def get_rmsnorm_cls(): + """Get RMSNorm implementation: HF for CPU, CustomRMSNorm for Neuron.""" + return Qwen2RMSNorm if cpu_mode() else CustomRMSNorm + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class NeuronInternVL3Attention(NeuronAttentionBase): + """ + InternVL3 text attention: GQA with standard RoPE and QKV bias. + + - 28 Q heads, 4 KV heads (7:1 GQA ratio) + - head_dim = 128 + - Q/K/V have bias, O does not + - Standard RoPE (not M-RoPE) + - No Q-K normalization (unlike Qwen3) + """ + + def __init__(self, config): + head_dim = getattr( + config, + "head_dim", + config.hidden_size // config.num_attention_heads, + ) + rotary_emb = RotaryEmbedding( + dim=head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=head_dim, + rotary_emb=rotary_emb, + qkv_bias=True, + o_bias=False, + rms_norm_eps=config.rms_norm_eps, + ) + + +# --------------------------------------------------------------------------- +# Decoder layer +# --------------------------------------------------------------------------- + + +class NeuronInternVL3DecoderLayer(nn.Module): + """ + InternVL3 text decoder layer: pre-norm RMSNorm + GQA attention + SwiGLU MLP. + + Supports NKI kernel fused RMSNorm when qkv_kernel_enabled or mlp_kernel_enabled. + """ + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = NeuronInternVL3Attention(config) + self.mlp = NeuronLlamaMLP(config) + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + # NKI kernel flags — fuse RMSNorm into kernels when enabled + neuron_config = config.neuron_config + self.qkv_kernel_enabled = neuron_config.qkv_kernel_enabled + self.mlp_kernel_enabled = neuron_config.mlp_kernel_enabled + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + **kwargs, + ): + residual = hidden_states + + # When QKV kernel is enabled, pass the RMSNorm module to be fused + # into the NKI kernel instead of applying it separately + if self.qkv_kernel_enabled: + qkv_fused_rmsnorm = self.input_layernorm + else: + hidden_states = self.input_layernorm(hidden_states) + qkv_fused_rmsnorm = None + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + rmsnorm=qkv_fused_rmsnorm, + **kwargs, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + + # When MLP kernel is enabled, pass the RMSNorm module to be fused + if self.mlp_kernel_enabled: + hidden_states, _ = self.mlp( + hidden_states, + rmsnorm=self.post_attention_layernorm, + ) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states)[0] + + hidden_states = residual + hidden_states + + return (hidden_states, present_key_value, cos_cache, sin_cache, None) + + +# --------------------------------------------------------------------------- +# Text model wrapper (ImageToTextModelWrapper) +# --------------------------------------------------------------------------- + + +class InternVL3TextModelWrapper(ImageToTextModelWrapper): + """ + Text model wrapper for InternVL3 that includes vision embedding inputs + in the compiled NEFF trace signature. + + Inherits ImageToTextModelWrapper which generates 24-argument input tuples + with vision_embeddings (arg 22) and vision_mask (arg 23). + """ + + def __init__( + self, + config, + model_cls, + tag="", + compiler_args=None, + priority_model_idx=None, + pipeline_execution=True, + return_ranked_to_cpu=True, + model_init_kwargs={}, + ) -> None: + super().__init__( + config, + model_cls, + tag, + compiler_args, + priority_model_idx, + pipeline_execution, + return_ranked_to_cpu, + model_init_kwargs, + ) + + @staticmethod + def get_dummy_vision_inputs(config, input_ids, n_active_tokens, fill_value): + """ + Create dummy vision tensors for tracing and text-only / token-gen passes. + + For context encoding (seq_len > 1): + - vision_embeddings: [batch, seq_len, hidden_size] zeros + - vision_mask: [batch, n_active_tokens, 1] filled with fill_value (int32 positions) + For token generation (seq_len == 1): + - Both are empty tensors + """ + input_batch_size, input_sequence_len = input_ids.shape[0], input_ids.shape[-1] + if input_sequence_len > 1: + vision_embeddings = torch.zeros( + input_batch_size, + config.neuron_config.seq_len, + config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + size=(input_batch_size, n_active_tokens, 1), + fill_value=fill_value, + dtype=torch.int32, + ) + else: + vision_embeddings = torch.zeros((0), dtype=config.neuron_config.torch_dtype) + vision_mask = torch.zeros((0), dtype=torch.bool) + return vision_embeddings, vision_mask + + +# --------------------------------------------------------------------------- +# Text model (traced on Neuron) +# --------------------------------------------------------------------------- + + +class NeuronInternVL3TextModel(NeuronBaseModel): + """ + InternVL3 text model (Qwen2.5-7B backbone) for NxDI. + + Components: + - ParallelEmbedding (vocab=151674, hidden=3584) + - 28x NeuronInternVL3DecoderLayer + - RMSNorm + - ColumnParallelLinear lm_head + + Implements encode_vision_to_input() for merging vision embeddings + into text embeddings during context encoding. + """ + + def setup_attr_for_model(self, config): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + pad=True, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + ) + + self.layers = nn.ModuleList( + [ + NeuronInternVL3DecoderLayer(config) + for _ in range(config.num_hidden_layers) + ] + ) + + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + pad=True, + gather_output=not self.on_device_sampling, + dtype=config.neuron_config.torch_dtype, + ) + + def encode_vision_to_input( + self, + inputs_embeds: torch.Tensor, + vision_embeddings: torch.Tensor, + vision_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Scatter vision embeddings into text embeddings at image token positions. + + Called by NeuronBaseModel.get_model_output() during context encoding only. + Runs ON-DEVICE inside the compiled NEFF. + + Args: + inputs_embeds: [batch, seq_len, hidden_size] -- text token embeddings + vision_embeddings: [batch, seq_len, hidden_size] -- padded vision embeddings + vision_mask: [batch, seq_len, 1] -- int32 position indices + + Returns: + inputs_embeds with vision positions replaced by vision embeddings + """ + return scatter_by_index_put(inputs_embeds, vision_embeddings, vision_mask) + + +# --------------------------------------------------------------------------- +# Weight conversion helper (text-only, used by top-level model) +# --------------------------------------------------------------------------- + + +class NeuronInternVL3TextForCausalLM(NeuronBaseForCausalLM): + """ + Helper class for text weight conversion only. + Not used directly for inference -- the top-level NeuronBaseForImageToText + class handles that via NeuronInternVL3TextModel. + """ + + _model_cls = NeuronInternVL3TextModel + + @staticmethod + def load_hf_model(model_path, **kwargs): + return None + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: InferenceConfig + ) -> dict: + """ + Convert InternVL3 text weights from HuggingFace to Neuron format. + + HF layout: + language_model.model.embed_tokens.weight + language_model.model.layers.{i}.self_attn.{q,k,v}_proj.{weight,bias} + language_model.model.layers.{i}.self_attn.o_proj.weight + language_model.model.layers.{i}.mlp.{gate,up,down}_proj.weight + language_model.model.layers.{i}.{input,post_attention}_layernorm.weight + language_model.model.norm.weight + language_model.lm_head.weight + + NxDI layout (fused_qkv=False): + embed_tokens.weight + layers.{i}.self_attn.{q,k,v}_proj.{weight,bias} + layers.{i}.self_attn.o_proj.weight + layers.{i}.mlp.{gate,up,down}_proj.weight + layers.{i}.{input,post_attention}_layernorm.weight + norm.weight + lm_head.weight + + NxDI layout (fused_qkv=True): + layers.{i}.self_attn.qkv_proj.Wqkv.{weight,bias} + (q/k/v fused into single Wqkv tensor) + + When fused_qkv=False, separate q/k/v_proj are kept and the GQA + preshard_hook fuses them during weight sharding. + + When fused_qkv=True (required for NKI QKV kernel), we pre-fuse + q/k/v into Wqkv here because the preshard_hook expects it. + """ + neuron_config = config.neuron_config + neuron_state_dict = {} + + # Add rank tensors for tensor parallelism + if neuron_config.vocab_parallel: + neuron_state_dict["embed_tokens.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + for key, value in state_dict.items(): + # Only process text weights (language_model.*) + if not key.startswith("language_model."): + continue + + # Strip the language_model.model. prefix + new_key = key + if key.startswith("language_model.model."): + new_key = key[len("language_model.model.") :] + elif key.startswith("language_model."): + new_key = key[len("language_model.") :] + + neuron_state_dict[new_key] = value.detach().clone() + + # When fused_qkv=True, fuse separate q/k/v weights into Wqkv + if neuron_config.fused_qkv: + for i in range(config.num_hidden_layers): + prefix = f"layers.{i}.self_attn" + # Fuse weights: [q_size, hidden] + [kv_size, hidden] + [kv_size, hidden] + q_w = neuron_state_dict.pop(f"{prefix}.q_proj.weight") + k_w = neuron_state_dict.pop(f"{prefix}.k_proj.weight") + v_w = neuron_state_dict.pop(f"{prefix}.v_proj.weight") + neuron_state_dict[f"{prefix}.qkv_proj.Wqkv.weight"] = torch.cat( + [q_w, k_w, v_w], dim=0 + ) + # Fuse biases (Q/K/V all have bias in InternVL3) + q_b = neuron_state_dict.pop(f"{prefix}.q_proj.bias", None) + k_b = neuron_state_dict.pop(f"{prefix}.k_proj.bias", None) + v_b = neuron_state_dict.pop(f"{prefix}.v_proj.bias", None) + if q_b is not None and k_b is not None and v_b is not None: + neuron_state_dict[f"{prefix}.qkv_proj.Wqkv.bias"] = torch.cat( + [q_b, k_b, v_b], dim=0 + ) + + # Add per-layer rank tensors for attention TP sharding + tp_degree = neuron_config.tp_degree + for i in range(config.num_hidden_layers): + neuron_state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + # Add base model rank tensor + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + return neuron_state_dict + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + """Handle tied embeddings (InternVL3: tie_word_embeddings=False, but support both).""" + if "lm_head.weight" not in state_dict and "embed_tokens.weight" in state_dict: + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + + @classmethod + def get_config_cls(cls): + # This helper class doesn't need its own config -- used through top-level + from neuronx_distributed_inference.models.config import InferenceConfig + + return InferenceConfig diff --git a/contrib/models/InternVL3-8B-Instruct/src/modeling_internvl3_vision.py b/contrib/models/InternVL3-8B-Instruct/src/modeling_internvl3_vision.py new file mode 100644 index 00000000..0d5ee256 --- /dev/null +++ b/contrib/models/InternVL3-8B-Instruct/src/modeling_internvl3_vision.py @@ -0,0 +1,447 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +InternVL3-8B-Instruct: Vision encoder (InternViT-300M) for NxDI. + +Architecture: + - InternViT-300M-448px-V2.5 + - 24 layers, hidden_size=1024, 16 heads, head_dim=64 + - Patch size 14x14, image 448x448 -> 1024 patches + CLS = 1025 tokens + - LayerNorm, GELU, LayerScale (ls1, ls2) + - Fused QKV: attn.qkv [3072, 1024] + - Position: absolute learned embeddings [1, 1025, 1024] + +Full vision pipeline: + 1. Patch embedding + CLS + position -> 24 transformer layers -> [B, 1025, 1024] + 2. Strip CLS token: [B, 1024, 1024] + 3. Pixel shuffle (0.5x): [B, 256, 4096] + 4. Projector MLP: LayerNorm(4096) -> Linear(4096,3584) -> GELU -> Linear(3584,3584) + 5. Pad to text seq_len: [1, seq_len, 3584] + +Weight key mapping (HF -> NxDI vision): + vision_model.embeddings.class_embedding -> encoder.class_embedding + vision_model.embeddings.patch_embedding.{weight,bias} -> encoder.patch_embedding.{weight,bias} + vision_model.embeddings.position_embedding -> encoder.position_embedding + vision_model.encoder.layers.{i}.* -> encoder.layers.{i}.* + mlp1.0.{weight,bias} -> proj_norm.{weight,bias} + mlp1.1.{weight,bias} -> proj_linear1.{weight,bias} + mlp1.3.{weight,bias} -> proj_linear2.{weight,bias} +""" + +import math +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neuronx_distributed_inference.models.config import InferenceConfig +from neuronx_distributed_inference.models.model_wrapper import ( + EncoderModelInstance, + ModelWrapper, +) + +# InternVL3 vision constants +VISION_HIDDEN_SIZE = 1024 +VISION_NUM_HEADS = 16 +VISION_NUM_LAYERS = 24 +VISION_INTERMEDIATE_SIZE = 4096 +VISION_PATCH_SIZE = 14 +VISION_IMAGE_SIZE = 448 +VISION_NUM_PATCHES = (VISION_IMAGE_SIZE // VISION_PATCH_SIZE) ** 2 # 1024 +VISION_NUM_OUTPUT_TOKENS = 256 # after pixel shuffle (0.5x): 1024 / 4 +TEXT_HIDDEN_SIZE = 3584 +DOWNSAMPLE_RATIO = 0.5 + + +# --------------------------------------------------------------------------- +# Vision encoder components (pure PyTorch, for tracing) +# --------------------------------------------------------------------------- + + +class InternVisionAttention(nn.Module): + """InternViT attention with fused QKV projection.""" + + def __init__(self, hidden_size=VISION_HIDDEN_SIZE, num_heads=VISION_NUM_HEADS): + super().__init__() + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv = nn.Linear(hidden_size, 3 * hidden_size, bias=True) + self.proj = nn.Linear(hidden_size, hidden_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch, seq_len, _ = hidden_states.shape + + qkv = self.qkv(hidden_states) + qkv = qkv.reshape(batch, seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, seq, head_dim] + q, k, v = qkv.unbind(0) + + # Scaled dot-product attention (no causal mask for vision) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + out = attn @ v # [B, heads, seq, head_dim] + + out = out.transpose(1, 2).reshape(batch, seq_len, -1) + return self.proj(out) + + +class InternVisionMLP(nn.Module): + """InternViT MLP: fc1 -> GELU -> fc2.""" + + def __init__( + self, hidden_size=VISION_HIDDEN_SIZE, intermediate_size=VISION_INTERMEDIATE_SIZE + ): + super().__init__() + self.fc1 = nn.Linear(hidden_size, intermediate_size, bias=True) + self.fc2 = nn.Linear(intermediate_size, hidden_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.fc2(F.gelu(self.fc1(hidden_states))) + + +class InternVisionLayer(nn.Module): + """InternViT transformer layer with pre-norm and LayerScale.""" + + def __init__( + self, + hidden_size=VISION_HIDDEN_SIZE, + num_heads=VISION_NUM_HEADS, + intermediate_size=VISION_INTERMEDIATE_SIZE, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size) + self.attn = InternVisionAttention(hidden_size, num_heads) + self.ls1 = nn.Parameter(torch.ones(hidden_size) * 0.1) + + self.norm2 = nn.LayerNorm(hidden_size) + self.mlp = InternVisionMLP(hidden_size, intermediate_size) + self.ls2 = nn.Parameter(torch.ones(hidden_size) * 0.1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states + self.ls1 * self.attn(self.norm1(hidden_states)) + hidden_states = hidden_states + self.ls2 * self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class InternVisionEncoder(nn.Module): + """ + InternViT-300M vision encoder. + + Takes pixel_values [B, 3, 448, 448], returns [B, 1025, 1024] (with CLS). + """ + + def __init__( + self, + hidden_size=VISION_HIDDEN_SIZE, + num_heads=VISION_NUM_HEADS, + num_layers=VISION_NUM_LAYERS, + intermediate_size=VISION_INTERMEDIATE_SIZE, + patch_size=VISION_PATCH_SIZE, + image_size=VISION_IMAGE_SIZE, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_patches = (image_size // patch_size) ** 2 + + self.patch_embedding = nn.Conv2d( + 3, hidden_size, kernel_size=patch_size, stride=patch_size, bias=True + ) + self.class_embedding = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.position_embedding = nn.Parameter( + torch.zeros(1, self.num_patches + 1, hidden_size) + ) + + self.layers = nn.ModuleList( + [ + InternVisionLayer(hidden_size, num_heads, intermediate_size) + for _ in range(num_layers) + ] + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch = pixel_values.shape[0] + # Cast pixel_values to match Conv2d weight dtype to avoid + # "Input type (BFloat16) and bias type (float) should be the same" error + # when HF weights are loaded in float32 but input arrives in bf16. + # The compiler's --auto-cast=matmult handles the actual mixed-precision math. + pixel_values = pixel_values.to(dtype=self.patch_embedding.weight.dtype) + patches = self.patch_embedding(pixel_values) + patches = patches.flatten(2).transpose(1, 2) + cls_tokens = self.class_embedding.expand(batch, -1, -1) + hidden_states = torch.cat([cls_tokens, patches], dim=1) + hidden_states = hidden_states + self.position_embedding + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +# --------------------------------------------------------------------------- +# Full vision pipeline (encoder + pixel shuffle + projector + padding) +# --------------------------------------------------------------------------- + + +class NeuronInternVL3VisionModel(nn.Module): + """ + Full InternVL3 vision pipeline for NxDI tracing. + + Input: pixel_values [B, 3, 448, 448] + Output: [1, text_seq_len, text_hidden_size] (padded for scatter_by_index_put) + + Pipeline: + 1. InternViT encoder -> [B, 1025, 1024] + 2. Strip CLS token -> [B, 1024, 1024] + 3. Pixel shuffle (0.5x) -> [B, 256, 4096] + 4. Projector MLP: LayerNorm -> Linear -> GELU -> Linear -> [B, 256, 3584] + 5. Pad to text seq_len -> [1, seq_len, 3584] + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + self.config = config + + # Extract text seq_len for output padding + self.text_seq_len = config.text_config.neuron_config.seq_len + self.text_hidden_size = config.text_config.hidden_size + self.text_dtype = config.text_config.neuron_config.torch_dtype + + # Vision encoder + self.encoder = InternVisionEncoder( + hidden_size=VISION_HIDDEN_SIZE, + num_heads=VISION_NUM_HEADS, + num_layers=VISION_NUM_LAYERS, + intermediate_size=VISION_INTERMEDIATE_SIZE, + patch_size=VISION_PATCH_SIZE, + image_size=VISION_IMAGE_SIZE, + ) + + # Projector MLP (pixel_shuffle output -> text hidden size) + proj_input_dim = int(VISION_HIDDEN_SIZE * (1.0 / DOWNSAMPLE_RATIO) ** 2) # 4096 + self.proj_norm = nn.LayerNorm(proj_input_dim) # mlp1.0 + self.proj_linear1 = nn.Linear( + proj_input_dim, TEXT_HIDDEN_SIZE, bias=True + ) # mlp1.1 + # mlp1.2 is GELU (no weights) + self.proj_linear2 = nn.Linear( + TEXT_HIDDEN_SIZE, TEXT_HIDDEN_SIZE, bias=True + ) # mlp1.3 + + def pixel_shuffle(self, x: torch.Tensor) -> torch.Tensor: + """ + Pixel shuffle downsampling (downsample_ratio=0.5). + + Input: [B, H*W, C] where H=W=32, C=1024 + Output: [B, (H/2)*(W/2), C*4] = [B, 256, 4096] + """ + batch, n_patches, channels = x.shape + h = w = int(math.sqrt(n_patches)) + + x = x.reshape(batch, h, w, channels) + + ratio = DOWNSAMPLE_RATIO + x = x.reshape(batch, h, int(w * ratio), int(channels / ratio)) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.reshape( + batch, int(h * ratio), int(w * ratio), int(channels / (ratio * ratio)) + ) + x = x.permute(0, 2, 1, 3).contiguous() + + return x.reshape(batch, -1, x.shape[-1]) + + def pad_to_text_seq_len(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Pad vision embeddings to text model's sequence length. + + Input: [B, 256, 3584] + Output: [1, text_seq_len, 3584] (zero-padded, batch=1) + """ + hidden_states = hidden_states.to(self.text_dtype) + batch, n_tokens, hidden_size = hidden_states.shape + + # Pad sequence dimension to text seq_len + if n_tokens < self.text_seq_len: + pad = torch.zeros( + batch, + self.text_seq_len - n_tokens, + hidden_size, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + hidden_states = torch.cat([hidden_states, pad], dim=1) + + # Reshape to [1, seq_len, hidden] (batch=1 for scatter) + hidden_states = hidden_states.view(-1, hidden_size).unsqueeze(0) + return hidden_states + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Args: + pixel_values: [batch, 3, 448, 448] + + Returns: + vision_embeddings: [1, text_seq_len, text_hidden_size] padded + """ + # Encoder: [B, 1025, 1024] + hidden_states = self.encoder(pixel_values) + + # Strip CLS: [B, 1024, 1024] + hidden_states = hidden_states[:, 1:, :] + + # Pixel shuffle: [B, 256, 4096] + hidden_states = self.pixel_shuffle(hidden_states) + + # Projector MLP: [B, 256, 3584] + hidden_states = self.proj_norm(hidden_states) + hidden_states = F.gelu(self.proj_linear1(hidden_states)) + hidden_states = self.proj_linear2(hidden_states) + + # Pad to text seq_len: [1, seq_len, 3584] + hidden_states = self.pad_to_text_seq_len(hidden_states) + + return hidden_states + + +# --------------------------------------------------------------------------- +# Vision model wrapper (for NxDI tracing) +# --------------------------------------------------------------------------- + + +class InternVL3VisionModelWrapper(ModelWrapper): + """ + Wrapper for tracing the InternVL3 vision encoder on Neuron. + + Uses EncoderModelInstance (no KV cache). + Vision buckets represent number of images (always 1 for InternVL3). + """ + + def __init__( + self, + config: InferenceConfig, + model_cls, + tag="", + compiler_args=None, + priority_model_idx=None, + pipeline_execution=True, + return_ranked_to_cpu=False, + model_init_kwargs={}, + ) -> None: + super().__init__( + config, + model_cls, + tag, + compiler_args, + priority_model_idx, + pipeline_execution, + return_ranked_to_cpu, + model_init_kwargs, + ) + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + """ + Generate example inputs for vision encoder tracing. + + InternVL3 processes one 448x448 image at a time (no dynamic patching + at this stage). Single bucket with batch=1. + """ + inputs = [] + # Vision buckets = [1] (single image) + for bucket in self.config.vision_config.neuron_config.buckets: + pixel_values = torch.ones( + [bucket, 3, VISION_IMAGE_SIZE, VISION_IMAGE_SIZE], + dtype=self.config.vision_config.neuron_config.torch_dtype, + ) + inputs.append((pixel_values,)) + return inputs + + def get_model_instance(self): + return EncoderModelInstance(model_cls=self.model_cls, config=self.config) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Run vision encoder on Neuron. + + Args: + pixel_values: [batch, 3, 448, 448] + + Returns: + vision_embeddings: [1, text_seq_len, text_hidden_size] + """ + if self.model is None: + raise RuntimeError( + "Forward called before load. Run load() or load_state_dict() first." + ) + output = self._forward(pixel_values) + return output + + +# --------------------------------------------------------------------------- +# Vision weight conversion +# --------------------------------------------------------------------------- + + +def convert_vision_hf_to_neuron_state_dict(state_dict: dict) -> dict: + """ + Convert InternVL3 vision + projector weights from HF to NxDI format. + + HF keys: + vision_model.embeddings.class_embedding + vision_model.embeddings.patch_embedding.{weight,bias} + vision_model.embeddings.position_embedding + vision_model.encoder.layers.{i}.attn.qkv.{weight,bias} + vision_model.encoder.layers.{i}.attn.proj.{weight,bias} + vision_model.encoder.layers.{i}.ls1 + vision_model.encoder.layers.{i}.ls2 + vision_model.encoder.layers.{i}.norm1.{weight,bias} + vision_model.encoder.layers.{i}.norm2.{weight,bias} + vision_model.encoder.layers.{i}.mlp.fc1.{weight,bias} + vision_model.encoder.layers.{i}.mlp.fc2.{weight,bias} + mlp1.0.{weight,bias} -> proj_norm + mlp1.1.{weight,bias} -> proj_linear1 + mlp1.3.{weight,bias} -> proj_linear2 + + NxDI keys: + encoder.class_embedding + encoder.patch_embedding.{weight,bias} + encoder.position_embedding + encoder.layers.{i}.* + proj_norm.{weight,bias} + proj_linear1.{weight,bias} + proj_linear2.{weight,bias} + """ + neuron_state_dict = {} + + # Projector key mapping + PROJECTOR_MAP = { + "mlp1.0.weight": "proj_norm.weight", + "mlp1.0.bias": "proj_norm.bias", + "mlp1.1.weight": "proj_linear1.weight", + "mlp1.1.bias": "proj_linear1.bias", + "mlp1.3.weight": "proj_linear2.weight", + "mlp1.3.bias": "proj_linear2.bias", + } + + for key, tensor in state_dict.items(): + # Projector weights + if key in PROJECTOR_MAP: + neuron_state_dict[PROJECTOR_MAP[key]] = tensor.detach().clone() + continue + + # Vision encoder embeddings + if key.startswith("vision_model.embeddings."): + suffix = key[len("vision_model.embeddings.") :] + neuron_state_dict[f"encoder.{suffix}"] = tensor.detach().clone() + continue + + # Vision encoder layers + if key.startswith("vision_model.encoder.layers."): + suffix = key[len("vision_model.encoder.") :] + neuron_state_dict[f"encoder.{suffix}"] = tensor.detach().clone() + continue + + # Skip other vision_model keys and all non-vision keys + # (text weights handled by text conversion) + + return neuron_state_dict diff --git a/contrib/models/InternVL3-8B-Instruct/test/__init__.py b/contrib/models/InternVL3-8B-Instruct/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/InternVL3-8B-Instruct/test/integration/__init__.py b/contrib/models/InternVL3-8B-Instruct/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/InternVL3-8B-Instruct/test/integration/test_model.py b/contrib/models/InternVL3-8B-Instruct/test/integration/test_model.py new file mode 100644 index 00000000..eafb79ad --- /dev/null +++ b/contrib/models/InternVL3-8B-Instruct/test/integration/test_model.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Integration tests for InternVL3-8B-Instruct NxDI contrib model. + +Validates Neuron model accuracy against CPU reference using logit_validation. + +Usage: + pytest test_model.py -v --tb=short + +Prerequisites: + - Model downloaded to MODEL_PATH + - Compiled model at COMPILED_MODEL_PATH (run compile_internvl3_vlm.py first) + - Neuron runtime available (trn2.3xlarge, LNC=2, TP=4) +""" + +import json +import math +import os +import sys +from pathlib import Path + +import pytest +import torch +from torch_neuronx.testing.validation import logit_validation +from transformers import AutoTokenizer, GenerationConfig + +from neuronx_distributed_inference.models.config import NeuronConfig +from neuronx_distributed_inference.utils.accuracy import generate_expected_logits +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + +# Import from src directory +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_internvl3 import ( + NeuronInternVL3ForCausalLM, + InternVL3InferenceConfig, +) + + +# Test configuration — override via environment variables if needed +MODEL_PATH = os.environ.get( + "INTERNVL3_MODEL_PATH", "/mnt/models/InternVL3-8B-Instruct/" +) +COMPILED_MODEL_PATH = os.environ.get( + "INTERNVL3_COMPILED_PATH", "/mnt/models/neuron_models/InternVL3-8B-Instruct/" +) +NUM_TOKENS_TO_CHECK = 16 +TOKEN_DIVERGENCE_ATOL = 0.02 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def load_compiled_model(model_path: str, compiled_path: str): + """Load pre-compiled InternVL3 model for inference.""" + config_path = Path(compiled_path) / "neuron_config.json" + if not config_path.exists(): + raise FileNotFoundError(f"neuron_config.json not found at {config_path}") + + with open(config_path) as f: + config_data = json.load(f) + nc_dict = config_data.get("neuron_config", config_data) + + dtype_str = nc_dict.get("torch_dtype", "torch.bfloat16") + if isinstance(dtype_str, str): + dtype = ( + getattr(torch, dtype_str.split(".")[-1]) + if "torch" in dtype_str + else torch.bfloat16 + ) + else: + dtype = dtype_str + + text_neuron_config = NeuronConfig( + tp_degree=nc_dict.get("tp_degree", 4), + max_batch_size=nc_dict.get("batch_size", nc_dict.get("max_batch_size", 1)), + seq_len=nc_dict.get("seq_len", 2048), + torch_dtype=dtype, + on_device_sampling_config=None, + save_sharded_checkpoint=True, + ) + + vision_neuron_config = NeuronConfig( + tp_degree=nc_dict.get("tp_degree", 4), + max_batch_size=1, + seq_len=256, + torch_dtype=dtype, + on_device_sampling_config=None, + buckets=[1], + fused_qkv=True, + save_sharded_checkpoint=True, + ) + + config = InternVL3InferenceConfig.from_pretrained( + model_path, + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + ) + + model = NeuronInternVL3ForCausalLM(model_path, config=config) + model.load(compiled_path) + return model + + +def get_tp_aligned_vocab_size(vocab_size: int, tp_degree: int) -> int: + """Get the largest vocab index that's valid (not TP padding). + + When vocab_size is not divisible by tp_degree, lm_head pads up to the + next multiple of tp_degree. The padded positions contain -FLT_MAX in + the Neuron output. Truncate to this boundary to avoid false failures. + """ + return (vocab_size // tp_degree) * tp_degree + + +# --------------------------------------------------------------------------- +# Integration tests (requires pre-compiled Neuron model) +# --------------------------------------------------------------------------- + + +class TestInternVL3Integration: + """Integration tests for InternVL3-8B-Instruct on Neuron.""" + + @pytest.fixture(scope="class") + def neuron_model(self): + """Load pre-compiled Neuron model (shared across tests in this class).""" + compiled_path = Path(COMPILED_MODEL_PATH) + if not compiled_path.exists(): + pytest.skip(f"Compiled model not found at {compiled_path}") + return load_compiled_model(MODEL_PATH, COMPILED_MODEL_PATH) + + @pytest.fixture(scope="class") + def tokenizer(self): + """Load tokenizer.""" + return AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + def test_config(self, neuron_model): + """Validate InternVL3 config matches expected Qwen2.5-7B architecture.""" + config = neuron_model.config + assert config.hidden_size == 3584 + assert config.num_attention_heads == 28 + assert config.num_key_value_heads == 4 + assert config.num_hidden_layers == 28 + assert config.intermediate_size == 18944 + assert config.vocab_size == 151674 + assert config.rope_theta == 1000000.0 + + def test_text_logit_validation(self, neuron_model, tokenizer): + """ + Validate text-only logit accuracy: Neuron vs CPU reference. + + Uses generate_expected_logits() for CPU golden logits and + logit_validation() for multi-tier logit comparison. + + InternVL3 vocab_size (151674) is not divisible by TP degree (4), + so lm_head pads to the next multiple (151676). The 2 padding + positions get -FLT_MAX in Neuron output. We truncate both CPU + and Neuron logits to the TP-aligned boundary (151672) to avoid + false failures from these padding artifacts. + """ + prompt = "The capital of France is" + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + attention_mask = torch.ones_like(input_ids, dtype=torch.int32) + + generation_config = GenerationConfig( + do_sample=False, + max_new_tokens=NUM_TOKENS_TO_CHECK, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id or 0, + ) + + # 1. Generate CPU reference logits + expected_logits = generate_expected_logits( + neuron_model=neuron_model, + input_ids=input_ids, + inputs_attention_mask=attention_mask, + generation_config=generation_config, + num_tokens=NUM_TOKENS_TO_CHECK, + ) + + # 2. Truncate expected logits to TP-aligned vocab boundary + tp_degree = neuron_model.config.neuron_config.tp_degree + vocab_size = neuron_model.config.vocab_size + aligned_vocab = get_tp_aligned_vocab_size(vocab_size, tp_degree) + expected_logits = expected_logits[..., :aligned_vocab] + + # 3. Build generate_fn that returns Neuron logits (also truncated) + adapter = HuggingFaceGenerationAdapter(neuron_model) + + expected_attention_mask = torch.ones( + (attention_mask.shape[0], expected_logits.shape[0]), + dtype=torch.int32, + ) + extrapolated_attention_mask = torch.cat( + (attention_mask, expected_attention_mask), dim=1 + ) + + def generate_fn(input_ids_tensor): + neuron_model.reset() + input_length = input_ids_tensor.shape[1] + attn_mask = extrapolated_attention_mask[:, :input_length] + new_tokens = NUM_TOKENS_TO_CHECK + input_ids.shape[1] - input_length + + outputs = adapter.generate( + input_ids=input_ids_tensor, + attention_mask=attn_mask, + max_new_tokens=new_tokens, + min_new_tokens=new_tokens, + do_sample=False, + return_dict_in_generate=True, + output_scores=True, + generation_config=generation_config, + ) + actual_logits = torch.stack(outputs.scores) + # Truncate to TP-aligned vocab boundary + actual_logits = actual_logits[..., :aligned_vocab] + return actual_logits + + # 4. Run logit_validation with BF16-appropriate tolerances. + # Default tolerances (top-5: 0.01, top-50: 0.02) are calibrated for + # FP16/FP32 models. BF16 has lower mantissa precision (7 bits vs 10), + # so normalized logit errors of 0.03-0.06 are expected. + bf16_tol_map = { + 5: (1e-5, 0.05), + 50: (1e-5, 0.06), + 1000: (1e-5, 0.06), + None: (1e-5, 0.08), + } + + passed, results, status_msg = logit_validation( + input_ids=input_ids, + generate_fn=generate_fn, + expected_logits=expected_logits, + tol_map=bf16_tol_map, + divergence_difference_tol=TOKEN_DIVERGENCE_ATOL, + ) + + print(f"\n{status_msg}") + assert passed, f"Logit validation failed:\n{status_msg}" diff --git a/contrib/models/InternVL3-8B-Instruct/test/unit/__init__.py b/contrib/models/InternVL3-8B-Instruct/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b