1010class BCEWeighted (torch .nn .BCEWithLogitsLoss ):
1111 """
1212 BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
13- If beta is None or data_extractor is None, the loss is unweighted.
1413
1514 This class computes weights based on the formula from the paper by Cui et al. (2019):
1615 https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf
@@ -22,7 +21,7 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss):
2221
2322 def __init__ (
2423 self ,
25- beta : Optional [ float ] = None ,
24+ beta : float = 0.99 ,
2625 data_extractor : Optional [XYBaseDataModule ] = None ,
2726 ** kwargs ,
2827 ):
@@ -32,11 +31,26 @@ def __init__(
3231 if isinstance (data_extractor , LabeledUnlabeledMixed ):
3332 data_extractor = data_extractor .labeled
3433 self .data_extractor = data_extractor
34+
35+ assert (
36+ isinstance (beta , float ) and beta > 0.0
37+ ), f"Beta parameter must be a float with value greater than 0.0, for loss class { self .__class__ .__name__ } ."
38+
39+ assert (
40+ self .data_extractor is not None
41+ ), f"Data extractor must be provided if this loss class ({ self .__class__ .__name__ } ) is used."
42+
43+ assert all (
44+ os .path .exists (os .path .join (self .data_extractor .processed_dir , file_name ))
45+ for file_name in self .data_extractor .processed_file_names
46+ ), "Dataset files not found. Make sure the dataset is processed before using this loss."
47+
3548 assert (
3649 isinstance (self .data_extractor , _ChEBIDataExtractor )
3750 or self .data_extractor is None
3851 )
3952 super ().__init__ (** kwargs )
53+ self .pos_weight : Optional [torch .Tensor ] = None
4054
4155 def set_pos_weight (self , input : torch .Tensor ) -> None :
4256 """
@@ -45,17 +59,7 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
4559 Args:
4660 input (torch.Tensor): The input tensor for which to set the positive weights.
4761 """
48- if (
49- self .beta is not None
50- and self .data_extractor is not None
51- and all (
52- os .path .exists (
53- os .path .join (self .data_extractor .processed_dir , file_name )
54- )
55- for file_name in self .data_extractor .processed_file_names
56- )
57- and self .pos_weight is None
58- ):
62+ if self .pos_weight is None :
5963 print (
6064 f"Computing loss-weights based on v{ self .data_extractor .chebi_version } dataset (beta={ self .beta } )"
6165 )
@@ -96,3 +100,9 @@ def forward(
96100 """
97101 self .set_pos_weight (input )
98102 return super ().forward (input , target )
103+
104+
105+ class UnWeightedBCEWithLogitsLoss (torch .nn .BCEWithLogitsLoss ):
106+ def forward (self , input , target , ** kwargs ):
107+ # As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them
108+ return super ().forward (input , target )
0 commit comments