Skip to content

Commit 497d614

Browse files
committed
Implement Tox21 dataset as used in the competition
1 parent 11819d1 commit 497d614

File tree

2 files changed

+58
-59
lines changed

2 files changed

+58
-59
lines changed

chebai/preprocessing/bin/tokens.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,4 +891,5 @@ I
891891
[o]
892892
[s]
893893
[SH]
894-
[as]
894+
[as]
895+
[c+]

chebai/preprocessing/datasets/tox21.py

Lines changed: 56 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from chebai.preprocessing.datasets.chebi import JCIExtendedTokenData
44
from chebai.preprocessing.datasets.pubchem import Hazardous
55
from chebai.preprocessing.datasets.base import XYBaseDataModule, MergedDataset
6-
from tempfile import NamedTemporaryFile
6+
from tempfile import NamedTemporaryFile, TemporaryDirectory
77
from urllib import request
88
from sklearn.model_selection import train_test_split, GroupShuffleSplit
99
import gzip
@@ -13,6 +13,9 @@
1313
from chebai.preprocessing import reader as dr
1414
import pysmiles
1515
import numpy as np
16+
import rdkit
17+
import zipfile
18+
import shutil
1619

1720

1821
class Tox21Base(XYBaseDataModule):
@@ -41,71 +44,66 @@ def label_number(self):
4144

4245
@property
4346
def raw_file_names(self):
44-
return ["tox21.csv"]
47+
return ["train.sdf", "validation.sdf", "validation.smiles", "test.smiles", "test_results.txt"]
4548

4649
@property
4750
def processed_file_names(self):
4851
return ["test.pt", "train.pt", "validation.pt"]
4952

5053
def download(self):
51-
with NamedTemporaryFile("rb") as gout:
52-
request.urlretrieve(
53-
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz",
54-
gout.name,
55-
)
56-
with gzip.open(gout.name) as gfile:
57-
with open(os.path.join(self.raw_dir, "tox21.csv"), "wt") as fout:
58-
fout.write(gfile.read().decode())
54+
self._retrieve_file("https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_data_allsdf&sec=", "train.sdf", compression="zip")
55+
self._retrieve_file("https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_challenge_testsdf&sec=",
56+
"validation.sdf", compression="zip")
57+
self._retrieve_file("https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_challenge_scoresmiles&sec=",
58+
"test.smiles")
59+
self._retrieve_file("https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_challenge_scoretxt&sec=",
60+
"test_results.txt")
61+
62+
def _retrieve_file(self, url, target_file, compression=None):
63+
target_path = os.path.join(self.raw_dir, target_file)
64+
if not os.path.isfile(target_path):
65+
with NamedTemporaryFile("rb") as gout:
66+
if compression is None:
67+
download_path = target_path
68+
else:
69+
download_path = gout.name
70+
request.urlretrieve(
71+
url,
72+
download_path,
73+
)
74+
if compression == "zip":
75+
td = TemporaryDirectory()
76+
with zipfile.ZipFile(download_path, 'r') as zip_ref:
77+
zip_ref.extractall(td.name)
78+
files_in_zip = os.listdir(td.name)
79+
f = files_in_zip[0]
80+
assert len(files_in_zip) == 1
81+
shutil.move(os.path.join(td.name, f), target_path)
82+
83+
def _load_data_from_file(self, path):
84+
sdf = rdkit.Chem.SDMolSupplier(path)
85+
data = []
86+
for mol in sdf:
87+
if mol is not None:
88+
d = dict(
89+
labels=[int(mol.GetProp(h)) if h in mol.GetPropNames() else None for h in self.HEADERS],
90+
ident=[mol.GetProp(k) for k in ("DSSTox_CID", "Compound ID") if k in mol.GetPropNames() ][0],
91+
features=rdkit.Chem.MolToSmiles(mol))
92+
data.append(self.reader.to_data(d))
93+
return data
5994

6095
def setup_processed(self):
61-
print("Create splits")
62-
data = self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv"))
63-
groups = np.array([d["group"] for d in data])
64-
if not all(g is None for g in groups):
65-
split_size = int(len(set(groups)) * self.train_split)
66-
os.makedirs(self.processed_dir, exist_ok=True)
67-
splitter = GroupShuffleSplit(train_size=split_size, n_splits=1)
68-
69-
train_split_index, temp_split_index = next(
70-
splitter.split(data, groups=groups)
71-
)
72-
73-
split_groups = groups[temp_split_index]
74-
75-
splitter = GroupShuffleSplit(
76-
train_size=int(len(set(split_groups)) * self.train_split), n_splits=1
77-
)
78-
test_split_index, validation_split_index = next(
79-
splitter.split(temp_split_index, groups=split_groups)
80-
)
81-
train_split = [data[i] for i in train_split_index]
82-
test_split = [
83-
d
84-
for d in (data[temp_split_index[i]] for i in test_split_index)
85-
if d["original"]
86-
]
87-
validation_split = [
88-
d
89-
for d in (data[temp_split_index[i]] for i in validation_split_index)
90-
if d["original"]
91-
]
92-
else:
93-
train_split, test_split = train_test_split(
94-
data, train_size=self.train_split, shuffle=True
95-
)
96-
test_split, validation_split = train_test_split(
97-
test_split, train_size=0.5, shuffle=True
98-
)
99-
for k, split in [
100-
("test", test_split),
101-
("train", train_split),
102-
("validation", validation_split),
103-
]:
104-
print("transform", k)
105-
torch.save(
106-
split,
107-
os.path.join(self.processed_dir, f"{k}.pt"),
108-
)
96+
for k in ("train", "validation"):
97+
torch.save(self._load_data_from_file(os.path.join(self.raw_dir, f"{k}.sdf")), os.path.join(self.processed_dir, f"{k}.pt"))
98+
99+
with open(os.path.join(self.raw_dir, f"test.smiles")) as fin:
100+
headers = next(fin)
101+
test_smiles = dict(reversed(row.strip().split("\t")) for row in fin)
102+
with open(os.path.join(self.raw_dir, f"test_results.txt")) as fin:
103+
headers = next(fin).strip().split("\t")
104+
test_results = {k["Sample ID"]:[int(k[h]) if k[h] != "x" else None for h in self.HEADERS] for k in (dict(zip(headers, row.strip().split("\t"), strict=True)) for row in fin if row)}
105+
test_data = [self.reader.to_data(dict(features=test_smiles[k], labels=test_results[k], ident=k)) for k in test_smiles]
106+
torch.save(test_data, os.path.join(self.processed_dir, f"test.pt"))
109107

110108
def setup(self, **kwargs):
111109
if any(

0 commit comments

Comments
 (0)