Skip to content

Commit d088f68

Browse files
author
sfluegel
committed
add kmeans pubchem dataset (clustering stage)
1 parent 81df748 commit d088f68

File tree

1 file changed

+76
-3
lines changed

1 file changed

+76
-3
lines changed

chebai/preprocessing/datasets/pubchem.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from sklearn.model_selection import train_test_split
1818
import requests
1919
import torch
20+
import time
21+
import numpy as np
2022
import tqdm
2123
from datetime import datetime
2224

@@ -56,7 +58,7 @@ def identifier(self):
5658

5759
@property
5860
def split_label(self):
59-
if self._k:
61+
if self._k and self._k != self.FULL:
6062
return str(self._k)
6163
else:
6264
return "full"
@@ -211,6 +213,77 @@ def download(self):
211213
)
212214

213215

216+
class PubChemKMeans(PubChem):
217+
218+
def __init__(self, *args, n_clusters=1e4, random_size=1e6, **kwargs):
219+
"""k: number of entries in this dataset,
220+
n_random_subsets: number of subsets of random data from which to draw
221+
the most dissimilar molecules,
222+
random_size_factor: size of random subsets (in total) in relation to k"""
223+
self.n_clusters = int(n_clusters)
224+
super(PubChemKMeans, self).__init__(*args, k=int(random_size), **kwargs)
225+
226+
@property
227+
def _name(self):
228+
return f"PubchemKMeans"
229+
230+
def download(self):
231+
if self._k == PubChem.FULL:
232+
super().download()
233+
else:
234+
print(f"Loading random dataset (size: {self._k})...")
235+
random_dataset = PubChem(k=self._k)
236+
random_dataset.download()
237+
fingerprints_path = os.path.join(self.raw_dir, "fingerprints.pkl")
238+
if not os.path.exists(fingerprints_path):
239+
with open(
240+
os.path.join(random_dataset.raw_dir, "smiles.txt"), "r"
241+
) as f_in:
242+
random_smiles = [s.split("\t")[1].strip() for s in f_in.readlines()]
243+
fpgen = AllChem.GetRDKitFPGenerator()
244+
selected_smiles = []
245+
print(f"Converting SMILES to molecules...")
246+
mols = [Chem.MolFromSmiles(s) for s in tqdm.tqdm(random_smiles)]
247+
print(f"Generating Fingerprints...")
248+
fps = [
249+
fpgen.GetFingerprint(m) if m is not None else m
250+
for m in tqdm.tqdm(mols)
251+
]
252+
similarity = []
253+
d = {"smiles": random_smiles, "fps": fps}
254+
df = pd.DataFrame(d, columns=["smiles", "fps"])
255+
df = df.dropna()
256+
df.to_pickle(open(fingerprints_path, "wb"))
257+
else:
258+
df = pd.read_pickle(open(fingerprints_path, "rb"))
259+
fps = np.array([list(vec) for vec in df["fps"].tolist()])
260+
print(f"Starting k-means clustering...")
261+
start_time = time.perf_counter()
262+
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_init="auto")
263+
kmeans.fit(fps)
264+
print(f"Finished k-means in {time.perf_counter() - start_time:.2f} seconds")
265+
df["label"] = kmeans.labels_
266+
df.to_pickle(
267+
open(
268+
os.path.join(
269+
self.raw_dir, f"fingerprints_labeled_{self.n_clusters}.pkl"
270+
),
271+
"wb",
272+
)
273+
)
274+
cluster_df = pd.DataFrame(
275+
data={"centers": [center for center in kmeans.cluster_centers_]}
276+
)
277+
cluster_df.to_pickle(
278+
open(
279+
os.path.join(
280+
self.raw_dir, f"cluster_centers_{self.n_clusters}.pkl"
281+
),
282+
"wb",
283+
)
284+
)
285+
286+
214287
class PubChemDissimilarSMILES(PubChemDissimilar):
215288
READER = dr.ChemDataReader
216289

@@ -310,8 +383,8 @@ def download(self):
310383

311384

312385
if __name__ == "__main__":
313-
haz = Hazardous()
314-
haz.setup_processed()
386+
kmeans_data = PubChemKMeans()
387+
kmeans_data.download()
315388

316389

317390
class SWJPreChem(PubChem):

0 commit comments

Comments
 (0)