Skip to content

Add FLUX.2-klein-base-9B contrib model#146

Open
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/flux2-klein
Open

Add FLUX.2-klein-base-9B contrib model#146
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/flux2-klein

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Note: The below template includes items meant for model contributions only. For other contributions such as bug fixes, features, etc., only fill out the relevant portions of the form.

Description

Add FLUX.2-klein-base-9B (9.08B parameter diffusion transformer) as a contrib model with NxD Inference tensor parallelism on trn2.3xlarge.

FLUX.2-klein differs from FLUX.1 in several key ways: SwiGLU activation (vs GELU), pre-computed modulation, fused QKV+MLP projections in single-stream blocks, Qwen3-8B text encoder, and 32 latent channels with 4D RoPE. The implementation splits all fused SwiGLU projections into separate ColumnParallelLinear layers for correct TP sharding, and decomposes the massive fused to_qkv_mlp_proj into independent Q/K/V and MLP projections to stay within compiler instruction limits.

Model Information

Model Name: FLUX.2-klein-base-9B

Model Architecture: Diffusion transformer (DiT) with 8 double-stream MMDiT blocks + 24 single-stream DiT blocks, 4D RoPE, SwiGLU, pre-computed modulation

Purpose: Text-to-image generation

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)

    • Integration test validates backbone cosine similarity vs CPU reference (0.9987 achieved)
    • Test compiles and runs the model on Neuron
  • README.md with the following sections:

    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types
    • Example Checkpoints: Links to compatible model checkpoints
    • Testing Instructions: Command to run the test suite
  • Source Code (src/)

    • modeling_flux2_klein.py: Full NxDI model implementation (~1350 lines)
    • application.py: Config factory, NeuronTransformerWrapper, NeuronFlux2KleinApplication
    • generate_flux2_klein.py: CLI entry point

Optional Components

  • Unit Tests (test/unit/ directory present, tests pending)

Folder Structure

/contrib/models/flux2-klein/
  README.md
  /samples
    hello_world_cat.png
  /src
    __init__.py
    modeling_flux2_klein.py
    application.py
    generate_flux2_klein.py
  /test
    __init__.py
    /unit
      __init__.py
    /integration
      __init__.py
      test_model.py

Testing

How did you test this change?

Tested on trn2.3xlarge (LNC=2, TP=4) with Neuron SDK 2.29 (DLAMI 20260410). Compiled backbone, ran direct backbone comparison against HF CPU reference with identical inputs, then ran full end-to-end pipeline with 5 warm generations for benchmarking.

Test Results:

Backbone accuracy:

  • Cosine similarity (Neuron vs CPU): 0.9987
  • Max absolute difference: 0.15
  • Mean absolute difference: 0.022

Benchmark (1024x1024, 30 steps, guidance_scale=4.0, classic CFG):

  • E2E generation time: 31.09s +/- 0.08s
  • Pipeline steps/sec: 0.96
  • Per-step latency: 1036ms
  • Backbone forward/sec: 1.93
  • Compilation time: ~135s (2.3 min)
  • Model load time: ~20s

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.29
  • Instance Type(s): trn2.3xlarge (LNC=2, TP=4)
  • PyTorch Version: 2.9
  • Python Version: 3.12

Additional Information

  • FLUX.2-klein is a gated model on HuggingFace requiring access approval before use
  • Text encoder (Qwen3-8B) runs on CPU (~5s per prompt) since it only executes once per image
  • The key technical challenge was SwiGLU + TP sharding: HuggingFace fuses gate and value into a single linear, but ColumnParallelLinear partitions contiguously, breaking the gate/value split boundary. Solved by splitting into separate parallel projections.
  • Classic CFG requires 2 forward passes per denoising step (positive + negative prompt)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant