Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,53 @@
Our method takes as input a front-view image, a natural-language navigation command with a system prompt, and the ego-vehicle states, and outputs an 8-waypoint future trajectory spanning 4 seconds through parallel denoising. The model is first trained via supervised fine-tuning to learn accurate trajectory prediction. We then apply simulatorguided GRPO to further optimize closed-loop behavior. The GRPO reward function integrates safety constraints (collision avoidance, drivable-area compliance) with performance objectives (ego-progress, time-to-collision, comfort).


## Preparation

### Environment
## Quick Start

### Installation

Clone the repo:

```sh
git clone https://github.com/fudan-generative-vision/WAM-Flow.git
cd WAM-Flow
```

Install dependencies:

```sh
conda create --name wam-flow python=3.10
conda activate wam-flow
pip install -r requirements.txt
```

## Training

### Model Download

Download models using huggingface-cli:

```sh
sh script/sft_debug.sh
pip install "huggingface_hub[cli]"
huggingface-cli download fudan-generative-ai/WAM-Flow --local-dir ./pretrained_model/wam-flow
huggingface-cli download LucasJinWang/FUDOKI --local-dir ./pretrained_model/fudoki
```

## Inference


### Inference

```sh
sh script/infer.sh
```


### Training

```bash
sh script/sft_debug.sh
```



## 📝 Citation

Expand Down
14 changes: 14 additions & 0 deletions script/infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

CKPT_PATH="pretrained_model/wam-flow/navsim"
FUDOKI_PATH="pretrained_model/fudoki"
IMAGE_PATH="data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg"

torchrun --nproc_per_node 1 infer.py \
--checkpoint_path $CKPT_PATH \
--image_path $IMAGE_PATH \
--processor_path $FUDOKI_PATH \
--text_embedding_path $FUDOKI_PATH/text_embedding.pt \
--image_embedding_path $FUDOKI_PATH/image_embedding.pt \
--discrete_fm_steps 2 \
--seed 123
17 changes: 17 additions & 0 deletions script/sft_debug.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

NUM_NODES=1
NUM_GPUS=1

config=config/debug.yaml
output_dir=output/train/debug

accelerate launch \
--config_file ./config/accelerate_config_ds2.yaml \
--machine_rank 0 \
--main_process_port 12345 \
--num_machines $NUM_NODES \
--num_processes $NUM_GPUS \
train.py \
--config $config \
--output_dir $output_dir
17 changes: 17 additions & 0 deletions script/sft_navsim.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

NUM_NODES=1
NUM_GPUS=1

config=config/sft_navsim.yaml
output_dir=output/train/debug

accelerate launch \
--config_file ./config/accelerate_config_ds2.yaml \
--machine_rank 0 \
--main_process_port 12345 \
--num_machines $NUM_NODES \
--num_processes $NUM_GPUS \
train.py \
--config $config \
--output_dir $output_dir