Skip to content

grad_norm=NaN During NFT Training on Flux1.d-dev #134

@rlustc

Description

@rlustc

DiffusionNFT 训练中频繁出现 grad_norm=NaN 这导致我的训练完全不可行 具体的训练yaml文件如下

# Environment Configuration
launcher: "accelerate"  # Options: accelerate
config_file: config/accelerate_configs/fsdp_full_shard.yaml  # Path to distributed config file (optional)
num_processes: 8  # Number of processes to launch (overrides config file)
main_process_port: 29500
mixed_precision: "bf16"  # Options: no, fp16, bf16

# Data Configuration
data:
  dataset_dir: "dataset/pickscore"  # Path to dataset folder
  preprocessing_batch_size: 8  # Batch size for preprocessing
  dataloader_num_workers: 16  # Number of workers for DataLoader
  force_reprocess: false  # Force reprocessing of the dataset
  cache_dir: "~/.cache/flow_factory/datasets" # Cache directory for preprocessed datasets
  max_dataset_size: 1024  # Limit the maximum number of samples in the dataset
  sampler_type: "auto"  # Options: auto, distributed_k_repeat, group_contiguous

# Model Configuration
model:
  finetune_type: 'full' # Options: full, lora
  target_modules: "default" # Options: all, default, or list of module names like ["to_k", "to_q", "to_v", "to_out.0"]
  model_name_or_path: "/data/aigc/liangyzh_intern/Lirui/Flux1.0-dev"  # HuggingFace model ID or local path
  model_type: "flux1"
  resume_path: null # Path to load previous checkpoint/lora adapter
  resume_type: null # Options: lora, full, state. Null to auto-detect based on `finetune_type`

log:
  run_name: null  # Run name (auto: {model_type}_{finetune_type}_{trainer_type}_{timestamp})
  project: "Flow-Factory"  # Project name for logging
  logging_backend: "tensorboard"  # Options: wandb, swanlab, none
  save_dir: "saves/"  # Directory to save model checkpoints and logs
  save_freq: 40  # Save frequency in epochs (0 to disable)
  save_model_only: true  # Save only the model weights (not optimizer, scheduler, etc.)

# Training Configuration
train:
  # Trainer settings
  trainer_type: 'nft'
  advantage_aggregation: 'sum' # Options: 'sum', 'gdpo'
  nft_beta: 1
  # `Old` Policy settings
  off_policy: true # Whether to use ema parameters for sampling off-policy data.
  ema_decay_schedule: "piecewise_linear"  # Decay schedule for EMA. Options: ['constant', 'power', 'linear', 'piecewise_linear', 'cosine', 'warmup_cosine']
  flat_steps: 0
  ramp_rate: 0.001
  ema_decay: 0.5  # EMA decay rate (0 to disable)
  ema_update_interval: 1  # EMA update interval (in epochs)
  ema_device: "cpu"  # Device to store EMA model (options: cpu, cuda)
  # Training Timestep distribution
  num_train_timesteps: 2 # Set null to all steps
  time_sampling_strategy: discrete # Options: uniform, logit_normal, discrete, discrete_with_init, discrete_wo_init
  time_shift: 3.0
  timestep_range: 0.7 # Select fraction of timesteps to train on
  # KL div
  kl_type: 'v-based'
  kl_beta: 0 # KL divergence beta, 0 to disable
  ref_param_device: 'cpu' # Options: cpu, cuda
  # Clipping
  adv_clip_range: 5.0  # Advantage clipping range

  # Sampling
  resolution: 384  # Can be int or [height, width]
  num_inference_steps: 8  # Number of timesteps
  guidance_scale: 3.5  # Guidance scale for sampling

  # Batch and sampling
  per_device_batch_size: 1  # Batch size per device
  group_size: 16  # Group size for GRPO sampling
  global_std: false  # Use global std for advantage normalization
  unique_sample_num_per_epoch: 48  # Unique samples per group
  gradient_step_per_epoch: 1  # Gradient steps per epoch. The first step is on-policy, the rest are off-policy.
  gradient_accumulation_steps: auto  # Options: auto, or positive integer. When set, `gradient_step_per_epoch` is ignored.
    
  # Optimization
  learning_rate: 1.0e-5  # Initial learning rate
  adam_weight_decay: 1.0e-4  # AdamW weight decay
  adam_betas: [0.9, 0.999]  # AdamW betas
  adam_epsilon: 1.0e-8  # AdamW epsilon
  max_grad_norm: 1.0  # Max gradient norm for clipping

  # Gradient checkpointing
  enable_gradient_checkpointing: true  # Enable gradient checkpointing to save memory with extra compute

  # Seed
  seed: 42  # Random seed

# Scheduler Configuration
scheduler:
  dynamics_type: "ODE"  # Options: Flow-SDE, Dance-SDE, CPS, ODE

# Evaluation settings
eval:
  resolution: 1024  # Evaluation resolution
  per_device_batch_size: 1  # Eval batch size
  guidance_scale: 3.5  # Guidance scale for sampling
  num_inference_steps: 28  # Number of eval timesteps
  eval_freq: 20  # Eval frequency in epochs (0 to disable)
  seed: 42  # Eval seed (defaults to training seed)

# Reward Model Configuration
rewards:
  - name: "hps"
    reward_model: "HPSv2"
    hps_ckpt_path: "/data/aigc/liangyzh_intern/zqni/DanceGRPO-main/HPSv2/ckpt_all/HPS_v2.1_compressed.pt"
    clip_pretrained_path: "/data/aigc/liangyzh_intern/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
    hps_version: "v2.1"
    batch_size: 16
    dtype: bfloat16
    device: "cuda"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions