From 04fb3d14bcf76809999d5999da59ebb18b58cd1b Mon Sep 17 00:00:00 2001 From: Josh Veitch-Michaelis Date: Fri, 2 Jan 2026 02:49:23 +0000 Subject: [PATCH] detach losses when logging --- src/deepforest/main.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 0c6fa9b1a..2a7bb0660 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -676,10 +676,12 @@ def training_step(self, batch, batch_idx): # Log loss for key, value in loss_dict.items(): - self.log(f"train_{key}", value, on_epoch=True, batch_size=len(images)) + self.log( + f"train_{key}", value.detach(), on_epoch=True, batch_size=len(images) + ) # Log sum of losses - self.log("train_loss", losses, on_epoch=True, batch_size=len(images)) + self.log("train_loss", losses.detach(), on_epoch=True, batch_size=len(images)) return losses @@ -699,9 +701,11 @@ def validation_step(self, batch, batch_idx): # Log losses try: for key, value in loss_dict.items(): - self.log(f"val_{key}", value, on_epoch=True, batch_size=len(images)) + self.log( + f"val_{key}", value.detach(), on_epoch=True, batch_size=len(images) + ) - self.log("val_loss", losses, on_epoch=True, batch_size=len(images)) + self.log("val_loss", losses.detach(), on_epoch=True, batch_size=len(images)) except MisconfigurationException: pass