Skip to content

Commit 62a49eb

Browse files
author
sfluegel
committed
add optional detach from gradients
1 parent 45c54f4 commit 62a49eb

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

chebai/loss/semantic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
"sum", "max", "mean", "log-sum", "log-max", "log-mean"
6767
] = "sum",
6868
multiply_with_base_loss: bool = True,
69+
no_grads: bool = False,
6970
):
7071
super().__init__()
7172
# automatically choose labeled subset for implication filter in case of mixed dataset
@@ -103,6 +104,7 @@ def __init__(
103104
self.start_at_epoch = start_at_epoch
104105
self.violations_per_cls_aggregator = violations_per_cls_aggregator
105106
self.multiply_with_base_loss = multiply_with_base_loss
107+
self.no_grads = no_grads
106108

107109
def _calculate_unaggregated_fuzzy_loss(
108110
self,
@@ -214,6 +216,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple:
214216
**kwargs,
215217
)
216218
)
219+
if self.no_grads:
220+
fuzzy_loss = fuzzy_loss.detach()
217221
loss_components["unweighted_fuzzy_loss"] = unweighted_fuzzy_mean
218222
loss_components["weighted_fuzzy_loss"] = weighted_fuzzy_mean
219223
if self.base_loss is None or target is None:
@@ -374,6 +378,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple:
374378
**kwargs,
375379
)
376380
)
381+
if self.no_grads:
382+
impl_loss = impl_loss.detach()
377383
loss_components["unweighted_implication_loss"] = unweighted_impl_mean
378384
loss_components["weighted_implication_loss"] = weighted_impl_mean
379385

@@ -388,6 +394,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple:
388394
**kwargs,
389395
)
390396
)
397+
if self.no_grads:
398+
disj_loss = disj_loss.detach()
391399
loss_components["unweighted_disjointness_loss"] = unweighted_disj_mean
392400
loss_components["weighted_disjointness_loss"] = weighted_disj_mean
393401

0 commit comments

Comments
 (0)