From 5d8315bf4bb267482f7f2ef37ec7b7e420278844 Mon Sep 17 00:00:00 2001 From: Samanvya Tripathi Date: Wed, 25 Mar 2026 21:31:59 -0400 Subject: [PATCH] fix(runners): guard against empty loss_history in all trainers When max_steps < logging_steps, loss_history is empty, causing IndexError in notebooks and CLI. Now falls back to [training_loss] from TRL's result object when no log entries contain loss. Fixes #35 --- src/alignrl/dpo.py | 2 ++ src/alignrl/grpo.py | 2 ++ src/alignrl/sft.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/src/alignrl/dpo.py b/src/alignrl/dpo.py index 6c3bc53..82b68ba 100644 --- a/src/alignrl/dpo.py +++ b/src/alignrl/dpo.py @@ -106,6 +106,8 @@ def train(self) -> TrainResult: trainer.save_model(str(output_dir / "final")) loss_history = [log["loss"] for log in trainer.state.log_history if "loss" in log] + if not loss_history: + loss_history = [result.training_loss] train_result = TrainResult( output_dir=output_dir / "final", diff --git a/src/alignrl/grpo.py b/src/alignrl/grpo.py index bc2aa0d..1921161 100644 --- a/src/alignrl/grpo.py +++ b/src/alignrl/grpo.py @@ -135,6 +135,8 @@ def train(self) -> TrainResult: trainer.save_model(str(output_dir / "final")) loss_history = [log["loss"] for log in trainer.state.log_history if "loss" in log] + if not loss_history: + loss_history = [result.training_loss] reward_history = [ log.get("reward", 0.0) for log in trainer.state.log_history if "reward" in log ] diff --git a/src/alignrl/sft.py b/src/alignrl/sft.py index 321a337..073e6b6 100644 --- a/src/alignrl/sft.py +++ b/src/alignrl/sft.py @@ -113,6 +113,8 @@ def train(self) -> TrainResult: trainer.save_model(str(output_dir / "final")) loss_history = [log["loss"] for log in trainer.state.log_history if "loss" in log] + if not loss_history: + loss_history = [result.training_loss] train_result = TrainResult( output_dir=output_dir / "final",