diff --git a/configs/flux_inference.yaml b/configs/flux_inference.yaml index 404ddeb..73a9f19 100644 --- a/configs/flux_inference.yaml +++ b/configs/flux_inference.yaml @@ -127,8 +127,8 @@ trainer: warmup_steps: 2000 # AdamW Optimizer Settings - max_lr: 0.0001 # Maximum learning rate for AdamW optimizer - min_lr: 0.00001 # Minimum learning rate for cosine decay schedule + adam_max_lr: 0.0001 # Maximum learning rate for AdamW optimizer + min_lr_ratio: 0.1 # Minimum learning rate for cosine decay schedule weight_decay: 0.0 # L2 regularization weight decay coefficient adam_betas: [ 0.9, 0.95 ] # Beta coefficients for AdamW momentum terms [beta1, beta2] diff --git a/configs/flux_tiny_imagenet.yaml b/configs/flux_tiny_imagenet.yaml index f2f4e52..cef1530 100644 --- a/configs/flux_tiny_imagenet.yaml +++ b/configs/flux_tiny_imagenet.yaml @@ -111,8 +111,8 @@ trainer: warmup_steps: 1000 # AdamW Optimizer Settings - max_lr: 0.0003 # Maximum learning rate for AdamW optimizer - min_lr: 0.00001 # Minimum learning rate for cosine decay schedule + adam_max_lr: 0.0003 # Maximum learning rate for AdamW optimizer + min_lr_ratio: 0.1 # Minimum learning rate for cosine decay schedule weight_decay: 0.0 # L2 regularization weight decay coefficient adam_betas: [ 0.9, 0.95 ] # Beta coefficients for AdamW momentum terms [beta1, beta2] diff --git a/configs/flux_tiny_imagenet_muon.yaml b/configs/flux_tiny_imagenet_muon.yaml new file mode 100644 index 0000000..6e30b31 --- /dev/null +++ b/configs/flux_tiny_imagenet_muon.yaml @@ -0,0 +1,216 @@ +# Example configuration for Flow Matching training + +model: + # Structured component specifications + vae: + module: "models.flux_vae.AutoEncoder" + params: + in_channels: 3 + z_channels: 16 + scale_factor: 0.3611 + shift_factor: 0.1159 + + text_encoder: + module: "models.vanilla_embedder.VanillaEmbedder" + params: + vocab_size: 1001 + embedding_dim: 768 + return_datum_lens: true + # Check the "data" section of this file for the correct paths + embeddings_path: "/mnt/localssd/data/imagenet/meta.pt::clip_embeddings" + txt_to_label_path: "/mnt/localssd/data/imagenet/meta.pt::txt_to_label" + + clip_encoder: + module: "models.vanilla_embedder.VanillaEmbedder" + params: + vocab_size: 1001 + embedding_dim: 768 + return_datum_lens: false + # Check the "data" section of this file for the correct paths + embeddings_path: "/mnt/localssd/data/imagenet/meta.pt::clip_embeddings" + txt_to_label_path: "/mnt/localssd/data/imagenet/meta.pt::txt_to_label" + + patchifier: + module: "models.patchifier.Patchifier" + params: + patch_size: [ 1, 2, 2 ] # [frames, height, width] - DiT typical + vae_latent_channels: 16 # VAE latent channels + # must agree with vae + vae_compression_factors: [ 1, 8, 8 ] # VAE compression factors [frames, height, width] + + denoiser: + module: "models.flux_denoiser.FluxDenoiser" + params: + d_model: 1024 + d_head: 64 + # n_ds_blocks: 19 + # n_ss_blocks: 38 + n_ds_blocks: 8 + n_ss_blocks: 16 + d_txt: 768 + d_vec: 768 + # must match vae_latent_channels * prod(vae_compression_factors) in patchifier + d_img: 64 + # must have sum equal to d_head; + # must have number of elements equal to patch_size in patchifier + rope_axis_dim: [ 8, 28, 28 ] # tyx coordinates + guidance_embed: false + fsdp: + meta_device_init: true + shard_size: 8 + param_dtype: "bf16" + reduce_dtype: "fp32" + ac_freq: 0 + blocks_attr: [ "double_blocks", "single_blocks" ] + reshard_after_forward_policy: "default" + blocks_per_shard_group: 12 # -1 + + time_sampler: + module: "utils_fm.noiser.TimeSampler" + params: + use_logit_normal: true + mu: 0.0 # Mean of the logit normal distribution + sigma: 1.0 # Standard deviation of the logit normal distribution + + time_warper: + module: "utils_fm.noiser.TimeWarper" + params: + base_len: 256 # Base sequence length + base_shift: 0.5 # Base shift parameter for time warping + max_len: 4096 # Maximum sequence length + max_shift: 1.15 # Maximum shift parameter for time warping + + time_weighter: + module: "utils_fm.noiser.TimeWeighter" + params: + use_logit_normal: false + mu: 0.0 # Mean of the logit normal distribution + sigma: 1.0 # Standard deviation of the logit normal distribution + + flow_noiser: + module: "utils_fm.noiser.FlowNoiser" + params: + compute_dtype: "fp32" # Internal computation dtype: "fp32", "fp16", "bf16" + + balancer: + use_dit_balancer: false # Use DIT balancer for sequence length balancing + dit_balancer_specs: "g1n8" # Bag specifications for DIT balancer + dit_balancer_gamma: 0.5 # Gamma parameter for DIT workload estimator + +trainer: + module: "trainers.dit_trainer.DiTTrainer" + params: + # Text dropout probability + txt_drop_prob: 0.1 + + # EMA Settings + ema_decay: 0.999 + + # Training Schedule + max_steps: 1_000_000 + warmup_steps: 200 + + # Optimizer Settings + adam_max_lr: 0.0003 # Maximum learning rate for AdamW optimizer + adam_betas: [ 0.9, 0.95 ] # Beta coefficients for AdamW momentum terms [beta1, beta2] + use_muon: true + muon_max_lr: 0.02 + muon_mu: 0.95 + muon_adjust_lr: "spectral_norm" + muon_param_patterns: + - "double_blocks.*.txt_attn.qkv.weight" + - "double_blocks.*.txt_attn.proj.weight" + - "double_blocks.*.img_attn.qkv.weight" + - "double_blocks.*.img_attn.proj.weight" + - "double_blocks.*.txt_mlp.*.weight" + - "double_blocks.*.img_mlp.*.weight" + - "single_blocks.*.linear1.weight" + - "single_blocks.*.linear2.weight" + # Note: Excludes txt_in, img_in, final_layer.linear (input/output projections) + # Note: Exclude modulation weights + # Note: Excludes all biases by explicitly matching only .weight + + min_lr_ratio: 0.1 # Minimum learning rate for cosine decay schedule + weight_decay: 0.0 # L2 regularization weight decay coefficient + + # Gradient accumulation settings + total_batch_size: 1024 + + # Gradient Safeguarding Settings + gradient_clip_norm: 1.0 + grad_norm_spike_threshold: 2.0 + grad_norm_spike_detection_start_step: 1000 + + # Checkpoint Settings + init_ckpt: null # Optional: "path/to/checkpoint" + init_ckpt_load_plan: "ckpt_model:mem_model,ckpt_ema:mem_ema,ckpt_optimizer:mem_optimizer,ckpt_scheduler:mem_scheduler,ckpt_step:mem_step" + ckpt_freq: 2000 + exp_dir: "./experiments/flux_tiny_imagenet_muon" + + # Logging Settings + wandb_mode: "disabled" # online, offline, or disabled (disabled = no wandb logging) + wandb_project: "minFM" + wandb_name: "flux_tiny_imagenet_muon" # Optional: experiment name, defaults to wandb auto-naming + # wandb_entity: # Optional: wandb entity/organization + # wandb_host: # Optional: wandb host # Optional: Hostname for custom-hosted setup + log_freq: 20 + + # Validation Settings + val_freq: 10_000 + val_num_samples: 10_000 + + # Inference Settings + inference_at_start: false + inference_then_exit: false + inference_freq: 2000 + +inferencer: + ckpt_dir: "./experiments/flux_tiny_imagenet/checkpoints/step_00098000" + inference_ops_args: + use_ema: false + prompt_file: "./resources/inference_imagenet_prompts.txt" + output_dir: "./experiments/inference_results_flux_tiny_imagenet" + img_fhw: [ 1, 256, 256 ] + samples_per_prompt: 4 + num_steps: 50 + neg_prompt: "" + cfg_scale: 5.0 + eta: 1.0 + file_ext: "jpg" + per_gpu_bs: 16 + guidance: null + sample_method: "ddim" + save_as_npz: false + +### Use the following inference setup for computing FID scores +### You can try different cfg_scale +### Usually lower cfg_scale leads to better FID scores, but visual quality may be worse +# inferencer: +# ckpt_dir: "./experiments/flux_tiny_imagenet/step_00380000" +# inference_ops_args: +# use_ema: true +# prompt_file: "./resources/inference_imagenet_1kcls.txt" +# output_dir: "./experiments/inference_results_flux_tiny_imagenet-cfg5" +# img_fhw: [ 1, 256, 256 ] +# samples_per_prompt: 50 +# num_steps: 50 +# neg_prompt: "" +# cfg_scale: 5.0 +# eta: 1.0 +# file_ext: "jpg" +# per_gpu_bs: 16 +# guidance: null +# sample_method: "ddim" +# save_as_npz: true + + +data: + module: "data.imagenet.ImagenetDataModule" + params: + batch_size: 128 + resolution: 256 + num_workers: 16 + p_horizon_flip: 0.5 + data_root_dir: "$MINFM_DATA_DIR/imagenet" + image_metas_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::image_metas" + label_to_txt_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::label_to_txt" diff --git a/pyproject.toml b/pyproject.toml index cef7dd5..3297d4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "knapformer==0.1.1", "torch_fidelity==0.4.0-beta", "scipy==1.15.3", + "dion==0.1.0", ] [project.optional-dependencies] @@ -65,6 +66,7 @@ dev-dependencies = [ flash-attn = { url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.2/flash_attn-2.8.2+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl" } knapformer = { git = "https://github.com/Kai-46/KnapFormer.git" } torch_fidelity = { git = "https://github.com/toshas/torch-fidelity.git" } +dion = { git = "https://github.com/Kai-46/dion.git" } # Ruff Configuration [tool.ruff] diff --git a/trainers/__init__.py b/trainers/__init__.py index 6797d02..8ba8204 100644 --- a/trainers/__init__.py +++ b/trainers/__init__.py @@ -23,12 +23,35 @@ class BaseTrainerParams(BaseParams): """Base parameters for all trainers - contains common training configuration.""" # Learning rate and optimizer settings - max_lr: float = 0.0001 - min_lr: float = 0.00001 + # AdamW + adam_max_lr: float = 0.0003 + adam_betas: tuple[float, float] = field(default_factory=lambda: (0.9, 0.95)) + # Muon + use_muon: bool = False + muon_max_lr: float = 0.02 + muon_mu: float = 0.95 + muon_adjust_lr: str = "spectral_norm" + muon_param_patterns: list[str] = field( + default_factory=lambda: [ + "double_blocks.*.txt_attn.qkv.weight", + "double_blocks.*.txt_attn.proj.weight", + "double_blocks.*.img_attn.qkv.weight", + "double_blocks.*.img_attn.proj.weight", + "double_blocks.*.txt_mlp.*.weight", + "double_blocks.*.img_mlp.*.weight", + "single_blocks.*.linear1.weight", # Contains qkv and mlp_in + "single_blocks.*.linear2.weight", # Contains proj and mlp_out + # Note: Excludes txt_in, img_in, final_layer.linear (input/output projections) + # Note: Exclude modulation weights + # Note: Excludes all biases by explicitly matching only .weight + ] + ) + + # Shared between AdamW and Muon + min_lr_ratio: float = 0.1 warmup_steps: int = 2000 max_steps: int = 1_000_000 weight_decay: float = 0.0 - adam_betas: tuple[float, float] = field(default_factory=lambda: (0.9, 0.95)) # Gradient accumulation settings total_batch_size: int = -1 @@ -91,23 +114,7 @@ def load_config(yaml_path: str | Path) -> dict[str, Any]: if config is None: config = {} - # Recursively expand environment variables in all string values - def _expand_env_vars(value: Any) -> Any: - """Recursively expand $VAR and ${VAR} in strings within nested structures.""" - if isinstance(value, dict): - return {k: _expand_env_vars(v) for k, v in value.items()} - if isinstance(value, list): - return [_expand_env_vars(v) for v in value] - if isinstance(value, tuple): - return tuple(_expand_env_vars(v) for v in value) - if isinstance(value, str): - # os.path.expandvars leaves unknown vars unchanged, which is desired - return os.path.expandvars(value) - return value - - expanded_config = _expand_env_vars(config) - - return cast(dict[str, Any], expanded_config) + return cast(dict[str, Any], config) def setup_distributed() -> tuple[torch.device, int, int, int]: @@ -190,4 +197,4 @@ def setup_experiment_dirs(exp_dir: str, config: dict[str, Any]) -> tuple[str, st # Wait for all processes to catch up dist.barrier() - return run_dir, ckpt_dir + return run_dir, ckpt_dir \ No newline at end of file diff --git a/trainers/dit_trainer.py b/trainers/dit_trainer.py index 5372527..ade2a78 100644 --- a/trainers/dit_trainer.py +++ b/trainers/dit_trainer.py @@ -9,7 +9,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.optim import AdamW # Local imports from data import DataStreamer @@ -22,7 +21,7 @@ from utils.fsdp import fwd_only_mode from utils.log import TrackingLogger, WandbLogger, get_logger, get_pbar, human_readable_number from utils.lr import LinearWarmupCosineDecayScheduler -from utils.optim import create_parameter_groups +from utils.optim import create_optimizer from utils.prof import Profiler from . import BaseTrainerParams, setup_distributed, setup_experiment_dirs @@ -138,11 +137,16 @@ def train_init(self) -> None: # Create optimizer (only for trainable denoiser parameters) logger.info("Setting up optimizer...") - self.optimizer = AdamW( - create_parameter_groups(self.denoiser, self.params.weight_decay), - lr=self.params.max_lr, - betas=self.params.adam_betas, - fused=True, + self.optimizer = create_optimizer( + model=self.denoiser, + use_muon=self.params.use_muon, + adam_lr=self.params.adam_max_lr, + adam_betas=self.params.adam_betas, + muon_param_patterns=self.params.muon_param_patterns, + muon_lr=self.params.muon_max_lr, + muon_mu=self.params.muon_mu, + muon_adjust_lr=self.params.muon_adjust_lr, + weight_decay=self.params.weight_decay, ) # Create learning rate scheduler @@ -151,8 +155,7 @@ def train_init(self) -> None: optimizer=self.optimizer, warmup_steps=self.params.warmup_steps, total_steps=self.params.max_steps, - max_lr=self.params.max_lr, - min_lr=self.params.min_lr, + min_lr_ratio=self.params.min_lr_ratio, ) # Switch to rank-dependent seed @@ -268,7 +271,7 @@ def train_one_step(self) -> None: ) # Optimizer step (only if gradient is healthy) - lr = self.scheduler.get_last_lr()[0] + lrs = self.scheduler.get_last_lr() # List of learning rates for each param group if not has_bad_grad: self.optimizer.step() update_ema(self.denoiser, self.ema_denoiser, self.params.ema_decay) @@ -280,14 +283,21 @@ def train_one_step(self) -> None: self.scheduler.step() # Update tracking logger - self.tracking_logger.log( - { - "has_bad_grad": has_bad_grad, - "lr": lr, - "grad_norm": grad_norm, - "step_duration": time.time() - train_step_tic, - } - ) + log_dict = { + "has_bad_grad": has_bad_grad, + "grad_norm": grad_norm, + "step_duration": time.time() - train_step_tic, + } + + # Log learning rates for each parameter group + if len(lrs) == 1: + log_dict["lr"] = lrs[0] + else: + # Multiple parameter groups - log each separately + for i, lr in enumerate(lrs): + log_dict[f"lr_{i}"] = lr + + self.tracking_logger.log(log_dict) if dump_trace: trace_path = ( @@ -329,7 +339,7 @@ def run_fwdbwd( print_data_summary: bool = False, txt_drop_prob: float = 0.1, ) -> torch.Tensor: - accum_batch_size = 0 + n_accum_steps = 0 while True: # Get next batch (FMDataContext) - infinite iterator, no StopIteration with tracking_logger.log_time("time/data"): @@ -347,14 +357,20 @@ def run_fwdbwd( with tracking_logger.log_time("time/trainable_ops_fwd"): fm_data_context = self.trainable_ops(fm_data_context) + global_batch_size = self.trainable_ops.global_batch_size + if total_batch_size <= 0: + target_n_accum_steps = 1 + else: + target_n_accum_steps = (total_batch_size + global_batch_size - 1) // global_batch_size + if not skip_backward: with tracking_logger.log_time("time/trainable_ops_bwd"): - fm_data_context.loss.backward() + (fm_data_context.loss / target_n_accum_steps).backward() tracking_logger.log({"loss_vec": fm_data_context.loss_vec, "num_tokens": fm_data_context.num_tokens}) - accum_batch_size += self.trainable_ops.global_batch_size - if accum_batch_size >= total_batch_size: + n_accum_steps += 1 + if n_accum_steps >= target_n_accum_steps: break @torch.no_grad() @@ -396,7 +412,17 @@ def log_metrics(self) -> None: # Compute average loss across all batch elements and GPUs self.tracking_logger.flush() avg_loss = self.tracking_logger["loss_vec", "mean"] - avg_lr = self.tracking_logger["lr", "mean"] + + # Handle learning rates (might be single or multiple) + lr_metrics = {} + if "lr" in self.tracking_logger.stats: + lr_metrics["lr"] = self.tracking_logger["lr", "mean"] + else: + # Multiple learning rates + for key in self.tracking_logger.stats: + if key.startswith("lr_"): + lr_metrics[key] = self.tracking_logger[key, "mean"] + max_grad_norm = self.tracking_logger["grad_norm", "max"] bad_grad_count = self.tracking_logger["has_bad_grad", "sum"] tps = self.tracking_logger["num_tokens", "sum"] / self.tracking_logger["step_duration", "sum"] * self.world_size @@ -406,30 +432,44 @@ def log_metrics(self) -> None: max_trainable_ops_bwd_time = self.tracking_logger["time/trainable_ops_bwd", "max"] if dist.get_rank() == 0: + # Format learning rate display + if "lr" in lr_metrics: + lr_display = f"LR: {lr_metrics['lr']:.2e}" + else: + lr_values = [f"{lr:.2e}" for key, lr in sorted(lr_metrics.items())] + lr_display = f"LRs: [{', '.join(lr_values)}]" + logger.info( - f"Step {self.step:6d} | Loss: {avg_loss:.4f} | LR: {avg_lr:.2e} | GradNorm: {max_grad_norm:.2f} | " + f"Step {self.step:6d} | Loss: {avg_loss:.4f} | {lr_display} | GradNorm: {max_grad_norm:.2f} | " f"TPS: {human_readable_number(tps)} | Data: {max_data_time:.3f}s | Frozen: {max_frozen_ops_time:.3f}s | " f"TrainableFwd: {max_trainable_ops_fwd_time:.3f}s | TrainableBwd: {max_trainable_ops_bwd_time:.3f}s" ) - self.wandb_logger.log( - { - "train/loss": avg_loss, - "train/learning_rate": avg_lr, - "train/gradient_norm": max_grad_norm, - "train/bad_grad_count": bad_grad_count, - "train/tps": tps, - "time/data": max_data_time, - "time/frozen_ops": max_frozen_ops_time, - "time/trainable_ops_fwd": max_trainable_ops_fwd_time, - "time/trainable_ops_bwd": max_trainable_ops_bwd_time, - }, - step=self.step, - ) + + # Prepare wandb metrics + wandb_metrics = { + "train/loss": avg_loss, + "train/gradient_norm": max_grad_norm, + "train/bad_grad_count": bad_grad_count, + "train/tps": tps, + "time/data": max_data_time, + "time/frozen_ops": max_frozen_ops_time, + "time/trainable_ops_fwd": max_trainable_ops_fwd_time, + "time/trainable_ops_bwd": max_trainable_ops_bwd_time, + } + + # Add learning rate(s) to wandb metrics + if "lr" in lr_metrics: + wandb_metrics["train/learning_rate"] = lr_metrics["lr"] + else: + for key, lr in lr_metrics.items(): + wandb_metrics[f"train/learning_rate_{key.split('_')[1]}"] = lr + + self.wandb_logger.log(wandb_metrics, step=self.step) def in_train_inference(self) -> None: - assert ( - self.latent_fm is not None and self.exp_dir is not None and self.step is not None - ), "LatentFM, exp_dir, and step must be provided in training mode" + assert self.latent_fm is not None and self.exp_dir is not None and self.step is not None, ( + "LatentFM, exp_dir, and step must be provided in training mode" + ) logger.info("Setting up InferenceOps...") inference_ops = InferenceOps(lfm=self.latent_fm) @@ -471,4 +511,4 @@ def inference(self, config: dict[str, Any]) -> None: logger.info("Running inference...") inference_ops(**config["inferencer"]["inference_ops_args"]) - dist.destroy_process_group() + dist.destroy_process_group() \ No newline at end of file diff --git a/utils/fsdp.py b/utils/fsdp.py index ae45508..dee656e 100644 --- a/utils/fsdp.py +++ b/utils/fsdp.py @@ -64,9 +64,9 @@ def apply_ac(model: nn.Module, ac_freq: int, blocks_attr: str | list[str] = "blo for attr in blocks_attr: # Retrieve container of blocks (e.g.​ transformer layers) blocks_container: nn.ModuleDict | nn.ModuleList = attrgetter(attr)(model) - assert isinstance( - blocks_container, nn.ModuleDict | nn.ModuleList - ), f"model.{attr} must be a nn.ModuleDict or nn.ModuleList, but got {type(blocks_container)}" + assert isinstance(blocks_container, nn.ModuleDict | nn.ModuleList), ( + f"model.{attr} must be a nn.ModuleDict or nn.ModuleList, but got {type(blocks_container)}" + ) logger.info( f"Applying activation checkpointing to {attr} of {type(model)} with ac_freq {ac_freq}; number of blocks: {len(blocks_container)}" @@ -135,9 +135,9 @@ def apply_fsdp( block_list = [] for attr in blocks_attr: blocks_container: nn.ModuleDict | nn.ModuleList = attrgetter(attr)(model) - assert isinstance( - blocks_container, nn.ModuleDict | nn.ModuleList - ), f"model.{blocks_attr} must be a nn.ModuleDict or nn.ModuleList, but got {type(blocks_container)}" + assert isinstance(blocks_container, nn.ModuleDict | nn.ModuleList), ( + f"model.{blocks_attr} must be a nn.ModuleDict or nn.ModuleList, but got {type(blocks_container)}" + ) logger.info( f"Applying FSDP to {attr} of {type(model)}, number of blocks: {len(blocks_container)}, " f"blocks_per_shard_group={blocks_per_shard_group}, " @@ -197,6 +197,9 @@ def apply_fsdp( # Apply FSDP to the root model fully_shard(model, **fsdp_config, reshard_after_forward=reshard_after_forward) # type: ignore + # Hacky: store fsdp_config in the model + model.fsdp_config = fsdp_config + # if reshard_after_forward_policy == "always": # from torch.distributed.fsdp._fully_shard._fully_shard import FSDPModule @@ -312,9 +315,9 @@ def dist_model_setup( if shard_size is None: shard_size = dist.get_world_size() - assert ( - dist.get_world_size() % shard_size == 0 - ), f"world_size {dist.get_world_size()} must be divisible by shard_size {shard_size}" + assert dist.get_world_size() % shard_size == 0, ( + f"world_size {dist.get_world_size()} must be divisible by shard_size {shard_size}" + ) dp_mesh = init_device_mesh( "cuda", (dist.get_world_size() // shard_size, shard_size), @@ -356,4 +359,4 @@ def fwd_only_mode(model: nn.Module): if isinstance(model, FSDPModule): model.reshard() if was_training: - model.train() + model.train() \ No newline at end of file diff --git a/utils/lr.py b/utils/lr.py index b677384..f5988d3 100644 --- a/utils/lr.py +++ b/utils/lr.py @@ -1,73 +1,81 @@ +import math import torch.optim as optim -from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LinearLR, SequentialLR +from torch.optim.lr_scheduler import LambdaLR class LinearWarmupCosineDecayScheduler: """ Learning rate scheduler with linear warmup, cosine decay, and fixed minimum learning rate. - Uses SequentialLR with three phases: - 1. LinearLR for warmup (if warmup_steps > 0) - 2. CosineAnnealingLR for decay - 3. LambdaLR to maintain min_lr after decay completes + Supports per-parameter-group learning rates where each group's learning rate is + automatically inferred from the optimizer. All groups follow the same warmup and + decay schedule pattern, decaying to min_lr_ratio * their respective max_lr. """ - def __init__( - self, optimizer: optim.Optimizer, warmup_steps: int, total_steps: int, max_lr: float, min_lr: float = 0.0 - ): + def __init__(self, optimizer: optim.Optimizer, warmup_steps: int, total_steps: int, min_lr_ratio: float = 0.0): """ Initialize the learning rate scheduler. Args: - optimizer: PyTorch optimizer + optimizer: PyTorch optimizer with parameter groups already configured with their learning rates warmup_steps: Number of steps for linear warmup (can be 0) total_steps: Total number of training steps for warmup + cosine decay - max_lr: Maximum learning rate - min_lr: Minimum learning rate at the end of cosine decay (maintained afterwards) + min_lr_ratio: Ratio to compute min_lr as fraction of each group's max_lr + (default: 0.0, meaning decay to 0; 0.1 means decay to 10% of max_lr) """ if warmup_steps < 0: raise ValueError("warmup_steps must be non-negative") if total_steps <= warmup_steps: raise ValueError("total_steps must be greater than warmup_steps") - if max_lr < min_lr: - raise ValueError("max_lr must be greater than or equal to min_lr") + if not 0.0 <= min_lr_ratio <= 1.0: + raise ValueError("min_lr_ratio must be between 0.0 and 1.0") self.optimizer = optimizer self.warmup_steps = warmup_steps self.total_steps = total_steps self.decay_steps = total_steps - warmup_steps - self.max_lr = max_lr - self.min_lr = min_lr - - # Set optimizer's initial lr to max_lr - for param_group in optimizer.param_groups: - param_group["lr"] = max_lr - - # Create warmup scheduler (no-op if warmup_steps is 0) - if warmup_steps > 0: - # Use a very small but non-zero start_factor (PyTorch requires > 0) - warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) - else: - # No-op warmup: start and end at max_lr for 1 step - warmup_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1.0, total_iters=1) - - # Create cosine decay scheduler - cosine_scheduler = CosineAnnealingLR(optimizer, T_max=self.decay_steps, eta_min=min_lr) - - # Create fixed learning rate scheduler to maintain min_lr after decay completes - # LambdaLR with constant function that returns min_lr/max_lr ratio - fixed_lr_factor = min_lr / max_lr - fixed_scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: fixed_lr_factor) - - # Create SequentialLR with three phases - warmup_milestone = max(warmup_steps, 1) # Use 1 if warmup_steps is 0 - decay_milestone = total_steps - - self.scheduler = SequentialLR( - optimizer, - schedulers=[warmup_scheduler, cosine_scheduler, fixed_scheduler], - milestones=[warmup_milestone, decay_milestone], - ) + self.min_lr_ratio = min_lr_ratio + + # Extract max learning rates from optimizer's parameter groups + self.max_lrs = [group["lr"] for group in optimizer.param_groups] + + # Compute min learning rates as ratio of max learning rates + self.min_lrs = [max_lr * min_lr_ratio for max_lr in self.max_lrs] + + # Create lambda functions for each parameter group + lr_lambdas = [] + for group_max_lr, group_min_lr in zip(self.max_lrs, self.min_lrs): + lr_lambdas.append(self._make_lr_lambda(group_max_lr, group_min_lr)) + + # Create LambdaLR scheduler with per-group lambdas + self.scheduler = LambdaLR(optimizer, lr_lambda=lr_lambdas) + + def _make_lr_lambda(self, group_max_lr: float, group_min_lr: float): + """ + Create a lambda function for a specific parameter group. + + The lambda function returns a multiplicative factor that is applied to the + base learning rate (group_max_lr) stored in the optimizer. + """ + + def lr_lambda(step: int) -> float: + if step < self.warmup_steps: + # Linear warmup from near-zero to 1.0 + if self.warmup_steps > 0: + return max(1e-8, step / self.warmup_steps) + else: + return 1.0 + elif step < self.total_steps: + # Cosine decay from 1.0 to min_lr/max_lr ratio + progress = (step - self.warmup_steps) / self.decay_steps + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + min_ratio = group_min_lr / group_max_lr if group_max_lr > 0 else 0.0 + return min_ratio + (1.0 - min_ratio) * cosine_decay + else: + # Fixed at min_lr/max_lr ratio after total_steps + return group_min_lr / group_max_lr if group_max_lr > 0 else 0.0 + + return lr_lambda def step(self) -> None: """Update learning rate for the next step.""" @@ -79,8 +87,8 @@ def get_last_lr(self) -> list: def state_dict(self) -> dict: """Return scheduler state.""" - return dict(self.scheduler.state_dict()) + return self.scheduler.state_dict() def load_state_dict(self, state_dict: dict) -> None: """Load scheduler state.""" - self.scheduler.load_state_dict(state_dict) + self.scheduler.load_state_dict(state_dict) \ No newline at end of file diff --git a/utils/optim.py b/utils/optim.py index 4ac41ea..12c54f5 100644 --- a/utils/optim.py +++ b/utils/optim.py @@ -2,53 +2,313 @@ Optimizer utilities for training. """ -from torch import nn +from dataclasses import dataclass +import fnmatch + +from dion import Muon +import torch +from torch.distributed.fsdp._fully_shard._fully_shard import FSDPModule +import torch.nn as nn +from torch.optim import AdamW, Optimizer from utils.log import get_logger logger = get_logger(__name__) -def create_parameter_groups(model: nn.Module, weight_decay: float) -> list[dict]: +@dataclass +class MatchParamResult: + matched_param_names: list[str] + matched_params: list[nn.Parameter] + unmatched_param_names: list[str] + unmatched_params: list[nn.Parameter] + + def summarize(self): + """ + Print a detailed summary table of all matched and unmatched parameters. + """ + + def format_shape(shape: torch.Size) -> str: + return "x".join(str(d) for d in shape) if len(shape) > 0 else "scalar" + + def format_number(num: int) -> str: + """Format large numbers with commas""" + return f"{num:,}" + + def build_param_table(params: list[nn.Parameter], names: list[str], title: str) -> tuple[str, int]: + """Build a formatted table of parameters as a string""" + lines = [] + + if not params: + lines.append(f"\n{title}: None") + return "\n".join(lines), 0 + + # Calculate total elements + total_elements = sum(p.numel() for p in params) + + # Find the maximum name length for proper alignment + max_name_len = max(len(name) for name in names) if names else 0 + max_name_len = max(max_name_len, len("Parameter Name")) # At least as wide as the header + + lines.append(f"\n{title} ({len(params)} parameters, {format_number(total_elements)} total elements):") + lines.append("-" * (max_name_len + 65)) + + # Header + header = f"{'Parameter Name':<{max_name_len}} | {'Shape':^15} | {'Elements':^12} | {'Dtype':^10} | {'Device':^10}" + lines.append(header) + lines.append("-" * (max_name_len + 65)) + + # Build each parameter row + for name, param in zip(names, params, strict=False): + shape_str = format_shape(param.shape) + elements_str = format_number(param.numel()) + dtype_str = str(param.dtype).replace("torch.", "") + device_str = str(param.device) + + row = f"{name:<{max_name_len}} | {shape_str:^15} | {elements_str:^12} | {dtype_str:^10} | {device_str:^10}" + lines.append(row) + + lines.append("-" * (max_name_len + 65)) + return "\n".join(lines), total_elements + + # Build the entire summary as a string + summary_lines = [] + + # Summary header + summary_lines.append("\n" + "=" * 120) + summary_lines.append("PARAMETER MATCHING SUMMARY") + summary_lines.append("=" * 120) + + # Build matched parameters table + matched_table, matched_elements = build_param_table( + self.matched_params, self.matched_param_names, "MATCHED PARAMETERS" + ) + summary_lines.append(matched_table) + + # Build unmatched parameters table + unmatched_table, unmatched_elements = build_param_table( + self.unmatched_params, self.unmatched_param_names, "UNMATCHED PARAMETERS" + ) + summary_lines.append(unmatched_table) + + # Overall summary + total_params = len(self.matched_params) + len(self.unmatched_params) + total_elements = matched_elements + unmatched_elements + + summary_lines.append("\n" + "=" * 120) + summary_lines.append("OVERALL SUMMARY") + summary_lines.append("=" * 120) + + if total_params > 0: + matched_param_pct = len(self.matched_params) / total_params * 100 + unmatched_param_pct = len(self.unmatched_params) / total_params * 100 + + if total_elements > 0: + matched_elem_pct = matched_elements / total_elements * 100 + unmatched_elem_pct = unmatched_elements / total_elements * 100 + else: + matched_elem_pct = unmatched_elem_pct = 0 + + summary_data = [ + [ + "Matched", + f"{len(self.matched_params)}", + f"{matched_param_pct:.1f}%", + f"{format_number(matched_elements)}", + f"{matched_elem_pct:.1f}%", + ], + [ + "Unmatched", + f"{len(self.unmatched_params)}", + f"{unmatched_param_pct:.1f}%", + f"{format_number(unmatched_elements)}", + f"{unmatched_elem_pct:.1f}%", + ], + ["Total", f"{total_params}", "100.0%", f"{format_number(total_elements)}", "100.0%"], + ] + + # Build summary table + summary_lines.append( + f"{'Category':<20} | {'# Parameters':^15} | {'% Parameters':^15} | {'# Elements':^15} | {'% Elements':^15}" + ) + summary_lines.append("-" * 85) + for row in summary_data: + summary_lines.append(f"{row[0]:<20} | {row[1]:^15} | {row[2]:^15} | {row[3]:^15} | {row[4]:^15}") + summary_lines.append("=" * 120) + else: + summary_lines.append("No parameters found!") + + # Single logger call with the complete summary + logger.info("\n".join(summary_lines)) + + +def match_param_patterns( + model: nn.Module, param_patterns: list[str] | None = None, include_frozen: bool = False +) -> MatchParamResult: """ - Create parameter groups for optimizer with selective weight decay. + Example: + param_patterns = [ + # Attention layers in double blocks (weight matrices only) + "double_blocks.*.txt_attn.qkv.weight", + "double_blocks.*.txt_attn.proj.weight", + "double_blocks.*.img_attn.qkv.weight", + "double_blocks.*.img_attn.proj.weight", + # MLP layers in double blocks (weight matrices only) + "double_blocks.*.txt_mlp.*.weight", + "double_blocks.*.img_mlp.*.weight", + # Attention layers in single blocks (weight matrices only) + "single_blocks.*.linear1.weight", # Contains qkv and mlp_in + "single_blocks.*.linear2.weight", # Contains proj and mlp_out + # Note: Excludes txt_in, img_in, final_layer.linear (input/output projections) + # Note: Excludes all biases by explicitly matching only .weight + ] + """ + matched_params = [] + matched_param_names = [] + unmatched_params = [] + unmatched_param_names = [] + + # Iterate through all named parameters + for name, param in model.named_parameters(): + if not param.requires_grad and not include_frozen: + continue - Weight decay is only applied to parameters with ndim > 1 (e.g., weight matrices). - Parameters with ndim <= 1 (e.g., biases, layer norm parameters) get no weight decay. + # Check if this parameter matches any Muon pattern + matched = False - Args: - model: The model to create parameter groups for - weight_decay: Weight decay value to apply to parameters with ndim > 1 + for pattern in param_patterns: + # Support wildcards in patterns + if fnmatch.fnmatch(name, pattern): + matched = True + break - Returns: - List of parameter group dictionaries suitable for optimizer initialization + if matched: + matched_params.append(param) + matched_param_names.append(name) + else: + unmatched_params.append(param) + unmatched_param_names.append(name) + + return MatchParamResult( + matched_param_names=matched_param_names, + matched_params=matched_params, + unmatched_param_names=unmatched_param_names, + unmatched_params=unmatched_params, + ) + + +def create_muon_optimizer( + model: nn.Module, + adam_lr: float, + adam_betas: tuple[float, float], + muon_param_patterns: list[str], + muon_lr: float, + muon_mu: float, + muon_adjust_lr: str, + weight_decay: float, +) -> list[Optimizer]: + """ + Create Muon optimizer for the model with separate learning rates for Muon and AdamW parameters. """ - # Separate parameters by dimensionality for selective weight decay - decay_params = [] - no_decay_params = [] + # https://github.com/pytorch/pytorch/blob/6c05ea6475beaf3acc05e1bda0f3f8fe3bdc1d49/torch/distributed/fsdp/_fully_shard/_fsdp_common.py#L52 + assert isinstance(model, FSDPModule), f"Model must be an FSDPModule, but got {type(model)}" + assert hasattr(model, "fsdp_config"), "Model must have fsdp_config" + shard_mesh = model.fsdp_config["mesh"]["shard"] + + match_param_result = match_param_patterns(model, muon_param_patterns) + match_param_result.summarize() - for param in model.parameters(): - if param.requires_grad: + muon_params = match_param_result.matched_params + muon_param_groups = [ + { + "params": muon_params, + "algorithm": "muon", + "lr": muon_lr, + "mu": muon_mu, + "adjust_lr": muon_adjust_lr, + "weight_decay": weight_decay, + }, + ] + + adam_params = match_param_result.unmatched_params + if weight_decay > 0: + adam_params_decay, adam_params_no_decay = [], [] + for param in adam_params: if param.ndim > 1: - decay_params.append(param) + adam_params_decay.append(param) else: - no_decay_params.append(param) + adam_params_no_decay.append(param) + adam_param_groups = [ + { + "params": adam_params_decay, + "algorithm": "adamw", + "lr": adam_lr, + "betas": adam_betas, + "weight_decay": weight_decay, + }, + { + "params": adam_params_no_decay, + "algorithm": "adamw", + "lr": adam_lr, + "betas": adam_betas, + "weight_decay": 0.0, + }, + ] + else: + adam_param_groups = [ + { + "params": adam_params, + "algorithm": "adamw", + "lr": adam_lr, + "betas": adam_betas, + "weight_decay": weight_decay, + }, + ] - # Count parameters for logging - decay_param_count = sum(p.numel() for p in decay_params) - no_decay_param_count = sum(p.numel() for p in no_decay_params) - total_params = decay_param_count + no_decay_param_count + param_groups = muon_param_groups + adam_param_groups - logger.info( - f"Parameter groups: {decay_param_count:,} parameters with weight decay, " - f"{no_decay_param_count:,} without weight decay " - f"(total: {total_params:,} parameters)" - ) + muon_optimizer = Muon(param_groups, distributed_mesh=shard_mesh, nesterov=True, use_triton=True) + + return muon_optimizer + + +def create_adam_optimizer(model: nn.Module, lr: float, betas: tuple[float, float], weight_decay: float) -> Optimizer: + """ + Create AdamW optimizer for the model. + """ + if weight_decay > 0: + adam_param_decay, adam_param_no_decay = [], [] + for param in model.parameters(): + if param.ndim > 1: + adam_param_decay.append(param) + else: + adam_param_no_decay.append(param) + param_groups = [ + {"params": adam_param_decay, "lr": lr, "betas": betas, "weight_decay": weight_decay}, + {"params": adam_param_no_decay, "lr": lr, "betas": betas, "weight_decay": 0.0}, + ] + else: + param_groups = [ + {"params": model.parameters(), "lr": lr, "betas": betas, "weight_decay": weight_decay}, + ] + + return AdamW(param_groups, fused=True) - # Create parameter groups with different weight decay settings - param_groups = [ - {"params": decay_params, "weight_decay": weight_decay}, - {"params": no_decay_params, "weight_decay": 0.0}, - ] - return param_groups +def create_optimizer( + model: nn.Module, + use_muon: bool, + adam_lr: float, + adam_betas: tuple[float, float], + muon_param_patterns: list[str], + muon_lr: float, + muon_mu: float, + muon_adjust_lr: str, + weight_decay: float, +) -> Optimizer: + if use_muon: + return create_muon_optimizer( + model, adam_lr, adam_betas, muon_param_patterns, muon_lr, muon_mu, muon_adjust_lr, weight_decay + ) + else: + return create_adam_optimizer(model, adam_lr, adam_betas, weight_decay) \ No newline at end of file