Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion remote/DiffCSP-official
2 changes: 1 addition & 1 deletion remote/riemannian-fm
109 changes: 109 additions & 0 deletions scripts_model/build_solvent_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import argparse
import ast
import csv
from pathlib import Path


def parse_solvent_label(label):
label = (label or "").strip()
if not label:
return []
if label.startswith("frozenset(") and label.endswith(")"):
set_literal = label[len("frozenset("):-1]
else:
set_literal = label
try:
solvents = ast.literal_eval(set_literal)
except Exception:
return [label] if label else []
if isinstance(solvents, (set, frozenset, list, tuple)):
return [str(s).strip() for s in solvents if str(s).strip()]
if isinstance(solvents, str):
return [solvents.strip()] if solvents.strip() else []
return []


def load_vocab(vocab_path):
vocab = []
with open(vocab_path) as f:
for line in f:
val = line.strip()
if val:
vocab.append(val)
return vocab


def load_solvent_map(solvent_csv, vocab_set):
with open(solvent_csv, newline="") as f:
reader = csv.reader(f)
rows = list(reader)
if not rows:
return {}
header = rows[0]
col_solvents = [parse_solvent_label(label) for label in header]
refcode_to_solvents = {}
for row in rows[1:]:
for idx, cell in enumerate(row):
cell = (cell or "").strip().strip('"')
if not cell:
continue
solvents = [s for s in col_solvents[idx] if s in vocab_set]
if not solvents:
continue
entry = refcode_to_solvents.setdefault(cell, set())
entry.update(solvents)
return refcode_to_solvents


def build_matrix(split_csv, solvent_map, vocab, out_csv):
seen = set()
material_ids = []
with open(split_csv, newline="") as f:
reader = csv.DictReader(f)
for row in reader:
mid = (row.get("material_id") or "").strip()
if not mid or mid in seen:
continue
seen.add(mid)
material_ids.append(mid)

out_path = Path(out_csv)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["material_id"] + vocab)
with_solvent = 0
for mid in material_ids:
solvents = solvent_map.get(mid, set())
row = [mid]
if solvents:
with_solvent += 1
for solv in vocab:
row.append(1 if solv in solvents else 0)
writer.writerow(row)

print(f"Rows: {len(material_ids)}")
print(f"Rows with solvents: {with_solvent}")
print(f"Output: {out_path}")


def main():
parser = argparse.ArgumentParser(
description="Build solvent matrix CSV per split from CSD_solvents."
)
parser.add_argument("--split_csv", required=True)
parser.add_argument("--solvent_csv", required=True)
parser.add_argument("--vocab_path", required=True)
parser.add_argument("--out_csv", required=True)
args = parser.parse_args()

vocab = load_vocab(args.vocab_path)
if not vocab:
raise ValueError(f"Empty vocab: {args.vocab_path}")
vocab_set = set(vocab)
solvent_map = load_solvent_map(args.solvent_csv, vocab_set)
build_matrix(args.split_csv, solvent_map, vocab, args.out_csv)


if __name__ == "__main__":
main()
30 changes: 18 additions & 12 deletions scripts_model/conf/data/organic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ datamodule:

datasets:
train:
_target_: diffcsp.pl_data.dataset.CrystDataset
_target_: orgflow.csp_ext.dataset.CrystDataset
name: OrganicTrainSet
path: ${oc.env:PROJECT_ROOT}/data/organic/train_drug.csv
# save_path: ${oc.env:PROJECT_ROOT}/data/organic/train_organic/
save_path: ${oc.env:PROJECT_ROOT}/data/organic/train_organic_drug.pt
path: ${oc.env:PROJECT_ROOT}/data/organic/train_drug_orgflow.csv
save_path: ${oc.env:PROJECT_ROOT}/data/organic/train_drug_orgflow.pt
prop: density
niggli: false
primitive: false
Expand All @@ -41,15 +40,17 @@ datamodule:
tolerance: 0.01
use_space_group: false
use_pos_index: false
solvent_matrix_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/train_solvent_matrix.csv
solvent_vocab_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/solvent_vocab.txt
drop_no_solvent: true
lattice_scale_method: scale_length
preprocess_workers: 30

val:
- _target_: diffcsp.pl_data.dataset.CrystDataset
- _target_: orgflow.csp_ext.dataset.CrystDataset
name: OrganicValSet
path: ${oc.env:PROJECT_ROOT}/data/organic/val_drug.csv
# save_path: ${oc.env:PROJECT_ROOT}/data/organic/val_organic/
save_path: ${oc.env:PROJECT_ROOT}/data/organic/val_organic_drug.pt
path: ${oc.env:PROJECT_ROOT}/data/organic/val_drug_orgflow.csv
save_path: ${oc.env:PROJECT_ROOT}/data/organic/val_drug_orgflow.pt
prop: density
niggli: false
primitive: false
Expand All @@ -58,15 +59,17 @@ datamodule:
tolerance: 0.01
use_space_group: false
use_pos_index: false
solvent_matrix_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/val_solvent_matrix.csv
solvent_vocab_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/solvent_vocab.txt
drop_no_solvent: true
lattice_scale_method: scale_length
preprocess_workers: 30

test:
- _target_: diffcsp.pl_data.dataset.CrystDataset
- _target_: orgflow.csp_ext.dataset.CrystDataset
name: OrganicTestSet
path: ${oc.env:PROJECT_ROOT}/data/organic/polymorphics_subset.csv
# save_path: ${oc.env:PROJECT_ROOT}/data/organic/test_organic/
save_path: ${oc.env:PROJECT_ROOT}/data/organic/polymorphics.pt
path: ${oc.env:PROJECT_ROOT}/data/organic/test_drug_orgflow.csv
save_path: ${oc.env:PROJECT_ROOT}/data/organic/test_drug_orgflow.pt
prop: density
niggli: false
primitive: false
Expand All @@ -75,6 +78,9 @@ datamodule:
tolerance: 0.01
use_space_group: false
use_pos_index: false
solvent_matrix_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/test_solvent_matrix.csv
solvent_vocab_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/solvent_vocab.txt
drop_no_solvent: true
lattice_scale_method: scale_length
preprocess_workers: 30

Expand Down
14 changes: 7 additions & 7 deletions scripts_model/conf/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ optim:
# RFM
optimizer:
_target_: torch.optim.AdamW
lr: 0.0003
weight_decay: 0.0
lr_scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
T_max: ${data.train_max_epochs}
eta_min: 1e-5
lr: 0.0002
weight_decay: 0.0002
# lr_scheduler:
# _target_: torch.optim.lr_scheduler.CosineAnnealingLR
# T_max: ${data.train_max_epochs}
# eta_min: 1e-5
interval: epoch
ema_decay: 0.999

Expand All @@ -46,7 +46,7 @@ train:
pl_trainer:
fast_dev_run: False # Enable this for debug purposes
strategy: ddp
devices: 2
devices: 3
accelerator: gpu
precision: 32
# max_steps: 10000
Expand Down
11 changes: 10 additions & 1 deletion scripts_model/conf/model/null_params.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
cost_coord: 600.
cost_lattice: 1.
cost_type: 0.0 # Always should be 0 for our CSP task
cost_bond: 0.007
cost_bond: 0.005
cost_solvent: 4
solvent_num_classes: 79
solvent_embed_dim: 512
solvent_pred_hidden_dim: 768
solvent_pred_num_layers: 4
solvent_dropout: 0.1
solvent_pred_layernorm: true
solvent_pos_weight: auto
solvent_pos_weight_max: 20.0
affine_combine_costs: true
target_distribution: conditional
self_cond: false
Expand Down
10 changes: 8 additions & 2 deletions scripts_model/conf/vectorfield/rfm_cspnet.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
_target_: orgflow.model.arch.CSPNet
hidden_dim: 128
hidden_dim: 512
time_dim: 256
num_layers: 12 # Testing more layers for organic data
num_layers: 3 # Testing more layers for organic data
act_fn: silu
dis_emb: sin
num_freqs: 128
Expand All @@ -17,3 +17,9 @@ represent_num_atoms: false
represent_angle_edge_to_lattice: true
self_edges: false
self_cond: ${model.self_cond}
solvent_num_classes: ${model.solvent_num_classes}
solvent_embed_dim: ${model.solvent_embed_dim}
solvent_pred_hidden_dim: ${model.solvent_pred_hidden_dim}
solvent_pred_num_layers: ${model.solvent_pred_num_layers}
solvent_dropout: ${model.solvent_dropout}
solvent_pred_layernorm: ${model.solvent_pred_layernorm}
Loading