|
17 | 17 | from sklearn.model_selection import train_test_split |
18 | 18 | import requests |
19 | 19 | import torch |
| 20 | +import time |
| 21 | +import numpy as np |
20 | 22 | import tqdm |
21 | 23 | from datetime import datetime |
22 | 24 |
|
@@ -56,7 +58,7 @@ def identifier(self): |
56 | 58 |
|
57 | 59 | @property |
58 | 60 | def split_label(self): |
59 | | - if self._k: |
| 61 | + if self._k and self._k != self.FULL: |
60 | 62 | return str(self._k) |
61 | 63 | else: |
62 | 64 | return "full" |
@@ -211,6 +213,77 @@ def download(self): |
211 | 213 | ) |
212 | 214 |
|
213 | 215 |
|
| 216 | +class PubChemKMeans(PubChem): |
| 217 | + |
| 218 | + def __init__(self, *args, n_clusters=1e4, random_size=1e6, **kwargs): |
| 219 | + """k: number of entries in this dataset, |
| 220 | + n_random_subsets: number of subsets of random data from which to draw |
| 221 | + the most dissimilar molecules, |
| 222 | + random_size_factor: size of random subsets (in total) in relation to k""" |
| 223 | + self.n_clusters = int(n_clusters) |
| 224 | + super(PubChemKMeans, self).__init__(*args, k=int(random_size), **kwargs) |
| 225 | + |
| 226 | + @property |
| 227 | + def _name(self): |
| 228 | + return f"PubchemKMeans" |
| 229 | + |
| 230 | + def download(self): |
| 231 | + if self._k == PubChem.FULL: |
| 232 | + super().download() |
| 233 | + else: |
| 234 | + print(f"Loading random dataset (size: {self._k})...") |
| 235 | + random_dataset = PubChem(k=self._k) |
| 236 | + random_dataset.download() |
| 237 | + fingerprints_path = os.path.join(self.raw_dir, "fingerprints.pkl") |
| 238 | + if not os.path.exists(fingerprints_path): |
| 239 | + with open( |
| 240 | + os.path.join(random_dataset.raw_dir, "smiles.txt"), "r" |
| 241 | + ) as f_in: |
| 242 | + random_smiles = [s.split("\t")[1].strip() for s in f_in.readlines()] |
| 243 | + fpgen = AllChem.GetRDKitFPGenerator() |
| 244 | + selected_smiles = [] |
| 245 | + print(f"Converting SMILES to molecules...") |
| 246 | + mols = [Chem.MolFromSmiles(s) for s in tqdm.tqdm(random_smiles)] |
| 247 | + print(f"Generating Fingerprints...") |
| 248 | + fps = [ |
| 249 | + fpgen.GetFingerprint(m) if m is not None else m |
| 250 | + for m in tqdm.tqdm(mols) |
| 251 | + ] |
| 252 | + similarity = [] |
| 253 | + d = {"smiles": random_smiles, "fps": fps} |
| 254 | + df = pd.DataFrame(d, columns=["smiles", "fps"]) |
| 255 | + df = df.dropna() |
| 256 | + df.to_pickle(open(fingerprints_path, "wb")) |
| 257 | + else: |
| 258 | + df = pd.read_pickle(open(fingerprints_path, "rb")) |
| 259 | + fps = np.array([list(vec) for vec in df["fps"].tolist()]) |
| 260 | + print(f"Starting k-means clustering...") |
| 261 | + start_time = time.perf_counter() |
| 262 | + kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_init="auto") |
| 263 | + kmeans.fit(fps) |
| 264 | + print(f"Finished k-means in {time.perf_counter() - start_time:.2f} seconds") |
| 265 | + df["label"] = kmeans.labels_ |
| 266 | + df.to_pickle( |
| 267 | + open( |
| 268 | + os.path.join( |
| 269 | + self.raw_dir, f"fingerprints_labeled_{self.n_clusters}.pkl" |
| 270 | + ), |
| 271 | + "wb", |
| 272 | + ) |
| 273 | + ) |
| 274 | + cluster_df = pd.DataFrame( |
| 275 | + data={"centers": [center for center in kmeans.cluster_centers_]} |
| 276 | + ) |
| 277 | + cluster_df.to_pickle( |
| 278 | + open( |
| 279 | + os.path.join( |
| 280 | + self.raw_dir, f"cluster_centers_{self.n_clusters}.pkl" |
| 281 | + ), |
| 282 | + "wb", |
| 283 | + ) |
| 284 | + ) |
| 285 | + |
| 286 | + |
214 | 287 | class PubChemDissimilarSMILES(PubChemDissimilar): |
215 | 288 | READER = dr.ChemDataReader |
216 | 289 |
|
@@ -310,8 +383,8 @@ def download(self): |
310 | 383 |
|
311 | 384 |
|
312 | 385 | if __name__ == "__main__": |
313 | | - haz = Hazardous() |
314 | | - haz.setup_processed() |
| 386 | + kmeans_data = PubChemKMeans() |
| 387 | + kmeans_data.download() |
315 | 388 |
|
316 | 389 |
|
317 | 390 | class SWJPreChem(PubChem): |
|
0 commit comments