|
12 | 12 | import random |
13 | 13 | import shutil |
14 | 14 | import tempfile |
| 15 | +from scipy import spatial |
15 | 16 |
|
16 | 17 | import pandas as pd |
17 | 18 | from sklearn.model_selection import train_test_split |
|
28 | 29 | ChEBIOver50, |
29 | 30 | ChEBIOver100, |
30 | 31 | ChEBIOverX, |
| 32 | + _ChEBIDataExtractor, |
31 | 33 | ) |
32 | 34 | from rdkit import Chem, DataStructs |
33 | 35 | from rdkit.Chem import AllChem |
| 36 | +from sklearn.cluster import KMeans |
34 | 37 |
|
35 | 38 |
|
36 | 39 | class PubChem(XYBaseDataModule): |
@@ -215,73 +218,247 @@ def download(self): |
215 | 218 |
|
216 | 219 | class PubChemKMeans(PubChem): |
217 | 220 |
|
218 | | - def __init__(self, *args, n_clusters=1e4, random_size=1e6, **kwargs): |
| 221 | + def __init__( |
| 222 | + self, |
| 223 | + *args, |
| 224 | + n_clusters=1e4, |
| 225 | + random_size=1e6, |
| 226 | + exclude_data_from: _ChEBIDataExtractor = None, |
| 227 | + validation_size_limit=4000, |
| 228 | + include_min_n_clusters=100, |
| 229 | + **kwargs, |
| 230 | + ): |
219 | 231 | """k: number of entries in this dataset, |
220 | 232 | n_random_subsets: number of subsets of random data from which to draw |
221 | 233 | the most dissimilar molecules, |
222 | 234 | random_size_factor: size of random subsets (in total) in relation to k""" |
223 | 235 | self.n_clusters = int(n_clusters) |
| 236 | + self.exclude_data_from = exclude_data_from |
| 237 | + self.validation_size_limit = validation_size_limit |
| 238 | + self.include_min_n_clusters = include_min_n_clusters |
224 | 239 | super(PubChemKMeans, self).__init__(*args, k=int(random_size), **kwargs) |
| 240 | + self._fingerprints = None |
| 241 | + self._cluster_centers = None |
| 242 | + self._fingerprints_clustered = None |
| 243 | + self._exclusion_data_clustered = None |
| 244 | + self._cluster_centers_superclustered = None |
225 | 245 |
|
226 | 246 | @property |
227 | 247 | def _name(self): |
228 | 248 | return f"PubchemKMeans" |
229 | 249 |
|
230 | | - def download(self): |
231 | | - if self._k == PubChem.FULL: |
232 | | - super().download() |
| 250 | + @property |
| 251 | + def split_label(self): |
| 252 | + if self._k and self._k != self.FULL: |
| 253 | + return f"{self.n_clusters}_centers_out_of_{self._k}" |
233 | 254 | else: |
234 | | - print(f"Loading random dataset (size: {self._k})...") |
235 | | - random_dataset = PubChem(k=self._k) |
236 | | - random_dataset.download() |
| 255 | + return f"{self.n_clusters}_centers_out_of_full" |
| 256 | + |
| 257 | + @property |
| 258 | + def raw_file_names(self): |
| 259 | + return ["train.txt", "validation.txt", "test.txt"] |
| 260 | + |
| 261 | + @property |
| 262 | + def fingerprints(self): |
| 263 | + if self._fingerprints is None: |
237 | 264 | fingerprints_path = os.path.join(self.raw_dir, "fingerprints.pkl") |
238 | 265 | if not os.path.exists(fingerprints_path): |
| 266 | + print(f"No fingerprints found...") |
| 267 | + print(f"Loading random dataset (size: {self._k})...") |
| 268 | + random_dataset = PubChem(k=self._k) |
| 269 | + random_dataset.download() |
239 | 270 | with open( |
240 | 271 | os.path.join(random_dataset.raw_dir, "smiles.txt"), "r" |
241 | 272 | ) as f_in: |
242 | 273 | random_smiles = [s.split("\t")[1].strip() for s in f_in.readlines()] |
243 | 274 | fpgen = AllChem.GetRDKitFPGenerator() |
244 | | - selected_smiles = [] |
245 | 275 | print(f"Converting SMILES to molecules...") |
246 | 276 | mols = [Chem.MolFromSmiles(s) for s in tqdm.tqdm(random_smiles)] |
247 | 277 | print(f"Generating Fingerprints...") |
248 | 278 | fps = [ |
249 | 279 | fpgen.GetFingerprint(m) if m is not None else m |
250 | 280 | for m in tqdm.tqdm(mols) |
251 | 281 | ] |
252 | | - similarity = [] |
253 | 282 | d = {"smiles": random_smiles, "fps": fps} |
254 | | - df = pd.DataFrame(d, columns=["smiles", "fps"]) |
255 | | - df = df.dropna() |
256 | | - df.to_pickle(open(fingerprints_path, "wb")) |
| 283 | + fingerprints_df = pd.DataFrame(d, columns=["smiles", "fps"]) |
| 284 | + fingerprints_df = fingerprints_df.dropna() |
| 285 | + fingerprints_df.to_pickle(open(fingerprints_path, "wb")) |
| 286 | + self._fingerprints = fingerprints_df |
257 | 287 | else: |
258 | | - df = pd.read_pickle(open(fingerprints_path, "rb")) |
259 | | - fps = np.array([list(vec) for vec in df["fps"].tolist()]) |
260 | | - print(f"Starting k-means clustering...") |
261 | | - start_time = time.perf_counter() |
262 | | - kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_init="auto") |
263 | | - kmeans.fit(fps) |
264 | | - print(f"Finished k-means in {time.perf_counter() - start_time:.2f} seconds") |
265 | | - df["label"] = kmeans.labels_ |
266 | | - df.to_pickle( |
267 | | - open( |
268 | | - os.path.join( |
269 | | - self.raw_dir, f"fingerprints_labeled_{self.n_clusters}.pkl" |
270 | | - ), |
271 | | - "wb", |
272 | | - ) |
| 288 | + self._fingerprints = pd.read_pickle(open(fingerprints_path, "rb")) |
| 289 | + return self._fingerprints |
| 290 | + |
| 291 | + def _build_clusters(self): |
| 292 | + fingerprints_clustered_path = os.path.join( |
| 293 | + self.raw_dir, "fingerprints_clustered.pkl" |
| 294 | + ) |
| 295 | + cluster_centers_path = os.path.join(self.raw_dir, f"cluster_centers.pkl") |
| 296 | + print(f"Starting k-means clustering...") |
| 297 | + start_time = time.perf_counter() |
| 298 | + kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_init="auto") |
| 299 | + fps = np.array([list(vec) for vec in self.fingerprints["fps"].tolist()]) |
| 300 | + kmeans.fit(fps) |
| 301 | + print(f"Finished k-means in {time.perf_counter() - start_time:.2f} seconds") |
| 302 | + fingerprints_df = self.fingerprints |
| 303 | + fingerprints_df["label"] = kmeans.labels_ |
| 304 | + fingerprints_df.to_pickle( |
| 305 | + open( |
| 306 | + fingerprints_clustered_path, |
| 307 | + "wb", |
273 | 308 | ) |
274 | | - cluster_df = pd.DataFrame( |
275 | | - data={"centers": [center for center in kmeans.cluster_centers_]} |
| 309 | + ) |
| 310 | + cluster_df = pd.DataFrame( |
| 311 | + data={"centers": [center for center in kmeans.cluster_centers_]} |
| 312 | + ) |
| 313 | + cluster_df.to_pickle( |
| 314 | + open( |
| 315 | + cluster_centers_path, |
| 316 | + "wb", |
276 | 317 | ) |
277 | | - cluster_df.to_pickle( |
278 | | - open( |
279 | | - os.path.join( |
280 | | - self.raw_dir, f"cluster_centers_{self.n_clusters}.pkl" |
281 | | - ), |
282 | | - "wb", |
| 318 | + ) |
| 319 | + |
| 320 | + return cluster_df, fingerprints_df |
| 321 | + |
| 322 | + def _exclude_clusters(self, cluster_centers): |
| 323 | + exclusion_data_path = os.path.join(self.raw_dir, "exclusion_data_clustered.pkl") |
| 324 | + cluster_centers_np = np.array( |
| 325 | + [ |
| 326 | + [cci for cci in cluster_center] |
| 327 | + for cluster_center in cluster_centers["centers"] |
| 328 | + ] |
| 329 | + ) |
| 330 | + if self.exclude_data_from is not None: |
| 331 | + if not os.path.exists(exclusion_data_path): |
| 332 | + print(f"Loading data for exclusion of clusters...") |
| 333 | + raw_chebi = [] |
| 334 | + for filename in self.exclude_data_from.raw_file_names: |
| 335 | + raw_chebi.append( |
| 336 | + pd.read_pickle( |
| 337 | + open( |
| 338 | + os.path.join(self.exclude_data_from.raw_dir, filename), |
| 339 | + "rb", |
| 340 | + ) |
| 341 | + ) |
| 342 | + ) |
| 343 | + raw_chebi = pd.concat(raw_chebi) |
| 344 | + raw_chebi_smiles = np.array(raw_chebi["SMILES"]) |
| 345 | + fpgen = AllChem.GetRDKitFPGenerator() |
| 346 | + print(f"Converting SMILES to molecules...") |
| 347 | + mols = [Chem.MolFromSmiles(s) for s in tqdm.tqdm(raw_chebi_smiles)] |
| 348 | + print(f"Generating Fingerprints...") |
| 349 | + chebi_fps = [ |
| 350 | + fpgen.GetFingerprint(m) if m is not None else m |
| 351 | + for m in tqdm.tqdm(mols) |
| 352 | + ] |
| 353 | + print(f"Finding cluster for each instance from exclusion-data") |
| 354 | + chebi_fps = np.array([list(fp) for fp in chebi_fps if fp is not None]) |
| 355 | + tree = spatial.KDTree(cluster_centers_np) |
| 356 | + chebi_clusters = [tree.query(fp)[1] for fp in chebi_fps] |
| 357 | + chebi_clusters_df = pd.DataFrame( |
| 358 | + {"fp": [fp for fp in chebi_fps], "center_id": chebi_clusters}, |
| 359 | + columns=["fp", "center_id"], |
283 | 360 | ) |
| 361 | + chebi_clusters_df.to_pickle(open(exclusion_data_path, "wb")) |
| 362 | + else: |
| 363 | + chebi_clusters_df = pd.read_pickle(open(exclusion_data_path, "rb")) |
| 364 | + # filter pubchem clusters and remove all that contain data from the exclusion set |
| 365 | + print(f"Removing clusters with data from exclusion-set") |
| 366 | + counts = chebi_clusters_df["center_id"].value_counts() |
| 367 | + cluster_centers["n_chebi_instances"] = counts |
| 368 | + cluster_centers["n_chebi_instances"].fillna(0, inplace=True) |
| 369 | + cluster_centers.sort_values( |
| 370 | + by="n_chebi_instances", ascending=False, inplace=True |
284 | 371 | ) |
| 372 | + zero_centers = cluster_centers[cluster_centers["n_chebi_instances"] == 0] |
| 373 | + if len(zero_centers) > self.include_min_n_clusters: |
| 374 | + cluster_centers = zero_centers |
| 375 | + else: |
| 376 | + cluster_centers = cluster_centers[-self.include_min_n_clusters :] |
| 377 | + return cluster_centers |
| 378 | + |
| 379 | + @property |
| 380 | + def cluster_centers(self): |
| 381 | + cluster_centers_path = os.path.join(self.raw_dir, f"cluster_centers.pkl") |
| 382 | + if self._cluster_centers is None: |
| 383 | + if os.path.exists(cluster_centers_path): |
| 384 | + self._cluster_centers = pd.read_pickle(open(cluster_centers_path, "rb")) |
| 385 | + else: |
| 386 | + self._cluster_centers = self._build_clusters()[0] |
| 387 | + return self._cluster_centers |
| 388 | + |
| 389 | + @property |
| 390 | + def fingerprints_clustered(self): |
| 391 | + fingerprints_path = os.path.join(self.raw_dir, f"fingerprints_clustered.pkl") |
| 392 | + if self._fingerprints_clustered is None: |
| 393 | + if os.path.exists(fingerprints_path): |
| 394 | + self._fingerprints_clustered = pd.read_pickle( |
| 395 | + open(fingerprints_path, "rb") |
| 396 | + ) |
| 397 | + else: |
| 398 | + self._fingerprints_clustered = self._build_clusters()[1] |
| 399 | + return self._fingerprints_clustered |
| 400 | + |
| 401 | + @property |
| 402 | + def cluster_centers_superclustered(self): |
| 403 | + cluster_centers_path = os.path.join( |
| 404 | + self.raw_dir, f"cluster_centers_superclustered.pkl" |
| 405 | + ) |
| 406 | + if self._cluster_centers_superclustered is None: |
| 407 | + if not os.path.exists(cluster_centers_path): |
| 408 | + clusters_filtered = self._exclude_clusters(self.cluster_centers) |
| 409 | + print(f"Superclustering PubChem clusters") |
| 410 | + kmeans = KMeans(n_clusters=3, random_state=0, n_init="auto") |
| 411 | + clusters_np = np.array( |
| 412 | + [[cci for cci in center] for center in clusters_filtered["centers"]] |
| 413 | + ) |
| 414 | + kmeans.fit(clusters_np) |
| 415 | + clusters_filtered["label"] = kmeans.labels_ |
| 416 | + clusters_filtered.to_pickle( |
| 417 | + open( |
| 418 | + os.path.join( |
| 419 | + self.raw_dir, "cluster_centers_superclustered.pkl" |
| 420 | + ), |
| 421 | + "wb", |
| 422 | + ) |
| 423 | + ) |
| 424 | + self._cluster_centers_superclustered = clusters_filtered |
| 425 | + else: |
| 426 | + self._cluster_centers_superclustered = pd.read_pickle( |
| 427 | + open( |
| 428 | + os.path.join( |
| 429 | + self.raw_dir, f"cluster_centers_superclustered.pkl" |
| 430 | + ), |
| 431 | + "rb", |
| 432 | + ) |
| 433 | + ) |
| 434 | + return self._cluster_centers_superclustered |
| 435 | + |
| 436 | + def download(self): |
| 437 | + if self._k == PubChem.FULL: |
| 438 | + super().download() |
| 439 | + else: |
| 440 | + if not all( |
| 441 | + os.path.exists(os.path.join(self.raw_dir, file)) |
| 442 | + for file in self.raw_file_names |
| 443 | + ): |
| 444 | + fingerprints = self.fingerprints_clustered |
| 445 | + fingerprints["big_cluster_assignment"] = fingerprints["label"].apply( |
| 446 | + lambda l: ( |
| 447 | + -1 |
| 448 | + if l not in self.cluster_centers_superclustered.index |
| 449 | + else self.cluster_centers_superclustered.loc[int(l), "label"] |
| 450 | + ) |
| 451 | + ) |
| 452 | + fp_grouped = fingerprints.groupby("big_cluster_assignment") |
| 453 | + splits = [fp_grouped.get_group(g) for g in fp_grouped.groups if g != -1] |
| 454 | + splits[0] = splits[0][: self.validation_size_limit] |
| 455 | + splits.sort(key=lambda x: len(x)) |
| 456 | + for i, name in enumerate(["validation", "test", "train"]): |
| 457 | + if not os.path.exists(os.path.join(self.raw_dir, f"{name}.txt")): |
| 458 | + open(os.path.join(self.raw_dir, f"{name}.txt"), "x").close() |
| 459 | + with open(os.path.join(self.raw_dir, f"{name}.txt"), "w") as f: |
| 460 | + for id, row in splits[i].iterrows(): |
| 461 | + f.writelines(f"{id}\t{row['smiles']}\n") |
285 | 462 |
|
286 | 463 |
|
287 | 464 | class PubChemDissimilarSMILES(PubChemDissimilar): |
@@ -383,7 +560,12 @@ def download(self): |
383 | 560 |
|
384 | 561 |
|
385 | 562 | if __name__ == "__main__": |
386 | | - kmeans_data = PubChemKMeans() |
| 563 | + kmeans_data = PubChemKMeans( |
| 564 | + n_clusters=100, |
| 565 | + random_size=10000, |
| 566 | + exclude_data_from=ChEBIOver100(chebi_version=231), |
| 567 | + include_min_n_clusters=10, |
| 568 | + ) |
387 | 569 | kmeans_data.download() |
388 | 570 |
|
389 | 571 |
|
|
0 commit comments