diff --git a/_unit_test_modality_sampler.py b/_unit_test_modality_sampler.py new file mode 100644 index 00000000000..323d2461439 --- /dev/null +++ b/_unit_test_modality_sampler.py @@ -0,0 +1,308 @@ +# test_stateful_modality_sampler_hardcoded.py + +import json +from typing import Dict, Any, List, Iterator +from torch.utils.data import Dataset, BatchSampler +import random + +# ==== ADJUST PATHS below to match your repo structure ==== +# from verl.utils.dataset.modality_sampler import ModalitySignatureBatchSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from collections import defaultdict, deque +import torch +import numpy as np + +# ---------- HARD-CODED PATHS + CONFIG ---------- +JSONL_PATH = "/Users/keane/Desktop/research/human-behavior/data/all/sigs_no_lmvd_discretized_v3_template_prompts.jsonl" +TRAIN_BS = 4 +VAL_BS = 4 +SEED = 42 +TRUNCATE_RATIO = 0.001 # for quick testing; set to 1.0 to disable +# --------------------------------------------- + +# TODO: Please remove text only; everything should be text_only + +class ModalitySignatureBatchSampler(BatchSampler): + """ + Round-robin across modality signatures, pruning exhausted signatures. + - Shuffles within each signature if shuffle=True (train). + - Each yielded batch is homogeneous by modality_signature. + - If a signature runs out of batches, it is removed and RR continues. + """ + def __init__( + self, + indices_by_sig: Dict[str, List[int]], + batch_size: int, + drop_last: bool = True, + seed: int = 42, + shuffle: bool = True, + ): + self.indices_by_sig = {s: list(v) for s, v in indices_by_sig.items()} + self.batch_size = int(batch_size) + self.drop_last = drop_last + self.shuffle = shuffle + self.rng = random.Random(seed) + self.sigs = list(self.indices_by_sig.keys()) + + def _batches_for(self, pool: List[int]) -> List[List[int]]: + n = len(pool) + batches = [] + for start in range(0, n, self.batch_size): + chunk = pool[start:start + self.batch_size] + if len(chunk) < self.batch_size and self.drop_last: + continue + if chunk: + batches.append(chunk) + return batches + + def __iter__(self) -> Iterator[List[int]]: + # Fresh pools + optional shuffle within each signature + pools = {s: list(v) for s, v in self.indices_by_sig.items()} + for s in pools: + if self.shuffle: + self.rng.shuffle(pools[s]) + + # Build per-signature batch queues; essentially a dictionary with batches of each different modality signature + per_sig_batches = {s: deque(self._batches_for(pools[s])) for s in self.sigs} + + # Establish RR order + order = list(self.sigs) + if self.shuffle: + # rotate start signature per epoch for variety (keeps RR structure) + k = self.rng.randrange(len(order)) if order else 0 + order = order[k:] + order[:k] + else: + order = sorted(order) + + # Active signatures as a deque for easy rotation + active = deque([s for s in order if len(per_sig_batches[s]) > 0]) + + while active: + s = active.popleft() # take the queue's leftmost element (modality signature) + q = per_sig_batches[s] # access all of the batched stuff + if q: + yield q.popleft() # yield that batch + # if still has batches, push to the end to continue RR + if q: + active.append(s) # reappend the modality signature to the active queue + # if q is empty, we simply don't re-append s → pruned automatically + else: + print(f"Ran-Out: Pruning modality signature: {s}") + + def __len__(self) -> int: + # Total number of batches across all signatures (after drop_last handling) + total = 0 + for pool in self.indices_by_sig.values(): + full, rem = divmod(len(pool), self.batch_size) + total += full + (0 if self.drop_last or rem == 0 else 1) + return total + + +def rl_collate_fn(data_list: list[dict]) -> dict: + """ + Collate a batch of sample dicts into batched tensors and arrays. + + Args: + data_list: List of dicts mapping feature names to torch.Tensor or other values. + + Returns: + Dict where tensor entries are stacked into a torch.Tensor of shape + (batch_size, dims) and non-tensor entries are converted to + np.ndarray of dtype object with shape (batch_size,). + """ + tensors = defaultdict(list) + non_tensors = defaultdict(list) + + for data in data_list: + for key, val in data.items(): + if isinstance(val, torch.Tensor): + tensors[key].append(val) + else: + non_tensors[key].append(val) + + for key, val in tensors.items(): + tensors[key] = torch.stack(val, dim=0) + + for key, val in non_tensors.items(): + non_tensors[key] = np.fromiter(val, dtype=object, count=len(val)) + + return {**tensors, **non_tensors} + +def create_rl_sampler(data_config, dataset, split: str = "train"): + """Create a sampler for the dataset, grouping strictly by existing modality_signature.""" + import torch + from torch.utils.data import RandomSampler, SequentialSampler + + mb_cfg = data_config.get("modality_batching") if split == "train" \ + else data_config.get("val_modality_batching") + + # (keep curriculum path if you actually use it; omitted here for brevity) + + if mb_cfg and mb_cfg.get("enabled", False): + by_sig: Dict[str, List[int]] = {} + for i in range(len(dataset)): + row = dataset.dataframe[i] if hasattr(dataset, "dataframe") else dataset[i] + sig = row.get("modality_signature") + if sig is None: + print(f"[WARNING] Row {i} missing 'modality_signature'. Skipping.") + continue + by_sig.setdefault(sig, []).append(i) + + batch_size = mb_cfg.get("batch_size", data_config.get( + "train_batch_size" if split=="train" else "val_batch_size" + )) + drop_last = mb_cfg.get("drop_last", split=="train") + shuffle = (split == "train") + + return ModalitySignatureBatchSampler( + indices_by_sig=by_sig, + batch_size=int(batch_size), + drop_last=drop_last, + seed=data_config.get("seed", 42), + shuffle=shuffle, + ) + + # Fallbacks + if data_config.get("shuffle", True) and split == "train": + g = torch.Generator(); g.manual_seed(data_config.get("seed", 1)) + return RandomSampler(data_source=dataset, generator=g) + else: + return SequentialSampler(data_source=dataset) + +class JsonlDataset(Dataset): + def __init__(self, jsonl_path: str, truncate_ratio: float = TRUNCATE_RATIO, seed: int = SEED): + """ + Loads ONLY entries that already have 'modality_signature'. + Optionally keeps a proportion per signature for fast debugging. + """ + all_rows: List[Dict[str, Any]] = [] + with open(jsonl_path, "r", encoding="utf-8") as f: + for ln in f: + ln = ln.strip() + if not ln: + continue + ex = json.loads(ln) + sig = ex.get("modality_signature") + if sig is None: + print(f"[WARNING] Entry missing 'modality_signature'. Skipping.") + continue # skip missing + all_rows.append(ex) + + # Group by signature and truncate per signature + sig_to_rows: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + for ex in all_rows: + sig_to_rows[ex["modality_signature"]].append(ex) + + rng = random.Random(seed) + truncated_rows: List[Dict[str, Any]] = [] + for sig, rows in sig_to_rows.items(): + if truncate_ratio >= 1.0: + truncated_rows.extend(rows) + continue + keep_n = max(1, int(len(rows) * truncate_ratio)) + rng.shuffle(rows) + truncated_rows.extend(rows[:keep_n]) + + self.rows = truncated_rows + self.dataframe = self # preserve your API + + # simple stats + counts = {sig: sum(1 for r in self.rows if r["modality_signature"] == sig) for sig in sig_to_rows} + print(f"[DEBUG] After truncation (ratio={truncate_ratio}), total {len(self.rows)}. Per-signature: {counts}") + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + +def assert_homogeneous(batch_list: List[Dict[str, Any]]): + sigs = {b.get("modality_signature") for b in batch_list} + if len(sigs) != 1: + raise AssertionError(f"Non-homogeneous batch signatures: {sigs}") + +def collate_with_guard(batch_list): + assert_homogeneous(batch_list) + return rl_collate_fn(batch_list) + +def build_cfg(train_bs: int, val_bs: int, seed: int = 42): + class Dot(dict): + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + return Dot({ + "train_batch_size": train_bs, + "val_batch_size": val_bs, + "shuffle": True, + "seed": seed, + "dataloader_num_workers": 0, + "validation_shuffle": False, + "sampler": None, + "modality_batching": {"enabled": True, "batch_size": train_bs, "drop_last": True}, + "val_modality_batching": {"enabled": True, "batch_size": val_bs, "drop_last": False}, + }) + +def build_loader(dataset, data_cfg, split: str): + sampler_or_batch = create_rl_sampler(data_cfg, dataset, split=split) + if isinstance(sampler_or_batch, BatchSampler): + return StatefulDataLoader( + dataset=dataset, + batch_sampler=sampler_or_batch, + num_workers=data_cfg["dataloader_num_workers"], + collate_fn=collate_with_guard, + ) + else: + bs = data_cfg.get("train_batch_size" if split == "train" else "val_batch_size") + return StatefulDataLoader( + dataset=dataset, + sampler=sampler_or_batch, + batch_size=bs, + num_workers=data_cfg["dataloader_num_workers"], + drop_last=(split == "train"), + shuffle=False if split == "val" else False, + collate_fn=collate_with_guard, + ) + +def main(): + ds = JsonlDataset(JSONL_PATH) + print(f"Dataset size: {len(ds)}; per-signature counts:", + {sig: sum(1 for r in ds.rows if r['modality_signature']==sig) + for sig in sorted({r['modality_signature'] for r in ds.rows})}) + + cfg = build_cfg(TRAIN_BS, VAL_BS, SEED) + + # TRAIN + train_loader = build_loader(ds, cfg, split="train") + print("\n[TRAIN] Iteration 1") + n_train_batches = sum(1 for _ in train_loader) # iterating as you would with the train loader + print(f"train steps: {n_train_batches} (drop_last=True)") + + # New epoch + train_loader2 = build_loader(ds, cfg, split="train") + n_train_batches2 = sum(1 for _ in train_loader2) + assert n_train_batches == n_train_batches2 + print("[TRAIN] Iteration 2: step count consistent") + + # VAL + val_loader = build_loader(ds, cfg, split="val") + print("\n[VAL] Iteration 1") + n_val_batches = sum(1 for _ in val_loader) + print(f"val steps: {n_val_batches} (drop_last=False)") + + # Stateful resume check (if supported) + if hasattr(train_loader, "state_dict"): + print("\n[STATEFUL] Testing resume mid-epoch") + train_loader3 = build_loader(ds, cfg, split="train") + it = iter(train_loader3) + next(it); next(it) # consume 2 + sd = train_loader3.state_dict() + train_loader4 = build_loader(ds, cfg, split="train") + train_loader4.load_state_dict(sd) + resumed = sum(1 for _ in train_loader4) + print(f"resumed batches after 2 consumed: {resumed}") + + print("\nOK: StatefulDataLoader + sampler test finished.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dvd_requirements.txt b/dvd_requirements.txt new file mode 100644 index 00000000000..7210f325c6b --- /dev/null +++ b/dvd_requirements.txt @@ -0,0 +1,282 @@ +# torch==2.7.1+cu126 +absl-py==2.3.0 +accelerate==1.8.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.12.13 +aiohttp-cors==0.8.1 +aiosignal==1.3.2 +airportsdata==20250622 +alembic==1.16.2 +aliyun-python-sdk-core==2.16.0 +aliyun-python-sdk-kms==2.16.5 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.9.0 +astor==0.8.1 +attrs==25.3.0 +audioread==3.0.1 +av==14.4.0 +beautifulsoup4==4.13.4 +blake3==1.0.5 +cachetools==5.5.2 +cbor2==5.6.5 +certifi==2025.8.3 +cffi==1.17.1 +chardet==5.2.0 +charset-normalizer==3.4.2 +click==8.2.1 +cloudpickle==3.1.1 +codetiming==1.4.0 +color-matcher==0.6.0 +colorama==0.4.6 +colorful==0.5.6 +comfyui-embedded-docs==0.2.3 +comfyui_frontend_package==1.23.4 +comfyui_workflow_templates==0.1.32 +compressed-tensors==0.10.2 +contourpy==1.3.2 +crcmod==1.7 +cryptography==45.0.5 +cuda-bindings==12.9.0 +cuda-python==12.9.0 +cupy-cuda12x==13.4.1 +cycler==0.12.1 +datasets==4.0.0 +ddt==1.7.2 +decorator==5.2.1 +Deprecated==1.2.18 +depyf==0.19.0 +diffusers==0.34.0 +dill==0.3.8 +diskcache==5.6.3 +distlib==0.3.9 +distro==1.9.0 +dnspython==2.7.0 +docutils==0.21.2 +einops==0.8.1 +email_validator==2.2.0 +fastapi==0.115.14 +fastapi-cli==0.0.7 +fastrlock==0.8.3 +ffmpeg-python==0.2.0 +# filelock==3.19.1 +flash_attn==2.8.0.post2 +flatbuffers==25.2.10 +fonttools==4.58.4 +frozenlist==1.7.0 +fsspec==2025.3.0 +ftfy==6.3.1 +future==1.0.0 +gguf==0.17.1 +gitdb==4.0.12 +GitPython==3.1.44 +google-api-core==2.25.1 +google-auth==2.40.3 +googleapis-common-protos==1.70.0 +greenlet==3.2.3 +grpcio==1.73.1 +h11==0.16.0 +heavyball==1.7.2 +hf-xet==1.1.5 +httpcore==1.0.9 +httptools==0.6.4 +httpx==0.28.1 +huggingface-hub==0.34.1 +hydra-core==1.3.2 +idna==3.10 +imageio==2.37.0 +# importlib_metadata==8.7.0 +interegular==0.3.3 +jax==0.7.0 +jaxlib==0.7.0 +Jinja2==3.1.6 +jiter==0.10.0 +jmespath==0.10.0 +joblib==1.5.1 +jsonschema==4.24.0 +jsonschema-specifications==2025.4.1 +kiwisolver==1.4.8 +kornia==0.8.1 +kornia_rs==0.1.9 +lark==1.2.2 +lazy_loader==0.4 +librosa==0.11.0 +liger_kernel==0.5.10 +llguidance==0.7.30 +llvmlite==0.44.0 +lm-format-enforcer==0.10.11 +Mako==1.3.10 +Markdown==3.8.2 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +mathruler==0.1.0 +matplotlib==3.10.3 +matrix-client==0.4.0 +mdurl==0.1.2 +mediapipe==0.10.21 +mistral_common==1.8.3 +ml_dtypes==0.5.3 +model-index==0.1.11 +mpi4py==4.1.0 +mpmath==1.3.0 +msgpack==1.1.1 +msgspec==0.19.0 +multidict==6.6.0 +multiprocess==0.70.16 +natsort==8.4.0 +nest-asyncio==1.6.0 +networkx==3.5 +ninja==1.11.1.4 +numba==0.61.2 +numpy==1.26.4 +nvidia-cublas-cu12==12.6.4.1 +nvidia-cuda-cupti-cu12==12.6.80 +nvidia-cuda-nvrtc-cu12==12.6.77 +nvidia-cuda-runtime-cu12==12.6.77 +nvidia-cudnn-cu12==9.5.1.17 +nvidia-cufft-cu12==11.3.0.4 +nvidia-cufile-cu12==1.11.1.6 +nvidia-curand-cu12==10.3.7.77 +nvidia-cusolver-cu12==11.7.1.2 +nvidia-cusparse-cu12==12.5.4.2 +nvidia-cusparselt-cu12==0.6.3 +nvidia-ml-py==12.575.51 +nvidia-nccl-cu12==2.26.2 +nvidia-nvjitlink-cu12==12.6.85 +nvidia-nvshmem-cu12==3.3.9 +nvidia-nvtx-cu12==12.6.77 +olefile==0.47 +omegaconf==2.3.0 +openai==1.90.0 +opencensus==0.11.4 +opencensus-context==0.1.3 +opencv-contrib-python==4.11.0.86 +opencv-python-headless==4.11.0.86 +opendatalab==0.0.10 +openmim==0.3.9 +opentelemetry-api==1.26.0 +opentelemetry-exporter-otlp-proto-grpc==1.26.0 +opentelemetry-exporter-otlp-proto-http==1.26.0 +opentelemetry-exporter-prometheus==0.41b0 +opentelemetry-proto==1.26.0 +opentelemetry-sdk==1.26.0 +# opentelemetry-semantic-conventions==0.55b1 +openxlab==0.1.2 +opt_einsum==3.4.0 +ordered-set==4.1.0 +orjson==3.10.18 +oss2==2.17.0 +outlines==0.1.11 +outlines_core==0.2.10 +packaging==24.2 +pandas==2.3.0 +partial-json-parser==0.2.1.1.post6 +peft==0.15.2 +piexif==1.1.3 +pillow==11.2.1 +pip==25.1 +platformdirs==4.3.8 +pooch==1.8.2 +prometheus_client==0.22.1 +prometheus-fastapi-instrumentator==7.1.0 +propcache==0.3.2 +proto-plus==1.26.1 +protobuf==4.25.8 +psutil==7.0.0 +py-cpuinfo==9.0.0 +py-spy==0.4.0 +pyarrow==20.0.0 +pyasn1==0.6.1 +pyasn1_modules==0.4.2 +pybase64==1.4.1 +pycountry==24.6.1 +pycparser==2.22 +pycryptodome==3.23.0 +pydantic==2.11.7 +pydantic_core==2.33.2 +pydantic-extra-types==2.10.5 +pydantic-settings==2.10.1 +PyGithub==2.6.1 +Pygments==2.19.2 +PyJWT==2.10.1 +pylatexenc==2.10 +pyloudnorm==0.1.1 +PyNaCl==1.5.0 +pynvml==12.0.0 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +python-json-logger==4.0.0.dev0 +python-multipart==0.0.20 +pytz==2023.4 +PyYAML==6.0.2 +pyzmq==27.0.0 +qwen-vl-utils==0.0.11 +ray==2.47.1 +referencing==0.36.2 +regex==2024.11.6 +# requests==2.28.2 +rich==14.1.0 +rich-toolkit==0.14.7 +rpds-py==0.25.1 +rsa==4.9.1 +safetensors==0.6.0rc0 +# sageattention==2.2.0 +scikit-image==0.25.2 +scikit-learn==1.7.0 +scipy==1.16.0 +sentencepiece==0.2.0 +sentry-sdk==2.32.0 +setproctitle==1.3.6 +setuptools==79.0.1 +shellingham==1.5.4 +six==1.17.0 +smart-open==7.1.0 +smmap==5.0.2 +sniffio==1.3.1 +sounddevice==0.5.2 +soundfile==0.13.1 +soupsieve==2.7 +soxr==0.5.0.post1 +spandrel==0.4.1 +SQLAlchemy==2.0.41 +starlette==0.46.2 +sympy==1.14.0 +tabulate==0.9.0 +tensorboard==2.19.0 +tensorboard-data-server==0.7.2 +tensordict==0.8.3 +threadpoolctl==3.6.0 +tifffile==2025.6.11 +tiktoken==0.9.0 +timm==1.0.16 +tokenizers==0.21.2 +toml==0.10.2 +# torchaudio==2.7.1 +# torchcodec==0.4.0+cu126 +torchdata==0.11.0 +torchsde==0.2.6 +# torchvision==0.22.1 +# tqdm==4.65.2 +trampoline==0.1.2 +transformers==4.54.0 +triton==3.3.1 +typer==0.16.0 +typing_extensions==4.14.0 +typing-inspection==0.4.1 +tzdata==2025.2 +ujson==5.10.0 +urllib3==1.26.20 +uv==0.7.19 +uvicorn==0.34.3 +uvloop==0.21.0 +virtualenv==20.31.2 +vllm==0.8.4 +wandb==0.20.1 +watchdog==6.0.0 +watchfiles==1.1.0 +wcwidth==0.2.13 +websockets==15.0.1 +Werkzeug==3.1.3 +wfdb==4.3.0 +wheel==0.45.1 \ No newline at end of file diff --git a/examples/drpo_trainer/run_qwen2_5_vl-7b_climb_no_thinking.sh b/examples/drpo_trainer/run_qwen2_5_vl-7b_climb_no_thinking.sh new file mode 100755 index 00000000000..87c368ad1a2 --- /dev/null +++ b/examples/drpo_trainer/run_qwen2_5_vl-7b_climb_no_thinking.sh @@ -0,0 +1,54 @@ +set -x +ENGINE=${1:-vllm} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=drpo \ + data.train_files=/home/dvdai/orcd/scratch/high_modality/geom_train_upsampled_new.jsonl \ + data.val_files=/home/dvdai/orcd/scratch/high_modality/geom_valid_mini_new.jsonl \ + data.train_batch_size=512 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.format_prompt=examples/format_prompt/no_thinking.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=2e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_climb' \ + trainer.experiment_name='drpo_nothinking' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/format_prompt/README.md b/examples/format_prompt/README.md new file mode 100644 index 00000000000..412c5a558e3 --- /dev/null +++ b/examples/format_prompt/README.md @@ -0,0 +1,63 @@ +# Format Prompt Templates + +This directory contains Jinja2 templates for formatting prompts in RLHF datasets. + +## Overview + +The format prompt feature allows you to apply custom formatting to each prompt in your dataset using Jinja2 templates. This is useful when you want to add consistent instructions or formatting to all prompts without modifying the original dataset. + +## Default Template + +The default template (`default.jinja`) appends the following instruction to each prompt: + +``` +{{ content }}You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}. +``` + +## Usage + +To use a format prompt template, specify the `format_prompt` parameter in your data configuration: + +```yaml +data: + # ... other data config ... + format_prompt: examples/format_prompt/default.jinja # Path to your template file +``` + +Or set it to `null` to disable format prompting: + +```yaml +data: + format_prompt: null +``` + +## Creating Custom Templates + +To create a custom format prompt: + +1. Create a new `.jinja` file in this directory or elsewhere +2. Use `{{ content }}` as the placeholder for the original prompt content +3. Add your custom formatting around it + +Example custom template: + +```jinja +{{ content }} + +Please solve this problem step by step: +1. Understand the problem +2. Plan your approach +3. Execute the solution +4. Verify your answer +``` + +## Template Variables + +Currently, the template receives one variable: +- `content`: The original prompt text + +## Notes + +- The template is applied during dataset preprocessing +- If the template file is not found, the system will use the original prompt without formatting +- For multimodal datasets (images/videos), the formatting is applied to text segments only \ No newline at end of file diff --git a/examples/format_prompt/default.jinja b/examples/format_prompt/default.jinja new file mode 100644 index 00000000000..be95b0ef441 --- /dev/null +++ b/examples/format_prompt/default.jinja @@ -0,0 +1 @@ +{{ content }}You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}. \ No newline at end of file diff --git a/examples/format_prompt/no_thinking.jinja b/examples/format_prompt/no_thinking.jinja new file mode 100644 index 00000000000..39a137c9384 --- /dev/null +++ b/examples/format_prompt/no_thinking.jinja @@ -0,0 +1 @@ +{{ content }}You MUST provide the final answer directly without any extra information. Enclose the final answer in \boxed{}. \ No newline at end of file diff --git a/examples/generation/run_deepseek7b_mutli_node.sh b/examples/generation/run_deepseek7b_mutli_node.sh old mode 100644 new mode 100755 diff --git a/examples/generation/run_deepseek_v2_lite_math.sh b/examples/generation/run_deepseek_v2_lite_math.sh old mode 100644 new mode 100755 diff --git a/examples/gmpo_trainer/run_qwen2_5-7b_math.sh b/examples/gmpo_trainer/run_qwen2_5-7b_math.sh old mode 100644 new mode 100755 diff --git a/examples/gmpo_trainer/test_dapo_7b_math.sh b/examples/gmpo_trainer/test_dapo_7b_math.sh old mode 100644 new mode 100755 diff --git a/examples/gmpo_trainer/test_dapo_qwen3_30b_math.sh b/examples/gmpo_trainer/test_dapo_qwen3_30b_math.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/_debug_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh b/examples/grpo_trainer/_debug_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh new file mode 100755 index 00000000000..3522b2fea39 --- /dev/null +++ b/examples/grpo_trainer/_debug_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh @@ -0,0 +1,121 @@ +set -x + +unset ROCR_VISIBLE_DEVICES + +# actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct +# actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B +# data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.modalities=\'audio,videos\' \ + +# SETTING OF SAVE PATH: trainer.default_local_dir= /scratch/keane/human_behaviour/2_models_hb_vision_only +# SETTING OF THE LOAD PATH from directory of checkpoints is also: trainer.default_local_dir + +# TRAINING FROM scratch: trainer.resume_mode == "disable" (default will save into default_local_dir) + +# TRAINING AUTOMATICALLY (i.e. from scratch or from latest checkpoint) : + # trainer.resume_mode == "auto" and then the model will take the latest ckpt from trainer.default_hdfs_dir + +# TRAINING from specific CHECKPOINT: trainer.resume_mode == "resume_path" and then specify trainer.resume_from_path + # Setting of path to resume training from trainer.resume_from_path (exact path of checkpoint) + # the model will take from resume_from_path directly (absolute path), and ignore default_hdfs_dir + +# for validation, set val_before_train=True ; make sure that the checkpoint is loaded and put val_only=True +# the checkpoint should already be loaded before that +# and then we will just evaluate + +# ALTERNATIVES +# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_v3_template_prompts.jsonl +# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_no_chalearn_v3_template_prompts.jsonl + +# when resuming training from a loaded checkpoint cuda OOM error + +# alt: /scratch/keane/human_behaviour/human_behaviour_data/0.1_train_no_lmvd_discretized_v3_template_prompts.jsonl +# org: /scratch/keane/human_behaviour/human_behaviour_data/train_no_lmvd_discretized_v3_template_prompts.jsonl + +# LORA: + # actor_rollout_ref.model.use_shm=True \ + # actor_rollout_ref.model.lora_rank=32 \ + # actor_rollout_ref.model.lora_alpha=32 \ + # actor_rollout_ref.rollout.load_format=safetensors \ + # actor_rollout_ref.model.target_modules=all-linear \ + # actor_rollout_ref.rollout.layered_summon=True \ + +# Set PyTorch CUDA memory allocator policies +# export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb=128 + +# data: +# train_batch_size: 8 +# val_batch_size: 8 + +# NOTE: THESE NEED TO BE TOGGLED AS TRUE +# train_modality_batching: +# enabled: true +# drop_last: true + +# val_modality_batching: +# enabled: true +# drop_last: false + + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 PYTHONPATH="/home/keaneong/human-behavior/verl:$PYTHONPATH" NCCL_ASYNC_ERROR_HANDLING=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_train_no_lmvd_discretized_v3_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_val_no_lmvd_discretized_v3_template_prompts.jsonl \ + data.train_batch_size=6 \ + data.val_batch_size=6 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.dataloader_num_workers=8 \ + data.modalities=\'audio,videos\' \ + data.train_modality_batching.enabled=True \ + data.train_modality_batching.drop_last=True \ + data.val_modality_batching.enabled=True \ + data.val_modality_batching.drop_last=True \ + data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=3 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.max_model_len=4096 \ + actor_rollout_ref.rollout.max_num_batched_tokens=4096 \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/human_behaviour.py \ + custom_reward_function.name=human_behaviour_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_hb' \ + trainer.experiment_name='mixed_modal_omni' \ + trainer.n_gpus_per_node=3 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.total_epochs=1 $@ \ + trainer.default_local_dir=/scratch/keane/human_behaviour/mixed_modal_verl_models_hb_omni \ No newline at end of file diff --git a/examples/grpo_trainer/_human_behaviour.sh b/examples/grpo_trainer/_human_behaviour.sh new file mode 100755 index 00000000000..512a4e845b1 --- /dev/null +++ b/examples/grpo_trainer/_human_behaviour.sh @@ -0,0 +1,55 @@ +set -x +ENGINE=${1:-vllm} + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ + data.train_batch_size=512 \ + data.val_batch_size=128 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_human_behaviour' \ + trainer.experiment_name='qwen2_5_vl_7b_function_trial' \ + trainer.n_gpus_per_node= 2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh b/examples/grpo_trainer/_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh new file mode 100755 index 00000000000..9bebaf58293 --- /dev/null +++ b/examples/grpo_trainer/_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh @@ -0,0 +1,120 @@ +set -x + +unset ROCR_VISIBLE_DEVICES + +# actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct +# actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B +# data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.modalities=\'audio,videos\' \ + +# SETTING OF SAVE PATH: trainer.default_local_dir= /scratch/keane/human_behaviour/2_models_hb_vision_only +# SETTING OF THE LOAD PATH from directory of checkpoints is also: trainer.default_local_dir + +# TRAINING FROM scratch: trainer.resume_mode == "disable" (default will save into default_local_dir) + +# TRAINING AUTOMATICALLY (i.e. from scratch or from latest checkpoint) : + # trainer.resume_mode == "auto" and then the model will take the latest ckpt from trainer.default_hdfs_dir + +# TRAINING from specific CHECKPOINT: trainer.resume_mode == "resume_path" and then specify trainer.resume_from_path + # Setting of path to resume training from trainer.resume_from_path (exact path of checkpoint) + # the model will take from resume_from_path directly (absolute path), and ignore default_hdfs_dir + +# for validation, set val_before_train=True ; make sure that the checkpoint is loaded and put val_only=True +# the checkpoint should already be loaded before that +# and then we will just evaluate + +# ALTERNATIVES +# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_v3_template_prompts.jsonl +# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_no_chalearn_v3_template_prompts.jsonl + +# when resuming training from a loaded checkpoint cuda OOM error + +# alt: /scratch/keane/human_behaviour/human_behaviour_data/0.1_train_no_lmvd_discretized_v3_template_prompts.jsonl +# org: /scratch/keane/human_behaviour/human_behaviour_data/train_no_lmvd_discretized_v3_template_prompts.jsonl + +# LORA: + # actor_rollout_ref.model.use_shm=True \ + # actor_rollout_ref.model.lora_rank=32 \ + # actor_rollout_ref.model.lora_alpha=32 \ + # actor_rollout_ref.rollout.load_format=safetensors \ + # actor_rollout_ref.model.target_modules=all-linear \ + # actor_rollout_ref.rollout.layered_summon=True \ + +# Set PyTorch CUDA memory allocator policies +# export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb=128 + +# data: +# train_batch_size: 8 +# val_batch_size: 8 + +# train_modality_batching: +# enabled: true +# drop_last: true + +# val_modality_batching: +# enabled: true +# drop_last: false + + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 PYTHONPATH="/home/keaneong/human-behavior/verl:$PYTHONPATH" NCCL_ASYNC_ERROR_HANDLING=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_train_no_lmvd_discretized_v3_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_val_no_lmvd_discretized_v3_template_prompts.jsonl \ + data.train_batch_size=288 \ + data.val_batch_size=144 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.dataloader_num_workers=8 \ + data.modalities=\'audio,videos\' \ + data.train_modality_batching.enabled=True \ + data.train_modality_batching.drop_last=True \ + data.val_modality_batching.enabled=True \ + data.val_modality_batching.drop_last=True \ + data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=72 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.max_model_len=4096 \ + actor_rollout_ref.rollout.max_num_batched_tokens=4096 \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/human_behaviour.py \ + custom_reward_function.name=human_behaviour_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='mixed_modal_verl_hb' \ + trainer.experiment_name='mixed_modal_omni' \ + trainer.n_gpus_per_node=3 \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ \ + trainer.default_local_dir=/scratch/keane/human_behaviour/mixed_modal_verl_models_hb_omni \ No newline at end of file diff --git a/examples/grpo_trainer/keane_vl_only_run_qwen2_5_vl-7b_hb_all_modalities.sh b/examples/grpo_trainer/keane_vl_only_run_qwen2_5_vl-7b_hb_all_modalities.sh new file mode 100755 index 00000000000..13cf20b279d --- /dev/null +++ b/examples/grpo_trainer/keane_vl_only_run_qwen2_5_vl-7b_hb_all_modalities.sh @@ -0,0 +1,66 @@ +set -x + +unset ROCR_VISIBLE_DEVICES + +# NOTE: be careful when setting filter_overlong_prompts; because this removes the prompts from the max_prompt_length + +# actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct +# actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B +# data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.modalities=\'audio,videos\' \ + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 PYTHONPATH="/home/keaneong/human-behavior/verl:$PYTHONPATH" NCCL_ASYNC_ERROR_HANDLING=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/old_train_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/old_val_template_prompts.jsonl \ + data.train_batch_size=64 \ + data.val_batch_size=64 \ + data.max_prompt_length=3072 \ + data.max_response_length=1536 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.dataloader_num_workers=0 \ + data.modalities=\'videos\' \ + data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=3 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_hb' \ + trainer.experiment_name='vision_only' \ + trainer.n_gpus_per_node=3 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=1 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_deepseek671b_math_megatron.sh b/examples/grpo_trainer/run_deepseek671b_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_deepseek7b_llm.sh b/examples/grpo_trainer/run_deepseek7b_llm.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_deepseek7b_llm_math.sh b/examples/grpo_trainer/run_deepseek7b_llm_math.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh b/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh b/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_minicpmo2_6.sh b/examples/grpo_trainer/run_minicpmo2_6.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_moonlight16b_math_megatron.sh b/examples/grpo_trainer/run_moonlight16b_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b.sh b/examples/grpo_trainer/run_qwen2-7b.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_math.sh b/examples/grpo_trainer/run_qwen2-7b_math.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh b/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh b/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh b/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_climb.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_climb.sh new file mode 100755 index 00000000000..761abd09784 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_climb.sh @@ -0,0 +1,54 @@ +set -x +ENGINE=${1:-vllm} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/home/dvdai/orcd/scratch/high_modality/geom_train_upsampled_new.jsonl \ + data.val_files=/home/dvdai/orcd/scratch/high_modality/geom_valid_mini_new.jsonl \ + data.train_batch_size=512 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.format_prompt=examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_climb' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_hb.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb.sh new file mode 100755 index 00000000000..3f1f0796af4 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb.sh @@ -0,0 +1,53 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/old_train_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/old_val_template_prompts.jsonl \ + data.train_batch_size=512 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.format_prompt=examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_hb' \ + trainer.experiment_name='vision_only' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_hb_all_modalities.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb_all_modalities.sh new file mode 100755 index 00000000000..fd2d60ba553 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb_all_modalities.sh @@ -0,0 +1,56 @@ +set -x + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/old_train_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/old_val_template_prompts.jsonl \ + data.train_batch_size=128 \ + data.val_batch_size=128 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.dataloader_num_workers=0 \ + data.modalities=\'audio,videos\' \ + data.format_prompt=examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_hb' \ + trainer.experiment_name='vision_only' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen3-236b_megatron.sh b/examples/grpo_trainer/run_qwen3-236b_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen3-8b.sh b/examples/grpo_trainer/run_qwen3-8b.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh b/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm.sh b/examples/ppo_trainer/run_deepseek7b_llm.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh b/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh b/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh b/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh b/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_gemma.sh b/examples/ppo_trainer/run_gemma.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh b/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_rm.sh b/examples/ppo_trainer/run_qwen2-7b_rm.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/ppo_trainer/run_qwen2.5-32b.sh b/examples/ppo_trainer/run_qwen2.5-32b.sh old mode 100644 new mode 100755 diff --git a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh old mode 100644 new mode 100755 diff --git a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh old mode 100644 new mode 100755 diff --git a/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh b/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh b/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh old mode 100644 new mode 100755 diff --git a/examples/reward_function/dapo.py b/examples/reward_function/dapo.py new file mode 100644 index 00000000000..9285cd1d0fd --- /dev/null +++ b/examples/reward_function/dapo.py @@ -0,0 +1,163 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def accuracy_reward(response: str, ground_truth: str) -> float: + match = re.findall(r"(?i)Answer\s*:\s*([^\n]+)", response) + answer = match[-1] if match else "[INVALID]" + if normalize_final_answer(answer) == normalize_final_answer(ground_truth): + return 1.0 + else: + return -1.0 + + +def soft_overlong_punishment(response_length: int, max_response_length: int, overlong_buffer_length: int): + expected_len = max_response_length - overlong_buffer_length + if response_length <= expected_len: + return 0.0 + elif response_length <= max_response_length: + return (expected_len - response_length) / overlong_buffer_length + else: + return -1.0 + + +def compute_score( + reward_inputs: list[dict[str, Any]], + max_response_length: int, + overlong_buffer_length: int, + overlong_penalty_factor: float, +) -> list[dict[str, float]]: + if not isinstance(reward_inputs, list): + raise ValueError("Please use `reward_type=batch` for dapo reward function.") + + scores = [] + for reward_input in reward_inputs: + response = reward_input["response"][-300:] # The longest answer in MATH-500 has 159 characters + accuracy_score = accuracy_reward(response, reward_input["ground_truth"]) + overlong_score = soft_overlong_punishment( + reward_input["response_length"], max_response_length, overlong_buffer_length + ) + scores.append( + { + "overall": accuracy_score + overlong_score * overlong_penalty_factor, + "accuracy": accuracy_score, + "overlong": overlong_score, + "accuracy_normalized": 0.5 * (accuracy_score + 1.0), + } + ) + + return scores diff --git a/examples/reward_function/evaluation.py b/examples/reward_function/evaluation.py new file mode 100644 index 00000000000..45ec549d862 --- /dev/null +++ b/examples/reward_function/evaluation.py @@ -0,0 +1,552 @@ +import datetime +import json +import os +from collections import defaultdict +from typing import Dict, List, Set +import statistics + +def parse_conditions(text: str) -> Set[str]: + """ + Parse medical conditions from text, handling various separators. + + Args: + text (str): Text containing medical conditions. + + Returns: + Set[str]: Set of individual medical conditions. + """ + # Remove any boxing notation if present + text = text.replace("\\boxed{", "").replace("}", "") + + # Split by common separators + for sep in [", ", " and ", " & ", ",", "&"]: + if sep in text: + return set(cond.strip() for cond in text.split(sep)) + + # If no separator found, treat as single condition + return {text.strip()} + + +def extract_boxed_content(text: str) -> str: + """ + Extract content within \boxed{} or similar boxing notations. + + Args: + text (str): Text containing potentially boxed content. + + Returns: + str: Extracted boxed content or the original text if no box found. + """ + import re + + # Look for LaTeX \boxed{} notation + boxed_match = re.search(r"\\boxed{([^}]*)}", text) + if boxed_match: + return boxed_match.group(1) + + # Look for markdown boxed notation (e.g., [boxed content]) + markdown_match = re.search(r"\[(.*?)\]", text) + if markdown_match: + return markdown_match.group(1) + + # Return the text as is if no boxed content is found + return text + + +def compute_class_metrics(class_name: str, confusion_matrix: Dict[str, int]) -> Dict[str, float]: + """ + Compute metrics for a single class based on its confusion matrix. + + Args: + class_name (str): Name of the class. + confusion_matrix (Dict[str, int]): Confusion matrix with tp, fp, fn, tn. + + Returns: + Dict[str, float]: Dictionary of metrics for this class. + """ + tp = confusion_matrix["tp"] + fp = confusion_matrix["fp"] + fn = confusion_matrix["fn"] + tn = confusion_matrix["tn"] + + # Calculate metrics (avoid division by zero) + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + sensitivity = recall # sensitivity is the same as recall + specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0 + + return { + "precision": precision, + "recall": recall, + "sensitivity": sensitivity, + "specificity": specificity, + "f1": f1, + "accuracy": accuracy, + "count": confusion_matrix["count"], + "confusion_matrix": {"tp": tp, "fp": fp, "fn": fn, "tn": tn}, + } + + +def gender(predictions: List[str], ground_truths: List[str], demographics: List[str]) -> Dict[str, float]: + groups = {"male": {"preds": [], "gts": []}, "female": {"preds": [], "gts": []}} + + for pred, gt, demo in zip(predictions, ground_truths, demographics): + if demo is not None and "female" in demo.lower(): + groups["female"]["preds"].append(pred) + groups["female"]["gts"].append(gt) + elif demo is not None and "male" in demo.lower(): + groups["male"]["preds"].append(pred) + groups["male"]["gts"].append(gt) + + results = {} + acc_values = [] + f1_values = [] + + for sex in ["male", "female"]: + preds = groups[sex]["preds"] + gts = groups[sex]["gts"] + if len(preds) == 0: + continue + metrics = compute_dataset_metrics(preds, gts)["dataset_metrics"] + acc = metrics["accuracy"] + f1 = metrics["f1"] + results[f"{sex}/accuracy"] = acc + results[f"{sex}/f1"] = f1 + acc_values.append(acc) + f1_values.append(f1) + print(f"{sex}: accuracy = {acc:.4f}, f1 = {f1:.4f}") + + if len(acc_values) >= 2: + acc_diff = abs(acc_values[0] - acc_values[1]) + results["acc_diff for sex"] = acc_diff + results["std_accuracy for sex"] = statistics.stdev(acc_values) + print(f"Accuracy max diff for sex = {acc_diff:.4f}") + print(f"std of accuracy for sex = {results['std_accuracy for sex']:.4f}") + + if len(f1_values) >= 2: + f1_diff = abs(f1_values[0] - f1_values[1]) + results["f1_diff for sex"] = f1_diff + results["std_f1 for sex"] = statistics.stdev(f1_values) + print(f"F1 max diff for sex = {f1_diff:.4f}") + print(f"std of f1 for sex = {results['std_f1 for sex']:.4f}") + + return results + + +def parent(predictions: List[str], ground_truths: List[str], demographics: List[str]) -> Dict[str, float]: + groups = {} + for pred, gt, demo in zip(predictions, ground_truths, demographics): + if demo is not None and "father" in demo.lower(): + if ( + demo.split("father:")[1].strip().split()[0] not in groups + and demo.split("father:")[1].strip().split()[0] != "NAN" + ): + groups[demo.split("father:")[1].strip().split()[0]] = {"preds": [], "gts": []} + groups[demo.split("father:")[1].strip().split()[0]]["preds"].append(pred) + groups[demo.split("father:")[1].strip().split()[0]]["gts"].append(gt) + else: + groups[demo.split("father:")[1].strip().split()[0]]["preds"].append(pred) + groups[demo.split("father:")[1].strip().split()[0]]["gts"].append(gt) + if demo is not None and "mother" in demo.lower(): + if ( + demo.split("mother:")[1].strip().split()[0] not in groups + and demo.split("mother:")[1].strip().split()[0] != "NAN" + ): + groups[demo.split("mother:")[1].strip().split()[0]] = {"preds": [], "gts": []} + groups[demo.split("mother:")[1].strip().split()[0]]["preds"].append(pred) + groups[demo.split("mother:")[1].strip().split()[0]]["gts"].append(gt) + else: + groups[demo.split("father:")[1].strip().split()[0]]["preds"].append(pred) + groups[demo.split("father:")[1].strip().split()[0]]["gts"].append(gt) + + results = {} + acc_values = [] + f1_values = [] + + for race in groups: + preds = groups[race]["preds"] + gts = groups[race]["gts"] + if len(preds) == 0: + continue + metrics = compute_dataset_metrics(preds, gts)["dataset_metrics"] + acc = metrics["accuracy"] + f1 = metrics["f1"] + results[f"{race}/accuracy"] = acc + results[f"{race}/f1"] = f1 + acc_values.append(acc) + f1_values.append(f1) + print(f"{race}: accuracy = {acc:.4f}, f1 = {f1:.4f}") + + if len(acc_values) >= 2: + acc_diff = max(acc_values) - min(acc_values) + results["acc_diff"] = acc_diff + print(f"Accuracy max diff for parent = {acc_diff:.4f}") + std_acc = statistics.stdev(acc_values) + results["std_accuracy"] = std_acc + print(f"std of accuracy for parent = {std_acc:.4f}") + + if len(f1_values) >= 2: + f1_diff = max(f1_values) - min(f1_values) + results["f1_diff"] = f1_diff + print(f"F1 max diff for parent = {f1_diff:.4f}") + std_f1 = statistics.stdev(f1_values) + results["std_f1"] = std_f1 + print(f"std of f1 for parent = {std_f1:.4f}") + + return results + + +def age(predictions: List[str], ground_truths: List[str], demographics: List[str]) -> Dict[str, float]: + groups = { + "a1": {"preds": [], "gts": []}, + "a2": {"preds": [], "gts": []}, + "a3": {"preds": [], "gts": []}, + "a4": {"preds": [], "gts": []}, + } + + for pred, gt, demo in zip(predictions, ground_truths, demographics): + if demo is not None and "age" in demo.lower(): + try: + age_str = demo.split("age:")[1].strip().split()[0].replace(",", "") + age_val = float(age_str) + except (IndexError, ValueError): + continue + + if age_val <= 25: + groups["a1"]["preds"].append(pred) + groups["a1"]["gts"].append(gt) + elif 25 < age_val <= 50: + groups["a2"]["preds"].append(pred) + groups["a2"]["gts"].append(gt) + elif 50 < age_val <= 75: + groups["a3"]["preds"].append(pred) + groups["a3"]["gts"].append(gt) + elif 75 < age_val: + groups["a4"]["preds"].append(pred) + groups["a4"]["gts"].append(gt) + + results = {} + acc_values = [] + f1_values = [] + + for group in ["a1", "a2", "a3", "a4"]: + preds = groups[group]["preds"] + gts = groups[group]["gts"] + if len(preds) == 0: + continue + metrics = compute_dataset_metrics(preds, gts)["dataset_metrics"] + acc = metrics["accuracy"] + f1 = metrics["f1"] + results[f"{group}/accuracy"] = acc + results[f"{group}/f1"] = f1 + acc_values.append(acc) + f1_values.append(f1) + + if len(acc_values) >= 2: + results["acc_diff"] = max(acc_values) - min(acc_values) + results["std_accuracy"] = statistics.stdev(acc_values) + + if len(f1_values) >= 2: + results["f1_diff"] = max(f1_values) - min(f1_values) + results["std_f1"] = statistics.stdev(f1_values) + + for group in ["a1", "a2", "a3", "a4"]: + acc = results.get(f"{group}/accuracy") + f1 = results.get(f"{group}/f1") + if acc is not None and f1 is not None: + print(f"{group}: accuracy = {acc:.4f}, f1 = {f1:.4f}") + + if "acc_diff" in results: + print(f"Accuracy max diff = {results['acc_diff']:.4f}") + print(f"std of accuracy for age = {results['std_accuracy']:.4f}") + if "f1_diff" in results: + print(f"F1 max diff = {results['f1_diff']:.4f}") + print(f"std of f1 for age = {results['std_f1']:.4f}") + + return results +def compute_confusion_matrices(predictions: List[str], ground_truths: List[str]) -> Dict[str, Dict[str, int]]: + """ + Compute confusion matrices for each class. + + Args: + predictions (List[str]): List of model predictions. + ground_truths (List[str]): List of ground truth labels. + + Returns: + Dict[str, Dict[str, int]]: Confusion matrices for each class. + """ + # Initialize counters for each condition + all_conditions = set() + condition_matrices = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0, "tn": 0, "count": 0}) + + # First pass: identify all unique conditions + for gt in ground_truths: + gt_conditions = parse_conditions(gt) + all_conditions.update(gt_conditions) + + for pred in predictions: + pred_answer = extract_boxed_content(pred) + if pred_answer != "None": + pred_conditions = parse_conditions(pred_answer) + all_conditions.update(pred_conditions) + + # Second pass: compute confusion matrices + for pred, gt in zip(predictions, ground_truths): + pred_answer = extract_boxed_content(pred) + if pred_answer == "None": + pred_conditions = set() + else: + pred_conditions = parse_conditions(pred_answer) + + gt_conditions = parse_conditions(gt) + + # For each possible condition + for condition in all_conditions: + condition_present_in_gt = condition in gt_conditions + condition_present_in_pred = condition in pred_conditions + + if condition_present_in_gt: + condition_matrices[condition]["count"] += 1 + + if condition_present_in_gt and condition_present_in_pred: + # True positive + condition_matrices[condition]["tp"] += 1 + elif condition_present_in_gt and not condition_present_in_pred: + # False negative + condition_matrices[condition]["fn"] += 1 + elif not condition_present_in_gt and condition_present_in_pred: + # False positive + condition_matrices[condition]["fp"] += 1 + else: + # True negative + condition_matrices[condition]["tn"] += 1 + + return condition_matrices + + +def compute_dataset_metrics(predictions: List[str], ground_truths: List[str]) -> Dict[str, Dict]: + """ + Compute metrics for a single dataset, with class-wise averaging. + + Args: + predictions (List[str]): List of model predictions for this dataset. + ground_truths (List[str]): List of ground truth labels for this dataset. + + Returns: + Dict[str, Dict]: Class metrics and averaged dataset metrics. + """ + # Compute confusion matrices for each class + class_matrices = compute_confusion_matrices(predictions, ground_truths) + + # Compute metrics for each class + class_metrics = {} + active_classes = 0 + + # Accumulators for dataset-level metrics + dataset_metrics = { + "precision": 0.0, + "recall": 0.0, + "sensitivity": 0.0, + "specificity": 0.0, + "f1": 0.0, + "accuracy": 0.0, + } + + # Compute metrics for each class and accumulate for dataset average + for class_name, matrix in class_matrices.items(): + # Skip classes that never appear in ground truth + if matrix["count"] == 0: + continue + + active_classes += 1 + metrics = compute_class_metrics(class_name, matrix) + class_metrics[class_name] = metrics + + # Accumulate for dataset average (equal class weighting) + for metric_name in dataset_metrics.keys(): + dataset_metrics[metric_name] += metrics[metric_name] + + # Calculate dataset average (equal class weighting) + if active_classes > 0: + for metric_name in dataset_metrics.keys(): + dataset_metrics[metric_name] /= active_classes + + # Add class metrics to the result + result = {"class_metrics": class_metrics, "dataset_metrics": dataset_metrics, "active_classes": active_classes} + + return result + + +def compute_metrics_by_data_source( + predictions: List[str], + ground_truths: List[str], + data_sources: List[str], + datasets: List[str], + demographics: List[str], +) -> Dict[str, float]: + """ + Compute hierarchical metrics: class -> dataset -> data source -> global. + + Args: + predictions (List[str]): List of model predictions. + ground_truths (List[str]): List of ground truth labels. + data_sources (List[str]): List of data sources for each example. + datasets (List[str]): List of dataset identifiers for each example. + demographics (List[str]): List of demographic information for each example. + + Returns: + Dict[str, float]: Flattened dictionary of metrics at all levels with keys: + - "val/{metric}" for global metrics + - "{data_source}/{metric}" for data source metrics + - "{data_source}/{dataset}/{metric}" for dataset metrics + """ + # Save inputs to json for debugging under outputs/ + + output_dir = "outputs" + os.makedirs(output_dir, exist_ok=True) + input_data = { + "predictions": predictions, + "ground_truths": ground_truths, + "data_sources": data_sources, + "datasets": datasets, + "demographics": demographics, + } + # name is time in yyyy-mm-dd_hh-mm-ss format + with open( + os.path.join(output_dir, f"input_data_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"), "w" + ) as f: + json.dump(input_data, f, indent=4) + + # Group examples by data source and dataset + grouped_data = defaultdict(lambda: defaultdict(lambda: {"preds": [], "gts": []})) + + for pred, gt, source, dataset in zip(predictions, ground_truths, data_sources, datasets): + grouped_data[source][dataset]["preds"].append(pred) + grouped_data[source][dataset]["gts"].append(gt) + + # Initialize the flattened result dictionary + result = {} + + # Initialize global metrics accumulators + global_metrics = { + "precision": 0.0, + "recall": 0.0, + "sensitivity": 0.0, + "specificity": 0.0, + "f1": 0.0, + "accuracy": 0.0, + } + + # Compute metrics for each dataset within each data source + total_data_sources = 0 + + for source_name, source_datasets in grouped_data.items(): + # Initialize metrics accumulators for this data source + source_metrics = { + "precision": 0.0, + "recall": 0.0, + "sensitivity": 0.0, + "specificity": 0.0, + "f1": 0.0, + "accuracy": 0.0, + } + + total_datasets_in_source = 0 + + for dataset_name, dataset_data in source_datasets.items(): + # Compute metrics for this dataset + dataset_result = compute_dataset_metrics(dataset_data["preds"], dataset_data["gts"]) + + # Store dataset-level metrics with the format "data_source/dataset/metric" + for metric_name, metric_value in dataset_result["dataset_metrics"].items(): + result[f"{source_name}/{dataset_name}/{metric_name}"] = metric_value + + # Skip empty datasets + if dataset_result["active_classes"] == 0: + continue + + total_datasets_in_source += 1 + + # Accumulate metrics for data source average (equal dataset weighting) + for metric_name in source_metrics.keys(): + source_metrics[metric_name] += dataset_result["dataset_metrics"][metric_name] + + # Calculate data source average (equal dataset weighting) + if total_datasets_in_source > 0: + for metric_name in source_metrics.keys(): + source_metrics[metric_name] /= total_datasets_in_source + + # Store data source metrics with the format "data_source/metric" + for metric_name, metric_value in source_metrics.items(): + result[f"{source_name}/{metric_name}"] = metric_value + + total_data_sources += 1 + + # Accumulate for global metrics (equal data source weighting) + for metric_name in global_metrics.keys(): + global_metrics[metric_name] += source_metrics[metric_name] + + # Calculate global average (equal data source weighting) + if total_data_sources > 0: + for metric_name in global_metrics.keys(): + global_metrics[metric_name] /= total_data_sources + + # Store global metrics with the format "val/metric" + for metric_name, metric_value in global_metrics.items(): + result[f"val/{metric_name}"] = metric_value + + gender_results = gender(predictions, ground_truths, demographics) + for k, v in gender_results.items(): + result[f"fairness/gender/{k}"] = v + + age_results = age(predictions, ground_truths, demographics) + for k, v in age_results.items(): + result[f"fairness/age/{k}"] = v + + parent_results = parent(predictions, ground_truths, demographics) + for k, v in parent_results.items(): + result[f"fairness/parent/{k}"] = v + + + std_acc_values = [] + std_f1_values = [] + try: + + std_acc_values.append(gender_results["std_accuracy for sex"]) + std_f1_values.append(gender_results["std_f1 for sex"]) + + + std_acc_values.append(age_results["std_accuracy"]) + std_f1_values.append(age_results["std_f1"]) + + std_acc_values.append(parent_results["std_accuracy"]) + std_f1_values.append(parent_results["std_f1"]) + + result["fairness/avg_std_accuracy"] = sum(std_acc_values) / len(std_acc_values) + result["fairness/avg_std_f1"] = sum(std_f1_values) / len(std_f1_values) + except KeyError: + print("Some fairness metrics do not have standard deviation values, skipping average calculation.") + + return result + + +if __name__ == "__main__": + outputs_dir = "../../outputs" + output_files = [f for f in os.listdir(outputs_dir) if f.startswith("input_data_") and f.endswith(".json")] + if not output_files: + print("No output files found in the outputs directory.") + else: + latest_file = max(output_files, key=lambda f: os.path.getmtime(os.path.join(outputs_dir, f))) + with open(os.path.join(outputs_dir, latest_file), "r") as f: + input_data = json.load(f) + + predictions = input_data["predictions"] + ground_truths = input_data["ground_truths"] + data_sources = input_data["data_sources"] + datasets = input_data["datasets"] + demographics = input_data["demographics"] + + metrics = compute_metrics_by_data_source(predictions, ground_truths, data_sources, datasets, demographics) + print(json.dumps(metrics, indent=4)) \ No newline at end of file diff --git a/examples/reward_function/human_behaviour.py b/examples/reward_function/human_behaviour.py new file mode 100644 index 00000000000..9221dbf7eda --- /dev/null +++ b/examples/reward_function/human_behaviour.py @@ -0,0 +1,103 @@ +from typing import List, Dict +import re + +def extract_boxed_content(text: str) -> str: + """ + Extract content within \boxed{} or similar boxing notations. + + Args: + text (str): Text containing potentially boxed content. + + Returns: + str: Extracted boxed content or the original text if no box found. + """ + + # Look for LaTeX \boxed{} notation + boxed_match = re.search(r"\\boxed{([^}]*)}", text) + if boxed_match: + return boxed_match.group(1) + + # Look for markdown boxed notation (e.g., [boxed content]) + markdown_match = re.search(r"\[(.*?)\]", text) + if markdown_match: + return markdown_match.group(1) + + # Return the text as is if no boxed content is found + return text + +def format_reward(response: str) -> float: + """ + Check whether the response matches the expected format. + Here we require something like ... ... \boxed{...} + """ + pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) + format_match = re.fullmatch(pattern, response) + return 1.0 if format_match else 0.0 + +def accuracy_reward(response: str, ground_truth: str) -> float: + """ + Simple accuracy: exact match to ground truth string. + """ + return 1.0 if response == ground_truth else 0.0 + +def human_behaviour_compute_score_batch( + data_sources: List[str], + solution_strs: List[str], + ground_truths: List[str], + extra_infos: List[str], + **kwargs +) -> List[Dict[str, float]]: + """ + Compute human behaviour scoring for batch inputs. + + Args: + data_sources: List of data sources (unused here, but kept for interface compatibility) + solution_strs: List of model prediction strings + ground_truths: List of ground truth strings + extra_infos: List of extra information (unused here, kept for compatibility) + + Returns: + List of score dictionaries + """ + batch_scores = [] + format_weight = 0.1 + + for data_source, predict_str, ground_truth, extra_info in zip(data_sources, solution_strs, ground_truths, extra_infos): + # Normalize response formatting (e.g., qwen2.5vl quirks) + full_response = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) + pred_label = extract_boxed_content(full_response).lower() # handle qwen2.5vl-32b format + + print(pred_label) + # Compute individual components + format_score = format_reward(full_response) + standard_score = accuracy_reward(pred_label, ground_truth) + + ground_truth = ground_truth.lower() + + # Weighted overall score + overall_score = (1 - format_weight) * standard_score + format_weight * format_score + + scores = { + "score": overall_score, + "standard_score": standard_score, + "format_score": format_score, + } + batch_scores.append(scores) + + return batch_scores + + +if __name__ == "__main__": + response_str = ( + "Well, I've listened to the speech recording. It sounds like the speaker is expressing anger. " + "You know, the tone and the way the words are said seem to indicate frustration or annoyance. " + "So, I'd say the emotion is anger.\\boxed{anger}If you have any other questions or need more help, feel free to let me know." + ) + + data_sources = ["sample_audio.wav"] + solution_strs = [response_str] + ground_truths = ["anger"] + extra_infos = [""] + + scores = human_behaviour_compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos) + print(scores) diff --git a/examples/reward_function/human_behaviour_alt.py b/examples/reward_function/human_behaviour_alt.py new file mode 100644 index 00000000000..765dbeb7b16 --- /dev/null +++ b/examples/reward_function/human_behaviour_alt.py @@ -0,0 +1,140 @@ +import re +import json +from typing import Dict, List + +import numpy +import torch +import numpy as np +from mathruler.grader import extract_boxed_content +import wandb +import random + + +def parse_conditions(text): + # Remove any boxing notation if present + text = text.replace("\\boxed{", "").replace("}", "") + + # Split by common separators + for sep in [", ", " and ", " & ", ",", "&"]: + if sep in text: + return set(cond.strip() for cond in text.split(sep)) + + # If no separator found, treat as single condition + return {text.strip()} + + +def parse_json(json_output): + """ + Parsing out the markdown fencing from JSON code blocks. + """ + # Look for content between ```json and ``` + lines = json_output.splitlines() + for i, line in enumerate(lines): + if line == "```json" or line.strip() == "```": + json_output = "\n".join(lines[i + 1:]) # Remove everything before ```json + if "```" in json_output: + json_output = json_output.split("```")[0] # Remove everything after the closing ``` + break # Exit the loop once code block marker is found + return json_output + + +def extract_json_from_response(text): + """ + Extract JSON content from markdown code blocks in the response. + + Args: + text: The model's response text + + Returns: + Parsed JSON object or None if no valid JSON found + """ + # Find content between ```json and ``` + json_pattern = r"```(?:json)?\s*([\s\S]*?)```" + matches = re.findall(json_pattern, text) + + if not matches: + return None + + # Try to parse each match as JSON + for match in matches: + try: + parsed_json = json.loads(match.strip()) + return parsed_json + except json.JSONDecodeError: + continue + + # If we couldn't parse any match as valid JSON, try with ast.literal_eval + import ast + for match in matches: + try: + # Clean up the match a bit + cleaned = match.strip().replace("'", "\"") + parsed_json = ast.literal_eval(cleaned) + return parsed_json + except: + continue + + return None + + +def medical_compute_score_batch(data_sources: List[str], solution_strs: List[str], ground_truths: List[str], extra_infos: List[str], **kwargs) -> List[Dict[str, float]]: + """ + Compute medical scoring for batch inputs including standard score, bounding box IoU, and format score. + + Args: + data_sources: List of data sources (e.g., file paths or identifiers) + solution_strs: List of model prediction strings + ground_truths: List of ground truth strings + extra_infos: List of extra information (e.g., segmentation masks, bounding boxes) + + Returns: + List of score dictionaries + """ + batch_scores = [] + + for data_source, predict_str, ground_truth, extra_info in zip(data_sources, solution_strs, ground_truths, extra_infos): + segmentation_mask = None + bbox = None + + # Calculate standard score + answer = extract_boxed_content(predict_str) + if answer == "None": + standard_score = 0.0 # no answer + else: + # Parse both prediction and ground truth into sets of conditions + predicted_conditions = parse_conditions(answer) + ground_truth_conditions = parse_conditions(ground_truth) + + # Calculate true positives, false positives, and false negatives + true_positives = len(predicted_conditions.intersection(ground_truth_conditions)) + false_positives = len(predicted_conditions - ground_truth_conditions) + false_negatives = len(ground_truth_conditions - predicted_conditions) + + # Calculate F1 score components + precision = ( + true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 + ) + recall = true_positives / (true_positives + false_negatives) if ( + true_positives + false_negatives) > 0 else 0 + + # Calculate F1 score (harmonic mean of precision and recall) + standard_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # Calculate format score (how well the JSON follows the expected format) + format_score = evaluate_bbox_format(predict_str) + + # length score + if len(predict_str) > 600: # ~200 words + length_score = 1 + else: + length_score = len(predict_str) * 0.001 + + scores = { + "score": 0.5 * standard_score + 0.3 * iou_score + 0.1 * format_score, + "standard_score": standard_score, + "format_score": format_score, + "length_score": length_score, + } + batch_scores.append(scores) + + return batch_scores \ No newline at end of file diff --git a/examples/reward_function/math.py b/examples/reward_function/math.py new file mode 100644 index 00000000000..f39b32a5699 --- /dev/null +++ b/examples/reward_function/math.py @@ -0,0 +1,49 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any + +from mathruler.grader import extract_boxed_content, grade_answer + + +def format_reward(response: str) -> float: + pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) + format_match = re.fullmatch(pattern, response) + return 1.0 if format_match else 0.0 + + +def accuracy_reward(response: str, ground_truth: str) -> float: + answer = extract_boxed_content(response) + return 1.0 if grade_answer(answer, ground_truth) else 0.0 + + +def compute_score(reward_inputs: list[dict[str, Any]], format_weight: float = 0.1) -> list[dict[str, float]]: + if not isinstance(reward_inputs, list): + raise ValueError("Please use `reward_type=batch` for math reward function.") + + scores = [] + for reward_input in reward_inputs: + response = re.sub(r"\s*(<|>|/)\s*", r"\1", reward_input["response"]) # handle qwen2.5vl-32b format + format_score = format_reward(response) + accuracy_score = accuracy_reward(response, reward_input["ground_truth"]) + scores.append( + { + "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, + "format": format_score, + "accuracy": accuracy_score, + } + ) + + return scores \ No newline at end of file diff --git a/examples/reward_function/medical.py b/examples/reward_function/medical.py new file mode 100644 index 00000000000..aeeac05b019 --- /dev/null +++ b/examples/reward_function/medical.py @@ -0,0 +1,460 @@ +import re +import json +from typing import Dict, List + +import numpy +import torch +import numpy as np +from mathruler.grader import extract_boxed_content +import wandb +import random + + +def parse_conditions(text): + # Remove any boxing notation if present + text = text.replace("\\boxed{", "").replace("}", "") + + # Split by common separators + for sep in [", ", " and ", " & ", ",", "&"]: + if sep in text: + return set(cond.strip() for cond in text.split(sep)) + + # If no separator found, treat as single condition + return {text.strip()} + + +def parse_json(json_output): + """ + Parsing out the markdown fencing from JSON code blocks. + """ + # Look for content between ```json and ``` + lines = json_output.splitlines() + for i, line in enumerate(lines): + if line == "```json" or line.strip() == "```": + json_output = "\n".join(lines[i + 1:]) # Remove everything before ```json + if "```" in json_output: + json_output = json_output.split("```")[0] # Remove everything after the closing ``` + break # Exit the loop once code block marker is found + return json_output + + +def extract_json_from_response(text): + """ + Extract JSON content from markdown code blocks in the response. + + Args: + text: The model's response text + + Returns: + Parsed JSON object or None if no valid JSON found + """ + # Find content between ```json and ``` + json_pattern = r"```(?:json)?\s*([\s\S]*?)```" + matches = re.findall(json_pattern, text) + + if not matches: + return None + + # Try to parse each match as JSON + for match in matches: + try: + parsed_json = json.loads(match.strip()) + return parsed_json + except json.JSONDecodeError: + continue + + # If we couldn't parse any match as valid JSON, try with ast.literal_eval + import ast + for match in matches: + try: + # Clean up the match a bit + cleaned = match.strip().replace("'", "\"") + parsed_json = ast.literal_eval(cleaned) + return parsed_json + except: + continue + + return None + + +def bbox_to_mask(bbox, height, width): + """ + Convert bounding box to binary mask. + + Args: + bbox: Bounding box in format [x1, y1, x2, y2] + height: Height of the mask + width: Width of the mask + + Returns: + Binary mask of shape (height, width) + """ + mask = torch.zeros((height, width), dtype=torch.float32) + + # Ensure bbox coordinates are within image boundaries + x1 = max(0, min(int(bbox[0]), width - 1)) + y1 = max(0, min(int(bbox[1]), height - 1)) + x2 = max(0, min(int(bbox[2]), width - 1)) + y2 = max(0, min(int(bbox[3]), height - 1)) + + # Handle cases where x1>x2 or y1>y2 + if x1 > x2: + x1, x2 = x2, x1 + if y1 > y2: + y1, y2 = y2, y1 + + # Set the box region to 1 + if x1 < x2 and y1 < y2: # Ensure valid box dimensions + mask[y1:y2 + 1, x1:x2 + 1] = 1.0 + + return mask + + +def calculate_bbox_iou(pred_bboxes, seg_mask=None, gt_bbox=None): + """ + Calculate IoU between predicted bounding boxes and ground truth (segmentation mask or bbox). + + Args: + pred_bboxes: List of predicted bounding boxes in format [x1, y1, x2, y2] + seg_mask: Ground truth segmentation mask tensor + gt_bbox: Ground truth bounding box in format [x1, y1, x2, y2] + + Returns: + Mean IoU score across all bounding boxes + """ + if not pred_bboxes: + return 0.0 + + # If single layer bbox, wrap it in a list + if not isinstance(pred_bboxes[0], list): + pred_bboxes = [pred_bboxes] + + if seg_mask is not None and isinstance(seg_mask, numpy.ndarray): + seg_mask = torch.from_numpy(seg_mask) + + # Not none and not all zero + if seg_mask is not None and torch.sum(seg_mask) > 0: + # Get mask dimensions + if len(seg_mask.shape) == 3: # Channel dimension + height, width = seg_mask.shape[1], seg_mask.shape[2] + else: + height, width = seg_mask.shape[0], seg_mask.shape[1] + + # Convert segmentation mask to binary (1 for any positive value) + binary_seg_mask = (seg_mask > 0).float() + + total_iou = 0.0 + for bbox in pred_bboxes: + if len(bbox) < 4: + continue + # Convert bbox to mask + try: + bbox_mask = bbox_to_mask(bbox, height, width) + except: + continue + + # Calculate intersection and union + intersection = torch.sum(bbox_mask * binary_seg_mask) + union = torch.sum(torch.clamp(bbox_mask + binary_seg_mask, 0, 1)) + + # Calculate IoU + iou = intersection / union if union > 0 else 0.0 + total_iou += iou + + # Return mean IoU + return total_iou / len(pred_bboxes) + + elif gt_bbox is not None: + # Calculate IoU directly between bounding boxes + total_iou = 0.0 + for pred_bbox in pred_bboxes: + if len(pred_bbox) < 4: + continue + # Calculate intersection + gt_bbox = gt_bbox.tolist() + # print("pred_bbox: ", pred_bbox.__class__) + # print("gt_bbox: ", gt_bbox.__class__) + x1 = max(pred_bbox[0], gt_bbox[0]) + y1 = max(pred_bbox[1], gt_bbox[1]) + x2 = min(pred_bbox[2], gt_bbox[2]) + y2 = min(pred_bbox[3], gt_bbox[3]) + + # Check if boxes overlap + if x1 >= x2 or y1 >= y2: + iou = 0.0 + else: + # Calculate areas + intersection = (x2 - x1) * (y2 - y1) + pred_area = (pred_bbox[2] - pred_bbox[0]) * (pred_bbox[3] - pred_bbox[1]) + gt_area = (gt_bbox[2] - gt_bbox[0]) * (gt_bbox[3] - gt_bbox[1]) + union = pred_area + gt_area - intersection + + # Calculate IoU + iou = intersection / union if union > 0 else 0.0 + + total_iou += iou + + # Return mean IoU + return total_iou / len(pred_bboxes) + + else: + # Neither segmentation mask nor ground truth bbox provided + return 0.0 + + +def evaluate_bbox_format(predict_str): + """ + Evaluate the format correctness of the bounding box JSON in the response. + Returns a score based on how well the response follows the expected format. + + Args: + predict_str: The model's prediction string + + Returns: + Format score between 0.0 and 1.0 + """ + format_score = 0.0 + + # Check if response contains a code block + if "```" in predict_str: + format_score += 0.2 # 20% for having a code block + + # Check if it's specifically marked as JSON + if "```json" in predict_str: + format_score += 0.1 # Additional 10% for correct JSON marker + + # Try to extract and parse JSON + json_str = parse_json(predict_str) + if not json_str: + return format_score # Failed to find JSON content + + try: + # Try to parse as JSON + parsed_json = None + try: + parsed_json = json.loads(json_str) + format_score += 0.2 # Additional 20% for valid JSON + except json.JSONDecodeError: + # Try with ast.literal_eval as fallback + import ast + try: + cleaned = json_str.replace("'", "\"") + parsed_json = ast.literal_eval(cleaned) + format_score += 0.1 # Only 10% for requiring fallback parsing + except: + return format_score # Failed to parse + + # Check if it's a list of objects + if not isinstance(parsed_json, list): + return format_score + + format_score += 0.1 # Additional 10% for being a list + + # Check each item for proper bbox structure + valid_items = 0 + total_items = len(parsed_json) + + for item in parsed_json: + if not isinstance(item, dict): + continue + + # Check for required fields + has_bbox = "bbox_2d" in item + has_label = "label" in item + + if has_bbox and has_label: + bbox = item["bbox_2d"] + # Check bbox format [x1, y1, x2, y2] + if (isinstance(bbox, list) and len(bbox) == 4 and + all(isinstance(coord, (int, float)) for coord in bbox)): + valid_items += 1 + + # Add up to 40% based on proportion of valid items + if total_items > 0: + format_score += 0.4 * (valid_items / total_items) + + except Exception: + # Any other parsing issues + pass + + return format_score + + +def medical_compute_score(predict_str: str, ground_truth: str, segmentation_mask=None, bbox=None) -> Dict[str, float]: + """ + Compute medical scoring including standard score, bounding box IoU, and format score. + + Args: + predict_str: The model's prediction string + ground_truth: The ground truth string + segmentation_mask: Ground truth segmentation mask tensor + bbox: Ground truth bounding box + + Returns: + Tuple of (standard_score, bbox_score) + Note: bbox_score is a combination of IoU score and format score + """ + # Calculate standard score + answer = extract_boxed_content(predict_str) + if answer == "None": + standard_score = 0.0 # no answer + else: + # Parse both prediction and ground truth into sets of conditions + predicted_conditions = parse_conditions(answer) + ground_truth_conditions = parse_conditions(ground_truth) + + # Calculate true positives, false positives, and false negatives + true_positives = len(predicted_conditions.intersection(ground_truth_conditions)) + false_positives = len(predicted_conditions - ground_truth_conditions) + false_negatives = len(ground_truth_conditions - predicted_conditions) + + # Calculate F1 score components + precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 + recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 + + # Calculate F1 score (harmonic mean of precision and recall) + standard_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # Calculate format score (how well the JSON follows the expected format) + format_score = evaluate_bbox_format(predict_str) + + # length score + if len(predict_str) > 600: # ~200 words + length_score = 1 + else: + length_score = len(predict_str) * 0.001 + + + # Calculate bounding box IoU score + iou_score = 0.0 + # Extract predicted bounding boxes from the response + json_data = extract_json_from_response(predict_str) + if json_data: + # Extract bounding boxes from the JSON + try: + pred_bboxes = [] + if isinstance(json_data, list): + for item in json_data: + if isinstance(item, dict) and "bbox_2d" in item: + pred_bboxes.append(item["bbox_2d"]) + elif isinstance(json_data, dict) and "bbox_2d" in json_data: + pred_bboxes.append(json_data["bbox_2d"]) + elif isinstance(json_data, dict) and 'objects_of_interest' in json_data: + for item in json_data['objects_of_interest']: + if isinstance(item, dict) and "bbox_2d" in item: + pred_bboxes.append(item["bbox_2d"]) + # else: + # print("Error: Invalid JSON format") + if random.random() < 0.0005: # print every 0.5% + print("[Bounding Box] ", json_data) + print("[Formatted Bounding Box] ", pred_bboxes) + print('[GT Bounding Box] ', bbox) + + # Calculate IoU between predicted boxes and ground truth + if pred_bboxes: + iou_score = calculate_bbox_iou(pred_bboxes, segmentation_mask, bbox) + except: + pass + # traceback.print_exc() + + scores = { + "overall": 0.6 * standard_score + 0.2 * iou_score + 0.1 * format_score + 0.1 * length_score, + "standard_score": standard_score, + "iou_score": iou_score, + "format_score": format_score, + } + return scores + + +def medical_compute_score_batch(data_sources: List[str], solution_strs: List[str], ground_truths: List[str], extra_infos: List[str], **kwargs) -> List[Dict[str, float]]: + """ + Compute medical scoring for batch inputs including standard score, bounding box IoU, and format score. + + Args: + data_sources: List of data sources (e.g., file paths or identifiers) + solution_strs: List of model prediction strings + ground_truths: List of ground truth strings + extra_infos: List of extra information (e.g., segmentation masks, bounding boxes) + + Returns: + List of score dictionaries + """ + batch_scores = [] + + for data_source, predict_str, ground_truth, extra_info in zip(data_sources, solution_strs, ground_truths, extra_infos): + segmentation_mask = None + bbox = None + + # Calculate standard score + answer = extract_boxed_content(predict_str) + if answer == "None": + standard_score = 0.0 # no answer + else: + # Parse both prediction and ground truth into sets of conditions + predicted_conditions = parse_conditions(answer) + ground_truth_conditions = parse_conditions(ground_truth) + + # Calculate true positives, false positives, and false negatives + true_positives = len(predicted_conditions.intersection(ground_truth_conditions)) + false_positives = len(predicted_conditions - ground_truth_conditions) + false_negatives = len(ground_truth_conditions - predicted_conditions) + + # Calculate F1 score components + precision = ( + true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 + ) + recall = true_positives / (true_positives + false_negatives) if ( + true_positives + false_negatives) > 0 else 0 + + # Calculate F1 score (harmonic mean of precision and recall) + standard_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # Calculate format score (how well the JSON follows the expected format) + format_score = evaluate_bbox_format(predict_str) + + # length score + if len(predict_str) > 600: # ~200 words + length_score = 1 + else: + length_score = len(predict_str) * 0.001 + + # Calculate bounding box IoU score + iou_score = 0.0 + # Extract predicted bounding boxes from the response + json_data = extract_json_from_response(predict_str) + if json_data: + # Extract bounding boxes from the JSON + try: + pred_bboxes = [] + if isinstance(json_data, list): + for item in json_data: + if isinstance(item, dict) and "bbox_2d" in item: + pred_bboxes.append(item["bbox_2d"]) + elif isinstance(json_data, dict) and "bbox_2d" in json_data: + pred_bboxes.append(json_data["bbox_2d"]) + elif isinstance(json_data, dict) and "objects_of_interest" in json_data: + for item in json_data["objects_of_interest"]: + if isinstance(item, dict) and "bbox_2d" in item: + pred_bboxes.append(item["bbox_2d"]) + + if random.random() < 0.005: # print every 0.5% + print("[Bounding Box] ", json_data) + print("[Formatted Bounding Box] ", pred_bboxes) + print("[GT Bounding Box] ", bbox) + + # Calculate IoU between predicted boxes and ground truth + if pred_bboxes: + iou_score = calculate_bbox_iou(pred_bboxes, segmentation_mask, bbox) + except: + pass + + scores = { + "score": 0.5 * standard_score + 0.3 * iou_score + 0.1 * format_score, + "standard_score": standard_score, + "iou_score": iou_score, + "format_score": format_score, + "length_score": length_score, + } + batch_scores.append(scores) + + return batch_scores \ No newline at end of file diff --git a/examples/reward_function/r1v.py b/examples/reward_function/r1v.py new file mode 100644 index 00000000000..6a28548b292 --- /dev/null +++ b/examples/reward_function/r1v.py @@ -0,0 +1,50 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any + +from mathruler.grader import grade_answer + + +def format_reward(response: str) -> float: + pattern = re.compile(r".*?\s*.*?", re.DOTALL) + format_match = re.fullmatch(pattern, response) + return 1.0 if format_match else 0.0 + + +def accuracy_reward(response: str, ground_truth: str) -> float: + try: + content_match = re.search(r"(.*?)", response) + given_answer = content_match.group(1).strip() if content_match else response.strip() + if grade_answer(given_answer, ground_truth.strip()): + return 1.0 + + except Exception: + pass + + return 0.0 + + +def compute_score(reward_input: dict[str, Any], format_weight: float = 0.5) -> dict[str, float]: + if not isinstance(reward_input, dict): + raise ValueError("Please use `reward_type=sequential` for r1v reward function.") + + format_score = format_reward(reward_input["response"]) + accuracy_score = accuracy_reward(reward_input["response"], reward_input["ground_truth"]) + return { + "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, + "format": format_score, + "accuracy": accuracy_score, + } diff --git a/examples/rloo_trainer/run_qwen2-7b.sh b/examples/rloo_trainer/run_qwen2-7b.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_deepseek_6b7.sh b/examples/sft/gsm8k/run_deepseek_6b7.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_gemma_2b.sh b/examples/sft/gsm8k/run_gemma_2b.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_gemma_7b.sh b/examples/sft/gsm8k/run_gemma_7b.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh b/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_qwen_05_peft.sh b/examples/sft/gsm8k/run_qwen_05_peft.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_qwen_05_sp2.sh b/examples/sft/gsm8k/run_qwen_05_sp2.sh old mode 100644 new mode 100755 diff --git a/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh b/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh old mode 100644 new mode 100755 diff --git a/examples/sft/multiturn/run_qwen_05_sp2.sh b/examples/sft/multiturn/run_qwen_05_sp2.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh b/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/run_qwen2_3b_dapo_multiturn.sh b/examples/sglang_multiturn/run_qwen2_3b_dapo_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh b/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh old mode 100644 new mode 100755 diff --git a/examples/split_placement/run_deepseek7b_llm.sh b/examples/split_placement/run_deepseek7b_llm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh b/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh b/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh b/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh b/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh b/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh b/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh b/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh b/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh old mode 100644 new mode 100755 diff --git a/requirements.txt b/requirements.txt index 162022343a1..b5512988789 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ accelerate codetiming datasets dill -flash-attn +# flash-attn hydra-core liger-kernel numpy<2.0.0 @@ -17,6 +17,7 @@ ray[default] tensordict>=0.8.0,<=0.9.1,!=0.9.0 torchdata transformers +vllm # vllm==0.8.4 wandb packaging>=20.0 diff --git a/scripts/process_mosei_annotations.py b/scripts/process_mosei_annotations.py new file mode 100644 index 00000000000..6d655e0998b --- /dev/null +++ b/scripts/process_mosei_annotations.py @@ -0,0 +1,74 @@ +import json +import tqdm + + +def process_mosei_annotations(annotation_path: str) -> None: + data = [] + with open(annotation_path, "r") as f: # jsonl file + for line in f: + entry = json.loads(line.strip()) + data.append(entry) + + formatted_data = [] + for sample in tqdm.tqdm(data): + image_path = sample["image"] + video_id = image_path.split("/")[1].split("_")[0] + clip_id = image_path.split("_")[-1].split(".")[0] + raw_video_path = f"Raw/{video_id}/{clip_id}.mp4" + + problem: str = sample["conversations"][0]["value"] + question_statement = problem.index("What is ") + question_str = problem[question_statement:] + answer_str = sample["conversations"][1]["value"] + + new_entry = { + "videos": [raw_video_path], + "problem": question_str, + "answer": answer_str, + } + + # avoid adding if the video and problem already exists + if not any( + entry["videos"] == new_entry["videos"] and entry["problem"] == new_entry["problem"] + for entry in formatted_data + ): + formatted_data.append(new_entry) + + formatted_data = sorted(formatted_data, key=lambda entry: entry["videos"]) + + output_path = annotation_path.replace(".jsonl", "_formatted.jsonl") + with open(output_path, "w") as f: + for entry in formatted_data: + f.write(json.dumps(entry) + "\n") + + # Add train test split of 80-20, calling it annotations_train.jsonl and annotations_test.jsonl + split_index = int(0.8 * len(formatted_data)) + train_data = formatted_data[:split_index] + test_data = formatted_data[split_index:] + folder_name = annotation_path.rsplit("/", 1)[0] if "/" in annotation_path else "." + train_output_path = f"{folder_name}/annotations_train.jsonl" + test_output_path = f"{folder_name}/annotations_test.jsonl" + + with open(train_output_path, "w") as f: + for entry in train_data: + f.write(json.dumps(entry) + "\n") + + with open(test_output_path, "w") as f: + for entry in test_data: + f.write(json.dumps(entry) + "\n") + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Process MOSEI annotations") + parser.add_argument( + "--annotation_path", + type=str, + default="mosei_annotations.jsonl", + help="Path to the MOSEI annotations file (default: mosei_annotations.jsonl)" + ) + + args = parser.parse_args() + + process_mosei_annotations(args.annotation_path) + print(f"Processed annotations saved to {args.annotation_path.replace('.jsonl', '_formatted.jsonl')}") diff --git a/verl/model_merger/base_model_merger.py b/verl/model_merger/base_model_merger.py index 08859cc5568..4f8091788b0 100644 --- a/verl/model_merger/base_model_merger.py +++ b/verl/model_merger/base_model_merger.py @@ -191,6 +191,9 @@ def __init__(self, config: ModelMergerConfig): def get_transformers_auto_model_class(self): if "ForTokenClassification" in self.model_config.architectures[0]: return AutoModelForTokenClassification + elif "Qwen2.5-Omni" in self.hf_model_config_path: + from transformers import Qwen2_5OmniThinkerForConditionalGeneration + return Qwen2_5OmniThinkerForConditionalGeneration elif "ForCausalLM" in self.model_config.architectures[0]: return AutoModelForCausalLM elif "ForConditionalGeneration" in self.model_config.architectures[0]: diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 2866a02353e..2597ef6faa5 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -58,9 +58,20 @@ def get_rope_index( """ spatial_merge_size = processor.image_processor.merge_size tokens_per_second = 2 + + # Try old token names first, then fall back to new ones image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + if image_token_id is None: + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_pad|>") # New tokenizer uses vision_pad for images + video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") + if video_token_id is None: + video_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_pad|>") # New tokenizer uses vision_pad for videos + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") + if vision_start_token_id is None: + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_bos|>") # New tokenizer uses vision_bos + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) @@ -69,10 +80,20 @@ def get_rope_index( image_index, video_index = 0, 0 input_ids = input_ids[attention_mask == 1] image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() + + # Ensure vision_start_token_id is valid before comparison + if vision_start_token_id is not None: + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) + else: + vision_start_indices = torch.empty((0, 1), dtype=torch.long, device=input_ids.device) + + # Handle case where there are vision tokens + if vision_start_indices.numel() > 0: + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum().item() if image_token_id is not None else 0 + video_nums = (vision_tokens == video_token_id).sum().item() if video_token_id is not None else 0 + else: + image_nums, video_nums = 0, 0 input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 27233a87994..29d33429668 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -264,6 +264,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + format_prompt: examples/format_prompt/default.jinja reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 @@ -280,6 +281,14 @@ data: truncation: error image_key: images video_key: videos + audio_key: audios + modalities: images,videos + train_modality_batching: + enabled: false + drop_last: false + val_modality_batching: + enabled: false + drop_last: false trust_remote_code: false custom_cls: path: null diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index bca4e51679c..c468f6db4bb 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -237,6 +237,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + format_prompt: examples/format_prompt/default.jinja reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 @@ -253,6 +254,14 @@ data: truncation: error image_key: images video_key: videos + audio_key: audios + modalities: images,videos # list of modalities to process + train_modality_batching: + enabled: false + drop_last: false + val_modality_batching: + enabled: false + drop_last: false trust_remote_code: false custom_cls: path: null diff --git a/verl/trainer/config/data/legacy_data.yaml b/verl/trainer/config/data/legacy_data.yaml index 9a5ce8f0dd1..19d93f89b12 100644 --- a/verl/trainer/config/data/legacy_data.yaml +++ b/verl/trainer/config/data/legacy_data.yaml @@ -16,6 +16,11 @@ val_files: ~/data/rlhf/gsm8k/test.parquet # The field in the dataset where the prompt is located. Default is 'prompt'. prompt_key: prompt +# Path to the format prompt template file. If null, uses the default format prompt. +# The template should be a Jinja2 template that will be applied to each prompt. +# Example: examples/format_prompt/default.jinja +format_prompt: examples/format_prompt/default.jinja + # The field used to select the reward function (if using different ones per example). reward_fn_key: data_source @@ -72,6 +77,21 @@ image_key: images # The field in the multi-modal dataset where the video is located. video_key: videos +# The field in the multi-modal dataset where the audio is located. Default is 'audios'. +audio_key: audios + +# Comma-separated list of modalities to process. Default is 'images,videos'. +# Available modalities: images, videos, audios +# Example: 'images,videos,audios' to enable all modalities +modalities: images,videos + +train_modality_batching: + enabled: false + drop_last: false +val_modality_batching: + enabled: false + drop_last: false + # If the remote tokenizer has a Python file, this flag determines whether to allow using it. trust_remote_code: False diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index de218572e4e..33bcaa89a65 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -47,10 +47,10 @@ free_cache_engine: True tensor_model_parallel_size: 2 # max number of tokens in a batch -max_num_batched_tokens: 8192 +max_num_batched_tokens: 1536 # max length for rollout -max_model_len: null +max_model_len: 1536 # max length of sequences max_num_seqs: 1024 diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 071035b33fd..0e0a23ce1a7 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -22,12 +22,14 @@ import ray from omegaconf import OmegaConf +from typing import Dict, List from verl.experimental.dataset.sampler import AbstractSampler from verl.trainer.constants_ppo import get_ppo_ray_runtime_env from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.trainer.ppo.reward import load_reward_manager from verl.utils.device import is_cuda_available from verl.utils.import_utils import load_extern_type +from verl.utils.dataset.modality_sampler import ModalitySignatureBatchSampler @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) @@ -58,6 +60,7 @@ def run_ppo(config) -> None: ray.init( runtime_env=get_ppo_ray_runtime_env(), num_cpus=config.ray_init.num_cpus, + dashboard_host="0.0.0.0", ) # Create a remote instance of the TaskRunner class, and @@ -74,6 +77,8 @@ def run_ppo(config) -> None: runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() else: runner = TaskRunner.remote() + + # RUN THE TRAINING using runner.run ray.get(runner.run.remote(config)) # [Optional] get the path of the timeline trace file from the configuration, default to None @@ -254,11 +259,17 @@ def run(self, config): from verl.utils.dataset.rl_dataset import collate_fn # Create training and validation datasets. + # This is done by reading from the train and the val files train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) - train_sampler = create_rl_sampler(config.data, train_dataset) + train_sampler = create_rl_sampler(config.data, train_dataset, split="train") + + # print(f"Using train sampler: {train_sampler}") + # print(f"Using val dataset: {val_dataset}") + # print(f"Using train dataset: {train_dataset}") # Initialize the PPO trainer. + # TODO: train sampler is fed into this; and is used to shuffle the training dataset (and extending to validation) trainer = RayPPOTrainer( config=config, tokenizer=tokenizer, @@ -328,8 +339,50 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr return dataset - -def create_rl_sampler(data_config, dataset): +# NOTE: This is the old rl_sampler +# def create_rl_sampler(data_config, dataset): +# """Create a sampler for the dataset. + +# Arguments: +# data_config: The data config. +# dataset (Dataset): The dataset. + +# Returns: +# sampler (Sampler): The sampler. +# """ +# import torch +# from torch.utils.data import RandomSampler, SequentialSampler + +# if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: +# curriculum_class = load_extern_type( +# data_config.sampler.class_path, +# data_config.sampler.class_name, +# ) +# sampler = curriculum_class( +# data_source=dataset, +# data_config=data_config, +# ) +# assert isinstance(sampler, AbstractSampler) +# assert data_config.get("dataloader_num_workers", 8) == 0, ( +# "If using curriculum, num_workers must be 0 to prevent data caching. " +# "If the dataloader caches data before the batch is done the " +# "curriculum sampler won't have the opportunity to reorder it. " +# ) + +# # Use a sampler to facilitate checkpoint resumption. +# # If shuffling is enabled in the data configuration, create a random sampler. +# elif data_config.shuffle: +# train_dataloader_generator = torch.Generator() +# train_dataloader_generator.manual_seed(data_config.get("seed", 1)) +# sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) +# else: +# # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. +# sampler = SequentialSampler(data_source=dataset) + +# return sampler + +# NOTE: This is your implementation +def create_rl_sampler(data_config, dataset, split: str = "train"): """Create a sampler for the dataset. Arguments: @@ -342,6 +395,10 @@ def create_rl_sampler(data_config, dataset): import torch from torch.utils.data import RandomSampler, SequentialSampler + # modality batching config parse + mb_cfg = data_config.get("train_modality_batching") if split == "train" \ + else data_config.get("val_modality_batching") + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: curriculum_class = load_extern_type( data_config.sampler.class_path, @@ -358,18 +415,52 @@ def create_rl_sampler(data_config, dataset): "curriculum sampler won't have the opportunity to reorder it. " ) + if mb_cfg and mb_cfg.get("enabled", False): + print(f"Creating our modality sampler for split: {split}") + # by_sig is actually the collation of dataset indices grouped by their modality signature + by_sig: Dict[str, List[int]] = {} + # essentially getting "modality_signature" from the jsonl dataset + for i in range(len(dataset)): + row = dataset.dataframe[i] if hasattr(dataset, "dataframe") else dataset[i] + sig = row.get("modality_signature") + if sig is None: + print(f"[WARNING] Row {i} missing 'modality_signature'. Skipping.") + continue + by_sig.setdefault(sig, []).append(i) + + # batch_size = mb_cfg.get("batch_size", data_config.get( + # "train_batch_size" if split=="train" else "val_batch_size" + # )) + + batch_size = data_config.get("train_batch_size" if split=="train" else "val_batch_size") + + drop_last = mb_cfg.get("drop_last") + + # shuffle if split (meaning that we shuffle the samples within each batch) + shuffle = (split == "train") + + print(f"Creating our modality sampler for split: {split}, batch_size: {batch_size}, drop_last: {drop_last}, shuffle: {shuffle}") + + sampler = ModalitySignatureBatchSampler( + indices_by_sig=by_sig, + batch_size=int(batch_size), + drop_last=drop_last, + seed=data_config.get("seed", 42), + shuffle=shuffle, + ) + # Use a sampler to facilitate checkpoint resumption. # If shuffling is enabled in the data configuration, create a random sampler. - elif data_config.shuffle: + elif data_config.shuffle and split == "train": train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(data_config.get("seed", 1)) sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. sampler = SequentialSampler(data_source=dataset) return sampler - if __name__ == "__main__": main() diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 7ec622036d9..2ea8bfb6305 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -20,9 +20,11 @@ __all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"] +import math from collections import defaultdict from enum import Enum -from typing import Any, Callable, Optional +from sklearn.cluster import KMeans +from typing import Any, Callable, Optional, Dict, List, Tuple import numpy as np import torch @@ -101,6 +103,7 @@ class AdvantageEstimator(str, Enum): OPO = "opo" GRPO_PASSK = "grpo_passk" GPG = "gpg" + DRPO = "drpo" ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {} @@ -324,6 +327,181 @@ def compute_grpo_outcome_advantage( return scores, scores +EPS_DEFAULT: float = 1e-6 + +# Per‑domain question history ------------------------------------------------ # +# domain_qstats[dom] = { +# "vectors": List[np.ndarray] # shape = (Q, R) +# "q_ids": List[int], # question ids in same order as vectors +# "count": int, # #questions accumulated so far +# } +# --------------------------------------------------------------------------- # +domain_qstats: Dict[Any, Dict[str, Any]] = defaultdict(lambda: { + "vectors": [], + "q_ids": [], + "count": 0, +}) + +global_running_stats: Dict[str, int] = {"q_count": 0} + +# --------------------------------------------------------------------------- # +# Helpers # +# --------------------------------------------------------------------------- # + +def _select_k_elbow(vals: np.ndarray, k_max: int = 10, tol: float = 0.10) -> int: + """k‑means elbow pick on multi‑dimensional points.""" + unique_cnt = len(np.unique(vals, axis=0)) + k_cap = min(k_max, unique_cnt) + ks = range(1, k_cap + 1) + inertias = [KMeans(n_clusters=k, n_init="auto", random_state=0).fit(vals).inertia_ for k in ks] + if len(inertias) == 1: + return 1 + drops = np.diff(inertias) * -1.0 + for i in range(1, len(drops)): + if drops[i] < tol * drops[i - 1]: + return i + 1 + return ks[-1] + + +def _cluster_info_question(vectors: List[np.ndarray]) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: + """K‑means on question‑level vectors. + + Returns + ------- + mu_d : float – inverse‑cluster‑size weighted mean of the centroid means + assignments : (Q,) – cluster index for each question vector + counts : (k,) – cluster sizes + centroids : (k,R) – cluster centroid vectors + """ + if len(vectors) == 0: + return 0.0, np.empty(0, int), np.empty(0), np.empty((0, 0)) + + X = np.stack(vectors, axis=0) # (Q,R) – R inferred from data + k_opt = _select_k_elbow(X, k_max=20) + km = KMeans(n_clusters=k_opt, n_init="auto", random_state=0).fit(X) + + centroids = km.cluster_centers_ # (k,R) + assignments = km.labels_ # (Q,) + _, counts = np.unique(assignments, return_counts=True) + counts = counts.astype(float) + + centroid_means = centroids.mean(axis=1) # (k,) + weights = 1.0 / counts + mu_d = float((weights * centroid_means).sum() / weights.sum()) + + # Debug ------------------------------------------------------------- # + print( + f"[KMEANS‑Q] k={k_opt} | centroid_means=" + f"[{', '.join(f'{m:.3f}' for m in centroid_means)}] | counts={counts.tolist()} | μ_d={mu_d:.3f}" + ) + + return mu_d, assignments, counts, centroids + + +@register_adv_est(AdvantageEstimator.DRPO) +def compute_drpo_outcome_advantage( + token_level_rewards: torch.Tensor, # (B,L) + response_mask: torch.Tensor, # (B,L) + index: np.ndarray[str], # (B,) question ids + domain_info: np.ndarray, # (B,) domain ids + epsilon: float = EPS_DEFAULT, +): + """DRPO with question‑level clustering.""" + + B, L = token_level_rewards.shape + + # 1) raw rollout‑level rewards -------------------------------------- # + raw_scores = token_level_rewards.sum(dim=-1) # (B,) + + # 2) collect rollouts per question for this mini‑batch -------------- # + q2rollouts: Dict[str, List[float]] = defaultdict(list) + q2domain: Dict[str, Any] = {} + for i in range(B): + qid: str = index[i] + q2rollouts[qid].append(raw_scores[i].item()) + q2domain[qid] = domain_info[i] + + # ensure consistent rollout count ----------------------------------- # + rollout_lens = {len(v) for v in q2rollouts.values()} + assert len(rollout_lens) == 1, "Inconsistent rollout counts per question in batch!" + + # build vector per question ----------------------------------------- # + q_vectors = {qid: np.asarray(v, dtype=np.float32) for qid, v in q2rollouts.items()} + + # 3) update per‑domain question history ----------------------------- # + for qid, vec in q_vectors.items(): + dom = q2domain[qid] + dstat = domain_qstats[dom] + dstat["vectors"].append(vec) + dstat["q_ids"].append(qid) + dstat["count"] += 1 + global_running_stats["q_count"] += 1 + + # 4) GRPO normalisation (within‑question) --------------------------- # + scores = raw_scores.clone() + id2mean = {qid: torch.mean(torch.tensor(v)) for qid, v in q2rollouts.items()} + id2std = {qid: torch.std (torch.tensor(v)) for qid, v in q2rollouts.items()} + for i in range(B): + qid: str = index[i] + scores[i] = (scores[i] - id2mean[qid]) / (id2std[qid] + epsilon) + before_scale_score = scores.clone() + + # 5) Domain‑wise question clustering -------------------------------- # + domain_cluster_cache: Dict[Any, Dict[str, Any]] = {} + for dom, dstat in domain_qstats.items(): + if dstat["count"] == 0: + continue + mu_d, assign, counts, centroids = _cluster_info_question(dstat["vectors"]) + domain_cluster_cache[dom] = { + "mu_d": mu_d, + "assign": assign, + "counts": counts, + "centroids": centroids, + "q_ids": dstat["q_ids"], + } + + # 6) Apply scaling --------------------------------------------------- # + scaling_factors: List[float] = [] + for i in range(B): + qid: str = index[i] + dom = q2domain[qid] + cache = domain_cluster_cache[dom] + + # map qid → cluster idx ---------------------------------------- # + q_idx = cache["q_ids"].index(qid) + cluster_idx = cache["assign"][q_idx] + + N_d = float(domain_qstats[dom]["count"]) + mu_d = cache["mu_d"] + T_d = max(math.sqrt(N_d) * mu_d, epsilon) + + N_c = float(cache["counts"][cluster_idx]) + mu_c = float(cache["centroids"][cluster_idx].mean()) + + factor = T_d * math.sqrt(N_c) * mu_c + scaling_factors.append(factor) + scores[i] = scores[i] / factor + + # divide scores by std of scores + scores_std = torch.std(scores) + scores = scores / (scores_std + epsilon) + + # Debug report -------------------------------------------------------- # + print("--------------Hierarchical scaling report--------------") + dom2scale: Dict[Any, List[torch.Tensor]] = defaultdict(list) + for i in range(B): + dom2scale[domain_info[i]].append(scores[i] / (before_scale_score[i] + epsilon)) + for dom, lst in dom2scale.items(): + avg_sf = torch.mean(torch.stack(lst)).item() + print(f"[HDRPO] domain = {dom:<15} | mean overall scale = {avg_sf:6.3f}") + + # Print global reward mean + print(f"[HDRPO] global reward mean = {torch.mean(scores):.3f}") + + returns = scores.unsqueeze(-1) * response_mask + return returns, returns + + @register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk") def compute_grpo_passk_outcome_advantage( token_level_rewards: torch.Tensor, diff --git a/verl/trainer/ppo/org_functions.py b/verl/trainer/ppo/org_functions.py new file mode 100644 index 00000000000..f9e3eb86248 --- /dev/null +++ b/verl/trainer/ppo/org_functions.py @@ -0,0 +1,78 @@ +def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset( + self.config.data.train_files, self.config.data, self.tokenizer, self.processor + ) + if val_dataset is None: + val_dataset = create_rl_dataset( + self.config.data.val_files, self.config.data, self.tokenizer, self.processor + ) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + # TODO; you can essentially specify the type of sampler here, based also on the data split + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + num_workers = self.config.data["dataloader_num_workers"] + + ## TODO: trainer_sampler is pretty much the rl_sampler here, which shuffles the dataset + # TODO: the sampler is now placed into this (which is essentially your sampler) + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + # shuffle=False, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + # TODO: validation data is shuffled here as well + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") \ No newline at end of file diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 97b68684d5c..28201700860 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -27,13 +27,15 @@ from dataclasses import dataclass, field from enum import Enum from pprint import pprint -from typing import Optional +from typing import Optional, Dict import numpy as np import ray import torch +import ujson +import wandb from omegaconf import OmegaConf, open_dict -from torch.utils.data import Dataset, Sampler +from torch.utils.data import Dataset, Sampler, BatchSampler, SequentialSampler from torchdata.stateful_dataloader import StatefulDataLoader from tqdm import tqdm @@ -61,9 +63,11 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger +from examples.reward_function.evaluation import compute_metrics_by_data_source WorkerType = type[Worker] +debug_file = "/home/keaneong/human-behavior/verl/examples/grpo_trainer/debug_log.txt" class Role(Enum): """ @@ -271,6 +275,18 @@ def compute_advantage( ) data.batch["advantages"] = advantages data.batch["returns"] = returns + elif adv_estimator == AdvantageEstimator.DRPO: + grpo_calculation_mask = data.batch["response_mask"] + domain_info = data.non_tensor_batch["dataset"] + + advantages, returns = core_algos.compute_drpo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=grpo_calculation_mask, + index=data.non_tensor_batch["uid"], + domain_info=domain_info + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns else: # handle all other adv estimator type other than GAE and GRPO adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) @@ -315,6 +331,7 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, + val_sampler: Optional[Sampler] = None, device_name=None, ): """ @@ -382,7 +399,7 @@ def __init__( self.use_critic = False self._validate_config() - self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler, val_sampler) def _validate_config(self): config = self.config @@ -499,7 +516,7 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): print("[validate_config] All configuration checks passed successfully!") - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler], val_sampler: Optional[Sampler]): """ Creates the train and validation dataloaders. """ @@ -517,7 +534,11 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl self.train_dataset, self.val_dataset = train_dataset, val_dataset if train_sampler is None: - train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + train_sampler = create_rl_sampler(self.config.data, self.train_dataset, split="train") + + if val_sampler is None: + val_sampler = create_rl_sampler(self.config.data, self.val_dataset, split="val") + if collate_fn is None: from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn @@ -525,27 +546,54 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl num_workers = self.config.data["dataloader_num_workers"] - self.train_dataloader = StatefulDataLoader( - dataset=self.train_dataset, - batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), - num_workers=num_workers, - drop_last=True, - collate_fn=collate_fn, - sampler=train_sampler, - ) - val_batch_size = self.config.data.val_batch_size # Prefer config value if set - if val_batch_size is None: - val_batch_size = len(self.val_dataset) - - self.val_dataloader = StatefulDataLoader( - dataset=self.val_dataset, - batch_size=val_batch_size, - num_workers=num_workers, - shuffle=self.config.data.get("validation_shuffle", True), - drop_last=False, - collate_fn=collate_fn, - ) + if isinstance(train_sampler, BatchSampler): + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_sampler=train_sampler, + num_workers=num_workers, + collate_fn=collate_fn, + ) + else: + # Else if it is not a batch sampler, we can specify the batch size directly + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + # shuffle=False, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + if isinstance(val_sampler, BatchSampler): + # BatchSampler path: DO NOT pass batch_size/shuffle/drop_last + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_sampler=val_sampler, + num_workers=num_workers, + collate_fn=collate_fn, + ) + else: + # Plain Sampler path: compute val_batch_size (None -> len(dataset)) + # This plain sampler path, if you trace the instance of val_sampler, + # should be that of a sequential sampler. Break if it is not. + if not isinstance(val_sampler, SequentialSampler): + raise ValueError("Validation sampler is not a SequentialSampler") + + val_batch_size = self.config.data.val_batch_size + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + sampler=val_sampler, + batch_size=val_batch_size, + num_workers=num_workers, + drop_last=False, # keep all val samples + collate_fn=collate_fn, + # Deterministic val preferred; if you want to honor a config flag, keep it here: + shuffle=self.config.data.get("validation_shuffle", False), + ) assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" @@ -573,7 +621,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl except Exception as e: print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") - def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): + def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path, **kwargs): """Dump rollout/validation samples as JSONL.""" os.makedirs(dump_path, exist_ok=True) filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") @@ -591,6 +639,14 @@ def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dic if len(v) == n: base_data[k] = v + for k, v in kwargs.items(): + if isinstance(v, np.ndarray): + base_data[k] = v.tolist() + elif hasattr(v, 'cpu'): # Check if it's a torch tensor + base_data[k] = v.cpu().numpy().tolist() + else: + base_data[k] = v + lines = [] for i in range(n): entry = {k: v[i] for k, v in base_data.items()} @@ -636,6 +692,14 @@ def _validate(self): sample_scores = [] sample_turns = [] + # New lists for metric calculation + all_predictions = [] + all_ground_truths = [] + all_data_sources = [] + all_demographics = [] + all_datasets = [] + data_source_lst = [] + for test_data in self.val_dataloader: test_batch = DataProto.from_single_dict(test_data) @@ -658,6 +722,9 @@ def _validate(self): item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch ] sample_gts.extend(ground_truths) + data_sources = test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(input_texts)) + datasets = test_batch.non_tensor_batch.get("dataset", ["unknown"] * len(input_texts)) + demographics = test_batch.non_tensor_batch.get("demo", ["unknown"] * len(input_texts)) batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] @@ -708,6 +775,16 @@ def _validate(self): output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] sample_outputs.extend(output_texts) + # Collect for metrics calculation + all_predictions.extend(output_texts) + all_ground_truths.extend(ground_truths) + all_data_sources.extend(data_sources) + all_datasets.extend(datasets) + all_demographics.extend(demographics) + data_source_lst.append( + test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(input_texts)) + ) + test_batch = test_batch.union(test_output_gen_batch) test_batch.meta_info["validate"] = True @@ -730,27 +807,23 @@ def _validate(self): if "__num_turns__" in test_batch.non_tensor_batch: sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) - data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) - self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) - # dump generations - val_data_dir = self.config.trainer.get("validation_data_dir", None) - if val_data_dir: - self._dump_generations( - inputs=sample_inputs, - outputs=sample_outputs, - gts=sample_gts, - scores=sample_scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=val_data_dir, - ) + # Per data source metrics + metrics = compute_metrics_by_data_source(all_predictions, all_ground_truths, + all_data_sources, all_datasets, all_demographics) + wandb.log(metrics, step=self.global_steps) for key_info, lst in reward_extra_infos_dict.items(): assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" data_sources = np.concatenate(data_source_lst, axis=0) + # convert to list for easier processing + data_sources = data_sources.tolist() + print(f"size of sample_scores: {len(sample_scores)}, size of sample_outputs: {len(sample_outputs)}," + f" size of sample_gts: {len(sample_gts)}, size of sample_inputs: {len(sample_inputs)}" + f", size of data_sources: {len(data_sources)}, size of sample_turns: {len(sample_turns)}") data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) metric_dict = {} for data_source, var2metric2val in data_src2var2metric2val.items(): @@ -769,6 +842,20 @@ def _validate(self): pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" metric_dict[pfx] = metric_val + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", self.config.trainer.default_local_dir) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + datasets=all_datasets, + data_paths=data_sources, + ) + if len(sample_turns) > 0: sample_turns = np.concatenate(sample_turns) metric_dict["val-aux/num_turns/min"] = sample_turns.min() @@ -777,6 +864,32 @@ def _validate(self): return metric_dict + def save_generations(self, sample_datapaths, sample_datasets, sample_inputs, sample_labels, sample_outputs, + sample_scores): + generation_save_folder = os.path.join(self.config.trainer.default_local_dir, + f"global_step_{self.global_steps}") + if not os.path.exists(generation_save_folder): + os.makedirs(generation_save_folder, exist_ok=True) + with open(os.path.join(generation_save_folder, "generations.jsonl"), "w") as f: + for i in range(len(sample_inputs)): + try: + short_answer = sample_outputs[i].split("boxed{")[1].split("}")[0] + except IndexError: + short_answer = '' + answer_is_correct = short_answer == sample_labels[i] + f.write( + ujson.dumps({ + "input": sample_inputs[i], + "generations": sample_outputs[i], + "short_answer": short_answer, + "answer_is_correct": answer_is_correct, + "label": sample_labels[i], + "score": sample_scores[i], + "dataset": sample_datasets[i], + "datapath": sample_datapaths[i], + }) + "\n" + ) + def init_workers(self): """Initialize distributed training workers using Ray backend. @@ -883,6 +996,8 @@ def init_workers(self): ) def _save_checkpoint(self): + + ## TO SAVE CHECKPOINT from verl.utils.fs import local_mkdir_safe # path: given_path + `/global_step_{global_steps}` + `/actor` @@ -1089,8 +1204,23 @@ def fit(self): ) next_step_profile = False + for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: + # i = 0 + for batch_idx, batch_dict in enumerate(self.train_dataloader): + #--- DEBUG: log batch content into debug_file --- + + + if debug_file is not None: + with open(debug_file, "a", encoding="utf-8") as f: + log_entry = { + "epoch": int(epoch), + "batch_idx": int(batch_idx), + "modality_signatures": batch_dict.get("modality_signatures", []), + "prompts": batch_dict.get("debug_prompts", []), + } + f.write(json.dumps(log_entry, ensure_ascii=False, default=lambda o: o.tolist() if isinstance(o, np.ndarray) else str(o)) + "\n") + metrics = {} timing_raw = {} @@ -1111,7 +1241,28 @@ def fit(self): # pop those keys for generation batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + + + # if "input_ids" in batch.batch: + # print(f"[DEBUG] input_ids shape: {batch.batch['input_ids'].shape}") + # print(f"[DEBUG] First sequence tokens: {batch.batch['input_ids'][0][:10].tolist()}") + + + # if "input_ids" in batch.batch: + # with open(debug_file, "a") as f: # append mode + # f.write(f"[DEBUG] Epoch {epoch}, Iter {i}\n") + # f.write(f"input_ids shape: {batch.batch['input_ids'].shape}\n") + # f.write(f"First sequence tokens: {batch.batch['input_ids'][0][:10].tolist()}\n\n") + + # if i == 5: + # raise ValueError( + # f"Debugging error at iteration 4\n" + # f"input_ids shape: {batch.batch['input_ids'].shape}\n" + # f"First 10 tokens: {batch.batch['input_ids'][0][:10].tolist()}" + # ) + if "multi_modal_data" in batch.non_tensor_batch: + # TODO: Fix the audio generation for this non_tensor_batch_keys_to_pop.append("multi_modal_data") if "raw_prompt" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("raw_prompt") @@ -1135,10 +1286,15 @@ def fit(self): is_last_step = self.global_steps >= self.total_training_steps + # TODO: double check the gen_batch + # print(f"gen_batch", gen_batch) + # i += 1 + with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw, color="red"): if not self.async_rollout_mode: + # TODO: Fix the audio generation for this gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) else: gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) diff --git a/verl/utils/dataset/audio_utils.py b/verl/utils/dataset/audio_utils.py new file mode 100644 index 00000000000..5a29550922f --- /dev/null +++ b/verl/utils/dataset/audio_utils.py @@ -0,0 +1,175 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union +import torch +import torchaudio + +def process_audio( + audio: Union[str, dict], + processor=None, + max_seconds: float = None # keep audio to this many seconds max +) -> Tuple[torch.Tensor, int]: + """ + Load audio, convert to mono, resample, and clip to max_seconds. + """ + if isinstance(audio, dict): + audio_path = audio.get("audio", audio) + else: + audio_path = audio + + try: + # Load + audio_data, original_sr = torchaudio.load(audio_path) + + # Resample if needed + if processor and hasattr(processor, 'feature_extractor') and hasattr(processor.feature_extractor, 'sampling_rate'): + target_sr = processor.feature_extractor.sampling_rate + else: + target_sr = 16000 + + if original_sr != target_sr: + resampler = torchaudio.transforms.Resample(original_sr, target_sr) + audio_data = resampler(audio_data) + else: + target_sr = original_sr + + # Convert to mono + if audio_data.shape[0] > 1: + audio_data = audio_data.mean(dim=0, keepdim=False) + else: + audio_data = audio_data.squeeze(0) + + # Clip to max_seconds + if max_seconds: + max_samples = int(max_seconds * target_sr) + + print(f"Processing Audio {audio_path}, shape={audio_data.shape}, " + f"sr={target_sr}, max_samples={max_samples}") + # ValueError("Audio was processed") + + if audio_data.shape[0] > max_samples: + print("Clipping audio to max_seconds") + audio_data = audio_data[:max_samples] + raise ValueError("Audio data was clipped to max_seconds") + + return audio_data, target_sr + + except Exception as e: + print(f"Error processing audio {audio_path}: {e}") + dummy_audio = torch.zeros((int(16000 * max_seconds),), dtype=torch.float32) + return dummy_audio, 16000 + + +# def process_audio(audio: str | dict, processor=None) -> Tuple[torch.Tensor, int]: +# if isinstance(audio, dict): +# # TODO: to check whether the keys are correct here +# audio_path = audio.get("audio", audio) +# else: +# audio_path = audio + +# try: +# # Load audio +# # NOTE: accepts waveform and sample rate; +# audio_data, original_sr = torchaudio.load(audio_path) + +# # Get target sampling rate +# # NOTE: sample rate is basically the amount of audio samples captured per second +# # 16000 means 16000 samples are taken in every second +# if processor and hasattr(processor, 'feature_extractor') and hasattr(processor.feature_extractor, +# 'sampling_rate'): +# target_sr = processor.feature_extractor.sampling_rate +# else: + +# target_sr = 16000 +# # print(f"KEANE: Processing audio {audio_path} with sampling rate, {target_sr}") +# # Resample if needed +# # NOTE: This is essentially the resampling of the audio sample rate +# if original_sr != target_sr: +# resampler = torchaudio.transforms.Resample(original_sr, target_sr) +# audio_data = resampler(audio_data) + +# # Convert to mono if stereo +# # NOTE: This is essentially the conversion of stereo audio to mono, so that we only have one channel +# if audio_data.shape[0] > 1: +# audio_data = audio_data.mean(dim=0, keepdim=False) +# else: +# audio_data = audio_data.squeeze(0) + +# # Debug prints +# # print( +# # f"KEANE: Finished processing {audio_path} -> " +# # f"waveform shape {audio_data.shape}, dtype {audio_data.dtype}, " +# # # f"min {audio_data.min().item():.4f}, max {audio_data.max().item():.4f}" +# # ) +# # print(f"KEANE: Returning tuple (waveform, sr={target_sr})") + +# return audio_data, target_sr +# except Exception as e: +# print(f"Error processing audio {audio_path}: {e}") +# dummy_audio = torch.zeros((1000,), dtype=torch.float32) +# return dummy_audio, 16000 + + +# def process_audio(audio: str | dict, processor=None) -> np.ndarray: +# """ +# NOTE: Keane's implementation +# """ +# if isinstance(audio, dict): +# audio_path = audio.get("audio", audio) +# else: +# audio_path = audio + +# try: +# # Load audio -> (channels, time), sample_rate +# audio_data, original_sr = torchaudio.load(audio_path) + +# # Target sampling rate (from processor or default to 16k) +# if ( +# processor +# and hasattr(processor, "feature_extractor") +# and hasattr(processor.feature_extractor, "sampling_rate") +# ): +# target_sr = processor.feature_extractor.sampling_rate +# else: +# target_sr = 16000 + +# print(f"KEANE: Processing audio {audio_path} with target sampling rate {target_sr}") + +# # Resample if needed +# if original_sr != target_sr: +# resampler = torchaudio.transforms.Resample(original_sr, target_sr) +# audio_data = resampler(audio_data) + +# # Convert to mono if stereo +# if audio_data.shape[0] > 1: +# audio_data = audio_data.mean(dim=0, keepdim=False) +# else: +# audio_data = audio_data.squeeze(0) + +# # Convert to numpy float32 (1-D) +# audio_np = audio_data.detach().cpu().numpy().astype(np.float32) + +# print( +# f"KEANE: Finished {audio_path} -> " +# f"waveform shape {audio_np.shape}, dtype {audio_np.dtype}" +# ) + +# # NOTE: we only need to return the numpy array and not the sampling rate + +# return audio_np + +# except Exception as e: +# print(f"Error processing audio {audio_path}: {e}") +# return np.zeros((1000,), dtype=np.float32) # dummy 1-D waveform diff --git a/verl/utils/dataset/modality_sampler.py b/verl/utils/dataset/modality_sampler.py new file mode 100644 index 00000000000..94e34b00b7b --- /dev/null +++ b/verl/utils/dataset/modality_sampler.py @@ -0,0 +1,79 @@ +import random +from typing import Dict, List, Iterator, Optional +from collections import defaultdict, deque +from torch.utils.data import BatchSampler + +class ModalitySignatureBatchSampler(BatchSampler): + """ + Round-robin across modality signatures, pruning exhausted signatures. + - Shuffles within each signature if shuffle=True (train). + - Each yielded batch is homogeneous by modality_signature. + - If a signature runs out of batches, it is removed and RR continues. + """ + def __init__( + self, + indices_by_sig: Dict[str, List[int]], + batch_size: int, + drop_last: bool = True, + seed: int = 42, + shuffle: bool = True, + ): + self.indices_by_sig = {s: list(v) for s, v in indices_by_sig.items()} + self.batch_size = int(batch_size) + self.drop_last = drop_last + self.shuffle = shuffle + self.rng = random.Random(seed) + self.sigs = list(self.indices_by_sig.keys()) + + def _batches_for(self, pool: List[int]) -> List[List[int]]: + n = len(pool) + batches = [] + for start in range(0, n, self.batch_size): + chunk = pool[start:start + self.batch_size] + if len(chunk) < self.batch_size and self.drop_last: + continue + if chunk: + batches.append(chunk) + return batches + + def __iter__(self) -> Iterator[List[int]]: + # Fresh pools + optional shuffle within each signature + pools = {s: list(v) for s, v in self.indices_by_sig.items()} + for s in pools: + if self.shuffle: + self.rng.shuffle(pools[s]) + + # Build per-signature batch queues; essentially a dictionary with batches of each different modality signature + per_sig_batches = {s: deque(self._batches_for(pools[s])) for s in self.sigs} + + # Establish RR order + order = list(self.sigs) + if self.shuffle: + # rotate start signature per epoch for variety (keeps RR structure) + k = self.rng.randrange(len(order)) if order else 0 + order = order[k:] + order[:k] + else: + order = sorted(order) + + # Active signatures as a deque for easy rotation + active = deque([s for s in order if len(per_sig_batches[s]) > 0]) + + while active: + s = active.popleft() # take the queue's leftmost element (modality signature) + q = per_sig_batches[s] # access all of the batched stuff + if q: + yield q.popleft() # yield that batch + # if still has batches, push to the end to continue RR + if q: + active.append(s) # reappend the modality signature to the active queue + # if q is empty, we simply don't re-append s → pruned automatically + else: + print(f"Ran-Out: Pruning modality signature: {s}") + + def __len__(self) -> int: + # Total number of batches across all signatures (after drop_last handling) + total = 0 + for pool in self.indices_by_sig.values(): + full, rem = divmod(len(pool), self.batch_size) + total += full + (0 if self.drop_last or rem == 0 else 1) + return total \ No newline at end of file diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 2c19385c2b3..cad3fbfecfe 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -19,20 +19,42 @@ import os import re from collections import defaultdict -from typing import Optional - +from typing import Optional,Dict, Any, List import datasets import numpy as np import torch +from jinja2 import Template from omegaconf import DictConfig, ListConfig from torch.utils.data import Dataset from transformers import PreTrainedTokenizer, ProcessorMixin +import warnings import verl.utils.torch_functional as verl_F from verl.utils.model import compute_position_id_with_mask +import time, os, math, warnings logger = logging.getLogger(__name__) +def _tok_est_from_hw(H, W): + # 28x28 -> 1 "visual token" heuristic + return math.ceil(H/28) * math.ceil(W/28) + +def _sec_from_array(arr, sr): + try: + return round(len(arr) / float(sr), 3) + except Exception: + return "?" + +def _p99(xs): + xs = sorted(xs) + if not xs: return 0 + k = int(0.99*(len(xs)-1)) + return xs[k] + +def assert_homogeneous(batch_list: List[Dict[str, Any]]): + sigs = {b.get("modality_signature") for b in batch_list} + if len(sigs) != 1: + raise AssertionError(f"Non-homogeneous batch signatures: {sigs}") def collate_fn(data_list: list[dict]) -> dict: """ @@ -43,9 +65,13 @@ def collate_fn(data_list: list[dict]) -> dict: Returns: Dict where tensor entries are stacked into a torch.Tensor of shape - (batch_size, \*dims) and non-tensor entries are converted to + (batch_size, dims) and non-tensor entries are converted to np.ndarray of dtype object with shape (batch_size,). """ + # data list is the batch list + # NOTE: we assert homogeneous if the modality signatures are not homogeneous + assert_homogeneous(data_list) # assert if not homogeneous + tensors = defaultdict(list) non_tensors = defaultdict(list) @@ -99,14 +125,29 @@ def __init__( self.config = config self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + + # Essentially getting all the different keys. self.prompt_key = config.get("prompt_key", "prompt") self.image_key = config.get("image_key", "images") self.video_key = config.get("video_key", "videos") + + # NOTE: SET AUDIO KEY AS AUDIOS + self.audio_key = config.get("audio_key", "audios") + + # NOTE: SET MODALITIES, split the images and videos + self.modalities = set(config.get("modalities", "images,videos").split(",")) + self.max_prompt_length = config.get("max_prompt_length", 1024) self.return_raw_chat = config.get("return_raw_chat", False) self.return_full_prompt = config.get("return_full_prompt", False) self.truncation = config.get("truncation", "error") + + # TODO: Check whether this is true self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) + if isinstance(data_files, str): + self.base_dir = os.path.dirname(os.path.abspath(data_files)) + else: + self.base_dir = os.path.dirname(os.path.abspath(data_files[0])) self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) self.num_workers = min(self.num_workers, os.cpu_count()) @@ -116,9 +157,21 @@ def __init__( self.filter_prompts = config.get("filter_prompts", True) self.serialize_dataset = False self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True) + + # Load format prompt from file if specified + self.format_prompt_path = config.get("format_prompt", "examples/format_prompt/default.jinja") + self.format_prompt = self._load_format_prompt() self._download() - self._read_files_and_tokenize() + self._read_files_and_tokenize() # essentially this is prepared first before _getitem + + def _load_format_prompt(self) -> Optional[Template]: + """Load format prompt from file if specified.""" + if self.format_prompt_path: + with open(self.format_prompt_path, 'r', encoding='utf-8') as f: + template_content = f.read() + return Template(template_content) + return None def _download(self, use_origin_parquet=False): from verl.utils.fs import copy_to_local @@ -129,40 +182,92 @@ def _download(self, use_origin_parquet=False): def _read_files_and_tokenize(self): dataframes = [] + + features = datasets.Features({ + "problem": datasets.Value("string"), + "answer": datasets.Value("string"), + "images": datasets.Sequence(datasets.Value("string")), + "videos": datasets.Sequence(datasets.Value("string")), + "audios": datasets.Sequence(datasets.Value("string")), # <- force list of strings + "dataset": datasets.Value("string"), + "texts": datasets.Sequence(datasets.Value("string")), + "modality_signature": datasets.Value("string"), + }) + for parquet_file in self.data_files: # read parquet files and cache - dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + if parquet_file.endswith(".parquet"): + dataframe = datasets.load_dataset("parquet", data_files=parquet_file, features=features)["train"] + elif parquet_file.endswith(".json") or parquet_file.endswith(".jsonl"): + dataframe = datasets.load_dataset("json", data_files=parquet_file, features=features)["train"] + else: + raise ValueError(f"Unsupported file format: {parquet_file}. Only .parquet, .json, .jsonl are supported.") dataframes.append(dataframe) self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) print(f"dataset len: {len(self.dataframe)}") + # PROCESSING THE DATAFRAME for TRAINING self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe) def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None): - # filter out too long prompts + # NOTE: filter out too long prompts, because the prompts can become very long + # when the audio is appended. + if self.filter_overlong_prompts: + # NOTE: FILTER OUT THE LONG PROMPTS SO THAT THEY FIT THE LENGTH tokenizer = self.tokenizer processor = self.processor prompt_key = self.prompt_key image_key = self.image_key video_key = self.video_key + audio_key = self.audio_key if processor is not None: + # print(f"KEANE: PROCESSOR FOUND") from verl.utils.dataset.vision_utils import process_image, process_video + from verl.utils.dataset.audio_utils import process_audio def doc2len(doc) -> int: messages = self._build_messages(doc) raw_prompt = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) - images = [process_image(image) for image in doc[image_key]] if image_key in doc else None - videos = [process_video(video) for video in doc[video_key]] if video_key in doc else None - - return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0]) + processor_kwargs = {"text": [raw_prompt]} + + if "images" in self.modalities and image_key in doc and len(doc[image_key]) > 0: + images = [process_image(image) for image in doc[image_key]] + processor_kwargs["images"] = images + + if "videos" in self.modalities and video_key in doc and len(doc[video_key]) > 0: + videos = [process_video(video) for video in doc[video_key]] + processor_kwargs["videos"] = videos + + if "audio" in self.modalities and audio_key in doc and doc.get(audio_key, None) is not None and len(doc[audio_key]) > 0: + # processing of audio + # print(f"KEANE: Processing audio within rl dataset file") + # audios = [process_audio(audio, processor) for audio in doc[audio_key]] + # processor_kwargs["audio"] = audios + + # PATCH + audios = [] + audio_tuples = [] # Keep tuples for multi_modal_data + for audio in doc.get(self.audio_key): + audio_path = os.path.join(self.base_dir, audio) if isinstance(audio, str) else audio + audio_data, sampling_rate = process_audio(audio_path, self.processor) + audio_tuples.append((audio_data, sampling_rate)) + # audios.append(audio_data.numpy()) # Convert to numpy array for Whisper + audios.append(audio_data.detach().cpu().numpy().astype("float32")) + + processor_kwargs["audio"] = audios # Pass numpy arrays to processor + # TODO: cannot process the audio inputs + # print(f"KEANE: Processor class is {processor.__class__.__name__}") + # print(f"KEANE: Printing the processor_kwargs, {processor_kwargs}") + # Assume that all are in tensors already, hence there is no return_tensors = "pt" + return len(processor(**processor_kwargs)["input_ids"][0]) else: - + # print(f"KEANE: PROCESSOR NOT FOUND") def doc2len(doc) -> int: return len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) @@ -188,58 +293,319 @@ def __len__(self): return len(self.dataframe) def _build_messages(self, example: dict): - messages: list = example.pop(self.prompt_key) - - if self.image_key in example or self.video_key in example: + """ + This appears to be called twice, once during maybe_filter_out_long_prompts, and another time during getitems + """ + messages: list = example.get(self.prompt_key) + if isinstance(messages, str): + messages = [messages] + + # NOTE: Before building, check if there is multimodal content + has_multimodal = ( + ("images" in self.modalities and self.image_key in example) or + ("videos" in self.modalities and self.video_key in example) or + ("audio" in self.modalities and self.audio_key in example) + ) + + if has_multimodal: + new_messages = [] for message in messages: - content = message["content"] - content_list = [] - segments = re.split("(|