From 7bd1f4ec2b70380c7191529d4965ad62d4438bef Mon Sep 17 00:00:00 2001 From: Tony Date: Sun, 15 Mar 2026 04:09:00 +0000 Subject: [PATCH] Added batched and distributed eval --- scripts/eval.sh | 26 ++- src/configs/config.py | 3 + src/dataloaders/build_dataloader.py | 4 +- src/elms/connectors/linear_proj.py | 4 +- src/main_evaluator.py | 46 ++-- src/runners/evaluator.py | 349 +++++++++++++++++++++++----- 6 files changed, 354 insertions(+), 78 deletions(-) diff --git a/scripts/eval.sh b/scripts/eval.sh index ddc10d9..623f255 100644 --- a/scripts/eval.sh +++ b/scripts/eval.sh @@ -1,6 +1,28 @@ +# Single GPU, batched CUDA_VISIBLE_DEVICES=0 uv run src/main_evaluator.py \ --data_representation signal \ --data ecg-qa-ptbxl-250-2500 \ --llm llama-3.2-1b-instruct \ ---elm fuyu \ ---elm_ckpt src/runs/pretrain/llama-3.2-1b-instruct_None/2/checkpoints/epoch_best.pt \ No newline at end of file +--elm llava \ +--peft \ +--encoder st_mem \ +--num_workers 4 \ +--eval_batch_size 8 \ +--system_prompt src/dataloaders/system_prompts/system_prompt.txt \ +--elm_ckpt src/runs/llama-3.2-1b-instruct_st_mem/ecg-instruct-45k-250-2500/4/checkpoints/epoch_best.pt + +# Multi-GPU, distributed + batched (uncomment to use) +# CUDA_VISIBLE_DEVICES=0,1 uv run -m torch.distributed.run \ +# --nproc_per_node=2 \ +# src/main_evaluator.py \ +# --distributed \ +# --data_representation signal \ +# --data ecg-qa-ptbxl-250-2500 \ +# --llm llama-3.2-1b-instruct \ +# --elm llava \ +# --peft \ +# --encoder st_mem \ +# --num_workers 4 \ +# --eval_batch_size 8 \ +# --system_prompt src/dataloaders/system_prompts/system_prompt.txt \ +# --elm_ckpt src/runs/llama-3.2-1b-instruct_st_mem/ecg-instruct-45k-250-2500/4/checkpoints/epoch_best.pt diff --git a/src/configs/config.py b/src/configs/config.py index e3db9e2..aef227d 100644 --- a/src/configs/config.py +++ b/src/configs/config.py @@ -51,6 +51,9 @@ def get_args(mode: Mode) -> argparse.Namespace: parser.add_argument("--llm_input_len", type=int, default=2048, help="LLM Input Sequence Length") parser.add_argument("--min_ecg_tokens_len", type=int, default=512, help="Minimum ECG token length to consider") parser.add_argument("--norm_eps", type=float, default=1e-6, help="Please choose the normalization epsilon") + if mode in {"eval", "inference"}: + parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for batched generation during eval/inference") + if mode == "train": parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "adamw", "muon"], help="Optimizer type") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") diff --git a/src/dataloaders/build_dataloader.py b/src/dataloaders/build_dataloader.py index fddfe75..7422a90 100644 --- a/src/dataloaders/build_dataloader.py +++ b/src/dataloaders/build_dataloader.py @@ -41,10 +41,12 @@ def build_torch_dataloader(self, torch_dataset): elif "eval" in self.args.mode: torch_data_loader = DataLoader( torch_dataset, - batch_size=1, # batched inference/eval not implemented + batch_size=1, shuffle=False, + num_workers=self.args.num_workers, pin_memory=torch.cuda.is_available(), collate_fn=self.collate_fn, + persistent_workers=(self.args.num_workers > 0), ) return torch_data_loader diff --git a/src/elms/connectors/linear_proj.py b/src/elms/connectors/linear_proj.py index 3d2e7af..b10be72 100644 --- a/src/elms/connectors/linear_proj.py +++ b/src/elms/connectors/linear_proj.py @@ -10,7 +10,7 @@ def __init__(self, projection_dim, llm_id): self.projection = nn.Linear(projection_dim, HF_LLMS[llm_id]["model_hidden_size"]).to(dtype=self.input_dtype) def forward(self, ecg_signal): - return self.projection(ecg_signal.to(dtype=self.input_dtype)) + return self.projection(ecg_signal.to(dtype=self.projection.weight.dtype)) def project(self, signal_embeds): - return self.projection(signal_embeds.to(dtype=self.input_dtype)) + return self.projection(signal_embeds.to(dtype=self.projection.weight.dtype)) diff --git a/src/main_evaluator.py b/src/main_evaluator.py index fed77bd..7888956 100644 --- a/src/main_evaluator.py +++ b/src/main_evaluator.py @@ -5,7 +5,7 @@ from pathlib import Path from configs.config import get_args -from utils.gpu_manager import GPUSetup +from utils.gpu_manager import GPUSetup, init_dist, cleanup, is_main from utils.seed_manager import set_seed from dataloaders.build_dataloader import BuildDataLoader from elms.build_elm import BuildELM @@ -18,6 +18,10 @@ def main(): mode = "eval" args = get_args(mode) args.mode = mode + + if getattr(args, "distributed", False): + init_dist() + # folds = ["1", "2", "3", "4", "5"] # seeds = [1337, 1338, 1339, 1340, 1341] folds = ["1"] @@ -33,9 +37,12 @@ def main(): sys_prompt_name = Path(args.system_prompt).stem data_name = "_".join(args.data) results_file = os.path.join(checkpoint_dir, f"{ckpt_file_name}_{data_name}_{sys_prompt_name}_{args.perturb}.json") + debug_path = results_file.replace(".json", "_debug.txt") + debug_file = open(debug_path, "w") if is_main() else None for fold in folds: for seed in seeds: - print(f"Evaluating fold {fold} with seed {seed}") + if is_main(): + print(f"Evaluating fold {fold} with seed {seed}") args.fold = fold args.seed = seed set_seed(args.seed) @@ -47,26 +54,35 @@ def main(): elm = gpu_setup.setup_gpu(elm_components["elm"], elm_components["find_unused_parameters"]) if args.dev: gpu_setup.print_model_device(elm, f"{args.llm}_{args.encoder}") - out = evaluate(elm, dataloader, args) + out = evaluate(elm, dataloader, args, debug_file=debug_file) all_metrics.append(out) - if len(all_metrics) == 1: + if is_main() and len(all_metrics) == 1: examples_path = results_file.replace(".json", "_examples.json") examples = [{"prompt": p, "predicted": h, "ground_truth": r} for p, h, r in zip(out["prompts"], out["hypotheses"], out["references"])] with open(examples_path, "w") as ef: json.dump(examples, ef, indent=2) print(f"Saved {len(examples)} eval examples to {examples_path}") - if "confusion_matrix" in out: - cm_path = results_file.replace(".json", f"{fold}_{seed}.png") - save_confusion_matrix_png(out["confusion_matrix"], cm_path) - other_path = results_file.replace(".json", f"{fold}_{seed}_other.png") - save_other_outputs_histogram_png(out["other_output_counts"], other_path, top_k = 10) - incorrect_path = results_file.replace(".json", f"{fold}_{seed}_incorrect.png") - save_incorrect_predictions_histogram_png(out["references"], out["hypotheses"], incorrect_path) - statistical_results = run_statistical_analysis(all_metrics) - with open(results_file, "w") as f: - json.dump(statistical_results, f, indent=2) - print(f"Saved evaluation results to {results_file}") + if is_main(): + if "confusion_matrix" in out: + cm_path = results_file.replace(".json", f"{fold}_{seed}.png") + save_confusion_matrix_png(out["confusion_matrix"], cm_path) + other_path = results_file.replace(".json", f"{fold}_{seed}_other.png") + save_other_outputs_histogram_png(out["other_output_counts"], other_path, top_k = 10) + incorrect_path = results_file.replace(".json", f"{fold}_{seed}_incorrect.png") + save_incorrect_predictions_histogram_png(out["references"], out["hypotheses"], incorrect_path) + if debug_file is not None: + debug_file.close() + if is_main(): + print(f"Saved debug dump to {debug_path}") + if is_main(): + statistical_results = run_statistical_analysis(all_metrics) + with open(results_file, "w") as f: + json.dump(statistical_results, f, indent=2) + print(f"Saved evaluation results to {results_file}") + + if getattr(args, "distributed", False): + cleanup() if __name__ == "__main__": diff --git a/src/runners/evaluator.py b/src/runners/evaluator.py index 3a3a3cc..b97ba45 100644 --- a/src/runners/evaluator.py +++ b/src/runners/evaluator.py @@ -1,11 +1,15 @@ +import functools import numpy as np import scipy.stats as stats from tqdm import tqdm import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler from collections import Counter import string -from utils.gpu_manager import is_main, train_dev_break +from utils.gpu_manager import is_main, get_world_size, get_rank from runners.helper import batch_to_device @@ -219,70 +223,298 @@ def save_incorrect_predictions_histogram_png(references, hypotheses, path, top_k plt.close(fig) print(f"Saved incorrect-predictions histogram to {path}") -def evaluate(elm, dataloader, args): - show_progress = is_main() - elm.eval() - needs_signal_injection = args.elm in ("llava", "base_elm", "patch_elm") + +# --------------------------------------------------------------------------- +# Phase 1 helpers: flatten all (sample, turn) pairs on CPU +# --------------------------------------------------------------------------- + +def _extract_encoder_out_item(encoder_tokenizer_out, b): + """Extract single-sample encoder outputs from a batched dict (batch dim squeezed).""" + out = {} + for k, v in encoder_tokenizer_out.items(): + if isinstance(v, dict): + out[k] = _extract_encoder_out_item(v, b) + else: + out[k] = v[b] # remove batch dim + return out + + +def flatten_eval_turns(dataloader, args, needs_signal_injection): + """Iterate through the dataset (batch_size=1) and produce a flat list of turn-level work items.""" + dataset = dataloader.dataset + turns = [] + global_idx = 0 + progress = tqdm( dataloader, - desc=f"LLM: {args.llm} ENCODER: {args.encoder}", - disable=not show_progress, + desc="Flattening turns", + disable=not is_main(), leave=False, ) + + for batch_idx, batch in enumerate(progress): + B = batch["elm_input_ids"].shape[0] + for b in range(B): + full_ids = batch["elm_input_ids"][b].tolist() + full_attn = batch["elm_attention_mask"][b].tolist() + + if needs_signal_injection: + signal_indices = batch["signal_id_indices"][b] + encoder_out_item = _extract_encoder_out_item(batch["encoder_tokenizer_out"], b) + + ranges = dataset.get_response_ranges(full_ids) + gt_texts = dataset.get_ground_truth_responses(full_ids, ranges) + + if getattr(args, "dev", False) and is_main(): + print(f"\n--- Batch {batch_idx}, Sample {b} ---") + print(f"Total turns: {len(ranges)}") + dataset.assert_range_alignment(full_ids, ranges) + + for turn_idx, ((s, _), gt) in enumerate(zip(ranges, gt_texts)): + sub_ids = full_ids[:s] + sub_attn = full_attn[:s] + + turn = { + "global_idx": global_idx, + "sample_idx": batch_idx * B + b, + "turn_idx": turn_idx, + "prefix_ids": sub_ids, + "prefix_attn": sub_attn, + "gt_text": gt, + } + + if needs_signal_injection: + masked_indices = signal_indices.clone() + masked_indices[masked_indices >= len(sub_ids)] = -1 + turn["signal_id_indices"] = masked_indices + turn["encoder_tokenizer_out"] = encoder_out_item + + turns.append(turn) + global_idx += 1 + + if is_main(): + print(f"Flattened {global_idx} turns from {len(dataloader.dataset)} samples") + return turns + + +# --------------------------------------------------------------------------- +# Phase 2 helpers: batched generation +# --------------------------------------------------------------------------- + +class TurnDataset(Dataset): + """Wraps the flat turn list for DataLoader batching.""" + def __init__(self, turns): + self.turns = turns + + def __len__(self): + return len(self.turns) + + def __getitem__(self, idx): + return self.turns[idx] + + +def _stack_encoder_out(items): + """Stack per-sample encoder outputs along a new batch dimension.""" + if not items: + return {} + keys = items[0].keys() + out = {} + for k in keys: + vals = [item[k] for item in items] + if isinstance(vals[0], dict): + out[k] = _stack_encoder_out(vals) + elif isinstance(vals[0], torch.Tensor): + out[k] = torch.stack(vals, dim=0) + elif isinstance(vals[0], np.ndarray): + out[k] = torch.from_numpy(np.stack(vals, axis=0)) + else: + out[k] = vals + return out + + +def eval_collate_fn(batch, pad_token_id): + """Collate turn dicts into a left-padded batch for generation.""" + max_len = max(len(item["prefix_ids"]) for item in batch) + + all_input_ids = [] + all_attn_masks = [] + all_prefix_lens = [] + all_global_idxs = [] + all_gt_texts = [] + all_original_prefix_ids = [] + has_signal = "signal_id_indices" in batch[0] + all_signal_indices = [] + all_encoder_outs = [] + + for item in batch: + prefix_ids = item["prefix_ids"] + prefix_attn = item["prefix_attn"] + prefix_len = len(prefix_ids) + pad_len = max_len - prefix_len + + # Left-pad input_ids and attention mask + padded_ids = [pad_token_id] * pad_len + prefix_ids + padded_attn = [0] * pad_len + prefix_attn + + all_input_ids.append(padded_ids) + all_attn_masks.append(padded_attn) + all_prefix_lens.append(prefix_len) + all_global_idxs.append(item["global_idx"]) + all_gt_texts.append(item["gt_text"]) + all_original_prefix_ids.append(prefix_ids) + + if has_signal: + # Shift signal indices right by pad_len (left-padding offset) + indices = item["signal_id_indices"].clone() + valid = indices >= 0 + indices[valid] += pad_len + all_signal_indices.append(indices) + all_encoder_outs.append(item["encoder_tokenizer_out"]) + + out = { + "elm_input_ids": torch.tensor(all_input_ids, dtype=torch.int64), + "elm_attention_mask": torch.tensor(all_attn_masks, dtype=torch.float32), + "prefix_len": all_prefix_lens, + "global_idx": all_global_idxs, + "gt_text": all_gt_texts, + "original_prefix_ids": all_original_prefix_ids, + } + + if has_signal: + out["signal_id_indices"] = torch.stack(all_signal_indices, dim=0) + out["encoder_tokenizer_out"] = _stack_encoder_out(all_encoder_outs) + + return out + + +# --------------------------------------------------------------------------- +# Main evaluate function (batched + distributed) +# --------------------------------------------------------------------------- + +def evaluate(elm, dataloader, args, debug_file=None): + elm.eval() + needs_signal_injection = args.elm in ("llava", "base_elm", "patch_elm") dataset = dataloader.dataset device = next(elm.parameters()).device - all_refs, all_hyps, all_prompts = [], [], [] + distributed = getattr(args, "distributed", False) + + # --- Phase 1: Flatten all turns (CPU, batch_size=1) --- + turns = flatten_eval_turns(dataloader, args, needs_signal_injection) + + if not turns: + return { + "num_pairs": 0, + "metrics": {"ACC": 0.0, "F1": 0.0}, + "prompts": [], + "references": [], + "hypotheses": [], + } + + # --- Phase 2: Batched generation --- + turn_dataset = TurnDataset(turns) + pad_token_id = dataset.llm_tokenizer.pad_token_id + collate = functools.partial(eval_collate_fn, pad_token_id=pad_token_id) + + if distributed: + sampler = DistributedSampler( + turn_dataset, + num_replicas=get_world_size(), + rank=get_rank(), + shuffle=False, + drop_last=False, + ) + else: + sampler = None + + eval_batch_size = getattr(args, "eval_batch_size", 1) + gen_loader = DataLoader( + turn_dataset, + batch_size=eval_batch_size, + shuffle=False, + sampler=sampler, + collate_fn=collate, + num_workers=0, # data already in memory + pin_memory=False, + ) + + # Unwrap DDP for generate() + gen_model = elm.module if hasattr(elm, "module") else elm + + local_results = [] + with torch.no_grad(): - for batch_idx, batch in enumerate(progress): - B = batch["elm_input_ids"].shape[0] - for b in range(B): - full_ids = batch["elm_input_ids"][b].tolist() - full_attn = batch["elm_attention_mask"][b].tolist() - if needs_signal_injection: - signal_indices = batch["signal_id_indices"][b] - full_encoder_tokenizer_out = index_nested(batch["encoder_tokenizer_out"], b) - ranges = dataset.get_response_ranges(full_ids) - gt_texts = dataset.get_ground_truth_responses(full_ids, ranges) - if getattr(args, "dev", False): - print(f"\n--- Batch {batch_idx}, Sample {b} ---") - print(f"Total turns: {len(ranges)}") - dataset.assert_range_alignment(full_ids, ranges) - for turn_idx, ((s, _), gt) in enumerate(zip(ranges, gt_texts)): - sub_ids = full_ids[:s] - sub_attn = full_attn[:s] - gen_batch = { - "elm_input_ids": torch.tensor(sub_ids, dtype=torch.int64).unsqueeze(0), - "elm_attention_mask": torch.tensor(sub_attn, dtype=torch.float32).unsqueeze(0), - } - if needs_signal_injection: - gen_batch["encoder_tokenizer_out"] = full_encoder_tokenizer_out - # Mask out signal indices that fall outside the truncated sequence - truncated_len = len(sub_ids) - masked_indices = signal_indices.clone() - masked_indices[masked_indices >= truncated_len] = -1 - gen_batch["signal_id_indices"] = masked_indices - gen_batch = {k: batch_to_device(v, device) for k, v in gen_batch.items()} - gen_out = elm.generate(**gen_batch)[0].cpu().tolist() - gen_txt = dataset.get_generated_response_for_turn(sub_ids, gen_out) - if getattr(args, "dev", False): - print(f"\nTurn {turn_idx + 1}:") - print(f"\nGround Truth:\n{gt}") - print(f"\nGenerated:\n{gen_txt}") - print("-" * 100) - if gt and gen_txt: - all_prompts.append(dataset.llm_tokenizer.decode(sub_ids, skip_special_tokens=True).strip()) - all_refs.append(gt) - all_hyps.append(gen_txt) - # if train_dev_break(getattr(args, "dev", False), batch, 0): - # break - # if batch_idx == 10: - # break - # input() + for batch in tqdm(gen_loader, desc=f"Generating (bs={eval_batch_size})", + disable=not is_main(), leave=False): + gen_batch = { + "elm_input_ids": batch["elm_input_ids"].to(device), + "elm_attention_mask": batch["elm_attention_mask"].to(device), + } + if needs_signal_injection: + enc_out = batch_to_device(batch["encoder_tokenizer_out"], device) + gen_batch["encoder_tokenizer_out"] = enc_out + gen_batch["signal_id_indices"] = batch["signal_id_indices"].to(device) + + gen_out = gen_model.generate(**gen_batch) # [B, output_seq_len] + + B = gen_out.shape[0] + for i in range(B): + gidx = batch["global_idx"][i] + prefix_ids = batch["original_prefix_ids"][i] + gt = batch["gt_text"][i] + + gen_ids = gen_out[i].cpu().tolist() + gen_txt = dataset.get_generated_response_for_turn(prefix_ids, gen_ids) + + local_results.append((gidx, gt, gen_txt, prefix_ids)) + + # --- Phase 3: Gather, deduplicate, reorder --- + if distributed: + all_results_nested = [None] * get_world_size() + dist.all_gather_object(all_results_nested, local_results) + all_results = [item for sublist in all_results_nested for item in sublist] + else: + all_results = local_results + + # Sort by global_idx for deterministic ordering + all_results.sort(key=lambda x: x[0]) + + # Deduplicate (DistributedSampler with drop_last=False pads with duplicates) + seen = set() + deduped = [] + for item in all_results: + if item[0] not in seen: + seen.add(item[0]) + deduped.append(item) + all_results = deduped + + # Build final lists (only keep pairs where both gt and gen are non-empty) + all_refs, all_hyps, all_prompts = [], [], [] + example_idx = 0 + for gidx, gt, gen_txt, prefix_ids in all_results: + if gt and gen_txt: + prompt_txt = dataset.llm_tokenizer.decode(prefix_ids, skip_special_tokens=True).strip() + all_prompts.append(prompt_txt) + all_refs.append(gt) + all_hyps.append(gen_txt) + + if debug_file is not None and is_main(): + debug_file.write(f"{'='*80}\n") + debug_file.write(f"EXAMPLE {example_idx} | global_idx={gidx}\n") + debug_file.write(f"{'='*80}\n") + debug_file.write(f"PROMPT ({len(prefix_ids)} tokens):\n{prompt_txt[:500]}\n\n") + debug_file.write(f"GROUND TRUTH:\n{gt}\n\n") + debug_file.write(f"GENERATED:\n{gen_txt}\n\n") + debug_file.write(f"MATCH: {gt == gen_txt}\n\n") + debug_file.flush() + example_idx += 1 + results = evaluate_strings(all_refs, all_hyps) - print("\n=== N-Turn Evaluation (generated vs. gold response only) ===") - print(f"Pairs: {len(all_refs)}") - print(f"ACC: {results['ACC']:.4f}") - print(f"F1: {results['F1']:.4f}") + if is_main(): + print("\n=== N-Turn Evaluation (generated vs. gold response only) ===") + print(f"Pairs: {len(all_refs)}") + print(f"ACC: {results['ACC']:.4f}") + print(f"F1: {results['F1']:.4f}") + out = { "num_pairs": len(all_refs), "metrics": results, @@ -292,7 +524,8 @@ def evaluate(elm, dataloader, args): } if any(d.startswith("ecg-comp") for d in args.data): per_class_acc, confusion_matrix, other_counts = compute_classification_metrics(all_refs, all_hyps) - print_classification_metrics(per_class_acc, confusion_matrix) + if is_main(): + print_classification_metrics(per_class_acc, confusion_matrix) results["per_class_acc"] = per_class_acc out["confusion_matrix"] = confusion_matrix out["other_output_counts"] = other_counts