Skip to content

Commit 678ae19

Browse files
authored
Merge pull request #132 from ChEB-AI/fix/weighted_bce_loss
Fix for weighted BCE loss
2 parents 97839d9 + a86e0f7 commit 678ae19

File tree

6 files changed

+34
-16
lines changed

6 files changed

+34
-16
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,7 @@ lightning_logs
176176
logs
177177
.isort.cfg
178178
/.vscode
179+
180+
*.out
181+
*.err
182+
*.sh

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont
8181
You can evaluate a model trained on the ontology extension task in one of two ways:
8282

8383
### 1. Using the Jupyter Notebook
84-
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
84+
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
8585
- Load your finetuned model and run the evaluation cells to compute metrics on the test set.
8686

8787
### 2. Using the Lightning CLI

chebai/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def call_data_methods(data: Type[XYBaseDataModule]):
7070
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
7171
)
7272

73+
parser.link_arguments(
74+
"data", "model.init_args.criterion.init_args.data_extractor"
75+
)
76+
7377
@staticmethod
7478
def subcommands() -> Dict[str, Set[str]]:
7579
"""

chebai/loss/bce_weighted.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
1111
"""
1212
BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
13-
If beta is None or data_extractor is None, the loss is unweighted.
1413
1514
This class computes weights based on the formula from the paper by Cui et al. (2019):
1615
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf
@@ -22,7 +21,7 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss):
2221

2322
def __init__(
2423
self,
25-
beta: Optional[float] = None,
24+
beta: float = 0.99,
2625
data_extractor: Optional[XYBaseDataModule] = None,
2726
**kwargs,
2827
):
@@ -32,11 +31,26 @@ def __init__(
3231
if isinstance(data_extractor, LabeledUnlabeledMixed):
3332
data_extractor = data_extractor.labeled
3433
self.data_extractor = data_extractor
34+
35+
assert (
36+
isinstance(beta, float) and beta > 0.0
37+
), f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}."
38+
39+
assert (
40+
self.data_extractor is not None
41+
), f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used."
42+
43+
assert all(
44+
os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name))
45+
for file_name in self.data_extractor.processed_file_names
46+
), "Dataset files not found. Make sure the dataset is processed before using this loss."
47+
3548
assert (
3649
isinstance(self.data_extractor, _ChEBIDataExtractor)
3750
or self.data_extractor is None
3851
)
3952
super().__init__(**kwargs)
53+
self.pos_weight: Optional[torch.Tensor] = None
4054

4155
def set_pos_weight(self, input: torch.Tensor) -> None:
4256
"""
@@ -45,17 +59,7 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
4559
Args:
4660
input (torch.Tensor): The input tensor for which to set the positive weights.
4761
"""
48-
if (
49-
self.beta is not None
50-
and self.data_extractor is not None
51-
and all(
52-
os.path.exists(
53-
os.path.join(self.data_extractor.processed_dir, file_name)
54-
)
55-
for file_name in self.data_extractor.processed_file_names
56-
)
57-
and self.pos_weight is None
58-
):
62+
if self.pos_weight is None:
5963
print(
6064
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
6165
)
@@ -96,3 +100,9 @@ def forward(
96100
"""
97101
self.set_pos_weight(input)
98102
return super().forward(input, target)
103+
104+
105+
class UnWeightedBCEWithLogitsLoss(torch.nn.BCEWithLogitsLoss):
106+
def forward(self, input, target, **kwargs):
107+
# As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them
108+
return super().forward(input, target)

tests/unit/cli/bce_loss.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
class_path: chebai.loss.bce_weighted.UnWeightedBCEWithLogitsLoss

tests/unit/cli/testCLI.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ def setUp(self):
1515
"--model.pass_loss_kwargs=false",
1616
"--trainer.min_epochs=1",
1717
"--trainer.max_epochs=1",
18-
"--model.criterion=configs/loss/bce.yml",
19-
"--model.criterion.init_args.beta=0.99",
18+
"--model.criterion=tests/unit/cli/bce_loss.yml",
2019
]
2120

2221
def test_mlp_on_chebai_cli(self):

0 commit comments

Comments
 (0)