diff --git a/_unit_test_modality_sampler.py b/_unit_test_modality_sampler.py new file mode 100644 index 00000000000..323d2461439 --- /dev/null +++ b/_unit_test_modality_sampler.py @@ -0,0 +1,308 @@ +# test_stateful_modality_sampler_hardcoded.py + +import json +from typing import Dict, Any, List, Iterator +from torch.utils.data import Dataset, BatchSampler +import random + +# ==== ADJUST PATHS below to match your repo structure ==== +# from verl.utils.dataset.modality_sampler import ModalitySignatureBatchSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from collections import defaultdict, deque +import torch +import numpy as np + +# ---------- HARD-CODED PATHS + CONFIG ---------- +JSONL_PATH = "/Users/keane/Desktop/research/human-behavior/data/all/sigs_no_lmvd_discretized_v3_template_prompts.jsonl" +TRAIN_BS = 4 +VAL_BS = 4 +SEED = 42 +TRUNCATE_RATIO = 0.001 # for quick testing; set to 1.0 to disable +# --------------------------------------------- + +# TODO: Please remove text only; everything should be text_only + +class ModalitySignatureBatchSampler(BatchSampler): + """ + Round-robin across modality signatures, pruning exhausted signatures. + - Shuffles within each signature if shuffle=True (train). + - Each yielded batch is homogeneous by modality_signature. + - If a signature runs out of batches, it is removed and RR continues. + """ + def __init__( + self, + indices_by_sig: Dict[str, List[int]], + batch_size: int, + drop_last: bool = True, + seed: int = 42, + shuffle: bool = True, + ): + self.indices_by_sig = {s: list(v) for s, v in indices_by_sig.items()} + self.batch_size = int(batch_size) + self.drop_last = drop_last + self.shuffle = shuffle + self.rng = random.Random(seed) + self.sigs = list(self.indices_by_sig.keys()) + + def _batches_for(self, pool: List[int]) -> List[List[int]]: + n = len(pool) + batches = [] + for start in range(0, n, self.batch_size): + chunk = pool[start:start + self.batch_size] + if len(chunk) < self.batch_size and self.drop_last: + continue + if chunk: + batches.append(chunk) + return batches + + def __iter__(self) -> Iterator[List[int]]: + # Fresh pools + optional shuffle within each signature + pools = {s: list(v) for s, v in self.indices_by_sig.items()} + for s in pools: + if self.shuffle: + self.rng.shuffle(pools[s]) + + # Build per-signature batch queues; essentially a dictionary with batches of each different modality signature + per_sig_batches = {s: deque(self._batches_for(pools[s])) for s in self.sigs} + + # Establish RR order + order = list(self.sigs) + if self.shuffle: + # rotate start signature per epoch for variety (keeps RR structure) + k = self.rng.randrange(len(order)) if order else 0 + order = order[k:] + order[:k] + else: + order = sorted(order) + + # Active signatures as a deque for easy rotation + active = deque([s for s in order if len(per_sig_batches[s]) > 0]) + + while active: + s = active.popleft() # take the queue's leftmost element (modality signature) + q = per_sig_batches[s] # access all of the batched stuff + if q: + yield q.popleft() # yield that batch + # if still has batches, push to the end to continue RR + if q: + active.append(s) # reappend the modality signature to the active queue + # if q is empty, we simply don't re-append s → pruned automatically + else: + print(f"Ran-Out: Pruning modality signature: {s}") + + def __len__(self) -> int: + # Total number of batches across all signatures (after drop_last handling) + total = 0 + for pool in self.indices_by_sig.values(): + full, rem = divmod(len(pool), self.batch_size) + total += full + (0 if self.drop_last or rem == 0 else 1) + return total + + +def rl_collate_fn(data_list: list[dict]) -> dict: + """ + Collate a batch of sample dicts into batched tensors and arrays. + + Args: + data_list: List of dicts mapping feature names to torch.Tensor or other values. + + Returns: + Dict where tensor entries are stacked into a torch.Tensor of shape + (batch_size, dims) and non-tensor entries are converted to + np.ndarray of dtype object with shape (batch_size,). + """ + tensors = defaultdict(list) + non_tensors = defaultdict(list) + + for data in data_list: + for key, val in data.items(): + if isinstance(val, torch.Tensor): + tensors[key].append(val) + else: + non_tensors[key].append(val) + + for key, val in tensors.items(): + tensors[key] = torch.stack(val, dim=0) + + for key, val in non_tensors.items(): + non_tensors[key] = np.fromiter(val, dtype=object, count=len(val)) + + return {**tensors, **non_tensors} + +def create_rl_sampler(data_config, dataset, split: str = "train"): + """Create a sampler for the dataset, grouping strictly by existing modality_signature.""" + import torch + from torch.utils.data import RandomSampler, SequentialSampler + + mb_cfg = data_config.get("modality_batching") if split == "train" \ + else data_config.get("val_modality_batching") + + # (keep curriculum path if you actually use it; omitted here for brevity) + + if mb_cfg and mb_cfg.get("enabled", False): + by_sig: Dict[str, List[int]] = {} + for i in range(len(dataset)): + row = dataset.dataframe[i] if hasattr(dataset, "dataframe") else dataset[i] + sig = row.get("modality_signature") + if sig is None: + print(f"[WARNING] Row {i} missing 'modality_signature'. Skipping.") + continue + by_sig.setdefault(sig, []).append(i) + + batch_size = mb_cfg.get("batch_size", data_config.get( + "train_batch_size" if split=="train" else "val_batch_size" + )) + drop_last = mb_cfg.get("drop_last", split=="train") + shuffle = (split == "train") + + return ModalitySignatureBatchSampler( + indices_by_sig=by_sig, + batch_size=int(batch_size), + drop_last=drop_last, + seed=data_config.get("seed", 42), + shuffle=shuffle, + ) + + # Fallbacks + if data_config.get("shuffle", True) and split == "train": + g = torch.Generator(); g.manual_seed(data_config.get("seed", 1)) + return RandomSampler(data_source=dataset, generator=g) + else: + return SequentialSampler(data_source=dataset) + +class JsonlDataset(Dataset): + def __init__(self, jsonl_path: str, truncate_ratio: float = TRUNCATE_RATIO, seed: int = SEED): + """ + Loads ONLY entries that already have 'modality_signature'. + Optionally keeps a proportion per signature for fast debugging. + """ + all_rows: List[Dict[str, Any]] = [] + with open(jsonl_path, "r", encoding="utf-8") as f: + for ln in f: + ln = ln.strip() + if not ln: + continue + ex = json.loads(ln) + sig = ex.get("modality_signature") + if sig is None: + print(f"[WARNING] Entry missing 'modality_signature'. Skipping.") + continue # skip missing + all_rows.append(ex) + + # Group by signature and truncate per signature + sig_to_rows: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + for ex in all_rows: + sig_to_rows[ex["modality_signature"]].append(ex) + + rng = random.Random(seed) + truncated_rows: List[Dict[str, Any]] = [] + for sig, rows in sig_to_rows.items(): + if truncate_ratio >= 1.0: + truncated_rows.extend(rows) + continue + keep_n = max(1, int(len(rows) * truncate_ratio)) + rng.shuffle(rows) + truncated_rows.extend(rows[:keep_n]) + + self.rows = truncated_rows + self.dataframe = self # preserve your API + + # simple stats + counts = {sig: sum(1 for r in self.rows if r["modality_signature"] == sig) for sig in sig_to_rows} + print(f"[DEBUG] After truncation (ratio={truncate_ratio}), total {len(self.rows)}. Per-signature: {counts}") + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + +def assert_homogeneous(batch_list: List[Dict[str, Any]]): + sigs = {b.get("modality_signature") for b in batch_list} + if len(sigs) != 1: + raise AssertionError(f"Non-homogeneous batch signatures: {sigs}") + +def collate_with_guard(batch_list): + assert_homogeneous(batch_list) + return rl_collate_fn(batch_list) + +def build_cfg(train_bs: int, val_bs: int, seed: int = 42): + class Dot(dict): + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + return Dot({ + "train_batch_size": train_bs, + "val_batch_size": val_bs, + "shuffle": True, + "seed": seed, + "dataloader_num_workers": 0, + "validation_shuffle": False, + "sampler": None, + "modality_batching": {"enabled": True, "batch_size": train_bs, "drop_last": True}, + "val_modality_batching": {"enabled": True, "batch_size": val_bs, "drop_last": False}, + }) + +def build_loader(dataset, data_cfg, split: str): + sampler_or_batch = create_rl_sampler(data_cfg, dataset, split=split) + if isinstance(sampler_or_batch, BatchSampler): + return StatefulDataLoader( + dataset=dataset, + batch_sampler=sampler_or_batch, + num_workers=data_cfg["dataloader_num_workers"], + collate_fn=collate_with_guard, + ) + else: + bs = data_cfg.get("train_batch_size" if split == "train" else "val_batch_size") + return StatefulDataLoader( + dataset=dataset, + sampler=sampler_or_batch, + batch_size=bs, + num_workers=data_cfg["dataloader_num_workers"], + drop_last=(split == "train"), + shuffle=False if split == "val" else False, + collate_fn=collate_with_guard, + ) + +def main(): + ds = JsonlDataset(JSONL_PATH) + print(f"Dataset size: {len(ds)}; per-signature counts:", + {sig: sum(1 for r in ds.rows if r['modality_signature']==sig) + for sig in sorted({r['modality_signature'] for r in ds.rows})}) + + cfg = build_cfg(TRAIN_BS, VAL_BS, SEED) + + # TRAIN + train_loader = build_loader(ds, cfg, split="train") + print("\n[TRAIN] Iteration 1") + n_train_batches = sum(1 for _ in train_loader) # iterating as you would with the train loader + print(f"train steps: {n_train_batches} (drop_last=True)") + + # New epoch + train_loader2 = build_loader(ds, cfg, split="train") + n_train_batches2 = sum(1 for _ in train_loader2) + assert n_train_batches == n_train_batches2 + print("[TRAIN] Iteration 2: step count consistent") + + # VAL + val_loader = build_loader(ds, cfg, split="val") + print("\n[VAL] Iteration 1") + n_val_batches = sum(1 for _ in val_loader) + print(f"val steps: {n_val_batches} (drop_last=False)") + + # Stateful resume check (if supported) + if hasattr(train_loader, "state_dict"): + print("\n[STATEFUL] Testing resume mid-epoch") + train_loader3 = build_loader(ds, cfg, split="train") + it = iter(train_loader3) + next(it); next(it) # consume 2 + sd = train_loader3.state_dict() + train_loader4 = build_loader(ds, cfg, split="train") + train_loader4.load_state_dict(sd) + resumed = sum(1 for _ in train_loader4) + print(f"resumed batches after 2 consumed: {resumed}") + + print("\nOK: StatefulDataLoader + sampler test finished.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dvd_requirements.txt b/dvd_requirements.txt new file mode 100644 index 00000000000..7210f325c6b --- /dev/null +++ b/dvd_requirements.txt @@ -0,0 +1,282 @@ +# torch==2.7.1+cu126 +absl-py==2.3.0 +accelerate==1.8.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.12.13 +aiohttp-cors==0.8.1 +aiosignal==1.3.2 +airportsdata==20250622 +alembic==1.16.2 +aliyun-python-sdk-core==2.16.0 +aliyun-python-sdk-kms==2.16.5 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.9.0 +astor==0.8.1 +attrs==25.3.0 +audioread==3.0.1 +av==14.4.0 +beautifulsoup4==4.13.4 +blake3==1.0.5 +cachetools==5.5.2 +cbor2==5.6.5 +certifi==2025.8.3 +cffi==1.17.1 +chardet==5.2.0 +charset-normalizer==3.4.2 +click==8.2.1 +cloudpickle==3.1.1 +codetiming==1.4.0 +color-matcher==0.6.0 +colorama==0.4.6 +colorful==0.5.6 +comfyui-embedded-docs==0.2.3 +comfyui_frontend_package==1.23.4 +comfyui_workflow_templates==0.1.32 +compressed-tensors==0.10.2 +contourpy==1.3.2 +crcmod==1.7 +cryptography==45.0.5 +cuda-bindings==12.9.0 +cuda-python==12.9.0 +cupy-cuda12x==13.4.1 +cycler==0.12.1 +datasets==4.0.0 +ddt==1.7.2 +decorator==5.2.1 +Deprecated==1.2.18 +depyf==0.19.0 +diffusers==0.34.0 +dill==0.3.8 +diskcache==5.6.3 +distlib==0.3.9 +distro==1.9.0 +dnspython==2.7.0 +docutils==0.21.2 +einops==0.8.1 +email_validator==2.2.0 +fastapi==0.115.14 +fastapi-cli==0.0.7 +fastrlock==0.8.3 +ffmpeg-python==0.2.0 +# filelock==3.19.1 +flash_attn==2.8.0.post2 +flatbuffers==25.2.10 +fonttools==4.58.4 +frozenlist==1.7.0 +fsspec==2025.3.0 +ftfy==6.3.1 +future==1.0.0 +gguf==0.17.1 +gitdb==4.0.12 +GitPython==3.1.44 +google-api-core==2.25.1 +google-auth==2.40.3 +googleapis-common-protos==1.70.0 +greenlet==3.2.3 +grpcio==1.73.1 +h11==0.16.0 +heavyball==1.7.2 +hf-xet==1.1.5 +httpcore==1.0.9 +httptools==0.6.4 +httpx==0.28.1 +huggingface-hub==0.34.1 +hydra-core==1.3.2 +idna==3.10 +imageio==2.37.0 +# importlib_metadata==8.7.0 +interegular==0.3.3 +jax==0.7.0 +jaxlib==0.7.0 +Jinja2==3.1.6 +jiter==0.10.0 +jmespath==0.10.0 +joblib==1.5.1 +jsonschema==4.24.0 +jsonschema-specifications==2025.4.1 +kiwisolver==1.4.8 +kornia==0.8.1 +kornia_rs==0.1.9 +lark==1.2.2 +lazy_loader==0.4 +librosa==0.11.0 +liger_kernel==0.5.10 +llguidance==0.7.30 +llvmlite==0.44.0 +lm-format-enforcer==0.10.11 +Mako==1.3.10 +Markdown==3.8.2 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +mathruler==0.1.0 +matplotlib==3.10.3 +matrix-client==0.4.0 +mdurl==0.1.2 +mediapipe==0.10.21 +mistral_common==1.8.3 +ml_dtypes==0.5.3 +model-index==0.1.11 +mpi4py==4.1.0 +mpmath==1.3.0 +msgpack==1.1.1 +msgspec==0.19.0 +multidict==6.6.0 +multiprocess==0.70.16 +natsort==8.4.0 +nest-asyncio==1.6.0 +networkx==3.5 +ninja==1.11.1.4 +numba==0.61.2 +numpy==1.26.4 +nvidia-cublas-cu12==12.6.4.1 +nvidia-cuda-cupti-cu12==12.6.80 +nvidia-cuda-nvrtc-cu12==12.6.77 +nvidia-cuda-runtime-cu12==12.6.77 +nvidia-cudnn-cu12==9.5.1.17 +nvidia-cufft-cu12==11.3.0.4 +nvidia-cufile-cu12==1.11.1.6 +nvidia-curand-cu12==10.3.7.77 +nvidia-cusolver-cu12==11.7.1.2 +nvidia-cusparse-cu12==12.5.4.2 +nvidia-cusparselt-cu12==0.6.3 +nvidia-ml-py==12.575.51 +nvidia-nccl-cu12==2.26.2 +nvidia-nvjitlink-cu12==12.6.85 +nvidia-nvshmem-cu12==3.3.9 +nvidia-nvtx-cu12==12.6.77 +olefile==0.47 +omegaconf==2.3.0 +openai==1.90.0 +opencensus==0.11.4 +opencensus-context==0.1.3 +opencv-contrib-python==4.11.0.86 +opencv-python-headless==4.11.0.86 +opendatalab==0.0.10 +openmim==0.3.9 +opentelemetry-api==1.26.0 +opentelemetry-exporter-otlp-proto-grpc==1.26.0 +opentelemetry-exporter-otlp-proto-http==1.26.0 +opentelemetry-exporter-prometheus==0.41b0 +opentelemetry-proto==1.26.0 +opentelemetry-sdk==1.26.0 +# opentelemetry-semantic-conventions==0.55b1 +openxlab==0.1.2 +opt_einsum==3.4.0 +ordered-set==4.1.0 +orjson==3.10.18 +oss2==2.17.0 +outlines==0.1.11 +outlines_core==0.2.10 +packaging==24.2 +pandas==2.3.0 +partial-json-parser==0.2.1.1.post6 +peft==0.15.2 +piexif==1.1.3 +pillow==11.2.1 +pip==25.1 +platformdirs==4.3.8 +pooch==1.8.2 +prometheus_client==0.22.1 +prometheus-fastapi-instrumentator==7.1.0 +propcache==0.3.2 +proto-plus==1.26.1 +protobuf==4.25.8 +psutil==7.0.0 +py-cpuinfo==9.0.0 +py-spy==0.4.0 +pyarrow==20.0.0 +pyasn1==0.6.1 +pyasn1_modules==0.4.2 +pybase64==1.4.1 +pycountry==24.6.1 +pycparser==2.22 +pycryptodome==3.23.0 +pydantic==2.11.7 +pydantic_core==2.33.2 +pydantic-extra-types==2.10.5 +pydantic-settings==2.10.1 +PyGithub==2.6.1 +Pygments==2.19.2 +PyJWT==2.10.1 +pylatexenc==2.10 +pyloudnorm==0.1.1 +PyNaCl==1.5.0 +pynvml==12.0.0 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +python-json-logger==4.0.0.dev0 +python-multipart==0.0.20 +pytz==2023.4 +PyYAML==6.0.2 +pyzmq==27.0.0 +qwen-vl-utils==0.0.11 +ray==2.47.1 +referencing==0.36.2 +regex==2024.11.6 +# requests==2.28.2 +rich==14.1.0 +rich-toolkit==0.14.7 +rpds-py==0.25.1 +rsa==4.9.1 +safetensors==0.6.0rc0 +# sageattention==2.2.0 +scikit-image==0.25.2 +scikit-learn==1.7.0 +scipy==1.16.0 +sentencepiece==0.2.0 +sentry-sdk==2.32.0 +setproctitle==1.3.6 +setuptools==79.0.1 +shellingham==1.5.4 +six==1.17.0 +smart-open==7.1.0 +smmap==5.0.2 +sniffio==1.3.1 +sounddevice==0.5.2 +soundfile==0.13.1 +soupsieve==2.7 +soxr==0.5.0.post1 +spandrel==0.4.1 +SQLAlchemy==2.0.41 +starlette==0.46.2 +sympy==1.14.0 +tabulate==0.9.0 +tensorboard==2.19.0 +tensorboard-data-server==0.7.2 +tensordict==0.8.3 +threadpoolctl==3.6.0 +tifffile==2025.6.11 +tiktoken==0.9.0 +timm==1.0.16 +tokenizers==0.21.2 +toml==0.10.2 +# torchaudio==2.7.1 +# torchcodec==0.4.0+cu126 +torchdata==0.11.0 +torchsde==0.2.6 +# torchvision==0.22.1 +# tqdm==4.65.2 +trampoline==0.1.2 +transformers==4.54.0 +triton==3.3.1 +typer==0.16.0 +typing_extensions==4.14.0 +typing-inspection==0.4.1 +tzdata==2025.2 +ujson==5.10.0 +urllib3==1.26.20 +uv==0.7.19 +uvicorn==0.34.3 +uvloop==0.21.0 +virtualenv==20.31.2 +vllm==0.8.4 +wandb==0.20.1 +watchdog==6.0.0 +watchfiles==1.1.0 +wcwidth==0.2.13 +websockets==15.0.1 +Werkzeug==3.1.3 +wfdb==4.3.0 +wheel==0.45.1 \ No newline at end of file diff --git a/examples/grpo_trainer/_debug_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh b/examples/grpo_trainer/_debug_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh new file mode 100755 index 00000000000..3522b2fea39 --- /dev/null +++ b/examples/grpo_trainer/_debug_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh @@ -0,0 +1,121 @@ +set -x + +unset ROCR_VISIBLE_DEVICES + +# actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct +# actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B +# data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.modalities=\'audio,videos\' \ + +# SETTING OF SAVE PATH: trainer.default_local_dir= /scratch/keane/human_behaviour/2_models_hb_vision_only +# SETTING OF THE LOAD PATH from directory of checkpoints is also: trainer.default_local_dir + +# TRAINING FROM scratch: trainer.resume_mode == "disable" (default will save into default_local_dir) + +# TRAINING AUTOMATICALLY (i.e. from scratch or from latest checkpoint) : + # trainer.resume_mode == "auto" and then the model will take the latest ckpt from trainer.default_hdfs_dir + +# TRAINING from specific CHECKPOINT: trainer.resume_mode == "resume_path" and then specify trainer.resume_from_path + # Setting of path to resume training from trainer.resume_from_path (exact path of checkpoint) + # the model will take from resume_from_path directly (absolute path), and ignore default_hdfs_dir + +# for validation, set val_before_train=True ; make sure that the checkpoint is loaded and put val_only=True +# the checkpoint should already be loaded before that +# and then we will just evaluate + +# ALTERNATIVES +# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_v3_template_prompts.jsonl +# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_no_chalearn_v3_template_prompts.jsonl + +# when resuming training from a loaded checkpoint cuda OOM error + +# alt: /scratch/keane/human_behaviour/human_behaviour_data/0.1_train_no_lmvd_discretized_v3_template_prompts.jsonl +# org: /scratch/keane/human_behaviour/human_behaviour_data/train_no_lmvd_discretized_v3_template_prompts.jsonl + +# LORA: + # actor_rollout_ref.model.use_shm=True \ + # actor_rollout_ref.model.lora_rank=32 \ + # actor_rollout_ref.model.lora_alpha=32 \ + # actor_rollout_ref.rollout.load_format=safetensors \ + # actor_rollout_ref.model.target_modules=all-linear \ + # actor_rollout_ref.rollout.layered_summon=True \ + +# Set PyTorch CUDA memory allocator policies +# export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb=128 + +# data: +# train_batch_size: 8 +# val_batch_size: 8 + +# NOTE: THESE NEED TO BE TOGGLED AS TRUE +# train_modality_batching: +# enabled: true +# drop_last: true + +# val_modality_batching: +# enabled: true +# drop_last: false + + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 PYTHONPATH="/home/keaneong/human-behavior/verl:$PYTHONPATH" NCCL_ASYNC_ERROR_HANDLING=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_train_no_lmvd_discretized_v3_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_val_no_lmvd_discretized_v3_template_prompts.jsonl \ + data.train_batch_size=6 \ + data.val_batch_size=6 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.dataloader_num_workers=8 \ + data.modalities=\'audio,videos\' \ + data.train_modality_batching.enabled=True \ + data.train_modality_batching.drop_last=True \ + data.val_modality_batching.enabled=True \ + data.val_modality_batching.drop_last=True \ + data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=3 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.max_model_len=4096 \ + actor_rollout_ref.rollout.max_num_batched_tokens=4096 \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/human_behaviour.py \ + custom_reward_function.name=human_behaviour_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_hb' \ + trainer.experiment_name='mixed_modal_omni' \ + trainer.n_gpus_per_node=3 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.total_epochs=1 $@ \ + trainer.default_local_dir=/scratch/keane/human_behaviour/mixed_modal_verl_models_hb_omni \ No newline at end of file diff --git a/examples/grpo_trainer/_human_behaviour.sh b/examples/grpo_trainer/_human_behaviour.sh new file mode 100755 index 00000000000..512a4e845b1 --- /dev/null +++ b/examples/grpo_trainer/_human_behaviour.sh @@ -0,0 +1,55 @@ +set -x +ENGINE=${1:-vllm} + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ + data.train_batch_size=512 \ + data.val_batch_size=128 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_human_behaviour' \ + trainer.experiment_name='qwen2_5_vl_7b_function_trial' \ + trainer.n_gpus_per_node= 2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh b/examples/grpo_trainer/_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh new file mode 100755 index 00000000000..9bebaf58293 --- /dev/null +++ b/examples/grpo_trainer/_keane_run_qwen2_5_omni-7b_hb_all_modalities.sh @@ -0,0 +1,120 @@ +set -x + +unset ROCR_VISIBLE_DEVICES + +# actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct +# actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B +# data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.modalities=\'audio,videos\' \ + +# SETTING OF SAVE PATH: trainer.default_local_dir= /scratch/keane/human_behaviour/2_models_hb_vision_only +# SETTING OF THE LOAD PATH from directory of checkpoints is also: trainer.default_local_dir + +# TRAINING FROM scratch: trainer.resume_mode == "disable" (default will save into default_local_dir) + +# TRAINING AUTOMATICALLY (i.e. from scratch or from latest checkpoint) : + # trainer.resume_mode == "auto" and then the model will take the latest ckpt from trainer.default_hdfs_dir + +# TRAINING from specific CHECKPOINT: trainer.resume_mode == "resume_path" and then specify trainer.resume_from_path + # Setting of path to resume training from trainer.resume_from_path (exact path of checkpoint) + # the model will take from resume_from_path directly (absolute path), and ignore default_hdfs_dir + +# for validation, set val_before_train=True ; make sure that the checkpoint is loaded and put val_only=True +# the checkpoint should already be loaded before that +# and then we will just evaluate + +# ALTERNATIVES +# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_v3_template_prompts.jsonl +# /scratch/keane/human_behaviour/human_behaviour_data/discretized_no_lmvd_no_chsimsv2_no_chalearn_v3_template_prompts.jsonl + +# when resuming training from a loaded checkpoint cuda OOM error + +# alt: /scratch/keane/human_behaviour/human_behaviour_data/0.1_train_no_lmvd_discretized_v3_template_prompts.jsonl +# org: /scratch/keane/human_behaviour/human_behaviour_data/train_no_lmvd_discretized_v3_template_prompts.jsonl + +# LORA: + # actor_rollout_ref.model.use_shm=True \ + # actor_rollout_ref.model.lora_rank=32 \ + # actor_rollout_ref.model.lora_alpha=32 \ + # actor_rollout_ref.rollout.load_format=safetensors \ + # actor_rollout_ref.model.target_modules=all-linear \ + # actor_rollout_ref.rollout.layered_summon=True \ + +# Set PyTorch CUDA memory allocator policies +# export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb=128 + +# data: +# train_batch_size: 8 +# val_batch_size: 8 + +# train_modality_batching: +# enabled: true +# drop_last: true + +# val_modality_batching: +# enabled: true +# drop_last: false + + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 PYTHONPATH="/home/keaneong/human-behavior/verl:$PYTHONPATH" NCCL_ASYNC_ERROR_HANDLING=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_train_no_lmvd_discretized_v3_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/sigs_val_no_lmvd_discretized_v3_template_prompts.jsonl \ + data.train_batch_size=288 \ + data.val_batch_size=144 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.dataloader_num_workers=8 \ + data.modalities=\'audio,videos\' \ + data.train_modality_batching.enabled=True \ + data.train_modality_batching.drop_last=True \ + data.val_modality_batching.enabled=True \ + data.val_modality_batching.drop_last=True \ + data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=72 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.max_model_len=4096 \ + actor_rollout_ref.rollout.max_num_batched_tokens=4096 \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/human_behaviour.py \ + custom_reward_function.name=human_behaviour_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='mixed_modal_verl_hb' \ + trainer.experiment_name='mixed_modal_omni' \ + trainer.n_gpus_per_node=3 \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ \ + trainer.default_local_dir=/scratch/keane/human_behaviour/mixed_modal_verl_models_hb_omni \ No newline at end of file diff --git a/examples/grpo_trainer/keane_vl_only_run_qwen2_5_vl-7b_hb_all_modalities.sh b/examples/grpo_trainer/keane_vl_only_run_qwen2_5_vl-7b_hb_all_modalities.sh new file mode 100755 index 00000000000..13cf20b279d --- /dev/null +++ b/examples/grpo_trainer/keane_vl_only_run_qwen2_5_vl-7b_hb_all_modalities.sh @@ -0,0 +1,66 @@ +set -x + +unset ROCR_VISIBLE_DEVICES + +# NOTE: be careful when setting filter_overlong_prompts; because this removes the prompts from the max_prompt_length + +# actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct +# actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B +# data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/train_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/val_no_meld_no_chalearn_vision_v2_template_prompts.jsonl \ +# data.modalities=\'audio,videos\' \ + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 PYTHONPATH="/home/keaneong/human-behavior/verl:$PYTHONPATH" NCCL_ASYNC_ERROR_HANDLING=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/old_train_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/old_val_template_prompts.jsonl \ + data.train_batch_size=64 \ + data.val_batch_size=64 \ + data.max_prompt_length=3072 \ + data.max_response_length=1536 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.dataloader_num_workers=0 \ + data.modalities=\'videos\' \ + data.format_prompt=/home/keaneong/human-behavior/verl/examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=3 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=/home/keaneong/human-behavior/verl/examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_hb' \ + trainer.experiment_name='vision_only' \ + trainer.n_gpus_per_node=3 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=1 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh b/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh old mode 100644 new mode 100755 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_hb.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb.sh new file mode 100755 index 00000000000..3f1f0796af4 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb.sh @@ -0,0 +1,53 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/old_train_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/old_val_template_prompts.jsonl \ + data.train_batch_size=512 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.format_prompt=examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_hb' \ + trainer.experiment_name='vision_only' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_hb_all_modalities.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb_all_modalities.sh new file mode 100755 index 00000000000..fd2d60ba553 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_hb_all_modalities.sh @@ -0,0 +1,56 @@ +set -x + +PYTHONUNBUFFERED=1 HYDRA_FULL_ERROR=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/scratch/keane/human_behaviour/human_behaviour_data/old_train_template_prompts.jsonl \ + data.val_files=/scratch/keane/human_behaviour/human_behaviour_data/old_val_template_prompts.jsonl \ + data.train_batch_size=128 \ + data.val_batch_size=128 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=False \ + data.truncation='left' \ + data.image_key=images \ + data.video_key=videos \ + data.prompt_key=problem \ + data.dataloader_num_workers=0 \ + data.modalities=\'audio,videos\' \ + data.format_prompt=examples/format_prompt/default.jinja \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-Omni-7B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=1e-8 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + custom_reward_function.path=examples/reward_function/medical.py \ + custom_reward_function.name=medical_compute_score_batch \ + reward_model.reward_manager=batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_hb' \ + trainer.experiment_name='vision_only' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/reward_function/evaluation.py b/examples/reward_function/evaluation.py index 3d20e757a8a..716efaa2e98 100644 --- a/examples/reward_function/evaluation.py +++ b/examples/reward_function/evaluation.py @@ -253,9 +253,6 @@ def parent(predictions: List[str], ground_truths: List[str], demographics: List[ print(f"std of accuracy for parent = {std_acc:.4f}") if len(f1_values) >= 2: - f1_diff = max(f1_values) - min(f1_values) - results["f1_diff"] = f1_diff - print(f"F1 max diff for parent = {f1_diff:.4f}") f1_std = statistics.stdev(f1_values) results["f1_std"] = f1_std print(f"std of f1 for parent = {f1_std:.4f}") diff --git a/examples/reward_function/human_behaviour.py b/examples/reward_function/human_behaviour.py new file mode 100644 index 00000000000..9221dbf7eda --- /dev/null +++ b/examples/reward_function/human_behaviour.py @@ -0,0 +1,103 @@ +from typing import List, Dict +import re + +def extract_boxed_content(text: str) -> str: + """ + Extract content within \boxed{} or similar boxing notations. + + Args: + text (str): Text containing potentially boxed content. + + Returns: + str: Extracted boxed content or the original text if no box found. + """ + + # Look for LaTeX \boxed{} notation + boxed_match = re.search(r"\\boxed{([^}]*)}", text) + if boxed_match: + return boxed_match.group(1) + + # Look for markdown boxed notation (e.g., [boxed content]) + markdown_match = re.search(r"\[(.*?)\]", text) + if markdown_match: + return markdown_match.group(1) + + # Return the text as is if no boxed content is found + return text + +def format_reward(response: str) -> float: + """ + Check whether the response matches the expected format. + Here we require something like ... ... \boxed{...} + """ + pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) + format_match = re.fullmatch(pattern, response) + return 1.0 if format_match else 0.0 + +def accuracy_reward(response: str, ground_truth: str) -> float: + """ + Simple accuracy: exact match to ground truth string. + """ + return 1.0 if response == ground_truth else 0.0 + +def human_behaviour_compute_score_batch( + data_sources: List[str], + solution_strs: List[str], + ground_truths: List[str], + extra_infos: List[str], + **kwargs +) -> List[Dict[str, float]]: + """ + Compute human behaviour scoring for batch inputs. + + Args: + data_sources: List of data sources (unused here, but kept for interface compatibility) + solution_strs: List of model prediction strings + ground_truths: List of ground truth strings + extra_infos: List of extra information (unused here, kept for compatibility) + + Returns: + List of score dictionaries + """ + batch_scores = [] + format_weight = 0.1 + + for data_source, predict_str, ground_truth, extra_info in zip(data_sources, solution_strs, ground_truths, extra_infos): + # Normalize response formatting (e.g., qwen2.5vl quirks) + full_response = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) + pred_label = extract_boxed_content(full_response).lower() # handle qwen2.5vl-32b format + + print(pred_label) + # Compute individual components + format_score = format_reward(full_response) + standard_score = accuracy_reward(pred_label, ground_truth) + + ground_truth = ground_truth.lower() + + # Weighted overall score + overall_score = (1 - format_weight) * standard_score + format_weight * format_score + + scores = { + "score": overall_score, + "standard_score": standard_score, + "format_score": format_score, + } + batch_scores.append(scores) + + return batch_scores + + +if __name__ == "__main__": + response_str = ( + "Well, I've listened to the speech recording. It sounds like the speaker is expressing anger. " + "You know, the tone and the way the words are said seem to indicate frustration or annoyance. " + "So, I'd say the emotion is anger.\\boxed{anger}If you have any other questions or need more help, feel free to let me know." + ) + + data_sources = ["sample_audio.wav"] + solution_strs = [response_str] + ground_truths = ["anger"] + extra_infos = [""] + + scores = human_behaviour_compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos) + print(scores) diff --git a/examples/reward_function/human_behaviour_alt.py b/examples/reward_function/human_behaviour_alt.py new file mode 100644 index 00000000000..765dbeb7b16 --- /dev/null +++ b/examples/reward_function/human_behaviour_alt.py @@ -0,0 +1,140 @@ +import re +import json +from typing import Dict, List + +import numpy +import torch +import numpy as np +from mathruler.grader import extract_boxed_content +import wandb +import random + + +def parse_conditions(text): + # Remove any boxing notation if present + text = text.replace("\\boxed{", "").replace("}", "") + + # Split by common separators + for sep in [", ", " and ", " & ", ",", "&"]: + if sep in text: + return set(cond.strip() for cond in text.split(sep)) + + # If no separator found, treat as single condition + return {text.strip()} + + +def parse_json(json_output): + """ + Parsing out the markdown fencing from JSON code blocks. + """ + # Look for content between ```json and ``` + lines = json_output.splitlines() + for i, line in enumerate(lines): + if line == "```json" or line.strip() == "```": + json_output = "\n".join(lines[i + 1:]) # Remove everything before ```json + if "```" in json_output: + json_output = json_output.split("```")[0] # Remove everything after the closing ``` + break # Exit the loop once code block marker is found + return json_output + + +def extract_json_from_response(text): + """ + Extract JSON content from markdown code blocks in the response. + + Args: + text: The model's response text + + Returns: + Parsed JSON object or None if no valid JSON found + """ + # Find content between ```json and ``` + json_pattern = r"```(?:json)?\s*([\s\S]*?)```" + matches = re.findall(json_pattern, text) + + if not matches: + return None + + # Try to parse each match as JSON + for match in matches: + try: + parsed_json = json.loads(match.strip()) + return parsed_json + except json.JSONDecodeError: + continue + + # If we couldn't parse any match as valid JSON, try with ast.literal_eval + import ast + for match in matches: + try: + # Clean up the match a bit + cleaned = match.strip().replace("'", "\"") + parsed_json = ast.literal_eval(cleaned) + return parsed_json + except: + continue + + return None + + +def medical_compute_score_batch(data_sources: List[str], solution_strs: List[str], ground_truths: List[str], extra_infos: List[str], **kwargs) -> List[Dict[str, float]]: + """ + Compute medical scoring for batch inputs including standard score, bounding box IoU, and format score. + + Args: + data_sources: List of data sources (e.g., file paths or identifiers) + solution_strs: List of model prediction strings + ground_truths: List of ground truth strings + extra_infos: List of extra information (e.g., segmentation masks, bounding boxes) + + Returns: + List of score dictionaries + """ + batch_scores = [] + + for data_source, predict_str, ground_truth, extra_info in zip(data_sources, solution_strs, ground_truths, extra_infos): + segmentation_mask = None + bbox = None + + # Calculate standard score + answer = extract_boxed_content(predict_str) + if answer == "None": + standard_score = 0.0 # no answer + else: + # Parse both prediction and ground truth into sets of conditions + predicted_conditions = parse_conditions(answer) + ground_truth_conditions = parse_conditions(ground_truth) + + # Calculate true positives, false positives, and false negatives + true_positives = len(predicted_conditions.intersection(ground_truth_conditions)) + false_positives = len(predicted_conditions - ground_truth_conditions) + false_negatives = len(ground_truth_conditions - predicted_conditions) + + # Calculate F1 score components + precision = ( + true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 + ) + recall = true_positives / (true_positives + false_negatives) if ( + true_positives + false_negatives) > 0 else 0 + + # Calculate F1 score (harmonic mean of precision and recall) + standard_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # Calculate format score (how well the JSON follows the expected format) + format_score = evaluate_bbox_format(predict_str) + + # length score + if len(predict_str) > 600: # ~200 words + length_score = 1 + else: + length_score = len(predict_str) * 0.001 + + scores = { + "score": 0.5 * standard_score + 0.3 * iou_score + 0.1 * format_score, + "standard_score": standard_score, + "format_score": format_score, + "length_score": length_score, + } + batch_scores.append(scores) + + return batch_scores \ No newline at end of file diff --git a/examples/reward_function/math.py b/examples/reward_function/math.py new file mode 100644 index 00000000000..f39b32a5699 --- /dev/null +++ b/examples/reward_function/math.py @@ -0,0 +1,49 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any + +from mathruler.grader import extract_boxed_content, grade_answer + + +def format_reward(response: str) -> float: + pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) + format_match = re.fullmatch(pattern, response) + return 1.0 if format_match else 0.0 + + +def accuracy_reward(response: str, ground_truth: str) -> float: + answer = extract_boxed_content(response) + return 1.0 if grade_answer(answer, ground_truth) else 0.0 + + +def compute_score(reward_inputs: list[dict[str, Any]], format_weight: float = 0.1) -> list[dict[str, float]]: + if not isinstance(reward_inputs, list): + raise ValueError("Please use `reward_type=batch` for math reward function.") + + scores = [] + for reward_input in reward_inputs: + response = re.sub(r"\s*(<|>|/)\s*", r"\1", reward_input["response"]) # handle qwen2.5vl-32b format + format_score = format_reward(response) + accuracy_score = accuracy_reward(response, reward_input["ground_truth"]) + scores.append( + { + "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, + "format": format_score, + "accuracy": accuracy_score, + } + ) + + return scores \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8ee2c46ff0f..ab17283f24d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ accelerate codetiming datasets dill -flash-attn +# flash-attn hydra-core liger-kernel numpy @@ -17,6 +17,7 @@ ray[default] tensordict torchdata transformers +vllm # vllm==0.8.4 wandb packaging diff --git a/verl/model_merger/base_model_merger.py b/verl/model_merger/base_model_merger.py index b46f40f879b..185adbd7e74 100644 --- a/verl/model_merger/base_model_merger.py +++ b/verl/model_merger/base_model_merger.py @@ -188,7 +188,7 @@ def __init__(self, config: ModelMergerConfig): self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code ) - def get_transformers_auto_model_class(self): + def get_transformers_auto_model_class(self): has_remote_code = hasattr(self.model_config, "auto_map") and any( self.model_config.architectures[0] in val for val in self.model_config.auto_map.values() ) @@ -208,6 +208,9 @@ def get_transformers_auto_model_class(self): else: if "ForTokenClassification" in self.model_config.architectures[0]: return AutoModelForTokenClassification + elif "Qwen2.5-Omni" in self.hf_model_config_path: + from transformers import Qwen2_5OmniThinkerForConditionalGeneration + return Qwen2_5OmniThinkerForConditionalGeneration elif "ForCausalLM" in self.model_config.architectures[0]: return AutoModelForCausalLM elif "ForConditionalGeneration" in self.model_config.architectures[0]: diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 0cc7820d114..e9c73a87f01 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -73,9 +73,20 @@ def get_rope_index( """ spatial_merge_size = processor.image_processor.merge_size tokens_per_second = 2 + + # Try old token names first, then fall back to new ones image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + if image_token_id is None: + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_pad|>") # New tokenizer uses vision_pad for images + video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") + if video_token_id is None: + video_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_pad|>") # New tokenizer uses vision_pad for videos + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") + if vision_start_token_id is None: + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_bos|>") # New tokenizer uses vision_bos + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) @@ -84,10 +95,20 @@ def get_rope_index( image_index, video_index = 0, 0 input_ids = input_ids[attention_mask == 1] image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() + + # Ensure vision_start_token_id is valid before comparison + if vision_start_token_id is not None: + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) + else: + vision_start_indices = torch.empty((0, 1), dtype=torch.long, device=input_ids.device) + + # Handle case where there are vision tokens + if vision_start_indices.numel() > 0: + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum().item() if image_token_id is not None else 0 + video_nums = (vision_tokens == video_token_id).sum().item() if video_token_id is not None else 0 + else: + image_nums, video_nums = 0, 0 input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index fe22ae148e3..bd141bc9c54 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -285,6 +285,14 @@ data: truncation: error image_key: images video_key: videos + audio_key: audios + modalities: images,videos + train_modality_batching: + enabled: false + drop_last: false + val_modality_batching: + enabled: false + drop_last: false trust_remote_code: false custom_cls: path: null diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 43d0a28878a..8e30ee74892 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -282,6 +282,14 @@ data: truncation: error image_key: images video_key: videos + audio_key: audios + modalities: images,videos # list of modalities to process + train_modality_batching: + enabled: false + drop_last: false + val_modality_batching: + enabled: false + drop_last: false trust_remote_code: false custom_cls: path: null diff --git a/verl/trainer/config/data/legacy_data.yaml b/verl/trainer/config/data/legacy_data.yaml index f6bd5aa26a2..06809587be2 100644 --- a/verl/trainer/config/data/legacy_data.yaml +++ b/verl/trainer/config/data/legacy_data.yaml @@ -77,6 +77,21 @@ image_key: images # The field in the multi-modal dataset where the video is located. video_key: videos +# The field in the multi-modal dataset where the audio is located. Default is 'audios'. +audio_key: audios + +# Comma-separated list of modalities to process. Default is 'images,videos'. +# Available modalities: images, videos, audios +# Example: 'images,videos,audios' to enable all modalities +modalities: images,videos + +train_modality_batching: + enabled: false + drop_last: false +val_modality_batching: + enabled: false + drop_last: false + # If the remote tokenizer has a Python file, this flag determines whether to allow using it. trust_remote_code: False diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index 8622cb68790..16f19c6fd48 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -50,10 +50,10 @@ free_cache_engine: True tensor_model_parallel_size: 2 # max number of tokens in a batch -max_num_batched_tokens: 8192 +max_num_batched_tokens: 1536 # max length for rollout -max_model_len: null +max_model_len: 1536 # max length of sequences max_num_seqs: 1024 diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index fcf748ba99e..f16d30987ca 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -22,6 +22,7 @@ import ray from omegaconf import OmegaConf +from typing import Dict, List from verl.experimental.dataset.sampler import AbstractSampler from verl.trainer.constants_ppo import get_ppo_ray_runtime_env from verl.trainer.ppo.ray_trainer import RayPPOTrainer @@ -30,6 +31,7 @@ from verl.utils.config import validate_config from verl.utils.device import is_cuda_available from verl.utils.import_utils import load_extern_type +from verl.utils.dataset.modality_sampler import ModalitySignatureBatchSampler @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) @@ -82,6 +84,8 @@ def run_ppo(config) -> None: runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() else: runner = TaskRunner.remote() + + # RUN THE TRAINING using runner.run ray.get(runner.run.remote(config)) # [Optional] get the path of the timeline trace file from the configuration, default to None @@ -282,11 +286,17 @@ def run(self, config): from verl.utils.dataset.rl_dataset import collate_fn # Create training and validation datasets. + # This is done by reading from the train and the val files train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) - train_sampler = create_rl_sampler(config.data, train_dataset) + train_sampler = create_rl_sampler(config.data, train_dataset, split="train") + + # print(f"Using train sampler: {train_sampler}") + # print(f"Using val dataset: {val_dataset}") + # print(f"Using train dataset: {train_dataset}") # Initialize the PPO trainer. + # TODO: train sampler is fed into this; and is used to shuffle the training dataset (and extending to validation) trainer = RayPPOTrainer( config=config, tokenizer=tokenizer, @@ -356,8 +366,50 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr return dataset - -def create_rl_sampler(data_config, dataset): +# NOTE: This is the old rl_sampler +# def create_rl_sampler(data_config, dataset): +# """Create a sampler for the dataset. + +# Arguments: +# data_config: The data config. +# dataset (Dataset): The dataset. + +# Returns: +# sampler (Sampler): The sampler. +# """ +# import torch +# from torch.utils.data import RandomSampler, SequentialSampler + +# if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: +# curriculum_class = load_extern_type( +# data_config.sampler.class_path, +# data_config.sampler.class_name, +# ) +# sampler = curriculum_class( +# data_source=dataset, +# data_config=data_config, +# ) +# assert isinstance(sampler, AbstractSampler) +# assert data_config.get("dataloader_num_workers", 8) == 0, ( +# "If using curriculum, num_workers must be 0 to prevent data caching. " +# "If the dataloader caches data before the batch is done the " +# "curriculum sampler won't have the opportunity to reorder it. " +# ) + +# # Use a sampler to facilitate checkpoint resumption. +# # If shuffling is enabled in the data configuration, create a random sampler. +# elif data_config.shuffle: +# train_dataloader_generator = torch.Generator() +# train_dataloader_generator.manual_seed(data_config.get("seed", 1)) +# sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) +# else: +# # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. +# sampler = SequentialSampler(data_source=dataset) + +# return sampler + +# NOTE: This is your implementation +def create_rl_sampler(data_config, dataset, split: str = "train"): """Create a sampler for the dataset. Arguments: @@ -370,6 +422,10 @@ def create_rl_sampler(data_config, dataset): import torch from torch.utils.data import RandomSampler, SequentialSampler + # modality batching config parse + mb_cfg = data_config.get("train_modality_batching") if split == "train" \ + else data_config.get("val_modality_batching") + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: curriculum_class = load_extern_type( data_config.sampler.class_path, @@ -386,18 +442,52 @@ def create_rl_sampler(data_config, dataset): "curriculum sampler won't have the opportunity to reorder it. " ) + if mb_cfg and mb_cfg.get("enabled", False): + print(f"Creating our modality sampler for split: {split}") + # by_sig is actually the collation of dataset indices grouped by their modality signature + by_sig: Dict[str, List[int]] = {} + # essentially getting "modality_signature" from the jsonl dataset + for i in range(len(dataset)): + row = dataset.dataframe[i] if hasattr(dataset, "dataframe") else dataset[i] + sig = row.get("modality_signature") + if sig is None: + print(f"[WARNING] Row {i} missing 'modality_signature'. Skipping.") + continue + by_sig.setdefault(sig, []).append(i) + + # batch_size = mb_cfg.get("batch_size", data_config.get( + # "train_batch_size" if split=="train" else "val_batch_size" + # )) + + batch_size = data_config.get("train_batch_size" if split=="train" else "val_batch_size") + + drop_last = mb_cfg.get("drop_last") + + # shuffle if split (meaning that we shuffle the samples within each batch) + shuffle = (split == "train") + + print(f"Creating our modality sampler for split: {split}, batch_size: {batch_size}, drop_last: {drop_last}, shuffle: {shuffle}") + + sampler = ModalitySignatureBatchSampler( + indices_by_sig=by_sig, + batch_size=int(batch_size), + drop_last=drop_last, + seed=data_config.get("seed", 42), + shuffle=shuffle, + ) + # Use a sampler to facilitate checkpoint resumption. # If shuffling is enabled in the data configuration, create a random sampler. - elif data_config.shuffle: + elif data_config.shuffle and split == "train": train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(data_config.get("seed", 1)) sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. sampler = SequentialSampler(data_source=dataset) return sampler - if __name__ == "__main__": main() diff --git a/verl/trainer/ppo/org_functions.py b/verl/trainer/ppo/org_functions.py new file mode 100644 index 00000000000..f9e3eb86248 --- /dev/null +++ b/verl/trainer/ppo/org_functions.py @@ -0,0 +1,78 @@ +def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset( + self.config.data.train_files, self.config.data, self.tokenizer, self.processor + ) + if val_dataset is None: + val_dataset = create_rl_dataset( + self.config.data.val_files, self.config.data, self.tokenizer, self.processor + ) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + # TODO; you can essentially specify the type of sampler here, based also on the data split + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + num_workers = self.config.data["dataloader_num_workers"] + + ## TODO: trainer_sampler is pretty much the rl_sampler here, which shuffles the dataset + # TODO: the sampler is now placed into this (which is essentially your sampler) + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + # shuffle=False, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + # TODO: validation data is shuffled here as well + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") \ No newline at end of file diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 620f4b050d7..6ed2244e500 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -34,7 +34,7 @@ import ujson import wandb from omegaconf import OmegaConf, open_dict -from torch.utils.data import Dataset, Sampler +from torch.utils.data import Dataset, Sampler, BatchSampler, SequentialSampler from torchdata.stateful_dataloader import StatefulDataLoader from tqdm import tqdm @@ -64,6 +64,23 @@ from verl.utils.tracking import ValidationGenerationsLogger from examples.reward_function.evaluation import compute_metrics_by_data_source +WorkerType = type[Worker] + +debug_file = "/home/keaneong/human-behavior/verl/examples/grpo_trainer/debug_log.txt" + +class Role(Enum): + """ + To create more roles dynamically, you can subclass Role and add new members + """ + + Actor = 0 + Rollout = 1 + ActorRollout = 2 + Critic = 3 + RefPolicy = 4 + RewardModel = 5 + ActorRolloutRef = 6 + @dataclass class ResourcePoolManager: @@ -327,6 +344,7 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, + val_sampler: Optional[Sampler] = None, device_name=None, ): """ @@ -402,7 +420,11 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl self.train_dataset, self.val_dataset = train_dataset, val_dataset if train_sampler is None: - train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + train_sampler = create_rl_sampler(self.config.data, self.train_dataset, split="train") + + if val_sampler is None: + val_sampler = create_rl_sampler(self.config.data, self.val_dataset, split="val") + if collate_fn is None: from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn @@ -410,16 +432,41 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl num_workers = self.config.data["dataloader_num_workers"] - self.train_dataloader = StatefulDataLoader( - dataset=self.train_dataset, - batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), - num_workers=num_workers, - drop_last=True, - collate_fn=collate_fn, - sampler=train_sampler, - ) - val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if isinstance(train_sampler, BatchSampler): + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_sampler=train_sampler, + num_workers=num_workers, + collate_fn=collate_fn, + ) + else: + # Else if it is not a batch sampler, we can specify the batch size directly + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + # shuffle=False, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + if isinstance(val_sampler, BatchSampler): + # BatchSampler path: DO NOT pass batch_size/shuffle/drop_last + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_sampler=val_sampler, + num_workers=num_workers, + collate_fn=collate_fn, + ) + else: + # Plain Sampler path: compute val_batch_size (None -> len(dataset)) + # This plain sampler path, if you trace the instance of val_sampler, + # should be that of a sequential sampler. Break if it is not. + if not isinstance(val_sampler, SequentialSampler): + raise ValueError("Validation sampler is not a SequentialSampler") + + val_batch_size = self.config.data.val_batch_size if val_batch_size is None: val_batch_size = len(self.val_dataset) @@ -839,6 +886,8 @@ def init_workers(self): ) def _save_checkpoint(self): + + ## TO SAVE CHECKPOINT from verl.utils.fs import local_mkdir_safe # path: given_path + `/global_step_{global_steps}` + `/actor` @@ -1045,8 +1094,23 @@ def fit(self): ) next_step_profile = False + for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: + # i = 0 + for batch_idx, batch_dict in enumerate(self.train_dataloader): + #--- DEBUG: log batch content into debug_file --- + + + if debug_file is not None: + with open(debug_file, "a", encoding="utf-8") as f: + log_entry = { + "epoch": int(epoch), + "batch_idx": int(batch_idx), + "modality_signatures": batch_dict.get("modality_signatures", []), + "prompts": batch_dict.get("debug_prompts", []), + } + f.write(json.dumps(log_entry, ensure_ascii=False, default=lambda o: o.tolist() if isinstance(o, np.ndarray) else str(o)) + "\n") + metrics = {} timing_raw = {} @@ -1072,10 +1136,15 @@ def fit(self): is_last_step = self.global_steps >= self.total_training_steps + # TODO: double check the gen_batch + # print(f"gen_batch", gen_batch) + # i += 1 + with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw, color="red"): if not self.async_rollout_mode: + # TODO: Fix the audio generation for this gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) else: gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) diff --git a/verl/utils/dataset/audio_utils.py b/verl/utils/dataset/audio_utils.py new file mode 100644 index 00000000000..5a29550922f --- /dev/null +++ b/verl/utils/dataset/audio_utils.py @@ -0,0 +1,175 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union +import torch +import torchaudio + +def process_audio( + audio: Union[str, dict], + processor=None, + max_seconds: float = None # keep audio to this many seconds max +) -> Tuple[torch.Tensor, int]: + """ + Load audio, convert to mono, resample, and clip to max_seconds. + """ + if isinstance(audio, dict): + audio_path = audio.get("audio", audio) + else: + audio_path = audio + + try: + # Load + audio_data, original_sr = torchaudio.load(audio_path) + + # Resample if needed + if processor and hasattr(processor, 'feature_extractor') and hasattr(processor.feature_extractor, 'sampling_rate'): + target_sr = processor.feature_extractor.sampling_rate + else: + target_sr = 16000 + + if original_sr != target_sr: + resampler = torchaudio.transforms.Resample(original_sr, target_sr) + audio_data = resampler(audio_data) + else: + target_sr = original_sr + + # Convert to mono + if audio_data.shape[0] > 1: + audio_data = audio_data.mean(dim=0, keepdim=False) + else: + audio_data = audio_data.squeeze(0) + + # Clip to max_seconds + if max_seconds: + max_samples = int(max_seconds * target_sr) + + print(f"Processing Audio {audio_path}, shape={audio_data.shape}, " + f"sr={target_sr}, max_samples={max_samples}") + # ValueError("Audio was processed") + + if audio_data.shape[0] > max_samples: + print("Clipping audio to max_seconds") + audio_data = audio_data[:max_samples] + raise ValueError("Audio data was clipped to max_seconds") + + return audio_data, target_sr + + except Exception as e: + print(f"Error processing audio {audio_path}: {e}") + dummy_audio = torch.zeros((int(16000 * max_seconds),), dtype=torch.float32) + return dummy_audio, 16000 + + +# def process_audio(audio: str | dict, processor=None) -> Tuple[torch.Tensor, int]: +# if isinstance(audio, dict): +# # TODO: to check whether the keys are correct here +# audio_path = audio.get("audio", audio) +# else: +# audio_path = audio + +# try: +# # Load audio +# # NOTE: accepts waveform and sample rate; +# audio_data, original_sr = torchaudio.load(audio_path) + +# # Get target sampling rate +# # NOTE: sample rate is basically the amount of audio samples captured per second +# # 16000 means 16000 samples are taken in every second +# if processor and hasattr(processor, 'feature_extractor') and hasattr(processor.feature_extractor, +# 'sampling_rate'): +# target_sr = processor.feature_extractor.sampling_rate +# else: + +# target_sr = 16000 +# # print(f"KEANE: Processing audio {audio_path} with sampling rate, {target_sr}") +# # Resample if needed +# # NOTE: This is essentially the resampling of the audio sample rate +# if original_sr != target_sr: +# resampler = torchaudio.transforms.Resample(original_sr, target_sr) +# audio_data = resampler(audio_data) + +# # Convert to mono if stereo +# # NOTE: This is essentially the conversion of stereo audio to mono, so that we only have one channel +# if audio_data.shape[0] > 1: +# audio_data = audio_data.mean(dim=0, keepdim=False) +# else: +# audio_data = audio_data.squeeze(0) + +# # Debug prints +# # print( +# # f"KEANE: Finished processing {audio_path} -> " +# # f"waveform shape {audio_data.shape}, dtype {audio_data.dtype}, " +# # # f"min {audio_data.min().item():.4f}, max {audio_data.max().item():.4f}" +# # ) +# # print(f"KEANE: Returning tuple (waveform, sr={target_sr})") + +# return audio_data, target_sr +# except Exception as e: +# print(f"Error processing audio {audio_path}: {e}") +# dummy_audio = torch.zeros((1000,), dtype=torch.float32) +# return dummy_audio, 16000 + + +# def process_audio(audio: str | dict, processor=None) -> np.ndarray: +# """ +# NOTE: Keane's implementation +# """ +# if isinstance(audio, dict): +# audio_path = audio.get("audio", audio) +# else: +# audio_path = audio + +# try: +# # Load audio -> (channels, time), sample_rate +# audio_data, original_sr = torchaudio.load(audio_path) + +# # Target sampling rate (from processor or default to 16k) +# if ( +# processor +# and hasattr(processor, "feature_extractor") +# and hasattr(processor.feature_extractor, "sampling_rate") +# ): +# target_sr = processor.feature_extractor.sampling_rate +# else: +# target_sr = 16000 + +# print(f"KEANE: Processing audio {audio_path} with target sampling rate {target_sr}") + +# # Resample if needed +# if original_sr != target_sr: +# resampler = torchaudio.transforms.Resample(original_sr, target_sr) +# audio_data = resampler(audio_data) + +# # Convert to mono if stereo +# if audio_data.shape[0] > 1: +# audio_data = audio_data.mean(dim=0, keepdim=False) +# else: +# audio_data = audio_data.squeeze(0) + +# # Convert to numpy float32 (1-D) +# audio_np = audio_data.detach().cpu().numpy().astype(np.float32) + +# print( +# f"KEANE: Finished {audio_path} -> " +# f"waveform shape {audio_np.shape}, dtype {audio_np.dtype}" +# ) + +# # NOTE: we only need to return the numpy array and not the sampling rate + +# return audio_np + +# except Exception as e: +# print(f"Error processing audio {audio_path}: {e}") +# return np.zeros((1000,), dtype=np.float32) # dummy 1-D waveform diff --git a/verl/utils/dataset/modality_sampler.py b/verl/utils/dataset/modality_sampler.py new file mode 100644 index 00000000000..94e34b00b7b --- /dev/null +++ b/verl/utils/dataset/modality_sampler.py @@ -0,0 +1,79 @@ +import random +from typing import Dict, List, Iterator, Optional +from collections import defaultdict, deque +from torch.utils.data import BatchSampler + +class ModalitySignatureBatchSampler(BatchSampler): + """ + Round-robin across modality signatures, pruning exhausted signatures. + - Shuffles within each signature if shuffle=True (train). + - Each yielded batch is homogeneous by modality_signature. + - If a signature runs out of batches, it is removed and RR continues. + """ + def __init__( + self, + indices_by_sig: Dict[str, List[int]], + batch_size: int, + drop_last: bool = True, + seed: int = 42, + shuffle: bool = True, + ): + self.indices_by_sig = {s: list(v) for s, v in indices_by_sig.items()} + self.batch_size = int(batch_size) + self.drop_last = drop_last + self.shuffle = shuffle + self.rng = random.Random(seed) + self.sigs = list(self.indices_by_sig.keys()) + + def _batches_for(self, pool: List[int]) -> List[List[int]]: + n = len(pool) + batches = [] + for start in range(0, n, self.batch_size): + chunk = pool[start:start + self.batch_size] + if len(chunk) < self.batch_size and self.drop_last: + continue + if chunk: + batches.append(chunk) + return batches + + def __iter__(self) -> Iterator[List[int]]: + # Fresh pools + optional shuffle within each signature + pools = {s: list(v) for s, v in self.indices_by_sig.items()} + for s in pools: + if self.shuffle: + self.rng.shuffle(pools[s]) + + # Build per-signature batch queues; essentially a dictionary with batches of each different modality signature + per_sig_batches = {s: deque(self._batches_for(pools[s])) for s in self.sigs} + + # Establish RR order + order = list(self.sigs) + if self.shuffle: + # rotate start signature per epoch for variety (keeps RR structure) + k = self.rng.randrange(len(order)) if order else 0 + order = order[k:] + order[:k] + else: + order = sorted(order) + + # Active signatures as a deque for easy rotation + active = deque([s for s in order if len(per_sig_batches[s]) > 0]) + + while active: + s = active.popleft() # take the queue's leftmost element (modality signature) + q = per_sig_batches[s] # access all of the batched stuff + if q: + yield q.popleft() # yield that batch + # if still has batches, push to the end to continue RR + if q: + active.append(s) # reappend the modality signature to the active queue + # if q is empty, we simply don't re-append s → pruned automatically + else: + print(f"Ran-Out: Pruning modality signature: {s}") + + def __len__(self) -> int: + # Total number of batches across all signatures (after drop_last handling) + total = 0 + for pool in self.indices_by_sig.values(): + full, rem = divmod(len(pool), self.batch_size) + total += full + (0 if self.drop_last or rem == 0 else 1) + return total \ No newline at end of file diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index c27a89133c7..61a4b0f4f73 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -20,8 +20,7 @@ import os import re from collections import defaultdict -from typing import Optional - +from typing import Optional,Dict, Any, List import datasets import numpy as np import torch @@ -29,12 +28,34 @@ from omegaconf import DictConfig, ListConfig from torch.utils.data import Dataset from transformers import PreTrainedTokenizer, ProcessorMixin +import warnings import verl.utils.torch_functional as verl_F from verl.utils.model import compute_position_id_with_mask +import time, os, math, warnings logger = logging.getLogger(__name__) +def _tok_est_from_hw(H, W): + # 28x28 -> 1 "visual token" heuristic + return math.ceil(H/28) * math.ceil(W/28) + +def _sec_from_array(arr, sr): + try: + return round(len(arr) / float(sr), 3) + except Exception: + return "?" + +def _p99(xs): + xs = sorted(xs) + if not xs: return 0 + k = int(0.99*(len(xs)-1)) + return xs[k] + +def assert_homogeneous(batch_list: List[Dict[str, Any]]): + sigs = {b.get("modality_signature") for b in batch_list} + if len(sigs) != 1: + raise AssertionError(f"Non-homogeneous batch signatures: {sigs}") def processor_supports_video(processor: ProcessorMixin) -> bool: """ @@ -78,9 +99,13 @@ def collate_fn(data_list: list[dict]) -> dict: Returns: Dict where tensor entries are stacked into a torch.Tensor of shape - (batch_size, \*dims) and non-tensor entries are converted to + (batch_size, dims) and non-tensor entries are converted to np.ndarray of dtype object with shape (batch_size,). """ + # data list is the batch list + # NOTE: we assert homogeneous if the modality signatures are not homogeneous + assert_homogeneous(data_list) # assert if not homogeneous + tensors = defaultdict(list) non_tensors = defaultdict(list) @@ -134,13 +159,24 @@ def __init__( self.config = config self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + + # Essentially getting all the different keys. self.prompt_key = config.get("prompt_key", "prompt") self.image_key = config.get("image_key", "images") self.video_key = config.get("video_key", "videos") + + # NOTE: SET AUDIO KEY AS AUDIOS + self.audio_key = config.get("audio_key", "audios") + + # NOTE: SET MODALITIES, split the images and videos + self.modalities = set(config.get("modalities", "images,videos").split(",")) + self.max_prompt_length = config.get("max_prompt_length", 1024) self.return_raw_chat = config.get("return_raw_chat", False) self.return_full_prompt = config.get("return_full_prompt", False) self.truncation = config.get("truncation", "error") + + # TODO: Check whether this is true self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) if isinstance(data_files, str): self.base_dir = os.path.dirname(os.path.abspath(data_files)) @@ -156,13 +192,25 @@ def __init__( self.filter_prompts = config.get("filter_prompts", True) self.serialize_dataset = False self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True) + + # Load format prompt from file if specified + self.format_prompt_path = config.get("format_prompt", "examples/format_prompt/default.jinja") + self.format_prompt = self._load_format_prompt() # Load format prompt from file if specified self.format_prompt_path = config.get("format_prompt", "examples/format_prompt/default.jinja") self.format_prompt = self._load_format_prompt() self._download() - self._read_files_and_tokenize() + self._read_files_and_tokenize() # essentially this is prepared first before _getitem + + def _load_format_prompt(self) -> Optional[Template]: + """Load format prompt from file if specified.""" + if self.format_prompt_path: + with open(self.format_prompt_path, 'r', encoding='utf-8') as f: + template_content = f.read() + return Template(template_content) + return None def _load_format_prompt(self) -> Optional[Template]: """Load format prompt from file if specified.""" @@ -181,6 +229,18 @@ def _download(self, use_origin_parquet=False): def _read_files_and_tokenize(self): dataframes = [] + + features = datasets.Features({ + "problem": datasets.Value("string"), + "answer": datasets.Value("string"), + "images": datasets.Sequence(datasets.Value("string")), + "videos": datasets.Sequence(datasets.Value("string")), + "audios": datasets.Sequence(datasets.Value("string")), # <- force list of strings + "dataset": datasets.Value("string"), + "texts": datasets.Sequence(datasets.Value("string")), + "modality_signature": datasets.Value("string"), + }) + for parquet_file in self.data_files: # read parquet files and cache if parquet_file.endswith(".parquet"): @@ -194,51 +254,78 @@ def _read_files_and_tokenize(self): print(f"dataset len: {len(self.dataframe)}") + # PROCESSING THE DATAFRAME for TRAINING self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe) def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None): - # filter out too long prompts + # NOTE: filter out too long prompts, because the prompts can become very long + # when the audio is appended. + if self.filter_overlong_prompts: + # NOTE: FILTER OUT THE LONG PROMPTS SO THAT THEY FIT THE LENGTH tokenizer = self.tokenizer processor = self.processor prompt_key = self.prompt_key image_key = self.image_key video_key = self.video_key + audio_key = self.audio_key if processor is not None: + # print(f"KEANE: PROCESSOR FOUND") from verl.utils.dataset.vision_utils import process_image, process_video + from verl.utils.dataset.audio_utils import process_audio def doc2len(doc) -> int: messages = self._build_messages(doc) raw_prompt = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs ) - images = [process_image(image) for image in doc[image_key]] if image_key in doc else None - videos = [process_video(video) for video in doc[video_key]] if video_key in doc else None - - # Handle video-to-image conversion for processors that don't support video - if videos and not processor_supports_video(processor): - # Convert video frames to images - if images is None: - images = [] - for video_tensor in videos: - # video_tensor is shape [n_frames, 3, H, W] - for frame_idx in range(video_tensor.shape[0]): - frame = video_tensor[frame_idx] # [3, H, W] - frame_np = frame.permute(1, 2, 0).numpy() # [H, W, 3] - from PIL import Image - frame_image = Image.fromarray(frame_np.astype('uint8'), 'RGB') - images.append(frame_image) - videos = None + processor_kwargs = {"text": [raw_prompt]} - # Call processor with appropriate parameters - if processor_supports_video(processor): - return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0]) - else: - return len(processor(text=[raw_prompt], images=images)["input_ids"][0]) + if "images" in self.modalities and image_key in doc and len(doc[image_key]) > 0: + images = [process_image(image) for image in doc[image_key]] + processor_kwargs["images"] = images + + if "videos" in self.modalities and video_key in doc and len(doc[video_key]) > 0: + videos = [process_video(video) for video in doc[video_key]] + # Handle video-to-image conversion for processors that don't support video + if videos and not processor_supports_video(processor): + # Convert video frames to images + if images is None: + images = [] + for video_tensor in videos: + # video_tensor is shape [n_frames, 3, H, W] + for frame_idx in range(video_tensor.shape[0]): + frame = video_tensor[frame_idx] # [3, H, W] + frame_np = frame.permute(1, 2, 0).numpy() # [H, W, 3] + from PIL import Image + frame_image = Image.fromarray(frame_np.astype('uint8'), 'RGB') + images.append(frame_image) + videos = None + processor_kwargs["videos"] = videos + + if "audio" in self.modalities and audio_key in doc and doc.get(audio_key, None) is not None and len(doc[audio_key]) > 0: + # processing of audio + # print(f"KEANE: Processing audio within rl dataset file") + # audios = [process_audio(audio, processor) for audio in doc[audio_key]] + # processor_kwargs["audio"] = audios + + # PATCH + audios = [] + audio_tuples = [] # Keep tuples for multi_modal_data + for audio in doc.get(self.audio_key): + audio_path = os.path.join(self.base_dir, audio) if isinstance(audio, str) else audio + audio_data, sampling_rate = process_audio(audio_path, self.processor) + audio_tuples.append((audio_data, sampling_rate)) + # audios.append(audio_data.numpy()) # Convert to numpy array for Whisper + audios.append(audio_data.detach().cpu().numpy().astype("float32")) + + processor_kwargs["audio"] = audios # Pass numpy arrays to processor + + return len(processor(**processor_kwargs)["input_ids"][0]) else: - + # print(f"KEANE: PROCESSOR NOT FOUND") def doc2len(doc) -> int: return len( tokenizer.apply_chat_template( @@ -267,31 +354,40 @@ def resume_dataset_state(self): def __len__(self): return len(self.dataframe) - def _build_messages(self, example: dict, convert_video_to_images: bool = False): + def _build_messages(self, example: dict): + """ + This appears to be called twice, once during maybe_filter_out_long_prompts, and another time during getitems + """ messages: list = example.get(self.prompt_key) if isinstance(messages, str): messages = [messages] - format_prompt = ("You FIRST think about the reasoning process as an internal monologue and then " - "provide the final answer. The reasoning process MUST BE enclosed within " - " tags. The final answer MUST BE put in \\boxed{}.") - - if self.image_key in example or self.video_key in example: + # NOTE: Before building, check if there is multimodal content + has_multimodal = ( + ("images" in self.modalities and self.image_key in example) or + ("videos" in self.modalities and self.video_key in example) or + ("audio" in self.modalities and self.audio_key in example) + ) + + if has_multimodal: new_messages = [] for message in messages: new_message = copy.deepcopy(message) if isinstance(new_message, str): new_message = {"role": "user", "content": new_message} content = new_message["content"] - # Apply format prompt to the entire content first if template is loaded if self.format_prompt: content = self.format_prompt.render(content=content) image_count = len(example.get(self.image_key, [])) video_count = len(example.get(self.video_key, [])) + audio_count = len(example.get(self.audio_key, [])) image_tag_count = content.count("") video_tag_count = content.count("