Skip to content

Commit e9de192

Browse files
author
sfluegel
committed
reformat using black
1 parent 85ce29f commit e9de192

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

chebai/callbacks/prediction_callback.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,23 @@
33
import os
44
import pickle
55

6+
67
class PredictionWriter(BasePredictionWriter):
78
def __init__(self, output_dir, write_interval):
89
super().__init__(write_interval)
910
self.output_dir = output_dir
1011
self.prediction_file_name = "predictions.pkl"
1112

1213
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
13-
results = [dict(ident=row["data"]["idents"][0], predictions=torch.sigmoid(row["output"]["logits"]).numpy(),
14-
labels=row["labels"][0].numpy() if row["labels"] is not None else None) for row in predictions]
15-
with open(os.path.join(self.output_dir, self.prediction_file_name), "wb") as fout:
14+
results = [
15+
dict(
16+
ident=row["data"]["idents"][0],
17+
predictions=torch.sigmoid(row["output"]["logits"]).numpy(),
18+
labels=row["labels"][0].numpy() if row["labels"] is not None else None,
19+
)
20+
for row in predictions
21+
]
22+
with open(
23+
os.path.join(self.output_dir, self.prediction_file_name), "wb"
24+
) as fout:
1625
pickle.dump(results, fout)

chebai/loss/bce_weighted.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ def __init__(self, beta: float = None, data_extractor: _ChEBIDataExtractor = Non
1515
self.data_extractor = data_extractor
1616
super().__init__()
1717

18-
1918
def set_pos_weight(self, input):
2019
if (
2120
self.beta is not None
2221
and self.data_extractor is not None
2322
and all(
2423
os.path.exists(os.path.join(self.data_extractor.raw_dir, raw_file))
2524
for raw_file in self.data_extractor.raw_file_names
26-
) and self.pos_weight is None
25+
)
26+
and self.pos_weight is None
2727
):
2828
complete_data = pd.concat(
2929
[
@@ -42,9 +42,13 @@ def set_pos_weight(self, input):
4242
value_counts = []
4343
for c in complete_data.columns[3:]:
4444
value_counts.append(len([v for v in complete_data[c] if v]))
45-
weights = [(1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts]
45+
weights = [
46+
(1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts
47+
]
4648
mean = sum(weights) / len(weights)
47-
self.pos_weight = torch.tensor([w / mean for w in weights], device=input.device)
49+
self.pos_weight = torch.tensor(
50+
[w / mean for w in weights], device=input.device
51+
)
4852

4953
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
5054
self.set_pos_weight(input)

chebai/models/electra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
2020

21-
from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa
21+
from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa
22+
2223

2324
class ElectraPre(ChebaiBaseNet):
2425
NAME = "ElectraPre"

0 commit comments

Comments
 (0)