diff --git a/README.md b/README.md index 0901c11..c9fbdb1 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ This experimental feature leverages `diffusers`'s `transformer.set_attention_bac | AWM | awm | [Advantage Weighted Matching](https://arxiv.org/abs/2509.25050) | | DGPO | dgpo | [DGPO](https://arxiv.org/abs/2510.08425) | | GRPO-Guard | grpo-guard | [GRPO-Guard](https://arxiv.org/abs/2510.22319) | +| CRD | crd | [Centered Reward Distillation](https://arxiv.org/abs/2603.14128) ([Blog (chinese)](https://mp.weixin.qq.com/s/fpTi7PPi3APSNJQ2kXN3Dw))| See [`Algorithm Guidance`](guidance/algorithms.md) for more information. @@ -155,7 +156,7 @@ We provide a set of guidance documents to help you understand the framework and | Document | Description | |---|---| | [Workflow](guidance/workflow.md) | End-to-end training pipeline: the overall stages from data preprocessing to policy optimization | -| [Algorithms](guidance/algorithms.md) | Supported RL algorithms (GRPO, GRPO-Guard, DiffusionNFT, AWM, DPO, DGPO) and their configurations | +| [Algorithms](guidance/algorithms.md) | Supported RL algorithms (GRPO, GRPO-Guard, DiffusionNFT, AWM, DPO, DGPO, CRD) and their configurations | | [Rewards](guidance/rewards.md) | Reward model system: built-in models, custom rewards, and remote reward servers | | [New Model](guidance/new_model.md) | How to add support for a new Diffusion/Flow-Matching model | diff --git a/examples/README.md b/examples/README.md index 6479922..96f5784 100644 --- a/examples/README.md +++ b/examples/README.md @@ -10,7 +10,7 @@ examples/{algorithm}/{finetune_type}/{model_type}/{variant}.yaml | Level | Description | Examples | |-------|-------------|---------| -| `algorithm` | Training algorithm | `grpo`, `nft`, `awm`, `dgpo`, `dpo` | +| `algorithm` | Training algorithm | `grpo`, `nft`, `awm`, `dgpo`, `dpo`, `crd` | | `finetune_type` | Parameter-efficient or full | `lora`, `full` | | `model_type` | Model family (underscore-separated) | `flux1`, `sd3_5`, `wan21`, `ltx2` | | `variant` | Config variant | `default.yaml`, `nocfg.yaml`, `t2v.yaml` | diff --git a/guidance/workflow.md b/guidance/workflow.md index 23aa6f4..983abba 100644 --- a/guidance/workflow.md +++ b/guidance/workflow.md @@ -220,11 +220,12 @@ BaseSample( | **DiffusionNFT** | `False` | `[-1]` (final only) | Only needs final clean latent $x_1$; log-prob not required | | **AWM** | `False` | `[-1]` (final only) | Same as NFT; log-prob computed later during optimization | | **DGPO** | `False` | `[-1]` (final only) | Same trajectory policy as NFT/AWM; optimization uses fresh `TimeSampler` timesteps | +| **CRD** | `False` | `[-1]` (final only) | Same trajectory policy as NFT/AWM; reward distillation against CFG-guided teacher reference | ### Key Points - **Selective trajectory recording**: `trajectory_indices` controls which denoising steps are stored. For GRPO, only steps corresponding to `train_timesteps` are kept to reduce memory. -- **SDE dynamics for exploration**: GRPO injects noise during sampling via SDE formulation, enabling the log-probability computation required for policy gradients. NFT, AWM, and DGPO use decoupled sampling (typically ODE) with `compute_log_prob=False`. +- **SDE dynamics for exploration**: GRPO injects noise during sampling via SDE formulation, enabling the log-probability computation required for policy gradients. NFT, AWM, DGPO, and CRD use decoupled sampling (typically ODE) with `compute_log_prob=False`. - **Off-policy sampling**: NFT optionally uses EMA parameters for sampling (`off_policy: true`), while the current policy is optimized — stabilizing training. @@ -399,6 +400,7 @@ def optimize(self, samples): | **DiffusionNFT** | Samples fresh timesteps; interpolates $x_t = (1-t)x_1 + t\epsilon$; contrastive objective with normalized rewards | | **AWM** | Samples fresh timesteps; weights velocity matching loss by advantage; PPO clipping + EMA-KL regularization | | **DGPO** | Samples fresh timesteps via `TimeSampler`; applies group-level preference objective with optional PPO clipping and EMA-reference KL | +| **CRD** | Samples fresh timesteps; reward distillation against CFG-guided teacher with adaptive KL; old/sampling model snapshots and centered advantages | | **DPO** | Preference loss on chosen/rejected pairs; pairs formed at the start of `optimize` after advantages | ### Key Points @@ -406,7 +408,7 @@ def optimize(self, samples): - **Inner epochs**: Samples can be reused for multiple optimization passes (`num_inner_epochs`), amortizing the cost of sampling. - **Gradient accumulation**: The `accelerator.accumulate()` context handles gradient accumulation across timesteps and micro-batches, with optimizer steps only at sync boundaries. - **KL regularization**: Optional penalty keeping the policy close to a reference model (or EMA model for AWM), preventing reward hacking. -- **Per-timestep iteration**: GRPO iterates over each stored trajectory timestep, computing loss at each. NFT, AWM, and DGPO sample fresh timesteps independently of the sampling trajectory. +- **Per-timestep iteration**: GRPO iterates over each stored trajectory timestep, computing loss at each. NFT, AWM, DGPO, and CRD sample fresh timesteps independently of the sampling trajectory. ## Putting It All Together