Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions contrib/models/flux1-lite-8b/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Contrib Model: FLUX.1-lite-8B-alpha

FLUX.1-lite-8B-alpha image generation model running on AWS Neuron using NxDI's first-party FLUX.1 implementation with zero code modifications.

## Model Information

- **HuggingFace ID:** `Freepik/flux.1-lite-8B-alpha`
- **Model Type:** Diffusion transformer (DiT) for text-to-image generation
- **Parameters:** ~8B (BF16)
- **Architecture:** 8 double-stream MMDiT blocks + 38 single-stream DiT blocks, CLIP + T5-XXL text encoders, 16-channel VAE, FlowMatchEulerDiscrete scheduler
- **License:** Check HuggingFace model card (gated model, requires access approval)

## Key Finding: Native NxDI FLUX.1 Compatibility

**FLUX.1-lite-8B-alpha is architecturally identical to FLUX.1-dev** with only the number of double-stream blocks reduced (8 vs 19). All other components are the same:

| Component | FLUX.1-dev | FLUX.1-lite-8B | Same? |
|-----------|-----------|----------------|-------|
| Double-stream (MMDiT) blocks | 19 | 8 | Different |
| Single-stream (DiT) blocks | 38 | 38 | Same |
| Attention heads | 24 | 24 | Same |
| Attention head dim | 128 | 128 | Same |
| Joint attention dim | 4096 | 4096 | Same |
| Text encoders | CLIP + T5-XXL | CLIP + T5-XXL | Same |
| VAE latent channels | 16 | 16 | Same |
| RoPE axes_dim | (16, 56, 56) | (16, 56, 56) | Same |
| Pipeline class | FluxPipeline | FluxPipeline | Same |
| Scheduler | FlowMatchEulerDiscrete | FlowMatchEulerDiscrete | Same |
| guidance_embeds | True | True | Same |

Because NxDI's FLUX.1 implementation reads `num_layers` and `num_single_layers` from the model's `config.json` at runtime (via `load_diffusers_config()`), it automatically adapts to FLUX.1-lite's configuration. **No custom modeling code is needed.**

This contrib provides:
- A standalone generation script (`src/generate_flux_lite.py`)
- Integration tests validating correct operation on Neuron
- Benchmark results demonstrating the performance benefit of the lighter architecture

## Validation Results

**Validated:** 2026-04-28
**Instance:** trn2.3xlarge (LNC=2, 4 logical cores)
**SDK:** Neuron SDK 2.29 (DLAMI 20260410), PyTorch 2.9, NxD Inference 0.9

### Benchmark Results (1024x1024, 25 steps, guidance_scale=3.5)

| Metric | Value |
|--------|-------|
| Resolution | 1024x1024 |
| Inference steps | 25 |
| TP Degree | 4 |
| CFG | Guidance distillation (single forward pass/step) |
| E2E generation time | 5.91s avg |
| Pipeline steps/sec | 4.23 |
| Backbone forward/sec | 4.49 |
| Compilation time | ~128s (CLIP 69s + T5 5s + backbone 53s + VAE ~2s) |

## Usage

```python
import torch
from neuronx_distributed_inference.models.diffusers.flux.application import (
NeuronFluxApplication,
create_flux_config,
get_flux_parallelism_config,
)

MODEL_PATH = "/shared/flux1-lite-8b/"
COMPILE_DIR = "/tmp/flux-lite/compiled/"

# Configure (reads num_layers=8 from model's config.json automatically)
world_size = get_flux_parallelism_config(backbone_tp_degree=4)
clip_cfg, t5_cfg, backbone_cfg, decoder_cfg = create_flux_config(
MODEL_PATH, world_size, backbone_tp_degree=4,
dtype=torch.bfloat16, height=1024, width=1024,
)

# Create application
app = NeuronFluxApplication(
model_path=MODEL_PATH,
text_encoder_config=clip_cfg,
text_encoder2_config=t5_cfg,
backbone_config=backbone_cfg,
decoder_config=decoder_cfg,
height=1024, width=1024,
)

# Compile + load
app.compile(COMPILE_DIR)
app.load(COMPILE_DIR)

# Generate
image = app(
"A cat holding a sign that says hello world",
height=1024, width=1024,
guidance_scale=3.5,
num_inference_steps=25,
).images[0]
image.save("output.png")
```

Or use the provided script:

```bash
python src/generate_flux_lite.py \
--checkpoint_dir /shared/flux1-lite-8b \
--compile_workdir /tmp/flux-lite/compiled/ \
--prompt "A cat holding a sign that says hello world" \
--height 1024 --width 1024 \
--num_inference_steps 25 \
--save_image
```

## Setup

```bash
# Activate NxDI environment
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate

# Install diffusers (not pre-installed in NxDI venv)
pip install diffusers transformers accelerate sentencepiece protobuf

# Download model (requires HuggingFace token with access)
huggingface-cli download Freepik/flux.1-lite-8B-alpha \
--local-dir /shared/flux1-lite-8b
```

## Compatibility Matrix

| Instance/Version | SDK 2.29 | SDK 2.28 |
|------------------|----------|----------|
| trn2.3xlarge (LNC=2, TP=4) | VALIDATED | Not tested |

## Example Checkpoints

* [Freepik/flux.1-lite-8B-alpha](https://huggingface.co/Freepik/flux.1-lite-8B-alpha)

## Testing Instructions

```bash
# Set model path
export FLUX_LITE_MODEL_PATH=/shared/flux1-lite-8b/

# Run with pytest
cd contrib/models/flux1-lite-8b/
pytest test/integration/test_model.py -v

# Or standalone
python test/integration/test_model.py
```

## Known Issues

- The NxDI venv (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`) does not include `diffusers` by default. Install it with pip before running.
- `attention_cte` kernel warnings about batch size x seqlen_q x seqlen_k appear during inference. These are informational and do not affect output quality.

## Sample Output

![FLUX.1-lite output](samples/flux_lite_cat_hello_world.png)

*"A cat holding a sign that says hello world" -- 1024x1024, 25 steps, guidance_scale=3.5*
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions contrib/models/flux1-lite-8b/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# NxDI FLUX.1-lite-8B-alpha Diffusion Model
# Demonstrates that FLUX.1-lite runs natively on NxDI's first-party FLUX.1 implementation
# with no code modifications -- only different model weights.
147 changes: 147 additions & 0 deletions contrib/models/flux1-lite-8b/src/generate_flux_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
FLUX.1-lite-8B-alpha generation script for AWS Neuron.

FLUX.1-lite-8B-alpha (Freepik) is architecturally identical to FLUX.1-dev with a
reduced backbone: 8 double-stream MMDiT blocks instead of 19. It uses the same
CLIP + T5-XXL text encoders, FluxPipeline, VAE, scheduler, and RoPE configuration.

Because of this architectural compatibility, FLUX.1-lite runs natively on NxDI's
first-party FLUX.1 implementation with no code modifications. The NxDI FLUX.1
application reads `num_layers` and `num_single_layers` from the model's config.json
at runtime, so it automatically adapts to FLUX.1-lite's configuration:
- num_layers: 8 (vs 19 in FLUX.1-dev)
- num_single_layers: 38 (same as FLUX.1-dev)

Usage:
# Download the model (requires HuggingFace access):
huggingface-cli download Freepik/flux.1-lite-8B-alpha --local-dir /shared/flux1-lite-8b

# Generate an image:
python generate_flux_lite.py \\
--checkpoint_dir /shared/flux1-lite-8b \\
--compile_workdir /tmp/flux-lite/compiled/ \\
--prompt "A cat holding a sign that says hello world" \\
--height 1024 --width 1024 \\
--num_inference_steps 25 \\
--save_image

Requirements:
pip install diffusers transformers accelerate sentencepiece protobuf
"""

import argparse
import time

import torch
from neuronx_distributed_inference.models.diffusers.flux.application import (
NeuronFluxApplication,
create_flux_config,
get_flux_parallelism_config,
)
from neuronx_distributed_inference.utils.random import set_random_seed

set_random_seed(0)

DEFAULT_CKPT_DIR = "/shared/flux1-lite-8b/"
DEFAULT_COMPILE_DIR = "/tmp/flux-lite/compiled/"


def run_generate(args):
print(f"FLUX.1-lite-8B generation with args: {args}")

backbone_tp_degree = args.backbone_tp_degree if args.backbone_tp_degree else 4
world_size = get_flux_parallelism_config(backbone_tp_degree)
dtype = torch.bfloat16

clip_config, t5_config, backbone_config, decoder_config = create_flux_config(
args.checkpoint_dir,
world_size,
backbone_tp_degree,
dtype,
args.height,
args.width,
)

flux_app = NeuronFluxApplication(
model_path=args.checkpoint_dir,
text_encoder_config=clip_config,
text_encoder2_config=t5_config,
backbone_config=backbone_config,
decoder_config=decoder_config,
height=args.height,
width=args.width,
)

print("Compiling model...")
compile_start = time.time()
flux_app.compile(args.compile_workdir)
compile_time = time.time() - compile_start
print(f"Compilation completed in {compile_time:.1f}s")

flux_app.load(args.compile_workdir)

# Warmup
print("Warming up...")
for _ in range(args.warmup_rounds):
flux_app(
args.prompt,
height=args.height,
width=args.width,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
).images[0]

# Generate
total_time = 0
for i in range(args.num_images):
start = time.time()
image = flux_app(
args.prompt,
height=args.height,
width=args.width,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
).images[0]
gen_time = time.time() - start
total_time += gen_time

if args.save_image:
filename = f"flux_lite_output_{i + 1}.png"
image.save(filename)
print(f"Image {i + 1} saved to {filename} in {gen_time:.2f}s")
else:
print(f"Image {i + 1} generated in {gen_time:.2f}s")

avg_time = total_time / args.num_images
steps_per_sec = args.num_inference_steps / avg_time
print(f"\nResults:")
print(f" Average generation time: {avg_time:.2f}s")
print(f" Pipeline steps/sec: {steps_per_sec:.2f}")
print(f" Compilation time: {compile_time:.1f}s")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FLUX.1-lite-8B on AWS Neuron (NxDI)")
parser.add_argument(
"-p", "--prompt", type=str, default="A cat holding a sign that says hello world"
)
parser.add_argument("-hh", "--height", type=int, default=1024)
parser.add_argument("-w", "--width", type=int, default=1024)
parser.add_argument("-n", "--num_inference_steps", type=int, default=25)
parser.add_argument("-g", "--guidance_scale", type=float, default=3.5)
parser.add_argument("-c", "--checkpoint_dir", type=str, default=DEFAULT_CKPT_DIR)
parser.add_argument("--compile_workdir", type=str, default=DEFAULT_COMPILE_DIR)
parser.add_argument("--num_images", type=int, default=3)
parser.add_argument("--warmup_rounds", type=int, default=5)
parser.add_argument("--save_image", action="store_true")
parser.add_argument(
"--backbone_tp_degree",
type=int,
default=None,
help="Tensor parallelism degree (default: 4 for trn2.3xlarge)",
)
args = parser.parse_args()
run_generate(args)
Empty file.
Empty file.
Loading