Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/flux_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ trainer:
warmup_steps: 2000

# AdamW Optimizer Settings
max_lr: 0.0001 # Maximum learning rate for AdamW optimizer
min_lr: 0.00001 # Minimum learning rate for cosine decay schedule
adam_max_lr: 0.0001 # Maximum learning rate for AdamW optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adam_max_lr-> adamw_max_lr?

min_lr_ratio: 0.1 # Minimum learning rate for cosine decay schedule
weight_decay: 0.0 # L2 regularization weight decay coefficient
adam_betas: [ 0.9, 0.95 ] # Beta coefficients for AdamW momentum terms [beta1, beta2]

Expand Down
4 changes: 2 additions & 2 deletions configs/flux_tiny_imagenet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ trainer:
warmup_steps: 1000

# AdamW Optimizer Settings
max_lr: 0.0003 # Maximum learning rate for AdamW optimizer
min_lr: 0.00001 # Minimum learning rate for cosine decay schedule
adam_max_lr: 0.0003 # Maximum learning rate for AdamW optimizer
min_lr_ratio: 0.1 # Minimum learning rate for cosine decay schedule
weight_decay: 0.0 # L2 regularization weight decay coefficient
adam_betas: [ 0.9, 0.95 ] # Beta coefficients for AdamW momentum terms [beta1, beta2]

Expand Down
216 changes: 216 additions & 0 deletions configs/flux_tiny_imagenet_muon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Example configuration for Flow Matching training

model:
# Structured component specifications
vae:
module: "models.flux_vae.AutoEncoder"
params:
in_channels: 3
z_channels: 16
scale_factor: 0.3611
shift_factor: 0.1159

text_encoder:
module: "models.vanilla_embedder.VanillaEmbedder"
params:
vocab_size: 1001
embedding_dim: 768
return_datum_lens: true
# Check the "data" section of this file for the correct paths
embeddings_path: "/mnt/localssd/data/imagenet/meta.pt::clip_embeddings"
txt_to_label_path: "/mnt/localssd/data/imagenet/meta.pt::txt_to_label"

clip_encoder:
module: "models.vanilla_embedder.VanillaEmbedder"
params:
vocab_size: 1001
embedding_dim: 768
return_datum_lens: false
# Check the "data" section of this file for the correct paths
embeddings_path: "/mnt/localssd/data/imagenet/meta.pt::clip_embeddings"
txt_to_label_path: "/mnt/localssd/data/imagenet/meta.pt::txt_to_label"

patchifier:
module: "models.patchifier.Patchifier"
params:
patch_size: [ 1, 2, 2 ] # [frames, height, width] - DiT typical
vae_latent_channels: 16 # VAE latent channels
# must agree with vae
vae_compression_factors: [ 1, 8, 8 ] # VAE compression factors [frames, height, width]

denoiser:
module: "models.flux_denoiser.FluxDenoiser"
params:
d_model: 1024
d_head: 64
# n_ds_blocks: 19
# n_ss_blocks: 38
n_ds_blocks: 8
n_ss_blocks: 16
d_txt: 768
d_vec: 768
# must match vae_latent_channels * prod(vae_compression_factors) in patchifier
d_img: 64
# must have sum equal to d_head;
# must have number of elements equal to patch_size in patchifier
rope_axis_dim: [ 8, 28, 28 ] # tyx coordinates
guidance_embed: false
fsdp:
meta_device_init: true
shard_size: 8
param_dtype: "bf16"
reduce_dtype: "fp32"
ac_freq: 0
blocks_attr: [ "double_blocks", "single_blocks" ]
reshard_after_forward_policy: "default"
blocks_per_shard_group: 12 # -1

time_sampler:
module: "utils_fm.noiser.TimeSampler"
params:
use_logit_normal: true
mu: 0.0 # Mean of the logit normal distribution
sigma: 1.0 # Standard deviation of the logit normal distribution

time_warper:
module: "utils_fm.noiser.TimeWarper"
params:
base_len: 256 # Base sequence length
base_shift: 0.5 # Base shift parameter for time warping
max_len: 4096 # Maximum sequence length
max_shift: 1.15 # Maximum shift parameter for time warping

time_weighter:
module: "utils_fm.noiser.TimeWeighter"
params:
use_logit_normal: false
mu: 0.0 # Mean of the logit normal distribution
sigma: 1.0 # Standard deviation of the logit normal distribution

flow_noiser:
module: "utils_fm.noiser.FlowNoiser"
params:
compute_dtype: "fp32" # Internal computation dtype: "fp32", "fp16", "bf16"

balancer:
use_dit_balancer: false # Use DIT balancer for sequence length balancing
dit_balancer_specs: "g1n8" # Bag specifications for DIT balancer
dit_balancer_gamma: 0.5 # Gamma parameter for DIT workload estimator

trainer:
module: "trainers.dit_trainer.DiTTrainer"
params:
# Text dropout probability
txt_drop_prob: 0.1

# EMA Settings
ema_decay: 0.999

# Training Schedule
max_steps: 1_000_000
warmup_steps: 200

# Optimizer Settings
adam_max_lr: 0.0003 # Maximum learning rate for AdamW optimizer
adam_betas: [ 0.9, 0.95 ] # Beta coefficients for AdamW momentum terms [beta1, beta2]
use_muon: true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 🚀

muon_max_lr: 0.02
muon_mu: 0.95
muon_adjust_lr: "spectral_norm"
muon_param_patterns:
- "double_blocks.*.txt_attn.qkv.weight"
- "double_blocks.*.txt_attn.proj.weight"
- "double_blocks.*.img_attn.qkv.weight"
- "double_blocks.*.img_attn.proj.weight"
- "double_blocks.*.txt_mlp.*.weight"
- "double_blocks.*.img_mlp.*.weight"
- "single_blocks.*.linear1.weight"
- "single_blocks.*.linear2.weight"
# Note: Excludes txt_in, img_in, final_layer.linear (input/output projections)
# Note: Exclude modulation weights
# Note: Excludes all biases by explicitly matching only .weight

min_lr_ratio: 0.1 # Minimum learning rate for cosine decay schedule
weight_decay: 0.0 # L2 regularization weight decay coefficient

# Gradient accumulation settings
total_batch_size: 1024

# Gradient Safeguarding Settings
gradient_clip_norm: 1.0
grad_norm_spike_threshold: 2.0
grad_norm_spike_detection_start_step: 1000

# Checkpoint Settings
init_ckpt: null # Optional: "path/to/checkpoint"
init_ckpt_load_plan: "ckpt_model:mem_model,ckpt_ema:mem_ema,ckpt_optimizer:mem_optimizer,ckpt_scheduler:mem_scheduler,ckpt_step:mem_step"
ckpt_freq: 2000
exp_dir: "./experiments/flux_tiny_imagenet_muon"

# Logging Settings
wandb_mode: "disabled" # online, offline, or disabled (disabled = no wandb logging)
wandb_project: "minFM"
wandb_name: "flux_tiny_imagenet_muon" # Optional: experiment name, defaults to wandb auto-naming
# wandb_entity: <your-wandb-entity> # Optional: wandb entity/organization
# wandb_host: <your-wandb-host> # Optional: wandb host # Optional: Hostname for custom-hosted setup
log_freq: 20

# Validation Settings
val_freq: 10_000
val_num_samples: 10_000

# Inference Settings
inference_at_start: false
inference_then_exit: false
inference_freq: 2000

inferencer:
ckpt_dir: "./experiments/flux_tiny_imagenet/checkpoints/step_00098000"
inference_ops_args:
use_ema: false
prompt_file: "./resources/inference_imagenet_prompts.txt"
output_dir: "./experiments/inference_results_flux_tiny_imagenet"
img_fhw: [ 1, 256, 256 ]
samples_per_prompt: 4
num_steps: 50
neg_prompt: ""
cfg_scale: 5.0
eta: 1.0
file_ext: "jpg"
per_gpu_bs: 16
guidance: null
sample_method: "ddim"
save_as_npz: false

### Use the following inference setup for computing FID scores
### You can try different cfg_scale
### Usually lower cfg_scale leads to better FID scores, but visual quality may be worse
# inferencer:
# ckpt_dir: "./experiments/flux_tiny_imagenet/step_00380000"
# inference_ops_args:
# use_ema: true
# prompt_file: "./resources/inference_imagenet_1kcls.txt"
# output_dir: "./experiments/inference_results_flux_tiny_imagenet-cfg5"
# img_fhw: [ 1, 256, 256 ]
# samples_per_prompt: 50
# num_steps: 50
# neg_prompt: ""
# cfg_scale: 5.0
# eta: 1.0
# file_ext: "jpg"
# per_gpu_bs: 16
# guidance: null
# sample_method: "ddim"
# save_as_npz: true


data:
module: "data.imagenet.ImagenetDataModule"
params:
batch_size: 128
resolution: 256
num_workers: 16
p_horizon_flip: 0.5
data_root_dir: "$MINFM_DATA_DIR/imagenet"
image_metas_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::image_metas"
label_to_txt_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::label_to_txt"
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"knapformer==0.1.1",
"torch_fidelity==0.4.0-beta",
"scipy==1.15.3",
"dion==0.1.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -65,6 +66,7 @@ dev-dependencies = [
flash-attn = { url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.2/flash_attn-2.8.2+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl" }
knapformer = { git = "https://github.com/Kai-46/KnapFormer.git" }
torch_fidelity = { git = "https://github.com/toshas/torch-fidelity.git" }
dion = { git = "https://github.com/Kai-46/dion.git" }

# Ruff Configuration
[tool.ruff]
Expand Down
49 changes: 28 additions & 21 deletions trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,35 @@ class BaseTrainerParams(BaseParams):
"""Base parameters for all trainers - contains common training configuration."""

# Learning rate and optimizer settings
max_lr: float = 0.0001
min_lr: float = 0.00001
# AdamW
adam_max_lr: float = 0.0003
adam_betas: tuple[float, float] = field(default_factory=lambda: (0.9, 0.95))
# Muon
use_muon: bool = False
muon_max_lr: float = 0.02
muon_mu: float = 0.95
muon_adjust_lr: str = "spectral_norm"
muon_param_patterns: list[str] = field(
default_factory=lambda: [
"double_blocks.*.txt_attn.qkv.weight",
"double_blocks.*.txt_attn.proj.weight",
"double_blocks.*.img_attn.qkv.weight",
"double_blocks.*.img_attn.proj.weight",
"double_blocks.*.txt_mlp.*.weight",
"double_blocks.*.img_mlp.*.weight",
"single_blocks.*.linear1.weight", # Contains qkv and mlp_in
"single_blocks.*.linear2.weight", # Contains proj and mlp_out
# Note: Excludes txt_in, img_in, final_layer.linear (input/output projections)
# Note: Exclude modulation weights
# Note: Excludes all biases by explicitly matching only .weight
]
)

# Shared between AdamW and Muon
min_lr_ratio: float = 0.1
warmup_steps: int = 2000
max_steps: int = 1_000_000
weight_decay: float = 0.0
adam_betas: tuple[float, float] = field(default_factory=lambda: (0.9, 0.95))

# Gradient accumulation settings
total_batch_size: int = -1
Expand Down Expand Up @@ -91,23 +114,7 @@ def load_config(yaml_path: str | Path) -> dict[str, Any]:
if config is None:
config = {}

# Recursively expand environment variables in all string values
def _expand_env_vars(value: Any) -> Any:
"""Recursively expand $VAR and ${VAR} in strings within nested structures."""
if isinstance(value, dict):
return {k: _expand_env_vars(v) for k, v in value.items()}
if isinstance(value, list):
return [_expand_env_vars(v) for v in value]
if isinstance(value, tuple):
return tuple(_expand_env_vars(v) for v in value)
if isinstance(value, str):
# os.path.expandvars leaves unknown vars unchanged, which is desired
return os.path.expandvars(value)
return value

expanded_config = _expand_env_vars(config)

return cast(dict[str, Any], expanded_config)
return cast(dict[str, Any], config)


def setup_distributed() -> tuple[torch.device, int, int, int]:
Expand Down Expand Up @@ -190,4 +197,4 @@ def setup_experiment_dirs(exp_dir: str, config: dict[str, Any]) -> tuple[str, st
# Wait for all processes to catch up
dist.barrier()

return run_dir, ckpt_dir
return run_dir, ckpt_dir
Loading