diff --git a/src/alignrl/dpo.py b/src/alignrl/dpo.py index 6c3bc53..82b68ba 100644 --- a/src/alignrl/dpo.py +++ b/src/alignrl/dpo.py @@ -106,6 +106,8 @@ def train(self) -> TrainResult: trainer.save_model(str(output_dir / "final")) loss_history = [log["loss"] for log in trainer.state.log_history if "loss" in log] + if not loss_history: + loss_history = [result.training_loss] train_result = TrainResult( output_dir=output_dir / "final", diff --git a/src/alignrl/grpo.py b/src/alignrl/grpo.py index bc2aa0d..1921161 100644 --- a/src/alignrl/grpo.py +++ b/src/alignrl/grpo.py @@ -135,6 +135,8 @@ def train(self) -> TrainResult: trainer.save_model(str(output_dir / "final")) loss_history = [log["loss"] for log in trainer.state.log_history if "loss" in log] + if not loss_history: + loss_history = [result.training_loss] reward_history = [ log.get("reward", 0.0) for log in trainer.state.log_history if "reward" in log ] diff --git a/src/alignrl/sft.py b/src/alignrl/sft.py index 321a337..073e6b6 100644 --- a/src/alignrl/sft.py +++ b/src/alignrl/sft.py @@ -113,6 +113,8 @@ def train(self) -> TrainResult: trainer.save_model(str(output_dir / "final")) loss_history = [log["loss"] for log in trainer.state.log_history if "loss" in log] + if not loss_history: + loss_history = [result.training_loss] train_result = TrainResult( output_dir=output_dir / "final",