diff --git a/contrib/models/flux2-klein/README.md b/contrib/models/flux2-klein/README.md new file mode 100644 index 00000000..25583c23 --- /dev/null +++ b/contrib/models/flux2-klein/README.md @@ -0,0 +1,152 @@ +# Contrib Model: FLUX.2-klein-base-9B + +FLUX.2-klein-base-9B image generation model running on AWS Neuron with NxD Inference tensor parallelism. + +## Model Information + +- **HuggingFace ID:** `black-forest-labs/FLUX.2-klein-base-9B` +- **Model Type:** Diffusion transformer (DiT) for text-to-image generation +- **Parameters:** ~9.08B (BF16) +- **Architecture:** 8 double-stream MMDiT blocks + 24 single-stream DiT blocks, 4D RoPE, SwiGLU, pre-computed modulation, Qwen3-8B text encoder +- **License:** Check HuggingFace model card (gated model, requires access approval) + +## Key Architecture Notes + +FLUX.2-klein differs from FLUX.1 in several important ways: + +1. **Pre-computed modulation**: Modulation parameters are computed once from the timestep embedding and shared across all blocks (FLUX.1 computes per-block). +2. **Fused QKV+MLP**: Single-stream blocks use `Flux2ParallelSelfAttention` with a fused `to_qkv_mlp_proj` linear. For NxDI, this is split into separate TP-sharded Q/K/V and MLP projections. +3. **SwiGLU activation**: Uses `SiLU(x1) * x2` instead of GELU. +4. **Qwen3-8B text encoder**: Single encoder producing multi-layer hidden states (layers 9, 18, 27 concatenated for `joint_attention_dim=12288`), replacing FLUX.1's CLIP+T5 dual encoder. +5. **32 latent channels** (packed to 128), 4D RoPE with `axes_dims=(32,32,32,32)`, `theta=2000`. +6. **Classic CFG**: Two forward passes (positive + negative prompt) rather than guidance distillation. + +## Validation Results + +**Validated:** Yes +**Instance:** trn2.3xlarge (LNC=2, 4 logical cores) +**SDK:** Neuron SDK 2.29 (DLAMI 20260410), PyTorch 2.9, NxD Inference 0.9 + +### Benchmark Results (1024x1024, 30 steps, guidance_scale=4.0) + +| Metric | Value | +|--------|-------| +| Resolution | 1024x1024 | +| Inference steps | 30 | +| TP Degree | 4 | +| CFG | Classic (2 forward passes per step) | +| E2E generation time | 31.09s +/- 0.08s | +| Pipeline steps/sec | 0.96 | +| Per-step latency | 1036ms | +| Backbone forward/sec | 1.93 (2 CFG passes/step) | +| Compilation time | ~135s (2.3 min) | +| Model load time | ~20s | + +### Accuracy Validation + +| Metric | Value | +|--------|-------| +| Backbone cosine similarity (Neuron vs CPU) | 0.9987 | +| Max absolute difference | 0.15 | +| Mean absolute difference | 0.022 | + +## Usage + +```python +import torch +from flux2_klein.src.application import ( + NeuronFlux2KleinApplication, + create_flux2_klein_config, +) + +MODEL_PATH = "/shared/flux2-klein/" +COMPILE_DIR = "/tmp/flux2_klein/compiled/" + +# Configure +backbone_config = create_flux2_klein_config( + model_path=MODEL_PATH, + backbone_tp_degree=4, # trn2.3xlarge LNC=2 + dtype=torch.bfloat16, + height=1024, + width=1024, +) + +# Create application +app = NeuronFlux2KleinApplication( + model_path=MODEL_PATH, + backbone_config=backbone_config, +) + +# Compile (first time only, ~2-3 min) +app.compile(COMPILE_DIR) + +# Load compiled model +app.load(COMPILE_DIR) + +# Generate +result = app( + prompt="A cat holding a sign that says hello world", + negative_prompt="", + height=1024, + width=1024, + num_inference_steps=30, + guidance_scale=4.0, +) +result.images[0].save("output.png") +``` + +## Compatibility Matrix + +| Instance | SDK 2.29 | SDK 2.28 | +|----------|----------|----------| +| trn2.3xlarge (TP=4, LNC=2) | Validated | Not tested | +| trn2.48xlarge (TP=4, LNC=2) | Not tested | Not tested | + +## Example Checkpoints + +* [black-forest-labs/FLUX.2-klein-base-9B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B) (gated, requires access approval) + +## Testing Instructions + +```bash +# Set environment variables +export FLUX2_MODEL_PATH=/shared/flux2-klein/ +export FLUX2_COMPILE_DIR=/tmp/flux2_klein_test/ +export FLUX2_TP_DEGREE=4 + +# Run integration tests +cd contrib/models/flux2-klein/ +pytest test/integration/test_model.py -v + +# Or run standalone +python test/integration/test_model.py +``` + +## Design Decisions + +### Why NxDI instead of torch_neuronx.trace()? + +A single `Flux2SingleTransformerBlock` generates 7.4M instructions (exceeding the 5M compiler limit) due to the fused `to_qkv_mlp_proj: Linear(4096, 36864)`. With NxDI tensor parallelism (TP=4), each rank handles only 1/4 of the attention heads and MLP hidden dimensions, bringing the instruction count well within limits. + +### Why is the text encoder on CPU? + +Qwen3-8B (8.19B parameters) is too large for efficient Neuron compilation as a single traced model. Since it only runs once per image (not in the denoising loop), the ~5s CPU execution time is negligible compared to the 30-step denoising (~31s). + +### Why split the fused QKV+MLP projection? + +The HuggingFace `Flux2ParallelSelfAttention` uses a single massive `Linear(4096, 36864)` that fuses Q, K, V projections with MLP input. For TP sharding, we split this into separate `ColumnParallelLinear` layers for Q/K/V and MLP, each independently sharded across TP ranks. + +### Why split SwiGLU linear_in into gate + value? + +FLUX.2-klein uses SwiGLU activation (`SiLU(gate) * value`) in its feed-forward networks, where the HuggingFace implementation fuses gate and value into a single `linear_in` weight of shape `[2*inner_dim, dim]`. With `ColumnParallelLinear`, each TP rank receives a contiguous 1/N partition of the output dimension. If the fused weight is sharded directly, the SwiGLU split on each rank won't correspond to correct gate/value pairs. The fix is to split `linear_in` into two separate `ColumnParallelLinear` layers (`linear_in_gate` and `linear_in_value`), so each is independently and correctly sharded. The same pattern applies to the single block's fused MLP projection. + +## Known Issues + +- **Gated model**: Requires HuggingFace access approval before use +- **CPU text encoding**: Qwen3-8B runs on CPU (~5s per prompt) +- **First compilation**: Takes ~2-3 minutes on trn2.3xlarge +- **Memory**: The 9B transformer requires all 4 cores with TP=4 on trn2.3xlarge + +## Maintainer + +Jim Burtoft diff --git a/contrib/models/flux2-klein/samples/hello_world_cat.png b/contrib/models/flux2-klein/samples/hello_world_cat.png new file mode 100644 index 00000000..2febd0f0 Binary files /dev/null and b/contrib/models/flux2-klein/samples/hello_world_cat.png differ diff --git a/contrib/models/flux2-klein/src/__init__.py b/contrib/models/flux2-klein/src/__init__.py new file mode 100644 index 00000000..f2ac73a9 --- /dev/null +++ b/contrib/models/flux2-klein/src/__init__.py @@ -0,0 +1,2 @@ +# NxDI FLUX.2-klein-base-9B Diffusion Model +# Neuron-optimized implementation of the FLUX.2-klein transformer backbone diff --git a/contrib/models/flux2-klein/src/application.py b/contrib/models/flux2-klein/src/application.py new file mode 100644 index 00000000..20728ad6 --- /dev/null +++ b/contrib/models/flux2-klein/src/application.py @@ -0,0 +1,292 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FLUX.2-klein-base-9B NxDI application and pipeline integration. + +This module provides: +- NeuronFlux2KleinApplication: Top-level orchestrator for compile/load/generate +- create_flux2_klein_config: Config factory for backbone + VAE +- NeuronTransformerWrapper: Drop-in replacement for Flux2Transformer2DModel in Diffusers pipeline + +The text encoder (Qwen3-8B) runs on CPU since it only executes once per prompt. +The VAE decoder is compiled separately with torch_neuronx.trace(). +""" + +import gc +import logging +import os +import time +from typing import Optional + +import torch +import torch.nn as nn + +try: + from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, + ) + from neuronx_distributed_inference.utils.diffusers_adapter import ( + load_diffusers_config, + ) + + try: + from .modeling_flux2_klein import ( + Flux2KleinBackboneInferenceConfig, + NeuronFlux2KleinBackboneApplication, + ) + except ImportError: + from modeling_flux2_klein import ( + Flux2KleinBackboneInferenceConfig, + NeuronFlux2KleinBackboneApplication, + ) + + NEURON_AVAILABLE = True +except ImportError: + NEURON_AVAILABLE = False + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def create_flux2_klein_config( + model_path: str, + backbone_tp_degree: int = 4, + dtype=torch.bfloat16, + height: int = 1024, + width: int = 1024, +): + """ + Create NxDI configuration for FLUX.2-klein backbone. + + Args: + model_path: Path to the HuggingFace model directory + backbone_tp_degree: Tensor parallelism degree (4 for trn2.3xlarge LNC=2) + dtype: Model data type + height: Image height + width: Image width + + Returns: + Flux2KleinBackboneInferenceConfig for the transformer backbone + """ + transformer_path = os.path.join(model_path, "transformer") + + backbone_neuron_config = NeuronConfig( + tp_degree=backbone_tp_degree, + world_size=backbone_tp_degree, + torch_dtype=dtype, + ) + + backbone_config = Flux2KleinBackboneInferenceConfig( + neuron_config=backbone_neuron_config, + load_config=load_diffusers_config(transformer_path), + height=height, + width=width, + ) + + # FLUX.2-klein specific defaults + if ( + not hasattr(backbone_config, "out_channels") + or backbone_config.out_channels is None + ): + backbone_config.out_channels = backbone_config.in_channels + + # VAE scale factor for FLUX.2: 16 (8x spatial * 2x patch) + setattr(backbone_config, "vae_scale_factor", 16) + + return backbone_config + + +class NeuronTransformerWrapper(nn.Module): + """ + Drop-in replacement for Flux2Transformer2DModel in the Diffusers pipeline. + + This wrapper accepts the same forward() signature as the HuggingFace model + and delegates to the compiled NxDI backbone. + """ + + def __init__(self, neuron_app: "NeuronFlux2KleinBackboneApplication", config): + super().__init__() + self.neuron_app = neuron_app + self.config = config + # Expose config attributes that the pipeline expects + self._internal_dict = getattr(config, "_internal_dict", {}) + + def cache_context(self, name: str): + """Delegate cache_context to the NxDI backbone application.""" + return self.neuron_app.cache_context(name) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + img_ids=None, + txt_ids=None, + guidance=None, + joint_attention_kwargs=None, + return_dict=False, + **kwargs, + ): + """Match Flux2Transformer2DModel.forward() signature.""" + output = self.neuron_app.model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + img_ids=img_ids, + txt_ids=txt_ids, + guidance=guidance, + return_dict=return_dict, + ) + # Wrap output to match Diffusers Transformer2DModelOutput + if return_dict: + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + return Transformer2DModelOutput(sample=output) + return (output,) + + @property + def dtype(self): + return self.config.neuron_config.torch_dtype + + +class NeuronVAEDecoderWrapper(nn.Module): + """ + Wrapper for the VAE decoder compiled with torch_neuronx.trace(). + + Handles the FLUX.2-specific latent post-processing: + 1. Unpack latents from sequence to spatial format + 2. Batch norm normalize + 3. Unpatchify (32ch -> 8ch with 2x2 spatial expansion) + 4. VAE decode + """ + + def __init__(self, traced_model, vae_config): + super().__init__() + self.traced = traced_model + self.vae_config = vae_config + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + """ + Args: + latent: [B, C_vae, H_lat, W_lat] latent after unpatchify (e.g., [1, 8, 128, 128]) + Returns: + image: [B, 3, H, W] decoded image + """ + return self.traced(latent) + + +class NeuronFlux2KleinApplication(nn.Module): + """ + Top-level application for FLUX.2-klein-base-9B on Neuron. + + Components: + - Transformer backbone: Compiled with NxDI (TP=4 on trn2.3xlarge) + - Text encoder (Qwen3-8B): Runs on CPU (single execution per prompt) + - VAE decoder: Compiled with torch_neuronx.trace() (single core) + - Scheduler: FlowMatchEulerDiscreteScheduler from Diffusers + + Usage: + app = NeuronFlux2KleinApplication(model_path, backbone_config) + app.compile(compile_dir) + app.load(compile_dir) + image = app("A photo of a cat", height=1024, width=1024) + """ + + def __init__( + self, + model_path: str, + backbone_config: "Flux2KleinBackboneInferenceConfig", + height: int = 1024, + width: int = 1024, + ): + super().__init__() + self.model_path = model_path + self.backbone_config = backbone_config + self.height = height + self.width = width + + transformer_path = os.path.join(model_path, "transformer") + + # Create NxDI backbone application + self.backbone_app = NeuronFlux2KleinBackboneApplication( + model_path=transformer_path, + config=backbone_config, + ) + + # Diffusers pipeline (for text encoding, scheduling, VAE) + self.pipe = None + self.vae_traced = None + + def _load_pipeline(self): + """Load the Diffusers pipeline for CPU components.""" + if self.pipe is not None: + return + + from diffusers import Flux2KleinPipeline + + logger.info("Loading Flux2KleinPipeline...") + self.pipe = Flux2KleinPipeline.from_pretrained( + self.model_path, + torch_dtype=torch.bfloat16, + ) + logger.info("Pipeline loaded successfully.") + + def compile(self, compiled_model_path: str, debug: bool = False): + """Compile the transformer backbone.""" + transformer_dir = os.path.join(compiled_model_path, "transformer") + if os.path.exists(transformer_dir): + logger.info(f"Transformer already compiled at {transformer_dir}, skipping.") + else: + os.environ["BASE_COMPILE_WORK_DIR"] = os.path.join( + compiled_model_path, "compiler_workdir" + ) + self.backbone_app.compile(transformer_dir, debug) + + def load(self, compiled_model_path: str, skip_warmup: bool = False): + """Load the compiled transformer backbone.""" + transformer_dir = os.path.join(compiled_model_path, "transformer") + self.backbone_app.load(transformer_dir, skip_warmup=skip_warmup) + + # Load pipeline for CPU components + self._load_pipeline() + + # Hot-swap transformer with Neuron version + self.pipe.transformer = NeuronTransformerWrapper( + self.backbone_app, + self.backbone_config, + ) + + def __call__( + self, + prompt: str, + height: int = None, + width: int = None, + num_inference_steps: int = 50, + guidance_scale: float = 4.0, + **kwargs, + ): + """Generate an image.""" + height = height or self.height + width = width or self.width + + return self.pipe( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + **kwargs, + ) diff --git a/contrib/models/flux2-klein/src/generate_flux2_klein.py b/contrib/models/flux2-klein/src/generate_flux2_klein.py new file mode 100644 index 00000000..463b7721 --- /dev/null +++ b/contrib/models/flux2-klein/src/generate_flux2_klein.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +FLUX.2-klein-base-9B inference on Neuron using NxD Inference. + +Usage: + # Compile and generate + python generate_flux2_klein.py --checkpoint_dir /path/to/model --compile_workdir /tmp/flux2_klein/ + + # With custom prompt + python generate_flux2_klein.py -p "A futuristic cityscape at sunset" --save_image + + # Benchmark mode + python generate_flux2_klein.py --num_images 5 --save_image + +Requirements: + - trn2.3xlarge instance with LNC=2 (4 logical cores, TP=4) + - Neuron SDK 2.29+ + - diffusers >= 0.37.1 (for Flux2KleinPipeline) + - HuggingFace access to black-forest-labs/FLUX.2-klein-base-9B +""" + +import argparse +import os +import sys +import time + +import torch + +# Add src dir to path for local imports +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, SCRIPT_DIR) + +from application import ( + NeuronFlux2KleinApplication, + create_flux2_klein_config, +) + +DEFAULT_MODEL_PATH = "/shared/flux2-klein/" +DEFAULT_COMPILE_DIR = "/tmp/flux2_klein/compiled/" + + +def main(args): + print(f"FLUX.2-klein-base-9B Neuron Inference") + print(f" Model: {args.checkpoint_dir}") + print(f" TP degree: {args.tp_degree}") + print(f" Resolution: {args.height}x{args.width}") + print(f" Steps: {args.num_inference_steps}") + print(f" Guidance scale: {args.guidance_scale}") + print() + + # Create config + backbone_config = create_flux2_klein_config( + model_path=args.checkpoint_dir, + backbone_tp_degree=args.tp_degree, + dtype=torch.bfloat16, + height=args.height, + width=args.width, + ) + + # Create application + app = NeuronFlux2KleinApplication( + model_path=args.checkpoint_dir, + backbone_config=backbone_config, + height=args.height, + width=args.width, + ) + + # Compile + print("Compiling transformer backbone...") + t0 = time.time() + app.compile(args.compile_workdir) + compile_time = time.time() - t0 + print(f"Compilation: {compile_time:.1f}s") + + # Load + print("Loading compiled model...") + t0 = time.time() + app.load(args.compile_workdir) + load_time = time.time() - t0 + print(f"Load: {load_time:.1f}s") + + # Warmup + print("Warming up...") + for _ in range(2): + app( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + ) + + # Generate + print(f"\nGenerating {args.num_images} images...") + latencies = [] + for i in range(args.num_images): + t0 = time.time() + result = app( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + ) + elapsed = time.time() - t0 + latencies.append(elapsed) + print(f" Image {i + 1}: {elapsed:.2f}s") + + if args.save_image: + image = result.images[0] + filename = os.path.join(args.compile_workdir, f"output_{i + 1}.png") + image.save(filename) + print(f" Saved: {filename}") + + # Summary + import numpy as np + + latencies = np.array(latencies) + print(f"\nResults:") + print(f" Mean latency: {latencies.mean():.2f}s") + print(f" Std: {latencies.std():.2f}s") + print(f" Steps/sec: {args.num_inference_steps / latencies.mean():.2f}") + print(f" img/s: {1.0 / latencies.mean():.3f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="FLUX.2-klein-base-9B Neuron inference" + ) + parser.add_argument( + "-p", "--prompt", type=str, default="A cat holding a sign that says hello world" + ) + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("-hh", "--height", type=int, default=1024) + parser.add_argument("-w", "--width", type=int, default=1024) + parser.add_argument("-n", "--num_inference_steps", type=int, default=50) + parser.add_argument("-g", "--guidance_scale", type=float, default=4.0) + parser.add_argument( + "--tp_degree", + type=int, + default=4, + help="Tensor parallelism degree (4 for trn2.3xlarge LNC=2)", + ) + parser.add_argument( + "-c", + "--checkpoint_dir", + type=str, + default=DEFAULT_MODEL_PATH, + help="Path to the HuggingFace model directory", + ) + parser.add_argument( + "--compile_workdir", + type=str, + default=DEFAULT_COMPILE_DIR, + help="Path for compiled model artifacts", + ) + parser.add_argument("--num_images", type=int, default=1) + parser.add_argument("--save_image", action="store_true") + parser.add_argument("--profile", action="store_true") + + args = parser.parse_args() + main(args) diff --git a/contrib/models/flux2-klein/src/modeling_flux2_klein.py b/contrib/models/flux2-klein/src/modeling_flux2_klein.py new file mode 100644 index 00000000..d836e48f --- /dev/null +++ b/contrib/models/flux2-klein/src/modeling_flux2_klein.py @@ -0,0 +1,1349 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# This implementation is derived from the Diffusers library and the NxD Inference +# FLUX implementation. It has been adapted for FLUX.2-klein-base-9B architecture +# and optimized for execution on Amazon Neuron devices with tensor parallelism. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FLUX.2-klein-base-9B NxDI model implementation. + +Architecture differences from FLUX.1: +- Pre-computed modulation (shared across all blocks, not per-block) +- SwiGLU activation in feed-forward (not GELU) +- Fused QKV+MLP in single-stream blocks (Flux2ParallelSelfAttention) +- 4D RoPE (T, H, W, L) with axes_dims=(32,32,32,32), theta=2000 +- Single text encoder: Qwen3-8B (not CLIP+T5) +- 32 latent channels packed to 128 (not 16->64) +- Classic CFG (two forward passes) instead of guidance distillation +- Timestep embedding without pooled text projection +""" + +import logging +import math +import os +from typing import Any, Dict, List, Optional, Tuple +from contextlib import contextmanager + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from neuronx_distributed.parallel_layers.layer_norm import LayerNorm + from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, + SPMDRank, + ) + from neuronx_distributed.parallel_layers.mappings import ( + gather_from_tensor_model_parallel_region_with_dim, + reduce_from_tensor_model_parallel_region, + ) + from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_group, + get_tensor_model_parallel_size, + get_world_group, + ) + from neuronx_distributed_inference.models.diffusers.embeddings import ( + FluxPosEmbed, + apply_rotary_emb, + ) + from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + from neuronx_distributed_inference.utils.distributed import get_dp_rank_spmd + from nkilib.core.attention.attention_cte import attention_cte + from neuronx_distributed.utils.utils import hardware + from torch_neuronx.utils import get_platform_target + from neuronx_distributed_inference.models.application_base import ( + NeuronApplicationBase, + ) + from neuronx_distributed_inference.models.config import InferenceConfig + from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, + ) + from neuronx_distributed_inference.models.model_wrapper import ( + BaseModelInstance, + ModelWrapper, + ) + + NEURON_AVAILABLE = True + _HARDWARE = hardware(get_platform_target()) +except ImportError: + NEURON_AVAILABLE = False + _HARDWARE = None + +logger = logging.getLogger(__name__) + + +# ============================================================ +# NKI Flash Attention Wrapper +# ============================================================ + + +def attention_wrapper_sharded(query, key, value): + """ + NKI flash attention for FLUX.2-klein (bi-directional, no causal mask). + + Input shapes: query, key, value all [bs, n_head, seq_len, d_head] + Output shape: [bs, n_head, q_len, d_head] + """ + bs, n_head, q_len, d_head = query.shape + k_len = key.shape[2] + v_len = value.shape[2] + + q = query.reshape((bs * n_head, q_len, d_head)) + k = key.reshape((bs * n_head, k_len, d_head)) + v = value.reshape((bs * n_head, v_len, d_head)) + + vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "1")) + use_sharded = vc_size == 2 + scale = 1 / math.sqrt(d_head) + + if use_sharded: + attn_output = attention_cte[2]( + q, + k, + v, + scale, + causal_mask=False, + tp_q=True, + tp_k=True, + tp_out=False, + ) + else: + attn_output = attention_cte( + q, + k, + v, + scale, + causal_mask=False, + tp_q=True, + tp_k=True, + tp_out=False, + ) + + return attn_output.reshape((bs, n_head, q_len, d_head)) + + +# ============================================================ +# FLUX.2-klein Modulation +# ============================================================ + + +class NeuronFlux2Modulation(nn.Module): + """ + Pre-computed modulation for FLUX.2-klein. + + Unlike FLUX.1 where each block has its own AdaLayerNormZero with + an internal linear projection, FLUX.2 computes modulation ONCE from + the timestep embedding and passes it to ALL blocks. + + Args: + dim: Hidden dimension (4096 for Klein) + mod_param_sets: Number of parameter sets per block type. + - 2 for double blocks (shift/scale/gate for attn + shift/scale/gate for FF = 6 params) + - 1 for single blocks (shift/scale/gate for fused attn+MLP = 3 params) + """ + + def __init__(self, dim: int, mod_param_sets: int = 1, reduce_dtype=torch.bfloat16): + super().__init__() + self.mod_param_sets = mod_param_sets + self.act = nn.SiLU() + # Output: 3 * mod_param_sets * dim + # For single: 3 * 1 * 4096 = 12288 + # For double: 3 * 2 * 4096 = 24576 + self.linear = ColumnParallelLinear( + dim, + 3 * mod_param_sets * dim, + bias=False, + gather_output=True, + reduce_dtype=reduce_dtype, + ) + + def forward(self, temb: torch.Tensor) -> torch.Tensor: + """ + Args: + temb: [B, dim] timestep embedding + Returns: + modulation params: [B, 3 * mod_param_sets * dim] + """ + return self.linear(self.act(temb)) + + @staticmethod + def split(mod: torch.Tensor, n: int): + """Split modulation output into n groups of (shift, scale, gate).""" + # mod: [B, 3*n*dim] -> n groups of [B, 3*dim] -> each into (shift, scale, gate) + chunks = mod.chunk(3 * n, dim=-1) + groups = [] + for i in range(n): + shift = chunks[3 * i] + scale = chunks[3 * i + 1] + gate = chunks[3 * i + 2] + groups.append((shift, scale, gate)) + return groups + + +# ============================================================ +# FLUX.2-klein Timestep Embedding +# ============================================================ + + +class NeuronFlux2TimestepEmbedder(nn.Module): + """Container for timestep MLP to match HF key naming.""" + + def __init__(self, time_proj_dim, embedding_dim, reduce_dtype): + super().__init__() + self.linear_1 = ColumnParallelLinear( + time_proj_dim, + embedding_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.linear_2 = RowParallelLinear( + embedding_dim, + embedding_dim, + bias=False, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + ) + + +class NeuronFlux2TimestepEmbedding(nn.Module): + """ + Timestep + optional guidance embedding for FLUX.2-klein. + + Unlike FLUX.1 which also incorporates a pooled text projection, + FLUX.2's time_guidance_embed only takes timestep and guidance. + + HF key structure: + time_guidance_embed.timestep_embedder.linear_1.weight [4096, 256] + time_guidance_embed.timestep_embedder.linear_2.weight [4096, 4096] + """ + + def __init__(self, embedding_dim: int, reduce_dtype=torch.bfloat16): + super().__init__() + self.time_proj_dim = 256 + self.timestep_embedder = NeuronFlux2TimestepEmbedder( + self.time_proj_dim, embedding_dim, reduce_dtype + ) + self.guidance_proj = None # Klein base does not use guidance + + def get_timestep_sinusoidal(self, timesteps: torch.Tensor) -> torch.Tensor: + """Sinusoidal positional encoding for timesteps.""" + half_dim = self.time_proj_dim // 2 + emb = math.log(10000.0) / (half_dim - 1) + emb = torch.exp( + torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb + ) + emb = timesteps.float().unsqueeze(-1) * emb.unsqueeze(0) + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) + return emb.to(timesteps.dtype) + + def forward( + self, timestep: torch.Tensor, guidance: Optional[torch.Tensor] = None + ) -> torch.Tensor: + t_emb = self.get_timestep_sinusoidal(timestep) + temb = F.silu(self.timestep_embedder.linear_1(t_emb)) + temb = self.timestep_embedder.linear_2(temb) + return temb + + +# ============================================================ +# ============================================================ +# FLUX.2-klein Feed-Forward (Double-stream blocks) +# ============================================================ + + +class NeuronFlux2FeedForward(nn.Module): + """ + SwiGLU-based feed-forward for FLUX.2-klein double-stream blocks. + + Architecture: SwiGLU(Linear_gate(x), Linear_value(x)) -> Linear_out + + HF key structure: + ff.linear_in.weight [inner_dim*2, dim] (gate and value concatenated) + ff.linear_out.weight [dim, inner_dim] + + For TP: We split linear_in into two ColumnParallelLinear (gate + value) + to avoid the SwiGLU split crossing TP partition boundaries. + """ + + def __init__(self, dim: int, mlp_ratio: float = 3.0, reduce_dtype=torch.bfloat16): + super().__init__() + inner_dim = int(dim * mlp_ratio) + + # Two separate projections for SwiGLU gate and value + # These will be loaded by splitting the fused linear_in weight + self.linear_in_gate = ColumnParallelLinear( + dim, + inner_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.linear_in_value = ColumnParallelLinear( + dim, + inner_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.linear_out = RowParallelLinear( + inner_dim, + dim, + bias=False, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = F.silu(self.linear_in_gate(x)) + value = self.linear_in_value(x) + return self.linear_out(gate * value) + + +# ============================================================ +# FLUX.2-klein Attention (shared between double and single blocks) +# ============================================================ + + +class NeuronFlux2Attention(nn.Module): + """ + Attention module for FLUX.2-klein double-stream blocks. + + Separate Q/K/V projections for image and text streams, + concatenated for joint attention, then split back. + + Uses RMSNorm on Q/K (not LayerNorm as in FLUX.1). + No bias on Q/K/V linears. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + added_kv_proj_dim: Optional[int] = None, + bias: bool = False, + eps: float = 1e-6, + reduce_dtype=torch.bfloat16, + ): + super().__init__() + self.heads = num_attention_heads + self.head_dim = attention_head_dim + + tp_degree = get_tensor_model_parallel_size() + # Pad heads to be divisible by TP degree + padded_heads = math.ceil(num_attention_heads / tp_degree) * tp_degree + self.padded_inner_dim = padded_heads * attention_head_dim + self.heads_per_rank = padded_heads // tp_degree + + # Image stream Q/K/V + self.to_q = ColumnParallelLinear( + dim, + self.padded_inner_dim, + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_k = ColumnParallelLinear( + dim, + self.padded_inner_dim, + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_v = ColumnParallelLinear( + dim, + self.padded_inner_dim, + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + + # QK norm + self.norm_q = CustomRMSNorm(attention_head_dim, eps=eps) + self.norm_k = CustomRMSNorm(attention_head_dim, eps=eps) + + # Text stream Q/K/V (for double-stream blocks) + self.added_kv_proj_dim = added_kv_proj_dim + if added_kv_proj_dim is not None: + self.add_q_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.padded_inner_dim, + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.add_k_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.padded_inner_dim, + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.add_v_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.padded_inner_dim, + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.norm_added_q = CustomRMSNorm(attention_head_dim, eps=eps) + self.norm_added_k = CustomRMSNorm(attention_head_dim, eps=eps) + + # Output projections + # HF uses ModuleList for to_out (key: attn.to_out.0.weight) + # but direct Linear for to_add_out (key: attn.to_add_out.weight) + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + self.padded_inner_dim, + dim, + bias=False, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + ), + ] + ) + self.to_add_out = RowParallelLinear( + self.padded_inner_dim, + dim, + bias=False, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + batch_size = hidden_states.shape[0] + head_dim = self.head_dim + + # Image Q/K/V + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + query = query.view(batch_size, -1, self.heads_per_rank, head_dim).transpose( + 1, 2 + ) + key = key.view(batch_size, -1, self.heads_per_rank, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads_per_rank, head_dim).transpose( + 1, 2 + ) + + query = self.norm_q(query) + key = self.norm_k(key) + + if encoder_hidden_states is not None and self.added_kv_proj_dim is not None: + # Text Q/K/V + add_query = self.add_q_proj(encoder_hidden_states) + add_key = self.add_k_proj(encoder_hidden_states) + add_value = self.add_v_proj(encoder_hidden_states) + + add_query = add_query.view( + batch_size, -1, self.heads_per_rank, head_dim + ).transpose(1, 2) + add_key = add_key.view( + batch_size, -1, self.heads_per_rank, head_dim + ).transpose(1, 2) + add_value = add_value.view( + batch_size, -1, self.heads_per_rank, head_dim + ).transpose(1, 2) + + add_query = self.norm_added_q(add_query) + add_key = self.norm_added_k(add_key) + + # Concatenate text + image for joint attention BEFORE RoPE + # image_rotary_emb has positions for [text_seq + img_seq] + query = torch.cat([add_query, query], dim=2) + key = torch.cat([add_key, key], dim=2) + value = torch.cat([add_value, value], dim=2) + + # Apply RoPE to the full concatenated sequence + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # Flash attention + if _HARDWARE == hardware.TRN1: + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=0.0, + is_causal=False, + ) + else: + hidden_states = attention_wrapper_sharded(query, key, value) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, self.heads_per_rank * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # Split back to text and image + txt_len = encoder_hidden_states.shape[1] + encoder_attn_out = hidden_states[:, :txt_len] + hidden_attn_out = hidden_states[:, txt_len:] + + hidden_attn_out = self.to_out[0](hidden_attn_out) + encoder_attn_out = self.to_add_out(encoder_attn_out) + + return hidden_attn_out, encoder_attn_out + else: + # Single-stream: no text stream (used when called from parallel attn) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if _HARDWARE == hardware.TRN1: + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=0.0, + is_causal=False, + ) + else: + hidden_states = attention_wrapper_sharded(query, key, value) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, self.heads_per_rank * head_dim + ) + return hidden_states.to(query.dtype), None + + +# ============================================================ +# FLUX.2-klein Double-Stream Transformer Block +# ============================================================ + + +class NeuronFlux2TransformerBlock(nn.Module): + """ + Double-stream transformer block for FLUX.2-klein. + + Unlike FLUX.1's NeuronFluxTransformerBlock which has internal AdaLayerNormZero, + this block receives pre-computed modulation parameters from the parent model. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + reduce_dtype=torch.bfloat16, + ): + super().__init__() + + # Layer norms (no affine parameters -- modulated externally via shift/scale) + self.norm1 = LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = LayerNorm(dim, elementwise_affine=False, eps=eps) + + # Joint attention + self.attn = NeuronFlux2Attention( + dim=dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + added_kv_proj_dim=dim, + bias=False, + eps=eps, + reduce_dtype=reduce_dtype, + ) + + # Feed-forward (SwiGLU, not GELU) + self.norm2 = LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = NeuronFlux2FeedForward( + dim=dim, mlp_ratio=mlp_ratio, reduce_dtype=reduce_dtype + ) + + self.norm2_context = LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = NeuronFlux2FeedForward( + dim=dim, mlp_ratio=mlp_ratio, reduce_dtype=reduce_dtype + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_img: torch.Tensor, + temb_mod_txt: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + hidden_states: [B, img_seq, dim] image features + encoder_hidden_states: [B, txt_seq, dim] text features + temb_mod_img: [B, 6*dim] pre-computed image modulation (2 sets of shift/scale/gate) + temb_mod_txt: [B, 6*dim] pre-computed text modulation + image_rotary_emb: RoPE frequencies + """ + # Split modulation: 2 sets of (shift, scale, gate) for img and txt + img_mods = NeuronFlux2Modulation.split(temb_mod_img, 2) + txt_mods = NeuronFlux2Modulation.split(temb_mod_txt, 2) + + img_shift_attn, img_scale_attn, img_gate_attn = img_mods[0] + img_shift_ff, img_scale_ff, img_gate_ff = img_mods[1] + txt_shift_attn, txt_scale_attn, txt_gate_attn = txt_mods[0] + txt_shift_ff, txt_scale_ff, txt_gate_ff = txt_mods[1] + + # Attention sub-block: image stream + norm_hidden = self.norm1(hidden_states) + norm_hidden = ( + norm_hidden * (1 + img_scale_attn[:, None]) + img_shift_attn[:, None] + ) + + # Attention sub-block: text stream + norm_encoder = self.norm1_context(encoder_hidden_states) + norm_encoder = ( + norm_encoder * (1 + txt_scale_attn[:, None]) + txt_shift_attn[:, None] + ) + + # Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden, + encoder_hidden_states=norm_encoder, + image_rotary_emb=image_rotary_emb, + ) + + # Residual + gate for image + hidden_states = hidden_states + img_gate_attn[:, None] * attn_output + + # FF sub-block: image stream + norm_hidden = self.norm2(hidden_states) + norm_hidden = norm_hidden * (1 + img_scale_ff[:, None]) + img_shift_ff[:, None] + ff_output = self.ff(norm_hidden) + hidden_states = hidden_states + img_gate_ff[:, None] * ff_output + + # Residual + gate for text + encoder_hidden_states = ( + encoder_hidden_states + txt_gate_attn[:, None] * context_attn_output + ) + + # FF sub-block: text stream + norm_encoder = self.norm2_context(encoder_hidden_states) + norm_encoder = ( + norm_encoder * (1 + txt_scale_ff[:, None]) + txt_shift_ff[:, None] + ) + context_ff_output = self.ff_context(norm_encoder) + encoder_hidden_states = ( + encoder_hidden_states + txt_gate_ff[:, None] * context_ff_output + ) + + return encoder_hidden_states, hidden_states + + +# ============================================================ +# FLUX.2-klein Single-Stream Transformer Block +# ============================================================ + + +class NeuronFlux2SingleTransformerBlock(nn.Module): + """ + Single-stream transformer block for FLUX.2-klein with fused attention+MLP. + + The key architectural feature is Flux2ParallelSelfAttention: a single large + linear projects to Q, K, V, and MLP input simultaneously. The MLP uses SwiGLU. + Attention output and MLP output are concatenated and projected together. + + For NxDI with tensor parallelism, we split this fused projection into separate + TP-sharded Q/K/V and MLP projections to stay within compiler instruction limits. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + reduce_dtype=torch.bfloat16, + ): + super().__init__() + + self.norm = LayerNorm(dim, elementwise_affine=False, eps=eps) + self.mlp_hidden_dim = int(dim * mlp_ratio) + + tp_degree = get_tensor_model_parallel_size() + padded_heads = math.ceil(num_attention_heads / tp_degree) * tp_degree + self.padded_inner_dim = padded_heads * attention_head_dim + self.heads_per_rank = padded_heads // tp_degree + self.head_dim = attention_head_dim + + # Separate Q/K/V projections (split from the fused to_qkv_mlp_proj) + self.to_q = ColumnParallelLinear( + dim, + self.padded_inner_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_k = ColumnParallelLinear( + dim, + self.padded_inner_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_v = ColumnParallelLinear( + dim, + self.padded_inner_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + + # QK norm + self.norm_q = CustomRMSNorm(attention_head_dim, eps=eps) + self.norm_k = CustomRMSNorm(attention_head_dim, eps=eps) + + # MLP input projection: split the fused to_qkv_mlp_proj's MLP part + # into gate and value for correct TP sharding with SwiGLU + self.proj_mlp_gate = ColumnParallelLinear( + dim, + self.mlp_hidden_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.proj_mlp_value = ColumnParallelLinear( + dim, + self.mlp_hidden_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + + # Output: attention + MLP projected separately then summed + # (Same pattern as FLUX.1 NxDI: split proj_out into attn + mlp parts) + self.proj_out_attn = RowParallelLinear( + self.padded_inner_dim, + dim, + bias=False, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + reduce_output=False, + ) + self.proj_out_mlp = RowParallelLinear( + self.mlp_hidden_dim, + dim, + bias=False, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + reduce_output=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb_mod: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: [B, seq_len, dim] (text + image concatenated) + temb_mod: [B, 3*dim] pre-computed single-stream modulation + image_rotary_emb: RoPE frequencies for the full sequence + """ + residual = hidden_states + batch_size = hidden_states.shape[0] + + # Modulation: 1 set of (shift, scale, gate) + mods = NeuronFlux2Modulation.split(temb_mod, 1) + shift, scale, gate = mods[0] + + # Norm + modulate + norm_hidden = self.norm(hidden_states) + norm_hidden = norm_hidden * (1 + scale[:, None]) + shift[:, None] + + # Parallel: Q/K/V projections + MLP input projection + query = self.to_q(norm_hidden) + key = self.to_k(norm_hidden) + value = self.to_v(norm_hidden) + + query = query.view( + batch_size, -1, self.heads_per_rank, self.head_dim + ).transpose(1, 2) + key = key.view(batch_size, -1, self.heads_per_rank, self.head_dim).transpose( + 1, 2 + ) + value = value.view( + batch_size, -1, self.heads_per_rank, self.head_dim + ).transpose(1, 2) + + query = self.norm_q(query) + key = self.norm_k(key) + + # MLP branch (parallel with attention) - SwiGLU + mlp_gate = F.silu(self.proj_mlp_gate(norm_hidden)) + mlp_value = self.proj_mlp_value(norm_hidden) + mlp_hidden = mlp_gate * mlp_value + + # RoPE + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # Attention + if _HARDWARE == hardware.TRN1: + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=0.0, + is_causal=False, + ) + else: + attn_output = attention_wrapper_sharded(query, key, value) + + attn_output = attn_output.transpose(1, 2).reshape( + batch_size, -1, self.heads_per_rank * self.head_dim + ) + attn_output = attn_output.to(query.dtype) + + # Fused output projection: attention + MLP summed with single all-reduce + out_attn = self.proj_out_attn(attn_output) + out_mlp = self.proj_out_mlp(mlp_hidden) + proj_out = reduce_from_tensor_model_parallel_region( + out_attn + out_mlp, + process_group=self.proj_out_attn.tensor_parallel_group, + ) + + hidden_states = gate[:, None] * proj_out + hidden_states = residual + hidden_states + + return hidden_states + + +# ============================================================ +# FLUX.2-klein Output Normalization +# ============================================================ + + +class NeuronFlux2AdaLayerNormContinuous(nn.Module): + """ + Continuous adaptive layer norm for the output layer. + Produces scale/shift from the conditioning embedding (temb). + """ + + def __init__( + self, + embedding_dim: int, + conditioning_dim: int, + eps: float = 1e-6, + reduce_dtype=torch.bfloat16, + ): + super().__init__() + self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=eps) + self.linear = ColumnParallelLinear( + conditioning_dim, + embedding_dim * 2, + bias=False, + gather_output=True, + reduce_dtype=reduce_dtype, + ) + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.act(conditioning)) + scale, shift = emb.chunk(2, dim=-1) + x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] + return x + + +# ============================================================ +# FLUX.2-klein Transformer Model (Top-level) +# ============================================================ + + +class NeuronFlux2KleinTransformer(nn.Module): + """ + FLUX.2-klein-base-9B transformer backbone for NxD Inference. + + This model contains only the transformer blocks + output layers. + Input embeddings (x_embedder, context_embedder), timestep embedding, + modulation computation, and RoPE are all included in this compiled graph. + + Architecture: + - 8 double-stream blocks (Flux2TransformerBlock) + - 24 single-stream blocks (Flux2SingleTransformerBlock) + - Pre-computed modulation (shared across all blocks of each type) + - 4D RoPE (T, H, W, L) + - No guidance embedding (Klein base) + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.inner_dim = config.num_attention_heads * config.attention_head_dim + reduce_dtype = config.neuron_config.torch_dtype + + # Input projections (no bias in HF model) + self.x_embedder = ColumnParallelLinear( + config.in_channels, + self.inner_dim, + bias=False, + gather_output=True, + reduce_dtype=reduce_dtype, + ) + self.context_embedder = ColumnParallelLinear( + config.joint_attention_dim, + self.inner_dim, + bias=False, + gather_output=True, + reduce_dtype=reduce_dtype, + ) + + # Timestep embedding (no pooled text, no guidance for Klein base) + self.time_guidance_embed = NeuronFlux2TimestepEmbedding( + embedding_dim=self.inner_dim, + reduce_dtype=reduce_dtype, + ) + + # Pre-computed modulation + self.double_stream_modulation_img = NeuronFlux2Modulation( + self.inner_dim, + mod_param_sets=2, + reduce_dtype=reduce_dtype, + ) + self.double_stream_modulation_txt = NeuronFlux2Modulation( + self.inner_dim, + mod_param_sets=2, + reduce_dtype=reduce_dtype, + ) + self.single_stream_modulation = NeuronFlux2Modulation( + self.inner_dim, + mod_param_sets=1, + reduce_dtype=reduce_dtype, + ) + + # Double-stream blocks (8 for Klein) + self.transformer_blocks = nn.ModuleList( + [ + NeuronFlux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + mlp_ratio=getattr(config, "mlp_ratio", 3.0), + reduce_dtype=reduce_dtype, + ) + for _ in range(config.num_layers) + ] + ) + + # Single-stream blocks (24 for Klein) + self.single_transformer_blocks = nn.ModuleList( + [ + NeuronFlux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + mlp_ratio=getattr(config, "mlp_ratio", 3.0), + reduce_dtype=reduce_dtype, + ) + for _ in range(config.num_single_layers) + ] + ) + + # Output normalization and projection + self.norm_out = NeuronFlux2AdaLayerNormContinuous( + self.inner_dim, + self.inner_dim, + eps=1e-6, + reduce_dtype=reduce_dtype, + ) + self.proj_out = ColumnParallelLinear( + self.inner_dim, + config.patch_size * config.patch_size * config.out_channels, + bias=False, + gather_output=True, + reduce_dtype=reduce_dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + image_rotary_emb: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states: [B, img_seq, in_channels] packed latents + encoder_hidden_states: [B, txt_seq, joint_attention_dim] text embeddings + timestep: [B] diffusion timestep (0-1 range, will be scaled by 1000) + image_rotary_emb: [total_seq, head_dim, 2] precomputed RoPE + Returns: + output: [B, img_seq, patch_size^2 * out_channels] denoised prediction + """ + # Scale timestep + timestep = timestep.to(self.config.neuron_config.torch_dtype) * 1000 + + # Timestep embedding + temb = self.time_guidance_embed(timestep, guidance=None) + + # Pre-compute all modulation parameters + double_mod_img = self.double_stream_modulation_img(temb) + double_mod_txt = self.double_stream_modulation_txt(temb) + single_mod = self.single_stream_modulation(temb) + + # Input projections + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # Double-stream blocks + hidden_states, encoder_hidden_states = ModuleMarkerStartWrapper()( + hidden_states, encoder_hidden_states + ) + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_img=double_mod_img, + temb_mod_txt=double_mod_txt, + image_rotary_emb=image_rotary_emb, + ) + hidden_states, encoder_hidden_states = ModuleMarkerEndWrapper()( + hidden_states, encoder_hidden_states + ) + + # Concatenate text + image for single-stream processing + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # Single-stream blocks + for idx, block in enumerate(self.single_transformer_blocks): + if idx % 2 == 0: + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = block( + hidden_states=hidden_states, + temb_mod=single_mod, + image_rotary_emb=image_rotary_emb, + ) + if idx % 2 == 1: + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + + # Remove text tokens, keep only image tokens + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + # Output normalization and projection + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + output = ModuleMarkerEndWrapper()(output) + + return output + + +# ============================================================ +# NxDI Config and Application +# ============================================================ + + +class Flux2KleinBackboneInferenceConfig( + InferenceConfig if NEURON_AVAILABLE else object +): + """Configuration for the FLUX.2-klein backbone.""" + + def __init__(self, *args, **kwargs): + if NEURON_AVAILABLE: + super().__init__(*args, **kwargs) + + def get_required_attributes(self) -> List[str]: + return [ + "attention_head_dim", + "in_channels", + "joint_attention_dim", + "num_attention_heads", + "num_layers", + "num_single_layers", + "patch_size", + "out_channels", + "height", + "width", + ] + + +class ModelWrapperFlux2KleinBackbone(ModelWrapper if NEURON_AVAILABLE else object): + """Model wrapper for FLUX.2-klein backbone compilation.""" + + def __init__( + self, + config, + model_cls, + tag="", + compiler_args=None, + priority_model_idx=None, + model_init_kwargs={}, + ): + if NEURON_AVAILABLE: + super().__init__( + config, + model_cls, + tag, + compiler_args, + priority_model_idx, + model_init_kwargs, + ) + self.pos_embed = FluxPosEmbed(theta=2000, axes_dim=(32, 32, 32, 32)) + self.image_rotary_emb = None + self.cache_image_rotary_emb = False + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + """Generate example inputs for tracing.""" + in_channels = self.config.in_channels # 128 + joint_attention_dim = self.config.joint_attention_dim # 12288 + attention_head_dim = self.config.attention_head_dim # 128 + # For FLUX.2-klein: latent is (B, H*W, 128) after pack + # At 1024x1024: H=W=64 (1024/16), so img_seq = 4096 + vae_scale_factor = getattr(self.config, "vae_scale_factor", 16) + num_patches = self.config.height * self.config.width // (vae_scale_factor**2) + txt_seq_len = 512 + total_seq = num_patches + txt_seq_len + + dtype = self.config.neuron_config.torch_dtype + + model_inputs = ( + # hidden_states: [1, img_seq, in_channels] + torch.randn([1, num_patches, in_channels], dtype=dtype), + # encoder_hidden_states: [1, txt_seq, joint_attention_dim] + torch.randn([1, txt_seq_len, joint_attention_dim], dtype=dtype), + # timestep: [1] + torch.randn([1], dtype=dtype), + # image_rotary_emb: [total_seq, head_dim, 2] + torch.randn([total_seq, attention_head_dim, 2], dtype=dtype), + ) + + return [model_inputs] + + def get_model_instance(self): + def _create_model(): + model = self.model_cls(self.config) + model = model.to(dtype=self.config.neuron_config.torch_dtype) + model.eval() + return model + + return BaseModelInstance(module_cls=_create_model, input_output_aliases={}) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + img_ids=None, + txt_ids=None, + guidance=None, + return_dict=False, + **kwargs, + ): + """Override ModelWrapper.forward() to match Diffusers Flux2Transformer2DModel API.""" + if self.model is None: + raise RuntimeError("Forward called before load. Call load() first.") + + if timestep is not None: + timestep = timestep.to(self.config.neuron_config.torch_dtype) + + # Compute image_rotary_emb + image_rotary_emb = self.image_rotary_emb + if image_rotary_emb is None: + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = torch.stack(self.pos_embed(ids), dim=2).to( + dtype=self.config.neuron_config.torch_dtype + ) + + if self.cache_image_rotary_emb: + self.image_rotary_emb = image_rotary_emb + + output = self._forward( + hidden_states, + encoder_hidden_states, + timestep, + image_rotary_emb, + ) + return output + + +class NeuronFlux2KleinBackboneApplication( + NeuronApplicationBase if NEURON_AVAILABLE else object +): + """NxDI Application for the FLUX.2-klein transformer backbone.""" + + _model_cls = NeuronFlux2KleinTransformer + + def __init__(self, *args, **kwargs): + if NEURON_AVAILABLE: + super().__init__(*args, **kwargs) + self.model_wrapper = ModelWrapperFlux2KleinBackbone + + self.model = self.model_wrapper( + config=self.config, + model_cls=self._model_cls, + tag=self._model_cls.__name__, + compiler_args=self.get_compiler_args(), + priority_model_idx=0, + ) + self.models.append(self.model) + self.dtype = self.config.neuron_config.torch_dtype + + def get_model_wrapper_cls(self): + return ModelWrapperFlux2KleinBackbone + + def forward(self, *model_inputs, **kwargs): + return self.models[0](*model_inputs, **kwargs) + + @contextmanager + def cache_context(self, name: str): + """Cache context for matching Diffusers pipeline. Currently a no-op.""" + yield + + @contextmanager + def image_rotary_emb_cache_context(self): + self.model.cache_image_rotary_emb = True + self.model.image_rotary_emb = None + yield + self.model.cache_image_rotary_emb = False + self.model.image_rotary_emb = None + + def get_compiler_args(self): + compiler_args = "--model-type=transformer -O1" + compiler_args += " --tensorizer-options='--enable-ccop-compute-overlap'" + compiler_args += ( + " --auto-cast=none --internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + os.environ["LOCAL_WORLD_SIZE"] = str(self.config.neuron_config.world_size) + if _HARDWARE == hardware.TRN2: + os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" + return compiler_args + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: dict, config) -> dict: + """ + Convert HuggingFace Flux2Transformer2DModel state dict to NxDI format. + + Key transformations: + 1. Single blocks: HF fused `attn.to_qkv_mlp_proj.weight` is split into + to_q, to_k, to_v, proj_mlp_gate, proj_mlp_value. + 2. Single blocks: HF fused `attn.to_out.weight` is split into + proj_out_attn and proj_out_mlp. + 3. Single blocks: `attn.norm_q/k` renamed to `norm_q/k`. + 4. Double blocks: HF `ff.linear_in.weight` is split into + `ff.linear_in_gate.weight` and `ff.linear_in_value.weight`. + 5. Same for `ff_context.linear_in.weight`. + + All SwiGLU gate/value splits are needed so that TP sharding + partitions gate and value weights together, preserving correctness. + """ + inner_dim = config.num_attention_heads * config.attention_head_dim # 4096 + mlp_hidden_dim = int(inner_dim * getattr(config, "mlp_ratio", 3.0)) # 12288 + + new_state_dict = {} + + for key, value in state_dict.items(): + # ---- Single-stream block weight remapping ---- + if "single_transformer_blocks." in key: + block_idx = key.split(".")[1] + prefix = f"single_transformer_blocks.{block_idx}" + + if ".attn.to_qkv_mlp_proj.weight" in key: + # HF fused weight shape: [Q + K + V + MLP_gate + MLP_value, dim] + # = [4096 + 4096 + 4096 + 12288 + 12288, 4096] = [36864, 4096] + w = value + q_w = w[:inner_dim, :] + k_w = w[inner_dim : 2 * inner_dim, :] + v_w = w[2 * inner_dim : 3 * inner_dim, :] + mlp_gate_w = w[3 * inner_dim : 3 * inner_dim + mlp_hidden_dim, :] + mlp_value_w = w[3 * inner_dim + mlp_hidden_dim :, :] + + new_state_dict[f"{prefix}.to_q.weight"] = q_w.clone() + new_state_dict[f"{prefix}.to_k.weight"] = k_w.clone() + new_state_dict[f"{prefix}.to_v.weight"] = v_w.clone() + new_state_dict[f"{prefix}.proj_mlp_gate.weight"] = ( + mlp_gate_w.clone() + ) + new_state_dict[f"{prefix}.proj_mlp_value.weight"] = ( + mlp_value_w.clone() + ) + continue + + if ".attn.to_out.weight" in key: + # HF fused output: [dim, attn_dim + mlp_dim] + # = [4096, 4096 + 12288] = [4096, 16384] + w = value + attn_w = w[:, :inner_dim] + mlp_w = w[:, inner_dim:] + + new_state_dict[f"{prefix}.proj_out_attn.weight"] = attn_w.clone() + new_state_dict[f"{prefix}.proj_out_mlp.weight"] = mlp_w.clone() + continue + + # Rename QK norm from attn.norm_q/k to top-level norm_q/k + if ".attn.norm_q.weight" in key: + new_state_dict[f"{prefix}.norm_q.weight"] = value.contiguous() + continue + if ".attn.norm_k.weight" in key: + new_state_dict[f"{prefix}.norm_k.weight"] = value.contiguous() + continue + + # All other single block keys pass through + + # ---- Double-stream block: split SwiGLU linear_in ---- + if "transformer_blocks." in key: + # ff.linear_in.weight [24576, 4096] -> gate [12288, 4096] + value [12288, 4096] + if ".ff.linear_in.weight" in key: + block_prefix = key.rsplit(".ff.linear_in.weight", 1)[0] + w = value + gate_w = w[:mlp_hidden_dim, :] + value_w = w[mlp_hidden_dim:, :] + new_state_dict[f"{block_prefix}.ff.linear_in_gate.weight"] = ( + gate_w.clone() + ) + new_state_dict[f"{block_prefix}.ff.linear_in_value.weight"] = ( + value_w.clone() + ) + continue + + if ".ff_context.linear_in.weight" in key: + block_prefix = key.rsplit(".ff_context.linear_in.weight", 1)[0] + w = value + gate_w = w[:mlp_hidden_dim, :] + value_w = w[mlp_hidden_dim:, :] + new_state_dict[ + f"{block_prefix}.ff_context.linear_in_gate.weight" + ] = gate_w.clone() + new_state_dict[ + f"{block_prefix}.ff_context.linear_in_value.weight" + ] = value_w.clone() + continue + + # ---- All other keys pass through unchanged ---- + new_state_dict[key] = value.contiguous() + + return new_state_dict diff --git a/contrib/models/flux2-klein/test/__init__.py b/contrib/models/flux2-klein/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/flux2-klein/test/integration/__init__.py b/contrib/models/flux2-klein/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/flux2-klein/test/integration/test_model.py b/contrib/models/flux2-klein/test/integration/test_model.py new file mode 100644 index 00000000..384c50e8 --- /dev/null +++ b/contrib/models/flux2-klein/test/integration/test_model.py @@ -0,0 +1,273 @@ +""" +Integration tests for FLUX.2-klein-base-9B on Neuron. + +Tests: +1. test_smoke_pipeline_loads: Pipeline loads without errors +2. test_generation_produces_image: Generates a 1024x1024 image +3. test_accuracy_vs_cpu: SSIM > 0.7 between Neuron and CPU-generated reference +4. test_warm_generation_time: Warm generation < 120s + +Requirements: + - trn2.3xlarge with LNC=2 (4 logical cores) + - Neuron SDK 2.29+ + - diffusers >= 0.37.1 + - Model downloaded to /shared/flux2-klein/ or $FLUX2_MODEL_PATH + +Run: + pytest test_model.py -v + # Or standalone: + python test_model.py +""" + +import gc +import os +import sys +import time + +import numpy as np +import pytest +import torch + +# Add contrib to path +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +CONTRIB_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, "..", "..")) +sys.path.insert(0, os.path.join(CONTRIB_ROOT, "..")) + +MODEL_PATH = os.environ.get("FLUX2_MODEL_PATH", "/shared/flux2-klein/") +COMPILE_DIR = os.environ.get("FLUX2_COMPILE_DIR", "/tmp/flux2_klein_test/") +TP_DEGREE = int(os.environ.get("FLUX2_TP_DEGREE", "4")) +HEIGHT = 1024 +WIDTH = 1024 +NUM_STEPS = 50 +GUIDANCE_SCALE = 4.0 +PROMPT = "A cat holding a sign that says hello world" + + +def compute_ssim(img1: np.ndarray, img2: np.ndarray) -> float: + """Compute SSIM between two images (H, W, C) in [0, 255] uint8.""" + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + mu1 = img1.mean() + mu2 = img2.mean() + sigma1_sq = ((img1 - mu1) ** 2).mean() + sigma2_sq = ((img2 - mu2) ** 2).mean() + sigma12 = ((img1 - mu1) * (img2 - mu2)).mean() + + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 + + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2) + ) + return float(ssim) + + +@pytest.fixture(scope="module") +def neuron_app(): + """Create, compile, and load the FLUX.2-klein NxDI application.""" + from flux2_klein.src.application import ( + NeuronFlux2KleinApplication, + create_flux2_klein_config, + ) + + backbone_config = create_flux2_klein_config( + model_path=MODEL_PATH, + backbone_tp_degree=TP_DEGREE, + dtype=torch.bfloat16, + height=HEIGHT, + width=WIDTH, + ) + + app = NeuronFlux2KleinApplication( + model_path=MODEL_PATH, + backbone_config=backbone_config, + height=HEIGHT, + width=WIDTH, + ) + + # Compile + app.compile(COMPILE_DIR) + # Load + app.load(COMPILE_DIR) + + # Warmup + app( + prompt=PROMPT, + height=HEIGHT, + width=WIDTH, + num_inference_steps=NUM_STEPS, + guidance_scale=GUIDANCE_SCALE, + ) + + yield app + + # Cleanup + del app + gc.collect() + + +def test_smoke_pipeline_loads(neuron_app): + """Pipeline loads without errors and has required attributes.""" + assert neuron_app is not None + assert neuron_app.pipe is not None + assert hasattr(neuron_app, "backbone_app") + assert neuron_app.pipe.transformer is not None + assert neuron_app.pipe.text_encoder is not None + assert neuron_app.pipe.vae is not None + + +def test_generation_produces_image(neuron_app): + """Generates an image at the expected resolution.""" + result = neuron_app( + prompt=PROMPT, + height=HEIGHT, + width=WIDTH, + num_inference_steps=NUM_STEPS, + guidance_scale=GUIDANCE_SCALE, + ) + + assert result is not None + assert hasattr(result, "images") + assert len(result.images) == 1 + + image = result.images[0] + assert image.size == (WIDTH, HEIGHT), ( + f"Expected ({WIDTH}, {HEIGHT}), got {image.size}" + ) + + # Save for inspection + os.makedirs(os.path.join(COMPILE_DIR, "test_outputs"), exist_ok=True) + image.save(os.path.join(COMPILE_DIR, "test_outputs", "test_generation.png")) + + +def test_accuracy_vs_cpu(neuron_app): + """ + Compare Neuron output against CPU reference using SSIM. + + Since FLUX.2-klein uses classic CFG (not guidance distillation), + the Neuron output should closely match CPU output when using the + same random seed and parameters. + """ + from neuronx_distributed_inference.utils.random import set_random_seed + + # Generate on Neuron + set_random_seed(42) + neuron_result = neuron_app( + prompt=PROMPT, + height=HEIGHT, + width=WIDTH, + num_inference_steps=NUM_STEPS, + guidance_scale=GUIDANCE_SCALE, + ) + neuron_image = np.array(neuron_result.images[0]) + + # Generate CPU reference + from diffusers import Flux2KleinPipeline + + set_random_seed(42) + cpu_pipe = Flux2KleinPipeline.from_pretrained( + MODEL_PATH, + torch_dtype=torch.bfloat16, + ) + cpu_result = cpu_pipe( + prompt=PROMPT, + height=HEIGHT, + width=WIDTH, + num_inference_steps=NUM_STEPS, + guidance_scale=GUIDANCE_SCALE, + ) + cpu_image = np.array(cpu_result.images[0]) + + del cpu_pipe + gc.collect() + + # Compare + ssim = compute_ssim(neuron_image, cpu_image) + print(f"SSIM (Neuron vs CPU): {ssim:.4f}") + + # Save both for visual comparison + os.makedirs(os.path.join(COMPILE_DIR, "test_outputs"), exist_ok=True) + from PIL import Image + + Image.fromarray(neuron_image).save( + os.path.join(COMPILE_DIR, "test_outputs", "accuracy_neuron.png") + ) + Image.fromarray(cpu_image).save( + os.path.join(COMPILE_DIR, "test_outputs", "accuracy_cpu.png") + ) + + assert ssim > 0.7, f"SSIM {ssim:.4f} below threshold 0.7" + + +def test_warm_generation_time(neuron_app): + """Warm generation should complete in reasonable time.""" + t0 = time.time() + neuron_app( + prompt=PROMPT, + height=HEIGHT, + width=WIDTH, + num_inference_steps=NUM_STEPS, + guidance_scale=GUIDANCE_SCALE, + ) + elapsed = time.time() - t0 + print(f"Warm generation time: {elapsed:.2f}s") + + # FLUX.2-klein with 50 steps + CFG (2x forward) on TP=4 should be < 120s + assert elapsed < 120, f"Generation took {elapsed:.2f}s, expected < 120s" + + +# Standalone runner +if __name__ == "__main__": + from flux2_klein.src.application import ( + NeuronFlux2KleinApplication, + create_flux2_klein_config, + ) + + print("=" * 60) + print("FLUX.2-klein-base-9B Integration Tests") + print("=" * 60) + + backbone_config = create_flux2_klein_config( + model_path=MODEL_PATH, + backbone_tp_degree=TP_DEGREE, + dtype=torch.bfloat16, + height=HEIGHT, + width=WIDTH, + ) + + app = NeuronFlux2KleinApplication( + model_path=MODEL_PATH, + backbone_config=backbone_config, + height=HEIGHT, + width=WIDTH, + ) + + print("\n[1/5] Compiling...") + app.compile(COMPILE_DIR) + + print("\n[2/5] Loading...") + app.load(COMPILE_DIR) + + print("\n[3/5] Warmup...") + app( + prompt=PROMPT, + height=HEIGHT, + width=WIDTH, + num_inference_steps=NUM_STEPS, + guidance_scale=GUIDANCE_SCALE, + ) + + print("\n[4/5] test_smoke_pipeline_loads") + test_smoke_pipeline_loads(app) + print(" PASSED") + + print("\n[5/5] test_generation_produces_image") + test_generation_produces_image(app) + print(" PASSED") + + print("\n[6/6] test_warm_generation_time") + test_warm_generation_time(app) + print(" PASSED") + + print("\nAll tests passed!") diff --git a/contrib/models/flux2-klein/test/unit/__init__.py b/contrib/models/flux2-klein/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b