Skip to content

Commit e06b9e6

Browse files
author
sfluegel
committed
add complete kmeans pubchem dataset
1 parent d088f68 commit e06b9e6

File tree

1 file changed

+218
-36
lines changed

1 file changed

+218
-36
lines changed

chebai/preprocessing/datasets/pubchem.py

Lines changed: 218 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import random
1313
import shutil
1414
import tempfile
15+
from scipy import spatial
1516

1617
import pandas as pd
1718
from sklearn.model_selection import train_test_split
@@ -28,9 +29,11 @@
2829
ChEBIOver50,
2930
ChEBIOver100,
3031
ChEBIOverX,
32+
_ChEBIDataExtractor,
3133
)
3234
from rdkit import Chem, DataStructs
3335
from rdkit.Chem import AllChem
36+
from sklearn.cluster import KMeans
3437

3538

3639
class PubChem(XYBaseDataModule):
@@ -215,73 +218,247 @@ def download(self):
215218

216219
class PubChemKMeans(PubChem):
217220

218-
def __init__(self, *args, n_clusters=1e4, random_size=1e6, **kwargs):
221+
def __init__(
222+
self,
223+
*args,
224+
n_clusters=1e4,
225+
random_size=1e6,
226+
exclude_data_from: _ChEBIDataExtractor = None,
227+
validation_size_limit=4000,
228+
include_min_n_clusters=100,
229+
**kwargs,
230+
):
219231
"""k: number of entries in this dataset,
220232
n_random_subsets: number of subsets of random data from which to draw
221233
the most dissimilar molecules,
222234
random_size_factor: size of random subsets (in total) in relation to k"""
223235
self.n_clusters = int(n_clusters)
236+
self.exclude_data_from = exclude_data_from
237+
self.validation_size_limit = validation_size_limit
238+
self.include_min_n_clusters = include_min_n_clusters
224239
super(PubChemKMeans, self).__init__(*args, k=int(random_size), **kwargs)
240+
self._fingerprints = None
241+
self._cluster_centers = None
242+
self._fingerprints_clustered = None
243+
self._exclusion_data_clustered = None
244+
self._cluster_centers_superclustered = None
225245

226246
@property
227247
def _name(self):
228248
return f"PubchemKMeans"
229249

230-
def download(self):
231-
if self._k == PubChem.FULL:
232-
super().download()
250+
@property
251+
def split_label(self):
252+
if self._k and self._k != self.FULL:
253+
return f"{self.n_clusters}_centers_out_of_{self._k}"
233254
else:
234-
print(f"Loading random dataset (size: {self._k})...")
235-
random_dataset = PubChem(k=self._k)
236-
random_dataset.download()
255+
return f"{self.n_clusters}_centers_out_of_full"
256+
257+
@property
258+
def raw_file_names(self):
259+
return ["train.txt", "validation.txt", "test.txt"]
260+
261+
@property
262+
def fingerprints(self):
263+
if self._fingerprints is None:
237264
fingerprints_path = os.path.join(self.raw_dir, "fingerprints.pkl")
238265
if not os.path.exists(fingerprints_path):
266+
print(f"No fingerprints found...")
267+
print(f"Loading random dataset (size: {self._k})...")
268+
random_dataset = PubChem(k=self._k)
269+
random_dataset.download()
239270
with open(
240271
os.path.join(random_dataset.raw_dir, "smiles.txt"), "r"
241272
) as f_in:
242273
random_smiles = [s.split("\t")[1].strip() for s in f_in.readlines()]
243274
fpgen = AllChem.GetRDKitFPGenerator()
244-
selected_smiles = []
245275
print(f"Converting SMILES to molecules...")
246276
mols = [Chem.MolFromSmiles(s) for s in tqdm.tqdm(random_smiles)]
247277
print(f"Generating Fingerprints...")
248278
fps = [
249279
fpgen.GetFingerprint(m) if m is not None else m
250280
for m in tqdm.tqdm(mols)
251281
]
252-
similarity = []
253282
d = {"smiles": random_smiles, "fps": fps}
254-
df = pd.DataFrame(d, columns=["smiles", "fps"])
255-
df = df.dropna()
256-
df.to_pickle(open(fingerprints_path, "wb"))
283+
fingerprints_df = pd.DataFrame(d, columns=["smiles", "fps"])
284+
fingerprints_df = fingerprints_df.dropna()
285+
fingerprints_df.to_pickle(open(fingerprints_path, "wb"))
286+
self._fingerprints = fingerprints_df
257287
else:
258-
df = pd.read_pickle(open(fingerprints_path, "rb"))
259-
fps = np.array([list(vec) for vec in df["fps"].tolist()])
260-
print(f"Starting k-means clustering...")
261-
start_time = time.perf_counter()
262-
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_init="auto")
263-
kmeans.fit(fps)
264-
print(f"Finished k-means in {time.perf_counter() - start_time:.2f} seconds")
265-
df["label"] = kmeans.labels_
266-
df.to_pickle(
267-
open(
268-
os.path.join(
269-
self.raw_dir, f"fingerprints_labeled_{self.n_clusters}.pkl"
270-
),
271-
"wb",
272-
)
288+
self._fingerprints = pd.read_pickle(open(fingerprints_path, "rb"))
289+
return self._fingerprints
290+
291+
def _build_clusters(self):
292+
fingerprints_clustered_path = os.path.join(
293+
self.raw_dir, "fingerprints_clustered.pkl"
294+
)
295+
cluster_centers_path = os.path.join(self.raw_dir, f"cluster_centers.pkl")
296+
print(f"Starting k-means clustering...")
297+
start_time = time.perf_counter()
298+
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_init="auto")
299+
fps = np.array([list(vec) for vec in self.fingerprints["fps"].tolist()])
300+
kmeans.fit(fps)
301+
print(f"Finished k-means in {time.perf_counter() - start_time:.2f} seconds")
302+
fingerprints_df = self.fingerprints
303+
fingerprints_df["label"] = kmeans.labels_
304+
fingerprints_df.to_pickle(
305+
open(
306+
fingerprints_clustered_path,
307+
"wb",
273308
)
274-
cluster_df = pd.DataFrame(
275-
data={"centers": [center for center in kmeans.cluster_centers_]}
309+
)
310+
cluster_df = pd.DataFrame(
311+
data={"centers": [center for center in kmeans.cluster_centers_]}
312+
)
313+
cluster_df.to_pickle(
314+
open(
315+
cluster_centers_path,
316+
"wb",
276317
)
277-
cluster_df.to_pickle(
278-
open(
279-
os.path.join(
280-
self.raw_dir, f"cluster_centers_{self.n_clusters}.pkl"
281-
),
282-
"wb",
318+
)
319+
320+
return cluster_df, fingerprints_df
321+
322+
def _exclude_clusters(self, cluster_centers):
323+
exclusion_data_path = os.path.join(self.raw_dir, "exclusion_data_clustered.pkl")
324+
cluster_centers_np = np.array(
325+
[
326+
[cci for cci in cluster_center]
327+
for cluster_center in cluster_centers["centers"]
328+
]
329+
)
330+
if self.exclude_data_from is not None:
331+
if not os.path.exists(exclusion_data_path):
332+
print(f"Loading data for exclusion of clusters...")
333+
raw_chebi = []
334+
for filename in self.exclude_data_from.raw_file_names:
335+
raw_chebi.append(
336+
pd.read_pickle(
337+
open(
338+
os.path.join(self.exclude_data_from.raw_dir, filename),
339+
"rb",
340+
)
341+
)
342+
)
343+
raw_chebi = pd.concat(raw_chebi)
344+
raw_chebi_smiles = np.array(raw_chebi["SMILES"])
345+
fpgen = AllChem.GetRDKitFPGenerator()
346+
print(f"Converting SMILES to molecules...")
347+
mols = [Chem.MolFromSmiles(s) for s in tqdm.tqdm(raw_chebi_smiles)]
348+
print(f"Generating Fingerprints...")
349+
chebi_fps = [
350+
fpgen.GetFingerprint(m) if m is not None else m
351+
for m in tqdm.tqdm(mols)
352+
]
353+
print(f"Finding cluster for each instance from exclusion-data")
354+
chebi_fps = np.array([list(fp) for fp in chebi_fps if fp is not None])
355+
tree = spatial.KDTree(cluster_centers_np)
356+
chebi_clusters = [tree.query(fp)[1] for fp in chebi_fps]
357+
chebi_clusters_df = pd.DataFrame(
358+
{"fp": [fp for fp in chebi_fps], "center_id": chebi_clusters},
359+
columns=["fp", "center_id"],
283360
)
361+
chebi_clusters_df.to_pickle(open(exclusion_data_path, "wb"))
362+
else:
363+
chebi_clusters_df = pd.read_pickle(open(exclusion_data_path, "rb"))
364+
# filter pubchem clusters and remove all that contain data from the exclusion set
365+
print(f"Removing clusters with data from exclusion-set")
366+
counts = chebi_clusters_df["center_id"].value_counts()
367+
cluster_centers["n_chebi_instances"] = counts
368+
cluster_centers["n_chebi_instances"].fillna(0, inplace=True)
369+
cluster_centers.sort_values(
370+
by="n_chebi_instances", ascending=False, inplace=True
284371
)
372+
zero_centers = cluster_centers[cluster_centers["n_chebi_instances"] == 0]
373+
if len(zero_centers) > self.include_min_n_clusters:
374+
cluster_centers = zero_centers
375+
else:
376+
cluster_centers = cluster_centers[-self.include_min_n_clusters :]
377+
return cluster_centers
378+
379+
@property
380+
def cluster_centers(self):
381+
cluster_centers_path = os.path.join(self.raw_dir, f"cluster_centers.pkl")
382+
if self._cluster_centers is None:
383+
if os.path.exists(cluster_centers_path):
384+
self._cluster_centers = pd.read_pickle(open(cluster_centers_path, "rb"))
385+
else:
386+
self._cluster_centers = self._build_clusters()[0]
387+
return self._cluster_centers
388+
389+
@property
390+
def fingerprints_clustered(self):
391+
fingerprints_path = os.path.join(self.raw_dir, f"fingerprints_clustered.pkl")
392+
if self._fingerprints_clustered is None:
393+
if os.path.exists(fingerprints_path):
394+
self._fingerprints_clustered = pd.read_pickle(
395+
open(fingerprints_path, "rb")
396+
)
397+
else:
398+
self._fingerprints_clustered = self._build_clusters()[1]
399+
return self._fingerprints_clustered
400+
401+
@property
402+
def cluster_centers_superclustered(self):
403+
cluster_centers_path = os.path.join(
404+
self.raw_dir, f"cluster_centers_superclustered.pkl"
405+
)
406+
if self._cluster_centers_superclustered is None:
407+
if not os.path.exists(cluster_centers_path):
408+
clusters_filtered = self._exclude_clusters(self.cluster_centers)
409+
print(f"Superclustering PubChem clusters")
410+
kmeans = KMeans(n_clusters=3, random_state=0, n_init="auto")
411+
clusters_np = np.array(
412+
[[cci for cci in center] for center in clusters_filtered["centers"]]
413+
)
414+
kmeans.fit(clusters_np)
415+
clusters_filtered["label"] = kmeans.labels_
416+
clusters_filtered.to_pickle(
417+
open(
418+
os.path.join(
419+
self.raw_dir, "cluster_centers_superclustered.pkl"
420+
),
421+
"wb",
422+
)
423+
)
424+
self._cluster_centers_superclustered = clusters_filtered
425+
else:
426+
self._cluster_centers_superclustered = pd.read_pickle(
427+
open(
428+
os.path.join(
429+
self.raw_dir, f"cluster_centers_superclustered.pkl"
430+
),
431+
"rb",
432+
)
433+
)
434+
return self._cluster_centers_superclustered
435+
436+
def download(self):
437+
if self._k == PubChem.FULL:
438+
super().download()
439+
else:
440+
if not all(
441+
os.path.exists(os.path.join(self.raw_dir, file))
442+
for file in self.raw_file_names
443+
):
444+
fingerprints = self.fingerprints_clustered
445+
fingerprints["big_cluster_assignment"] = fingerprints["label"].apply(
446+
lambda l: (
447+
-1
448+
if l not in self.cluster_centers_superclustered.index
449+
else self.cluster_centers_superclustered.loc[int(l), "label"]
450+
)
451+
)
452+
fp_grouped = fingerprints.groupby("big_cluster_assignment")
453+
splits = [fp_grouped.get_group(g) for g in fp_grouped.groups if g != -1]
454+
splits[0] = splits[0][: self.validation_size_limit]
455+
splits.sort(key=lambda x: len(x))
456+
for i, name in enumerate(["validation", "test", "train"]):
457+
if not os.path.exists(os.path.join(self.raw_dir, f"{name}.txt")):
458+
open(os.path.join(self.raw_dir, f"{name}.txt"), "x").close()
459+
with open(os.path.join(self.raw_dir, f"{name}.txt"), "w") as f:
460+
for id, row in splits[i].iterrows():
461+
f.writelines(f"{id}\t{row['smiles']}\n")
285462

286463

287464
class PubChemDissimilarSMILES(PubChemDissimilar):
@@ -383,7 +560,12 @@ def download(self):
383560

384561

385562
if __name__ == "__main__":
386-
kmeans_data = PubChemKMeans()
563+
kmeans_data = PubChemKMeans(
564+
n_clusters=100,
565+
random_size=10000,
566+
exclude_data_from=ChEBIOver100(chebi_version=231),
567+
include_min_n_clusters=10,
568+
)
387569
kmeans_data.download()
388570

389571

0 commit comments

Comments
 (0)