From e09eadcd669622e1bd4dfce43eb145999a44714d Mon Sep 17 00:00:00 2001 From: AlvandVahedi Date: Sun, 25 Jan 2026 21:59:03 +0000 Subject: [PATCH] Adding solvent classification in new branch --- remote/DiffCSP-official | 2 +- remote/riemannian-fm | 2 +- scripts_model/build_solvent_matrix.py | 109 ++++ scripts_model/conf/data/organic.yaml | 30 +- scripts_model/conf/default.yaml | 14 +- scripts_model/conf/model/null_params.yaml | 11 +- .../conf/vectorfield/rfm_cspnet.yaml | 10 +- scripts_model/evaluate.py | 471 +++++++++++++++++- src/orgflow/csp_ext/dataset.py | 83 +++ src/orgflow/model/arch.py | 122 +++++ src/orgflow/model/model_pl.py | 178 ++++++- 11 files changed, 1000 insertions(+), 32 deletions(-) create mode 100644 scripts_model/build_solvent_matrix.py diff --git a/remote/DiffCSP-official b/remote/DiffCSP-official index f6deb99..b81853f 160000 --- a/remote/DiffCSP-official +++ b/remote/DiffCSP-official @@ -1 +1 @@ -Subproject commit f6deb99050ab5f4e17cb87758914e972f723c785 +Subproject commit b81853f06e3c1f9b95fc064b8bcb889b6bebc9ff diff --git a/remote/riemannian-fm b/remote/riemannian-fm index f6fa72f..b0d01e9 160000 --- a/remote/riemannian-fm +++ b/remote/riemannian-fm @@ -1 +1 @@ -Subproject commit f6fa72f9590093ffff7b71d4c761ad07d1618cba +Subproject commit b0d01e966bc0b9fc87ad92424b6a3b8667b4fa85 diff --git a/scripts_model/build_solvent_matrix.py b/scripts_model/build_solvent_matrix.py new file mode 100644 index 0000000..7ff553e --- /dev/null +++ b/scripts_model/build_solvent_matrix.py @@ -0,0 +1,109 @@ +import argparse +import ast +import csv +from pathlib import Path + + +def parse_solvent_label(label): + label = (label or "").strip() + if not label: + return [] + if label.startswith("frozenset(") and label.endswith(")"): + set_literal = label[len("frozenset("):-1] + else: + set_literal = label + try: + solvents = ast.literal_eval(set_literal) + except Exception: + return [label] if label else [] + if isinstance(solvents, (set, frozenset, list, tuple)): + return [str(s).strip() for s in solvents if str(s).strip()] + if isinstance(solvents, str): + return [solvents.strip()] if solvents.strip() else [] + return [] + + +def load_vocab(vocab_path): + vocab = [] + with open(vocab_path) as f: + for line in f: + val = line.strip() + if val: + vocab.append(val) + return vocab + + +def load_solvent_map(solvent_csv, vocab_set): + with open(solvent_csv, newline="") as f: + reader = csv.reader(f) + rows = list(reader) + if not rows: + return {} + header = rows[0] + col_solvents = [parse_solvent_label(label) for label in header] + refcode_to_solvents = {} + for row in rows[1:]: + for idx, cell in enumerate(row): + cell = (cell or "").strip().strip('"') + if not cell: + continue + solvents = [s for s in col_solvents[idx] if s in vocab_set] + if not solvents: + continue + entry = refcode_to_solvents.setdefault(cell, set()) + entry.update(solvents) + return refcode_to_solvents + + +def build_matrix(split_csv, solvent_map, vocab, out_csv): + seen = set() + material_ids = [] + with open(split_csv, newline="") as f: + reader = csv.DictReader(f) + for row in reader: + mid = (row.get("material_id") or "").strip() + if not mid or mid in seen: + continue + seen.add(mid) + material_ids.append(mid) + + out_path = Path(out_csv) + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["material_id"] + vocab) + with_solvent = 0 + for mid in material_ids: + solvents = solvent_map.get(mid, set()) + row = [mid] + if solvents: + with_solvent += 1 + for solv in vocab: + row.append(1 if solv in solvents else 0) + writer.writerow(row) + + print(f"Rows: {len(material_ids)}") + print(f"Rows with solvents: {with_solvent}") + print(f"Output: {out_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Build solvent matrix CSV per split from CSD_solvents." + ) + parser.add_argument("--split_csv", required=True) + parser.add_argument("--solvent_csv", required=True) + parser.add_argument("--vocab_path", required=True) + parser.add_argument("--out_csv", required=True) + args = parser.parse_args() + + vocab = load_vocab(args.vocab_path) + if not vocab: + raise ValueError(f"Empty vocab: {args.vocab_path}") + vocab_set = set(vocab) + solvent_map = load_solvent_map(args.solvent_csv, vocab_set) + build_matrix(args.split_csv, solvent_map, vocab, args.out_csv) + + +if __name__ == "__main__": + main() diff --git a/scripts_model/conf/data/organic.yaml b/scripts_model/conf/data/organic.yaml index 4fa2c6f..3a3ae3b 100644 --- a/scripts_model/conf/data/organic.yaml +++ b/scripts_model/conf/data/organic.yaml @@ -28,11 +28,10 @@ datamodule: datasets: train: - _target_: diffcsp.pl_data.dataset.CrystDataset + _target_: orgflow.csp_ext.dataset.CrystDataset name: OrganicTrainSet - path: ${oc.env:PROJECT_ROOT}/data/organic/train_drug.csv -# save_path: ${oc.env:PROJECT_ROOT}/data/organic/train_organic/ - save_path: ${oc.env:PROJECT_ROOT}/data/organic/train_organic_drug.pt + path: ${oc.env:PROJECT_ROOT}/data/organic/train_drug_orgflow.csv + save_path: ${oc.env:PROJECT_ROOT}/data/organic/train_drug_orgflow.pt prop: density niggli: false primitive: false @@ -41,15 +40,17 @@ datamodule: tolerance: 0.01 use_space_group: false use_pos_index: false + solvent_matrix_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/train_solvent_matrix.csv + solvent_vocab_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/solvent_vocab.txt + drop_no_solvent: true lattice_scale_method: scale_length preprocess_workers: 30 val: - - _target_: diffcsp.pl_data.dataset.CrystDataset + - _target_: orgflow.csp_ext.dataset.CrystDataset name: OrganicValSet - path: ${oc.env:PROJECT_ROOT}/data/organic/val_drug.csv -# save_path: ${oc.env:PROJECT_ROOT}/data/organic/val_organic/ - save_path: ${oc.env:PROJECT_ROOT}/data/organic/val_organic_drug.pt + path: ${oc.env:PROJECT_ROOT}/data/organic/val_drug_orgflow.csv + save_path: ${oc.env:PROJECT_ROOT}/data/organic/val_drug_orgflow.pt prop: density niggli: false primitive: false @@ -58,15 +59,17 @@ datamodule: tolerance: 0.01 use_space_group: false use_pos_index: false + solvent_matrix_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/val_solvent_matrix.csv + solvent_vocab_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/solvent_vocab.txt + drop_no_solvent: true lattice_scale_method: scale_length preprocess_workers: 30 test: - - _target_: diffcsp.pl_data.dataset.CrystDataset + - _target_: orgflow.csp_ext.dataset.CrystDataset name: OrganicTestSet - path: ${oc.env:PROJECT_ROOT}/data/organic/polymorphics_subset.csv -# save_path: ${oc.env:PROJECT_ROOT}/data/organic/test_organic/ - save_path: ${oc.env:PROJECT_ROOT}/data/organic/polymorphics.pt + path: ${oc.env:PROJECT_ROOT}/data/organic/test_drug_orgflow.csv + save_path: ${oc.env:PROJECT_ROOT}/data/organic/test_drug_orgflow.pt prop: density niggli: false primitive: false @@ -75,6 +78,9 @@ datamodule: tolerance: 0.01 use_space_group: false use_pos_index: false + solvent_matrix_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/test_solvent_matrix.csv + solvent_vocab_path: ${oc.env:PROJECT_ROOT}/data/organic/solvents/solvent_vocab.txt + drop_no_solvent: true lattice_scale_method: scale_length preprocess_workers: 30 diff --git a/scripts_model/conf/default.yaml b/scripts_model/conf/default.yaml index 6b96661..3871c20 100755 --- a/scripts_model/conf/default.yaml +++ b/scripts_model/conf/default.yaml @@ -26,12 +26,12 @@ optim: # RFM optimizer: _target_: torch.optim.AdamW - lr: 0.0003 - weight_decay: 0.0 - lr_scheduler: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: ${data.train_max_epochs} - eta_min: 1e-5 + lr: 0.0002 + weight_decay: 0.0002 + # lr_scheduler: + # _target_: torch.optim.lr_scheduler.CosineAnnealingLR + # T_max: ${data.train_max_epochs} + # eta_min: 1e-5 interval: epoch ema_decay: 0.999 @@ -46,7 +46,7 @@ train: pl_trainer: fast_dev_run: False # Enable this for debug purposes strategy: ddp - devices: 2 + devices: 3 accelerator: gpu precision: 32 # max_steps: 10000 diff --git a/scripts_model/conf/model/null_params.yaml b/scripts_model/conf/model/null_params.yaml index d73ee91..c49f58f 100644 --- a/scripts_model/conf/model/null_params.yaml +++ b/scripts_model/conf/model/null_params.yaml @@ -1,7 +1,16 @@ cost_coord: 600. cost_lattice: 1. cost_type: 0.0 # Always should be 0 for our CSP task -cost_bond: 0.007 +cost_bond: 0.005 +cost_solvent: 4 +solvent_num_classes: 79 +solvent_embed_dim: 512 +solvent_pred_hidden_dim: 768 +solvent_pred_num_layers: 4 +solvent_dropout: 0.1 +solvent_pred_layernorm: true +solvent_pos_weight: auto +solvent_pos_weight_max: 20.0 affine_combine_costs: true target_distribution: conditional self_cond: false diff --git a/scripts_model/conf/vectorfield/rfm_cspnet.yaml b/scripts_model/conf/vectorfield/rfm_cspnet.yaml index d10302b..828dac2 100644 --- a/scripts_model/conf/vectorfield/rfm_cspnet.yaml +++ b/scripts_model/conf/vectorfield/rfm_cspnet.yaml @@ -1,7 +1,7 @@ _target_: orgflow.model.arch.CSPNet -hidden_dim: 128 +hidden_dim: 512 time_dim: 256 -num_layers: 12 # Testing more layers for organic data +num_layers: 3 # Testing more layers for organic data act_fn: silu dis_emb: sin num_freqs: 128 @@ -17,3 +17,9 @@ represent_num_atoms: false represent_angle_edge_to_lattice: true self_edges: false self_cond: ${model.self_cond} +solvent_num_classes: ${model.solvent_num_classes} +solvent_embed_dim: ${model.solvent_embed_dim} +solvent_pred_hidden_dim: ${model.solvent_pred_hidden_dim} +solvent_pred_num_layers: ${model.solvent_pred_num_layers} +solvent_dropout: ${model.solvent_dropout} +solvent_pred_layernorm: ${model.solvent_pred_layernorm} diff --git a/scripts_model/evaluate.py b/scripts_model/evaluate.py index 960e6f9..abee859 100755 --- a/scripts_model/evaluate.py +++ b/scripts_model/evaluate.py @@ -6,18 +6,26 @@ mp.set_sharing_strategy('file_system') # Test to see if it helps debug the error (TODO: remove this line and the line above it if it didn't fix the error) from copy import deepcopy +import csv +import json from pathlib import Path from typing import Any, Literal, Sequence import click import pytorch_lightning as pl import torch +import numpy as np +from torch.nn import functional as F from pytorch_lightning.callbacks import BasePredictionWriter from pytorch_lightning.loggers.wandb import WandbLogger from torch_geometric.data import Batch, Data, DataLoader import wandb -from diffcsp.script_utils import GenDataset +from omegaconf import OmegaConf +try: + from diffcsp.script_utils import GenDataset +except ModuleNotFoundError: + GenDataset = None from orgflow.model.eval_utils import ( CSPDataset, get_loaders, @@ -29,9 +37,7 @@ load_project_from_wandb, register_omega_conf_resolvers, ) -from orgflow.old_eval.generation_metrics import compute_generation_metrics -from orgflow.old_eval.lattice_metrics import compute_lattice_metrics -from orgflow.old_eval.reconstruction_metrics import compute_reconstruction_metrics +from manifm.ema import EMA TASKS_TYPE = Literal[ "reconstruct", "recon_trajectory", "generate", "gen_trajectory", "pred" @@ -766,6 +772,460 @@ def predict( ) +@cli.command(name="solvent_metrics") +@click.argument("checkpoint", type=Path) +@click.option("--stage", type=click.Choice(STAGES, case_sensitive=False), default="val") +@click.option("--batch_size", type=int, default=None) +@click.option("--threshold", type=float, default=0.3) +@click.option( + "--subdir", type=str, default="", help="subdir name at level of checkpoint" +) +@click.option( + "--max_batches", + type=int, + default=None, + help="optional cap on number of batches to evaluate", +) +def solvent_metrics( + checkpoint: Path, + stage: STAGE_TYPE, + batch_size: int | None, + threshold: float, + subdir: str, + max_batches: int | None, +) -> None: + cfg, model = load_model(checkpoint) + stage = stage.lower() + + if batch_size is None: + batch_size = getattr(cfg.data.datamodule.batch_size, stage) + print(f"Using {batch_size=} from default cfg") + else: + setattr(cfg.data.datamodule.batch_size, stage, batch_size) + print(f"Using custom {batch_size=}") + + loaders = get_loaders(cfg) + loader = loaders[STAGES.index(stage)] + target_dir = get_target_dir(checkpoint, subdir) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + model.eval() + vecfield = model.model.model if isinstance(model.model, EMA) else model.model + cspnet = vecfield.cspnet + if cspnet.solvent_pred is None: + raise RuntimeError( + "Solvent prediction head is not initialized. " + "Set cfg.model.solvent_num_classes and pass it to CSPNet." + ) + + total_loss = 0.0 + total_elems = 0 + total_tp = 0 + total_fp = 0 + total_fn = 0 + total_pos = 0 + total_graphs = 0 + num_classes = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(loader): + if max_batches is not None and batch_idx >= max_batches: + break + if not hasattr(batch, "solvent_vec"): + raise ValueError( + "batch is missing solvent_vec; check solvent_matrix_path " + "and drop_no_solvent settings." + ) + batch = batch.to(device) + targets = batch.solvent_vec + if targets.dim() == 3 and targets.size(1) == 1: + targets = targets.squeeze(1) + if targets.dim() == 1: + num_graphs = batch.num_atoms.size(0) if hasattr(batch, "num_atoms") else 1 + if targets.numel() % num_graphs == 0: + targets = targets.view(num_graphs, -1) + else: + targets = targets.unsqueeze(0) + + x1_full, _, dims_full, mask_full = model.manifold_getter( + batch.batch, + batch.atom_types, + batch.frac_coords, + batch.lengths, + batch.angles, + split_manifold=False, + ) + atom_types, frac_coords, lattices = model.manifold_getter.flatrep_to_georep( + x1_full, dims=dims_full, mask_a_or_f=mask_full + ) + non_zscored_lattice = ( + lattices.clone() if cspnet.represent_angle_edge_to_lattice else None + ) + if hasattr(vecfield, "lat_x_t_mean"): + lattices = (lattices - vecfield.lat_x_t_mean) / vecfield.lat_x_t_std + num_graphs = batch.num_atoms.size(0) + t_solvent = torch.ones( + (num_graphs, 1), + dtype=atom_types.dtype, + device=atom_types.device, + ) + logits = cspnet.solvent_logits_from_georep( + t_solvent, + atom_types, + frac_coords, + lattices, + batch.num_atoms, + batch.batch, + non_zscored_lattice, + edge_index=batch.edge_index, + to_jimages=batch.to_jimages, + ) + if logits.size(1) != targets.size(1): + raise ValueError( + "Solvent class count mismatch: " + f"logits={logits.size(1)} targets={targets.size(1)}." + ) + num_classes = logits.size(1) + targets = targets.to(device=logits.device, dtype=logits.dtype) + + loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="sum") + total_loss += float(loss.item()) + total_elems += int(targets.numel()) + + probs = torch.sigmoid(logits) + preds = probs >= threshold + tgt_bool = targets > 0.5 + total_tp += int((preds & tgt_bool).sum().item()) + total_fp += int((preds & ~tgt_bool).sum().item()) + total_fn += int((~preds & tgt_bool).sum().item()) + total_pos += int(tgt_bool.sum().item()) + total_graphs += int(targets.size(0)) + + eps = 1e-12 + precision = total_tp / (total_tp + total_fp + eps) + recall = total_tp / (total_tp + total_fn + eps) + f1 = (2.0 * precision * recall) / (precision + recall + eps) + bce = total_loss / max(1, total_elems) + avg_pos_per_graph = total_pos / max(1, total_graphs) + + metrics = { + f"{stage}/solvent_bce": bce, + f"{stage}/solvent_precision": precision, + f"{stage}/solvent_recall": recall, + f"{stage}/solvent_f1": f1, + f"{stage}/solvent_pos_per_graph": avg_pos_per_graph, + f"{stage}/solvent_num_graphs": float(total_graphs), + f"{stage}/solvent_num_classes": float(num_classes or 0), + f"{stage}/solvent_threshold": float(threshold), + } + + metrics_path = target_dir / f"solvent_metrics_{stage}.json" + metrics_path.write_text(json.dumps(metrics, indent=2)) + print(f"Saved solvent metrics to {metrics_path}") + print(metrics) + + +@cli.command(name="solvent_predictions") +@click.argument("checkpoint", type=Path) +@click.option("--stage", type=click.Choice(STAGES, case_sensitive=False), default="val") +@click.option("--batch_size", type=int, default=None) +@click.option("--threshold", type=float, default=0.3) +@click.option( + "--subdir", type=str, default="", help="subdir name at level of checkpoint" +) +@click.option( + "--max_batches", + type=int, + default=None, + help="optional cap on number of batches to evaluate", +) +@click.option( + "--out_csv", + type=str, + default=None, + help="optional path to write CSV (defaults to checkpoint subdir)", +) +@click.option( + "--out_json", + type=str, + default=None, + help="optional path to write per-sample JSONL metrics (defaults to checkpoint subdir)", +) +@click.option( + "--out_summary_json", + type=str, + default=None, + help="optional path to write aggregate JSON metrics (defaults to checkpoint subdir)", +) +def solvent_predictions( + checkpoint: Path, + stage: STAGE_TYPE, + batch_size: int | None, + threshold: float, + subdir: str, + max_batches: int | None, + out_csv: str | None, + out_json: str | None, + out_summary_json: str | None, +) -> None: + cfg, model = load_model(checkpoint) + stage = stage.lower() + + if batch_size is None: + batch_size = getattr(cfg.data.datamodule.batch_size, stage) + print(f"Using {batch_size=} from default cfg") + else: + setattr(cfg.data.datamodule.batch_size, stage, batch_size) + print(f"Using custom {batch_size=}") + + ds_cfg = cfg.data.datamodule.datasets[stage] + if OmegaConf.is_list(ds_cfg): + ds_cfg = ds_cfg[0] + solvent_vocab_path = ds_cfg.get("solvent_vocab_path", None) + if solvent_vocab_path is None: + raise ValueError("solvent_vocab_path is missing from dataset config.") + vocab = [ + line.strip() + for line in Path(solvent_vocab_path).read_text().splitlines() + if line.strip() + ] + if not vocab: + raise ValueError(f"solvent_vocab_path is empty: {solvent_vocab_path}") + + loaders = get_loaders(cfg) + loader = loaders[STAGES.index(stage)] + target_dir = get_target_dir(checkpoint, subdir) + out_path = Path(out_csv) if out_csv else target_dir / f"solvent_predictions_{stage}.csv" + out_json_path = ( + Path(out_json) + if out_json + else target_dir / f"solvent_predictions_{stage}.jsonl" + ) + out_summary_path = ( + Path(out_summary_json) + if out_summary_json + else target_dir / f"solvent_predictions_{stage}_summary.json" + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + model.eval() + vecfield = model.model.model if isinstance(model.model, EMA) else model.model + cspnet = vecfield.cspnet + if cspnet.solvent_pred is None: + raise RuntimeError( + "Solvent prediction head is not initialized. " + "Set cfg.model.solvent_num_classes and pass it to CSPNet." + ) + + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", newline="") as f_csv, out_json_path.open("w") as f_json: + writer = csv.writer(f_csv) + writer.writerow( + [ + "material_id", + "target_solvents", + "pred_solvents", + "pred_probs", + "target_count", + "pred_count", + ] + ) + total_tp = 0 + total_fp = 0 + total_fn = 0 + total_tn = 0 + total_samples = 0 + total_classes = 0 + sum_precision = 0.0 + sum_recall = 0.0 + sum_f1 = 0.0 + sum_jaccard = 0.0 + sum_target_count = 0 + sum_pred_count = 0 + with torch.no_grad(): + for batch_idx, batch in enumerate(loader): + if max_batches is not None and batch_idx >= max_batches: + break + if not hasattr(batch, "solvent_vec"): + raise ValueError( + "batch is missing solvent_vec; check solvent_matrix_path " + "and drop_no_solvent settings." + ) + batch = batch.to(device) + targets = batch.solvent_vec + if targets.dim() == 3 and targets.size(1) == 1: + targets = targets.squeeze(1) + if targets.dim() == 1: + num_graphs = batch.num_atoms.size(0) if hasattr(batch, "num_atoms") else 1 + if targets.numel() % num_graphs == 0: + targets = targets.view(num_graphs, -1) + else: + targets = targets.unsqueeze(0) + + x1_full, _, dims_full, mask_full = model.manifold_getter( + batch.batch, + batch.atom_types, + batch.frac_coords, + batch.lengths, + batch.angles, + split_manifold=False, + ) + atom_types, frac_coords, lattices = model.manifold_getter.flatrep_to_georep( + x1_full, dims=dims_full, mask_a_or_f=mask_full + ) + non_zscored_lattice = ( + lattices.clone() if cspnet.represent_angle_edge_to_lattice else None + ) + if hasattr(vecfield, "lat_x_t_mean"): + lattices = (lattices - vecfield.lat_x_t_mean) / vecfield.lat_x_t_std + num_graphs = batch.num_atoms.size(0) + t_solvent = torch.ones( + (num_graphs, 1), + dtype=atom_types.dtype, + device=atom_types.device, + ) + logits = cspnet.solvent_logits_from_georep( + t_solvent, + atom_types, + frac_coords, + lattices, + batch.num_atoms, + batch.batch, + non_zscored_lattice, + edge_index=batch.edge_index, + to_jimages=batch.to_jimages, + ) + if logits.size(1) != targets.size(1): + raise ValueError( + "Solvent class count mismatch: " + f"logits={logits.size(1)} targets={targets.size(1)}." + ) + + probs = torch.sigmoid(logits).detach().cpu().numpy() + targets_np = targets.detach().cpu().numpy() + + material_ids = getattr(batch, "material_id", None) + if material_ids is None: + material_ids = [""] * num_graphs + elif isinstance(material_ids, (list, tuple)): + material_ids = [str(x) for x in material_ids] + elif torch.is_tensor(material_ids): + material_ids = [str(x.item()) for x in material_ids] + else: + material_ids = [str(material_ids)] + if len(material_ids) != num_graphs: + raise ValueError( + f"material_id count mismatch: {len(material_ids)} != {num_graphs}" + ) + + for i in range(num_graphs): + target_idx = np.where(targets_np[i] > 0.5)[0].tolist() + pred_mask = probs[i] >= threshold + pred_idx = np.where(pred_mask)[0].tolist() + target_names = [vocab[j] for j in target_idx] + pred_names = [vocab[j] for j in pred_idx] + pred_probs_str = ";".join( + f"{vocab[j]}:{probs[i][j]:.3f}" for j in pred_idx + ) + writer.writerow( + [ + material_ids[i], + ";".join(target_names), + ";".join(pred_names), + pred_probs_str, + len(target_names), + len(pred_names), + ] + ) + + target_mask = targets_np[i] > 0.5 + tp = int(np.logical_and(pred_mask, target_mask).sum()) + fp = int(np.logical_and(pred_mask, ~target_mask).sum()) + fn = int(np.logical_and(~pred_mask, target_mask).sum()) + tn = int((~pred_mask & ~target_mask).sum()) + precision = tp / (tp + fp + 1e-12) + recall = tp / (tp + fn + 1e-12) + f1 = (2.0 * precision * recall) / (precision + recall + 1e-12) + acc = (tp + tn) / max(1, targets_np.shape[1]) + jaccard = tp / (tp + fp + fn + 1e-12) + record = { + "material_id": material_ids[i], + "target_solvents": target_names, + "pred_solvents": pred_names, + "pred_probs": { + vocab[j]: float(probs[i][j]) for j in pred_idx + }, + "threshold": float(threshold), + "thresholds_source": "global", + "tp": tp, + "fp": fp, + "fn": fn, + "tn": tn, + "precision": precision, + "recall": recall, + "f1": f1, + "accuracy": acc, + "jaccard": jaccard, + "target_count": len(target_names), + "pred_count": len(pred_names), + "num_classes": int(targets_np.shape[1]), + } + f_json.write(json.dumps(record) + "\n") + total_tp += tp + total_fp += fp + total_fn += fn + total_tn += tn + total_samples += 1 + total_classes += int(targets_np.shape[1]) + sum_precision += precision + sum_recall += recall + sum_f1 += f1 + sum_jaccard += jaccard + sum_target_count += len(target_names) + sum_pred_count += len(pred_names) + + print(f"Wrote per-sample solvent predictions to {out_path}") + print(f"Wrote per-sample solvent metrics to {out_json_path}") + if total_samples > 0 and total_classes > 0: + micro_precision = total_tp / (total_tp + total_fp + 1e-12) + micro_recall = total_tp / (total_tp + total_fn + 1e-12) + micro_f1 = (2.0 * micro_precision * micro_recall) / ( + micro_precision + micro_recall + 1e-12 + ) + micro_acc = (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn + 1e-12) + micro_jaccard = total_tp / (total_tp + total_fp + total_fn + 1e-12) + macro_precision = sum_precision / total_samples + macro_recall = sum_recall / total_samples + macro_f1 = sum_f1 / total_samples + macro_jaccard = sum_jaccard / total_samples + summary = { + "stage": stage, + "threshold": float(threshold), + "thresholds_source": "global", + "num_samples": total_samples, + "num_classes": int(total_classes / total_samples), + "avg_target_count": sum_target_count / total_samples, + "avg_pred_count": sum_pred_count / total_samples, + "tp": total_tp, + "fp": total_fp, + "fn": total_fn, + "tn": total_tn, + "micro_precision": micro_precision, + "micro_recall": micro_recall, + "micro_f1": micro_f1, + "micro_accuracy": micro_acc, + "micro_jaccard": micro_jaccard, + "macro_precision": macro_precision, + "macro_recall": macro_recall, + "macro_f1": macro_f1, + "macro_jaccard": macro_jaccard, + } + out_summary_path.write_text(json.dumps(summary, indent=2)) + print(f"Wrote aggregate solvent metrics to {out_summary_path}") + + def _get_consolidated_path(directory: Path, task: str) -> str: return directory / f"consolidated_{task}.pt" @@ -939,6 +1399,7 @@ def _reconstruction_metrics_wandb( global_step: int, stage: STAGE_TYPE, ) -> dict[str, float]: + from orgflow.old_eval.reconstruction_metrics import compute_reconstruction_metrics recon_metrics = {} if consolidated_reconstruction_path.exists(): # should be 20 evals @@ -980,6 +1441,7 @@ def _generation_metrics_wandb( n_subsamples: int, stage: STAGE_TYPE, ) -> dict[str, float]: + from orgflow.old_eval.generation_metrics import compute_generation_metrics gen_metrics = {} if consolidated_generation_path.exists(): tmp = compute_generation_metrics( @@ -1116,6 +1578,7 @@ def lattice_metrics( subdir: str, stage: STAGE_TYPE, ) -> None: + from orgflow.old_eval.lattice_metrics import compute_lattice_metrics # log_wandb = not do_not_log_wandb target_dir = get_target_dir(checkpoint, subdir) diff --git a/src/orgflow/csp_ext/dataset.py b/src/orgflow/csp_ext/dataset.py index 6c753c3..ddd9b60 100644 --- a/src/orgflow/csp_ext/dataset.py +++ b/src/orgflow/csp_ext/dataset.py @@ -33,13 +33,52 @@ def __init__(self, name: ValueNode, path: ValueNode, self.use_space_group = use_space_group self.use_pos_index = use_pos_index self.tolerance = tolerance + self.debug = bool(kwargs.get("debug", False)) + self.solvent_matrix_path = kwargs.get("solvent_matrix_path", None) + self.solvent_vocab_path = kwargs.get("solvent_vocab_path", None) + self.drop_no_solvent = bool(kwargs.get("drop_no_solvent", False)) + self.solvent_matrix = {} + self.solvent_vocab = [] + self.num_solvent_classes = 0 self.preprocess(save_path, preprocess_workers, prop) + self._load_solvent_matrix() + if self.drop_no_solvent and self.solvent_matrix: + before = len(self.cached_data) + self.cached_data = [ + d + for d in self.cached_data + if self._solvent_sum(d.get("mp_id")) > 0.0 + ] + after = len(self.cached_data) + if self.debug: + print(f"[dataset] drop_no_solvent kept={after} dropped={before - after}") add_scaled_lattice_prop(self.cached_data, lattice_scale_method) self.lattice_scaler = None self.scaler = None + if self.debug: + has_solvent = bool(self.solvent_matrix) + print( + f"[dataset] name={self.name} rows={len(self.df)} cached={len(self.cached_data)} " + f"path={self.path} prop={self.prop}" + ) + print( + f"[dataset] solvent_matrix={self.solvent_matrix_path} " + f"classes={self.num_solvent_classes} has_solvent={has_solvent}" + ) + if self.cached_data: + sample = self.cached_data[0] + keys = sorted(sample.keys()) + print(f"[dataset] cached_keys_sample={keys[:8]} total_keys={len(keys)}") + if has_solvent: + sample_id = sample.get("mp_id") + sample_sum = self._solvent_sum(sample_id) + print( + f"[dataset] solvent_vec_sample_id={sample_id} sum={sample_sum}" + ) + self.__repr__() def preprocess(self, save_path, preprocess_workers, prop): @@ -93,6 +132,7 @@ def __getitem__(self, index): # cif=self.df.iloc[index]['CIF'], # To calculate the additional loss in model_pl # smiles=self.df.iloc[index]['SMILES'] ) + data.material_id = str(data_dict.get("mp_id", "")) if 'bond_mean' in data_dict: data.bond_mean = torch.tensor(data_dict['bond_mean'], dtype=torch.float32) # [E] @@ -154,8 +194,51 @@ def __getitem__(self, index): pos_dic[atom] = pos_dic.get(atom, 0) + 1 indexes.append(pos_dic[atom] - 1) data.index = torch.LongTensor(indexes) + if self.solvent_matrix: + mp_id = data_dict.get("mp_id") + vec = self.solvent_matrix.get(mp_id) + has_solvent = vec is not None + if vec is None: + vec = np.zeros(self.num_solvent_classes, dtype=np.float32) + data.solvent_vec = torch.tensor(vec, dtype=torch.float32).view(1, -1) + data.has_solvent = torch.tensor([has_solvent], dtype=torch.bool) return data + def _solvent_sum(self, mp_id): + if mp_id is None or not self.solvent_matrix: + return 0.0 + vec = self.solvent_matrix.get(mp_id) + if vec is None: + return 0.0 + return float(np.sum(vec)) + + def _load_solvent_matrix(self): + if not self.solvent_matrix_path: + return + solvent_df = pd.read_csv(self.solvent_matrix_path) + if "material_id" not in solvent_df.columns: + raise ValueError( + "solvent_matrix_path must include a 'material_id' column." + ) + solvent_cols = [c for c in solvent_df.columns if c != "material_id"] + if self.solvent_vocab_path: + vocab = [line.strip() for line in Path(self.solvent_vocab_path).read_text().splitlines() if line.strip()] + if vocab: + solvent_cols = [v for v in vocab if v in solvent_cols] + self.solvent_vocab = solvent_cols + else: + self.solvent_vocab = solvent_cols + else: + self.solvent_vocab = solvent_cols + self.num_solvent_classes = len(self.solvent_vocab) + if self.num_solvent_classes == 0: + return + solvent_df = solvent_df.set_index("material_id") + matrix = solvent_df[self.solvent_vocab] + self.solvent_matrix = { + idx: row.values.astype(np.float32) for idx, row in matrix.iterrows() + } + def __repr__(self) -> str: return f"CrystDataset({self.name=}, {self.path=})" diff --git a/src/orgflow/model/arch.py b/src/orgflow/model/arch.py index 96247e0..bcd966d 100644 --- a/src/orgflow/model/arch.py +++ b/src/orgflow/model/arch.py @@ -252,6 +252,12 @@ def __init__( represent_angle_edge_to_lattice: bool = False, self_edges: bool = True, self_cond: bool = False, + solvent_num_classes: int = 0, + solvent_embed_dim: int | None = None, + solvent_pred_hidden_dim: int | None = None, + solvent_pred_num_layers: int = 2, + solvent_dropout: float = 0.0, + solvent_pred_layernorm: bool = False, ): nn.Module.__init__(self) assert not ( @@ -333,6 +339,36 @@ def __init__( self.final_layer_norm = nn.LayerNorm(hidden_dim) self.represent_angle_edge_to_lattice = represent_angle_edge_to_lattice self.self_edges = self_edges + self.solvent_num_classes = int(solvent_num_classes) + self.solvent_embed_dim = int(solvent_embed_dim or hidden_dim) + self.solvent_pred_hidden_dim = int(solvent_pred_hidden_dim or hidden_dim) + self.solvent_pred_num_layers = int(solvent_pred_num_layers) + self.solvent_dropout = float(solvent_dropout) + self.solvent_pred_layernorm = bool(solvent_pred_layernorm) + self.solvent_embed = None + self.solvent_pred = None + if self.solvent_num_classes > 0: + graph_dim = hidden_dim * (2 if concat_sum_pool else 1) + self.solvent_embed = nn.Sequential( + nn.Linear(self.solvent_num_classes, self.solvent_embed_dim), + self.act_fn, + nn.Linear(self.solvent_embed_dim, graph_dim), + ) + if self.solvent_pred_num_layers <= 1: + self.solvent_pred = nn.Linear(graph_dim, self.solvent_num_classes) + else: + layers = [] + in_dim = graph_dim + for _ in range(self.solvent_pred_num_layers - 1): + layers.append(nn.Linear(in_dim, self.solvent_pred_hidden_dim)) + if self.solvent_pred_layernorm: + layers.append(nn.LayerNorm(self.solvent_pred_hidden_dim)) + layers.append(self.act_fn) + if self.solvent_dropout > 0.0: + layers.append(nn.Dropout(self.solvent_dropout)) + in_dim = self.solvent_pred_hidden_dim + layers.append(nn.Linear(in_dim, self.solvent_num_classes)) + self.solvent_pred = nn.Sequential(*layers) def gen_edges(self, num_atoms, frac_coords, lattices, node2graph, edge_index=None, to_jimages=None @@ -457,6 +493,7 @@ def forward( num_atoms, node2graph, non_zscored_lattice, + solvent_vec=None, edge_index=None, to_jimages=None, ): @@ -521,6 +558,20 @@ def forward( ) else: graph_features = scatter(node_features, node2graph, dim=0, reduce="mean") + if self.solvent_embed is not None and solvent_vec is not None: + solvent_vec = solvent_vec.to(graph_features) + if solvent_vec.dim() == 1: + batch_size = graph_features.size(0) + if solvent_vec.numel() % batch_size == 0: + solvent_vec = solvent_vec.view(batch_size, -1) + else: + solvent_vec = solvent_vec.unsqueeze(0) + if solvent_vec.size(0) != graph_features.size(0): + raise ValueError( + "solvent_vec batch mismatch: " + f"{solvent_vec.size(0)} != {graph_features.size(0)}" + ) + graph_features = graph_features + self.solvent_embed(solvent_vec) lattice_out = self.lattice_out(graph_features) if self.lattice_manifold == "non_symmetric": lattice_out = lattice_out.view(-1, self.n_space, self.n_space) @@ -533,6 +584,73 @@ def forward( return lattice_out, coord_out, type_out + def solvent_logits_from_georep( + self, + t, + atom_types, + frac_coords, + lattices, + num_atoms, + node2graph, + non_zscored_lattice, + edge_index=None, + to_jimages=None, + ): + if self.solvent_pred is None: + raise RuntimeError("Solvent prediction head is not initialized.") + t_emb = self.time_emb(t) + t_emb = t_emb.expand(num_atoms.shape[0], -1) + + edges, frac_diff = self.gen_edges( + num_atoms, frac_coords, lattices, node2graph, edge_index, to_jimages + ) + edge2graph = node2graph[edges[0]] + + if self.represent_angle_edge_to_lattice: + if self.self_cond: + nzsl, nzsl_pred = torch.tensor_split(non_zscored_lattice, 2, dim=-1) + l = self._convert_lin_to_lattice(nzsl) + if (torch.zeros_like(nzsl_pred) == nzsl_pred).all(): + l_pred = None + else: + l_pred = self._convert_lin_to_lattice(nzsl_pred) + else: + l = self._convert_lin_to_lattice(non_zscored_lattice) + l_pred = None + else: + l = None + l_pred = None + + node_features = self.node_embedding(atom_types) + t_per_atom = t_emb.repeat_interleave(num_atoms, dim=0) + node_features = torch.cat([node_features, t_per_atom], dim=1) + node_features = self.atom_latent_emb(node_features) + for i in range(0, self.num_layers): + node_features = self._modules["csp_layer_%d" % i]( + node_features, + lattices, + edges, + edge2graph, + frac_diff, + num_atoms, + non_zscored_lattice=l, + non_zscored_lattice_pred=l_pred, + ) + if self.ln: + node_features = self.final_layer_norm(node_features) + + if self.concat_sum_pool: + graph_features = torch.concat( + [ + scatter(node_features, node2graph, dim=0, reduce="mean"), + scatter(node_features, node2graph, dim=0, reduce="sum"), + ], + dim=-1, + ) + else: + graph_features = scatter(node_features, node2graph, dim=0, reduce="mean") + return self.solvent_pred(graph_features) + class ProjectedConjugatedCSPNet(nn.Module): def __init__( @@ -585,6 +703,7 @@ def _conjugated_forward( t: torch.Tensor, x: torch.Tensor, cond: torch.Tensor | None, + solvent_vec: torch.Tensor | None = None, edge_index: torch.Tensor | None = None, to_jimages: torch.Tensor | None = None, ) -> ManifoldGetterOut: @@ -637,6 +756,7 @@ def _conjugated_forward( num_atoms, node2graph, non_zscored_lattice, + solvent_vec=solvent_vec, edge_index=edge_index, to_jimages=to_jimages, ) @@ -671,6 +791,7 @@ def forward( x: torch.Tensor, manifold: Manifold, cond: torch.Tensor | None = None, + solvent_vec: torch.Tensor | None = None, edge_index: torch.Tensor | None = None, to_jimages: torch.Tensor | None = None, ) -> torch.Tensor: @@ -684,6 +805,7 @@ def forward( cond = manifold.projx(cond) v, *_ = self._conjugated_forward( num_atoms, node2graph, dims, mask_a_or_f, t, x, cond, + solvent_vec=solvent_vec, edge_index=edge_index, to_jimages=to_jimages, ) diff --git a/src/orgflow/model/model_pl.py b/src/orgflow/model/model_pl.py index 680c2ec..2129c18 100644 --- a/src/orgflow/model/model_pl.py +++ b/src/orgflow/model/model_pl.py @@ -5,14 +5,18 @@ import sys import warnings from functools import partial +from pathlib import Path from typing import Any, Literal import hydra import pytorch_lightning as pl import torch +import pandas as pd +import numpy as np +from torch.nn import functional as F from geoopt.manifolds.euclidean import Euclidean from geoopt.manifolds.product import ProductManifold -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.func import jvp, vjp from torch_geometric.data import Data @@ -82,6 +86,7 @@ def __init__(self, cfg: DictConfig): cost_cross_ent = cfg.model.cost_cross_ent cost_bond = float(getattr(cfg.model, "cost_bond", 0.0)) + cost_solvent = float(getattr(cfg.model, "cost_solvent", 0.0)) self.costs = { "loss_a": cost_type, @@ -89,11 +94,14 @@ def __init__(self, cfg: DictConfig): "loss_l": cfg.model.cost_lattice, "loss_ce": cost_cross_ent, "loss_bond": cost_bond, + "loss_solv": cost_solvent, } if cfg.model.affine_combine_costs: total_cost = sum([v for v in self.costs.values()]) self.costs = {k: v / total_cost for k, v in self.costs.items()} + self.solvent_pos_weight = self._init_solvent_pos_weight() + model: CSPNet = hydra.utils.instantiate( self.cfg.vectorfield, _convert_="partial" ) @@ -125,6 +133,7 @@ def __init__(self, cfg: DictConfig): "loss_l": MeanMetric(dist_sync_on_step=True), "loss_ce": MeanMetric(dist_sync_on_step=True), "loss_bond": MeanMetric(dist_sync_on_step=True), + "loss_solv": MeanMetric(dist_sync_on_step=True), "unscaled/loss_a": MeanMetric(dist_sync_on_step=True), "unscaled/loss_f": MeanMetric(dist_sync_on_step=True), "unscaled/loss_l": MeanMetric(dist_sync_on_step=True), @@ -137,6 +146,7 @@ def __init__(self, cfg: DictConfig): "loss_l": MeanMetric(), "loss_ce": MeanMetric(), "loss_bond": MeanMetric(), + "loss_solv": MeanMetric(), "unscaled/loss_a": MeanMetric(), "unscaled/loss_f": MeanMetric(), "unscaled/loss_l": MeanMetric(), @@ -149,6 +159,7 @@ def __init__(self, cfg: DictConfig): "loss_l": MeanMetric(), "loss_ce": MeanMetric(), "loss_bond": MeanMetric(), + "loss_solv": MeanMetric(), "unscaled/loss_a": MeanMetric(), "unscaled/loss_f": MeanMetric(), "unscaled/loss_l": MeanMetric(), @@ -163,6 +174,7 @@ def __init__(self, cfg: DictConfig): "loss_l": MinMetric(), "loss_ce": MinMetric(), "loss_bond": MinMetric(), + "loss_solv": MinMetric(), "unscaled/loss_a": MinMetric(), "unscaled/loss_f": MinMetric(), "unscaled/loss_l": MinMetric(), @@ -172,6 +184,69 @@ def __init__(self, cfg: DictConfig): self.val_metrics["nll"] = MeanMetric() self.val_metrics_best["nll"] = MinMetric() + def _init_solvent_pos_weight(self) -> torch.Tensor | None: + mode = getattr(self.cfg.model, "solvent_pos_weight", None) + if not mode or float(self.costs.get("loss_solv", 0.0)) == 0.0: + return None + if isinstance(mode, (list, tuple)): + return torch.tensor(mode, dtype=torch.float32) + if isinstance(mode, (int, float)): + if mode <= 0: + return None + return torch.full( + (int(self.cfg.model.solvent_num_classes),), + float(mode), + dtype=torch.float32, + ) + if isinstance(mode, str) and mode.lower() != "auto": + raise ValueError( + f"Unknown solvent_pos_weight mode '{mode}'. Use 'auto' or a list." + ) + + ds_cfg = self.cfg.data.datamodule.datasets.train + if OmegaConf.is_list(ds_cfg): + ds_cfg = ds_cfg[0] + solvent_matrix_path = ds_cfg.get("solvent_matrix_path", None) + solvent_vocab_path = ds_cfg.get("solvent_vocab_path", None) + if not solvent_matrix_path: + raise ValueError( + "solvent_pos_weight=auto needs train.solvent_matrix_path." + ) + df = pd.read_csv(solvent_matrix_path) + if "material_id" not in df.columns: + raise ValueError( + "solvent_matrix_path must include a 'material_id' column." + ) + solvent_cols = [c for c in df.columns if c != "material_id"] + if solvent_vocab_path: + vocab = [ + line.strip() + for line in Path(solvent_vocab_path).read_text().splitlines() + if line.strip() + ] + if vocab: + solvent_cols = [v for v in vocab if v in solvent_cols] + if not solvent_cols: + raise ValueError("No solvent columns found for pos_weight.") + mat = df[solvent_cols].to_numpy(dtype=np.float32) + pos = mat.sum(axis=0) + total = mat.shape[0] + neg = total - pos + pos_weight = neg / np.clip(pos, 1.0, None) + max_w = getattr(self.cfg.model, "solvent_pos_weight_max", None) + if max_w is not None: + pos_weight = np.minimum(pos_weight, float(max_w)) + expected = int(getattr(self.cfg.model, "solvent_num_classes", 0)) + if expected and len(pos_weight) != expected: + raise ValueError( + f"solvent_pos_weight size mismatch: {len(pos_weight)} != {expected}" + ) + print( + f"[solvent] pos_weight auto from {solvent_matrix_path} " + f"classes={len(pos_weight)} max={max_w}" + ) + return torch.tensor(pos_weight, dtype=torch.float32) + @staticmethod def _annealing_schedule( t: torch.Tensor, slope: float, intercept: float @@ -216,6 +291,7 @@ def sample( x0 = manifold.random(*x1.shape, dtype=x1.dtype, device=x1.device) else: x0 = x0.to(x1) + solvent_vec = getattr(batch, "solvent_vec", None) return self.finish_sampling( x0=x0, @@ -228,6 +304,7 @@ def sample( node2graph=batch.batch, edge_index=batch.edge_index, to_jimages=batch.to_jimages, + solvent_vec=solvent_vec, mask_a_or_f=mask_a_or_f, num_steps=num_steps, entire_traj=entire_traj, @@ -331,9 +408,10 @@ def finish_sampling( node2graph: torch.LongTensor, edge_index: torch.LongTensor, to_jimages: torch.Tensor, - mask_a_or_f: torch.BoolTensor, - num_steps: int, - entire_traj: bool, + solvent_vec: torch.Tensor | None = None, + mask_a_or_f: torch.BoolTensor | None = None, + num_steps: int = 1_000, + entire_traj: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # bind the real graph connectivity into the vector field vecfield = partial( @@ -344,6 +422,7 @@ def finish_sampling( mask_a_or_f=mask_a_or_f, edge_index=edge_index, to_jimages=to_jimages, + solvent_vec=solvent_vec, ) compute_traj_velo_norms = self.cfg.integrate.get( @@ -485,6 +564,7 @@ def compute_exact_loglikelihood( node2graph=batch.batch, dims=dims, mask_a_or_f=mask_a_or_f, + solvent_vec=getattr(batch, "solvent_vec", None), ) def odefunc(t, tensor): @@ -604,6 +684,7 @@ def rfm_loss_fn( node2graph=batch.batch, dims=dims, mask_a_or_f=mask_a_or_f, + solvent_vec=getattr(batch, "solvent_vec", None), ) N = x1.shape[0] @@ -726,6 +807,93 @@ def rfm_loss_fn( else: loss_bond_raw = torch.zeros((), dtype=dists.dtype, device=dists.device) # ------------------------------------------------ + # Solvent prediction loss (multilabel BCE) + loss_solv = torch.zeros((), dtype=diff.dtype, device=diff.device) + if self.costs.get("loss_solv", 0.0) > 0.0 and hasattr(batch, "solvent_vec"): + targets = batch.solvent_vec + if targets.numel() > 0: + if targets.dim() == 3 and targets.size(1) == 1: + targets = targets.squeeze(1) + if targets.dim() == 1: + num_graphs = batch.num_atoms.size(0) if hasattr(batch, "num_atoms") else 1 + if targets.numel() % num_graphs == 0: + targets = targets.view(num_graphs, -1) + else: + targets = targets.unsqueeze(0) + has_solvent = getattr(batch, "has_solvent", None) + mask = None + if torch.is_tensor(has_solvent): + if has_solvent.dim() == 2 and has_solvent.size(1) == 1: + has_solvent = has_solvent.squeeze(1) + if has_solvent.dim() == 0: + has_solvent = has_solvent.unsqueeze(0) + if has_solvent.numel() == targets.size(0): + mask = has_solvent.to(device=targets.device, dtype=torch.bool) + vecfield = self.model.model if isinstance(self.model, EMA) else self.model + cspnet = vecfield.cspnet + if cspnet.solvent_pred is None: + raise RuntimeError( + "Solvent prediction head is not initialized. " + "Set cfg.model.solvent_num_classes and pass it to CSPNet." + ) + x1_full, _, dims_full, mask_full = self.manifold_getter( + batch.batch, + batch.atom_types, + batch.frac_coords, + batch.lengths, + batch.angles, + split_manifold=False, + ) + atom_types, frac_coords, lattices = self.manifold_getter.flatrep_to_georep( + x1_full, dims=dims_full, mask_a_or_f=mask_full + ) + non_zscored_lattice = ( + lattices.clone() + if cspnet.represent_angle_edge_to_lattice + else None + ) + if hasattr(vecfield, "lat_x_t_mean"): + lattices = (lattices - vecfield.lat_x_t_mean) / vecfield.lat_x_t_std + num_graphs = batch.num_atoms.size(0) + t_solvent = torch.ones( + (num_graphs, 1), + dtype=atom_types.dtype, + device=atom_types.device, + ) + logits = cspnet.solvent_logits_from_georep( + t_solvent, + atom_types, + frac_coords, + lattices, + batch.num_atoms, + batch.batch, + non_zscored_lattice, + edge_index=batch.edge_index, + to_jimages=batch.to_jimages, + ) + if logits.size(1) != targets.size(1): + raise ValueError( + "Solvent class count mismatch: " + f"logits={logits.size(1)} targets={targets.size(1)}. " + "Update cfg.model.solvent_num_classes to match vocab." + ) + targets = targets.to(device=logits.device, dtype=logits.dtype) + if mask is not None: + if mask.any(): + logits = logits[mask] + targets = targets[mask] + else: + logits = None + if logits is not None: + if self.solvent_pos_weight is not None: + pos_weight = self.solvent_pos_weight.to( + device=logits.device, dtype=logits.dtype + ) + loss_solv = F.binary_cross_entropy_with_logits( + logits, targets, pos_weight=pos_weight + ) + else: + loss_solv = F.binary_cross_entropy_with_logits(logits, targets) # Original inner products s = 0 @@ -754,6 +922,7 @@ def rfm_loss_fn( + self.costs["loss_l"] * loss_l + self.costs["loss_ce"] * loss_ce + cost_bond * loss_bond_raw # new term + + self.costs["loss_solv"] * loss_solv ) return { @@ -767,6 +936,7 @@ def rfm_loss_fn( "unscaled/loss_l": loss_l, "unscaled/loss_ce": loss_ce, "loss_bond": cost_bond * loss_bond_raw, # log scaled + "loss_solv": self.costs["loss_solv"] * loss_solv, } def training_step(self, batch: Data, batch_idx: int):