|
3 | 3 | from chebai.preprocessing.datasets.chebi import JCIExtendedTokenData |
4 | 4 | from chebai.preprocessing.datasets.pubchem import Hazardous |
5 | 5 | from chebai.preprocessing.datasets.base import XYBaseDataModule, MergedDataset |
6 | | -from tempfile import NamedTemporaryFile |
| 6 | +from tempfile import NamedTemporaryFile, TemporaryDirectory |
7 | 7 | from urllib import request |
8 | 8 | from sklearn.model_selection import train_test_split, GroupShuffleSplit |
9 | 9 | import gzip |
|
13 | 13 | from chebai.preprocessing import reader as dr |
14 | 14 | import pysmiles |
15 | 15 | import numpy as np |
| 16 | +import rdkit |
| 17 | +import zipfile |
| 18 | +import shutil |
16 | 19 |
|
17 | 20 |
|
18 | 21 | class Tox21Base(XYBaseDataModule): |
@@ -41,71 +44,66 @@ def label_number(self): |
41 | 44 |
|
42 | 45 | @property |
43 | 46 | def raw_file_names(self): |
44 | | - return ["tox21.csv"] |
| 47 | + return ["train.sdf", "validation.sdf", "validation.smiles", "test.smiles", "test_results.txt"] |
45 | 48 |
|
46 | 49 | @property |
47 | 50 | def processed_file_names(self): |
48 | 51 | return ["test.pt", "train.pt", "validation.pt"] |
49 | 52 |
|
50 | 53 | def download(self): |
51 | | - with NamedTemporaryFile("rb") as gout: |
52 | | - request.urlretrieve( |
53 | | - "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz", |
54 | | - gout.name, |
55 | | - ) |
56 | | - with gzip.open(gout.name) as gfile: |
57 | | - with open(os.path.join(self.raw_dir, "tox21.csv"), "wt") as fout: |
58 | | - fout.write(gfile.read().decode()) |
| 54 | + self._retrieve_file("https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_data_allsdf&sec=", "train.sdf", compression="zip") |
| 55 | + self._retrieve_file("https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_challenge_testsdf&sec=", |
| 56 | + "validation.sdf", compression="zip") |
| 57 | + self._retrieve_file("https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_challenge_scoresmiles&sec=", |
| 58 | + "test.smiles") |
| 59 | + self._retrieve_file("https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_challenge_scoretxt&sec=", |
| 60 | + "test_results.txt") |
| 61 | + |
| 62 | + def _retrieve_file(self, url, target_file, compression=None): |
| 63 | + target_path = os.path.join(self.raw_dir, target_file) |
| 64 | + if not os.path.isfile(target_path): |
| 65 | + with NamedTemporaryFile("rb") as gout: |
| 66 | + if compression is None: |
| 67 | + download_path = target_path |
| 68 | + else: |
| 69 | + download_path = gout.name |
| 70 | + request.urlretrieve( |
| 71 | + url, |
| 72 | + download_path, |
| 73 | + ) |
| 74 | + if compression == "zip": |
| 75 | + td = TemporaryDirectory() |
| 76 | + with zipfile.ZipFile(download_path, 'r') as zip_ref: |
| 77 | + zip_ref.extractall(td.name) |
| 78 | + files_in_zip = os.listdir(td.name) |
| 79 | + f = files_in_zip[0] |
| 80 | + assert len(files_in_zip) == 1 |
| 81 | + shutil.move(os.path.join(td.name, f), target_path) |
| 82 | + |
| 83 | + def _load_data_from_file(self, path): |
| 84 | + sdf = rdkit.Chem.SDMolSupplier(path) |
| 85 | + data = [] |
| 86 | + for mol in sdf: |
| 87 | + if mol is not None: |
| 88 | + d = dict( |
| 89 | + labels=[int(mol.GetProp(h)) if h in mol.GetPropNames() else None for h in self.HEADERS], |
| 90 | + ident=[mol.GetProp(k) for k in ("DSSTox_CID", "Compound ID") if k in mol.GetPropNames() ][0], |
| 91 | + features=rdkit.Chem.MolToSmiles(mol)) |
| 92 | + data.append(self.reader.to_data(d)) |
| 93 | + return data |
59 | 94 |
|
60 | 95 | def setup_processed(self): |
61 | | - print("Create splits") |
62 | | - data = self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv")) |
63 | | - groups = np.array([d["group"] for d in data]) |
64 | | - if not all(g is None for g in groups): |
65 | | - split_size = int(len(set(groups)) * self.train_split) |
66 | | - os.makedirs(self.processed_dir, exist_ok=True) |
67 | | - splitter = GroupShuffleSplit(train_size=split_size, n_splits=1) |
68 | | - |
69 | | - train_split_index, temp_split_index = next( |
70 | | - splitter.split(data, groups=groups) |
71 | | - ) |
72 | | - |
73 | | - split_groups = groups[temp_split_index] |
74 | | - |
75 | | - splitter = GroupShuffleSplit( |
76 | | - train_size=int(len(set(split_groups)) * self.train_split), n_splits=1 |
77 | | - ) |
78 | | - test_split_index, validation_split_index = next( |
79 | | - splitter.split(temp_split_index, groups=split_groups) |
80 | | - ) |
81 | | - train_split = [data[i] for i in train_split_index] |
82 | | - test_split = [ |
83 | | - d |
84 | | - for d in (data[temp_split_index[i]] for i in test_split_index) |
85 | | - if d["original"] |
86 | | - ] |
87 | | - validation_split = [ |
88 | | - d |
89 | | - for d in (data[temp_split_index[i]] for i in validation_split_index) |
90 | | - if d["original"] |
91 | | - ] |
92 | | - else: |
93 | | - train_split, test_split = train_test_split( |
94 | | - data, train_size=self.train_split, shuffle=True |
95 | | - ) |
96 | | - test_split, validation_split = train_test_split( |
97 | | - test_split, train_size=0.5, shuffle=True |
98 | | - ) |
99 | | - for k, split in [ |
100 | | - ("test", test_split), |
101 | | - ("train", train_split), |
102 | | - ("validation", validation_split), |
103 | | - ]: |
104 | | - print("transform", k) |
105 | | - torch.save( |
106 | | - split, |
107 | | - os.path.join(self.processed_dir, f"{k}.pt"), |
108 | | - ) |
| 96 | + for k in ("train", "validation"): |
| 97 | + torch.save(self._load_data_from_file(os.path.join(self.raw_dir, f"{k}.sdf")), os.path.join(self.processed_dir, f"{k}.pt")) |
| 98 | + |
| 99 | + with open(os.path.join(self.raw_dir, f"test.smiles")) as fin: |
| 100 | + headers = next(fin) |
| 101 | + test_smiles = dict(reversed(row.strip().split("\t")) for row in fin) |
| 102 | + with open(os.path.join(self.raw_dir, f"test_results.txt")) as fin: |
| 103 | + headers = next(fin).strip().split("\t") |
| 104 | + test_results = {k["Sample ID"]:[int(k[h]) if k[h] != "x" else None for h in self.HEADERS] for k in (dict(zip(headers, row.strip().split("\t"), strict=True)) for row in fin if row)} |
| 105 | + test_data = [self.reader.to_data(dict(features=test_smiles[k], labels=test_results[k], ident=k)) for k in test_smiles] |
| 106 | + torch.save(test_data, os.path.join(self.processed_dir, f"test.pt")) |
109 | 107 |
|
110 | 108 | def setup(self, **kwargs): |
111 | 109 | if any( |
|
0 commit comments