Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions contrib/models/S3Diff/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Contrib Model: S3Diff

S3Diff one-step 4x super-resolution on AWS Neuron using `torch_neuronx.trace()`.

## Model Information

- **HuggingFace ID:** `zhangap/S3Diff` (weights), base model `stabilityai/sd-turbo`
- **Model Type:** One-step diffusion model for image super-resolution
- **Parameters:** ~1.3B total (~2 GB on disk)
- **Architecture:** SD-Turbo UNet with degradation-guided dynamic LoRA modulation (DEResNet encoder, CLIP text encoder, VAE encoder/decoder with per-layer LoRA, UNet with per-layer LoRA)
- **Paper:** "Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors" (ECCV 2024)
- **License:** Check model cards for SD-Turbo and S3Diff

## Key Architecture Notes

S3Diff is unusual among diffusion models:

1. **Single denoising step**: Only one UNet forward pass per image (at t=999), making it extremely fast.
2. **Dynamic LoRA modulation**: A DEResNet encoder estimates input degradation and produces per-layer LoRA scaling matrices. These `[rank, rank]` modulation matrices are injected between `lora_A` and `lora_B` via einsum operations, conditioning the UNet on the specific degradation pattern of each input.
3. **Two LoRA ranks**: VAE uses rank=16 (6 blocks), UNet uses rank=32 (10 blocks).
4. **Small model**: Total size ~2 GB, fits on a single NeuronCore with no tensor parallelism needed.

This contrib uses `torch_neuronx.trace()` rather than NxDI tensor parallelism, which is appropriate for the model's small size and non-autoregressive architecture.

## Validation Results

**Validated:** 2026-04-28
**Instance:** trn2.3xlarge (LNC=2)
**SDK:** Neuron SDK 2.29 (DLAMI 20260410), PyTorch 2.9

### Benchmark Results (128x128 -> 512x512, single step)

| Component | Time |
|-----------|------|
| DEResNet | 3.8ms |
| Modulation (CPU) | 0.5ms |
| VAE Encode | 83.2ms |
| UNet x2 (CFG) | 218.8ms |
| VAE Decode | 164.6ms |
| **Total** | **0.471s** |

| Metric | Value |
|--------|-------|
| Resolution | 128x128 -> 512x512 (4x SR) |
| Inference steps | 1 (one-step model) |
| Warm generation time | 0.544s |
| Throughput | ~1.8 img/s |
| Total compile time | ~21 min |
| CPU baseline | 11.53s |
| Speedup vs CPU | ~21x |

### Accuracy Validation

Visual quality validated against CPU reference output. The model produces high-quality 4x upscaled images with correct degradation-aware enhancement.

## Usage

```python
from S3Diff.src.modeling_s3diff import S3DiffNeuronPipeline
from PIL import Image

pipeline = S3DiffNeuronPipeline(
sd_turbo_path="/shared/sd-turbo/",
s3diff_weights_path="/shared/s3diff/s3diff.pkl",
de_net_path="/shared/s3diff/de_net.pth",
compile_dir="/tmp/s3diff/compiled/",
lr_size=128,
)
pipeline.load()
pipeline.compile()

lr_image = Image.open("input_128x128.png").convert("RGB")
sr_image = pipeline(lr_image)
sr_image.save("output_512x512.png")
```

Or use the provided script:

```bash
python src/generate_s3diff.py \
--download \
--input_image input.png \
--output_image output.png \
--compile_dir /tmp/s3diff/compiled/
```

## Setup

```bash
# Activate NxDI environment
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate

# Install dependencies
pip install diffusers transformers peft accelerate torchvision

# Download weights
python src/generate_s3diff.py --download

# Or manually:
# SD-Turbo: huggingface-cli download stabilityai/sd-turbo --local-dir /shared/sd-turbo/
# S3Diff: huggingface-cli download zhangap/S3Diff --local-dir /shared/s3diff/
# DEResNet: git clone https://github.com/ArcticHare105/S3Diff.git /tmp/s3diff_repo
# cp /tmp/s3diff_repo/assets/mm-realsr/de_net.pth /shared/s3diff/
```

## Compatibility Matrix

| Instance/Version | SDK 2.29 | SDK 2.28 |
|------------------|----------|----------|
| trn2.3xlarge | VALIDATED | Not tested |

## Example Checkpoints

* [zhangap/S3Diff](https://huggingface.co/zhangap/S3Diff) -- S3Diff LoRA weights
* [stabilityai/sd-turbo](https://huggingface.co/stabilityai/sd-turbo) -- Base SD-Turbo model

## Testing Instructions

```bash
export SD_TURBO_PATH=/shared/sd-turbo/
export S3DIFF_WEIGHTS=/shared/s3diff/s3diff.pkl
export DE_NET_WEIGHTS=/shared/s3diff/de_net.pth

cd contrib/models/S3Diff/
pytest test/integration/test_model.py -v

# Or standalone
python test/integration/test_model.py
```

## Known Issues

- **LoRA + `--auto-cast=matmult` produces NaN**: The LoRA modulation einsum operations are numerically unstable when `--auto-cast=matmult` casts them to BF16. The VAE encoder, UNet, and VAE decoder all use `--model-type=unet-inference` instead, which avoids this issue. Only DEResNet and text encoder (no LoRA) use `--auto-cast=matmult`.
- **Compilation time**: ~21 minutes total (UNet is the slowest at ~12 min). Compiled models are cached for reuse.
- **CFG is sequential**: Two separate UNet passes (positive + negative prompt), not batched. Batching with batch_size=2 would halve UNet wall time but requires recompilation.
- **Neuron runtime HBM**: Once loaded, compiled models stay in HBM even if the Python object is deleted (within the same process). Plan memory accordingly.
3 changes: 3 additions & 0 deletions contrib/models/S3Diff/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# S3Diff one-step super-resolution on AWS Neuron
# Uses torch_neuronx.trace() for compilation -- not NxDI TP sharding
# (model is ~2 GB, fits on a single NeuronCore)
159 changes: 159 additions & 0 deletions contrib/models/S3Diff/src/generate_s3diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
S3Diff one-step 4x super-resolution on AWS Neuron.

Downloads required weights (SD-Turbo, S3Diff LoRA, DEResNet), compiles all
components, and runs super-resolution inference.

Usage:
python generate_s3diff.py \
--input_image /path/to/lr_image.png \
--output_image /path/to/sr_output.png \
--compile_dir /tmp/s3diff/compiled/

Requirements:
pip install diffusers transformers peft accelerate torchvision
"""

import argparse
import os
import time

import torch
from PIL import Image

try:
from .modeling_s3diff import S3DiffNeuronPipeline
except ImportError:
from modeling_s3diff import S3DiffNeuronPipeline


DEFAULT_SD_TURBO_PATH = "/shared/sd-turbo/"
DEFAULT_S3DIFF_WEIGHTS = "/shared/s3diff/s3diff.pkl"
DEFAULT_DE_NET_WEIGHTS = "/shared/s3diff/de_net.pth"
DEFAULT_COMPILE_DIR = "/tmp/s3diff/compiled/"


def download_weights(sd_turbo_path, s3diff_weights_path, de_net_path):
"""Download model weights if not already present."""
from huggingface_hub import hf_hub_download, snapshot_download

if not os.path.exists(sd_turbo_path):
print("Downloading SD-Turbo...")
snapshot_download("stabilityai/sd-turbo", local_dir=sd_turbo_path)

if not os.path.exists(s3diff_weights_path):
print("Downloading S3Diff weights...")
os.makedirs(os.path.dirname(s3diff_weights_path), exist_ok=True)
hf_hub_download(
"zhangap/S3Diff",
filename="s3diff.pkl",
local_dir=os.path.dirname(s3diff_weights_path),
)

if not os.path.exists(de_net_path):
print("Downloading DEResNet weights...")
os.makedirs(os.path.dirname(de_net_path), exist_ok=True)
# DEResNet weights are in the S3Diff GitHub repo
import subprocess

repo_dir = "/tmp/s3diff_repo"
if not os.path.exists(repo_dir):
subprocess.run(
[
"git",
"clone",
"https://github.com/ArcticHare105/S3Diff.git",
repo_dir,
],
check=True,
)
import shutil

shutil.copy2(
os.path.join(repo_dir, "assets", "mm-realsr", "de_net.pth"),
de_net_path,
)


def main():
parser = argparse.ArgumentParser(description="S3Diff 4x SR on AWS Neuron")
parser.add_argument(
"--input_image",
type=str,
default=None,
help="Path to input low-resolution image (128x128 recommended)",
)
parser.add_argument(
"--output_image", type=str, default="sr_output.png", help="Output path"
)
parser.add_argument("--sd_turbo_path", type=str, default=DEFAULT_SD_TURBO_PATH)
parser.add_argument("--s3diff_weights", type=str, default=DEFAULT_S3DIFF_WEIGHTS)
parser.add_argument("--de_net_weights", type=str, default=DEFAULT_DE_NET_WEIGHTS)
parser.add_argument("--compile_dir", type=str, default=DEFAULT_COMPILE_DIR)
parser.add_argument("--num_images", type=int, default=3)
parser.add_argument("--warmup_rounds", type=int, default=5)
parser.add_argument("--download", action="store_true", help="Download weights")
args = parser.parse_args()

if args.download:
download_weights(args.sd_turbo_path, args.s3diff_weights, args.de_net_weights)

# Create a test image if none provided
if args.input_image is None:
print("No input image provided, creating a 128x128 test pattern...")
import numpy as np

test_img = np.random.randint(0, 255, (128, 128, 3), dtype=np.uint8)
lr_image = Image.fromarray(test_img)
else:
lr_image = Image.open(args.input_image).convert("RGB")
print(f"Input image: {lr_image.size}")

# Build pipeline
pipeline = S3DiffNeuronPipeline(
sd_turbo_path=args.sd_turbo_path,
s3diff_weights_path=args.s3diff_weights,
de_net_path=args.de_net_weights,
compile_dir=args.compile_dir,
lr_size=lr_image.size[0],
)

print("\nLoading model...")
pipeline.load()

print("\nCompiling...")
t0 = time.time()
pipeline.compile()
compile_time = time.time() - t0
print(f"Total compilation: {compile_time:.1f}s")

# Warmup
print(f"\nWarming up ({args.warmup_rounds} rounds)...")
for _ in range(args.warmup_rounds):
pipeline(lr_image)

# Benchmark
print(f"\nGenerating {args.num_images} images...")
total_time = 0
for i in range(args.num_images):
t0 = time.time()
sr_image = pipeline(lr_image)
elapsed = time.time() - t0
total_time += elapsed
print(f" Image {i + 1}: {elapsed:.3f}s")

avg_time = total_time / args.num_images
print(f"\nResults:")
print(f" Average time: {avg_time:.3f}s")
print(f" Throughput: {1.0 / avg_time:.2f} img/s")
print(f" Compilation: {compile_time:.1f}s")

sr_image.save(args.output_image)
print(f" Saved: {args.output_image}")


if __name__ == "__main__":
main()
Loading