Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6d5aa7c
Update train.sh
zhuole1025 Apr 21, 2025
5349e4a
Update README.md
zhuole1025 Apr 21, 2025
e5ef38f
update tts&README
Apr 21, 2025
cac945c
Update README.md
zhuole1025 Apr 21, 2025
74439b1
update README
Apr 21, 2025
fbf8c7b
update readme
Apr 21, 2025
b066a5b
Update README.md
sayakpaul Apr 23, 2025
79b6007
clean codebase
zhuole1025 Apr 23, 2025
aa7f955
Update README.md
zhuole1025 Apr 23, 2025
0ad8db6
Update README.md
sayakpaul Apr 23, 2025
fd7b1c8
Merge pull request #1 from Diffusion-CoT/wds-dataset
sayakpaul Apr 24, 2025
737a7ed
hardware.
sayakpaul Apr 24, 2025
f2ab75f
note on number of shards.
sayakpaul Apr 24, 2025
43ab885
Update README.md
liangbingzhao Apr 24, 2025
f8b6898
Update README.md
liangbingzhao Apr 30, 2025
c29e81f
Update README.md
liangbingzhao Apr 30, 2025
99dad78
Update README.md
sayakpaul May 10, 2025
592bddb
Update README.md
sayakpaul May 10, 2025
9ca1df8
Update README.md
liangbingzhao May 10, 2025
df30a02
Update README.md
liangbingzhao May 10, 2025
f6d0d38
Update README.md to detail the paths
sayakpaul May 10, 2025
4267707
Merge pull request #3 from Diffusion-CoT/sayakpaul-patch-1
liangbingzhao May 10, 2025
0e4b318
Update train.py
zhuole1025 May 11, 2025
e4984f2
Update README.md
sayakpaul May 12, 2025
8772415
additional fixes in the readme.
sayakpaul May 12, 2025
d7f3c1d
fixes in tts_reflectionflow.py
sayakpaul May 12, 2025
398aecc
delete unneeded file.
sayakpaul May 12, 2025
fd73244
change to samples from midimg
sayakpaul May 12, 2025
38ef230
Merge pull request #4 from Diffusion-CoT/better-readme-instructions
liangbingzhao May 12, 2025
c9985bd
benchmarking code.
May 12, 2025
b2e9bbd
fixes
May 12, 2025
492e023
update
May 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 131 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
<div align="center" style="font-family: charter;">
<h1><i>From Reflection to Perfection:</i>:</br>Scaling Inference-Time Optimization for Text-to-Image Diffusion Models via Reflection Tuning</h1>
<h1><i>From Reflection to Perfection:</i></br>Scaling Inference-Time Optimization for Text-to-Image Diffusion Models via Reflection Tuning</h1>



<a href="tmp" target="_blank">
<a href="https://arxiv.org/abs/2504.16080" target="_blank">
<img alt="arXiv" src="https://img.shields.io/badge/arXiv-ReflectionFlow-red?logo=arxiv" height="20" /></a>
<a href="https://liangbingzhao.github.io/reflection2perfection/" target="_blank">
<a href="https://diffusion-cot.github.io/reflection2perfection/" target="_blank">
<img alt="Website" src="https://img.shields.io/badge/🌎_Website-ReflectionFlow-blue.svg" height="20" /></a>
<a href="https://huggingface.co/collections/diffusion-cot/reflectionflow-release-6803e14352b1b13a16aeda44" target="_blank">
<img alt="HF Dataset: ReflectionFlow" src="https://img.shields.io/badge/%F0%9F%A4%97%20_Hugging Face-ReflectionFlow-ffc107?color=ffc107&logoColor=white" height="20" /></a>
Expand Down Expand Up @@ -40,46 +40,156 @@

## :fire: News

- [2025/4/??] Release [paper](tmp).
- [2025/4/??] Release GenRef dataset, as well as the training and evaluation code.
- [2025/4/23] Release [paper](https://arxiv.org/abs/2504.16080).
- [2025/4/20] Release GenRef dataset, model checkpoints, as well as the training and inference code.

## ✨ Quick Start

### Installation

Coming soon.
1. **Environment setup**
```bash
conda create -n ReflectionFlow python=3.10
conda activate ReflectionFlow
```
2. **Requirements installation**
```bash
pip install -r requirements.txt
```

## 🚀 GenRef Dataset
## 🚀 Models and Datasets

### Introduction
### Datasets
| Name | Description | Link |
| --- | --- | --- |
| GenRef-wds | WebDataset format of full GenRef | [HuggingFace](https://huggingface.co/datasets/diffusion-cot/GenRef-wds) |
| GenRef-CoT | Chain-of-Thought reflection dataset | [HuggingFace](https://huggingface.co/datasets/diffusion-cot/GenRef-CoT) |

### Models
| Name | Description | Finetune Data | Link |
| --- | --- | --- | --- |
| FLUX Corrector | Main FLUX-based "text image -> image" model | GenRef-wds | [HuggingFace](https://huggingface.co/diffusion-cot/FLUX-Corrector) |
| Reflection Generator | Qwen-based reflection generator | GenRef-CoT | [HuggingFace](https://huggingface.co/diffusion-cot/Reflection-Generator) |
| Image Verifier | Qwen-based image verifier | GenRef-CoT | [HuggingFace](https://huggingface.co/diffusion-cot/Image-Verifier) |

Coming soon.

### Evaluation on VLM
## 🤖 Reflection Tuning

Coming soon.
[`train_flux/config.yaml`](./train_flux/config.yaml) exposes all the arguments to control
all the training-time configurations.

### Evaluation on LMM
First, get the data. You can either download the `webdataset` shards from [`diffusion-cot/GenRef-wds`](https://huggingface.co/datasets/diffusion-cot/GenRef-wds) or directly pass URLs.

Coming soon.
When using local paths, set `path` under `[train][dataset]` to a glob pattern: `DATA_DIR/genref_*.tar`. The current `config.yaml` configures training to stream from the `diffusion-cot/GenRef-wds` repository. You can even
change the number of tars you want to stream for easier debugging. Just change `genref_{0..208}.tar` to something
like `genref_{0..4}.tar`, depending on the number of shards you want to use.

Run the following command for training the FLUX Corrector:

## 🤖 Reflection Tuning
Coming soon.
```bash
bash train_flux/train.sh
```

We tested our implementation on a single node of 8 80GB A100s and H100s. We acknowledge that there are opportunities
for optimization, but we didn't prioritize them in this release.

>[!NOTE]
> Validation during training is yet to be implemented.

## ⚡ Inference Time Scaling
Coming Soon

### Introduction
We provide the code for the inference time scaling of our reflection-tuned models. Currently, we support:
* GPT-4o as verifier, reflection generator, and prompt refiner.
* [NVILA-2B](https://huggingface.co/Efficient-Large-Model/NVILA-Lite-2B-Verifier) verifier from SANA.
* Our [reflection generator](https://huggingface.co/diffusion-cot/Reflection-Generator).

### Setup
First, you need to set up the following:

```bash
export OPENAI_API_KEY=your_api_key
# if you want to use NVILA as verifier
pip install transformers==4.46
pip install git+https://github.com/bfshi/scaling_on_scales.git
```
Then you need to set up the `FLUX_PATH` and `LORA_PATH` in the config file of your choice from [tts/config](./tts/configs/). The `FLUX_PATH` is basically the contents of [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main) which can be downloaded like so:

```py
from huggingface_hub import snapshot_download

local_dir = "SOME_DIR"
snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", local_dir=local_dir)
```

The `LORA_PATH` is our [corrector model](https://huggingface.co/diffusion-cot/FLUX-Corrector) path.

If you want to use our finetuned reflection generator, you need to first install [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory). Then download the model from [here](https://huggingface.co/diffusion-cot/Reflection-Generator) and change the `model_name_or_path` in the config file of
`tts/config/our_reflectionmodel.yaml` to the reflection generator path. To be specific, the path should be like `Reflection-Generator/infer/30000`. Next, host the model with:

```bash
API_PORT=8001 CUDA_VISIBLE_DEVICES=0 llamafactory-cli api configs/our_reflectionmodel.yaml
```
And change the `name` of `reflection_args` in the config file (for example: [tts/configs/flux.1_dev_gptscore.json](./tts/config/flux.1_dev_gptscore.json)) to `ours`.

> [!NOTE]
> When using our reflection generator model, please consider using at least two GPUs for better allocating resources.

### Run

First, please run `tts_t2i_noise_scaling.py` to generate naive noise scaling results, with the commands:

```bash
export OUTPUT_DIR=output_dir
cd tts
python tts_t2i_noise_scaling.py --output_dir=$OUTPUT_DIR --meta_path=geneval/evaluation_metadata.jsonl --pipeline_config_path=configs/flux.1_dev_gptscore.json
```

Next, you can run the following command to generate the results of reflection tuning:

```bash
export NEW_OUTPUT_DIR=reflection_tuning_dir
python tts_reflectionflow.py --imgpath=$OUTPUT_DIR --pipeline_config_path=configs/flux.1_dev_gptscore.json --output_dir=NEW_OUTPUT_DIR
```

We also provide the code for only noise & prompt scaling:

```bash
python tts_t2i_noise_prompt_scaling.py --output_dir=$OUTPUT_DIR --meta_path=geneval/evaluation_metadata.jsonl --pipeline_config_path=configs/flux.1_dev_gptscore.json
```

You can also change to [tts/configs/flux.1_dev_nvilascore.json](./tts/config/flux.1_dev_nvilascore.json) to use the NVILA verifier.

By default, we use prompts from [tts/config/geneval/evaluation_metadata.jsonl](./tts/config/geneval/evaluation_metadata.jsonl). If you don't want to use all the prompts from it, you can specify `--start_index` and `--end_index` CLI args.

### NVILA Verifier Filter

After generation, we provide the code using NVILA verifier to filter and get different numbers of sample results.

```bash
python verifier_filter.py --imgpath=$OUTPUT_DIR --pipeline_config_path=configs/flux.1_dev_nvilascore.json
```

## 🤝 Acknowledgement

We are deeply grateful for the following GitHub repositories, as their valuable code and efforts have been incredibly helpful:

Coming Soon.

* OminiControl (https://github.com/Yuanshi9815/OminiControl)
* Flux-TTS (https://github.com/sayakpaul/tt-scale-flux)


## ✏️ Citation

If you find ReflectionFlow useful for your your research and applications, please cite using this BibTeX:
If you find ReflectionFlow useful for your research and applications, please cite using this BibTeX:

```bibtex
tmp
```
@misc{zhuo2025reflectionperfectionscalinginferencetime,
title={From Reflection to Perfection: Scaling Inference-Time Optimization for Text-to-Image Diffusion Models via Reflection Tuning},
author={Le Zhuo and Liangbing Zhao and Sayak Paul and Yue Liao and Renrui Zhang and Yi Xin and Peng Gao and Mohamed Elhoseiny and Hongsheng Li},
year={2025},
eprint={2504.16080},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2504.16080},
}
```
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ lightning
datasets
torchvision
prodigyopt
wandb
wandb
webdataset
40 changes: 11 additions & 29 deletions train_flux/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model_path: "black-forest-labs/FLUX.1-dev"
dtype: "bfloat16"
cache_dir: "/mnt/petrelfs/zhuole/.cache/huggingface/hub/"
cache_dir: "CACHE_DIR"

model:
union_cond_attn: true
Expand All @@ -23,25 +23,7 @@ train:
resume_training_from_checkpoint_path: ""
dataset:
type: "img"
path: {
"general": [
"/mnt/petrelfs/zhuole/data/metadata_clean/flux_pro_detailed_prompt_pairs_train.json", # 35344
"/mnt/petrelfs/zhuole/data/metadata_clean/flux_pro_short_prompt_pairs_train.json", # 53551
"/mnt/petrelfs/zhuole/data/metadata_clean/id_prompt_pairs_train.json", # 41440
"/mnt/petrelfs/zhuole/data/metadata_clean/zl2m_v2_pairs_train.json" # 11513
],
"length": [
"/mnt/petrelfs/zhuole/data/metadata_clean/flux_pro_pairs_train.json", # 78939
"/mnt/petrelfs/zhuole/data/metadata_clean/zl2m_prompt_pairs_train.json" # 16292
],
"rule": [
"/mnt/petrelfs/zhuole/data/metadata_clean/geneval_pairs_train.json", # 129321
"/mnt/petrelfs/zhuole/data/metadata_clean/t2i_pairs_train.json" # 55388
],
"editing": [
"/mnt/petrelfs/zhuole/data/metadata_clean/editing_pairs_train.json" # 616409
]
}
path: "pipe:curl -s -f -L https://huggingface.co/datasets/diffusion-cot/GenRef-wds/resolve/main/genref_{0..208}.tar"
split_ratios: {
"general": [0.1, 0.3],
"length": [0.1, 0.3],
Expand All @@ -50,15 +32,15 @@ train:
}
training_stages: [0, 5000]
root_dir: ""
val_path: {
"general": ["/mnt/petrelfs/zhuole/ReflectionFlow_/val.json"]
}
val_root_dir: ""
condition_size: 512
target_size: 1024
drop_text_prob: 0.1
drop_image_prob: 0.1
drop_reflection_prob: 0.1
# val_path: {
# "general": "VAL_TARS"
# }
# val_root_dir: ""
# condition_size: 512
# target_size: 1024
# drop_text_prob: 0.1
# drop_image_prob: 0.1
# drop_reflection_prob: 0.1

wandb:
project: "ReflectionFlow"
Expand Down
Binary file removed train_flux/flux/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file removed train_flux/flux/__pycache__/block.cpython-310.pyc
Binary file not shown.
Binary file removed train_flux/flux/__pycache__/block.cpython-311.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed train_flux/flux/__pycache__/generate.cpython-310.pyc
Binary file not shown.
Binary file removed train_flux/flux/__pycache__/generate.cpython-311.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
76 changes: 0 additions & 76 deletions train_flux/runs/test/20250421-230029/config.yaml

This file was deleted.

Binary file not shown.
Binary file removed train_flux/runs/test/20250421-230029/val/0_cot_1.jpg
Binary file not shown.
Binary file removed train_flux/runs/test/20250421-230029/val/0_cot_2.jpg
Binary file not shown.
4 changes: 2 additions & 2 deletions train_flux/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
export XFL_CONFIG=config.yaml

# Specify the WANDB API key
export WANDB_API_KEY="6259cb5e7b3f9ef85de258325ba564e92827f2c5"
export WANDB_API_KEY=""

export TOKENIZERS_PARALLELISM=true
accelerate launch --main_process_port 41353 -m train.train
accelerate launch --main_process_port 41353 -m train.train
Binary file removed train_flux/train/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file removed train_flux/train/__pycache__/data.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file removed train_flux/train/__pycache__/model.cpython-310.pyc
Binary file not shown.
Binary file removed train_flux/train/__pycache__/train.cpython-310.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion train_flux/train/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.total_steps += 1

# Update split ratios
trainer.train_dataloader.dataset._update_split_ratios(self.total_steps)
trainer.train_dataloader.dataset._update_split_ratios()

# Print training progress every n steps
if self.use_wandb:
Expand Down
Loading