diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 0c6fa9b1a..2a7bb0660 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -676,10 +676,12 @@ def training_step(self, batch, batch_idx): # Log loss for key, value in loss_dict.items(): - self.log(f"train_{key}", value, on_epoch=True, batch_size=len(images)) + self.log( + f"train_{key}", value.detach(), on_epoch=True, batch_size=len(images) + ) # Log sum of losses - self.log("train_loss", losses, on_epoch=True, batch_size=len(images)) + self.log("train_loss", losses.detach(), on_epoch=True, batch_size=len(images)) return losses @@ -699,9 +701,11 @@ def validation_step(self, batch, batch_idx): # Log losses try: for key, value in loss_dict.items(): - self.log(f"val_{key}", value, on_epoch=True, batch_size=len(images)) + self.log( + f"val_{key}", value.detach(), on_epoch=True, batch_size=len(images) + ) - self.log("val_loss", losses, on_epoch=True, batch_size=len(images)) + self.log("val_loss", losses.detach(), on_epoch=True, batch_size=len(images)) except MisconfigurationException: pass