diff --git a/_unit_test_modality_sampler.py b/_unit_test_modality_sampler.py
new file mode 100644
index 00000000000..323d2461439
--- /dev/null
+++ b/_unit_test_modality_sampler.py
@@ -0,0 +1,308 @@
+# test_stateful_modality_sampler_hardcoded.py
+
+import json
+from typing import Dict, Any, List, Iterator
+from torch.utils.data import Dataset, BatchSampler
+import random
+
+# ==== ADJUST PATHS below to match your repo structure ====
+# from verl.utils.dataset.modality_sampler import ModalitySignatureBatchSampler
+from torchdata.stateful_dataloader import StatefulDataLoader
+from collections import defaultdict, deque
+import torch
+import numpy as np
+
+# ---------- HARD-CODED PATHS + CONFIG ----------
+JSONL_PATH = "/Users/keane/Desktop/research/human-behavior/data/all/sigs_no_lmvd_discretized_v3_template_prompts.jsonl"
+TRAIN_BS = 4
+VAL_BS = 4
+SEED = 42
+TRUNCATE_RATIO = 0.001 # for quick testing; set to 1.0 to disable
+# ---------------------------------------------
+
+# TODO: Please remove text only; everything should be text_only
+
+class ModalitySignatureBatchSampler(BatchSampler):
+ """
+ Round-robin across modality signatures, pruning exhausted signatures.
+ - Shuffles within each signature if shuffle=True (train).
+ - Each yielded batch is homogeneous by modality_signature.
+ - If a signature runs out of batches, it is removed and RR continues.
+ """
+ def __init__(
+ self,
+ indices_by_sig: Dict[str, List[int]],
+ batch_size: int,
+ drop_last: bool = True,
+ seed: int = 42,
+ shuffle: bool = True,
+ ):
+ self.indices_by_sig = {s: list(v) for s, v in indices_by_sig.items()}
+ self.batch_size = int(batch_size)
+ self.drop_last = drop_last
+ self.shuffle = shuffle
+ self.rng = random.Random(seed)
+ self.sigs = list(self.indices_by_sig.keys())
+
+ def _batches_for(self, pool: List[int]) -> List[List[int]]:
+ n = len(pool)
+ batches = []
+ for start in range(0, n, self.batch_size):
+ chunk = pool[start:start + self.batch_size]
+ if len(chunk) < self.batch_size and self.drop_last:
+ continue
+ if chunk:
+ batches.append(chunk)
+ return batches
+
+ def __iter__(self) -> Iterator[List[int]]:
+ # Fresh pools + optional shuffle within each signature
+ pools = {s: list(v) for s, v in self.indices_by_sig.items()}
+ for s in pools:
+ if self.shuffle:
+ self.rng.shuffle(pools[s])
+
+ # Build per-signature batch queues; essentially a dictionary with batches of each different modality signature
+ per_sig_batches = {s: deque(self._batches_for(pools[s])) for s in self.sigs}
+
+ # Establish RR order
+ order = list(self.sigs)
+ if self.shuffle:
+ # rotate start signature per epoch for variety (keeps RR structure)
+ k = self.rng.randrange(len(order)) if order else 0
+ order = order[k:] + order[:k]
+ else:
+ order = sorted(order)
+
+ # Active signatures as a deque for easy rotation
+ active = deque([s for s in order if len(per_sig_batches[s]) > 0])
+
+ while active:
+ s = active.popleft() # take the queue's leftmost element (modality signature)
+ q = per_sig_batches[s] # access all of the batched stuff
+ if q:
+ yield q.popleft() # yield that batch
+ # if still has batches, push to the end to continue RR
+ if q:
+ active.append(s) # reappend the modality signature to the active queue
+ # if q is empty, we simply don't re-append s → pruned automatically
+ else:
+ print(f"Ran-Out: Pruning modality signature: {s}")
+
+ def __len__(self) -> int:
+ # Total number of batches across all signatures (after drop_last handling)
+ total = 0
+ for pool in self.indices_by_sig.values():
+ full, rem = divmod(len(pool), self.batch_size)
+ total += full + (0 if self.drop_last or rem == 0 else 1)
+ return total
+
+
+def rl_collate_fn(data_list: list[dict]) -> dict:
+ """
+ Collate a batch of sample dicts into batched tensors and arrays.
+
+ Args:
+ data_list: List of dicts mapping feature names to torch.Tensor or other values.
+
+ Returns:
+ Dict where tensor entries are stacked into a torch.Tensor of shape
+ (batch_size, dims) and non-tensor entries are converted to
+ np.ndarray of dtype object with shape (batch_size,).
+ """
+ tensors = defaultdict(list)
+ non_tensors = defaultdict(list)
+
+ for data in data_list:
+ for key, val in data.items():
+ if isinstance(val, torch.Tensor):
+ tensors[key].append(val)
+ else:
+ non_tensors[key].append(val)
+
+ for key, val in tensors.items():
+ tensors[key] = torch.stack(val, dim=0)
+
+ for key, val in non_tensors.items():
+ non_tensors[key] = np.fromiter(val, dtype=object, count=len(val))
+
+ return {**tensors, **non_tensors}
+
+def create_rl_sampler(data_config, dataset, split: str = "train"):
+ """Create a sampler for the dataset, grouping strictly by existing modality_signature."""
+ import torch
+ from torch.utils.data import RandomSampler, SequentialSampler
+
+ mb_cfg = data_config.get("modality_batching") if split == "train" \
+ else data_config.get("val_modality_batching")
+
+ # (keep curriculum path if you actually use it; omitted here for brevity)
+
+ if mb_cfg and mb_cfg.get("enabled", False):
+ by_sig: Dict[str, List[int]] = {}
+ for i in range(len(dataset)):
+ row = dataset.dataframe[i] if hasattr(dataset, "dataframe") else dataset[i]
+ sig = row.get("modality_signature")
+ if sig is None:
+ print(f"[WARNING] Row {i} missing 'modality_signature'. Skipping.")
+ continue
+ by_sig.setdefault(sig, []).append(i)
+
+ batch_size = mb_cfg.get("batch_size", data_config.get(
+ "train_batch_size" if split=="train" else "val_batch_size"
+ ))
+ drop_last = mb_cfg.get("drop_last", split=="train")
+ shuffle = (split == "train")
+
+ return ModalitySignatureBatchSampler(
+ indices_by_sig=by_sig,
+ batch_size=int(batch_size),
+ drop_last=drop_last,
+ seed=data_config.get("seed", 42),
+ shuffle=shuffle,
+ )
+
+ # Fallbacks
+ if data_config.get("shuffle", True) and split == "train":
+ g = torch.Generator(); g.manual_seed(data_config.get("seed", 1))
+ return RandomSampler(data_source=dataset, generator=g)
+ else:
+ return SequentialSampler(data_source=dataset)
+
+class JsonlDataset(Dataset):
+ def __init__(self, jsonl_path: str, truncate_ratio: float = TRUNCATE_RATIO, seed: int = SEED):
+ """
+ Loads ONLY entries that already have 'modality_signature'.
+ Optionally keeps a proportion per signature for fast debugging.
+ """
+ all_rows: List[Dict[str, Any]] = []
+ with open(jsonl_path, "r", encoding="utf-8") as f:
+ for ln in f:
+ ln = ln.strip()
+ if not ln:
+ continue
+ ex = json.loads(ln)
+ sig = ex.get("modality_signature")
+ if sig is None:
+ print(f"[WARNING] Entry missing 'modality_signature'. Skipping.")
+ continue # skip missing
+ all_rows.append(ex)
+
+ # Group by signature and truncate per signature
+ sig_to_rows: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
+ for ex in all_rows:
+ sig_to_rows[ex["modality_signature"]].append(ex)
+
+ rng = random.Random(seed)
+ truncated_rows: List[Dict[str, Any]] = []
+ for sig, rows in sig_to_rows.items():
+ if truncate_ratio >= 1.0:
+ truncated_rows.extend(rows)
+ continue
+ keep_n = max(1, int(len(rows) * truncate_ratio))
+ rng.shuffle(rows)
+ truncated_rows.extend(rows[:keep_n])
+
+ self.rows = truncated_rows
+ self.dataframe = self # preserve your API
+
+ # simple stats
+ counts = {sig: sum(1 for r in self.rows if r["modality_signature"] == sig) for sig in sig_to_rows}
+ print(f"[DEBUG] After truncation (ratio={truncate_ratio}), total {len(self.rows)}. Per-signature: {counts}")
+
+ def __len__(self):
+ return len(self.rows)
+
+ def __getitem__(self, idx):
+ return self.rows[idx]
+
+
+def assert_homogeneous(batch_list: List[Dict[str, Any]]):
+ sigs = {b.get("modality_signature") for b in batch_list}
+ if len(sigs) != 1:
+ raise AssertionError(f"Non-homogeneous batch signatures: {sigs}")
+
+def collate_with_guard(batch_list):
+ assert_homogeneous(batch_list)
+ return rl_collate_fn(batch_list)
+
+def build_cfg(train_bs: int, val_bs: int, seed: int = 42):
+ class Dot(dict):
+ __getattr__ = dict.get
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__
+ return Dot({
+ "train_batch_size": train_bs,
+ "val_batch_size": val_bs,
+ "shuffle": True,
+ "seed": seed,
+ "dataloader_num_workers": 0,
+ "validation_shuffle": False,
+ "sampler": None,
+ "modality_batching": {"enabled": True, "batch_size": train_bs, "drop_last": True},
+ "val_modality_batching": {"enabled": True, "batch_size": val_bs, "drop_last": False},
+ })
+
+def build_loader(dataset, data_cfg, split: str):
+ sampler_or_batch = create_rl_sampler(data_cfg, dataset, split=split)
+ if isinstance(sampler_or_batch, BatchSampler):
+ return StatefulDataLoader(
+ dataset=dataset,
+ batch_sampler=sampler_or_batch,
+ num_workers=data_cfg["dataloader_num_workers"],
+ collate_fn=collate_with_guard,
+ )
+ else:
+ bs = data_cfg.get("train_batch_size" if split == "train" else "val_batch_size")
+ return StatefulDataLoader(
+ dataset=dataset,
+ sampler=sampler_or_batch,
+ batch_size=bs,
+ num_workers=data_cfg["dataloader_num_workers"],
+ drop_last=(split == "train"),
+ shuffle=False if split == "val" else False,
+ collate_fn=collate_with_guard,
+ )
+
+def main():
+ ds = JsonlDataset(JSONL_PATH)
+ print(f"Dataset size: {len(ds)}; per-signature counts:",
+ {sig: sum(1 for r in ds.rows if r['modality_signature']==sig)
+ for sig in sorted({r['modality_signature'] for r in ds.rows})})
+
+ cfg = build_cfg(TRAIN_BS, VAL_BS, SEED)
+
+ # TRAIN
+ train_loader = build_loader(ds, cfg, split="train")
+ print("\n[TRAIN] Iteration 1")
+ n_train_batches = sum(1 for _ in train_loader) # iterating as you would with the train loader
+ print(f"train steps: {n_train_batches} (drop_last=True)")
+
+ # New epoch
+ train_loader2 = build_loader(ds, cfg, split="train")
+ n_train_batches2 = sum(1 for _ in train_loader2)
+ assert n_train_batches == n_train_batches2
+ print("[TRAIN] Iteration 2: step count consistent")
+
+ # VAL
+ val_loader = build_loader(ds, cfg, split="val")
+ print("\n[VAL] Iteration 1")
+ n_val_batches = sum(1 for _ in val_loader)
+ print(f"val steps: {n_val_batches} (drop_last=False)")
+
+ # Stateful resume check (if supported)
+ if hasattr(train_loader, "state_dict"):
+ print("\n[STATEFUL] Testing resume mid-epoch")
+ train_loader3 = build_loader(ds, cfg, split="train")
+ it = iter(train_loader3)
+ next(it); next(it) # consume 2
+ sd = train_loader3.state_dict()
+ train_loader4 = build_loader(ds, cfg, split="train")
+ train_loader4.load_state_dict(sd)
+ resumed = sum(1 for _ in train_loader4)
+ print(f"resumed batches after 2 consumed: {resumed}")
+
+ print("\nOK: StatefulDataLoader + sampler test finished.")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/dvd_requirements.txt b/dvd_requirements.txt
new file mode 100644
index 00000000000..7210f325c6b
--- /dev/null
+++ b/dvd_requirements.txt
@@ -0,0 +1,282 @@
+# torch==2.7.1+cu126
+absl-py==2.3.0
+accelerate==1.8.1
+aiohappyeyeballs==2.6.1
+aiohttp==3.12.13
+aiohttp-cors==0.8.1
+aiosignal==1.3.2
+airportsdata==20250622
+alembic==1.16.2
+aliyun-python-sdk-core==2.16.0
+aliyun-python-sdk-kms==2.16.5
+annotated-types==0.7.0
+antlr4-python3-runtime==4.9.3
+anyio==4.9.0
+astor==0.8.1
+attrs==25.3.0
+audioread==3.0.1
+av==14.4.0
+beautifulsoup4==4.13.4
+blake3==1.0.5
+cachetools==5.5.2
+cbor2==5.6.5
+certifi==2025.8.3
+cffi==1.17.1
+chardet==5.2.0
+charset-normalizer==3.4.2
+click==8.2.1
+cloudpickle==3.1.1
+codetiming==1.4.0
+color-matcher==0.6.0
+colorama==0.4.6
+colorful==0.5.6
+comfyui-embedded-docs==0.2.3
+comfyui_frontend_package==1.23.4
+comfyui_workflow_templates==0.1.32
+compressed-tensors==0.10.2
+contourpy==1.3.2
+crcmod==1.7
+cryptography==45.0.5
+cuda-bindings==12.9.0
+cuda-python==12.9.0
+cupy-cuda12x==13.4.1
+cycler==0.12.1
+datasets==4.0.0
+ddt==1.7.2
+decorator==5.2.1
+Deprecated==1.2.18
+depyf==0.19.0
+diffusers==0.34.0
+dill==0.3.8
+diskcache==5.6.3
+distlib==0.3.9
+distro==1.9.0
+dnspython==2.7.0
+docutils==0.21.2
+einops==0.8.1
+email_validator==2.2.0
+fastapi==0.115.14
+fastapi-cli==0.0.7
+fastrlock==0.8.3
+ffmpeg-python==0.2.0
+# filelock==3.19.1
+flash_attn==2.8.0.post2
+flatbuffers==25.2.10
+fonttools==4.58.4
+frozenlist==1.7.0
+fsspec==2025.3.0
+ftfy==6.3.1
+future==1.0.0
+gguf==0.17.1
+gitdb==4.0.12
+GitPython==3.1.44
+google-api-core==2.25.1
+google-auth==2.40.3
+googleapis-common-protos==1.70.0
+greenlet==3.2.3
+grpcio==1.73.1
+h11==0.16.0
+heavyball==1.7.2
+hf-xet==1.1.5
+httpcore==1.0.9
+httptools==0.6.4
+httpx==0.28.1
+huggingface-hub==0.34.1
+hydra-core==1.3.2
+idna==3.10
+imageio==2.37.0
+# importlib_metadata==8.7.0
+interegular==0.3.3
+jax==0.7.0
+jaxlib==0.7.0
+Jinja2==3.1.6
+jiter==0.10.0
+jmespath==0.10.0
+joblib==1.5.1
+jsonschema==4.24.0
+jsonschema-specifications==2025.4.1
+kiwisolver==1.4.8
+kornia==0.8.1
+kornia_rs==0.1.9
+lark==1.2.2
+lazy_loader==0.4
+librosa==0.11.0
+liger_kernel==0.5.10
+llguidance==0.7.30
+llvmlite==0.44.0
+lm-format-enforcer==0.10.11
+Mako==1.3.10
+Markdown==3.8.2
+markdown-it-py==3.0.0
+MarkupSafe==3.0.2
+mathruler==0.1.0
+matplotlib==3.10.3
+matrix-client==0.4.0
+mdurl==0.1.2
+mediapipe==0.10.21
+mistral_common==1.8.3
+ml_dtypes==0.5.3
+model-index==0.1.11
+mpi4py==4.1.0
+mpmath==1.3.0
+msgpack==1.1.1
+msgspec==0.19.0
+multidict==6.6.0
+multiprocess==0.70.16
+natsort==8.4.0
+nest-asyncio==1.6.0
+networkx==3.5
+ninja==1.11.1.4
+numba==0.61.2
+numpy==1.26.4
+nvidia-cublas-cu12==12.6.4.1
+nvidia-cuda-cupti-cu12==12.6.80
+nvidia-cuda-nvrtc-cu12==12.6.77
+nvidia-cuda-runtime-cu12==12.6.77
+nvidia-cudnn-cu12==9.5.1.17
+nvidia-cufft-cu12==11.3.0.4
+nvidia-cufile-cu12==1.11.1.6
+nvidia-curand-cu12==10.3.7.77
+nvidia-cusolver-cu12==11.7.1.2
+nvidia-cusparse-cu12==12.5.4.2
+nvidia-cusparselt-cu12==0.6.3
+nvidia-ml-py==12.575.51
+nvidia-nccl-cu12==2.26.2
+nvidia-nvjitlink-cu12==12.6.85
+nvidia-nvshmem-cu12==3.3.9
+nvidia-nvtx-cu12==12.6.77
+olefile==0.47
+omegaconf==2.3.0
+openai==1.90.0
+opencensus==0.11.4
+opencensus-context==0.1.3
+opencv-contrib-python==4.11.0.86
+opencv-python-headless==4.11.0.86
+opendatalab==0.0.10
+openmim==0.3.9
+opentelemetry-api==1.26.0
+opentelemetry-exporter-otlp-proto-grpc==1.26.0
+opentelemetry-exporter-otlp-proto-http==1.26.0
+opentelemetry-exporter-prometheus==0.41b0
+opentelemetry-proto==1.26.0
+opentelemetry-sdk==1.26.0
+# opentelemetry-semantic-conventions==0.55b1
+openxlab==0.1.2
+opt_einsum==3.4.0
+ordered-set==4.1.0
+orjson==3.10.18
+oss2==2.17.0
+outlines==0.1.11
+outlines_core==0.2.10
+packaging==24.2
+pandas==2.3.0
+partial-json-parser==0.2.1.1.post6
+peft==0.15.2
+piexif==1.1.3
+pillow==11.2.1
+pip==25.1
+platformdirs==4.3.8
+pooch==1.8.2
+prometheus_client==0.22.1
+prometheus-fastapi-instrumentator==7.1.0
+propcache==0.3.2
+proto-plus==1.26.1
+protobuf==4.25.8
+psutil==7.0.0
+py-cpuinfo==9.0.0
+py-spy==0.4.0
+pyarrow==20.0.0
+pyasn1==0.6.1
+pyasn1_modules==0.4.2
+pybase64==1.4.1
+pycountry==24.6.1
+pycparser==2.22
+pycryptodome==3.23.0
+pydantic==2.11.7
+pydantic_core==2.33.2
+pydantic-extra-types==2.10.5
+pydantic-settings==2.10.1
+PyGithub==2.6.1
+Pygments==2.19.2
+PyJWT==2.10.1
+pylatexenc==2.10
+pyloudnorm==0.1.1
+PyNaCl==1.5.0
+pynvml==12.0.0
+pyparsing==3.2.3
+python-dateutil==2.9.0.post0
+python-dotenv==1.1.1
+python-json-logger==4.0.0.dev0
+python-multipart==0.0.20
+pytz==2023.4
+PyYAML==6.0.2
+pyzmq==27.0.0
+qwen-vl-utils==0.0.11
+ray==2.47.1
+referencing==0.36.2
+regex==2024.11.6
+# requests==2.28.2
+rich==14.1.0
+rich-toolkit==0.14.7
+rpds-py==0.25.1
+rsa==4.9.1
+safetensors==0.6.0rc0
+# sageattention==2.2.0
+scikit-image==0.25.2
+scikit-learn==1.7.0
+scipy==1.16.0
+sentencepiece==0.2.0
+sentry-sdk==2.32.0
+setproctitle==1.3.6
+setuptools==79.0.1
+shellingham==1.5.4
+six==1.17.0
+smart-open==7.1.0
+smmap==5.0.2
+sniffio==1.3.1
+sounddevice==0.5.2
+soundfile==0.13.1
+soupsieve==2.7
+soxr==0.5.0.post1
+spandrel==0.4.1
+SQLAlchemy==2.0.41
+starlette==0.46.2
+sympy==1.14.0
+tabulate==0.9.0
+tensorboard==2.19.0
+tensorboard-data-server==0.7.2
+tensordict==0.8.3
+threadpoolctl==3.6.0
+tifffile==2025.6.11
+tiktoken==0.9.0
+timm==1.0.16
+tokenizers==0.21.2
+toml==0.10.2
+# torchaudio==2.7.1
+# torchcodec==0.4.0+cu126
+torchdata==0.11.0
+torchsde==0.2.6
+# torchvision==0.22.1
+# tqdm==4.65.2
+trampoline==0.1.2
+transformers==4.54.0
+triton==3.3.1
+typer==0.16.0
+typing_extensions==4.14.0
+typing-inspection==0.4.1
+tzdata==2025.2
+ujson==5.10.0
+urllib3==1.26.20
+uv==0.7.19
+uvicorn==0.34.3
+uvloop==0.21.0
+virtualenv==20.31.2
+vllm==0.8.4
+wandb==0.20.1
+watchdog==6.0.0
+watchfiles==1.1.0
+wcwidth==0.2.13
+websockets==15.0.1
+Werkzeug==3.1.3
+wfdb==4.3.0
+wheel==0.45.1
\ No newline at end of file
diff --git a/examples/grpo_trainer/_debug_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh b/examples/grpo_trainer/_debug_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh
new file mode 100755
index 00000000000..3522b2fea39
--- /dev/null
+++ b/examples/grpo_trainer/_debug_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh
@@ -0,0 +1,121 @@
+set -x
+
+unset ROCR_VISIBLE_DEVICES
+
+# actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct
+# actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B
+# data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \
+# data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \
+# data.modalities=\'audio,videos\' \
+
+# SETTING OF SAVE PATH: trainer.default_local_dir= /scratch/keane/human_behaviour/2_models_hb_vision_only
+# SETTING OF THE LOAD PATH from directory of checkpoints is also: trainer.default_local_dir
+
+# TRAINING FROM scratch: trainer.resume_mode == "disable" (default will save into default_local_dir)
+
+# TRAINING AUTOMATICALLY (i.e. from scratch or from latest checkpoint) :
+ # trainer.resume_mode == "auto" and then the model will take the latest ckpt from trainer.default_hdfs_dir
+
+# TRAINING from specific CHECKPOINT: trainer.resume_mode == "resume_path" and then specify trainer.resume_from_path
+ # Setting of path to resume training from trainer.resume_from_path (exact path of checkpoint)
+ # the model will take from resume_from_path directly (absolute path), and ignore default_hdfs_dir
+
+# for validation, set val_before_train=True ; make sure that the checkpoint is loaded and put val_only=True
+# the checkpoint should already be loaded before that
+# and then we will just evaluate
+
+# ALTERNATIVES
+# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_v3_template_prompts.jsonl
+# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_no_chalearn_v3_template_prompts.jsonl
+
+# when resuming training from a loaded checkpoint cuda OOM error
+
+# alt: /scratch/keane/human_behaviour/human_behaviour_data/0.1_train_no_lmvd_discretized_v3_template_prompts.jsonl
+# org: /scratch/keane/human_behaviour/human_behaviour_data/train_no_lmvd_discretized_v3_template_prompts.jsonl
+
+# LORA:
+ # actor_rollout_ref.model.use_shm=True \
+ # actor_rollout_ref.model.lora_rank=32 \
+ # actor_rollout_ref.model.lora_alpha=32 \
+ # actor_rollout_ref.rollout.load_format=safetensors \
+ # actor_rollout_ref.model.target_modules=all-linear \
+ # actor_rollout_ref.rollout.layered_summon=True \
+
+# Set PyTorch CUDA memory allocator policies
+# export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb=128
+
+# data:
+# train_batch_size: 8
+# val_batch_size: 8
+
+# NOTE: THESE NEED TO BE TOGGLED AS TRUE
+# train_modality_batching:
+# enabled: true
+# drop_last: true
+
+# val_modality_batching:
+# enabled: true
+# drop_last: false
+
+
+PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 PYTHONPATH="/home/keaneong/human-behavior/verl:$PYTHONPATH" NCCL_ASYNC_ERROR_HANDLING=1 python3 -m verl.trainer.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_train_no_lmvd_discretized_v3_template_prompts.jsonl \
+ data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_val_no_lmvd_discretized_v3_template_prompts.jsonl \
+ data.train_batch_size=6 \
+ data.val_batch_size=6 \
+ data.max_prompt_length=4096 \
+ data.max_response_length=4096 \
+ data.filter_overlong_prompts=False \
+ data.truncation='left' \
+ data.image_key=images \
+ data.video_key=videos \
+ data.prompt_key=problem \
+ data.dataloader_num_workers=8 \
+ data.modalities=\'audio,videos\' \
+ data.train_modality_batching.enabled=True \
+ data.train_modality_batching.drop_last=True \
+ data.val_modality_batching.enabled=True \
+ data.val_modality_batching.drop_last=True \
+ data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=False \
+ actor_rollout_ref.actor.ppo_mini_batch_size=3 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.kl_loss_coef=0 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.enable_chunked_prefill=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.free_cache_engine=True \
+ actor_rollout_ref.rollout.n=5 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.rollout.max_model_len=4096 \
+ actor_rollout_ref.rollout.max_num_batched_tokens=4096 \
+ algorithm.use_kl_in_reward=False \
+ custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/human_behaviour.py \
+ custom_reward_function.name=human_behaviour_compute_score_batch \
+ reward_model.reward_manager=batch \
+ trainer.critic_warmup=0 \
+ trainer.logger='["console","wandb"]' \
+ trainer.project_name='verl_hb' \
+ trainer.experiment_name='mixed_modal_omni' \
+ trainer.n_gpus_per_node=3 \
+ trainer.nnodes=1 \
+ trainer.save_freq=-1 \
+ trainer.val_before_train=False \
+ trainer.test_freq=10 \
+ trainer.total_epochs=1 $@ \
+ trainer.default_local_dir=/scratch/keane/human_behaviour/mixed_modal_verl_models_hb_omni
\ No newline at end of file
diff --git a/examples/grpo_trainer/_human_behaviour.sh b/examples/grpo_trainer/_human_behaviour.sh
new file mode 100755
index 00000000000..512a4e845b1
--- /dev/null
+++ b/examples/grpo_trainer/_human_behaviour.sh
@@ -0,0 +1,55 @@
+set -x
+ENGINE=${1:-vllm}
+
+PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 python3 -m verl.trainer.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \
+ data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \
+ data.train_batch_size=512 \
+ data.val_batch_size=128 \
+ data.max_prompt_length=4096 \
+ data.max_response_length=4096 \
+ data.filter_overlong_prompts=False \
+ data.truncation='left' \
+ data.image_key=images \
+ data.video_key=videos \
+ data.prompt_key=problem \
+ data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=False \
+ actor_rollout_ref.actor.ppo_mini_batch_size=128 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.actor.use_kl_loss=True \
+ actor_rollout_ref.actor.kl_loss_coef=1e-8 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
+ actor_rollout_ref.rollout.name=$ENGINE \
+ actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.enable_chunked_prefill=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.free_cache_engine=True \
+ actor_rollout_ref.rollout.n=5 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ algorithm.use_kl_in_reward=False \
+ custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/medical.py \
+ custom_reward_function.name=medical_compute_score_batch \
+ reward_model.reward_manager=batch \
+ trainer.critic_warmup=0 \
+ trainer.logger='["console","wandb"]' \
+ trainer.project_name='verl_human_behaviour' \
+ trainer.experiment_name='qwen2_5_vl_7b_function_trial' \
+ trainer.n_gpus_per_node= 2 \
+ trainer.nnodes=1 \
+ trainer.save_freq=20 \
+ trainer.val_before_train=False \
+ trainer.test_freq=5 \
+ trainer.total_epochs=15 $@
\ No newline at end of file
diff --git a/examples/grpo_trainer/_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh b/examples/grpo_trainer/_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh
new file mode 100755
index 00000000000..9bebaf58293
--- /dev/null
+++ b/examples/grpo_trainer/_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh
@@ -0,0 +1,120 @@
+set -x
+
+unset ROCR_VISIBLE_DEVICES
+
+# actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct
+# actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B
+# data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \
+# data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \
+# data.modalities=\'audio,videos\' \
+
+# SETTING OF SAVE PATH: trainer.default_local_dir= /scratch/keane/human_behaviour/2_models_hb_vision_only
+# SETTING OF THE LOAD PATH from directory of checkpoints is also: trainer.default_local_dir
+
+# TRAINING FROM scratch: trainer.resume_mode == "disable" (default will save into default_local_dir)
+
+# TRAINING AUTOMATICALLY (i.e. from scratch or from latest checkpoint) :
+ # trainer.resume_mode == "auto" and then the model will take the latest ckpt from trainer.default_hdfs_dir
+
+# TRAINING from specific CHECKPOINT: trainer.resume_mode == "resume_path" and then specify trainer.resume_from_path
+ # Setting of path to resume training from trainer.resume_from_path (exact path of checkpoint)
+ # the model will take from resume_from_path directly (absolute path), and ignore default_hdfs_dir
+
+# for validation, set val_before_train=True ; make sure that the checkpoint is loaded and put val_only=True
+# the checkpoint should already be loaded before that
+# and then we will just evaluate
+
+# ALTERNATIVES
+# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_v3_template_prompts.jsonl
+# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_no_chalearn_v3_template_prompts.jsonl
+
+# when resuming training from a loaded checkpoint cuda OOM error
+
+# alt: /scratch/keane/human_behaviour/human_behaviour_data/0.1_train_no_lmvd_discretized_v3_template_prompts.jsonl
+# org: /scratch/keane/human_behaviour/human_behaviour_data/train_no_lmvd_discretized_v3_template_prompts.jsonl
+
+# LORA:
+ # actor_rollout_ref.model.use_shm=True \
+ # actor_rollout_ref.model.lora_rank=32 \
+ # actor_rollout_ref.model.lora_alpha=32 \
+ # actor_rollout_ref.rollout.load_format=safetensors \
+ # actor_rollout_ref.model.target_modules=all-linear \
+ # actor_rollout_ref.rollout.layered_summon=True \
+
+# Set PyTorch CUDA memory allocator policies
+# export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb=128
+
+# data:
+# train_batch_size: 8
+# val_batch_size: 8
+
+# train_modality_batching:
+# enabled: true
+# drop_last: true
+
+# val_modality_batching:
+# enabled: true
+# drop_last: false
+
+
+PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 PYTHONPATH="/home/keaneong/human-behavior/verl:$PYTHONPATH" NCCL_ASYNC_ERROR_HANDLING=1 python3 -m verl.trainer.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_train_no_lmvd_discretized_v3_template_prompts.jsonl \
+ data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_val_no_lmvd_discretized_v3_template_prompts.jsonl \
+ data.train_batch_size=288 \
+ data.val_batch_size=144 \
+ data.max_prompt_length=4096 \
+ data.max_response_length=4096 \
+ data.filter_overlong_prompts=False \
+ data.truncation='left' \
+ data.image_key=images \
+ data.video_key=videos \
+ data.prompt_key=problem \
+ data.dataloader_num_workers=8 \
+ data.modalities=\'audio,videos\' \
+ data.train_modality_batching.enabled=True \
+ data.train_modality_batching.drop_last=True \
+ data.val_modality_batching.enabled=True \
+ data.val_modality_batching.drop_last=True \
+ data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=False \
+ actor_rollout_ref.actor.ppo_mini_batch_size=72 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.kl_loss_coef=0 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.enable_chunked_prefill=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.free_cache_engine=True \
+ actor_rollout_ref.rollout.n=5 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.rollout.max_model_len=4096 \
+ actor_rollout_ref.rollout.max_num_batched_tokens=4096 \
+ algorithm.use_kl_in_reward=False \
+ custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/human_behaviour.py \
+ custom_reward_function.name=human_behaviour_compute_score_batch \
+ reward_model.reward_manager=batch \
+ trainer.critic_warmup=0 \
+ trainer.logger='["console","wandb"]' \
+ trainer.project_name='mixed_modal_verl_hb' \
+ trainer.experiment_name='mixed_modal_omni' \
+ trainer.n_gpus_per_node=3 \
+ trainer.nnodes=1 \
+ trainer.save_freq=10 \
+ trainer.val_before_train=False \
+ trainer.test_freq=5 \
+ trainer.total_epochs=1 $@ \
+ trainer.default_local_dir=/scratch/keane/human_behaviour/mixed_modal_verl_models_hb_omni
\ No newline at end of file
diff --git a/examples/grpo_trainer/keane_vl_only_run_qwen2_5_vl-7b_hb_all_modalities.sh b/examples/grpo_trainer/keane_vl_only_run_qwen2_5_vl-7b_hb_all_modalities.sh
new file mode 100755
index 00000000000..13cf20b279d
--- /dev/null
+++ b/examples/grpo_trainer/keane_vl_only_run_qwen2_5_vl-7b_hb_all_modalities.sh
@@ -0,0 +1,66 @@
+set -x
+
+unset ROCR_VISIBLE_DEVICES
+
+# NOTE: be careful when setting filter_overlong_prompts; because this removes the prompts from the max_prompt_length
+
+# actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct
+# actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B
+# data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \
+# data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \
+# data.modalities=\'audio,videos\' \
+
+PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 PYTHONPATH="/home/keaneong/human-behavior/verl:$PYTHONPATH" NCCL_ASYNC_ERROR_HANDLING=1 python3 -m verl.trainer.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/old_train_template_prompts.jsonl \
+ data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/old_val_template_prompts.jsonl \
+ data.train_batch_size=64 \
+ data.val_batch_size=64 \
+ data.max_prompt_length=3072 \
+ data.max_response_length=1536 \
+ data.filter_overlong_prompts=False \
+ data.truncation='left' \
+ data.image_key=images \
+ data.video_key=videos \
+ data.prompt_key=problem \
+ data.dataloader_num_workers=0 \
+ data.modalities=\'videos\' \
+ data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=False \
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.kl_loss_coef=1e-8 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.enable_chunked_prefill=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.free_cache_engine=True \
+ actor_rollout_ref.rollout.n=3 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ algorithm.use_kl_in_reward=False \
+ custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/medical.py \
+ custom_reward_function.name=medical_compute_score_batch \
+ reward_model.reward_manager=batch \
+ trainer.critic_warmup=0 \
+ trainer.logger='["console","wandb"]' \
+ trainer.project_name='verl_hb' \
+ trainer.experiment_name='vision_only' \
+ trainer.n_gpus_per_node=3 \
+ trainer.nnodes=1 \
+ trainer.save_freq=20 \
+ trainer.val_before_train=False \
+ trainer.test_freq=1 \
+ trainer.total_epochs=15 $@
\ No newline at end of file
diff --git a/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh b/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh
old mode 100644
new mode 100755
diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_hb.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb.sh
new file mode 100755
index 00000000000..3f1f0796af4
--- /dev/null
+++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb.sh
@@ -0,0 +1,53 @@
+set -x
+
+python3 -m verl.trainer.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/old_train_template_prompts.jsonl \
+ data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/old_val_template_prompts.jsonl \
+ data.train_batch_size=512 \
+ data.max_prompt_length=4096 \
+ data.max_response_length=4096 \
+ data.filter_overlong_prompts=False \
+ data.truncation='left' \
+ data.image_key=images \
+ data.video_key=videos \
+ data.prompt_key=problem \
+ data.format_prompt=examples/format_prompt/default.jinja \
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=False \
+ actor_rollout_ref.actor.ppo_mini_batch_size=128 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.actor.use_kl_loss=True \
+ actor_rollout_ref.actor.kl_loss_coef=1e-8 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.enable_chunked_prefill=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.free_cache_engine=True \
+ actor_rollout_ref.rollout.n=5 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ algorithm.use_kl_in_reward=False \
+ custom_reward_function.path=examples/reward_function/medical.py \
+ custom_reward_function.name=medical_compute_score_batch \
+ reward_model.reward_manager=batch \
+ trainer.critic_warmup=0 \
+ trainer.logger='["console","wandb"]' \
+ trainer.project_name='verl_hb' \
+ trainer.experiment_name='vision_only' \
+ trainer.n_gpus_per_node=2 \
+ trainer.nnodes=1 \
+ trainer.save_freq=20 \
+ trainer.val_before_train=False \
+ trainer.test_freq=5 \
+ trainer.total_epochs=15 $@
diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_hb_all_modalities.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb_all_modalities.sh
new file mode 100755
index 00000000000..fd2d60ba553
--- /dev/null
+++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb_all_modalities.sh
@@ -0,0 +1,56 @@
+set -x
+
+PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 python3 -m verl.trainer.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/old_train_template_prompts.jsonl \
+ data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/old_val_template_prompts.jsonl \
+ data.train_batch_size=128 \
+ data.val_batch_size=128 \
+ data.max_prompt_length=4096 \
+ data.max_response_length=4096 \
+ data.filter_overlong_prompts=False \
+ data.truncation='left' \
+ data.image_key=images \
+ data.video_key=videos \
+ data.prompt_key=problem \
+ data.dataloader_num_workers=0 \
+ data.modalities=\'audio,videos\' \
+ data.format_prompt=examples/format_prompt/default.jinja \
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=False \
+ actor_rollout_ref.actor.ppo_mini_batch_size=128 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.kl_loss_coef=1e-8 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.enable_chunked_prefill=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.free_cache_engine=True \
+ actor_rollout_ref.rollout.n=5 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ algorithm.use_kl_in_reward=False \
+ custom_reward_function.path=examples/reward_function/medical.py \
+ custom_reward_function.name=medical_compute_score_batch \
+ reward_model.reward_manager=batch \
+ trainer.critic_warmup=0 \
+ trainer.logger='["console","wandb"]' \
+ trainer.project_name='verl_hb' \
+ trainer.experiment_name='vision_only' \
+ trainer.n_gpus_per_node=2 \
+ trainer.nnodes=1 \
+ trainer.save_freq=20 \
+ trainer.val_before_train=False \
+ trainer.test_freq=5 \
+ trainer.total_epochs=15 $@
diff --git a/examples/reward_function/evaluation.py b/examples/reward_function/evaluation.py
index 3d20e757a8a..716efaa2e98 100644
--- a/examples/reward_function/evaluation.py
+++ b/examples/reward_function/evaluation.py
@@ -253,9 +253,6 @@ def parent(predictions: List[str], ground_truths: List[str], demographics: List[
print(f"std of accuracy for parent = {std_acc:.4f}")
if len(f1_values) >= 2:
- f1_diff = max(f1_values) - min(f1_values)
- results["f1_diff"] = f1_diff
- print(f"F1 max diff for parent = {f1_diff:.4f}")
f1_std = statistics.stdev(f1_values)
results["f1_std"] = f1_std
print(f"std of f1 for parent = {f1_std:.4f}")
diff --git a/examples/reward_function/human_behaviour.py b/examples/reward_function/human_behaviour.py
new file mode 100644
index 00000000000..9221dbf7eda
--- /dev/null
+++ b/examples/reward_function/human_behaviour.py
@@ -0,0 +1,103 @@
+from typing import List, Dict
+import re
+
+def extract_boxed_content(text: str) -> str:
+ """
+ Extract content within \boxed{} or similar boxing notations.
+
+ Args:
+ text (str): Text containing potentially boxed content.
+
+ Returns:
+ str: Extracted boxed content or the original text if no box found.
+ """
+
+ # Look for LaTeX \boxed{} notation
+ boxed_match = re.search(r"\\boxed{([^}]*)}", text)
+ if boxed_match:
+ return boxed_match.group(1)
+
+ # Look for markdown boxed notation (e.g., [boxed content])
+ markdown_match = re.search(r"\[(.*?)\]", text)
+ if markdown_match:
+ return markdown_match.group(1)
+
+ # Return the text as is if no boxed content is found
+ return text
+
+def format_reward(response: str) -> float:
+ """
+ Check whether the response matches the expected format.
+ Here we require something like ... ... \boxed{...}
+ """
+ pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL)
+ format_match = re.fullmatch(pattern, response)
+ return 1.0 if format_match else 0.0
+
+def accuracy_reward(response: str, ground_truth: str) -> float:
+ """
+ Simple accuracy: exact match to ground truth string.
+ """
+ return 1.0 if response == ground_truth else 0.0
+
+def human_behaviour_compute_score_batch(
+ data_sources: List[str],
+ solution_strs: List[str],
+ ground_truths: List[str],
+ extra_infos: List[str],
+ **kwargs
+) -> List[Dict[str, float]]:
+ """
+ Compute human behaviour scoring for batch inputs.
+
+ Args:
+ data_sources: List of data sources (unused here, but kept for interface compatibility)
+ solution_strs: List of model prediction strings
+ ground_truths: List of ground truth strings
+ extra_infos: List of extra information (unused here, kept for compatibility)
+
+ Returns:
+ List of score dictionaries
+ """
+ batch_scores = []
+ format_weight = 0.1
+
+ for data_source, predict_str, ground_truth, extra_info in zip(data_sources, solution_strs, ground_truths, extra_infos):
+ # Normalize response formatting (e.g., qwen2.5vl quirks)
+ full_response = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str)
+ pred_label = extract_boxed_content(full_response).lower() # handle qwen2.5vl-32b format
+
+ print(pred_label)
+ # Compute individual components
+ format_score = format_reward(full_response)
+ standard_score = accuracy_reward(pred_label, ground_truth)
+
+ ground_truth = ground_truth.lower()
+
+ # Weighted overall score
+ overall_score = (1 - format_weight) * standard_score + format_weight * format_score
+
+ scores = {
+ "score": overall_score,
+ "standard_score": standard_score,
+ "format_score": format_score,
+ }
+ batch_scores.append(scores)
+
+ return batch_scores
+
+
+if __name__ == "__main__":
+ response_str = (
+ "Well, I've listened to the speech recording. It sounds like the speaker is expressing anger. "
+ "You know, the tone and the way the words are said seem to indicate frustration or annoyance. "
+ "So, I'd say the emotion is anger.\\boxed{anger}If you have any other questions or need more help, feel free to let me know."
+ )
+
+ data_sources = ["sample_audio.wav"]
+ solution_strs = [response_str]
+ ground_truths = ["anger"]
+ extra_infos = [""]
+
+ scores = human_behaviour_compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos)
+ print(scores)
diff --git a/examples/reward_function/human_behaviour_alt.py b/examples/reward_function/human_behaviour_alt.py
new file mode 100644
index 00000000000..765dbeb7b16
--- /dev/null
+++ b/examples/reward_function/human_behaviour_alt.py
@@ -0,0 +1,140 @@
+import re
+import json
+from typing import Dict, List
+
+import numpy
+import torch
+import numpy as np
+from mathruler.grader import extract_boxed_content
+import wandb
+import random
+
+
+def parse_conditions(text):
+ # Remove any boxing notation if present
+ text = text.replace("\\boxed{", "").replace("}", "")
+
+ # Split by common separators
+ for sep in [", ", " and ", " & ", ",", "&"]:
+ if sep in text:
+ return set(cond.strip() for cond in text.split(sep))
+
+ # If no separator found, treat as single condition
+ return {text.strip()}
+
+
+def parse_json(json_output):
+ """
+ Parsing out the markdown fencing from JSON code blocks.
+ """
+ # Look for content between ```json and ```
+ lines = json_output.splitlines()
+ for i, line in enumerate(lines):
+ if line == "```json" or line.strip() == "```":
+ json_output = "\n".join(lines[i + 1:]) # Remove everything before ```json
+ if "```" in json_output:
+ json_output = json_output.split("```")[0] # Remove everything after the closing ```
+ break # Exit the loop once code block marker is found
+ return json_output
+
+
+def extract_json_from_response(text):
+ """
+ Extract JSON content from markdown code blocks in the response.
+
+ Args:
+ text: The model's response text
+
+ Returns:
+ Parsed JSON object or None if no valid JSON found
+ """
+ # Find content between ```json and ```
+ json_pattern = r"```(?:json)?\s*([\s\S]*?)```"
+ matches = re.findall(json_pattern, text)
+
+ if not matches:
+ return None
+
+ # Try to parse each match as JSON
+ for match in matches:
+ try:
+ parsed_json = json.loads(match.strip())
+ return parsed_json
+ except json.JSONDecodeError:
+ continue
+
+ # If we couldn't parse any match as valid JSON, try with ast.literal_eval
+ import ast
+ for match in matches:
+ try:
+ # Clean up the match a bit
+ cleaned = match.strip().replace("'", "\"")
+ parsed_json = ast.literal_eval(cleaned)
+ return parsed_json
+ except:
+ continue
+
+ return None
+
+
+def medical_compute_score_batch(data_sources: List[str], solution_strs: List[str], ground_truths: List[str], extra_infos: List[str], **kwargs) -> List[Dict[str, float]]:
+ """
+ Compute medical scoring for batch inputs including standard score, bounding box IoU, and format score.
+
+ Args:
+ data_sources: List of data sources (e.g., file paths or identifiers)
+ solution_strs: List of model prediction strings
+ ground_truths: List of ground truth strings
+ extra_infos: List of extra information (e.g., segmentation masks, bounding boxes)
+
+ Returns:
+ List of score dictionaries
+ """
+ batch_scores = []
+
+ for data_source, predict_str, ground_truth, extra_info in zip(data_sources, solution_strs, ground_truths, extra_infos):
+ segmentation_mask = None
+ bbox = None
+
+ # Calculate standard score
+ answer = extract_boxed_content(predict_str)
+ if answer == "None":
+ standard_score = 0.0 # no answer
+ else:
+ # Parse both prediction and ground truth into sets of conditions
+ predicted_conditions = parse_conditions(answer)
+ ground_truth_conditions = parse_conditions(ground_truth)
+
+ # Calculate true positives, false positives, and false negatives
+ true_positives = len(predicted_conditions.intersection(ground_truth_conditions))
+ false_positives = len(predicted_conditions - ground_truth_conditions)
+ false_negatives = len(ground_truth_conditions - predicted_conditions)
+
+ # Calculate F1 score components
+ precision = (
+ true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
+ )
+ recall = true_positives / (true_positives + false_negatives) if (
+ true_positives + false_negatives) > 0 else 0
+
+ # Calculate F1 score (harmonic mean of precision and recall)
+ standard_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
+
+ # Calculate format score (how well the JSON follows the expected format)
+ format_score = evaluate_bbox_format(predict_str)
+
+ # length score
+ if len(predict_str) > 600: # ~200 words
+ length_score = 1
+ else:
+ length_score = len(predict_str) * 0.001
+
+ scores = {
+ "score": 0.5 * standard_score + 0.3 * iou_score + 0.1 * format_score,
+ "standard_score": standard_score,
+ "format_score": format_score,
+ "length_score": length_score,
+ }
+ batch_scores.append(scores)
+
+ return batch_scores
\ No newline at end of file
diff --git a/examples/reward_function/math.py b/examples/reward_function/math.py
new file mode 100644
index 00000000000..f39b32a5699
--- /dev/null
+++ b/examples/reward_function/math.py
@@ -0,0 +1,49 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Any
+
+from mathruler.grader import extract_boxed_content, grade_answer
+
+
+def format_reward(response: str) -> float:
+ pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL)
+ format_match = re.fullmatch(pattern, response)
+ return 1.0 if format_match else 0.0
+
+
+def accuracy_reward(response: str, ground_truth: str) -> float:
+ answer = extract_boxed_content(response)
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
+
+
+def compute_score(reward_inputs: list[dict[str, Any]], format_weight: float = 0.1) -> list[dict[str, float]]:
+ if not isinstance(reward_inputs, list):
+ raise ValueError("Please use `reward_type=batch` for math reward function.")
+
+ scores = []
+ for reward_input in reward_inputs:
+ response = re.sub(r"\s*(<|>|/)\s*", r"\1", reward_input["response"]) # handle qwen2.5vl-32b format
+ format_score = format_reward(response)
+ accuracy_score = accuracy_reward(response, reward_input["ground_truth"])
+ scores.append(
+ {
+ "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
+ "format": format_score,
+ "accuracy": accuracy_score,
+ }
+ )
+
+ return scores
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 8ee2c46ff0f..ab17283f24d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,7 +3,7 @@ accelerate
codetiming
datasets
dill
-flash-attn
+# flash-attn
hydra-core
liger-kernel
numpy
@@ -17,6 +17,7 @@ ray[default]
tensordict
torchdata
transformers
+vllm
# vllm==0.8.4
wandb
packaging
diff --git a/verl/model_merger/base_model_merger.py b/verl/model_merger/base_model_merger.py
index b46f40f879b..185adbd7e74 100644
--- a/verl/model_merger/base_model_merger.py
+++ b/verl/model_merger/base_model_merger.py
@@ -188,7 +188,7 @@ def __init__(self, config: ModelMergerConfig):
self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code
)
- def get_transformers_auto_model_class(self):
+ def get_transformers_auto_model_class(self):
has_remote_code = hasattr(self.model_config, "auto_map") and any(
self.model_config.architectures[0] in val for val in self.model_config.auto_map.values()
)
@@ -208,6 +208,9 @@ def get_transformers_auto_model_class(self):
else:
if "ForTokenClassification" in self.model_config.architectures[0]:
return AutoModelForTokenClassification
+ elif "Qwen2.5-Omni" in self.hf_model_config_path:
+ from transformers import Qwen2_5OmniThinkerForConditionalGeneration
+ return Qwen2_5OmniThinkerForConditionalGeneration
elif "ForCausalLM" in self.model_config.architectures[0]:
return AutoModelForCausalLM
elif "ForConditionalGeneration" in self.model_config.architectures[0]:
diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py
index 0cc7820d114..e9c73a87f01 100644
--- a/verl/models/transformers/qwen2_vl.py
+++ b/verl/models/transformers/qwen2_vl.py
@@ -73,9 +73,20 @@ def get_rope_index(
"""
spatial_merge_size = processor.image_processor.merge_size
tokens_per_second = 2
+
+ # Try old token names first, then fall back to new ones
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
+ if image_token_id is None:
+ image_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_pad|>") # New tokenizer uses vision_pad for images
+
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
+ if video_token_id is None:
+ video_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_pad|>") # New tokenizer uses vision_pad for videos
+
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
+ if vision_start_token_id is None:
+ vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_bos|>") # New tokenizer uses vision_bos
+
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
@@ -84,10 +95,20 @@ def get_rope_index(
image_index, video_index = 0, 0
input_ids = input_ids[attention_mask == 1]
image_nums, video_nums = 0, 0
- vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
- vision_tokens = input_ids[vision_start_indices + 1]
- image_nums = (vision_tokens == image_token_id).sum()
- video_nums = (vision_tokens == video_token_id).sum()
+
+ # Ensure vision_start_token_id is valid before comparison
+ if vision_start_token_id is not None:
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
+ else:
+ vision_start_indices = torch.empty((0, 1), dtype=torch.long, device=input_ids.device)
+
+ # Handle case where there are vision tokens
+ if vision_start_indices.numel() > 0:
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum().item() if image_token_id is not None else 0
+ video_nums = (vision_tokens == video_token_id).sum().item() if video_token_id is not None else 0
+ else:
+ image_nums, video_nums = 0, 0
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml
index fe22ae148e3..bd141bc9c54 100644
--- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml
@@ -285,6 +285,14 @@ data:
truncation: error
image_key: images
video_key: videos
+ audio_key: audios
+ modalities: images,videos
+ train_modality_batching:
+ enabled: false
+ drop_last: false
+ val_modality_batching:
+ enabled: false
+ drop_last: false
trust_remote_code: false
custom_cls:
path: null
diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml
index 43d0a28878a..8e30ee74892 100644
--- a/verl/trainer/config/_generated_ppo_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_trainer.yaml
@@ -282,6 +282,14 @@ data:
truncation: error
image_key: images
video_key: videos
+ audio_key: audios
+ modalities: images,videos # list of modalities to process
+ train_modality_batching:
+ enabled: false
+ drop_last: false
+ val_modality_batching:
+ enabled: false
+ drop_last: false
trust_remote_code: false
custom_cls:
path: null
diff --git a/verl/trainer/config/data/legacy_data.yaml b/verl/trainer/config/data/legacy_data.yaml
index f6bd5aa26a2..06809587be2 100644
--- a/verl/trainer/config/data/legacy_data.yaml
+++ b/verl/trainer/config/data/legacy_data.yaml
@@ -77,6 +77,21 @@ image_key: images
# The field in the multi-modal dataset where the video is located.
video_key: videos
+# The field in the multi-modal dataset where the audio is located. Default is 'audios'.
+audio_key: audios
+
+# Comma-separated list of modalities to process. Default is 'images,videos'.
+# Available modalities: images, videos, audios
+# Example: 'images,videos,audios' to enable all modalities
+modalities: images,videos
+
+train_modality_batching:
+ enabled: false
+ drop_last: false
+val_modality_batching:
+ enabled: false
+ drop_last: false
+
# If the remote tokenizer has a Python file, this flag determines whether to allow using it.
trust_remote_code: False
diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml
index 8622cb68790..16f19c6fd48 100644
--- a/verl/trainer/config/rollout/rollout.yaml
+++ b/verl/trainer/config/rollout/rollout.yaml
@@ -50,10 +50,10 @@ free_cache_engine: True
tensor_model_parallel_size: 2
# max number of tokens in a batch
-max_num_batched_tokens: 8192
+max_num_batched_tokens: 1536
# max length for rollout
-max_model_len: null
+max_model_len: 1536
# max length of sequences
max_num_seqs: 1024
diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py
index fcf748ba99e..f16d30987ca 100644
--- a/verl/trainer/main_ppo.py
+++ b/verl/trainer/main_ppo.py
@@ -22,6 +22,7 @@
import ray
from omegaconf import OmegaConf
+from typing import Dict, List
from verl.experimental.dataset.sampler import AbstractSampler
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
@@ -30,6 +31,7 @@
from verl.utils.config import validate_config
from verl.utils.device import is_cuda_available
from verl.utils.import_utils import load_extern_type
+from verl.utils.dataset.modality_sampler import ModalitySignatureBatchSampler
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
@@ -82,6 +84,8 @@ def run_ppo(config) -> None:
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
else:
runner = TaskRunner.remote()
+
+ # RUN THE TRAINING using runner.run
ray.get(runner.run.remote(config))
# [Optional] get the path of the timeline trace file from the configuration, default to None
@@ -282,11 +286,17 @@ def run(self, config):
from verl.utils.dataset.rl_dataset import collate_fn
# Create training and validation datasets.
+ # This is done by reading from the train and the val files
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False)
- train_sampler = create_rl_sampler(config.data, train_dataset)
+ train_sampler = create_rl_sampler(config.data, train_dataset, split="train")
+
+ # print(f"Using train sampler: {train_sampler}")
+ # print(f"Using val dataset: {val_dataset}")
+ # print(f"Using train dataset: {train_dataset}")
# Initialize the PPO trainer.
+ # TODO: train sampler is fed into this; and is used to shuffle the training dataset (and extending to validation)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
@@ -356,8 +366,50 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr
return dataset
-
-def create_rl_sampler(data_config, dataset):
+# NOTE: This is the old rl_sampler
+# def create_rl_sampler(data_config, dataset):
+# """Create a sampler for the dataset.
+
+# Arguments:
+# data_config: The data config.
+# dataset (Dataset): The dataset.
+
+# Returns:
+# sampler (Sampler): The sampler.
+# """
+# import torch
+# from torch.utils.data import RandomSampler, SequentialSampler
+
+# if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None:
+# curriculum_class = load_extern_type(
+# data_config.sampler.class_path,
+# data_config.sampler.class_name,
+# )
+# sampler = curriculum_class(
+# data_source=dataset,
+# data_config=data_config,
+# )
+# assert isinstance(sampler, AbstractSampler)
+# assert data_config.get("dataloader_num_workers", 8) == 0, (
+# "If using curriculum, num_workers must be 0 to prevent data caching. "
+# "If the dataloader caches data before the batch is done the "
+# "curriculum sampler won't have the opportunity to reorder it. "
+# )
+
+# # Use a sampler to facilitate checkpoint resumption.
+# # If shuffling is enabled in the data configuration, create a random sampler.
+# elif data_config.shuffle:
+# train_dataloader_generator = torch.Generator()
+# train_dataloader_generator.manual_seed(data_config.get("seed", 1))
+# sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
+# else:
+# # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.
+# sampler = SequentialSampler(data_source=dataset)
+
+# return sampler
+
+# NOTE: This is your implementation
+def create_rl_sampler(data_config, dataset, split: str = "train"):
"""Create a sampler for the dataset.
Arguments:
@@ -370,6 +422,10 @@ def create_rl_sampler(data_config, dataset):
import torch
from torch.utils.data import RandomSampler, SequentialSampler
+ # modality batching config parse
+ mb_cfg = data_config.get("train_modality_batching") if split == "train" \
+ else data_config.get("val_modality_batching")
+
if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None:
curriculum_class = load_extern_type(
data_config.sampler.class_path,
@@ -386,18 +442,52 @@ def create_rl_sampler(data_config, dataset):
"curriculum sampler won't have the opportunity to reorder it. "
)
+ if mb_cfg and mb_cfg.get("enabled", False):
+ print(f"Creating our modality sampler for split: {split}")
+ # by_sig is actually the collation of dataset indices grouped by their modality signature
+ by_sig: Dict[str, List[int]] = {}
+ # essentially getting "modality_signature" from the jsonl dataset
+ for i in range(len(dataset)):
+ row = dataset.dataframe[i] if hasattr(dataset, "dataframe") else dataset[i]
+ sig = row.get("modality_signature")
+ if sig is None:
+ print(f"[WARNING] Row {i} missing 'modality_signature'. Skipping.")
+ continue
+ by_sig.setdefault(sig, []).append(i)
+
+ # batch_size = mb_cfg.get("batch_size", data_config.get(
+ # "train_batch_size" if split=="train" else "val_batch_size"
+ # ))
+
+ batch_size = data_config.get("train_batch_size" if split=="train" else "val_batch_size")
+
+ drop_last = mb_cfg.get("drop_last")
+
+ # shuffle if split (meaning that we shuffle the samples within each batch)
+ shuffle = (split == "train")
+
+ print(f"Creating our modality sampler for split: {split}, batch_size: {batch_size}, drop_last: {drop_last}, shuffle: {shuffle}")
+
+ sampler = ModalitySignatureBatchSampler(
+ indices_by_sig=by_sig,
+ batch_size=int(batch_size),
+ drop_last=drop_last,
+ seed=data_config.get("seed", 42),
+ shuffle=shuffle,
+ )
+
# Use a sampler to facilitate checkpoint resumption.
# If shuffling is enabled in the data configuration, create a random sampler.
- elif data_config.shuffle:
+ elif data_config.shuffle and split == "train":
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
+
else:
# If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.
sampler = SequentialSampler(data_source=dataset)
return sampler
-
if __name__ == "__main__":
main()
diff --git a/verl/trainer/ppo/org_functions.py b/verl/trainer/ppo/org_functions.py
new file mode 100644
index 00000000000..f9e3eb86248
--- /dev/null
+++ b/verl/trainer/ppo/org_functions.py
@@ -0,0 +1,78 @@
+def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
+ """
+ Creates the train and validation dataloaders.
+ """
+ # TODO: we have to make sure the batch size is divisible by the dp size
+ from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
+
+ if train_dataset is None:
+ train_dataset = create_rl_dataset(
+ self.config.data.train_files, self.config.data, self.tokenizer, self.processor
+ )
+ if val_dataset is None:
+ val_dataset = create_rl_dataset(
+ self.config.data.val_files, self.config.data, self.tokenizer, self.processor
+ )
+ self.train_dataset, self.val_dataset = train_dataset, val_dataset
+
+ if train_sampler is None:
+ # TODO; you can essentially specify the type of sampler here, based also on the data split
+ train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
+ if collate_fn is None:
+ from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
+
+ collate_fn = default_collate_fn
+
+ num_workers = self.config.data["dataloader_num_workers"]
+
+ ## TODO: trainer_sampler is pretty much the rl_sampler here, which shuffles the dataset
+ # TODO: the sampler is now placed into this (which is essentially your sampler)
+ self.train_dataloader = StatefulDataLoader(
+ dataset=self.train_dataset,
+ batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
+ num_workers=num_workers,
+ # shuffle=False,
+ drop_last=True,
+ collate_fn=collate_fn,
+ sampler=train_sampler,
+ )
+
+ val_batch_size = self.config.data.val_batch_size # Prefer config value if set
+ if val_batch_size is None:
+ val_batch_size = len(self.val_dataset)
+
+ # TODO: validation data is shuffled here as well
+ self.val_dataloader = StatefulDataLoader(
+ dataset=self.val_dataset,
+ batch_size=val_batch_size,
+ num_workers=num_workers,
+ shuffle=self.config.data.get("validation_shuffle", True),
+ drop_last=False,
+ collate_fn=collate_fn,
+ )
+
+ assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
+ assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"
+
+ print(
+ f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: "
+ f"{len(self.val_dataloader)}"
+ )
+
+ total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
+
+ if self.config.trainer.total_training_steps is not None:
+ total_training_steps = self.config.trainer.total_training_steps
+
+ self.total_training_steps = total_training_steps
+ print(f"Total training steps: {self.total_training_steps}")
+
+ try:
+ OmegaConf.set_struct(self.config, True)
+ with open_dict(self.config):
+ if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"):
+ self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
+ if OmegaConf.select(self.config, "critic.optim"):
+ self.config.critic.optim.total_training_steps = total_training_steps
+ except Exception as e:
+ print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
\ No newline at end of file
diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py
index 620f4b050d7..6ed2244e500 100644
--- a/verl/trainer/ppo/ray_trainer.py
+++ b/verl/trainer/ppo/ray_trainer.py
@@ -34,7 +34,7 @@
import ujson
import wandb
from omegaconf import OmegaConf, open_dict
-from torch.utils.data import Dataset, Sampler
+from torch.utils.data import Dataset, Sampler, BatchSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm
@@ -64,6 +64,23 @@
from verl.utils.tracking import ValidationGenerationsLogger
from examples.reward_function.evaluation import compute_metrics_by_data_source
+WorkerType = type[Worker]
+
+debug_file = "/home/keaneong/human-behavior/verl/examples/grpo_trainer/debug_log.txt"
+
+class Role(Enum):
+ """
+ To create more roles dynamically, you can subclass Role and add new members
+ """
+
+ Actor = 0
+ Rollout = 1
+ ActorRollout = 2
+ Critic = 3
+ RefPolicy = 4
+ RewardModel = 5
+ ActorRolloutRef = 6
+
@dataclass
class ResourcePoolManager:
@@ -327,6 +344,7 @@ def __init__(
val_dataset: Optional[Dataset] = None,
collate_fn=None,
train_sampler: Optional[Sampler] = None,
+ val_sampler: Optional[Sampler] = None,
device_name=None,
):
"""
@@ -402,7 +420,11 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
self.train_dataset, self.val_dataset = train_dataset, val_dataset
if train_sampler is None:
- train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
+ train_sampler = create_rl_sampler(self.config.data, self.train_dataset, split="train")
+
+ if val_sampler is None:
+ val_sampler = create_rl_sampler(self.config.data, self.val_dataset, split="val")
+
if collate_fn is None:
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
@@ -410,16 +432,41 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
num_workers = self.config.data["dataloader_num_workers"]
- self.train_dataloader = StatefulDataLoader(
- dataset=self.train_dataset,
- batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
- num_workers=num_workers,
- drop_last=True,
- collate_fn=collate_fn,
- sampler=train_sampler,
- )
- val_batch_size = self.config.data.val_batch_size # Prefer config value if set
+ if isinstance(train_sampler, BatchSampler):
+ self.train_dataloader = StatefulDataLoader(
+ dataset=self.train_dataset,
+ batch_sampler=train_sampler,
+ num_workers=num_workers,
+ collate_fn=collate_fn,
+ )
+ else:
+ # Else if it is not a batch sampler, we can specify the batch size directly
+ self.train_dataloader = StatefulDataLoader(
+ dataset=self.train_dataset,
+ batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
+ num_workers=num_workers,
+ # shuffle=False,
+ drop_last=True,
+ collate_fn=collate_fn,
+ sampler=train_sampler,
+ )
+ if isinstance(val_sampler, BatchSampler):
+ # BatchSampler path: DO NOT pass batch_size/shuffle/drop_last
+ self.val_dataloader = StatefulDataLoader(
+ dataset=self.val_dataset,
+ batch_sampler=val_sampler,
+ num_workers=num_workers,
+ collate_fn=collate_fn,
+ )
+ else:
+ # Plain Sampler path: compute val_batch_size (None -> len(dataset))
+ # This plain sampler path, if you trace the instance of val_sampler,
+ # should be that of a sequential sampler. Break if it is not.
+ if not isinstance(val_sampler, SequentialSampler):
+ raise ValueError("Validation sampler is not a SequentialSampler")
+
+ val_batch_size = self.config.data.val_batch_size
if val_batch_size is None:
val_batch_size = len(self.val_dataset)
@@ -839,6 +886,8 @@ def init_workers(self):
)
def _save_checkpoint(self):
+
+ ## TO SAVE CHECKPOINT
from verl.utils.fs import local_mkdir_safe
# path: given_path + `/global_step_{global_steps}` + `/actor`
@@ -1045,8 +1094,23 @@ def fit(self):
)
next_step_profile = False
+
for epoch in range(self.config.trainer.total_epochs):
- for batch_dict in self.train_dataloader:
+ # i = 0
+ for batch_idx, batch_dict in enumerate(self.train_dataloader):
+ #--- DEBUG: log batch content into debug_file ---
+
+
+ if debug_file is not None:
+ with open(debug_file, "a", encoding="utf-8") as f:
+ log_entry = {
+ "epoch": int(epoch),
+ "batch_idx": int(batch_idx),
+ "modality_signatures": batch_dict.get("modality_signatures", []),
+ "prompts": batch_dict.get("debug_prompts", []),
+ }
+ f.write(json.dumps(log_entry, ensure_ascii=False, default=lambda o: o.tolist() if isinstance(o, np.ndarray) else str(o)) + "\n")
+
metrics = {}
timing_raw = {}
@@ -1072,10 +1136,15 @@ def fit(self):
is_last_step = self.global_steps >= self.total_training_steps
+ # TODO: double check the gen_batch
+ # print(f"gen_batch", gen_batch)
+ # i += 1
+
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
if not self.async_rollout_mode:
+ # TODO: Fix the audio generation for this
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
diff --git a/verl/utils/dataset/audio_utils.py b/verl/utils/dataset/audio_utils.py
new file mode 100644
index 00000000000..5a29550922f
--- /dev/null
+++ b/verl/utils/dataset/audio_utils.py
@@ -0,0 +1,175 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple, Union
+import torch
+import torchaudio
+
+def process_audio(
+ audio: Union[str, dict],
+ processor=None,
+ max_seconds: float = None # keep audio to this many seconds max
+) -> Tuple[torch.Tensor, int]:
+ """
+ Load audio, convert to mono, resample, and clip to max_seconds.
+ """
+ if isinstance(audio, dict):
+ audio_path = audio.get("audio", audio)
+ else:
+ audio_path = audio
+
+ try:
+ # Load
+ audio_data, original_sr = torchaudio.load(audio_path)
+
+ # Resample if needed
+ if processor and hasattr(processor, 'feature_extractor') and hasattr(processor.feature_extractor, 'sampling_rate'):
+ target_sr = processor.feature_extractor.sampling_rate
+ else:
+ target_sr = 16000
+
+ if original_sr != target_sr:
+ resampler = torchaudio.transforms.Resample(original_sr, target_sr)
+ audio_data = resampler(audio_data)
+ else:
+ target_sr = original_sr
+
+ # Convert to mono
+ if audio_data.shape[0] > 1:
+ audio_data = audio_data.mean(dim=0, keepdim=False)
+ else:
+ audio_data = audio_data.squeeze(0)
+
+ # Clip to max_seconds
+ if max_seconds:
+ max_samples = int(max_seconds * target_sr)
+
+ print(f"Processing Audio {audio_path}, shape={audio_data.shape}, "
+ f"sr={target_sr}, max_samples={max_samples}")
+ # ValueError("Audio was processed")
+
+ if audio_data.shape[0] > max_samples:
+ print("Clipping audio to max_seconds")
+ audio_data = audio_data[:max_samples]
+ raise ValueError("Audio data was clipped to max_seconds")
+
+ return audio_data, target_sr
+
+ except Exception as e:
+ print(f"Error processing audio {audio_path}: {e}")
+ dummy_audio = torch.zeros((int(16000 * max_seconds),), dtype=torch.float32)
+ return dummy_audio, 16000
+
+
+# def process_audio(audio: str | dict, processor=None) -> Tuple[torch.Tensor, int]:
+# if isinstance(audio, dict):
+# # TODO: to check whether the keys are correct here
+# audio_path = audio.get("audio", audio)
+# else:
+# audio_path = audio
+
+# try:
+# # Load audio
+# # NOTE: accepts waveform and sample rate;
+# audio_data, original_sr = torchaudio.load(audio_path)
+
+# # Get target sampling rate
+# # NOTE: sample rate is basically the amount of audio samples captured per second
+# # 16000 means 16000 samples are taken in every second
+# if processor and hasattr(processor, 'feature_extractor') and hasattr(processor.feature_extractor,
+# 'sampling_rate'):
+# target_sr = processor.feature_extractor.sampling_rate
+# else:
+
+# target_sr = 16000
+# # print(f"KEANE: Processing audio {audio_path} with sampling rate, {target_sr}")
+# # Resample if needed
+# # NOTE: This is essentially the resampling of the audio sample rate
+# if original_sr != target_sr:
+# resampler = torchaudio.transforms.Resample(original_sr, target_sr)
+# audio_data = resampler(audio_data)
+
+# # Convert to mono if stereo
+# # NOTE: This is essentially the conversion of stereo audio to mono, so that we only have one channel
+# if audio_data.shape[0] > 1:
+# audio_data = audio_data.mean(dim=0, keepdim=False)
+# else:
+# audio_data = audio_data.squeeze(0)
+
+# # Debug prints
+# # print(
+# # f"KEANE: Finished processing {audio_path} -> "
+# # f"waveform shape {audio_data.shape}, dtype {audio_data.dtype}, "
+# # # f"min {audio_data.min().item():.4f}, max {audio_data.max().item():.4f}"
+# # )
+# # print(f"KEANE: Returning tuple (waveform, sr={target_sr})")
+
+# return audio_data, target_sr
+# except Exception as e:
+# print(f"Error processing audio {audio_path}: {e}")
+# dummy_audio = torch.zeros((1000,), dtype=torch.float32)
+# return dummy_audio, 16000
+
+
+# def process_audio(audio: str | dict, processor=None) -> np.ndarray:
+# """
+# NOTE: Keane's implementation
+# """
+# if isinstance(audio, dict):
+# audio_path = audio.get("audio", audio)
+# else:
+# audio_path = audio
+
+# try:
+# # Load audio -> (channels, time), sample_rate
+# audio_data, original_sr = torchaudio.load(audio_path)
+
+# # Target sampling rate (from processor or default to 16k)
+# if (
+# processor
+# and hasattr(processor, "feature_extractor")
+# and hasattr(processor.feature_extractor, "sampling_rate")
+# ):
+# target_sr = processor.feature_extractor.sampling_rate
+# else:
+# target_sr = 16000
+
+# print(f"KEANE: Processing audio {audio_path} with target sampling rate {target_sr}")
+
+# # Resample if needed
+# if original_sr != target_sr:
+# resampler = torchaudio.transforms.Resample(original_sr, target_sr)
+# audio_data = resampler(audio_data)
+
+# # Convert to mono if stereo
+# if audio_data.shape[0] > 1:
+# audio_data = audio_data.mean(dim=0, keepdim=False)
+# else:
+# audio_data = audio_data.squeeze(0)
+
+# # Convert to numpy float32 (1-D)
+# audio_np = audio_data.detach().cpu().numpy().astype(np.float32)
+
+# print(
+# f"KEANE: Finished {audio_path} -> "
+# f"waveform shape {audio_np.shape}, dtype {audio_np.dtype}"
+# )
+
+# # NOTE: we only need to return the numpy array and not the sampling rate
+
+# return audio_np
+
+# except Exception as e:
+# print(f"Error processing audio {audio_path}: {e}")
+# return np.zeros((1000,), dtype=np.float32) # dummy 1-D waveform
diff --git a/verl/utils/dataset/modality_sampler.py b/verl/utils/dataset/modality_sampler.py
new file mode 100644
index 00000000000..94e34b00b7b
--- /dev/null
+++ b/verl/utils/dataset/modality_sampler.py
@@ -0,0 +1,79 @@
+import random
+from typing import Dict, List, Iterator, Optional
+from collections import defaultdict, deque
+from torch.utils.data import BatchSampler
+
+class ModalitySignatureBatchSampler(BatchSampler):
+ """
+ Round-robin across modality signatures, pruning exhausted signatures.
+ - Shuffles within each signature if shuffle=True (train).
+ - Each yielded batch is homogeneous by modality_signature.
+ - If a signature runs out of batches, it is removed and RR continues.
+ """
+ def __init__(
+ self,
+ indices_by_sig: Dict[str, List[int]],
+ batch_size: int,
+ drop_last: bool = True,
+ seed: int = 42,
+ shuffle: bool = True,
+ ):
+ self.indices_by_sig = {s: list(v) for s, v in indices_by_sig.items()}
+ self.batch_size = int(batch_size)
+ self.drop_last = drop_last
+ self.shuffle = shuffle
+ self.rng = random.Random(seed)
+ self.sigs = list(self.indices_by_sig.keys())
+
+ def _batches_for(self, pool: List[int]) -> List[List[int]]:
+ n = len(pool)
+ batches = []
+ for start in range(0, n, self.batch_size):
+ chunk = pool[start:start + self.batch_size]
+ if len(chunk) < self.batch_size and self.drop_last:
+ continue
+ if chunk:
+ batches.append(chunk)
+ return batches
+
+ def __iter__(self) -> Iterator[List[int]]:
+ # Fresh pools + optional shuffle within each signature
+ pools = {s: list(v) for s, v in self.indices_by_sig.items()}
+ for s in pools:
+ if self.shuffle:
+ self.rng.shuffle(pools[s])
+
+ # Build per-signature batch queues; essentially a dictionary with batches of each different modality signature
+ per_sig_batches = {s: deque(self._batches_for(pools[s])) for s in self.sigs}
+
+ # Establish RR order
+ order = list(self.sigs)
+ if self.shuffle:
+ # rotate start signature per epoch for variety (keeps RR structure)
+ k = self.rng.randrange(len(order)) if order else 0
+ order = order[k:] + order[:k]
+ else:
+ order = sorted(order)
+
+ # Active signatures as a deque for easy rotation
+ active = deque([s for s in order if len(per_sig_batches[s]) > 0])
+
+ while active:
+ s = active.popleft() # take the queue's leftmost element (modality signature)
+ q = per_sig_batches[s] # access all of the batched stuff
+ if q:
+ yield q.popleft() # yield that batch
+ # if still has batches, push to the end to continue RR
+ if q:
+ active.append(s) # reappend the modality signature to the active queue
+ # if q is empty, we simply don't re-append s → pruned automatically
+ else:
+ print(f"Ran-Out: Pruning modality signature: {s}")
+
+ def __len__(self) -> int:
+ # Total number of batches across all signatures (after drop_last handling)
+ total = 0
+ for pool in self.indices_by_sig.values():
+ full, rem = divmod(len(pool), self.batch_size)
+ total += full + (0 if self.drop_last or rem == 0 else 1)
+ return total
\ No newline at end of file
diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py
index c27a89133c7..61a4b0f4f73 100644
--- a/verl/utils/dataset/rl_dataset.py
+++ b/verl/utils/dataset/rl_dataset.py
@@ -20,8 +20,7 @@
import os
import re
from collections import defaultdict
-from typing import Optional
-
+from typing import Optional,Dict, Any, List
import datasets
import numpy as np
import torch
@@ -29,12 +28,34 @@
from omegaconf import DictConfig, ListConfig
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
+import warnings
import verl.utils.torch_functional as verl_F
from verl.utils.model import compute_position_id_with_mask
+import time, os, math, warnings
logger = logging.getLogger(__name__)
+def _tok_est_from_hw(H, W):
+ # 28x28 -> 1 "visual token" heuristic
+ return math.ceil(H/28) * math.ceil(W/28)
+
+def _sec_from_array(arr, sr):
+ try:
+ return round(len(arr) / float(sr), 3)
+ except Exception:
+ return "?"
+
+def _p99(xs):
+ xs = sorted(xs)
+ if not xs: return 0
+ k = int(0.99*(len(xs)-1))
+ return xs[k]
+
+def assert_homogeneous(batch_list: List[Dict[str, Any]]):
+ sigs = {b.get("modality_signature") for b in batch_list}
+ if len(sigs) != 1:
+ raise AssertionError(f"Non-homogeneous batch signatures: {sigs}")
def processor_supports_video(processor: ProcessorMixin) -> bool:
"""
@@ -78,9 +99,13 @@ def collate_fn(data_list: list[dict]) -> dict:
Returns:
Dict where tensor entries are stacked into a torch.Tensor of shape
- (batch_size, \*dims) and non-tensor entries are converted to
+ (batch_size, dims) and non-tensor entries are converted to
np.ndarray of dtype object with shape (batch_size,).
"""
+ # data list is the batch list
+ # NOTE: we assert homogeneous if the modality signatures are not homogeneous
+ assert_homogeneous(data_list) # assert if not homogeneous
+
tensors = defaultdict(list)
non_tensors = defaultdict(list)
@@ -134,13 +159,24 @@ def __init__(
self.config = config
self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
+
+ # Essentially getting all the different keys.
self.prompt_key = config.get("prompt_key", "prompt")
self.image_key = config.get("image_key", "images")
self.video_key = config.get("video_key", "videos")
+
+ # NOTE: SET AUDIO KEY AS AUDIOS
+ self.audio_key = config.get("audio_key", "audios")
+
+ # NOTE: SET MODALITIES, split the images and videos
+ self.modalities = set(config.get("modalities", "images,videos").split(","))
+
self.max_prompt_length = config.get("max_prompt_length", 1024)
self.return_raw_chat = config.get("return_raw_chat", False)
self.return_full_prompt = config.get("return_full_prompt", False)
self.truncation = config.get("truncation", "error")
+
+ # TODO: Check whether this is true
self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
if isinstance(data_files, str):
self.base_dir = os.path.dirname(os.path.abspath(data_files))
@@ -156,13 +192,25 @@ def __init__(
self.filter_prompts = config.get("filter_prompts", True)
self.serialize_dataset = False
self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True)
+
+ # Load format prompt from file if specified
+ self.format_prompt_path = config.get("format_prompt", "examples/format_prompt/default.jinja")
+ self.format_prompt = self._load_format_prompt()
# Load format prompt from file if specified
self.format_prompt_path = config.get("format_prompt", "examples/format_prompt/default.jinja")
self.format_prompt = self._load_format_prompt()
self._download()
- self._read_files_and_tokenize()
+ self._read_files_and_tokenize() # essentially this is prepared first before _getitem
+
+ def _load_format_prompt(self) -> Optional[Template]:
+ """Load format prompt from file if specified."""
+ if self.format_prompt_path:
+ with open(self.format_prompt_path, 'r', encoding='utf-8') as f:
+ template_content = f.read()
+ return Template(template_content)
+ return None
def _load_format_prompt(self) -> Optional[Template]:
"""Load format prompt from file if specified."""
@@ -181,6 +229,18 @@ def _download(self, use_origin_parquet=False):
def _read_files_and_tokenize(self):
dataframes = []
+
+ features = datasets.Features({
+ "problem": datasets.Value("string"),
+ "answer": datasets.Value("string"),
+ "images": datasets.Sequence(datasets.Value("string")),
+ "videos": datasets.Sequence(datasets.Value("string")),
+ "audios": datasets.Sequence(datasets.Value("string")), # <- force list of strings
+ "dataset": datasets.Value("string"),
+ "texts": datasets.Sequence(datasets.Value("string")),
+ "modality_signature": datasets.Value("string"),
+ })
+
for parquet_file in self.data_files:
# read parquet files and cache
if parquet_file.endswith(".parquet"):
@@ -194,51 +254,78 @@ def _read_files_and_tokenize(self):
print(f"dataset len: {len(self.dataframe)}")
+ # PROCESSING THE DATAFRAME for TRAINING
self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)
def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None):
- # filter out too long prompts
+ # NOTE: filter out too long prompts, because the prompts can become very long
+ # when the audio is appended.
+
if self.filter_overlong_prompts:
+ # NOTE: FILTER OUT THE LONG PROMPTS SO THAT THEY FIT THE LENGTH
tokenizer = self.tokenizer
processor = self.processor
prompt_key = self.prompt_key
image_key = self.image_key
video_key = self.video_key
+ audio_key = self.audio_key
if processor is not None:
+ # print(f"KEANE: PROCESSOR FOUND")
from verl.utils.dataset.vision_utils import process_image, process_video
+ from verl.utils.dataset.audio_utils import process_audio
def doc2len(doc) -> int:
messages = self._build_messages(doc)
raw_prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
)
- images = [process_image(image) for image in doc[image_key]] if image_key in doc else None
- videos = [process_video(video) for video in doc[video_key]] if video_key in doc else None
-
- # Handle video-to-image conversion for processors that don't support video
- if videos and not processor_supports_video(processor):
- # Convert video frames to images
- if images is None:
- images = []
- for video_tensor in videos:
- # video_tensor is shape [n_frames, 3, H, W]
- for frame_idx in range(video_tensor.shape[0]):
- frame = video_tensor[frame_idx] # [3, H, W]
- frame_np = frame.permute(1, 2, 0).numpy() # [H, W, 3]
- from PIL import Image
- frame_image = Image.fromarray(frame_np.astype('uint8'), 'RGB')
- images.append(frame_image)
- videos = None
+ processor_kwargs = {"text": [raw_prompt]}
- # Call processor with appropriate parameters
- if processor_supports_video(processor):
- return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0])
- else:
- return len(processor(text=[raw_prompt], images=images)["input_ids"][0])
+ if "images" in self.modalities and image_key in doc and len(doc[image_key]) > 0:
+ images = [process_image(image) for image in doc[image_key]]
+ processor_kwargs["images"] = images
+
+ if "videos" in self.modalities and video_key in doc and len(doc[video_key]) > 0:
+ videos = [process_video(video) for video in doc[video_key]]
+ # Handle video-to-image conversion for processors that don't support video
+ if videos and not processor_supports_video(processor):
+ # Convert video frames to images
+ if images is None:
+ images = []
+ for video_tensor in videos:
+ # video_tensor is shape [n_frames, 3, H, W]
+ for frame_idx in range(video_tensor.shape[0]):
+ frame = video_tensor[frame_idx] # [3, H, W]
+ frame_np = frame.permute(1, 2, 0).numpy() # [H, W, 3]
+ from PIL import Image
+ frame_image = Image.fromarray(frame_np.astype('uint8'), 'RGB')
+ images.append(frame_image)
+ videos = None
+ processor_kwargs["videos"] = videos
+
+ if "audio" in self.modalities and audio_key in doc and doc.get(audio_key, None) is not None and len(doc[audio_key]) > 0:
+ # processing of audio
+ # print(f"KEANE: Processing audio within rl dataset file")
+ # audios = [process_audio(audio, processor) for audio in doc[audio_key]]
+ # processor_kwargs["audio"] = audios
+
+ # PATCH
+ audios = []
+ audio_tuples = [] # Keep tuples for multi_modal_data
+ for audio in doc.get(self.audio_key):
+ audio_path = os.path.join(self.base_dir, audio) if isinstance(audio, str) else audio
+ audio_data, sampling_rate = process_audio(audio_path, self.processor)
+ audio_tuples.append((audio_data, sampling_rate))
+ # audios.append(audio_data.numpy()) # Convert to numpy array for Whisper
+ audios.append(audio_data.detach().cpu().numpy().astype("float32"))
+
+ processor_kwargs["audio"] = audios # Pass numpy arrays to processor
+
+ return len(processor(**processor_kwargs)["input_ids"][0])
else:
-
+ # print(f"KEANE: PROCESSOR NOT FOUND")
def doc2len(doc) -> int:
return len(
tokenizer.apply_chat_template(
@@ -267,31 +354,40 @@ def resume_dataset_state(self):
def __len__(self):
return len(self.dataframe)
- def _build_messages(self, example: dict, convert_video_to_images: bool = False):
+ def _build_messages(self, example: dict):
+ """
+ This appears to be called twice, once during maybe_filter_out_long_prompts, and another time during getitems
+ """
messages: list = example.get(self.prompt_key)
if isinstance(messages, str):
messages = [messages]
- format_prompt = ("You FIRST think about the reasoning process as an internal monologue and then "
- "provide the final answer. The reasoning process MUST BE enclosed within "
- " tags. The final answer MUST BE put in \\boxed{}.")
-
- if self.image_key in example or self.video_key in example:
+ # NOTE: Before building, check if there is multimodal content
+ has_multimodal = (
+ ("images" in self.modalities and self.image_key in example) or
+ ("videos" in self.modalities and self.video_key in example) or
+ ("audio" in self.modalities and self.audio_key in example)
+ )
+
+ if has_multimodal:
new_messages = []
for message in messages:
new_message = copy.deepcopy(message)
if isinstance(new_message, str):
new_message = {"role": "user", "content": new_message}
content = new_message["content"]
-
# Apply format prompt to the entire content first if template is loaded
if self.format_prompt:
content = self.format_prompt.render(content=content)
image_count = len(example.get(self.image_key, []))
video_count = len(example.get(self.video_key, []))
+ audio_count = len(example.get(self.audio_key, []))
image_tag_count = content.count("")
video_tag_count = content.count("