Skip to content

Commit 9e5a827

Browse files
authored
Merge pull request #8 from ChEB-AI/features-sfluegel
Minor fixes
2 parents 0ccd7fb + bf7d98b commit 9e5a827

File tree

7 files changed

+26
-18
lines changed

7 files changed

+26
-18
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ repos:
44
# hooks:
55
# - id: isort
66
- repo: https://github.com/psf/black
7-
rev: "22.10.0"
7+
rev: "24.2.0"
88
hooks:
99
- id: black

chebai/loss/__init__.py

Whitespace-only changes.

chebai/molecule.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,18 @@ def create_feature_vectors(self):
8787
for idx in range(self.no_of_atoms):
8888
sorted_path = self.directed_graphs[idx, :, :]
8989

90-
self.local_input_vector[
91-
idx, idx, :length_of_atom_features
92-
] = self.get_atom_features(idx)
90+
self.local_input_vector[idx, idx, :length_of_atom_features] = (
91+
self.get_atom_features(idx)
92+
)
9393

9494
no_of_incoming_edges = {}
9595
for i in range(self.no_of_atoms - 1):
9696
node1 = sorted_path[i, 0]
9797
node2 = sorted_path[i, 1]
9898

99-
self.local_input_vector[
100-
idx, node1, :length_of_atom_features
101-
] = self.get_atom_features(node1)
99+
self.local_input_vector[idx, node1, :length_of_atom_features] = (
100+
self.get_atom_features(node1)
101+
)
102102

103103
if node2 in no_of_incoming_edges:
104104
index = no_of_incoming_edges[node2]

chebai/preprocessing/datasets/chebi.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
"JCI_500_COLUMNS_INT",
1010
]
1111

12-
import queue
1312
from abc import ABC
1413
from collections import OrderedDict
1514
import os
1615
import pickle
16+
import queue
1717

1818
from iterstrat.ml_stratifiers import (
1919
MultilabelStratifiedKFold,
@@ -219,7 +219,7 @@ def setup_processed(self):
219219
self._setup_pruned_test_set()
220220

221221
def get_test_split(self, df: pd.DataFrame):
222-
print("Split dataset into train (including val) / test")
222+
print("Get test data split")
223223

224224
df_list = df.values.tolist()
225225
df_list = [row[3:] for row in df_list]
@@ -247,8 +247,8 @@ def get_train_val_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFram
247247
print(f"Split dataset into train / val with given test set")
248248

249249
df_trainval = df
250-
test_smiles = test_df["SMILES"].tolist()
251-
mask = [smiles not in test_smiles for smiles in df_trainval["SMILES"]]
250+
test_ids = test_df["id"].tolist()
251+
mask = [trainval_id not in test_ids for trainval_id in df_trainval["id"]]
252252
df_trainval = df_trainval[mask]
253253
df_trainval_list = df_trainval.values.tolist()
254254
df_trainval_list = [row[3:] for row in df_trainval_list]
@@ -265,9 +265,9 @@ def get_train_val_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFram
265265
df_validation = df_trainval.iloc[val_ids]
266266
df_train = df_trainval.iloc[train_ids]
267267
folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train
268-
folds[
269-
self.raw_file_names_dict[f"fold_{fold}_validation"]
270-
] = df_validation
268+
folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = (
269+
df_validation
270+
)
271271

272272
return folds
273273

@@ -513,7 +513,7 @@ def extract_class_hierarchy(self, chebi_path):
513513
g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]])
514514

515515
g = nx.transitive_closure_dag(g)
516-
g = g.subgraph(nx.descendants(g, self.top_class_id))
516+
g = g.subgraph(list(nx.descendants(g, self.top_class_id)) + [self.top_class_id])
517517
print("Compute transitive closure")
518518
return g
519519

chebai/result/classification.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def evaluate_model(
6060
collated.x = collated.to_x(model.device)
6161
collated.y = collated.to_y(model.device)
6262
processable_data = model._process_batch(collated, 0)
63-
model_output = model(processable_data)
63+
model_output = model(processable_data, **processable_data["model_kwargs"])
6464
preds, labels = model._get_prediction_and_labels(
6565
processable_data, processable_data["labels"], model_output
6666
)
@@ -166,6 +166,8 @@ def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output
166166

167167
zeros = []
168168
for i, f1 in enumerate(classwise_f1):
169-
if f1 == 0.0 and torch.sum(labels[:, i]):
169+
if f1 == 0.0 and torch.sum(labels[:, i]) != 0:
170170
zeros.append(f"{classes[i] if classes is not None else i}")
171-
print(f'Classes with F1-score == 0 (and non-zero labels): {", ".join(zeros)}')
171+
print(
172+
f'Found {len(zeros)} classes with F1-score == 0 (and non-zero labels): {", ".join(zeros)}'
173+
)

configs/weightings/chebi50_v227.yml

Lines changed: 3 additions & 0 deletions
Large diffs are not rendered by default.

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
"lightning",
4343
"jsonargparse[signatures]>=4.17.0",
4444
"omegaconf",
45+
"seaborn",
46+
"deepsmiles",
47+
"iterative-stratification",
4548
],
4649
extras_require={"dev": ["black", "isort", "pre-commit"]},
4750
)

0 commit comments

Comments
 (0)