-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathsampling_qed.py
More file actions
54 lines (41 loc) · 1.59 KB
/
sampling_qed.py
File metadata and controls
54 lines (41 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import time
from opt_qed import QEDTask
from rxnflow.base import RxnFlowSampler
from rxnflow.config import Config, init_empty
# NOTE: example setting
NUM_SAMPLES = 200
DEVICE = "cpu" # or 'cuda'
class QEDSampler(RxnFlowSampler):
def setup_task(self):
self.task = QEDTask(cfg=self.cfg)
if __name__ == "__main__":
# change config from training
config = init_empty(Config())
config.algo.num_from_policy = 100 # 64 -> 100
config.env_dir = "./data/envs/stock" # if you want to use catalog, just remove this line
ckpt_path = "./logs/example/qed/model_state.pt"
# construct sampler
sampler = QEDSampler(config, ckpt_path, DEVICE)
# type1: generate molecules only
tick_st = time.time()
res = sampler.sample(NUM_SAMPLES, calc_reward=False)
tick_end = time.time()
print(f"Generated Molecules: {len(res)}")
print(f"Sampling: {tick_end - tick_st:.3f} sec")
# save molecules
with open("./example-qed.smi", "w") as w:
for idx, sample in enumerate(res):
w.write(f"{sample['smiles']}\tsample{idx}\n")
# type2: generate molecules with their rewards
tick_st = time.time()
res = sampler.sample(NUM_SAMPLES, calc_reward=True)
tick_end = time.time()
print(f"Generated Molecules: {len(res)}")
print(f"Sampling: {tick_end - tick_st:.3f} sec")
# save molecules
with open("./example-qed.csv", "w") as w:
w.write(",SMILES,QED\n")
for idx, sample in enumerate(res):
smiles = sample["smiles"]
qed = sample["info"]["reward"][0]
w.write(f"{idx},{smiles},{qed:.3f}\n")