Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions scripts/eval.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
# Single GPU, batched
CUDA_VISIBLE_DEVICES=0 uv run src/main_evaluator.py \
--data_representation signal \
--data ecg-qa-ptbxl-250-2500 \
--llm llama-3.2-1b-instruct \
--elm fuyu \
--elm_ckpt src/runs/pretrain/llama-3.2-1b-instruct_None/2/checkpoints/epoch_best.pt
--elm llava \
--peft \
--encoder st_mem \
--num_workers 4 \
--eval_batch_size 8 \
--system_prompt src/dataloaders/system_prompts/system_prompt.txt \
--elm_ckpt src/runs/llama-3.2-1b-instruct_st_mem/ecg-instruct-45k-250-2500/4/checkpoints/epoch_best.pt

# Multi-GPU, distributed + batched (uncomment to use)
# CUDA_VISIBLE_DEVICES=0,1 uv run -m torch.distributed.run \
# --nproc_per_node=2 \
# src/main_evaluator.py \
# --distributed \
# --data_representation signal \
# --data ecg-qa-ptbxl-250-2500 \
# --llm llama-3.2-1b-instruct \
# --elm llava \
# --peft \
# --encoder st_mem \
# --num_workers 4 \
# --eval_batch_size 8 \
# --system_prompt src/dataloaders/system_prompts/system_prompt.txt \
# --elm_ckpt src/runs/llama-3.2-1b-instruct_st_mem/ecg-instruct-45k-250-2500/4/checkpoints/epoch_best.pt
3 changes: 3 additions & 0 deletions src/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def get_args(mode: Mode) -> argparse.Namespace:
parser.add_argument("--llm_input_len", type=int, default=2048, help="LLM Input Sequence Length")
parser.add_argument("--min_ecg_tokens_len", type=int, default=512, help="Minimum ECG token length to consider")
parser.add_argument("--norm_eps", type=float, default=1e-6, help="Please choose the normalization epsilon")
if mode in {"eval", "inference"}:
parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for batched generation during eval/inference")

if mode == "train":
parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "adamw", "muon"], help="Optimizer type")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
Expand Down
4 changes: 3 additions & 1 deletion src/dataloaders/build_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ def build_torch_dataloader(self, torch_dataset):
elif "eval" in self.args.mode:
torch_data_loader = DataLoader(
torch_dataset,
batch_size=1, # batched inference/eval not implemented
batch_size=1,
shuffle=False,
num_workers=self.args.num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=self.collate_fn,
persistent_workers=(self.args.num_workers > 0),
)
return torch_data_loader

Expand Down
4 changes: 2 additions & 2 deletions src/elms/connectors/linear_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, projection_dim, llm_id):
self.projection = nn.Linear(projection_dim,
HF_LLMS[llm_id]["model_hidden_size"]).to(dtype=self.input_dtype)
def forward(self, ecg_signal):
return self.projection(ecg_signal.to(dtype=self.input_dtype))
return self.projection(ecg_signal.to(dtype=self.projection.weight.dtype))

def project(self, signal_embeds):
return self.projection(signal_embeds.to(dtype=self.input_dtype))
return self.projection(signal_embeds.to(dtype=self.projection.weight.dtype))
46 changes: 31 additions & 15 deletions src/main_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path

from configs.config import get_args
from utils.gpu_manager import GPUSetup
from utils.gpu_manager import GPUSetup, init_dist, cleanup, is_main
from utils.seed_manager import set_seed
from dataloaders.build_dataloader import BuildDataLoader
from elms.build_elm import BuildELM
Expand All @@ -18,6 +18,10 @@ def main():
mode = "eval"
args = get_args(mode)
args.mode = mode

if getattr(args, "distributed", False):
init_dist()

# folds = ["1", "2", "3", "4", "5"]
# seeds = [1337, 1338, 1339, 1340, 1341]
folds = ["1"]
Expand All @@ -33,9 +37,12 @@ def main():
sys_prompt_name = Path(args.system_prompt).stem
data_name = "_".join(args.data)
results_file = os.path.join(checkpoint_dir, f"{ckpt_file_name}_{data_name}_{sys_prompt_name}_{args.perturb}.json")
debug_path = results_file.replace(".json", "_debug.txt")
debug_file = open(debug_path, "w") if is_main() else None
for fold in folds:
for seed in seeds:
print(f"Evaluating fold {fold} with seed {seed}")
if is_main():
print(f"Evaluating fold {fold} with seed {seed}")
args.fold = fold
args.seed = seed
set_seed(args.seed)
Expand All @@ -47,26 +54,35 @@ def main():
elm = gpu_setup.setup_gpu(elm_components["elm"], elm_components["find_unused_parameters"])
if args.dev:
gpu_setup.print_model_device(elm, f"{args.llm}_{args.encoder}")
out = evaluate(elm, dataloader, args)
out = evaluate(elm, dataloader, args, debug_file=debug_file)
all_metrics.append(out)
if len(all_metrics) == 1:
if is_main() and len(all_metrics) == 1:
examples_path = results_file.replace(".json", "_examples.json")
examples = [{"prompt": p, "predicted": h, "ground_truth": r}
for p, h, r in zip(out["prompts"], out["hypotheses"], out["references"])]
with open(examples_path, "w") as ef:
json.dump(examples, ef, indent=2)
print(f"Saved {len(examples)} eval examples to {examples_path}")
if "confusion_matrix" in out:
cm_path = results_file.replace(".json", f"{fold}_{seed}.png")
save_confusion_matrix_png(out["confusion_matrix"], cm_path)
other_path = results_file.replace(".json", f"{fold}_{seed}_other.png")
save_other_outputs_histogram_png(out["other_output_counts"], other_path, top_k = 10)
incorrect_path = results_file.replace(".json", f"{fold}_{seed}_incorrect.png")
save_incorrect_predictions_histogram_png(out["references"], out["hypotheses"], incorrect_path)
statistical_results = run_statistical_analysis(all_metrics)
with open(results_file, "w") as f:
json.dump(statistical_results, f, indent=2)
print(f"Saved evaluation results to {results_file}")
if is_main():
if "confusion_matrix" in out:
cm_path = results_file.replace(".json", f"{fold}_{seed}.png")
save_confusion_matrix_png(out["confusion_matrix"], cm_path)
other_path = results_file.replace(".json", f"{fold}_{seed}_other.png")
save_other_outputs_histogram_png(out["other_output_counts"], other_path, top_k = 10)
incorrect_path = results_file.replace(".json", f"{fold}_{seed}_incorrect.png")
save_incorrect_predictions_histogram_png(out["references"], out["hypotheses"], incorrect_path)
if debug_file is not None:
debug_file.close()
if is_main():
print(f"Saved debug dump to {debug_path}")
if is_main():
statistical_results = run_statistical_analysis(all_metrics)
with open(results_file, "w") as f:
json.dump(statistical_results, f, indent=2)
print(f"Saved evaluation results to {results_file}")

if getattr(args, "distributed", False):
cleanup()


if __name__ == "__main__":
Expand Down
Loading