Skip to content

Commit 04481e4

Browse files
committed
Implement Tox21 dataset as used in the competition
1 parent 497d614 commit 04481e4

File tree

5 files changed

+151
-52
lines changed

5 files changed

+151
-52
lines changed

chebai/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ def cli():
1515
@click.argument("experiment")
1616
@click.argument("batch_size", type=click.INT)
1717
@click.option("-g", "--group", default="default")
18+
@click.option("--version", default=None)
1819
@click.argument("args", nargs=-1)
19-
def train(experiment, batch_size, group, args):
20+
def train(experiment, batch_size, group, version, args):
2021
"""Run experiment identified by EXPERIMENT in batches of size BATCH_SIZE."""
2122
try:
22-
ex = experiments.EXPERIMENTS[experiment](batch_size, group)
23+
ex = experiments.EXPERIMENTS[experiment](batch_size, group, version=version)
2324
except KeyError:
2425
raise Exception(
2526
"Experiment ID not found. The following are available:"

chebai/experiments.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ def __init_subclass__(cls, **kwargs):
2525
if cls.identifier() is not None:
2626
EXPERIMENTS[cls.identifier()] = cls
2727

28-
def __init__(self, batch_size, *args, **kwargs):
28+
def __init__(self, batch_size, *args, version=None, **kwargs):
2929
self.dataset = self.build_dataset(batch_size)
30+
self.version=version
3031

3132
@classmethod
3233
def identifier(cls) -> str:
@@ -46,6 +47,7 @@ def train(self, batch_size, *args):
4647
self.MODEL.NAME,
4748
loss=self.LOSS,
4849
model_kwargs=self.model_kwargs(*args),
50+
version=self.version
4951
)
5052

5153
def test(self, ckpt_path, *args):
@@ -264,28 +266,32 @@ def model_kwargs(self, *args) -> Dict:
264266
return d
265267

266268

267-
class ElectraOnTox21Bloat(ElectraOnTox21):
269+
class ElectraOnTox21MoleculeNet(_ElectraExperiment):
268270
MODEL = electra.Electra
269271
LOSS = torch.nn.BCEWithLogitsLoss
270272

271273
@classmethod
272274
def identifier(cls) -> str:
273-
return "Electra+Tox21Bloat"
275+
return "Electra+Tox21MN"
274276

275277
def build_dataset(self, batch_size) -> datasets.XYBaseDataModule:
276-
return datasets.Tox21BloatChem(batch_size)
278+
return datasets.Tox21MolNetChem(batch_size)
277279

280+
def model_kwargs(self, *args) -> Dict:
281+
d = super().model_kwargs(*args)
282+
d["config"]["hidden_dropout_prob"] = 0.4
283+
d["config"]["word_dropout"] = 0.2
284+
d["optimizer_kwargs"]["weight_decay"] = 1e-4
285+
return d
278286

279-
class ElectraOnTox21Ext(_ElectraExperiment):
280-
MODEL = electra.Electra
281-
LOSS = torch.nn.BCEWithLogitsLoss
282287

288+
class ElectraOnTox21Challenge(_ElectraExperiment):
283289
@classmethod
284290
def identifier(cls) -> str:
285-
return "Electra+Tox21Ext"
291+
return "Electra+Tox21Chal"
286292

287293
def build_dataset(self, batch_size) -> datasets.XYBaseDataModule:
288-
return datasets.Tox21ExtendedChem(batch_size)
294+
return datasets.Tox21ChallengeChem(batch_size)
289295

290296

291297
class ElectraBPEOnJCIExt(_ElectraExperiment):

chebai/models/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def run(
146146
model_kwargs: dict = None,
147147
loss=torch.nn.BCELoss,
148148
weighted=False,
149+
version=None
149150
):
150151
if model_args is None:
151152
model_args = []
@@ -177,7 +178,9 @@ def run(
177178
else:
178179
trainer_kwargs = dict(gpus=0)
179180

180-
tb_logger = pl_loggers.TensorBoardLogger("logs/", name=name)
181+
tb_logger = pl_loggers.TensorBoardLogger("logs/", name=name, version=version)
182+
if os.path.isdir(tb_logger.log_dir):
183+
raise IOError("Fixed logging directory does already exist:", tb_logger.log_dir)
181184
best_checkpoint_callback = ModelCheckpoint(
182185
dirpath=os.path.join(tb_logger.log_dir, "best_checkpoints"),
183186
filename="{epoch}-{val_F1Score_micro:.4f}--{val_loss:.4f}",

chebai/models/electra.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,9 @@ def __init__(self, **kwargs):
151151
self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
152152
self.word_dropout = nn.Dropout(kwargs["config"].get("word_dropout", 0))
153153
if pretrained_checkpoint:
154-
elpre = ElectraPre.load_from_checkpoint(pretrained_checkpoint)
155-
with TemporaryDirectory() as td:
156-
elpre.as_pretrained.save_pretrained(td)
157-
self.electra = ElectraModel.from_pretrained(td, config=self.config)
154+
with open(pretrained_checkpoint, "rb") as fin:
155+
model_dict = torch.load(fin,map_location=self.device)
156+
self.electra = ElectraModel.from_pretrained(None, state_dict=model_dict['state_dict'], config=self.config)
158157
else:
159158
self.electra = ElectraModel(config=self.config)
160159

chebai/preprocessing/datasets/tox21.py

Lines changed: 126 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
from chebai.preprocessing import reader as dr
1414
import pysmiles
1515
import numpy as np
16-
import rdkit
16+
from rdkit import Chem
1717
import zipfile
1818
import shutil
1919

2020

21-
class Tox21Base(XYBaseDataModule):
21+
class Tox21MolNet(XYBaseDataModule):
2222
HEADERS = [
2323
"NR-AR",
2424
"NR-AR-LBD",
@@ -36,7 +36,122 @@ class Tox21Base(XYBaseDataModule):
3636

3737
@property
3838
def _name(self):
39-
return "tox21"
39+
return "Tox21mn"
40+
41+
@property
42+
def label_number(self):
43+
return 12
44+
45+
@property
46+
def raw_file_names(self):
47+
return ["tox21.csv"]
48+
49+
@property
50+
def processed_file_names(self):
51+
return ["test.pt", "train.pt", "validation.pt"]
52+
53+
def download(self):
54+
with NamedTemporaryFile("rb") as gout:
55+
request.urlretrieve(
56+
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz",
57+
gout.name,
58+
)
59+
with gzip.open(gout.name) as gfile:
60+
with open(os.path.join(self.raw_dir, "tox21.csv"), "wt") as fout:
61+
fout.write(gfile.read().decode())
62+
63+
def setup_processed(self):
64+
print("Create splits")
65+
data = self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv"))
66+
groups = np.array([d["group"] for d in data])
67+
if not all(g is None for g in groups):
68+
split_size = int(len(set(groups)) * self.train_split)
69+
os.makedirs(self.processed_dir, exist_ok=True)
70+
splitter = GroupShuffleSplit(train_size=split_size, n_splits=1)
71+
72+
train_split_index, temp_split_index = next(
73+
splitter.split(data, groups=groups)
74+
)
75+
76+
split_groups = groups[temp_split_index]
77+
78+
splitter = GroupShuffleSplit(
79+
train_size=int(len(set(split_groups)) * self.train_split), n_splits=1
80+
)
81+
test_split_index, validation_split_index = next(
82+
splitter.split(temp_split_index, groups=split_groups)
83+
)
84+
train_split = [data[i] for i in train_split_index]
85+
test_split = [
86+
d
87+
for d in (data[temp_split_index[i]] for i in test_split_index)
88+
if d["original"]
89+
]
90+
validation_split = [
91+
d
92+
for d in (data[temp_split_index[i]] for i in validation_split_index)
93+
if d["original"]
94+
]
95+
else:
96+
train_split, test_split = train_test_split(
97+
data, train_size=self.train_split, shuffle=True
98+
)
99+
test_split, validation_split = train_test_split(
100+
test_split, train_size=0.5, shuffle=True
101+
)
102+
for k, split in [
103+
("test", test_split),
104+
("train", train_split),
105+
("validation", validation_split),
106+
]:
107+
print("transform", k)
108+
torch.save(
109+
split,
110+
os.path.join(self.processed_dir, f"{k}.pt"),
111+
)
112+
113+
def setup(self, **kwargs):
114+
if any(
115+
not os.path.isfile(os.path.join(self.raw_dir, f))
116+
for f in self.raw_file_names
117+
):
118+
self.download()
119+
if any(
120+
not os.path.isfile(os.path.join(self.processed_dir, f))
121+
for f in self.processed_file_names
122+
):
123+
self.setup_processed()
124+
125+
def _load_dict(self, input_file_path):
126+
with open(input_file_path, "r") as input_file:
127+
reader = csv.DictReader(input_file)
128+
for row in reader:
129+
smiles = row["smiles"]
130+
labels = [
131+
bool(int(l)) if l else None for l in (row[k] for k in self.HEADERS)
132+
]
133+
yield dict(features=smiles, labels=labels, ident=row["mol_id"])
134+
135+
136+
class Tox21Challenge(XYBaseDataModule):
137+
HEADERS = [
138+
"NR-AR",
139+
"NR-AR-LBD",
140+
"NR-AhR",
141+
"NR-Aromatase",
142+
"NR-ER",
143+
"NR-ER-LBD",
144+
"NR-PPAR-gamma",
145+
"SR-ARE",
146+
"SR-ATAD5",
147+
"SR-HSE",
148+
"SR-MMP",
149+
"SR-p53",
150+
]
151+
152+
@property
153+
def _name(self):
154+
return "tox21chal"
40155

41156
@property
42157
def label_number(self):
@@ -81,23 +196,24 @@ def _retrieve_file(self, url, target_file, compression=None):
81196
shutil.move(os.path.join(td.name, f), target_path)
82197

83198
def _load_data_from_file(self, path):
84-
sdf = rdkit.Chem.SDMolSupplier(path)
199+
sdf = Chem.SDMolSupplier(path)
85200
data = []
86201
for mol in sdf:
87202
if mol is not None:
88203
d = dict(
89204
labels=[int(mol.GetProp(h)) if h in mol.GetPropNames() else None for h in self.HEADERS],
90205
ident=[mol.GetProp(k) for k in ("DSSTox_CID", "Compound ID") if k in mol.GetPropNames() ][0],
91-
features=rdkit.Chem.MolToSmiles(mol))
206+
features=Chem.MolToSmiles(mol))
92207
data.append(self.reader.to_data(d))
93208
return data
94209

95210
def setup_processed(self):
96211
for k in ("train", "validation"):
97-
torch.save(self._load_data_from_file(os.path.join(self.raw_dir, f"{k}.sdf")), os.path.join(self.processed_dir, f"{k}.pt"))
212+
d = self._load_data_from_file(os.path.join(self.raw_dir, f"{k}.sdf"))
213+
torch.save(d, os.path.join(self.processed_dir, f"{k}.pt"))
98214

99215
with open(os.path.join(self.raw_dir, f"test.smiles")) as fin:
100-
headers = next(fin)
216+
next(fin)
101217
test_smiles = dict(reversed(row.strip().split("\t")) for row in fin)
102218
with open(os.path.join(self.raw_dir, f"test_results.txt")) as fin:
103219
headers = next(fin).strip().split("\t")
@@ -128,35 +244,9 @@ def _load_dict(self, input_file_path):
128244
yield dict(features=smiles, labels=labels, ident=row["mol_id"])
129245

130246

131-
class Tox21Chem(Tox21Base):
247+
class Tox21ChallengeChem(Tox21Challenge):
132248
READER = dr.ChemDataReader
133249

134250

135-
class Tox21Graph(Tox21Base):
136-
READER = dr.GraphReader
137-
138-
139-
140-
class Tox21ExtendedChem(MergedDataset):
141-
MERGED = [Tox21Chem, Hazardous, JCIExtendedTokenData]
142-
143-
@property
144-
def limits(self):
145-
return [None, 5000, 5000]
146-
147-
def _process_data(self, subset_id, data):
148-
res = dict(
149-
features=data["features"], labels=data["labels"], ident=data["ident"]
150-
)
151-
# Feature: non-toxic
152-
if subset_id == 0:
153-
res["labels"] = [not any(res["labels"])]
154-
elif subset_id == 1:
155-
res["labels"] = [False]
156-
elif subset_id == 2:
157-
res["labels"] = [True]
158-
return res
159-
160-
@property
161-
def label_number(self):
162-
return 1
251+
class Tox21MolNetChem(Tox21MolNet):
252+
READER = dr.ChemDataReader

0 commit comments

Comments
 (0)