From cf9acbf03c38222f01c7b069f14eeb067c49fc22 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sat, 25 Apr 2026 20:41:32 -0400 Subject: [PATCH 1/3] Add YOLO26 object detection contrib model Ultralytics YOLO26 (n/s/m/l/x) on Trainium2 via torch_neuronx.trace(). All 5 detection variants plus pose and OBB task heads compile and run with high accuracy (CosSim 0.987-0.997). Peak throughput on trn2.3xlarge (LNC=1, DP=8): - YOLO26s: 1,523 img/s (1.43x vs A10G compiled) - YOLO26m: 1,267 img/s (2.67x vs A10G compiled) - YOLO26l: 1,093 img/s (2.95x vs A10G compiled) - YOLO26x: 876 img/s (4.49x vs A10G compiled) Includes modeling module, 13 integration tests (all passing), Jupyter notebook, and README with benchmarks. --- contrib/models/YOLO26/README.md | 132 ++++ contrib/models/YOLO26/src/__init__.py | 25 + contrib/models/YOLO26/src/modeling_yolo26.py | 415 ++++++++++++ contrib/models/YOLO26/test/__init__.py | 0 .../YOLO26/test/integration/__init__.py | 0 .../YOLO26/test/integration/test_model.py | 465 +++++++++++++ contrib/models/YOLO26/test/unit/__init__.py | 0 .../YOLO26/yolo26_neuron_notebook.ipynb | 637 ++++++++++++++++++ 8 files changed, 1674 insertions(+) create mode 100644 contrib/models/YOLO26/README.md create mode 100644 contrib/models/YOLO26/src/__init__.py create mode 100644 contrib/models/YOLO26/src/modeling_yolo26.py create mode 100644 contrib/models/YOLO26/test/__init__.py create mode 100644 contrib/models/YOLO26/test/integration/__init__.py create mode 100644 contrib/models/YOLO26/test/integration/test_model.py create mode 100644 contrib/models/YOLO26/test/unit/__init__.py create mode 100644 contrib/models/YOLO26/yolo26_neuron_notebook.ipynb diff --git a/contrib/models/YOLO26/README.md b/contrib/models/YOLO26/README.md new file mode 100644 index 00000000..44753f30 --- /dev/null +++ b/contrib/models/YOLO26/README.md @@ -0,0 +1,132 @@ +# Contrib Model: YOLO26 + +Ultralytics YOLO26 object detection models on AWS Trainium2 using `torch_neuronx.trace()`. + +## Model Information + +- **Source:** [Ultralytics YOLO26](https://github.com/ultralytics/ultralytics) +- **Model Type:** Object detection (also supports segmentation, pose estimation, oriented bounding boxes) +- **Variants:** 5 detection sizes — n (2.4M), s (10.0M), m (21.9M), l (26.3M), x (58.9M) +- **Architecture:** CNN backbone (Conv2d + BatchNorm + SiLU), FPN/PAN neck, Detect head with C2PSA attention +- **Input:** `[B, 3, 640, 640]` (fixed resolution) +- **Output:** `[B, 84, 8400]` (4 bbox + 80 COCO class scores per anchor) +- **License:** AGPL-3.0 or Ultralytics Enterprise License + +## Architecture Details + +YOLO26 is a 24-layer convolutional neural network optimized for real-time object detection. Unlike transformer-based vision models, it is dominated by Conv2d operations with a small C2PSA self-attention block on the P5 feature map. + +This contrib uses `torch_neuronx.trace()` rather than NxDI model classes because: (1) all variants fit trivially on a single NeuronCore (<180 MB NEFF), (2) there is no KV cache or token generation, and (3) the Conv2d-dominant architecture does not benefit from NxDI's attention infrastructure. Data Parallelism across NeuronCores provides throughput scaling. + +Key Neuron porting challenges: +- **`topk`/`sort` unsupported:** End-to-end postprocessing requires `torch.topk` which fails with `NCC_EVRF029`. Solution: trace with `end2end=False` for raw output, run postprocessing on CPU. +- **FP32 SB overflow for m/l/x:** Larger variants exceed Neuron's SB allocation in FP32. Solution: BF16 compilation (halves tensor sizes). +- **`--auto-cast=matmult` produces NaN:** Conv2d-dominant models get NaN with matmult autocast. Solution: no autocast flags. + +## Validation Results + +**Validated:** 2026-04-25 +**Instance:** trn2.3xlarge (1 Trainium2 chip) +**SDK:** Neuron SDK 2.28, PyTorch 2.9 + +### Peak Throughput (LNC=1, DP=8) + +| Variant | Params | Dtype | NEFF (MB) | BS/core | img/s | A10G Compiled | Speedup | +|---------|--------|-------|-----------|---------|-------|---------------|---------| +| YOLO26n | 2.4M | FP32 | 19.5 | 1 | 272 | 2,166 | 0.13x | +| YOLO26s | 10.0M | FP32 | 69.6 | 32 | 1,523 | 1,065 | **1.43x** | +| YOLO26m | 21.9M | BF16 | 66.4 | 32 | 1,267 | 474 | **2.67x** | +| YOLO26l | 26.3M | BF16 | 80.8 | 32 | 1,093 | 371 | **2.95x** | +| YOLO26x | 58.9M | BF16 | 177.7 | 16 | 876 | 195 | **4.49x** | + +### Accuracy Validation + +| Variant | Dtype | Cosine Similarity | Max Error | Has NaN | +|---------|-------|-------------------|-----------|---------| +| YOLO26n | FP32 | 0.9943 | 373.3 | No | +| YOLO26s | FP32 | 0.9932 | 439.8 | No | +| YOLO26m | BF16 | 0.9879 | 488.0 | No | +| YOLO26l | BF16 | 0.9967 | 242.0 | No | +| YOLO26x | BF16 | 0.9950 | 378.0 | No | + +### Additional Task Heads + +| Task | Head | CosSim | img/s (single core) | Status | +|------|------|--------|---------------------|--------| +| Pose | Pose26 | 0.9996 | 81.6 | Production ready | +| OBB | OBB26 | 0.9999 | 85.3 | Production ready | +| Segmentation | Segment26 | 0.995/0.858 | 63.9 | Proto mask needs validation | +| Classification | Classify | 0.257 | 671.0 | Precision issue (softmax sensitivity) | + +**Status:** VALIDATED + +## Usage + +### Quick Start + +```python +from src import YOLO26NeuronModel + +# Single core +model = YOLO26NeuronModel("s", batch_size=1) +output = model(torch.randn(1, 3, 640, 640)) + +# Data parallel (4 cores on LNC=2) +model = YOLO26NeuronModel("s", batch_size=8, num_cores=4) +output = model(torch.randn(32, 3, 640, 640)) + +# Benchmark +results = model.benchmark(warmup=10, iterations=50) +print(f"Throughput: {results['throughput_img_s']} img/s") +``` + +### Low-Level API + +```python +from src import prepare_yolo26, compile_yolo26, validate_accuracy + +# Prepare and compile +model = compile_yolo26("yolo26s.pt", batch_size=1, save_path="compiled/yolo26s.pt") + +# Validate accuracy +metrics = validate_accuracy("yolo26s.pt", model) +print(f"CosSim: {metrics['cosine_similarity']}") +``` + +### Known Issues + +1. **`topk` not supported on Neuron.** Models must be traced with `end2end=False`. Postprocessing (topk, NMS) runs on CPU with ~0.1ms overhead. +2. **FP32 fails for m/l/x variants.** Use BF16 (`torch.bfloat16`) for these variants. FP32 for n/s only. +3. **`--auto-cast=matmult` produces NaN.** Do not use autocast flags with YOLO26. +4. **LNC=1 requires `--lnc 1` compiler flag.** NEFFs compiled without this flag cannot run on LNC=1 runtime. +5. **Classification variant has precision issue.** Narrow logit range + softmax amplification causes CosSim 0.257. Detection, pose, and OBB are unaffected. + +## Compatibility Matrix + +| Instance | SDK 2.28 | SDK 2.29 | +|----------|----------|----------| +| trn2.3xlarge | Validated | Expected compatible | +| trn2.48xlarge | Expected compatible | Expected compatible | +| inf2.xlarge | Not tested (n/s only) | Not tested | + +## Testing Instructions + +```bash +# Activate Neuron environment +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +pip install ultralytics + +# Run integration tests +cd contrib/models/YOLO26 +pytest test/integration/test_model.py -v + +# Or standalone +python test/integration/test_model.py +``` + +## Maintainer + +Jim Burtoft +Community contribution + +**Last Updated:** 2026-04-25 diff --git a/contrib/models/YOLO26/src/__init__.py b/contrib/models/YOLO26/src/__init__.py new file mode 100644 index 00000000..b0e68c8a --- /dev/null +++ b/contrib/models/YOLO26/src/__init__.py @@ -0,0 +1,25 @@ +"""YOLO26 Object Detection on AWS Neuron (Trainium2 / Inferentia2).""" + +from .modeling_yolo26 import ( + YOLO26NeuronModel, + prepare_yolo26, + compile_yolo26, + validate_accuracy, + get_variant_dtype, + get_neuron_core_count, + VARIANT_DTYPES, + INPUT_SHAPE, + COSINE_SIM_THRESHOLDS, +) + +__all__ = [ + "YOLO26NeuronModel", + "prepare_yolo26", + "compile_yolo26", + "validate_accuracy", + "get_variant_dtype", + "get_neuron_core_count", + "VARIANT_DTYPES", + "INPUT_SHAPE", + "COSINE_SIM_THRESHOLDS", +] diff --git a/contrib/models/YOLO26/src/modeling_yolo26.py b/contrib/models/YOLO26/src/modeling_yolo26.py new file mode 100644 index 00000000..c52fd1ff --- /dev/null +++ b/contrib/models/YOLO26/src/modeling_yolo26.py @@ -0,0 +1,415 @@ +"""YOLO26 Object Detection on AWS Neuron (Trainium2 / Inferentia2). + +Compiles Ultralytics YOLO26 detection models for inference on NeuronCores using +``torch_neuronx.trace()``. Supports all 5 detection variants (n/s/m/l/x) plus +segmentation, pose, and OBB task heads. + +Architecture +------------ +YOLO26 is a 24-layer CNN (Conv2d + BatchNorm + SiLU backbone, FPN/PAN neck, +Detect head) with a small C2PSA self-attention block. The DFL layer is +``nn.Identity()`` (reg_max=1). Models range from 2.4M to 58.9M parameters. + +Neuron Strategy +--------------- +- ``torch_neuronx.trace()`` with fixed [B, 3, 640, 640] input shape +- ``end2end=False`` before ``fuse()`` — ``topk``/``sort`` not supported on trn2 +- FP32 for n/s variants; BF16 required for m/l/x (FP32 exceeds SB allocation) +- No ``--auto-cast`` flags — ``matmult`` produces NaN for Conv2d-dominant models +- Data Parallelism across NeuronCores for throughput scaling +- ``--lnc 1`` compiler flag required when running on LNC=1 mode + +Key Results (trn2.3xlarge, LNC=1, DP=8) +---------------------------------------- +- YOLO26s: 1,523 img/s (1.43x vs A10G compiled) +- YOLO26m: 1,267 img/s (2.67x vs A10G compiled) +- YOLO26l: 1,093 img/s (2.95x vs A10G compiled) +- YOLO26x: 876 img/s (4.49x vs A10G compiled) +""" + +import os +import time + +import torch +import torch.nn as nn + +try: + import torch_neuronx +except ImportError: + torch_neuronx = None + +from ultralytics import YOLO +from ultralytics.nn.modules.block import C2f + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +VARIANT_DTYPES = { + "n": torch.float32, + "s": torch.float32, + "m": torch.bfloat16, + "l": torch.bfloat16, + "x": torch.bfloat16, +} + +INPUT_SHAPE = (3, 640, 640) # C, H, W + +DEFAULT_COMPILER_ARGS = [] # No autocast for YOLO26 + +COSINE_SIM_THRESHOLDS = { + "n": 0.99, + "s": 0.99, + "m": 0.98, + "l": 0.99, + "x": 0.99, +} + + +# --------------------------------------------------------------------------- +# Model Preparation +# --------------------------------------------------------------------------- + + +def prepare_yolo26(weight_path: str, dtype: torch.dtype = torch.float32) -> nn.Module: + """Load and prepare a YOLO26 model for Neuron tracing. + + The preparation recipe: + 1. Set ``end2end=False`` before ``fuse()`` (topk unsupported on Neuron) + 2. ``fuse()`` merges BatchNorm into Conv2d, removes training branches + 3. Set ``export=True``, ``dynamic=False`` for clean single-tensor output + 4. Replace ``C2f.forward`` with ``C2f.forward_split`` (split vs chunk) + 5. Convert to target dtype if not FP32 + + Parameters + ---------- + weight_path : str + Path to the ``.pt`` weight file (e.g., ``"yolo26s.pt"``). + dtype : torch.dtype + Target dtype. Use ``torch.float32`` for n/s, ``torch.bfloat16`` for m/l/x. + + Returns + ------- + nn.Module + Prepared model in eval mode, ready for ``torch_neuronx.trace()``. + """ + model = YOLO(weight_path) + pytorch_model = model.model.eval() + + # Disable end2end BEFORE fuse (topk not supported on Neuron) + detect = pytorch_model.model[-1] + detect.end2end = False + + pytorch_model = pytorch_model.fuse(verbose=False) + + for m in pytorch_model.modules(): + if hasattr(m, "export"): + m.export = True + if hasattr(m, "dynamic"): + m.dynamic = False + if hasattr(m, "format"): + m.format = "torchscript" + if hasattr(m, "shape"): + m.shape = None + if isinstance(m, C2f): + m.forward = m.forward_split + + if dtype != torch.float32: + pytorch_model = pytorch_model.to(dtype) + + return pytorch_model + + +def get_variant_dtype(variant: str) -> torch.dtype: + """Return the recommended dtype for a YOLO26 variant. + + Parameters + ---------- + variant : str + One of ``"n"``, ``"s"``, ``"m"``, ``"l"``, ``"x"``. + + Returns + ------- + torch.dtype + ``torch.float32`` for n/s, ``torch.bfloat16`` for m/l/x. + """ + return VARIANT_DTYPES[variant] + + +# --------------------------------------------------------------------------- +# Compilation +# --------------------------------------------------------------------------- + + +def compile_yolo26( + weight_path: str, + batch_size: int = 1, + dtype: torch.dtype | None = None, + save_path: str | None = None, + lnc: int | None = None, + compiler_args: list[str] | None = None, +) -> torch.jit.ScriptModule: + """Compile a YOLO26 model for Neuron inference. + + Parameters + ---------- + weight_path : str + Path to the ``.pt`` weight file. + batch_size : int + Batch size per NeuronCore. + dtype : torch.dtype, optional + Override dtype. If ``None``, uses the recommended dtype for the variant. + save_path : str, optional + If provided, saves the compiled model to this path. + lnc : int, optional + Logical NeuronCore config (1 or 2). Adds ``--lnc`` compiler flag if set. + compiler_args : list[str], optional + Additional compiler arguments. Defaults to no autocast flags. + + Returns + ------- + torch.jit.ScriptModule + The traced Neuron model. + """ + if torch_neuronx is None: + raise RuntimeError("torch_neuronx is not installed. Run on a Neuron instance.") + + # Infer variant from weight path + variant = _infer_variant(weight_path) + if dtype is None: + dtype = get_variant_dtype(variant) if variant else torch.float32 + + model = prepare_yolo26(weight_path, dtype=dtype) + + dummy = torch.randn(batch_size, *INPUT_SHAPE, dtype=dtype) + with torch.no_grad(): + _ = model(dummy) # dry run to populate anchors + + args = list(compiler_args) if compiler_args else list(DEFAULT_COMPILER_ARGS) + if lnc is not None: + args.extend(["--lnc", str(lnc)]) + + traced = torch_neuronx.trace(model, dummy, compiler_args=args) + + if save_path: + os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) + torch.jit.save(traced, save_path) + + return traced + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +class YOLO26NeuronModel: + """High-level wrapper for YOLO26 inference on Neuron. + + Handles compilation, caching, data-parallel scaling, and accuracy + validation in a single class. + + Parameters + ---------- + variant : str + One of ``"n"``, ``"s"``, ``"m"``, ``"l"``, ``"x"``. + batch_size : int + Batch size per NeuronCore. + cache_dir : str + Directory for cached compiled models. + lnc : int, optional + Logical NeuronCore config (1 or 2). Detected from environment if not set. + num_cores : int, optional + Number of NeuronCores for data parallelism. If > 1, wraps in DataParallel. + + Examples + -------- + >>> model = YOLO26NeuronModel("s", batch_size=8, num_cores=4) + >>> output = model(torch.randn(32, 3, 640, 640)) + >>> output.shape + torch.Size([32, 84, 8400]) + """ + + def __init__( + self, + variant: str, + batch_size: int = 1, + cache_dir: str = "compiled", + lnc: int | None = None, + num_cores: int = 1, + ): + self.variant = variant + self.batch_size = batch_size + self.dtype = get_variant_dtype(variant) + self.dtype_name = "bf16" if self.dtype == torch.bfloat16 else "fp32" + self.num_cores = num_cores + + if lnc is None: + lnc = int(os.environ.get("NEURON_LOGICAL_NC_CONFIG", "2")) + self.lnc = lnc + + weight_path = f"yolo26{variant}.pt" + neff_name = f"yolo26{variant}_{self.dtype_name}_bs{batch_size}_lnc{lnc}.pt" + save_path = os.path.join(cache_dir, neff_name) + + if os.path.exists(save_path): + self._model = torch.jit.load(save_path) + else: + self._model = compile_yolo26( + weight_path, + batch_size=batch_size, + dtype=self.dtype, + save_path=save_path, + lnc=lnc, + ) + + if num_cores > 1 and torch_neuronx is not None: + self._model = torch_neuronx.DataParallel( + self._model, + device_ids=list(range(num_cores)), + dim=0, + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Run inference. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape ``[B, 3, 640, 640]``. + + Returns + ------- + torch.Tensor + Raw detection output ``[B, 84, 8400]``. + """ + with torch.no_grad(): + return self._model(x.to(self.dtype)) + + def benchmark(self, warmup: int = 10, iterations: int = 50) -> dict: + """Measure throughput and latency. + + Returns + ------- + dict + Keys: ``p50_ms``, ``p95_ms``, ``p99_ms``, ``throughput_img_s``. + """ + import numpy as np + + total_bs = self.batch_size * self.num_cores + dummy = torch.randn(total_bs, *INPUT_SHAPE, dtype=self.dtype) + + for _ in range(warmup): + self(dummy) + + latencies = [] + for _ in range(iterations): + t0 = time.time() + self(dummy) + latencies.append((time.time() - t0) * 1000) + + lat = np.array(sorted(latencies)) + p50 = float(np.percentile(lat, 50)) + return { + "p50_ms": round(p50, 2), + "p95_ms": round(float(np.percentile(lat, 95)), 2), + "p99_ms": round(float(np.percentile(lat, 99)), 2), + "throughput_img_s": round(total_bs / (p50 / 1000), 1), + } + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +def validate_accuracy( + weight_path: str, + neuron_model: torch.jit.ScriptModule | YOLO26NeuronModel, + dtype: torch.dtype | None = None, + seed: int = 42, +) -> dict: + """Compare Neuron output against CPU reference. + + Parameters + ---------- + weight_path : str + Path to the ``.pt`` weight file. + neuron_model : ScriptModule or YOLO26NeuronModel + The compiled Neuron model. + dtype : torch.dtype, optional + Input dtype. Inferred from variant if ``None``. + seed : int + Random seed for reproducibility. + + Returns + ------- + dict + Keys: ``cosine_similarity``, ``max_error``, ``mean_error``, ``has_nan``. + """ + variant = _infer_variant(weight_path) + if dtype is None: + dtype = get_variant_dtype(variant) if variant else torch.float32 + + torch.manual_seed(seed) + dummy = torch.randn(1, *INPUT_SHAPE, dtype=dtype) + + cpu_model = prepare_yolo26(weight_path, dtype=dtype) + with torch.no_grad(): + cpu_out = cpu_model(dummy) + + if isinstance(neuron_model, YOLO26NeuronModel): + nrn_out = neuron_model(dummy) + else: + with torch.no_grad(): + nrn_out = neuron_model(dummy) + + cpu_flat = cpu_out.flatten().float() + nrn_flat = nrn_out.flatten().float() + + cossim = torch.nn.functional.cosine_similarity( + cpu_flat.unsqueeze(0), nrn_flat.unsqueeze(0) + ).item() + + diff = (cpu_flat - nrn_flat).abs() + + return { + "cosine_similarity": round(cossim, 6), + "max_error": round(diff.max().item(), 6), + "mean_error": round(diff.mean().item(), 6), + "has_nan": bool(torch.isnan(nrn_out).any().item()), + } + + +# --------------------------------------------------------------------------- +# Utilities +# --------------------------------------------------------------------------- + + +def get_neuron_core_count() -> int: + """Detect the number of available NeuronCores.""" + import subprocess + import json + + try: + result = subprocess.run( + ["neuron-ls", "--json-output"], + capture_output=True, + text=True, + timeout=10, + ) + data = json.loads(result.stdout) + total = sum(dev.get("nc_count", 0) for dev in data) + return total + except Exception: + return 0 + + +def _infer_variant(weight_path: str) -> str | None: + """Infer the YOLO26 variant letter from a weight file path.""" + base = os.path.basename(weight_path).lower() + for v in ("n", "s", "m", "l", "x"): + if f"yolo26{v}" in base: + return v + return None diff --git a/contrib/models/YOLO26/test/__init__.py b/contrib/models/YOLO26/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/YOLO26/test/integration/__init__.py b/contrib/models/YOLO26/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/YOLO26/test/integration/test_model.py b/contrib/models/YOLO26/test/integration/test_model.py new file mode 100644 index 00000000..567900a8 --- /dev/null +++ b/contrib/models/YOLO26/test/integration/test_model.py @@ -0,0 +1,465 @@ +"""YOLO26 Neuron Integration Tests. + +Self-contained integration tests for YOLO26 detection models on Neuron. +All helpers are local to this file (tests do not import from src/). + +Usage +----- +With pytest:: + + pytest test/integration/test_model.py -v + +Standalone:: + + python test/integration/test_model.py + +Prerequisites +------------- +- Neuron instance (trn2.3xlarge or inf2) +- ``pip install ultralytics`` +- YOLO26 weights auto-download on first run +""" + +import json +import os +import subprocess +import time + +import numpy as np +import pytest +import torch +import torch.nn as nn + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +VARIANTS = ["n", "s"] # Test small variants for speed; m/l/x follow same pattern +VARIANT_DTYPES = { + "n": torch.float32, + "s": torch.float32, + "m": torch.bfloat16, + "l": torch.bfloat16, + "x": torch.bfloat16, +} +INPUT_SHAPE = (3, 640, 640) +COSINE_SIM_THRESHOLD = 0.98 +CACHE_DIR = "/tmp/yolo26_test_cache" +WARMUP_ITERS = 10 +BENCHMARK_ITERS = 50 + + +# --------------------------------------------------------------------------- +# Helpers (self-contained, no src/ imports) +# --------------------------------------------------------------------------- + + +def get_neuron_core_count() -> int: + """Detect available NeuronCores via neuron-ls.""" + try: + result = subprocess.run( + ["neuron-ls", "--json-output"], + capture_output=True, + text=True, + timeout=10, + ) + data = json.loads(result.stdout) + return sum(dev.get("nc_count", 0) for dev in data) + except Exception: + return 0 + + +def prepare_yolo26(weight_path: str, dtype: torch.dtype = torch.float32) -> nn.Module: + """Prepare YOLO26 for tracing (same recipe as src/modeling_yolo26.py).""" + from ultralytics import YOLO + from ultralytics.nn.modules.block import C2f + + model = YOLO(weight_path) + pytorch_model = model.model.eval() + + detect = pytorch_model.model[-1] + detect.end2end = False + + pytorch_model = pytorch_model.fuse(verbose=False) + + for m in pytorch_model.modules(): + if hasattr(m, "export"): + m.export = True + if hasattr(m, "dynamic"): + m.dynamic = False + if hasattr(m, "format"): + m.format = "torchscript" + if hasattr(m, "shape"): + m.shape = None + if isinstance(m, C2f): + m.forward = m.forward_split + + if dtype != torch.float32: + pytorch_model = pytorch_model.to(dtype) + + return pytorch_model + + +def compile_neuron_model( + weight_path: str, batch_size: int = 1, dtype: torch.dtype = torch.float32 +) -> torch.jit.ScriptModule: + """Compile a YOLO26 model for Neuron, with caching.""" + import torch_neuronx + + variant = os.path.basename(weight_path).replace("yolo26", "").replace(".pt", "") + dtype_name = "bf16" if dtype == torch.bfloat16 else "fp32" + lnc = int(os.environ.get("NEURON_LOGICAL_NC_CONFIG", "2")) + cache_path = os.path.join( + CACHE_DIR, f"yolo26{variant}_{dtype_name}_bs{batch_size}_lnc{lnc}.pt" + ) + + if os.path.exists(cache_path): + return torch.jit.load(cache_path) + + model = prepare_yolo26(weight_path, dtype=dtype) + dummy = torch.randn(batch_size, *INPUT_SHAPE, dtype=dtype) + with torch.no_grad(): + _ = model(dummy) + + compiler_args = [] + if lnc == 1: + compiler_args = ["--lnc", "1"] + + traced = torch_neuronx.trace(model, dummy, compiler_args=compiler_args) + os.makedirs(CACHE_DIR, exist_ok=True) + torch.jit.save(traced, cache_path) + return traced + + +# --------------------------------------------------------------------------- +# Fixtures (module-scoped — compile once per test session) +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module", params=VARIANTS) +def variant(request): + return request.param + + +@pytest.fixture(scope="module") +def cpu_model_n(): + return prepare_yolo26("yolo26n.pt", dtype=torch.float32) + + +@pytest.fixture(scope="module") +def cpu_model_s(): + return prepare_yolo26("yolo26s.pt", dtype=torch.float32) + + +@pytest.fixture(scope="module") +def neuron_model_n(): + return compile_neuron_model("yolo26n.pt", batch_size=1, dtype=torch.float32) + + +@pytest.fixture(scope="module") +def neuron_model_s(): + return compile_neuron_model("yolo26s.pt", batch_size=1, dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# Tests: Model Loading & Compilation +# --------------------------------------------------------------------------- + + +class TestModelCompiles: + """Smoke tests: models compile and produce output of expected shape.""" + + def test_yolo26n_compiles(self, neuron_model_n): + dummy = torch.randn(1, *INPUT_SHAPE, dtype=torch.float32) + with torch.no_grad(): + out = neuron_model_n(dummy) + assert out.shape == (1, 84, 8400), ( + f"Expected [1, 84, 8400], got {list(out.shape)}" + ) + + def test_yolo26s_compiles(self, neuron_model_s): + dummy = torch.randn(1, *INPUT_SHAPE, dtype=torch.float32) + with torch.no_grad(): + out = neuron_model_s(dummy) + assert out.shape == (1, 84, 8400), ( + f"Expected [1, 84, 8400], got {list(out.shape)}" + ) + + def test_output_is_finite(self, neuron_model_n): + dummy = torch.randn(1, *INPUT_SHAPE, dtype=torch.float32) + with torch.no_grad(): + out = neuron_model_n(dummy) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + + +# --------------------------------------------------------------------------- +# Tests: Accuracy +# --------------------------------------------------------------------------- + + +class TestAccuracy: + """Cosine similarity between CPU and Neuron outputs.""" + + @pytest.mark.parametrize("seed", [42, 123, 7]) + def test_yolo26n_cosine_similarity(self, cpu_model_n, neuron_model_n, seed): + torch.manual_seed(seed) + dummy = torch.randn(1, *INPUT_SHAPE, dtype=torch.float32) + + with torch.no_grad(): + cpu_out = cpu_model_n(dummy) + nrn_out = neuron_model_n(dummy) + + cossim = torch.nn.functional.cosine_similarity( + cpu_out.flatten().float().unsqueeze(0), + nrn_out.flatten().float().unsqueeze(0), + ).item() + + assert cossim >= COSINE_SIM_THRESHOLD, ( + f"CosSim {cossim:.6f} below threshold {COSINE_SIM_THRESHOLD} (seed={seed})" + ) + + @pytest.mark.parametrize("seed", [42, 123, 7]) + def test_yolo26s_cosine_similarity(self, cpu_model_s, neuron_model_s, seed): + torch.manual_seed(seed) + dummy = torch.randn(1, *INPUT_SHAPE, dtype=torch.float32) + + with torch.no_grad(): + cpu_out = cpu_model_s(dummy) + nrn_out = neuron_model_s(dummy) + + cossim = torch.nn.functional.cosine_similarity( + cpu_out.flatten().float().unsqueeze(0), + nrn_out.flatten().float().unsqueeze(0), + ).item() + + assert cossim >= COSINE_SIM_THRESHOLD, ( + f"CosSim {cossim:.6f} below threshold {COSINE_SIM_THRESHOLD} (seed={seed})" + ) + + +# --------------------------------------------------------------------------- +# Tests: Data Parallel +# --------------------------------------------------------------------------- + + +class TestDataParallel: + """Multi-core data parallel execution.""" + + def test_dp_runs(self, neuron_model_n): + import torch_neuronx + + num_cores = get_neuron_core_count() + if num_cores < 2: + pytest.skip("Only 1 NeuronCore available") + + model_dp = torch_neuronx.DataParallel( + neuron_model_n, + device_ids=list(range(num_cores)), + dim=0, + ) + + total_bs = num_cores # BS=1 per core + dummy = torch.randn(total_bs, *INPUT_SHAPE, dtype=torch.float32) + with torch.no_grad(): + out = model_dp(dummy) + + assert out.shape[0] == total_bs, ( + f"Expected batch dim {total_bs}, got {out.shape[0]}" + ) + assert out.shape[1:] == (84, 8400) + + def test_dp_speedup(self, neuron_model_n): + import torch_neuronx + + num_cores = get_neuron_core_count() + if num_cores < 2: + pytest.skip("Only 1 NeuronCore available") + + # Single core baseline + dummy_1 = torch.randn(1, *INPUT_SHAPE, dtype=torch.float32) + for _ in range(WARMUP_ITERS): + with torch.no_grad(): + neuron_model_n(dummy_1) + + single_lats = [] + for _ in range(BENCHMARK_ITERS): + t0 = time.time() + with torch.no_grad(): + neuron_model_n(dummy_1) + single_lats.append(time.time() - t0) + single_tput = 1.0 / np.median(single_lats) + + # Multi-core + model_dp = torch_neuronx.DataParallel( + neuron_model_n, + device_ids=list(range(num_cores)), + dim=0, + ) + total_bs = num_cores + dummy_dp = torch.randn(total_bs, *INPUT_SHAPE, dtype=torch.float32) + + for _ in range(WARMUP_ITERS): + with torch.no_grad(): + model_dp(dummy_dp) + + dp_lats = [] + for _ in range(BENCHMARK_ITERS): + t0 = time.time() + with torch.no_grad(): + model_dp(dummy_dp) + dp_lats.append(time.time() - t0) + dp_tput = total_bs / np.median(dp_lats) + + speedup = dp_tput / single_tput + assert speedup > 1.5, ( + f"DP speedup {speedup:.2f}x too low (expected >1.5x with {num_cores} cores)" + ) + + +# --------------------------------------------------------------------------- +# Tests: Performance +# --------------------------------------------------------------------------- + + +class TestPerformance: + """Throughput thresholds for detection variants.""" + + def test_yolo26n_throughput(self, neuron_model_n): + dummy = torch.randn(1, *INPUT_SHAPE, dtype=torch.float32) + + for _ in range(WARMUP_ITERS): + with torch.no_grad(): + neuron_model_n(dummy) + + latencies = [] + for _ in range(BENCHMARK_ITERS): + t0 = time.time() + with torch.no_grad(): + neuron_model_n(dummy) + latencies.append((time.time() - t0) * 1000) + + p50 = np.median(latencies) + throughput = 1000.0 / p50 + # Single core yolo26n should be at least 25 img/s + assert throughput >= 25.0, ( + f"Throughput {throughput:.1f} img/s below minimum 25.0 (p50={p50:.1f}ms)" + ) + + def test_yolo26s_throughput(self, neuron_model_s): + dummy = torch.randn(1, *INPUT_SHAPE, dtype=torch.float32) + + for _ in range(WARMUP_ITERS): + with torch.no_grad(): + neuron_model_s(dummy) + + latencies = [] + for _ in range(BENCHMARK_ITERS): + t0 = time.time() + with torch.no_grad(): + neuron_model_s(dummy) + latencies.append((time.time() - t0) * 1000) + + p50 = np.median(latencies) + throughput = 1000.0 / p50 + # Single core yolo26s should be at least 50 img/s + assert throughput >= 50.0, ( + f"Throughput {throughput:.1f} img/s below minimum 50.0 (p50={p50:.1f}ms)" + ) + + +# --------------------------------------------------------------------------- +# Standalone runner +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("=" * 60) + print("YOLO26 Neuron Integration Tests") + print("=" * 60) + + num_cores = get_neuron_core_count() + lnc = os.environ.get("NEURON_LOGICAL_NC_CONFIG", "2") + print(f"NeuronCores: {num_cores}, LNC: {lnc}") + + # [1/6] Compile n + print("\n[1/6] Compiling yolo26n (FP32, BS=1)...") + t0 = time.time() + neuron_n = compile_neuron_model("yolo26n.pt", batch_size=1, dtype=torch.float32) + print(f" Done in {time.time() - t0:.1f}s") + + # [2/6] Compile s + print("\n[2/6] Compiling yolo26s (FP32, BS=1)...") + t0 = time.time() + neuron_s = compile_neuron_model("yolo26s.pt", batch_size=1, dtype=torch.float32) + print(f" Done in {time.time() - t0:.1f}s") + + # [3/6] Shape test + print("\n[3/6] Testing output shapes...") + for name, model, dtype in [ + ("n", neuron_n, torch.float32), + ("s", neuron_s, torch.float32), + ]: + dummy = torch.randn(1, *INPUT_SHAPE, dtype=dtype) + with torch.no_grad(): + out = model(dummy) + print(f" yolo26{name}: {list(out.shape)}, NaN={torch.isnan(out).any().item()}") + assert out.shape == (1, 84, 8400) + + # [4/6] Accuracy + print("\n[4/6] Accuracy validation...") + for name, neuron_model, dtype in [ + ("n", neuron_n, torch.float32), + ("s", neuron_s, torch.float32), + ]: + cpu_model = prepare_yolo26(f"yolo26{name}.pt", dtype=dtype) + torch.manual_seed(42) + dummy = torch.randn(1, *INPUT_SHAPE, dtype=dtype) + with torch.no_grad(): + cpu_out = cpu_model(dummy) + nrn_out = neuron_model(dummy) + cossim = torch.nn.functional.cosine_similarity( + cpu_out.flatten().float().unsqueeze(0), + nrn_out.flatten().float().unsqueeze(0), + ).item() + print( + f" yolo26{name}: CosSim = {cossim:.6f} {'PASS' if cossim >= COSINE_SIM_THRESHOLD else 'FAIL'}" + ) + + # [5/6] Throughput + print("\n[5/6] Throughput benchmark (single core)...") + for name, model, dtype in [ + ("n", neuron_n, torch.float32), + ("s", neuron_s, torch.float32), + ]: + dummy = torch.randn(1, *INPUT_SHAPE, dtype=dtype) + for _ in range(WARMUP_ITERS): + with torch.no_grad(): + model(dummy) + lats = [] + for _ in range(BENCHMARK_ITERS): + t0 = time.time() + with torch.no_grad(): + model(dummy) + lats.append((time.time() - t0) * 1000) + p50 = np.median(lats) + print(f" yolo26{name}: {1000 / p50:.1f} img/s (p50={p50:.1f}ms)") + + # [6/6] Data Parallel + print("\n[6/6] Data Parallel test...") + if num_cores >= 2: + import torch_neuronx + + model_dp = torch_neuronx.DataParallel( + neuron_n, device_ids=list(range(num_cores)), dim=0 + ) + dummy_dp = torch.randn(num_cores, *INPUT_SHAPE, dtype=torch.float32) + with torch.no_grad(): + out_dp = model_dp(dummy_dp) + print(f" DP={num_cores}: output shape {list(out_dp.shape)}") + else: + print(" Skipped (only 1 NeuronCore)") + + print("\n" + "=" * 60) + print("All tests complete.") + print("=" * 60) diff --git a/contrib/models/YOLO26/test/unit/__init__.py b/contrib/models/YOLO26/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/YOLO26/yolo26_neuron_notebook.ipynb b/contrib/models/YOLO26/yolo26_neuron_notebook.ipynb new file mode 100644 index 00000000..ccd628ad --- /dev/null +++ b/contrib/models/YOLO26/yolo26_neuron_notebook.ipynb @@ -0,0 +1,637 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# YOLO26 Object Detection on AWS Trainium2 with Neuron SDK\n", + "\n", + "This notebook demonstrates compiling and running [Ultralytics YOLO26](https://github.com/ultralytics/ultralytics) detection models on AWS Trainium2 (trn2) using `torch_neuronx.trace()`. All 5 detection variants (n/s/m/l/x) are covered, along with data-parallel scaling across NeuronCores.\n", + "\n", + "| | |\n", + "|---|---|\n", + "| **Instance** | trn2.3xlarge (1 Trainium2 chip, 4/8 NeuronCores) |\n", + "| **SDK** | Neuron SDK 2.29 (DLAMI 20260410) |\n", + "| **Time to complete** | ~15 minutes (compilation + benchmarks) |\n", + "| **Prerequisites** | trn2 instance with Neuron DLAMI |\n", + "\n", + "**Quick Start:**\n", + "```bash\n", + "source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate\n", + "pip install ultralytics\n", + "jupyter lab --no-browser --port=8888\n", + "```\n", + "\n", + "### Key Results\n", + "\n", + "| Variant | Params | Neuron Peak (LNC=1, DP=8) | A10G Compiled FP16 | Neuron Speedup |\n", + "|---------|--------|--------------------------|--------------------|-----------|\n", + "| YOLO26n | 2.4M | 272 img/s | 2,166 img/s | 0.13x |\n", + "| YOLO26s | 10.0M | **1,523 img/s** | 1,065 img/s | **1.43x** |\n", + "| YOLO26m | 21.9M | **1,267 img/s** | 474 img/s | **2.67x** |\n", + "| YOLO26l | 26.3M | **1,093 img/s** | 371 img/s | **2.95x** |\n", + "| YOLO26x | 58.9M | **876 img/s** | 195 img/s | **4.49x** |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 0: Setup & Environment Check" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "subprocess.run([\"pip\", \"install\", \"-q\", \"ultralytics\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import json\n", + "\n", + "import torch\n", + "import torch_neuronx\n", + "import numpy as np\n", + "from ultralytics import YOLO\n", + "from ultralytics.nn.modules.block import C2f\n", + "\n", + "print(f\"PyTorch: {torch.__version__}\")\n", + "print(f\"torch-neuronx: {torch_neuronx.__version__}\")\n", + "print(f\"Neuron SDK: {torch_neuronx.__version__}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show available NeuronCores\n", + "subprocess.run([\"neuron-ls\"])\n", + "\n", + "# Check LNC mode\n", + "lnc = os.environ.get(\"NEURON_LOGICAL_NC_CONFIG\", \"2\")\n", + "print(f\"\\nLNC mode: {lnc}\")\n", + "num_cores = 8 if lnc == \"1\" else 4\n", + "print(f\"Available logical cores: {num_cores}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 1: Download Model Weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "VARIANTS = [\"n\", \"s\", \"m\", \"l\", \"x\"]\n", + "WEIGHT_DIR = \".\"\n", + "\n", + "for v in VARIANTS:\n", + " wp = f\"yolo26{v}.pt\"\n", + " if not os.path.exists(wp):\n", + " print(f\"Downloading {wp}...\")\n", + " _ = YOLO(wp) # ultralytics auto-downloads\n", + " else:\n", + " print(f\"{wp} already exists\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 2: Model Preparation & Tracing Recipe\n", + "\n", + "YOLO26 requires specific preparation before `torch_neuronx.trace()`:\n", + "\n", + "1. **`end2end=False`** before `fuse()` -- disables `topk` (unsupported on Neuron)\n", + "2. **`fuse()`** -- merges BatchNorm into Conv2d, removes training branches\n", + "3. **`export=True`** -- single-tensor output for clean tracing\n", + "4. **`C2f.forward_split`** -- uses `split()` instead of `chunk()` for tracing\n", + "5. **Dry run** -- populates anchor/stride caches\n", + "6. **BF16 for m/l/x** -- FP32 exceeds SB allocation for larger variants\n", + "\n", + "Output: raw `[B, 84, 8400]` tensor (4 bbox coords + 80 COCO class scores per anchor). Postprocessing (topk, NMS) happens on CPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_yolo26(weight_path, dtype=torch.float32):\n", + " \"\"\"Prepare YOLO26 detection model for Neuron tracing.\"\"\"\n", + " model = YOLO(weight_path)\n", + " pytorch_model = model.model.eval()\n", + "\n", + " # Disable end2end BEFORE fuse (topk not supported on Neuron)\n", + " detect = pytorch_model.model[-1]\n", + " detect.end2end = False\n", + "\n", + " pytorch_model = pytorch_model.fuse(verbose=False)\n", + "\n", + " for m in pytorch_model.modules():\n", + " if hasattr(m, \"export\"):\n", + " m.export = True\n", + " if hasattr(m, \"dynamic\"):\n", + " m.dynamic = False\n", + " if hasattr(m, \"format\"):\n", + " m.format = \"torchscript\"\n", + " if hasattr(m, \"shape\"):\n", + " m.shape = None\n", + " if isinstance(m, C2f):\n", + " m.forward = m.forward_split\n", + "\n", + " if dtype != torch.float32:\n", + " pytorch_model = pytorch_model.to(dtype)\n", + "\n", + " return pytorch_model\n", + "\n", + "\n", + "# Variant configurations: n/s use FP32, m/l/x require BF16\n", + "VARIANT_CONFIG = {\n", + " \"n\": {\"dtype\": torch.float32, \"dtype_name\": \"fp32\"},\n", + " \"s\": {\"dtype\": torch.float32, \"dtype_name\": \"fp32\"},\n", + " \"m\": {\"dtype\": torch.bfloat16, \"dtype_name\": \"bf16\"},\n", + " \"l\": {\"dtype\": torch.bfloat16, \"dtype_name\": \"bf16\"},\n", + " \"x\": {\"dtype\": torch.bfloat16, \"dtype_name\": \"bf16\"},\n", + "}\n", + "\n", + "print(\"Variant configuration:\")\n", + "for v, cfg in VARIANT_CONFIG.items():\n", + " print(f\" yolo26{v}: {cfg['dtype_name']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 3: Compile All Variants\n", + "\n", + "Compile all 5 detection variants for Neuron. Each compilation takes 30-70 seconds." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "COMPILE_DIR = \"compiled\"\n", + "os.makedirs(COMPILE_DIR, exist_ok=True)\n", + "BS = 1 # Batch size for initial compilation\n", + "\n", + "# Check LNC mode for compiler args\n", + "compiler_args = []\n", + "if lnc == \"1\":\n", + " compiler_args = [\"--lnc\", \"1\"]\n", + " print(\"Using --lnc 1 compiler flag for LNC=1 mode\")\n", + "\n", + "compile_results = {}\n", + "\n", + "for v in VARIANTS:\n", + " cfg = VARIANT_CONFIG[v]\n", + " neff_path = os.path.join(COMPILE_DIR, f\"yolo26{v}_{cfg['dtype_name']}_bs{BS}.pt\")\n", + "\n", + " if os.path.exists(neff_path):\n", + " print(f\"yolo26{v}: already compiled at {neff_path}\")\n", + " compile_results[v] = {\"status\": \"cached\", \"path\": neff_path}\n", + " continue\n", + "\n", + " print(f\"\\nCompiling yolo26{v} ({cfg['dtype_name']}, BS={BS})...\")\n", + " model = prepare_yolo26(f\"yolo26{v}.pt\", dtype=cfg[\"dtype\"])\n", + "\n", + " dummy = torch.randn(BS, 3, 640, 640, dtype=cfg[\"dtype\"])\n", + " with torch.no_grad():\n", + " _ = model(dummy) # dry run for anchors\n", + "\n", + " t0 = time.time()\n", + " traced = torch_neuronx.trace(model, dummy, compiler_args=compiler_args)\n", + " compile_time = time.time() - t0\n", + "\n", + " torch.jit.save(traced, neff_path)\n", + " neff_size = os.path.getsize(neff_path) / (1024 * 1024)\n", + "\n", + " print(f\" Compiled in {compile_time:.1f}s, NEFF: {neff_size:.1f} MB\")\n", + " compile_results[v] = {\n", + " \"status\": \"compiled\",\n", + " \"path\": neff_path,\n", + " \"compile_time_s\": round(compile_time, 1),\n", + " \"neff_size_mb\": round(neff_size, 1),\n", + " }\n", + " del model, traced\n", + "\n", + "print(\"\\nCompilation summary:\")\n", + "for v, r in compile_results.items():\n", + " ct = r.get('compile_time_s', 'cached')\n", + " ns = r.get('neff_size_mb', '?')\n", + " print(f\" yolo26{v}: {ct}s, {ns} MB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 4: Accuracy Validation\n", + "\n", + "Compare Neuron output against CPU reference using cosine similarity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{'Variant':<10} {'Dtype':<6} {'CosSim':<10} {'MaxErr':<12} {'MeanErr':<12} {'Neuron NaN':<10}\")\n", + "print(\"-\" * 60)\n", + "\n", + "for v in VARIANTS:\n", + " cfg = VARIANT_CONFIG[v]\n", + " neff_path = compile_results[v][\"path\"]\n", + "\n", + " # CPU reference\n", + " model_cpu = prepare_yolo26(f\"yolo26{v}.pt\", dtype=cfg[\"dtype\"])\n", + " dummy = torch.randn(1, 3, 640, 640, dtype=cfg[\"dtype\"])\n", + " with torch.no_grad():\n", + " cpu_out = model_cpu(dummy)\n", + "\n", + " # Neuron inference\n", + " traced = torch.jit.load(neff_path)\n", + " with torch.no_grad():\n", + " nrn_out = traced(dummy)\n", + "\n", + " # Compare\n", + " cpu_flat = cpu_out.flatten().float()\n", + " nrn_flat = nrn_out.flatten().float()\n", + " cossim = torch.nn.functional.cosine_similarity(cpu_flat.unsqueeze(0), nrn_flat.unsqueeze(0)).item()\n", + " max_err = (cpu_flat - nrn_flat).abs().max().item()\n", + " mean_err = (cpu_flat - nrn_flat).abs().mean().item()\n", + " has_nan = torch.isnan(nrn_out).any().item()\n", + "\n", + " print(f\"yolo26{v:<4} {cfg['dtype_name']:<6} {cossim:<10.6f} {max_err:<12.6f} {mean_err:<12.6f} {has_nan}\")\n", + "\n", + " del model_cpu, traced" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 5: Single-Core Latency Benchmark" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "WARMUP = 20\n", + "ITERATIONS = 100\n", + "\n", + "print(f\"{'Variant':<10} {'Dtype':<6} {'p50 (ms)':<12} {'p95 (ms)':<12} {'p99 (ms)':<12} {'img/s':<10}\")\n", + "print(\"-\" * 62)\n", + "\n", + "single_core_results = {}\n", + "\n", + "for v in VARIANTS:\n", + " cfg = VARIANT_CONFIG[v]\n", + " traced = torch.jit.load(compile_results[v][\"path\"])\n", + " dummy = torch.randn(1, 3, 640, 640, dtype=cfg[\"dtype\"])\n", + "\n", + " # Warmup\n", + " for _ in range(WARMUP):\n", + " with torch.no_grad():\n", + " traced(dummy)\n", + "\n", + " # Measure\n", + " latencies = []\n", + " for _ in range(ITERATIONS):\n", + " t0 = time.time()\n", + " with torch.no_grad():\n", + " traced(dummy)\n", + " latencies.append((time.time() - t0) * 1000)\n", + "\n", + " lat = np.array(sorted(latencies))\n", + " p50 = np.percentile(lat, 50)\n", + " p95 = np.percentile(lat, 95)\n", + " p99 = np.percentile(lat, 99)\n", + " throughput = 1000.0 / p50\n", + "\n", + " print(f\"yolo26{v:<4} {cfg['dtype_name']:<6} {p50:<12.2f} {p95:<12.2f} {p99:<12.2f} {throughput:<10.1f}\")\n", + " single_core_results[v] = {\"p50\": p50, \"throughput\": throughput}\n", + "\n", + " del traced" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 6: Data-Parallel Scaling\n", + "\n", + "Scale inference across all NeuronCores using `torch_neuronx.DataParallel`. Each core runs an independent model replica, multiplying throughput.\n", + "\n", + "On trn2.3xlarge:\n", + "- **LNC=2** (default): 4 logical cores, DP=4\n", + "- **LNC=1**: 8 logical cores, DP=8 (recommended for YOLO26)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Pick yolo26s (best overall balance) for the DP demo\n", + "DP_WARMUP = 10\n", + "DP_ITERATIONS = 50\n", + "DEMO_VARIANT = \"s\"\n", + "cfg = VARIANT_CONFIG[DEMO_VARIANT]\n", + "BATCH_SIZES = [1, 4, 8]\n", + "\n", + "print(f\"Data-Parallel benchmark: yolo26{DEMO_VARIANT} ({cfg['dtype_name']}), DP={num_cores}\")\n", + "print(f\"{'BS/core':<10} {'Total BS':<10} {'p50 (ms)':<12} {'Throughput':<14} {'vs BS=1':<10}\")\n", + "print(\"-\" * 56)\n", + "\n", + "dp_results = []\n", + "\n", + "for bs in BATCH_SIZES:\n", + " # Compile for this batch size if not already done\n", + " neff_path = os.path.join(COMPILE_DIR, f\"yolo26{DEMO_VARIANT}_{cfg['dtype_name']}_bs{bs}.pt\")\n", + " if not os.path.exists(neff_path):\n", + " model = prepare_yolo26(f\"yolo26{DEMO_VARIANT}.pt\", dtype=cfg[\"dtype\"])\n", + " dummy = torch.randn(bs, 3, 640, 640, dtype=cfg[\"dtype\"])\n", + " with torch.no_grad():\n", + " _ = model(dummy)\n", + " traced = torch_neuronx.trace(model, dummy, compiler_args=compiler_args)\n", + " torch.jit.save(traced, neff_path)\n", + " del model, traced\n", + "\n", + " # Load and create DataParallel\n", + " traced = torch.jit.load(neff_path)\n", + " model_dp = torch_neuronx.DataParallel(\n", + " traced,\n", + " device_ids=list(range(num_cores)),\n", + " dim=0,\n", + " )\n", + "\n", + " total_bs = bs * num_cores\n", + " dummy_dp = torch.randn(total_bs, 3, 640, 640, dtype=cfg[\"dtype\"])\n", + "\n", + " # Warmup\n", + " for _ in range(DP_WARMUP):\n", + " with torch.no_grad():\n", + " model_dp(dummy_dp)\n", + "\n", + " # Measure\n", + " latencies = []\n", + " for _ in range(DP_ITERATIONS):\n", + " t0 = time.time()\n", + " with torch.no_grad():\n", + " model_dp(dummy_dp)\n", + " latencies.append((time.time() - t0) * 1000)\n", + "\n", + " lat = np.array(sorted(latencies))\n", + " p50 = np.percentile(lat, 50)\n", + " throughput = total_bs / (p50 / 1000)\n", + "\n", + " dp_results.append({\"bs\": bs, \"total_bs\": total_bs, \"p50\": p50, \"throughput\": throughput})\n", + "\n", + " ratio = throughput / dp_results[0][\"throughput\"] if dp_results else 1.0\n", + " print(f\"{bs:<10} {total_bs:<10} {p50:<12.2f} {throughput:<14.1f} {ratio:<10.2f}x\")\n", + "\n", + " del model_dp, traced\n", + "\n", + "peak = max(dp_results, key=lambda x: x[\"throughput\"])\n", + "print(f\"\\nPeak: {peak['throughput']:.1f} img/s at BS={peak['bs']}/core (total BS={peak['total_bs']})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 7: Multi-Variant DP Benchmark\n", + "\n", + "Benchmark all 5 variants at their optimal batch sizes with full data parallelism." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Optimal BS per variant (from Task 008 sweep)\n", + "# LNC=1 (DP=8): n=1, s=32, m=32, l=32, x=16\n", + "# LNC=2 (DP=4): use smaller BS to reduce compile+benchmark time\n", + "if lnc == \"1\":\n", + " OPTIMAL_BS = {\"n\": 1, \"s\": 32, \"m\": 32, \"l\": 32, \"x\": 16}\n", + "else:\n", + " OPTIMAL_BS = {\"n\": 1, \"s\": 8, \"m\": 8, \"l\": 8, \"x\": 4}\n", + "\n", + "print(f\"All variants, DP={num_cores}, optimal batch sizes\")\n", + "print(f\"{'Variant':<10} {'Dtype':<6} {'BS/core':<10} {'Total BS':<10} {'p50 (ms)':<12} {'img/s':<12}\")\n", + "print(\"-\" * 60)\n", + "\n", + "all_variant_results = {}\n", + "\n", + "for v in VARIANTS:\n", + " cfg = VARIANT_CONFIG[v]\n", + " bs = OPTIMAL_BS[v]\n", + " total_bs = bs * num_cores\n", + "\n", + " # Compile if needed\n", + " neff_path = os.path.join(COMPILE_DIR, f\"yolo26{v}_{cfg['dtype_name']}_bs{bs}.pt\")\n", + " if not os.path.exists(neff_path):\n", + " model = prepare_yolo26(f\"yolo26{v}.pt\", dtype=cfg[\"dtype\"])\n", + " dummy = torch.randn(bs, 3, 640, 640, dtype=cfg[\"dtype\"])\n", + " with torch.no_grad():\n", + " _ = model(dummy)\n", + " traced = torch_neuronx.trace(model, dummy, compiler_args=compiler_args)\n", + " torch.jit.save(traced, neff_path)\n", + " del model, traced\n", + "\n", + " traced = torch.jit.load(neff_path)\n", + " model_dp = torch_neuronx.DataParallel(\n", + " traced, device_ids=list(range(num_cores)), dim=0\n", + " )\n", + "\n", + " dummy_dp = torch.randn(total_bs, 3, 640, 640, dtype=cfg[\"dtype\"])\n", + "\n", + " for _ in range(DP_WARMUP):\n", + " with torch.no_grad():\n", + " model_dp(dummy_dp)\n", + "\n", + " latencies = []\n", + " for _ in range(DP_ITERATIONS):\n", + " t0 = time.time()\n", + " with torch.no_grad():\n", + " model_dp(dummy_dp)\n", + " latencies.append((time.time() - t0) * 1000)\n", + "\n", + " lat = np.array(sorted(latencies))\n", + " p50 = np.percentile(lat, 50)\n", + " throughput = total_bs / (p50 / 1000)\n", + "\n", + " print(f\"yolo26{v:<4} {cfg['dtype_name']:<6} {bs:<10} {total_bs:<10} {p50:<12.2f} {throughput:<12.1f}\")\n", + " all_variant_results[v] = {\"throughput\": throughput, \"p50\": p50, \"bs\": bs}\n", + "\n", + " del model_dp, traced" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 8: GPU Comparison\n", + "\n", + "Comparison against NVIDIA A10G (g5.xlarge) with `torch.compile(mode='reduce-overhead')` + FP16.\n", + "\n", + "| Variant | Neuron trn2 (DP=8) | A10G Compiled FP16 | Neuron / A10G | Cost: Neuron $/M | Cost: A10G $/M |\n", + "|---------|-------------------|--------------------|---------------|-----------------|----------------|\n", + "| YOLO26n | 272 img/s | **2,166 img/s** | 0.13x | $3.87 | $0.13 |\n", + "| YOLO26s | **1,523 img/s** | 1,065 img/s | **1.43x** | $0.69 | $0.26 |\n", + "| YOLO26m | **1,267 img/s** | 474 img/s | **2.67x** | $0.83 | $0.59 |\n", + "| YOLO26l | **1,093 img/s** | 371 img/s | **2.95x** | $0.96 | $0.75 |\n", + "| YOLO26x | **876 img/s** | 195 img/s | **4.49x** | $1.20 | $1.43 |\n", + "\n", + "*Pricing: trn2.3xlarge ~$3.79/hr, g5.xlarge $1.006/hr (us-east-1 on-demand)*\n", + "\n", + "**Neuron wins 1.4-4.5x on throughput for s/m/l/x variants.** Cost-advantaged on YOLO26x ($1.20 vs $1.43 per million images). The advantage grows with model size." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 9: Additional Task Variants\n", + "\n", + "The same tracing recipe works for segmentation, pose, and OBB variants:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ADDITIONAL_VARIANTS = {\n", + " \"pose\": {\"weight\": \"yolo26s-pose.pt\", \"input\": (1, 3, 640, 640)},\n", + " \"obb\": {\"weight\": \"yolo26s-obb.pt\", \"input\": (1, 3, 640, 640)},\n", + "}\n", + "\n", + "for task_name, task_cfg in ADDITIONAL_VARIANTS.items():\n", + " print(f\"\\n--- {task_name.upper()} ---\")\n", + " wp = task_cfg[\"weight\"]\n", + " if not os.path.exists(wp):\n", + " _ = YOLO(wp)\n", + "\n", + " model = prepare_yolo26(wp)\n", + " dummy = torch.randn(*task_cfg[\"input\"])\n", + " with torch.no_grad():\n", + " cpu_out = model(dummy)\n", + "\n", + " neff_path = os.path.join(COMPILE_DIR, f\"{wp.replace('.pt', '_neuron.pt')}\")\n", + " if not os.path.exists(neff_path):\n", + " traced = torch_neuronx.trace(model, dummy, compiler_args=compiler_args)\n", + " torch.jit.save(traced, neff_path)\n", + " else:\n", + " traced = torch.jit.load(neff_path)\n", + "\n", + " with torch.no_grad():\n", + " nrn_out = traced(dummy)\n", + "\n", + " cossim = torch.nn.functional.cosine_similarity(\n", + " cpu_out.flatten().float().unsqueeze(0),\n", + " nrn_out.flatten().float().unsqueeze(0)\n", + " ).item()\n", + "\n", + " # Latency\n", + " for _ in range(10):\n", + " with torch.no_grad():\n", + " traced(dummy)\n", + " latencies = []\n", + " for _ in range(50):\n", + " t0 = time.time()\n", + " with torch.no_grad():\n", + " traced(dummy)\n", + " latencies.append((time.time() - t0) * 1000)\n", + " p50 = np.median(latencies)\n", + "\n", + " print(f\" Output: {list(nrn_out.shape) if isinstance(nrn_out, torch.Tensor) else [list(t.shape) for t in nrn_out]}\")\n", + " print(f\" CosSim: {cossim:.6f}\")\n", + " print(f\" Latency: {p50:.2f} ms ({1000/p50:.1f} img/s)\")\n", + "\n", + " del model, traced" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Summary\n", + "\n", + "This notebook demonstrated:\n", + "\n", + "- **All 5 YOLO26 detection variants** (n/s/m/l/x) compile and run on Neuron\n", + "- **Same tracing recipe** works across detection, pose, and OBB task heads\n", + "- **Data-parallel scaling** across 4-8 NeuronCores for throughput multiplication\n", + "- **Neuron outperforms A10G** by 1.4-4.5x on s/m/l/x variants (compiled GPU baseline)\n", + "- **Key tracing requirements**: `end2end=False` (topk unsupported), BF16 for m/l/x (FP32 SB overflow), no `--auto-cast` flags (NaN for Conv2d models)\n", + "\n", + "### Recommended Configurations\n", + "\n", + "| Variant | LNC | DP | BS/core | Peak img/s |\n", + "|---------|-----|-----|---------|------------|\n", + "| YOLO26n | 1 | 8 | 1 (FP32) | 272 |\n", + "| YOLO26s | 1 | 8 | 32 (BF16) | 1,523 |\n", + "| YOLO26m | 1 | 8 | 32 (BF16) | 1,267 |\n", + "| YOLO26l | 1 | 8 | 32 (BF16) | 1,093 |\n", + "| YOLO26x | 1 | 8 | 16 (BF16) | 876 |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 1adc61750642f3a0d3e3352c92435ca507f12f89 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Wed, 29 Apr 2026 13:36:58 -0400 Subject: [PATCH 2/3] Validate on SDK 2.28/2.29, trn2 and inf2 Tested all 4 combinations: - trn2.3xlarge SDK 2.28: 13/13 pytest passed - trn2.3xlarge SDK 2.29: 13/13 pytest passed - inf2.xlarge SDK 2.28: 6/6 standalone tests passed - inf2.xlarge SDK 2.29: 6/6 standalone tests passed inf2 single-core throughput: yolo26n 60-70 img/s, yolo26s 64-77 img/s. Updated compatibility matrix and notebook prerequisites. --- contrib/models/YOLO26/README.md | 19 ++++++++++++++----- .../YOLO26/yolo26_neuron_notebook.ipynb | 8 ++++---- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/contrib/models/YOLO26/README.md b/contrib/models/YOLO26/README.md index 44753f30..a8dc16c0 100644 --- a/contrib/models/YOLO26/README.md +++ b/contrib/models/YOLO26/README.md @@ -25,9 +25,9 @@ Key Neuron porting challenges: ## Validation Results -**Validated:** 2026-04-25 -**Instance:** trn2.3xlarge (1 Trainium2 chip) -**SDK:** Neuron SDK 2.28, PyTorch 2.9 +**Validated:** 2026-04-29 +**Instance:** trn2.3xlarge (1 Trainium2 chip), inf2.xlarge (1 Inferentia2 chip) +**SDK:** Neuron SDK 2.28 and 2.29, PyTorch 2.9 ### Peak Throughput (LNC=1, DP=8) @@ -105,9 +105,18 @@ print(f"CosSim: {metrics['cosine_similarity']}") | Instance | SDK 2.28 | SDK 2.29 | |----------|----------|----------| -| trn2.3xlarge | Validated | Expected compatible | +| trn2.3xlarge | Validated (13/13 tests) | Validated (13/13 tests) | | trn2.48xlarge | Expected compatible | Expected compatible | -| inf2.xlarge | Not tested (n/s only) | Not tested | +| inf2.xlarge | Validated (6/6 tests, n/s only) | Validated (6/6 tests, n/s only) | + +### inf2 Single-Core Throughput + +| Variant | inf2 SDK 2.28 | inf2 SDK 2.29 | trn2 SDK 2.28 | +|---------|---------------|---------------|---------------| +| YOLO26n | 60.1 img/s | 69.7 img/s | 32.3 img/s | +| YOLO26s | 64.1 img/s | 76.7 img/s | 66.0 img/s | + +*Note: inf2 single-core outperforms trn2 single-core for small models. trn2 advantage comes from DP=8 (LNC=1) scaling.* ## Testing Instructions diff --git a/contrib/models/YOLO26/yolo26_neuron_notebook.ipynb b/contrib/models/YOLO26/yolo26_neuron_notebook.ipynb index ccd628ad..554386ec 100644 --- a/contrib/models/YOLO26/yolo26_neuron_notebook.ipynb +++ b/contrib/models/YOLO26/yolo26_neuron_notebook.ipynb @@ -10,10 +10,10 @@ "\n", "| | |\n", "|---|---|\n", - "| **Instance** | trn2.3xlarge (1 Trainium2 chip, 4/8 NeuronCores) |\n", - "| **SDK** | Neuron SDK 2.29 (DLAMI 20260410) |\n", - "| **Time to complete** | ~15 minutes (compilation + benchmarks) |\n", - "| **Prerequisites** | trn2 instance with Neuron DLAMI |\n", + "| **Instance** | trn2.3xlarge (1 Trainium2 chip, 4/8 NeuronCores) |\n", + "| **SDK** | Neuron SDK 2.28 or 2.29 |\n", + "| **Time to complete** | ~15 minutes (compilation + benchmarks) |\n", + "| **Prerequisites** | trn2 or inf2 instance with Neuron DLAMI |\n", "\n", "**Quick Start:**\n", "```bash\n", From 9ad5d38a4907d2b924b83e296e3ba858e39ad9d8 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Wed, 29 Apr 2026 15:10:34 -0400 Subject: [PATCH 3/3] Validate all 5 YOLO26 variants on inf2.xlarge (SDK 2.29) --- contrib/models/YOLO26/README.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/contrib/models/YOLO26/README.md b/contrib/models/YOLO26/README.md index a8dc16c0..596a51e0 100644 --- a/contrib/models/YOLO26/README.md +++ b/contrib/models/YOLO26/README.md @@ -107,16 +107,19 @@ print(f"CosSim: {metrics['cosine_similarity']}") |----------|----------|----------| | trn2.3xlarge | Validated (13/13 tests) | Validated (13/13 tests) | | trn2.48xlarge | Expected compatible | Expected compatible | -| inf2.xlarge | Validated (6/6 tests, n/s only) | Validated (6/6 tests, n/s only) | +| inf2.xlarge | Validated (n/s) | Validated (all 5 variants) | -### inf2 Single-Core Throughput +### inf2 Single-Core Throughput (SDK 2.29) -| Variant | inf2 SDK 2.28 | inf2 SDK 2.29 | trn2 SDK 2.28 | -|---------|---------------|---------------|---------------| -| YOLO26n | 60.1 img/s | 69.7 img/s | 32.3 img/s | -| YOLO26s | 64.1 img/s | 76.7 img/s | 66.0 img/s | +| Variant | Dtype | CosSim | img/s | trn2 single-core | +|---------|-------|--------|-------|------------------| +| YOLO26n | FP32 | 0.9966 | 55.2 | 32.3 | +| YOLO26s | FP32 | 0.9934 | 82.8 | 66.0 | +| YOLO26m | BF16 | 0.9877 | 104.0 | 75.5 | +| YOLO26l | BF16 | 0.9966 | 91.6 | 66.3 | +| YOLO26x | BF16 | 0.9950 | 69.9 | 57.3 | -*Note: inf2 single-core outperforms trn2 single-core for small models. trn2 advantage comes from DP=8 (LNC=1) scaling.* +*Note: inf2 single-core outperforms trn2 single-core for all variants. trn2 advantage comes from DP=8 (LNC=1) scaling.* ## Testing Instructions