From 1303a69bb235671a76fb02d5a4ea7ca57469ede4 Mon Sep 17 00:00:00 2001 From: DDVD Date: Mon, 11 Aug 2025 14:54:55 -0400 Subject: [PATCH 001/232] Adapt to Our Datasets (#1) --- .../run_qwen2_5_vl-7b_climb_no_thinking.sh | 54 ++ examples/format_prompt/README.md | 63 ++ examples/format_prompt/default.jinja | 1 + examples/format_prompt/no_thinking.jinja | 1 + .../generation/run_deepseek7b_mutli_node.sh | 0 .../generation/run_deepseek_v2_lite_math.sh | 0 examples/gmpo_trainer/run_qwen2_5-7b_math.sh | 0 examples/gmpo_trainer/test_dapo_7b_math.sh | 0 .../gmpo_trainer/test_dapo_qwen3_30b_math.sh | 0 .../run_deepseek671b_math_megatron.sh | 0 examples/grpo_trainer/run_deepseek7b_llm.sh | 0 .../grpo_trainer/run_deepseek7b_llm_math.sh | 0 .../run_deepseek7b_llm_math_megatron.sh | 0 .../run_deepseek7b_llm_seq_balance.sh | 0 examples/grpo_trainer/run_minicpmo2_6.sh | 0 .../run_moonlight16b_math_megatron.sh | 0 examples/grpo_trainer/run_qwen2-7b.sh | 0 examples/grpo_trainer/run_qwen2-7b_math.sh | 0 .../run_qwen2-7b_math_megatron.sh | 0 .../grpo_trainer/run_qwen2-7b_seq_balance.sh | 0 .../run_qwen2-7b_seq_balance_math_megatron.sh | 0 .../grpo_trainer/run_qwen2-7b_sgl_megatron.sh | 0 .../run_qwen2_5-3b_gsm8k_grpo_lora.sh | 0 .../run_qwen2_5-7b_math_megatron_diff_tp.sh | 0 .../grpo_trainer/run_qwen2_5_32b_grpo_npu.sh | 0 .../run_qwen2_5_7b_grpo_discrete_prof_npu.sh | 0 .../run_qwen2_5_7b_grpo_e2e_prof_npu.sh | 0 .../grpo_trainer/run_qwen2_5_7b_grpo_npu.sh | 0 .../run_qwen2_5_vl-7b-megatron.sh | 0 examples/grpo_trainer/run_qwen2_5_vl-7b.sh | 0 .../grpo_trainer/run_qwen2_5_vl-7b_climb.sh | 54 ++ .../grpo_trainer/run_qwen2_5_vl-7b_lora.sh | 0 .../run_qwen2_5_vl-7b_seq_balance.sh | 0 .../grpo_trainer/run_qwen2_5_vl_32b_npu.sh | 0 .../grpo_trainer/run_qwen2_5_vl_3b_npu.sh | 0 .../grpo_trainer/run_qwen2_5_vl_7b_npu.sh | 0 .../grpo_trainer/run_qwen3-236b_megatron.sh | 0 examples/grpo_trainer/run_qwen3-8b.sh | 0 .../grpo_trainer/run_qwen3moe-30b_megatron.sh | 0 examples/ppo_trainer/run_deepseek7b_llm.sh | 0 .../run_deepseek7b_llm_modelscope.sh | 0 .../ppo_trainer/run_deepseek7b_llm_pfppo.sh | 0 .../run_deepseek7b_llm_sandbox_fusion.sh | 0 .../ppo_trainer/run_deepseek7b_llm_sp2.sh | 0 .../ppo_trainer/run_deepseek_full_hh_rlhf.sh | 0 .../run_deepseek_math_gsm8k_megatron.sh | 0 .../run_deepseek_math_gsm8k_megatron_nsys.sh | 0 examples/ppo_trainer/run_gemma.sh | 0 .../run_moonlight16b_a3b_gsm8k_megatron.sh | 0 .../run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh | 0 .../run_qwen2-7b_math_gsm8k_megatron.sh | 0 examples/ppo_trainer/run_qwen2-7b_rm.sh | 0 .../run_qwen2-7b_rm_seq_balance.sh | 0 ...n_qwen2-7b_rm_seq_balance_fused_kernels.sh | 0 .../run_qwen2-7b_rm_seq_balance_nsys.sh | 0 .../ppo_trainer/run_qwen2-7b_seq_balance.sh | 0 .../run_qwen2-7b_sglang_seq_balance.sh | 0 examples/ppo_trainer/run_qwen2.5-32b.sh | 0 .../run_qwen2-7b_math_rf.sh | 0 .../run_qwen2-7b_math_rf_baseline.sh | 0 .../run_qwen2.5-3b_seq_balance.sh | 0 .../run_qwen2.5-7b_seq_balance.sh | 0 examples/reward_function/dapo.py | 163 ++++++ examples/reward_function/evaluation.py | 552 ++++++++++++++++++ examples/reward_function/math.py | 49 ++ examples/reward_function/medical.py | 460 +++++++++++++++ examples/reward_function/r1v.py | 50 ++ examples/rloo_trainer/run_qwen2-7b.sh | 0 examples/sft/gsm8k/run_deepseek_6b7.sh | 0 examples/sft/gsm8k/run_gemma_2b.sh | 0 examples/sft/gsm8k/run_gemma_7b.sh | 0 .../gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh | 0 examples/sft/gsm8k/run_qwen_05_peft.sh | 0 examples/sft/gsm8k/run_qwen_05_sp2.sh | 0 examples/sft/gsm8k/run_qwen_05_sp2_liger.sh | 0 examples/sft/multiturn/run_qwen_05_sp2.sh | 0 .../geo3k/run_qwen2.5-3b_geo3k_multiturn.sh | 0 .../run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh | 0 ...run_qwen2.5-3b_megatron_geo3k_multiturn.sh | 0 ...n2.5-0.5b_gsm8k_multiturn_w_interaction.sh | 0 .../run_qwen2.5-3b_gsm8k_multiturn.sh | 0 .../run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh | 0 .../run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh | 0 ...run_qwen2.5-3b_megatron_gsm8k_multiturn.sh | 0 .../run_qwen2_3b_dapo_multiturn.sh | 0 ...un_qwen2.5-3b_instruct_search_multiturn.sh | 0 .../split_placement/run_deepseek7b_llm.sh | 0 .../qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh | 0 .../qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh | 0 .../qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh | 0 .../14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh | 0 .../qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh | 0 .../32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh | 0 .../3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh | 0 .../70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh | 0 .../70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh | 0 .../qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh | 0 .../7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh | 0 .../7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh | 0 scripts/process_mosei_annotations.py | 74 +++ .../_generated_ppo_megatron_trainer.yaml | 1 + .../config/_generated_ppo_trainer.yaml | 1 + verl/trainer/config/data/legacy_data.yaml | 5 + verl/trainer/ppo/core_algos.py | 180 +++++- verl/trainer/ppo/ray_trainer.py | 110 +++- verl/utils/dataset/rl_dataset.py | 130 ++++- verl/utils/dataset/vision_utils.py | 23 +- verl/workers/actor/dp_actor.py | 4 +- 108 files changed, 1942 insertions(+), 33 deletions(-) create mode 100755 examples/drpo_trainer/run_qwen2_5_vl-7b_climb_no_thinking.sh create mode 100644 examples/format_prompt/README.md create mode 100644 examples/format_prompt/default.jinja create mode 100644 examples/format_prompt/no_thinking.jinja mode change 100644 => 100755 examples/generation/run_deepseek7b_mutli_node.sh mode change 100644 => 100755 examples/generation/run_deepseek_v2_lite_math.sh mode change 100644 => 100755 examples/gmpo_trainer/run_qwen2_5-7b_math.sh mode change 100644 => 100755 examples/gmpo_trainer/test_dapo_7b_math.sh mode change 100644 => 100755 examples/gmpo_trainer/test_dapo_qwen3_30b_math.sh mode change 100644 => 100755 examples/grpo_trainer/run_deepseek671b_math_megatron.sh mode change 100644 => 100755 examples/grpo_trainer/run_deepseek7b_llm.sh mode change 100644 => 100755 examples/grpo_trainer/run_deepseek7b_llm_math.sh mode change 100644 => 100755 examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh mode change 100644 => 100755 examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh mode change 100644 => 100755 examples/grpo_trainer/run_minicpmo2_6.sh mode change 100644 => 100755 examples/grpo_trainer/run_moonlight16b_math_megatron.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2-7b.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2-7b_math.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2-7b_math_megatron.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2-7b_seq_balance.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_vl-7b.sh create mode 100755 examples/grpo_trainer/run_qwen2_5_vl-7b_climb.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen3-236b_megatron.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen3-8b.sh mode change 100644 => 100755 examples/grpo_trainer/run_qwen3moe-30b_megatron.sh mode change 100644 => 100755 examples/ppo_trainer/run_deepseek7b_llm.sh mode change 100644 => 100755 examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh mode change 100644 => 100755 examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh mode change 100644 => 100755 examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh mode change 100644 => 100755 examples/ppo_trainer/run_deepseek7b_llm_sp2.sh mode change 100644 => 100755 examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh mode change 100644 => 100755 examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh mode change 100644 => 100755 examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh mode change 100644 => 100755 examples/ppo_trainer/run_gemma.sh mode change 100644 => 100755 examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh mode change 100644 => 100755 examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh mode change 100644 => 100755 examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh mode change 100644 => 100755 examples/ppo_trainer/run_qwen2-7b_rm.sh mode change 100644 => 100755 examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh mode change 100644 => 100755 examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh mode change 100644 => 100755 examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh mode change 100644 => 100755 examples/ppo_trainer/run_qwen2-7b_seq_balance.sh mode change 100644 => 100755 examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh mode change 100644 => 100755 examples/ppo_trainer/run_qwen2.5-32b.sh mode change 100644 => 100755 examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh mode change 100644 => 100755 examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh mode change 100644 => 100755 examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh mode change 100644 => 100755 examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh create mode 100644 examples/reward_function/dapo.py create mode 100644 examples/reward_function/evaluation.py create mode 100644 examples/reward_function/math.py create mode 100644 examples/reward_function/medical.py create mode 100644 examples/reward_function/r1v.py mode change 100644 => 100755 examples/rloo_trainer/run_qwen2-7b.sh mode change 100644 => 100755 examples/sft/gsm8k/run_deepseek_6b7.sh mode change 100644 => 100755 examples/sft/gsm8k/run_gemma_2b.sh mode change 100644 => 100755 examples/sft/gsm8k/run_gemma_7b.sh mode change 100644 => 100755 examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh mode change 100644 => 100755 examples/sft/gsm8k/run_qwen_05_peft.sh mode change 100644 => 100755 examples/sft/gsm8k/run_qwen_05_sp2.sh mode change 100644 => 100755 examples/sft/gsm8k/run_qwen_05_sp2_liger.sh mode change 100644 => 100755 examples/sft/multiturn/run_qwen_05_sp2.sh mode change 100644 => 100755 examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh mode change 100644 => 100755 examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh mode change 100644 => 100755 examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh mode change 100644 => 100755 examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh mode change 100644 => 100755 examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh mode change 100644 => 100755 examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh mode change 100644 => 100755 examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh mode change 100644 => 100755 examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh mode change 100644 => 100755 examples/sglang_multiturn/run_qwen2_3b_dapo_multiturn.sh mode change 100644 => 100755 examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh mode change 100644 => 100755 examples/split_placement/run_deepseek7b_llm.sh mode change 100644 => 100755 examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh mode change 100644 => 100755 examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh mode change 100644 => 100755 examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh create mode 100644 scripts/process_mosei_annotations.py diff --git a/examples/drpo_trainer/run_qwen2_5_vl-7b_climb_no_thinking.sh b/examples/drpo_trainer/run_qwen2_5_vl-7b_climb_no_thinking.sh new file mode 100755 index 00000000000..87c368ad1a2 --- /dev/null +++ b/examples/drpo_trainer/run_qwen2_5_vl-7b_climb_no_thinking.sh @@ -0,0 +1,54 @@ +set -x +ENGINE=${1:-vllm} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=drpo \ + data.train_files=/home/dvdai/orcd/scratch/high_modality/geom_train_upsampled_new.jsonl \ + data.val_files=/home/dvdai/orcd/scratch/high_modality/geom_valid_mini_new.jsonl \ + data.train_batch_size=512 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.format_prompt=examples/format_prompt/no_thinking.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=2e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_climb' \ + trainer.experiment_name='drpo_nothinking' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/format_prompt/README.md b/examples/format_prompt/README.md new file mode 100644 index 00000000000..412c5a558e3 --- /dev/null +++ b/examples/format_prompt/README.md @@ -0,0 +1,63 @@ +# Format Prompt Templates + +This directory contains Jinja2 templates for formatting prompts in RLHF datasets. + +## Overview + +The format prompt feature allows you to apply custom formatting to each prompt in your dataset using Jinja2 templates. This is useful when you want to add consistent instructions or formatting to all prompts without modifying the original dataset. + +## Default Template + +The default template (`default.jinja`) appends the following instruction to each prompt: + +``` +{{ content }}You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}. +``` + +## Usage + +To use a format prompt template, specify the `format_prompt` parameter in your data configuration: + +```yaml +data: + # ... other data config ... + format_prompt: examples/format_prompt/default.jinja # Path to your template file +``` + +Or set it to `null` to disable format prompting: + +```yaml +data: + format_prompt: null +``` + +## Creating Custom Templates + +To create a custom format prompt: + +1. Create a new `.jinja` file in this directory or elsewhere +2. Use `{{ content }}` as the placeholder for the original prompt content +3. Add your custom formatting around it + +Example custom template: + +```jinja +{{ content }} + +Please solve this problem step by step: +1. Understand the problem +2. Plan your approach +3. Execute the solution +4. Verify your answer +``` + +## Template Variables + +Currently, the template receives one variable: +- `content`: The original prompt text + +## Notes + +- The template is applied during dataset preprocessing +- If the template file is not found, the system will use the original prompt without formatting +- For multimodal datasets (images/videos), the formatting is applied to text segments only \ No newline at end of file diff --git a/examples/format_prompt/default.jinja b/examples/format_prompt/default.jinja new file mode 100644 index 00000000000..be95b0ef441 --- /dev/null +++ b/examples/format_prompt/default.jinja @@ -0,0 +1 @@ +{{ content }}You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}. \ No newline at end of file diff --git a/examples/format_prompt/no_thinking.jinja b/examples/format_prompt/no_thinking.jinja new file mode 100644 index 00000000000..39a137c9384 --- /dev/null +++ b/examples/format_prompt/no_thinking.jinja @@ -0,0 +1 @@ +{{ content }}You MUST provide the final answer directly without any extra information. Enclose the final answer in \boxed{}. \ No newline at end of file diff --git a/examples/generation/run_deepseek7b_mutli_node.sh b/examples/generation/run_deepseek7b_mutli_node.sh old mode 100644 new mode 100755 diff --git a/examples/generation/run_deepseek_v2_lite_math.sh b/examples/generation/run_deepseek_v2_lite_math.sh old mode 100644 new mode 100755 diff --git a/examples/gmpo_trainer/run_qwen2_5-7b_math.sh b/examples/gmpo_trainer/run_qwen2_5-7b_math.sh old mode 100644 new mode 100755 diff --git a/examples/gmpo_trainer/test_dapo_7b_math.sh b/examples/gmpo_trainer/test_dapo_7b_math.sh old mode 100644 new mode 100755 diff --git a/examples/gmpo_trainer/test_dapo_qwen3_30b_math.sh b/examples/gmpo_trainer/test_dapo_qwen3_30b_math.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_deepseek671b_math_megatron.sh b/examples/grpo_trainer/run_deepseek671b_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_deepseek7b_llm.sh b/examples/grpo_trainer/run_deepseek7b_llm.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_deepseek7b_llm_math.sh b/examples/grpo_trainer/run_deepseek7b_llm_math.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh b/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh b/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_minicpmo2_6.sh b/examples/grpo_trainer/run_minicpmo2_6.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_moonlight16b_math_megatron.sh b/examples/grpo_trainer/run_moonlight16b_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b.sh b/examples/grpo_trainer/run_qwen2-7b.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_math.sh b/examples/grpo_trainer/run_qwen2-7b_math.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh b/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh b/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh b/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_climb.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_climb.sh new file mode 100755 index 00000000000..761abd09784 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_climb.sh @@ -0,0 +1,54 @@ +set -x +ENGINE=${1:-vllm} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/home/dvdai/orcd/scratch/high_modality/geom_train_upsampled_new.jsonl \ + data.val_files=/home/dvdai/orcd/scratch/high_modality/geom_valid_mini_new.jsonl \ + data.train_batch_size=512 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.format_prompt=examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_climb' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen3-236b_megatron.sh b/examples/grpo_trainer/run_qwen3-236b_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen3-8b.sh b/examples/grpo_trainer/run_qwen3-8b.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh b/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm.sh b/examples/ppo_trainer/run_deepseek7b_llm.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh b/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh b/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh b/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh b/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_gemma.sh b/examples/ppo_trainer/run_gemma.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh b/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_rm.sh b/examples/ppo_trainer/run_qwen2-7b_rm.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2.5-32b.sh b/examples/ppo_trainer/run_qwen2.5-32b.sh old mode 100644 new mode 100755 diff --git a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh old mode 100644 new mode 100755 diff --git a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh old mode 100644 new mode 100755 diff --git a/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh b/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh b/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/reward_function/dapo.py b/examples/reward_function/dapo.py new file mode 100644 index 00000000000..9285cd1d0fd --- /dev/null +++ b/examples/reward_function/dapo.py @@ -0,0 +1,163 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def accuracy_reward(response: str, ground_truth: str) -> float: + match = re.findall(r"(?i)Answer\s*:\s*([^\n]+)", response) + answer = match[-1] if match else "[INVALID]" + if normalize_final_answer(answer) == normalize_final_answer(ground_truth): + return 1.0 + else: + return -1.0 + + +def soft_overlong_punishment(response_length: int, max_response_length: int, overlong_buffer_length: int): + expected_len = max_response_length - overlong_buffer_length + if response_length <= expected_len: + return 0.0 + elif response_length <= max_response_length: + return (expected_len - response_length) / overlong_buffer_length + else: + return -1.0 + + +def compute_score( + reward_inputs: list[dict[str, Any]], + max_response_length: int, + overlong_buffer_length: int, + overlong_penalty_factor: float, +) -> list[dict[str, float]]: + if not isinstance(reward_inputs, list): + raise ValueError("Please use `reward_type=batch` for dapo reward function.") + + scores = [] + for reward_input in reward_inputs: + response = reward_input["response"][-300:] # The longest answer in MATH-500 has 159 characters + accuracy_score = accuracy_reward(response, reward_input["ground_truth"]) + overlong_score = soft_overlong_punishment( + reward_input["response_length"], max_response_length, overlong_buffer_length + ) + scores.append( + { + "overall": accuracy_score + overlong_score * overlong_penalty_factor, + "accuracy": accuracy_score, + "overlong": overlong_score, + "accuracy_normalized": 0.5 * (accuracy_score + 1.0), + } + ) + + return scores diff --git a/examples/reward_function/evaluation.py b/examples/reward_function/evaluation.py new file mode 100644 index 00000000000..45ec549d862 --- /dev/null +++ b/examples/reward_function/evaluation.py @@ -0,0 +1,552 @@ +import datetime +import json +import os +from collections import defaultdict +from typing import Dict, List, Set +import statistics + +def parse_conditions(text: str) -> Set[str]: + """ + Parse medical conditions from text, handling various separators. + + Args: + text (str): Text containing medical conditions. + + Returns: + Set[str]: Set of individual medical conditions. + """ + # Remove any boxing notation if present + text = text.replace("\\boxed{", "").replace("}", "") + + # Split by common separators + for sep in [", ", " and ", " & ", ",", "&"]: + if sep in text: + return set(cond.strip() for cond in text.split(sep)) + + # If no separator found, treat as single condition + return {text.strip()} + + +def extract_boxed_content(text: str) -> str: + """ + Extract content within \boxed{} or similar boxing notations. + + Args: + text (str): Text containing potentially boxed content. + + Returns: + str: Extracted boxed content or the original text if no box found. + """ + import re + + # Look for LaTeX \boxed{} notation + boxed_match = re.search(r"\\boxed{([^}]*)}", text) + if boxed_match: + return boxed_match.group(1) + + # Look for markdown boxed notation (e.g., [boxed content]) + markdown_match = re.search(r"\[(.*?)\]", text) + if markdown_match: + return markdown_match.group(1) + + # Return the text as is if no boxed content is found + return text + + +def compute_class_metrics(class_name: str, confusion_matrix: Dict[str, int]) -> Dict[str, float]: + """ + Compute metrics for a single class based on its confusion matrix. + + Args: + class_name (str): Name of the class. + confusion_matrix (Dict[str, int]): Confusion matrix with tp, fp, fn, tn. + + Returns: + Dict[str, float]: Dictionary of metrics for this class. + """ + tp = confusion_matrix["tp"] + fp = confusion_matrix["fp"] + fn = confusion_matrix["fn"] + tn = confusion_matrix["tn"] + + # Calculate metrics (avoid division by zero) + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + sensitivity = recall # sensitivity is the same as recall + specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0 + + return { + "precision": precision, + "recall": recall, + "sensitivity": sensitivity, + "specificity": specificity, + "f1": f1, + "accuracy": accuracy, + "count": confusion_matrix["count"], + "confusion_matrix": {"tp": tp, "fp": fp, "fn": fn, "tn": tn}, + } + + +def gender(predictions: List[str], ground_truths: List[str], demographics: List[str]) -> Dict[str, float]: + groups = {"male": {"preds": [], "gts": []}, "female": {"preds": [], "gts": []}} + + for pred, gt, demo in zip(predictions, ground_truths, demographics): + if demo is not None and "female" in demo.lower(): + groups["female"]["preds"].append(pred) + groups["female"]["gts"].append(gt) + elif demo is not None and "male" in demo.lower(): + groups["male"]["preds"].append(pred) + groups["male"]["gts"].append(gt) + + results = {} + acc_values = [] + f1_values = [] + + for sex in ["male", "female"]: + preds = groups[sex]["preds"] + gts = groups[sex]["gts"] + if len(preds) == 0: + continue + metrics = compute_dataset_metrics(preds, gts)["dataset_metrics"] + acc = metrics["accuracy"] + f1 = metrics["f1"] + results[f"{sex}/accuracy"] = acc + results[f"{sex}/f1"] = f1 + acc_values.append(acc) + f1_values.append(f1) + print(f"{sex}: accuracy = {acc:.4f}, f1 = {f1:.4f}") + + if len(acc_values) >= 2: + acc_diff = abs(acc_values[0] - acc_values[1]) + results["acc_diff for sex"] = acc_diff + results["std_accuracy for sex"] = statistics.stdev(acc_values) + print(f"Accuracy max diff for sex = {acc_diff:.4f}") + print(f"std of accuracy for sex = {results['std_accuracy for sex']:.4f}") + + if len(f1_values) >= 2: + f1_diff = abs(f1_values[0] - f1_values[1]) + results["f1_diff for sex"] = f1_diff + results["std_f1 for sex"] = statistics.stdev(f1_values) + print(f"F1 max diff for sex = {f1_diff:.4f}") + print(f"std of f1 for sex = {results['std_f1 for sex']:.4f}") + + return results + + +def parent(predictions: List[str], ground_truths: List[str], demographics: List[str]) -> Dict[str, float]: + groups = {} + for pred, gt, demo in zip(predictions, ground_truths, demographics): + if demo is not None and "father" in demo.lower(): + if ( + demo.split("father:")[1].strip().split()[0] not in groups + and demo.split("father:")[1].strip().split()[0] != "NAN" + ): + groups[demo.split("father:")[1].strip().split()[0]] = {"preds": [], "gts": []} + groups[demo.split("father:")[1].strip().split()[0]]["preds"].append(pred) + groups[demo.split("father:")[1].strip().split()[0]]["gts"].append(gt) + else: + groups[demo.split("father:")[1].strip().split()[0]]["preds"].append(pred) + groups[demo.split("father:")[1].strip().split()[0]]["gts"].append(gt) + if demo is not None and "mother" in demo.lower(): + if ( + demo.split("mother:")[1].strip().split()[0] not in groups + and demo.split("mother:")[1].strip().split()[0] != "NAN" + ): + groups[demo.split("mother:")[1].strip().split()[0]] = {"preds": [], "gts": []} + groups[demo.split("mother:")[1].strip().split()[0]]["preds"].append(pred) + groups[demo.split("mother:")[1].strip().split()[0]]["gts"].append(gt) + else: + groups[demo.split("father:")[1].strip().split()[0]]["preds"].append(pred) + groups[demo.split("father:")[1].strip().split()[0]]["gts"].append(gt) + + results = {} + acc_values = [] + f1_values = [] + + for race in groups: + preds = groups[race]["preds"] + gts = groups[race]["gts"] + if len(preds) == 0: + continue + metrics = compute_dataset_metrics(preds, gts)["dataset_metrics"] + acc = metrics["accuracy"] + f1 = metrics["f1"] + results[f"{race}/accuracy"] = acc + results[f"{race}/f1"] = f1 + acc_values.append(acc) + f1_values.append(f1) + print(f"{race}: accuracy = {acc:.4f}, f1 = {f1:.4f}") + + if len(acc_values) >= 2: + acc_diff = max(acc_values) - min(acc_values) + results["acc_diff"] = acc_diff + print(f"Accuracy max diff for parent = {acc_diff:.4f}") + std_acc = statistics.stdev(acc_values) + results["std_accuracy"] = std_acc + print(f"std of accuracy for parent = {std_acc:.4f}") + + if len(f1_values) >= 2: + f1_diff = max(f1_values) - min(f1_values) + results["f1_diff"] = f1_diff + print(f"F1 max diff for parent = {f1_diff:.4f}") + std_f1 = statistics.stdev(f1_values) + results["std_f1"] = std_f1 + print(f"std of f1 for parent = {std_f1:.4f}") + + return results + + +def age(predictions: List[str], ground_truths: List[str], demographics: List[str]) -> Dict[str, float]: + groups = { + "a1": {"preds": [], "gts": []}, + "a2": {"preds": [], "gts": []}, + "a3": {"preds": [], "gts": []}, + "a4": {"preds": [], "gts": []}, + } + + for pred, gt, demo in zip(predictions, ground_truths, demographics): + if demo is not None and "age" in demo.lower(): + try: + age_str = demo.split("age:")[1].strip().split()[0].replace(",", "") + age_val = float(age_str) + except (IndexError, ValueError): + continue + + if age_val <= 25: + groups["a1"]["preds"].append(pred) + groups["a1"]["gts"].append(gt) + elif 25 < age_val <= 50: + groups["a2"]["preds"].append(pred) + groups["a2"]["gts"].append(gt) + elif 50 < age_val <= 75: + groups["a3"]["preds"].append(pred) + groups["a3"]["gts"].append(gt) + elif 75 < age_val: + groups["a4"]["preds"].append(pred) + groups["a4"]["gts"].append(gt) + + results = {} + acc_values = [] + f1_values = [] + + for group in ["a1", "a2", "a3", "a4"]: + preds = groups[group]["preds"] + gts = groups[group]["gts"] + if len(preds) == 0: + continue + metrics = compute_dataset_metrics(preds, gts)["dataset_metrics"] + acc = metrics["accuracy"] + f1 = metrics["f1"] + results[f"{group}/accuracy"] = acc + results[f"{group}/f1"] = f1 + acc_values.append(acc) + f1_values.append(f1) + + if len(acc_values) >= 2: + results["acc_diff"] = max(acc_values) - min(acc_values) + results["std_accuracy"] = statistics.stdev(acc_values) + + if len(f1_values) >= 2: + results["f1_diff"] = max(f1_values) - min(f1_values) + results["std_f1"] = statistics.stdev(f1_values) + + for group in ["a1", "a2", "a3", "a4"]: + acc = results.get(f"{group}/accuracy") + f1 = results.get(f"{group}/f1") + if acc is not None and f1 is not None: + print(f"{group}: accuracy = {acc:.4f}, f1 = {f1:.4f}") + + if "acc_diff" in results: + print(f"Accuracy max diff = {results['acc_diff']:.4f}") + print(f"std of accuracy for age = {results['std_accuracy']:.4f}") + if "f1_diff" in results: + print(f"F1 max diff = {results['f1_diff']:.4f}") + print(f"std of f1 for age = {results['std_f1']:.4f}") + + return results +def compute_confusion_matrices(predictions: List[str], ground_truths: List[str]) -> Dict[str, Dict[str, int]]: + """ + Compute confusion matrices for each class. + + Args: + predictions (List[str]): List of model predictions. + ground_truths (List[str]): List of ground truth labels. + + Returns: + Dict[str, Dict[str, int]]: Confusion matrices for each class. + """ + # Initialize counters for each condition + all_conditions = set() + condition_matrices = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0, "tn": 0, "count": 0}) + + # First pass: identify all unique conditions + for gt in ground_truths: + gt_conditions = parse_conditions(gt) + all_conditions.update(gt_conditions) + + for pred in predictions: + pred_answer = extract_boxed_content(pred) + if pred_answer != "None": + pred_conditions = parse_conditions(pred_answer) + all_conditions.update(pred_conditions) + + # Second pass: compute confusion matrices + for pred, gt in zip(predictions, ground_truths): + pred_answer = extract_boxed_content(pred) + if pred_answer == "None": + pred_conditions = set() + else: + pred_conditions = parse_conditions(pred_answer) + + gt_conditions = parse_conditions(gt) + + # For each possible condition + for condition in all_conditions: + condition_present_in_gt = condition in gt_conditions + condition_present_in_pred = condition in pred_conditions + + if condition_present_in_gt: + condition_matrices[condition]["count"] += 1 + + if condition_present_in_gt and condition_present_in_pred: + # True positive + condition_matrices[condition]["tp"] += 1 + elif condition_present_in_gt and not condition_present_in_pred: + # False negative + condition_matrices[condition]["fn"] += 1 + elif not condition_present_in_gt and condition_present_in_pred: + # False positive + condition_matrices[condition]["fp"] += 1 + else: + # True negative + condition_matrices[condition]["tn"] += 1 + + return condition_matrices + + +def compute_dataset_metrics(predictions: List[str], ground_truths: List[str]) -> Dict[str, Dict]: + """ + Compute metrics for a single dataset, with class-wise averaging. + + Args: + predictions (List[str]): List of model predictions for this dataset. + ground_truths (List[str]): List of ground truth labels for this dataset. + + Returns: + Dict[str, Dict]: Class metrics and averaged dataset metrics. + """ + # Compute confusion matrices for each class + class_matrices = compute_confusion_matrices(predictions, ground_truths) + + # Compute metrics for each class + class_metrics = {} + active_classes = 0 + + # Accumulators for dataset-level metrics + dataset_metrics = { + "precision": 0.0, + "recall": 0.0, + "sensitivity": 0.0, + "specificity": 0.0, + "f1": 0.0, + "accuracy": 0.0, + } + + # Compute metrics for each class and accumulate for dataset average + for class_name, matrix in class_matrices.items(): + # Skip classes that never appear in ground truth + if matrix["count"] == 0: + continue + + active_classes += 1 + metrics = compute_class_metrics(class_name, matrix) + class_metrics[class_name] = metrics + + # Accumulate for dataset average (equal class weighting) + for metric_name in dataset_metrics.keys(): + dataset_metrics[metric_name] += metrics[metric_name] + + # Calculate dataset average (equal class weighting) + if active_classes > 0: + for metric_name in dataset_metrics.keys(): + dataset_metrics[metric_name] /= active_classes + + # Add class metrics to the result + result = {"class_metrics": class_metrics, "dataset_metrics": dataset_metrics, "active_classes": active_classes} + + return result + + +def compute_metrics_by_data_source( + predictions: List[str], + ground_truths: List[str], + data_sources: List[str], + datasets: List[str], + demographics: List[str], +) -> Dict[str, float]: + """ + Compute hierarchical metrics: class -> dataset -> data source -> global. + + Args: + predictions (List[str]): List of model predictions. + ground_truths (List[str]): List of ground truth labels. + data_sources (List[str]): List of data sources for each example. + datasets (List[str]): List of dataset identifiers for each example. + demographics (List[str]): List of demographic information for each example. + + Returns: + Dict[str, float]: Flattened dictionary of metrics at all levels with keys: + - "val/{metric}" for global metrics + - "{data_source}/{metric}" for data source metrics + - "{data_source}/{dataset}/{metric}" for dataset metrics + """ + # Save inputs to json for debugging under outputs/ + + output_dir = "outputs" + os.makedirs(output_dir, exist_ok=True) + input_data = { + "predictions": predictions, + "ground_truths": ground_truths, + "data_sources": data_sources, + "datasets": datasets, + "demographics": demographics, + } + # name is time in yyyy-mm-dd_hh-mm-ss format + with open( + os.path.join(output_dir, f"input_data_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"), "w" + ) as f: + json.dump(input_data, f, indent=4) + + # Group examples by data source and dataset + grouped_data = defaultdict(lambda: defaultdict(lambda: {"preds": [], "gts": []})) + + for pred, gt, source, dataset in zip(predictions, ground_truths, data_sources, datasets): + grouped_data[source][dataset]["preds"].append(pred) + grouped_data[source][dataset]["gts"].append(gt) + + # Initialize the flattened result dictionary + result = {} + + # Initialize global metrics accumulators + global_metrics = { + "precision": 0.0, + "recall": 0.0, + "sensitivity": 0.0, + "specificity": 0.0, + "f1": 0.0, + "accuracy": 0.0, + } + + # Compute metrics for each dataset within each data source + total_data_sources = 0 + + for source_name, source_datasets in grouped_data.items(): + # Initialize metrics accumulators for this data source + source_metrics = { + "precision": 0.0, + "recall": 0.0, + "sensitivity": 0.0, + "specificity": 0.0, + "f1": 0.0, + "accuracy": 0.0, + } + + total_datasets_in_source = 0 + + for dataset_name, dataset_data in source_datasets.items(): + # Compute metrics for this dataset + dataset_result = compute_dataset_metrics(dataset_data["preds"], dataset_data["gts"]) + + # Store dataset-level metrics with the format "data_source/dataset/metric" + for metric_name, metric_value in dataset_result["dataset_metrics"].items(): + result[f"{source_name}/{dataset_name}/{metric_name}"] = metric_value + + # Skip empty datasets + if dataset_result["active_classes"] == 0: + continue + + total_datasets_in_source += 1 + + # Accumulate metrics for data source average (equal dataset weighting) + for metric_name in source_metrics.keys(): + source_metrics[metric_name] += dataset_result["dataset_metrics"][metric_name] + + # Calculate data source average (equal dataset weighting) + if total_datasets_in_source > 0: + for metric_name in source_metrics.keys(): + source_metrics[metric_name] /= total_datasets_in_source + + # Store data source metrics with the format "data_source/metric" + for metric_name, metric_value in source_metrics.items(): + result[f"{source_name}/{metric_name}"] = metric_value + + total_data_sources += 1 + + # Accumulate for global metrics (equal data source weighting) + for metric_name in global_metrics.keys(): + global_metrics[metric_name] += source_metrics[metric_name] + + # Calculate global average (equal data source weighting) + if total_data_sources > 0: + for metric_name in global_metrics.keys(): + global_metrics[metric_name] /= total_data_sources + + # Store global metrics with the format "val/metric" + for metric_name, metric_value in global_metrics.items(): + result[f"val/{metric_name}"] = metric_value + + gender_results = gender(predictions, ground_truths, demographics) + for k, v in gender_results.items(): + result[f"fairness/gender/{k}"] = v + + age_results = age(predictions, ground_truths, demographics) + for k, v in age_results.items(): + result[f"fairness/age/{k}"] = v + + parent_results = parent(predictions, ground_truths, demographics) + for k, v in parent_results.items(): + result[f"fairness/parent/{k}"] = v + + + std_acc_values = [] + std_f1_values = [] + try: + + std_acc_values.append(gender_results["std_accuracy for sex"]) + std_f1_values.append(gender_results["std_f1 for sex"]) + + + std_acc_values.append(age_results["std_accuracy"]) + std_f1_values.append(age_results["std_f1"]) + + std_acc_values.append(parent_results["std_accuracy"]) + std_f1_values.append(parent_results["std_f1"]) + + result["fairness/avg_std_accuracy"] = sum(std_acc_values) / len(std_acc_values) + result["fairness/avg_std_f1"] = sum(std_f1_values) / len(std_f1_values) + except KeyError: + print("Some fairness metrics do not have standard deviation values, skipping average calculation.") + + return result + + +if __name__ == "__main__": + outputs_dir = "../../outputs" + output_files = [f for f in os.listdir(outputs_dir) if f.startswith("input_data_") and f.endswith(".json")] + if not output_files: + print("No output files found in the outputs directory.") + else: + latest_file = max(output_files, key=lambda f: os.path.getmtime(os.path.join(outputs_dir, f))) + with open(os.path.join(outputs_dir, latest_file), "r") as f: + input_data = json.load(f) + + predictions = input_data["predictions"] + ground_truths = input_data["ground_truths"] + data_sources = input_data["data_sources"] + datasets = input_data["datasets"] + demographics = input_data["demographics"] + + metrics = compute_metrics_by_data_source(predictions, ground_truths, data_sources, datasets, demographics) + print(json.dumps(metrics, indent=4)) \ No newline at end of file diff --git a/examples/reward_function/math.py b/examples/reward_function/math.py new file mode 100644 index 00000000000..ea75e3e91b5 --- /dev/null +++ b/examples/reward_function/math.py @@ -0,0 +1,49 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any + +from mathruler.grader import extract_boxed_content, grade_answer + + +def format_reward(response: str) -> float: + pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) + format_match = re.fullmatch(pattern, response) + return 1.0 if format_match else 0.0 + + +def accuracy_reward(response: str, ground_truth: str) -> float: + answer = extract_boxed_content(response) + return 1.0 if grade_answer(answer, ground_truth) else 0.0 + + +def compute_score(reward_inputs: list[dict[str, Any]], format_weight: float = 0.1) -> list[dict[str, float]]: + if not isinstance(reward_inputs, list): + raise ValueError("Please use `reward_type=batch` for math reward function.") + + scores = [] + for reward_input in reward_inputs: + response = re.sub(r"\s*(<|>|/)\s*", r"\1", reward_input["response"]) # handle qwen2.5vl-32b format + format_score = format_reward(response) + accuracy_score = accuracy_reward(response, reward_input["ground_truth"]) + scores.append( + { + "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, + "format": format_score, + "accuracy": accuracy_score, + } + ) + + return scores diff --git a/examples/reward_function/medical.py b/examples/reward_function/medical.py new file mode 100644 index 00000000000..aeeac05b019 --- /dev/null +++ b/examples/reward_function/medical.py @@ -0,0 +1,460 @@ +import re +import json +from typing import Dict, List + +import numpy +import torch +import numpy as np +from mathruler.grader import extract_boxed_content +import wandb +import random + + +def parse_conditions(text): + # Remove any boxing notation if present + text = text.replace("\\boxed{", "").replace("}", "") + + # Split by common separators + for sep in [", ", " and ", " & ", ",", "&"]: + if sep in text: + return set(cond.strip() for cond in text.split(sep)) + + # If no separator found, treat as single condition + return {text.strip()} + + +def parse_json(json_output): + """ + Parsing out the markdown fencing from JSON code blocks. + """ + # Look for content between ```json and ``` + lines = json_output.splitlines() + for i, line in enumerate(lines): + if line == "```json" or line.strip() == "```": + json_output = "\n".join(lines[i + 1:]) # Remove everything before ```json + if "```" in json_output: + json_output = json_output.split("```")[0] # Remove everything after the closing ``` + break # Exit the loop once code block marker is found + return json_output + + +def extract_json_from_response(text): + """ + Extract JSON content from markdown code blocks in the response. + + Args: + text: The model's response text + + Returns: + Parsed JSON object or None if no valid JSON found + """ + # Find content between ```json and ``` + json_pattern = r"```(?:json)?\s*([\s\S]*?)```" + matches = re.findall(json_pattern, text) + + if not matches: + return None + + # Try to parse each match as JSON + for match in matches: + try: + parsed_json = json.loads(match.strip()) + return parsed_json + except json.JSONDecodeError: + continue + + # If we couldn't parse any match as valid JSON, try with ast.literal_eval + import ast + for match in matches: + try: + # Clean up the match a bit + cleaned = match.strip().replace("'", "\"") + parsed_json = ast.literal_eval(cleaned) + return parsed_json + except: + continue + + return None + + +def bbox_to_mask(bbox, height, width): + """ + Convert bounding box to binary mask. + + Args: + bbox: Bounding box in format [x1, y1, x2, y2] + height: Height of the mask + width: Width of the mask + + Returns: + Binary mask of shape (height, width) + """ + mask = torch.zeros((height, width), dtype=torch.float32) + + # Ensure bbox coordinates are within image boundaries + x1 = max(0, min(int(bbox[0]), width - 1)) + y1 = max(0, min(int(bbox[1]), height - 1)) + x2 = max(0, min(int(bbox[2]), width - 1)) + y2 = max(0, min(int(bbox[3]), height - 1)) + + # Handle cases where x1>x2 or y1>y2 + if x1 > x2: + x1, x2 = x2, x1 + if y1 > y2: + y1, y2 = y2, y1 + + # Set the box region to 1 + if x1 < x2 and y1 < y2: # Ensure valid box dimensions + mask[y1:y2 + 1, x1:x2 + 1] = 1.0 + + return mask + + +def calculate_bbox_iou(pred_bboxes, seg_mask=None, gt_bbox=None): + """ + Calculate IoU between predicted bounding boxes and ground truth (segmentation mask or bbox). + + Args: + pred_bboxes: List of predicted bounding boxes in format [x1, y1, x2, y2] + seg_mask: Ground truth segmentation mask tensor + gt_bbox: Ground truth bounding box in format [x1, y1, x2, y2] + + Returns: + Mean IoU score across all bounding boxes + """ + if not pred_bboxes: + return 0.0 + + # If single layer bbox, wrap it in a list + if not isinstance(pred_bboxes[0], list): + pred_bboxes = [pred_bboxes] + + if seg_mask is not None and isinstance(seg_mask, numpy.ndarray): + seg_mask = torch.from_numpy(seg_mask) + + # Not none and not all zero + if seg_mask is not None and torch.sum(seg_mask) > 0: + # Get mask dimensions + if len(seg_mask.shape) == 3: # Channel dimension + height, width = seg_mask.shape[1], seg_mask.shape[2] + else: + height, width = seg_mask.shape[0], seg_mask.shape[1] + + # Convert segmentation mask to binary (1 for any positive value) + binary_seg_mask = (seg_mask > 0).float() + + total_iou = 0.0 + for bbox in pred_bboxes: + if len(bbox) < 4: + continue + # Convert bbox to mask + try: + bbox_mask = bbox_to_mask(bbox, height, width) + except: + continue + + # Calculate intersection and union + intersection = torch.sum(bbox_mask * binary_seg_mask) + union = torch.sum(torch.clamp(bbox_mask + binary_seg_mask, 0, 1)) + + # Calculate IoU + iou = intersection / union if union > 0 else 0.0 + total_iou += iou + + # Return mean IoU + return total_iou / len(pred_bboxes) + + elif gt_bbox is not None: + # Calculate IoU directly between bounding boxes + total_iou = 0.0 + for pred_bbox in pred_bboxes: + if len(pred_bbox) < 4: + continue + # Calculate intersection + gt_bbox = gt_bbox.tolist() + # print("pred_bbox: ", pred_bbox.__class__) + # print("gt_bbox: ", gt_bbox.__class__) + x1 = max(pred_bbox[0], gt_bbox[0]) + y1 = max(pred_bbox[1], gt_bbox[1]) + x2 = min(pred_bbox[2], gt_bbox[2]) + y2 = min(pred_bbox[3], gt_bbox[3]) + + # Check if boxes overlap + if x1 >= x2 or y1 >= y2: + iou = 0.0 + else: + # Calculate areas + intersection = (x2 - x1) * (y2 - y1) + pred_area = (pred_bbox[2] - pred_bbox[0]) * (pred_bbox[3] - pred_bbox[1]) + gt_area = (gt_bbox[2] - gt_bbox[0]) * (gt_bbox[3] - gt_bbox[1]) + union = pred_area + gt_area - intersection + + # Calculate IoU + iou = intersection / union if union > 0 else 0.0 + + total_iou += iou + + # Return mean IoU + return total_iou / len(pred_bboxes) + + else: + # Neither segmentation mask nor ground truth bbox provided + return 0.0 + + +def evaluate_bbox_format(predict_str): + """ + Evaluate the format correctness of the bounding box JSON in the response. + Returns a score based on how well the response follows the expected format. + + Args: + predict_str: The model's prediction string + + Returns: + Format score between 0.0 and 1.0 + """ + format_score = 0.0 + + # Check if response contains a code block + if "```" in predict_str: + format_score += 0.2 # 20% for having a code block + + # Check if it's specifically marked as JSON + if "```json" in predict_str: + format_score += 0.1 # Additional 10% for correct JSON marker + + # Try to extract and parse JSON + json_str = parse_json(predict_str) + if not json_str: + return format_score # Failed to find JSON content + + try: + # Try to parse as JSON + parsed_json = None + try: + parsed_json = json.loads(json_str) + format_score += 0.2 # Additional 20% for valid JSON + except json.JSONDecodeError: + # Try with ast.literal_eval as fallback + import ast + try: + cleaned = json_str.replace("'", "\"") + parsed_json = ast.literal_eval(cleaned) + format_score += 0.1 # Only 10% for requiring fallback parsing + except: + return format_score # Failed to parse + + # Check if it's a list of objects + if not isinstance(parsed_json, list): + return format_score + + format_score += 0.1 # Additional 10% for being a list + + # Check each item for proper bbox structure + valid_items = 0 + total_items = len(parsed_json) + + for item in parsed_json: + if not isinstance(item, dict): + continue + + # Check for required fields + has_bbox = "bbox_2d" in item + has_label = "label" in item + + if has_bbox and has_label: + bbox = item["bbox_2d"] + # Check bbox format [x1, y1, x2, y2] + if (isinstance(bbox, list) and len(bbox) == 4 and + all(isinstance(coord, (int, float)) for coord in bbox)): + valid_items += 1 + + # Add up to 40% based on proportion of valid items + if total_items > 0: + format_score += 0.4 * (valid_items / total_items) + + except Exception: + # Any other parsing issues + pass + + return format_score + + +def medical_compute_score(predict_str: str, ground_truth: str, segmentation_mask=None, bbox=None) -> Dict[str, float]: + """ + Compute medical scoring including standard score, bounding box IoU, and format score. + + Args: + predict_str: The model's prediction string + ground_truth: The ground truth string + segmentation_mask: Ground truth segmentation mask tensor + bbox: Ground truth bounding box + + Returns: + Tuple of (standard_score, bbox_score) + Note: bbox_score is a combination of IoU score and format score + """ + # Calculate standard score + answer = extract_boxed_content(predict_str) + if answer == "None": + standard_score = 0.0 # no answer + else: + # Parse both prediction and ground truth into sets of conditions + predicted_conditions = parse_conditions(answer) + ground_truth_conditions = parse_conditions(ground_truth) + + # Calculate true positives, false positives, and false negatives + true_positives = len(predicted_conditions.intersection(ground_truth_conditions)) + false_positives = len(predicted_conditions - ground_truth_conditions) + false_negatives = len(ground_truth_conditions - predicted_conditions) + + # Calculate F1 score components + precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 + recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 + + # Calculate F1 score (harmonic mean of precision and recall) + standard_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # Calculate format score (how well the JSON follows the expected format) + format_score = evaluate_bbox_format(predict_str) + + # length score + if len(predict_str) > 600: # ~200 words + length_score = 1 + else: + length_score = len(predict_str) * 0.001 + + + # Calculate bounding box IoU score + iou_score = 0.0 + # Extract predicted bounding boxes from the response + json_data = extract_json_from_response(predict_str) + if json_data: + # Extract bounding boxes from the JSON + try: + pred_bboxes = [] + if isinstance(json_data, list): + for item in json_data: + if isinstance(item, dict) and "bbox_2d" in item: + pred_bboxes.append(item["bbox_2d"]) + elif isinstance(json_data, dict) and "bbox_2d" in json_data: + pred_bboxes.append(json_data["bbox_2d"]) + elif isinstance(json_data, dict) and 'objects_of_interest' in json_data: + for item in json_data['objects_of_interest']: + if isinstance(item, dict) and "bbox_2d" in item: + pred_bboxes.append(item["bbox_2d"]) + # else: + # print("Error: Invalid JSON format") + if random.random() < 0.0005: # print every 0.5% + print("[Bounding Box] ", json_data) + print("[Formatted Bounding Box] ", pred_bboxes) + print('[GT Bounding Box] ', bbox) + + # Calculate IoU between predicted boxes and ground truth + if pred_bboxes: + iou_score = calculate_bbox_iou(pred_bboxes, segmentation_mask, bbox) + except: + pass + # traceback.print_exc() + + scores = { + "overall": 0.6 * standard_score + 0.2 * iou_score + 0.1 * format_score + 0.1 * length_score, + "standard_score": standard_score, + "iou_score": iou_score, + "format_score": format_score, + } + return scores + + +def medical_compute_score_batch(data_sources: List[str], solution_strs: List[str], ground_truths: List[str], extra_infos: List[str], **kwargs) -> List[Dict[str, float]]: + """ + Compute medical scoring for batch inputs including standard score, bounding box IoU, and format score. + + Args: + data_sources: List of data sources (e.g., file paths or identifiers) + solution_strs: List of model prediction strings + ground_truths: List of ground truth strings + extra_infos: List of extra information (e.g., segmentation masks, bounding boxes) + + Returns: + List of score dictionaries + """ + batch_scores = [] + + for data_source, predict_str, ground_truth, extra_info in zip(data_sources, solution_strs, ground_truths, extra_infos): + segmentation_mask = None + bbox = None + + # Calculate standard score + answer = extract_boxed_content(predict_str) + if answer == "None": + standard_score = 0.0 # no answer + else: + # Parse both prediction and ground truth into sets of conditions + predicted_conditions = parse_conditions(answer) + ground_truth_conditions = parse_conditions(ground_truth) + + # Calculate true positives, false positives, and false negatives + true_positives = len(predicted_conditions.intersection(ground_truth_conditions)) + false_positives = len(predicted_conditions - ground_truth_conditions) + false_negatives = len(ground_truth_conditions - predicted_conditions) + + # Calculate F1 score components + precision = ( + true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 + ) + recall = true_positives / (true_positives + false_negatives) if ( + true_positives + false_negatives) > 0 else 0 + + # Calculate F1 score (harmonic mean of precision and recall) + standard_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # Calculate format score (how well the JSON follows the expected format) + format_score = evaluate_bbox_format(predict_str) + + # length score + if len(predict_str) > 600: # ~200 words + length_score = 1 + else: + length_score = len(predict_str) * 0.001 + + # Calculate bounding box IoU score + iou_score = 0.0 + # Extract predicted bounding boxes from the response + json_data = extract_json_from_response(predict_str) + if json_data: + # Extract bounding boxes from the JSON + try: + pred_bboxes = [] + if isinstance(json_data, list): + for item in json_data: + if isinstance(item, dict) and "bbox_2d" in item: + pred_bboxes.append(item["bbox_2d"]) + elif isinstance(json_data, dict) and "bbox_2d" in json_data: + pred_bboxes.append(json_data["bbox_2d"]) + elif isinstance(json_data, dict) and "objects_of_interest" in json_data: + for item in json_data["objects_of_interest"]: + if isinstance(item, dict) and "bbox_2d" in item: + pred_bboxes.append(item["bbox_2d"]) + + if random.random() < 0.005: # print every 0.5% + print("[Bounding Box] ", json_data) + print("[Formatted Bounding Box] ", pred_bboxes) + print("[GT Bounding Box] ", bbox) + + # Calculate IoU between predicted boxes and ground truth + if pred_bboxes: + iou_score = calculate_bbox_iou(pred_bboxes, segmentation_mask, bbox) + except: + pass + + scores = { + "score": 0.5 * standard_score + 0.3 * iou_score + 0.1 * format_score, + "standard_score": standard_score, + "iou_score": iou_score, + "format_score": format_score, + "length_score": length_score, + } + batch_scores.append(scores) + + return batch_scores \ No newline at end of file diff --git a/examples/reward_function/r1v.py b/examples/reward_function/r1v.py new file mode 100644 index 00000000000..6a28548b292 --- /dev/null +++ b/examples/reward_function/r1v.py @@ -0,0 +1,50 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any + +from mathruler.grader import grade_answer + + +def format_reward(response: str) -> float: + pattern = re.compile(r".*?\s*.*?", re.DOTALL) + format_match = re.fullmatch(pattern, response) + return 1.0 if format_match else 0.0 + + +def accuracy_reward(response: str, ground_truth: str) -> float: + try: + content_match = re.search(r"(.*?)", response) + given_answer = content_match.group(1).strip() if content_match else response.strip() + if grade_answer(given_answer, ground_truth.strip()): + return 1.0 + + except Exception: + pass + + return 0.0 + + +def compute_score(reward_input: dict[str, Any], format_weight: float = 0.5) -> dict[str, float]: + if not isinstance(reward_input, dict): + raise ValueError("Please use `reward_type=sequential` for r1v reward function.") + + format_score = format_reward(reward_input["response"]) + accuracy_score = accuracy_reward(reward_input["response"], reward_input["ground_truth"]) + return { + "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, + "format": format_score, + "accuracy": accuracy_score, + } diff --git a/examples/rloo_trainer/run_qwen2-7b.sh b/examples/rloo_trainer/run_qwen2-7b.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_deepseek_6b7.sh b/examples/sft/gsm8k/run_deepseek_6b7.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_gemma_2b.sh b/examples/sft/gsm8k/run_gemma_2b.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_gemma_7b.sh b/examples/sft/gsm8k/run_gemma_7b.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh b/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_qwen_05_peft.sh b/examples/sft/gsm8k/run_qwen_05_peft.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_qwen_05_sp2.sh b/examples/sft/gsm8k/run_qwen_05_sp2.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh b/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh old mode 100644 new mode 100755 diff --git a/examples/sft/multiturn/run_qwen_05_sp2.sh b/examples/sft/multiturn/run_qwen_05_sp2.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh b/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2_3b_dapo_multiturn.sh b/examples/sglang_multiturn/run_qwen2_3b_dapo_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh b/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/split_placement/run_deepseek7b_llm.sh b/examples/split_placement/run_deepseek7b_llm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh b/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh b/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh b/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh b/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh b/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh b/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh b/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh b/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/scripts/process_mosei_annotations.py b/scripts/process_mosei_annotations.py new file mode 100644 index 00000000000..6d655e0998b --- /dev/null +++ b/scripts/process_mosei_annotations.py @@ -0,0 +1,74 @@ +import json +import tqdm + + +def process_mosei_annotations(annotation_path: str) -> None: + data = [] + with open(annotation_path, "r") as f: # jsonl file + for line in f: + entry = json.loads(line.strip()) + data.append(entry) + + formatted_data = [] + for sample in tqdm.tqdm(data): + image_path = sample["image"] + video_id = image_path.split("/")[1].split("_")[0] + clip_id = image_path.split("_")[-1].split(".")[0] + raw_video_path = f"Raw/{video_id}/{clip_id}.mp4" + + problem: str = sample["conversations"][0]["value"] + question_statement = problem.index("What is ") + question_str = problem[question_statement:] + answer_str = sample["conversations"][1]["value"] + + new_entry = { + "videos": [raw_video_path], + "problem": question_str, + "answer": answer_str, + } + + # avoid adding if the video and problem already exists + if not any( + entry["videos"] == new_entry["videos"] and entry["problem"] == new_entry["problem"] + for entry in formatted_data + ): + formatted_data.append(new_entry) + + formatted_data = sorted(formatted_data, key=lambda entry: entry["videos"]) + + output_path = annotation_path.replace(".jsonl", "_formatted.jsonl") + with open(output_path, "w") as f: + for entry in formatted_data: + f.write(json.dumps(entry) + "\n") + + # Add train test split of 80-20, calling it annotations_train.jsonl and annotations_test.jsonl + split_index = int(0.8 * len(formatted_data)) + train_data = formatted_data[:split_index] + test_data = formatted_data[split_index:] + folder_name = annotation_path.rsplit("/", 1)[0] if "/" in annotation_path else "." + train_output_path = f"{folder_name}/annotations_train.jsonl" + test_output_path = f"{folder_name}/annotations_test.jsonl" + + with open(train_output_path, "w") as f: + for entry in train_data: + f.write(json.dumps(entry) + "\n") + + with open(test_output_path, "w") as f: + for entry in test_data: + f.write(json.dumps(entry) + "\n") + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Process MOSEI annotations") + parser.add_argument( + "--annotation_path", + type=str, + default="mosei_annotations.jsonl", + help="Path to the MOSEI annotations file (default: mosei_annotations.jsonl)" + ) + + args = parser.parse_args() + + process_mosei_annotations(args.annotation_path) + print(f"Processed annotations saved to {args.annotation_path.replace('.jsonl', '_formatted.jsonl')}") diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 27233a87994..7e1b62b4cee 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -264,6 +264,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + format_prompt: examples/format_prompt/default.jinja reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index bca4e51679c..d2378a75223 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -237,6 +237,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + format_prompt: examples/format_prompt/default.jinja reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 diff --git a/verl/trainer/config/data/legacy_data.yaml b/verl/trainer/config/data/legacy_data.yaml index 9a5ce8f0dd1..ffeaa5cb19c 100644 --- a/verl/trainer/config/data/legacy_data.yaml +++ b/verl/trainer/config/data/legacy_data.yaml @@ -16,6 +16,11 @@ val_files: ~/data/rlhf/gsm8k/test.parquet # The field in the dataset where the prompt is located. Default is 'prompt'. prompt_key: prompt +# Path to the format prompt template file. If null, uses the default format prompt. +# The template should be a Jinja2 template that will be applied to each prompt. +# Example: examples/format_prompt/default.jinja +format_prompt: examples/format_prompt/default.jinja + # The field used to select the reward function (if using different ones per example). reward_fn_key: data_source diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 7ec622036d9..2ea8bfb6305 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -20,9 +20,11 @@ __all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"] +import math from collections import defaultdict from enum import Enum -from typing import Any, Callable, Optional +from sklearn.cluster import KMeans +from typing import Any, Callable, Optional, Dict, List, Tuple import numpy as np import torch @@ -101,6 +103,7 @@ class AdvantageEstimator(str, Enum): OPO = "opo" GRPO_PASSK = "grpo_passk" GPG = "gpg" + DRPO = "drpo" ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {} @@ -324,6 +327,181 @@ def compute_grpo_outcome_advantage( return scores, scores +EPS_DEFAULT: float = 1e-6 + +# Per‑domain question history ------------------------------------------------ # +# domain_qstats[dom] = { +# "vectors": List[np.ndarray] # shape = (Q, R) +# "q_ids": List[int], # question ids in same order as vectors +# "count": int, # #questions accumulated so far +# } +# --------------------------------------------------------------------------- # +domain_qstats: Dict[Any, Dict[str, Any]] = defaultdict(lambda: { + "vectors": [], + "q_ids": [], + "count": 0, +}) + +global_running_stats: Dict[str, int] = {"q_count": 0} + +# --------------------------------------------------------------------------- # +# Helpers # +# --------------------------------------------------------------------------- # + +def _select_k_elbow(vals: np.ndarray, k_max: int = 10, tol: float = 0.10) -> int: + """k‑means elbow pick on multi‑dimensional points.""" + unique_cnt = len(np.unique(vals, axis=0)) + k_cap = min(k_max, unique_cnt) + ks = range(1, k_cap + 1) + inertias = [KMeans(n_clusters=k, n_init="auto", random_state=0).fit(vals).inertia_ for k in ks] + if len(inertias) == 1: + return 1 + drops = np.diff(inertias) * -1.0 + for i in range(1, len(drops)): + if drops[i] < tol * drops[i - 1]: + return i + 1 + return ks[-1] + + +def _cluster_info_question(vectors: List[np.ndarray]) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: + """K‑means on question‑level vectors. + + Returns + ------- + mu_d : float – inverse‑cluster‑size weighted mean of the centroid means + assignments : (Q,) – cluster index for each question vector + counts : (k,) – cluster sizes + centroids : (k,R) – cluster centroid vectors + """ + if len(vectors) == 0: + return 0.0, np.empty(0, int), np.empty(0), np.empty((0, 0)) + + X = np.stack(vectors, axis=0) # (Q,R) – R inferred from data + k_opt = _select_k_elbow(X, k_max=20) + km = KMeans(n_clusters=k_opt, n_init="auto", random_state=0).fit(X) + + centroids = km.cluster_centers_ # (k,R) + assignments = km.labels_ # (Q,) + _, counts = np.unique(assignments, return_counts=True) + counts = counts.astype(float) + + centroid_means = centroids.mean(axis=1) # (k,) + weights = 1.0 / counts + mu_d = float((weights * centroid_means).sum() / weights.sum()) + + # Debug ------------------------------------------------------------- # + print( + f"[KMEANS‑Q] k={k_opt} | centroid_means=" + f"[{', '.join(f'{m:.3f}' for m in centroid_means)}] | counts={counts.tolist()} | μ_d={mu_d:.3f}" + ) + + return mu_d, assignments, counts, centroids + + +@register_adv_est(AdvantageEstimator.DRPO) +def compute_drpo_outcome_advantage( + token_level_rewards: torch.Tensor, # (B,L) + response_mask: torch.Tensor, # (B,L) + index: np.ndarray[str], # (B,) question ids + domain_info: np.ndarray, # (B,) domain ids + epsilon: float = EPS_DEFAULT, +): + """DRPO with question‑level clustering.""" + + B, L = token_level_rewards.shape + + # 1) raw rollout‑level rewards -------------------------------------- # + raw_scores = token_level_rewards.sum(dim=-1) # (B,) + + # 2) collect rollouts per question for this mini‑batch -------------- # + q2rollouts: Dict[str, List[float]] = defaultdict(list) + q2domain: Dict[str, Any] = {} + for i in range(B): + qid: str = index[i] + q2rollouts[qid].append(raw_scores[i].item()) + q2domain[qid] = domain_info[i] + + # ensure consistent rollout count ----------------------------------- # + rollout_lens = {len(v) for v in q2rollouts.values()} + assert len(rollout_lens) == 1, "Inconsistent rollout counts per question in batch!" + + # build vector per question ----------------------------------------- # + q_vectors = {qid: np.asarray(v, dtype=np.float32) for qid, v in q2rollouts.items()} + + # 3) update per‑domain question history ----------------------------- # + for qid, vec in q_vectors.items(): + dom = q2domain[qid] + dstat = domain_qstats[dom] + dstat["vectors"].append(vec) + dstat["q_ids"].append(qid) + dstat["count"] += 1 + global_running_stats["q_count"] += 1 + + # 4) GRPO normalisation (within‑question) --------------------------- # + scores = raw_scores.clone() + id2mean = {qid: torch.mean(torch.tensor(v)) for qid, v in q2rollouts.items()} + id2std = {qid: torch.std (torch.tensor(v)) for qid, v in q2rollouts.items()} + for i in range(B): + qid: str = index[i] + scores[i] = (scores[i] - id2mean[qid]) / (id2std[qid] + epsilon) + before_scale_score = scores.clone() + + # 5) Domain‑wise question clustering -------------------------------- # + domain_cluster_cache: Dict[Any, Dict[str, Any]] = {} + for dom, dstat in domain_qstats.items(): + if dstat["count"] == 0: + continue + mu_d, assign, counts, centroids = _cluster_info_question(dstat["vectors"]) + domain_cluster_cache[dom] = { + "mu_d": mu_d, + "assign": assign, + "counts": counts, + "centroids": centroids, + "q_ids": dstat["q_ids"], + } + + # 6) Apply scaling --------------------------------------------------- # + scaling_factors: List[float] = [] + for i in range(B): + qid: str = index[i] + dom = q2domain[qid] + cache = domain_cluster_cache[dom] + + # map qid → cluster idx ---------------------------------------- # + q_idx = cache["q_ids"].index(qid) + cluster_idx = cache["assign"][q_idx] + + N_d = float(domain_qstats[dom]["count"]) + mu_d = cache["mu_d"] + T_d = max(math.sqrt(N_d) * mu_d, epsilon) + + N_c = float(cache["counts"][cluster_idx]) + mu_c = float(cache["centroids"][cluster_idx].mean()) + + factor = T_d * math.sqrt(N_c) * mu_c + scaling_factors.append(factor) + scores[i] = scores[i] / factor + + # divide scores by std of scores + scores_std = torch.std(scores) + scores = scores / (scores_std + epsilon) + + # Debug report -------------------------------------------------------- # + print("--------------Hierarchical scaling report--------------") + dom2scale: Dict[Any, List[torch.Tensor]] = defaultdict(list) + for i in range(B): + dom2scale[domain_info[i]].append(scores[i] / (before_scale_score[i] + epsilon)) + for dom, lst in dom2scale.items(): + avg_sf = torch.mean(torch.stack(lst)).item() + print(f"[HDRPO] domain = {dom:<15} | mean overall scale = {avg_sf:6.3f}") + + # Print global reward mean + print(f"[HDRPO] global reward mean = {torch.mean(scores):.3f}") + + returns = scores.unsqueeze(-1) * response_mask + return returns, returns + + @register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk") def compute_grpo_passk_outcome_advantage( token_level_rewards: torch.Tensor, diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 97b68684d5c..bb783854aaf 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -27,11 +27,13 @@ from dataclasses import dataclass, field from enum import Enum from pprint import pprint -from typing import Optional +from typing import Optional, Dict import numpy as np import ray import torch +import ujson +import wandb from omegaconf import OmegaConf, open_dict from torch.utils.data import Dataset, Sampler from torchdata.stateful_dataloader import StatefulDataLoader @@ -61,6 +63,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger +from examples.reward_function.evaluation import compute_metrics_by_data_source WorkerType = type[Worker] @@ -271,6 +274,18 @@ def compute_advantage( ) data.batch["advantages"] = advantages data.batch["returns"] = returns + elif adv_estimator == AdvantageEstimator.DRPO: + grpo_calculation_mask = data.batch["response_mask"] + domain_info = data.non_tensor_batch["dataset"] + + advantages, returns = core_algos.compute_drpo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=grpo_calculation_mask, + index=data.non_tensor_batch["uid"], + domain_info=domain_info + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns else: # handle all other adv estimator type other than GAE and GRPO adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) @@ -573,7 +588,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl except Exception as e: print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") - def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): + def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path, **kwargs): """Dump rollout/validation samples as JSONL.""" os.makedirs(dump_path, exist_ok=True) filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") @@ -591,6 +606,14 @@ def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dic if len(v) == n: base_data[k] = v + for k, v in kwargs.items(): + if isinstance(v, np.ndarray): + base_data[k] = v.tolist() + elif hasattr(v, 'cpu'): # Check if it's a torch tensor + base_data[k] = v.cpu().numpy().tolist() + else: + base_data[k] = v + lines = [] for i in range(n): entry = {k: v[i] for k, v in base_data.items()} @@ -636,6 +659,14 @@ def _validate(self): sample_scores = [] sample_turns = [] + # New lists for metric calculation + all_predictions = [] + all_ground_truths = [] + all_data_sources = [] + all_demographics = [] + all_datasets = [] + data_source_lst = [] + for test_data in self.val_dataloader: test_batch = DataProto.from_single_dict(test_data) @@ -658,6 +689,9 @@ def _validate(self): item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch ] sample_gts.extend(ground_truths) + data_sources = test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(input_texts)) + datasets = test_batch.non_tensor_batch.get("dataset", ["unknown"] * len(input_texts)) + demographics = test_batch.non_tensor_batch.get("demo", ["unknown"] * len(input_texts)) batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] @@ -708,6 +742,16 @@ def _validate(self): output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] sample_outputs.extend(output_texts) + # Collect for metrics calculation + all_predictions.extend(output_texts) + all_ground_truths.extend(ground_truths) + all_data_sources.extend(data_sources) + all_datasets.extend(datasets) + all_demographics.extend(demographics) + data_source_lst.append( + test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(input_texts)) + ) + test_batch = test_batch.union(test_output_gen_batch) test_batch.meta_info["validate"] = True @@ -730,27 +774,23 @@ def _validate(self): if "__num_turns__" in test_batch.non_tensor_batch: sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) - data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) - self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) - # dump generations - val_data_dir = self.config.trainer.get("validation_data_dir", None) - if val_data_dir: - self._dump_generations( - inputs=sample_inputs, - outputs=sample_outputs, - gts=sample_gts, - scores=sample_scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=val_data_dir, - ) + # Per data source metrics + metrics = compute_metrics_by_data_source(all_predictions, all_ground_truths, + all_data_sources, all_datasets, all_demographics) + wandb.log(metrics, step=self.global_steps) for key_info, lst in reward_extra_infos_dict.items(): assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" data_sources = np.concatenate(data_source_lst, axis=0) + # convert to list for easier processing + data_sources = data_sources.tolist() + print(f"size of sample_scores: {len(sample_scores)}, size of sample_outputs: {len(sample_outputs)}," + f" size of sample_gts: {len(sample_gts)}, size of sample_inputs: {len(sample_inputs)}" + f", size of data_sources: {len(data_sources)}, size of sample_turns: {len(sample_turns)}") data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) metric_dict = {} for data_source, var2metric2val in data_src2var2metric2val.items(): @@ -769,6 +809,20 @@ def _validate(self): pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" metric_dict[pfx] = metric_val + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", self.config.trainer.default_local_dir) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + datasets=all_datasets, + data_paths=data_sources, + ) + if len(sample_turns) > 0: sample_turns = np.concatenate(sample_turns) metric_dict["val-aux/num_turns/min"] = sample_turns.min() @@ -777,6 +831,32 @@ def _validate(self): return metric_dict + def save_generations(self, sample_datapaths, sample_datasets, sample_inputs, sample_labels, sample_outputs, + sample_scores): + generation_save_folder = os.path.join(self.config.trainer.default_local_dir, + f"global_step_{self.global_steps}") + if not os.path.exists(generation_save_folder): + os.makedirs(generation_save_folder, exist_ok=True) + with open(os.path.join(generation_save_folder, "generations.jsonl"), "w") as f: + for i in range(len(sample_inputs)): + try: + short_answer = sample_outputs[i].split("boxed{")[1].split("}")[0] + except IndexError: + short_answer = '' + answer_is_correct = short_answer == sample_labels[i] + f.write( + ujson.dumps({ + "input": sample_inputs[i], + "generations": sample_outputs[i], + "short_answer": short_answer, + "answer_is_correct": answer_is_correct, + "label": sample_labels[i], + "score": sample_scores[i], + "dataset": sample_datasets[i], + "datapath": sample_datapaths[i], + }) + "\n" + ) + def init_workers(self): """Initialize distributed training workers using Ray backend. diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 2c19385c2b3..6024ba32ca7 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -24,6 +24,7 @@ import datasets import numpy as np import torch +from jinja2 import Template from omegaconf import DictConfig, ListConfig from torch.utils.data import Dataset from transformers import PreTrainedTokenizer, ProcessorMixin @@ -107,6 +108,10 @@ def __init__( self.return_full_prompt = config.get("return_full_prompt", False) self.truncation = config.get("truncation", "error") self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) + if isinstance(data_files, str): + self.base_dir = os.path.dirname(os.path.abspath(data_files)) + else: + self.base_dir = os.path.dirname(os.path.abspath(data_files[0])) self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) self.num_workers = min(self.num_workers, os.cpu_count()) @@ -116,10 +121,22 @@ def __init__( self.filter_prompts = config.get("filter_prompts", True) self.serialize_dataset = False self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True) + + # Load format prompt from file if specified + self.format_prompt_path = config.get("format_prompt", "examples/format_prompt/default.jinja") + self.format_prompt = self._load_format_prompt() self._download() self._read_files_and_tokenize() + def _load_format_prompt(self) -> Optional[Template]: + """Load format prompt from file if specified.""" + if self.format_prompt_path: + with open(self.format_prompt_path, 'r', encoding='utf-8') as f: + template_content = f.read() + return Template(template_content) + return None + def _download(self, use_origin_parquet=False): from verl.utils.fs import copy_to_local @@ -131,7 +148,12 @@ def _read_files_and_tokenize(self): dataframes = [] for parquet_file in self.data_files: # read parquet files and cache - dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + if parquet_file.endswith(".parquet"): + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + elif parquet_file.endswith(".json") or parquet_file.endswith(".jsonl"): + dataframe = datasets.load_dataset("json", data_files=parquet_file)["train"] + else: + raise ValueError(f"Unsupported file format: {parquet_file}. Only .parquet, .json, .jsonl are supported.") dataframes.append(dataframe) self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) @@ -188,11 +210,35 @@ def __len__(self): return len(self.dataframe) def _build_messages(self, example: dict): - messages: list = example.pop(self.prompt_key) + messages: list = example.get(self.prompt_key) + if isinstance(messages, str): + messages = [messages] if self.image_key in example or self.video_key in example: + new_messages = [] for message in messages: - content = message["content"] + new_message = copy.deepcopy(message) + if isinstance(new_message, str): + new_message = {"role": "user", "content": new_message} + content = new_message["content"] + + # Apply format prompt to the entire content first if template is loaded + if self.format_prompt: + content = self.format_prompt.render(content=content) + + image_count = len(example.get(self.image_key, [])) + video_count = len(example.get(self.video_key, [])) + image_tag_count = content.count("") + video_tag_count = content.count("