diff --git a/contrib/models/S3Diff/README.md b/contrib/models/S3Diff/README.md new file mode 100644 index 00000000..c0dfcbc0 --- /dev/null +++ b/contrib/models/S3Diff/README.md @@ -0,0 +1,136 @@ +# Contrib Model: S3Diff + +S3Diff one-step 4x super-resolution on AWS Neuron using `torch_neuronx.trace()`. + +## Model Information + +- **HuggingFace ID:** `zhangap/S3Diff` (weights), base model `stabilityai/sd-turbo` +- **Model Type:** One-step diffusion model for image super-resolution +- **Parameters:** ~1.3B total (~2 GB on disk) +- **Architecture:** SD-Turbo UNet with degradation-guided dynamic LoRA modulation (DEResNet encoder, CLIP text encoder, VAE encoder/decoder with per-layer LoRA, UNet with per-layer LoRA) +- **Paper:** "Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors" (ECCV 2024) +- **License:** Check model cards for SD-Turbo and S3Diff + +## Key Architecture Notes + +S3Diff is unusual among diffusion models: + +1. **Single denoising step**: Only one UNet forward pass per image (at t=999), making it extremely fast. +2. **Dynamic LoRA modulation**: A DEResNet encoder estimates input degradation and produces per-layer LoRA scaling matrices. These `[rank, rank]` modulation matrices are injected between `lora_A` and `lora_B` via einsum operations, conditioning the UNet on the specific degradation pattern of each input. +3. **Two LoRA ranks**: VAE uses rank=16 (6 blocks), UNet uses rank=32 (10 blocks). +4. **Small model**: Total size ~2 GB, fits on a single NeuronCore with no tensor parallelism needed. + +This contrib uses `torch_neuronx.trace()` rather than NxDI tensor parallelism, which is appropriate for the model's small size and non-autoregressive architecture. + +## Validation Results + +**Validated:** 2026-04-28 +**Instance:** trn2.3xlarge (LNC=2) +**SDK:** Neuron SDK 2.29 (DLAMI 20260410), PyTorch 2.9 + +### Benchmark Results (128x128 -> 512x512, single step) + +| Component | Time | +|-----------|------| +| DEResNet | 3.8ms | +| Modulation (CPU) | 0.5ms | +| VAE Encode | 83.2ms | +| UNet x2 (CFG) | 218.8ms | +| VAE Decode | 164.6ms | +| **Total** | **0.471s** | + +| Metric | Value | +|--------|-------| +| Resolution | 128x128 -> 512x512 (4x SR) | +| Inference steps | 1 (one-step model) | +| Warm generation time | 0.544s | +| Throughput | ~1.8 img/s | +| Total compile time | ~21 min | +| CPU baseline | 11.53s | +| Speedup vs CPU | ~21x | + +### Accuracy Validation + +Visual quality validated against CPU reference output. The model produces high-quality 4x upscaled images with correct degradation-aware enhancement. + +## Usage + +```python +from S3Diff.src.modeling_s3diff import S3DiffNeuronPipeline +from PIL import Image + +pipeline = S3DiffNeuronPipeline( + sd_turbo_path="/shared/sd-turbo/", + s3diff_weights_path="/shared/s3diff/s3diff.pkl", + de_net_path="/shared/s3diff/de_net.pth", + compile_dir="/tmp/s3diff/compiled/", + lr_size=128, +) +pipeline.load() +pipeline.compile() + +lr_image = Image.open("input_128x128.png").convert("RGB") +sr_image = pipeline(lr_image) +sr_image.save("output_512x512.png") +``` + +Or use the provided script: + +```bash +python src/generate_s3diff.py \ + --download \ + --input_image input.png \ + --output_image output.png \ + --compile_dir /tmp/s3diff/compiled/ +``` + +## Setup + +```bash +# Activate NxDI environment +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Install dependencies +pip install diffusers transformers peft accelerate torchvision + +# Download weights +python src/generate_s3diff.py --download + +# Or manually: +# SD-Turbo: huggingface-cli download stabilityai/sd-turbo --local-dir /shared/sd-turbo/ +# S3Diff: huggingface-cli download zhangap/S3Diff --local-dir /shared/s3diff/ +# DEResNet: git clone https://github.com/ArcticHare105/S3Diff.git /tmp/s3diff_repo +# cp /tmp/s3diff_repo/assets/mm-realsr/de_net.pth /shared/s3diff/ +``` + +## Compatibility Matrix + +| Instance/Version | SDK 2.29 | SDK 2.28 | +|------------------|----------|----------| +| trn2.3xlarge | VALIDATED | Not tested | + +## Example Checkpoints + +* [zhangap/S3Diff](https://huggingface.co/zhangap/S3Diff) -- S3Diff LoRA weights +* [stabilityai/sd-turbo](https://huggingface.co/stabilityai/sd-turbo) -- Base SD-Turbo model + +## Testing Instructions + +```bash +export SD_TURBO_PATH=/shared/sd-turbo/ +export S3DIFF_WEIGHTS=/shared/s3diff/s3diff.pkl +export DE_NET_WEIGHTS=/shared/s3diff/de_net.pth + +cd contrib/models/S3Diff/ +pytest test/integration/test_model.py -v + +# Or standalone +python test/integration/test_model.py +``` + +## Known Issues + +- **LoRA + `--auto-cast=matmult` produces NaN**: The LoRA modulation einsum operations are numerically unstable when `--auto-cast=matmult` casts them to BF16. The VAE encoder, UNet, and VAE decoder all use `--model-type=unet-inference` instead, which avoids this issue. Only DEResNet and text encoder (no LoRA) use `--auto-cast=matmult`. +- **Compilation time**: ~21 minutes total (UNet is the slowest at ~12 min). Compiled models are cached for reuse. +- **CFG is sequential**: Two separate UNet passes (positive + negative prompt), not batched. Batching with batch_size=2 would halve UNet wall time but requires recompilation. +- **Neuron runtime HBM**: Once loaded, compiled models stay in HBM even if the Python object is deleted (within the same process). Plan memory accordingly. diff --git a/contrib/models/S3Diff/src/__init__.py b/contrib/models/S3Diff/src/__init__.py new file mode 100644 index 00000000..3d29296f --- /dev/null +++ b/contrib/models/S3Diff/src/__init__.py @@ -0,0 +1,3 @@ +# S3Diff one-step super-resolution on AWS Neuron +# Uses torch_neuronx.trace() for compilation -- not NxDI TP sharding +# (model is ~2 GB, fits on a single NeuronCore) diff --git a/contrib/models/S3Diff/src/generate_s3diff.py b/contrib/models/S3Diff/src/generate_s3diff.py new file mode 100644 index 00000000..db98bd14 --- /dev/null +++ b/contrib/models/S3Diff/src/generate_s3diff.py @@ -0,0 +1,159 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +S3Diff one-step 4x super-resolution on AWS Neuron. + +Downloads required weights (SD-Turbo, S3Diff LoRA, DEResNet), compiles all +components, and runs super-resolution inference. + +Usage: + python generate_s3diff.py \ + --input_image /path/to/lr_image.png \ + --output_image /path/to/sr_output.png \ + --compile_dir /tmp/s3diff/compiled/ + +Requirements: + pip install diffusers transformers peft accelerate torchvision +""" + +import argparse +import os +import time + +import torch +from PIL import Image + +try: + from .modeling_s3diff import S3DiffNeuronPipeline +except ImportError: + from modeling_s3diff import S3DiffNeuronPipeline + + +DEFAULT_SD_TURBO_PATH = "/shared/sd-turbo/" +DEFAULT_S3DIFF_WEIGHTS = "/shared/s3diff/s3diff.pkl" +DEFAULT_DE_NET_WEIGHTS = "/shared/s3diff/de_net.pth" +DEFAULT_COMPILE_DIR = "/tmp/s3diff/compiled/" + + +def download_weights(sd_turbo_path, s3diff_weights_path, de_net_path): + """Download model weights if not already present.""" + from huggingface_hub import hf_hub_download, snapshot_download + + if not os.path.exists(sd_turbo_path): + print("Downloading SD-Turbo...") + snapshot_download("stabilityai/sd-turbo", local_dir=sd_turbo_path) + + if not os.path.exists(s3diff_weights_path): + print("Downloading S3Diff weights...") + os.makedirs(os.path.dirname(s3diff_weights_path), exist_ok=True) + hf_hub_download( + "zhangap/S3Diff", + filename="s3diff.pkl", + local_dir=os.path.dirname(s3diff_weights_path), + ) + + if not os.path.exists(de_net_path): + print("Downloading DEResNet weights...") + os.makedirs(os.path.dirname(de_net_path), exist_ok=True) + # DEResNet weights are in the S3Diff GitHub repo + import subprocess + + repo_dir = "/tmp/s3diff_repo" + if not os.path.exists(repo_dir): + subprocess.run( + [ + "git", + "clone", + "https://github.com/ArcticHare105/S3Diff.git", + repo_dir, + ], + check=True, + ) + import shutil + + shutil.copy2( + os.path.join(repo_dir, "assets", "mm-realsr", "de_net.pth"), + de_net_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="S3Diff 4x SR on AWS Neuron") + parser.add_argument( + "--input_image", + type=str, + default=None, + help="Path to input low-resolution image (128x128 recommended)", + ) + parser.add_argument( + "--output_image", type=str, default="sr_output.png", help="Output path" + ) + parser.add_argument("--sd_turbo_path", type=str, default=DEFAULT_SD_TURBO_PATH) + parser.add_argument("--s3diff_weights", type=str, default=DEFAULT_S3DIFF_WEIGHTS) + parser.add_argument("--de_net_weights", type=str, default=DEFAULT_DE_NET_WEIGHTS) + parser.add_argument("--compile_dir", type=str, default=DEFAULT_COMPILE_DIR) + parser.add_argument("--num_images", type=int, default=3) + parser.add_argument("--warmup_rounds", type=int, default=5) + parser.add_argument("--download", action="store_true", help="Download weights") + args = parser.parse_args() + + if args.download: + download_weights(args.sd_turbo_path, args.s3diff_weights, args.de_net_weights) + + # Create a test image if none provided + if args.input_image is None: + print("No input image provided, creating a 128x128 test pattern...") + import numpy as np + + test_img = np.random.randint(0, 255, (128, 128, 3), dtype=np.uint8) + lr_image = Image.fromarray(test_img) + else: + lr_image = Image.open(args.input_image).convert("RGB") + print(f"Input image: {lr_image.size}") + + # Build pipeline + pipeline = S3DiffNeuronPipeline( + sd_turbo_path=args.sd_turbo_path, + s3diff_weights_path=args.s3diff_weights, + de_net_path=args.de_net_weights, + compile_dir=args.compile_dir, + lr_size=lr_image.size[0], + ) + + print("\nLoading model...") + pipeline.load() + + print("\nCompiling...") + t0 = time.time() + pipeline.compile() + compile_time = time.time() - t0 + print(f"Total compilation: {compile_time:.1f}s") + + # Warmup + print(f"\nWarming up ({args.warmup_rounds} rounds)...") + for _ in range(args.warmup_rounds): + pipeline(lr_image) + + # Benchmark + print(f"\nGenerating {args.num_images} images...") + total_time = 0 + for i in range(args.num_images): + t0 = time.time() + sr_image = pipeline(lr_image) + elapsed = time.time() - t0 + total_time += elapsed + print(f" Image {i + 1}: {elapsed:.3f}s") + + avg_time = total_time / args.num_images + print(f"\nResults:") + print(f" Average time: {avg_time:.3f}s") + print(f" Throughput: {1.0 / avg_time:.2f} img/s") + print(f" Compilation: {compile_time:.1f}s") + + sr_image.save(args.output_image) + print(f" Saved: {args.output_image}") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/S3Diff/src/modeling_s3diff.py b/contrib/models/S3Diff/src/modeling_s3diff.py new file mode 100644 index 00000000..5cc9dda4 --- /dev/null +++ b/contrib/models/S3Diff/src/modeling_s3diff.py @@ -0,0 +1,659 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +S3Diff one-step 4x super-resolution on AWS Neuron. + +Model: Yukang/S3Diff (weights from zhangap/S3Diff on HuggingFace) +Paper: "Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors" (ECCV 2024) + +Architecture: SD-Turbo UNet with dynamic LoRA modulation. A DEResNet encoder +estimates input degradation, and per-layer LoRA scaling factors are computed +from these scores to condition the UNet on the specific degradation pattern. + +This module provides the full pipeline: load, compile, and run inference. +Uses torch_neuronx.trace() since the model is small (~2 GB) and does not +benefit from tensor parallelism. +""" + +import math +import os +import time +from typing import Any, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms + +# Scale factor and padding constants +SF = 4 +PAD_H = 512 +PAD_W = 512 + +# --------------------------------------------------------------------------- +# DEResNet -- degradation estimator +# --------------------------------------------------------------------------- + + +class ResidualBlockNoBN(nn.Module): + def __init__(self, num_feat=64): + super().__init__() + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return x + self.conv2(self.relu(self.conv1(x))) + + +class DEResNet(nn.Module): + """Degradation Estimation ResNet. Outputs per-degradation scores in [0, 1].""" + + def __init__( + self, + num_in_ch=3, + num_degradation=2, + num_feats=[64, 64, 64, 128], + num_blocks=[2, 2, 2, 2], + downscales=[1, 1, 2, 1], + ): + super().__init__() + num_stage = len(num_feats) + self.conv_first = nn.ModuleList() + for _ in range(num_degradation): + self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1)) + self.body = nn.ModuleList() + for _ in range(num_degradation): + body = [] + for stage in range(num_stage): + for _ in range(num_blocks[stage]): + body.append(ResidualBlockNoBN(num_feats[stage])) + if downscales[stage] == 1: + if ( + stage < num_stage - 1 + and num_feats[stage] != num_feats[stage + 1] + ): + body.append( + nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1) + ) + elif downscales[stage] == 2: + body.append( + nn.Conv2d( + num_feats[stage], + num_feats[min(stage + 1, num_stage - 1)], + 3, + 2, + 1, + ) + ) + self.body.append(nn.Sequential(*body)) + self.num_degradation = num_degradation + self.fc_degree = nn.ModuleList() + for _ in range(num_degradation): + self.fc_degree.append( + nn.Sequential( + nn.Linear(num_feats[-1], 512), + nn.ReLU(inplace=True), + nn.Linear(512, 1), + nn.Sigmoid(), + ) + ) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x): + degrees = [] + for i in range(self.num_degradation): + x_out = self.conv_first[i](x) + feat = self.body[i](x_out) + feat = self.avg_pool(feat).squeeze(-1).squeeze(-1) + degrees.append(self.fc_degree[i](feat).squeeze(-1)) + return torch.stack(degrees, dim=1) + + +# --------------------------------------------------------------------------- +# Custom LoRA forward with degradation modulation +# --------------------------------------------------------------------------- + + +def my_lora_fwd(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """LoRA forward that injects degradation modulation between lora_A and lora_B. + + For Conv2d LoRA layers: einsum('...khw,...kr->...rhw', lora_A(x), de_mod) + For Linear LoRA layers: einsum('...lk,...kr->...lr', lora_A(x), de_mod) + + The de_mod tensor is a [B, rank, rank] matrix set per-layer by the wrapper + modules before the forward pass. + """ + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward( + x, *args, adapter_names=adapter_names, **kwargs + ) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + if not self.use_dora[active_adapter]: + _tmp = lora_A(dropout(x)) + if isinstance(lora_A, torch.nn.Conv2d): + _tmp = torch.einsum("...khw,...kr->...rhw", _tmp, self.de_mod) + elif isinstance(lora_A, torch.nn.Linear): + _tmp = torch.einsum("...lk,...kr->...lr", _tmp, self.de_mod) + else: + raise NotImplementedError + result = result + lora_B(_tmp) * scaling + else: + x = dropout(x) + result = result + self._apply_dora( + x, lora_A, lora_B, scaling, active_adapter + ) + result = result.to(torch_result_dtype) + return result + + +# --------------------------------------------------------------------------- +# Neuron tracing wrappers +# --------------------------------------------------------------------------- + + +class TextEncoderWrapper(nn.Module): + def __init__(self, text_encoder): + super().__init__() + self.text_encoder = text_encoder + + def forward(self, input_ids): + return self.text_encoder(input_ids)[0] + + +class VAEEncoderWrapper(nn.Module): + """Wraps VAE encoder to accept de_mod_all [B, 6, rank, rank] as explicit input.""" + + def __init__(self, vae, vae_lora_layers, lora_rank_vae): + super().__init__() + self.vae = vae + self.vae_lora_layers = vae_lora_layers + self.lora_rank_vae = lora_rank_vae + self.layer_block_map = {} + for layer_name in vae_lora_layers: + split_name = layer_name.split(".") + if split_name[1] == "down_blocks": + self.layer_block_map[layer_name] = int(split_name[2]) + elif split_name[1] == "mid_block": + self.layer_block_map[layer_name] = 4 + else: + self.layer_block_map[layer_name] = 5 + + def forward(self, pixel_values, de_mod_all): + for layer_name, module in self.vae.named_modules(): + if layer_name in self.vae_lora_layers: + block_idx = self.layer_block_map[layer_name] + module.de_mod = de_mod_all[:, block_idx] + latent = ( + self.vae.encode(pixel_values).latent_dist.sample() + * self.vae.config.scaling_factor + ) + return latent + + +class UNetWrapper(nn.Module): + """Wraps UNet to accept de_mod_all [B, 10, rank, rank] as explicit input.""" + + def __init__(self, unet, unet_lora_layers, lora_rank_unet): + super().__init__() + self.unet = unet + self.unet_lora_layers = unet_lora_layers + self.lora_rank_unet = lora_rank_unet + self.layer_block_map = {} + for layer_name in unet_lora_layers: + split_name = layer_name.split(".") + if split_name[0] == "down_blocks": + self.layer_block_map[layer_name] = int(split_name[1]) + elif split_name[0] == "mid_block": + self.layer_block_map[layer_name] = 4 + elif split_name[0] == "up_blocks": + self.layer_block_map[layer_name] = int(split_name[1]) + 5 + else: + self.layer_block_map[layer_name] = 9 + + def forward(self, latent, timestep, encoder_hidden_states, de_mod_all): + for layer_name, module in self.unet.named_modules(): + if layer_name in self.unet_lora_layers: + block_idx = self.layer_block_map[layer_name] + module.de_mod = de_mod_all[:, block_idx] + return self.unet( + latent, timestep, encoder_hidden_states=encoder_hidden_states + ).sample + + +class VAEDecoderWrapper(nn.Module): + """Simple wrapper for VAE decoder (no LoRA).""" + + def __init__(self, vae): + super().__init__() + self.vae = vae + + def forward(self, latent): + return self.vae.decode(latent / self.vae.config.scaling_factor).sample + + +# --------------------------------------------------------------------------- +# S3Diff Neuron Pipeline +# --------------------------------------------------------------------------- + + +class S3DiffNeuronPipeline: + """End-to-end S3Diff super-resolution pipeline on Neuron. + + Handles model loading, compilation, and inference. + + Args: + sd_turbo_path: Path to SD-Turbo checkpoint (stabilityai/sd-turbo) + s3diff_weights_path: Path to s3diff.pkl weights + de_net_path: Path to de_net.pth weights + compile_dir: Directory for compiled model cache + lr_size: Input low-resolution size (default: 128) + """ + + def __init__( + self, + sd_turbo_path: str, + s3diff_weights_path: str, + de_net_path: str, + compile_dir: str = "/tmp/s3diff/compiled/", + lr_size: int = 128, + ): + self.sd_turbo_path = sd_turbo_path + self.s3diff_weights_path = s3diff_weights_path + self.de_net_path = de_net_path + self.compile_dir = compile_dir + self.lr_size = lr_size + self.hr_size = lr_size * SF + + # Will be set during load() + self.de_net_neuron = None + self.text_enc_neuron = None + self.vae_enc_neuron = None + self.unet_neuron = None + self.vae_dec_neuron = None + self.tokenizer = None + self.sched = None + self.compute_modulation = None + + def load(self): + """Load all model components and build the modulation network.""" + from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel + from peft import LoraConfig + from transformers import AutoTokenizer, CLIPTextModel + + print("Loading SD-Turbo + S3Diff LoRA weights...") + + self.tokenizer = AutoTokenizer.from_pretrained( + self.sd_turbo_path, subfolder="tokenizer" + ) + text_encoder = CLIPTextModel.from_pretrained( + self.sd_turbo_path, subfolder="text_encoder" + ).eval() + vae = AutoencoderKL.from_pretrained(self.sd_turbo_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained( + self.sd_turbo_path, subfolder="unet" + ) + + sd = torch.load(self.s3diff_weights_path, map_location="cpu") + self.lora_rank_unet = sd["rank_unet"] + self.lora_rank_vae = sd["rank_vae"] + print(f"LoRA ranks: unet={self.lora_rank_unet}, vae={self.lora_rank_vae}") + + # Add LoRA adapters and load trained weights + vae_lora_config = LoraConfig( + r=self.lora_rank_vae, + init_lora_weights="gaussian", + target_modules=sd["vae_lora_target_modules"], + ) + vae.add_adapter(vae_lora_config, adapter_name="vae_skip") + _sd_vae = vae.state_dict() + for k in sd["state_dict_vae"]: + _sd_vae[k] = sd["state_dict_vae"][k] + vae.load_state_dict(_sd_vae) + + unet_lora_config = LoraConfig( + r=self.lora_rank_unet, + init_lora_weights="gaussian", + target_modules=sd["unet_lora_target_modules"], + ) + unet.add_adapter(unet_lora_config) + _sd_unet = unet.state_dict() + for k in sd["state_dict_unet"]: + _sd_unet[k] = sd["state_dict_unet"][k] + unet.load_state_dict(_sd_unet) + + # Monkey-patch LoRA forward methods + vae_lora_layers = [] + for name, module in vae.named_modules(): + if "base_layer" in name: + vae_lora_layers.append(name[: -len(".base_layer")]) + for name, module in vae.named_modules(): + if name in vae_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + unet_lora_layers = [] + for name, module in unet.named_modules(): + if "base_layer" in name: + unet_lora_layers.append(name[: -len(".base_layer")]) + for name, module in unet.named_modules(): + if name in unet_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + vae.eval() + unet.eval() + + # Modulation MLPs (tiny, run on CPU) + num_embeddings = 64 + block_embedding_dim = 64 + W = nn.Parameter(sd["w"], requires_grad=False) + + vae_de_mlp = nn.Sequential(nn.Linear(num_embeddings * 4, 256), nn.ReLU(True)) + unet_de_mlp = nn.Sequential(nn.Linear(num_embeddings * 4, 256), nn.ReLU(True)) + vae_block_mlp = nn.Sequential(nn.Linear(block_embedding_dim, 64), nn.ReLU(True)) + unet_block_mlp = nn.Sequential( + nn.Linear(block_embedding_dim, 64), nn.ReLU(True) + ) + vae_fuse_mlp = nn.Linear(256 + 64, self.lora_rank_vae**2) + unet_fuse_mlp = nn.Linear(256 + 64, self.lora_rank_unet**2) + vae_block_embeddings = nn.Embedding(6, block_embedding_dim) + unet_block_embeddings = nn.Embedding(10, block_embedding_dim) + + for name, module in [ + ("vae_de_mlp", vae_de_mlp), + ("unet_de_mlp", unet_de_mlp), + ("vae_block_mlp", vae_block_mlp), + ("unet_block_mlp", unet_block_mlp), + ("vae_fuse_mlp", vae_fuse_mlp), + ("unet_fuse_mlp", unet_fuse_mlp), + ]: + _ssd = module.state_dict() + for k in sd[f"state_dict_{name}"]: + _ssd[k] = sd[f"state_dict_{name}"][k] + module.load_state_dict(_ssd) + vae_block_embeddings.load_state_dict( + sd["state_embeddings"]["state_dict_vae_block"] + ) + unet_block_embeddings.load_state_dict( + sd["state_embeddings"]["state_dict_unet_block"] + ) + + for m in [ + vae_de_mlp, + unet_de_mlp, + vae_block_mlp, + unet_block_mlp, + vae_fuse_mlp, + unet_fuse_mlp, + ]: + m.eval() + + # DEResNet + de_net = DEResNet(num_in_ch=3, num_degradation=2) + de_net.load_state_dict(torch.load(self.de_net_path, map_location="cpu")) + de_net.eval() + + # Scheduler + self.sched = DDPMScheduler.from_pretrained( + self.sd_turbo_path, subfolder="scheduler" + ) + self.sched.set_timesteps(1, device="cpu") + + # Build wrappers + self._text_enc_wrapper = TextEncoderWrapper(text_encoder) + self._vae_enc_wrapper = VAEEncoderWrapper( + vae, vae_lora_layers, self.lora_rank_vae + ) + self._unet_wrapper = UNetWrapper(unet, unet_lora_layers, self.lora_rank_unet) + self._vae_dec_wrapper = VAEDecoderWrapper(vae) + self._de_net = de_net + + # Build modulation closure + lora_rank_vae = self.lora_rank_vae + lora_rank_unet = self.lora_rank_unet + + def compute_modulation(deg_score): + deg_proj = deg_score[..., None] * W[None, None, :] * 2 * np.pi + deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1) + deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1) + + vae_de_c_embed = vae_de_mlp(deg_proj) + unet_de_c_embed = unet_de_mlp(deg_proj) + + vae_block_c_embeds = vae_block_mlp(vae_block_embeddings.weight) + unet_block_c_embeds = unet_block_mlp(unet_block_embeddings.weight) + + B = deg_score.shape[0] + vae_embeds = vae_fuse_mlp( + torch.cat( + [ + vae_de_c_embed.unsqueeze(1).repeat( + 1, vae_block_c_embeds.shape[0], 1 + ), + vae_block_c_embeds.unsqueeze(0).repeat(B, 1, 1), + ], + -1, + ) + ) + unet_embeds = unet_fuse_mlp( + torch.cat( + [ + unet_de_c_embed.unsqueeze(1).repeat( + 1, unet_block_c_embeds.shape[0], 1 + ), + unet_block_c_embeds.unsqueeze(0).repeat(B, 1, 1), + ], + -1, + ) + ) + + return ( + vae_embeds.reshape(B, 6, lora_rank_vae, lora_rank_vae), + unet_embeds.reshape(B, 10, lora_rank_unet, lora_rank_unet), + ) + + self.compute_modulation = compute_modulation + + print(f"VAE LoRA layers: {len(vae_lora_layers)}") + print(f"UNet LoRA layers: {len(unet_lora_layers)}") + print("Model loaded successfully.") + + def compile(self): + """Compile all components with torch_neuronx.trace().""" + import torch_neuronx + + os.makedirs(self.compile_dir, exist_ok=True) + lr_h, lr_w = self.lr_size, self.lr_size + hr_h, hr_w = self.hr_size, self.hr_size + lat_h, lat_w = hr_h // 8, hr_w // 8 + + # DEResNet + path = os.path.join(self.compile_dir, "de_net.pt") + if os.path.exists(path): + print("DEResNet: loading cached...") + self.de_net_neuron = torch.jit.load(path) + else: + print("DEResNet: compiling...") + t0 = time.time() + self.de_net_neuron = torch_neuronx.trace( + self._de_net, + torch.randn(1, 3, lr_h, lr_w), + compiler_args=["--auto-cast", "matmult", "-O1"], + ) + torch.jit.save(self.de_net_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + # Text encoder + path = os.path.join(self.compile_dir, "text_encoder.pt") + if os.path.exists(path): + print("Text encoder: loading cached...") + self.text_enc_neuron = torch.jit.load(path) + else: + print("Text encoder: compiling...") + t0 = time.time() + self.text_enc_neuron = torch_neuronx.trace( + self._text_enc_wrapper, + torch.zeros(1, 77, dtype=torch.long), + compiler_args=["--auto-cast", "matmult", "-O1"], + ) + torch.jit.save(self.text_enc_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + # VAE encoder (with LoRA) + path = os.path.join(self.compile_dir, "vae_encoder.pt") + if os.path.exists(path): + print("VAE encoder: loading cached...") + self.vae_enc_neuron = torch.jit.load(path) + else: + print("VAE encoder: compiling...") + t0 = time.time() + self.vae_enc_neuron = torch_neuronx.trace( + self._vae_enc_wrapper, + ( + torch.randn(1, 3, hr_h, hr_w), + torch.randn(1, 6, self.lora_rank_vae, self.lora_rank_vae), + ), + compiler_args=["--model-type=unet-inference", "-O1"], + ) + torch.jit.save(self.vae_enc_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + # UNet (with LoRA) + path = os.path.join(self.compile_dir, "unet.pt") + if os.path.exists(path): + print("UNet: loading cached...") + self.unet_neuron = torch.jit.load(path) + else: + print("UNet: compiling...") + t0 = time.time() + self.unet_neuron = torch_neuronx.trace( + self._unet_wrapper, + ( + torch.randn(1, 4, lat_h, lat_w), + torch.tensor([999], dtype=torch.long), + torch.randn(1, 77, 1024), + torch.randn(1, 10, self.lora_rank_unet, self.lora_rank_unet), + ), + compiler_args=["--model-type=unet-inference", "-O1"], + ) + torch.jit.save(self.unet_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + # VAE decoder (no LoRA) + path = os.path.join(self.compile_dir, "vae_decoder.pt") + if os.path.exists(path): + print("VAE decoder: loading cached...") + self.vae_dec_neuron = torch.jit.load(path) + else: + print("VAE decoder: compiling...") + t0 = time.time() + self.vae_dec_neuron = torch_neuronx.trace( + self._vae_dec_wrapper, + torch.randn(1, 4, lat_h, lat_w), + compiler_args=["--model-type=unet-inference", "-O1"], + ) + torch.jit.save(self.vae_dec_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + print("All components compiled.") + + @torch.no_grad() + def __call__( + self, + lr_image: Image.Image, + pos_prompt: str = "high quality, highly detailed, clean", + neg_prompt: str = "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", + cfg_scale: float = 1.07, + ) -> Image.Image: + """Run 4x super-resolution on a low-resolution PIL image. + + Args: + lr_image: Input PIL image (should be lr_size x lr_size) + pos_prompt: Positive text prompt + neg_prompt: Negative text prompt + cfg_scale: Classifier-free guidance scale + + Returns: + Super-resolved PIL image (4x the input resolution) + """ + to_tensor = transforms.ToTensor() + im_lr = to_tensor(lr_image).unsqueeze(0) + + # Preprocess: resize 4x + normalize + pad + ori_h, ori_w = im_lr.shape[2:] + im_lr_resize = F.interpolate( + im_lr, + size=(ori_h * SF, ori_w * SF), + mode="bilinear", + align_corners=False, + ) + im_lr_resize_norm = (im_lr_resize * 2 - 1.0).clamp(-1, 1) + resize_h, resize_w = im_lr_resize_norm.shape[2:] + im_lr_resize_norm = F.pad( + im_lr_resize_norm, + pad=(0, PAD_W - resize_w, 0, PAD_H - resize_h), + mode="reflect", + ) + + # 1. DEResNet -> degradation scores + deg_score = self.de_net_neuron(im_lr) + + # 2. Compute modulation on CPU + vae_de_mod_all, unet_de_mod_all = self.compute_modulation(deg_score) + + # 3. Text encoding + pos_tokens = self.tokenizer( + pos_prompt, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ).input_ids + neg_tokens = self.tokenizer( + neg_prompt, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ).input_ids + pos_enc = self.text_enc_neuron(pos_tokens) + neg_enc = self.text_enc_neuron(neg_tokens) + + # 4. VAE Encode (with de_mod) + latent = self.vae_enc_neuron(im_lr_resize_norm, vae_de_mod_all) + + # 5. UNet x2 for CFG (with de_mod) + timestep = torch.tensor([999], dtype=torch.long) + pos_pred = self.unet_neuron(latent, timestep, pos_enc, unet_de_mod_all) + neg_pred = self.unet_neuron(latent, timestep, neg_enc, unet_de_mod_all) + model_pred = neg_pred + cfg_scale * (pos_pred - neg_pred) + + # 6. Scheduler step (CPU) + x_denoised = self.sched.step( + model_pred.cpu(), torch.tensor([999]), latent.cpu(), return_dict=True + ).prev_sample + + # 7. VAE Decode + output = self.vae_dec_neuron(x_denoised).clamp(-1, 1) + output = output[:, :, :resize_h, :resize_w] + return transforms.ToPILImage()((output[0] * 0.5 + 0.5).cpu().clamp(0, 1)) diff --git a/contrib/models/S3Diff/test/__init__.py b/contrib/models/S3Diff/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/S3Diff/test/integration/__init__.py b/contrib/models/S3Diff/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/S3Diff/test/integration/test_model.py b/contrib/models/S3Diff/test/integration/test_model.py new file mode 100644 index 00000000..c00ad580 --- /dev/null +++ b/contrib/models/S3Diff/test/integration/test_model.py @@ -0,0 +1,159 @@ +""" +Integration tests for S3Diff one-step super-resolution on Neuron. + +Tests: +1. test_smoke_pipeline_loads: Pipeline loads without errors +2. test_sr_produces_correct_size: 128x128 -> 512x512 +3. test_warm_generation_time: < 2s per image + +Requirements: + - trn2.3xlarge + - Neuron SDK 2.29+ + - diffusers, transformers, peft, torchvision + - Weights downloaded (see generate_s3diff.py --download) + +Run: + pytest test_model.py -v + # Or standalone: + python test_model.py +""" + +import gc +import os +import sys +import time + +import numpy as np +import pytest +import torch +from PIL import Image + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +SRC_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, "..", "..", "src")) +sys.path.insert(0, SRC_DIR) + +from modeling_s3diff import S3DiffNeuronPipeline + +SD_TURBO_PATH = os.environ.get("SD_TURBO_PATH", "/shared/sd-turbo/") +S3DIFF_WEIGHTS = os.environ.get("S3DIFF_WEIGHTS", "/shared/s3diff/s3diff.pkl") +DE_NET_WEIGHTS = os.environ.get("DE_NET_WEIGHTS", "/shared/s3diff/de_net.pth") +COMPILE_DIR = os.environ.get("S3DIFF_COMPILE_DIR", "/tmp/s3diff_test/compiled/") +LR_SIZE = 128 +HR_SIZE = 512 + + +def make_test_image(size=128): + """Create a deterministic test image.""" + np.random.seed(42) + data = np.random.randint(0, 255, (size, size, 3), dtype=np.uint8) + return Image.fromarray(data) + + +@pytest.fixture(scope="module") +def pipeline(): + """Create, load, compile, and warm up the S3Diff pipeline.""" + pipe = S3DiffNeuronPipeline( + sd_turbo_path=SD_TURBO_PATH, + s3diff_weights_path=S3DIFF_WEIGHTS, + de_net_path=DE_NET_WEIGHTS, + compile_dir=COMPILE_DIR, + lr_size=LR_SIZE, + ) + pipe.load() + pipe.compile() + + # Warmup + test_img = make_test_image() + pipe(test_img) + + yield pipe + + del pipe + gc.collect() + + +def test_smoke_pipeline_loads(pipeline): + """Pipeline loads without errors and has required components.""" + assert pipeline is not None + assert pipeline.de_net_neuron is not None + assert pipeline.text_enc_neuron is not None + assert pipeline.vae_enc_neuron is not None + assert pipeline.unet_neuron is not None + assert pipeline.vae_dec_neuron is not None + assert pipeline.tokenizer is not None + assert pipeline.sched is not None + assert pipeline.compute_modulation is not None + + +def test_sr_produces_correct_size(pipeline): + """128x128 input produces 512x512 output.""" + test_img = make_test_image(LR_SIZE) + sr_img = pipeline(test_img) + + assert isinstance(sr_img, Image.Image) + assert sr_img.size == (HR_SIZE, HR_SIZE), ( + f"Expected ({HR_SIZE}, {HR_SIZE}), got {sr_img.size}" + ) + + # Verify reasonable pixel distribution (not blank) + sr_array = np.array(sr_img) + assert sr_array.shape == (HR_SIZE, HR_SIZE, 3) + assert sr_array.std() > 10, "Output appears blank or uniform" + + # Save for inspection + os.makedirs(os.path.join(COMPILE_DIR, "test_outputs"), exist_ok=True) + sr_img.save(os.path.join(COMPILE_DIR, "test_outputs", "test_sr.png")) + print(f"SR output saved, pixel std={sr_array.std():.1f}") + + +def test_warm_generation_time(pipeline): + """Warm generation should complete in < 2s.""" + test_img = make_test_image(LR_SIZE) + + t0 = time.time() + pipeline(test_img) + elapsed = time.time() - t0 + print(f"Warm generation time: {elapsed:.3f}s") + + assert elapsed < 2.0, f"Generation took {elapsed:.3f}s, expected < 2s" + + +# Standalone runner +if __name__ == "__main__": + print("=" * 60) + print("S3Diff Integration Tests") + print("=" * 60) + + pipe = S3DiffNeuronPipeline( + sd_turbo_path=SD_TURBO_PATH, + s3diff_weights_path=S3DIFF_WEIGHTS, + de_net_path=DE_NET_WEIGHTS, + compile_dir=COMPILE_DIR, + lr_size=LR_SIZE, + ) + + print("\n[1/5] Loading...") + pipe.load() + + print("\n[2/5] Compiling...") + t0 = time.time() + pipe.compile() + print(f" Compilation: {time.time() - t0:.1f}s") + + print("\n[3/5] Warmup...") + test_img = make_test_image() + pipe(test_img) + + print("\n[4/5] test_smoke_pipeline_loads") + test_smoke_pipeline_loads(pipe) + print(" PASSED") + + print("\n[5/5] test_sr_produces_correct_size") + test_sr_produces_correct_size(pipe) + print(" PASSED") + + print("\n[6/6] test_warm_generation_time") + test_warm_generation_time(pipe) + print(" PASSED") + + print("\nAll tests passed!")