diff --git a/README.md b/README.md index 4a85185..7e560fa 100644 --- a/README.md +++ b/README.md @@ -67,26 +67,53 @@ Our method takes as input a front-view image, a natural-language navigation command with a system prompt, and the ego-vehicle states, and outputs an 8-waypoint future trajectory spanning 4 seconds through parallel denoising. The model is first trained via supervised fine-tuning to learn accurate trajectory prediction. We then apply simulatorguided GRPO to further optimize closed-loop behavior. The GRPO reward function integrates safety constraints (collision avoidance, drivable-area compliance) with performance objectives (ego-progress, time-to-collision, comfort). -## Preparation -### Environment +## Quick Start + +### Installation + +Clone the repo: + +```sh +git clone https://github.com/fudan-generative-vision/WAM-Flow.git +cd WAM-Flow +``` + +Install dependencies: + ```sh conda create --name wam-flow python=3.10 conda activate wam-flow pip install -r requirements.txt ``` -## Training + +### Model Download + +Download models using huggingface-cli: + ```sh -sh script/sft_debug.sh +pip install "huggingface_hub[cli]" +huggingface-cli download fudan-generative-ai/WAM-Flow --local-dir ./pretrained_model/wam-flow +huggingface-cli download LucasJinWang/FUDOKI --local-dir ./pretrained_model/fudoki ``` -## Inference + + +### Inference + ```sh sh script/infer.sh ``` +### Training + +```bash +sh script/sft_debug.sh +``` + + ## 📝 Citation diff --git a/script/infer.sh b/script/infer.sh new file mode 100644 index 0000000..114b5f4 --- /dev/null +++ b/script/infer.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +CKPT_PATH="pretrained_model/wam-flow/navsim" +FUDOKI_PATH="pretrained_model/fudoki" +IMAGE_PATH="data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg" + +torchrun --nproc_per_node 1 infer.py \ + --checkpoint_path $CKPT_PATH \ + --image_path $IMAGE_PATH \ + --processor_path $FUDOKI_PATH \ + --text_embedding_path $FUDOKI_PATH/text_embedding.pt \ + --image_embedding_path $FUDOKI_PATH/image_embedding.pt \ + --discrete_fm_steps 2 \ + --seed 123 \ No newline at end of file diff --git a/script/sft_debug.sh b/script/sft_debug.sh new file mode 100644 index 0000000..fd310a6 --- /dev/null +++ b/script/sft_debug.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +NUM_NODES=1 +NUM_GPUS=1 + +config=config/debug.yaml +output_dir=output/train/debug + +accelerate launch \ + --config_file ./config/accelerate_config_ds2.yaml \ + --machine_rank 0 \ + --main_process_port 12345 \ + --num_machines $NUM_NODES \ + --num_processes $NUM_GPUS \ + train.py \ + --config $config \ + --output_dir $output_dir diff --git a/script/sft_navsim.sh b/script/sft_navsim.sh new file mode 100644 index 0000000..aedda23 --- /dev/null +++ b/script/sft_navsim.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +NUM_NODES=1 +NUM_GPUS=1 + +config=config/sft_navsim.yaml +output_dir=output/train/debug + +accelerate launch \ + --config_file ./config/accelerate_config_ds2.yaml \ + --machine_rank 0 \ + --main_process_port 12345 \ + --num_machines $NUM_NODES \ + --num_processes $NUM_GPUS \ + train.py \ + --config $config \ + --output_dir $output_dir