@@ -66,6 +66,7 @@ def __init__(
6666 "sum" , "max" , "mean" , "log-sum" , "log-max" , "log-mean"
6767 ] = "sum" ,
6868 multiply_with_base_loss : bool = True ,
69+ no_grads : bool = False ,
6970 ):
7071 super ().__init__ ()
7172 # automatically choose labeled subset for implication filter in case of mixed dataset
@@ -103,6 +104,7 @@ def __init__(
103104 self .start_at_epoch = start_at_epoch
104105 self .violations_per_cls_aggregator = violations_per_cls_aggregator
105106 self .multiply_with_base_loss = multiply_with_base_loss
107+ self .no_grads = no_grads
106108
107109 def _calculate_unaggregated_fuzzy_loss (
108110 self ,
@@ -214,6 +216,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple:
214216 ** kwargs ,
215217 )
216218 )
219+ if self .no_grads :
220+ fuzzy_loss = fuzzy_loss .detach ()
217221 loss_components ["unweighted_fuzzy_loss" ] = unweighted_fuzzy_mean
218222 loss_components ["weighted_fuzzy_loss" ] = weighted_fuzzy_mean
219223 if self .base_loss is None or target is None :
@@ -374,6 +378,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple:
374378 ** kwargs ,
375379 )
376380 )
381+ if self .no_grads :
382+ impl_loss = impl_loss .detach ()
377383 loss_components ["unweighted_implication_loss" ] = unweighted_impl_mean
378384 loss_components ["weighted_implication_loss" ] = weighted_impl_mean
379385
@@ -388,6 +394,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple:
388394 ** kwargs ,
389395 )
390396 )
397+ if self .no_grads :
398+ disj_loss = disj_loss .detach ()
391399 loss_components ["unweighted_disjointness_loss" ] = unweighted_disj_mean
392400 loss_components ["weighted_disjointness_loss" ] = weighted_disj_mean
393401
0 commit comments