1313from chebai .preprocessing import reader as dr
1414import pysmiles
1515import numpy as np
16- import rdkit
16+ from rdkit import Chem
1717import zipfile
1818import shutil
1919
2020
21- class Tox21Base (XYBaseDataModule ):
21+ class Tox21MolNet (XYBaseDataModule ):
2222 HEADERS = [
2323 "NR-AR" ,
2424 "NR-AR-LBD" ,
@@ -36,7 +36,122 @@ class Tox21Base(XYBaseDataModule):
3636
3737 @property
3838 def _name (self ):
39- return "tox21"
39+ return "Tox21mn"
40+
41+ @property
42+ def label_number (self ):
43+ return 12
44+
45+ @property
46+ def raw_file_names (self ):
47+ return ["tox21.csv" ]
48+
49+ @property
50+ def processed_file_names (self ):
51+ return ["test.pt" , "train.pt" , "validation.pt" ]
52+
53+ def download (self ):
54+ with NamedTemporaryFile ("rb" ) as gout :
55+ request .urlretrieve (
56+ "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz" ,
57+ gout .name ,
58+ )
59+ with gzip .open (gout .name ) as gfile :
60+ with open (os .path .join (self .raw_dir , "tox21.csv" ), "wt" ) as fout :
61+ fout .write (gfile .read ().decode ())
62+
63+ def setup_processed (self ):
64+ print ("Create splits" )
65+ data = self ._load_data_from_file (os .path .join (self .raw_dir , f"tox21.csv" ))
66+ groups = np .array ([d ["group" ] for d in data ])
67+ if not all (g is None for g in groups ):
68+ split_size = int (len (set (groups )) * self .train_split )
69+ os .makedirs (self .processed_dir , exist_ok = True )
70+ splitter = GroupShuffleSplit (train_size = split_size , n_splits = 1 )
71+
72+ train_split_index , temp_split_index = next (
73+ splitter .split (data , groups = groups )
74+ )
75+
76+ split_groups = groups [temp_split_index ]
77+
78+ splitter = GroupShuffleSplit (
79+ train_size = int (len (set (split_groups )) * self .train_split ), n_splits = 1
80+ )
81+ test_split_index , validation_split_index = next (
82+ splitter .split (temp_split_index , groups = split_groups )
83+ )
84+ train_split = [data [i ] for i in train_split_index ]
85+ test_split = [
86+ d
87+ for d in (data [temp_split_index [i ]] for i in test_split_index )
88+ if d ["original" ]
89+ ]
90+ validation_split = [
91+ d
92+ for d in (data [temp_split_index [i ]] for i in validation_split_index )
93+ if d ["original" ]
94+ ]
95+ else :
96+ train_split , test_split = train_test_split (
97+ data , train_size = self .train_split , shuffle = True
98+ )
99+ test_split , validation_split = train_test_split (
100+ test_split , train_size = 0.5 , shuffle = True
101+ )
102+ for k , split in [
103+ ("test" , test_split ),
104+ ("train" , train_split ),
105+ ("validation" , validation_split ),
106+ ]:
107+ print ("transform" , k )
108+ torch .save (
109+ split ,
110+ os .path .join (self .processed_dir , f"{ k } .pt" ),
111+ )
112+
113+ def setup (self , ** kwargs ):
114+ if any (
115+ not os .path .isfile (os .path .join (self .raw_dir , f ))
116+ for f in self .raw_file_names
117+ ):
118+ self .download ()
119+ if any (
120+ not os .path .isfile (os .path .join (self .processed_dir , f ))
121+ for f in self .processed_file_names
122+ ):
123+ self .setup_processed ()
124+
125+ def _load_dict (self , input_file_path ):
126+ with open (input_file_path , "r" ) as input_file :
127+ reader = csv .DictReader (input_file )
128+ for row in reader :
129+ smiles = row ["smiles" ]
130+ labels = [
131+ bool (int (l )) if l else None for l in (row [k ] for k in self .HEADERS )
132+ ]
133+ yield dict (features = smiles , labels = labels , ident = row ["mol_id" ])
134+
135+
136+ class Tox21Challenge (XYBaseDataModule ):
137+ HEADERS = [
138+ "NR-AR" ,
139+ "NR-AR-LBD" ,
140+ "NR-AhR" ,
141+ "NR-Aromatase" ,
142+ "NR-ER" ,
143+ "NR-ER-LBD" ,
144+ "NR-PPAR-gamma" ,
145+ "SR-ARE" ,
146+ "SR-ATAD5" ,
147+ "SR-HSE" ,
148+ "SR-MMP" ,
149+ "SR-p53" ,
150+ ]
151+
152+ @property
153+ def _name (self ):
154+ return "tox21chal"
40155
41156 @property
42157 def label_number (self ):
@@ -81,23 +196,24 @@ def _retrieve_file(self, url, target_file, compression=None):
81196 shutil .move (os .path .join (td .name , f ), target_path )
82197
83198 def _load_data_from_file (self , path ):
84- sdf = rdkit . Chem .SDMolSupplier (path )
199+ sdf = Chem .SDMolSupplier (path )
85200 data = []
86201 for mol in sdf :
87202 if mol is not None :
88203 d = dict (
89204 labels = [int (mol .GetProp (h )) if h in mol .GetPropNames () else None for h in self .HEADERS ],
90205 ident = [mol .GetProp (k ) for k in ("DSSTox_CID" , "Compound ID" ) if k in mol .GetPropNames () ][0 ],
91- features = rdkit . Chem .MolToSmiles (mol ))
206+ features = Chem .MolToSmiles (mol ))
92207 data .append (self .reader .to_data (d ))
93208 return data
94209
95210 def setup_processed (self ):
96211 for k in ("train" , "validation" ):
97- torch .save (self ._load_data_from_file (os .path .join (self .raw_dir , f"{ k } .sdf" )), os .path .join (self .processed_dir , f"{ k } .pt" ))
212+ d = self ._load_data_from_file (os .path .join (self .raw_dir , f"{ k } .sdf" ))
213+ torch .save (d , os .path .join (self .processed_dir , f"{ k } .pt" ))
98214
99215 with open (os .path .join (self .raw_dir , f"test.smiles" )) as fin :
100- headers = next (fin )
216+ next (fin )
101217 test_smiles = dict (reversed (row .strip ().split ("\t " )) for row in fin )
102218 with open (os .path .join (self .raw_dir , f"test_results.txt" )) as fin :
103219 headers = next (fin ).strip ().split ("\t " )
@@ -128,35 +244,9 @@ def _load_dict(self, input_file_path):
128244 yield dict (features = smiles , labels = labels , ident = row ["mol_id" ])
129245
130246
131- class Tox21Chem ( Tox21Base ):
247+ class Tox21ChallengeChem ( Tox21Challenge ):
132248 READER = dr .ChemDataReader
133249
134250
135- class Tox21Graph (Tox21Base ):
136- READER = dr .GraphReader
137-
138-
139-
140- class Tox21ExtendedChem (MergedDataset ):
141- MERGED = [Tox21Chem , Hazardous , JCIExtendedTokenData ]
142-
143- @property
144- def limits (self ):
145- return [None , 5000 , 5000 ]
146-
147- def _process_data (self , subset_id , data ):
148- res = dict (
149- features = data ["features" ], labels = data ["labels" ], ident = data ["ident" ]
150- )
151- # Feature: non-toxic
152- if subset_id == 0 :
153- res ["labels" ] = [not any (res ["labels" ])]
154- elif subset_id == 1 :
155- res ["labels" ] = [False ]
156- elif subset_id == 2 :
157- res ["labels" ] = [True ]
158- return res
159-
160- @property
161- def label_number (self ):
162- return 1
251+ class Tox21MolNetChem (Tox21MolNet ):
252+ READER = dr .ChemDataReader
0 commit comments