From 70569028a0390a74423e045d284125abed00ac15 Mon Sep 17 00:00:00 2001 From: isayev Date: Sat, 21 Feb 2026 23:14:46 -0500 Subject: [PATCH 1/2] Add .worktrees/ to gitignore --- .gitignore | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/.gitignore b/.gitignore index 15b49d7..9e6b271 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,24 @@ +# Claude Code and AI assistant files +CLAUDE.md +.claude/ + +# Development, roadmap, and planning files +ROADMAP.md +ROADMAP*.md +PLAN.md +PLAN*.md +PLANNING.md +PLANNING*.md +TODO.md +TODO*.md +DEVLOG.md +DEVLOG*.md +*.plan.md +*.roadmap.md + +# Git worktrees +.worktrees/ + .idea /scripts/results/ /results/ From 0830a3302322a1019e3e1423b5146bd00f25deb9 Mon Sep 17 00:00:00 2001 From: isayev Date: Sun, 22 Feb 2026 14:51:17 -0500 Subject: [PATCH 2/2] =?UTF-8?q?feat:=20inference=20optimization=20?= =?UTF-8?q?=E2=80=94=20new=20API,=20perf=20improvements,=20bug=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## New megalodon.inference package Public API (`from megalodon.inference import ...`): - `validate_smiles(smiles)` — validates SMILES before featurization; rejects salts, unsupported elements (LoQI vocab: 17 atoms), radicals - `generate_conformers(smiles_list, model, cfg, n_confs, batch_size=48, max_atoms_per_batch)` — batched conformer generation returning a structured `ConformerGenerationResult` with per-SMILES conformer lists, error records, and `.to_sdf()` serialization; default batch_size=48 (sweep-validated optimum on L40S: 8.7 conf/s at 83% SM utilization, 1.2 GB peak) - `ffd_pack_indices / pack_batches` — First-Fit-Decreasing atom-count bin- packing to minimise padding waste on heterogeneous molecule sets - `ConformerGenerationResult / MoleculeProcessingError` — typed result objects ## Performance improvements (src/megalodon/models/ and dynamics/) - Pre-build time tensors before diffusion loop (eliminates per-step alloc) - Register `freqs` as buffer in TimestepEmbedder (CPU→GPU transfer gone) - Cache time embeddings per discrete timestep in MegaFNV3Conf (24/25 MLP calls eliminated per sample) - Precompute attention mask once before diffusion loop (25 recomputations eliminated) - Pre-encode null variable one-hots before sample loop (25× redundant F.one_hot calls eliminated) - Skip softmax for discrete_null pass-through logits in `separate_discrete_variables` - Convert attn_mask to float additive bias (0.0 / -inf) enabling efficient Flash Attention dispatch - Skip ETKDG 3D embedding in app inference path (coordinates are overwritten by diffusion prior anyway) ## Bug fixes - batch_preprocessor argument typo in sample_conformers.py - Duplicate `lin_edge1` layer definition in fn_model.py - Stray `torch.max` expressions with discarded results in fn_model.py - `ModuleDict[key] is None` check (was `.get()` / `not in`, both wrong for nn.ModuleDict with None values) - `DataLoader` import moved from deprecated `torch_geometric.data` to `torch_geometric.loader` - `copy(base_data)` → `base_data.clone()` in featurization (shallow copy shared tensor storage, causing in-place mutation bugs) - Inline `_ATOM_ENCODER` to avoid 8s pytorch-lightning transitive import - `Chem.SetUseLegacyStereoPerception(True)` in package `__init__.py` to match training-time stereo assignment - Restore `--skip_eval` CLI arg as no-op for backward compat - Preserve `_Name` from SDF mol inputs in pickle output IDs - Add warning comment on Z-branch float-mask incompatibility in fn_model.py ## Scripts / tooling - `scripts/sample_conformers.py` refactored to use `generate_conformers()` API - `scripts/benchmark_inference.py` — timing + accuracy benchmark (20 curated drug-like molecules, batch sweep, FFD vs fixed, SMILES round-trip check) - `scripts/sustained_perf_test.py` — large-scale sustained-load test using real ChEMBL3D test-set SMILES with stratified size sampling - `scripts/batch_size_sweep.py` — batch-size sweep with live nvidia-smi GPU SM% / memory-BW% sampling; identifies throughput knee and efficiency optimum --- app/utils.py | 21 +- scripts/batch_size_sweep.py | 312 ++++++++++++++++++++ scripts/benchmark_inference.py | 343 ++++++++++++++++++++++ scripts/sample_conformers.py | 102 +++---- scripts/sustained_perf_test.py | 344 +++++++++++++++++++++++ src/megalodon/dynamics/fn_model.py | 77 +++-- src/megalodon/inference/__init__.py | 34 +++ src/megalodon/inference/batching.py | 89 ++++++ src/megalodon/inference/featurization.py | 231 +++++++++++++++ src/megalodon/inference/generation.py | 140 +++++++++ src/megalodon/inference/result.py | 64 +++++ src/megalodon/inference/validation.py | 46 +++ src/megalodon/models/denoising_models.py | 1 + src/megalodon/models/module.py | 67 ++++- 14 files changed, 1766 insertions(+), 105 deletions(-) create mode 100644 scripts/batch_size_sweep.py create mode 100644 scripts/benchmark_inference.py create mode 100644 scripts/sustained_perf_test.py create mode 100644 src/megalodon/inference/__init__.py create mode 100644 src/megalodon/inference/batching.py create mode 100644 src/megalodon/inference/featurization.py create mode 100644 src/megalodon/inference/generation.py create mode 100644 src/megalodon/inference/result.py create mode 100644 src/megalodon/inference/validation.py diff --git a/app/utils.py b/app/utils.py index bbb0595..b20d88f 100644 --- a/app/utils.py +++ b/app/utils.py @@ -5,6 +5,7 @@ import sys import torch import numpy as np +from copy import copy from rdkit import Chem from rdkit.Chem import AllChem from torch_geometric.data import Data, Batch @@ -103,7 +104,7 @@ def add_stereo_bonds(mol, chi_bonds, ez_bonds, edge_index=None, edge_attr=None, return edge_index, edge_attr -def mol_to_torch_geometric_simple(mol, smiles): +def mol_to_torch_geometric_simple(mol, smiles, from_3d=True): """ Convert RDKit molecule to PyTorch Geometric Data object with stereochemistry edges. @@ -138,7 +139,7 @@ def mol_to_torch_geometric_simple(mol, smiles): # Add stereochemistry edges (CRITICAL for LoQI model!) chi_bonds = [7, 8] # R/S stereochemistry edge types ez_bonds = {Chem.BondStereo.STEREOE: 5, Chem.BondStereo.STEREOZ: 6} # E/Z edge types - edge_index, edge_attr = add_stereo_bonds(mol, chi_bonds, ez_bonds, edge_index, edge_attr, from_3D=True) + edge_index, edge_attr = add_stereo_bonds(mol, chi_bonds, ez_bonds, edge_index, edge_attr, from_3D=from_3d) return Data( x=atom_types, @@ -167,17 +168,17 @@ def generate_conformers_batch(smiles, model, cfg, n_confs=10): """ try: # Create base molecule - mol = smiles_to_mol(smiles, add_hs=True, embed_3d=True) + # embed_3d=False: coordinates are overwritten by Gaussian prior in the diffusion model. + # Stereochemistry is read from SMILES directly (AssignStereochemistry, not From3D). + mol = smiles_to_mol(smiles, add_hs=True, embed_3d=False) if mol is None: return None, None, "Invalid SMILES string or failed to embed 3D coordinates" - # Create data list for batch processing - data_list = [] - reference_mols = [] - for _ in range(n_confs): - data = mol_to_torch_geometric_simple(mol, smiles) - data_list.append(data) - reference_mols.append(Chem.Mol(mol)) # Copy of original molecule for reference + # Build PyG graph once; topology is identical for all conformers + # from_3D=False since we skipped ETKDG; stereo is read from SMILES + base_data = mol_to_torch_geometric_simple(mol, smiles, from_3d=False) + data_list = [copy(base_data) for _ in range(n_confs)] + reference_mols = [Chem.Mol(mol) for _ in range(n_confs)] # Create batch and move to device batch = Batch.from_data_list(data_list).to(model.device) diff --git a/scripts/batch_size_sweep.py b/scripts/batch_size_sweep.py new file mode 100644 index 0000000..5938855 --- /dev/null +++ b/scripts/batch_size_sweep.py @@ -0,0 +1,312 @@ +""" +Batch-size sweep: find optimal throughput and GPU utilization. + +Samples GPU SM% and memory-BW% every second via `nvidia-smi dmon` while +generation runs, then reports min/mean/max for each batch size. + +Usage: + conda run -n loqi python scripts/batch_size_sweep.py \ + --config scripts/conf/loqi/loqi.yaml \ + --ckpt /home/olexandr/geoopt/data/loqi.ckpt \ + --dataset_root /home/olexandr/geoopt/data/LoQI/chembl3d_stereo \ + --smiles_pickle /home/olexandr/geoopt/data/LoQI/chembl3d_stereo/processed/test_smiles.pickle +""" + +import sys, time, argparse, pickle, random, subprocess, threading, re +sys.path.insert(0, "src") + +import torch +import numpy as np +from omegaconf import OmegaConf + +from megalodon.inference import generate_conformers, validate_smiles + + +# ── GPU sampling via nvidia-smi dmon ───────────────────────────────────────── + +class GpuMonitor: + """Background thread that reads nvidia-smi dmon output every ~1s.""" + + def __init__(self, gpu_idx: int = 0): + self.gpu_idx = gpu_idx + self._proc = None + self._thread = None + self._samples: list[dict] = [] + self._running = False + + def start(self): + self._samples = [] + self._running = True + # -s um: sm + memory utilization; -d 1: 1-second interval; -i N: GPU index + cmd = ["nvidia-smi", "dmon", "-s", "um", "-d", "1", "-i", str(self.gpu_idx)] + self._proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True + ) + self._thread = threading.Thread(target=self._read, daemon=True) + self._thread.start() + + def _read(self): + for line in self._proc.stdout: + if not self._running: + break + line = line.strip() + if line.startswith("#") or not line: + continue + parts = line.split() + # format: gpu_idx sm mem enc dec jpg ofa fb bar1 ccpm + if len(parts) >= 3 and parts[0].isdigit(): + try: + self._samples.append({"sm": int(parts[1]), "mem": int(parts[2]), + "fb_mb": int(parts[7]) if len(parts) > 7 else 0}) + except ValueError: + pass + + def stop(self) -> dict: + self._running = False + if self._proc: + self._proc.terminate() + self._proc.wait() + if self._thread: + self._thread.join(timeout=2) + + if not self._samples: + return {"sm_mean": 0, "sm_max": 0, "mem_mean": 0, "mem_max": 0, "fb_max_mb": 0, "n": 0} + + sm = [s["sm"] for s in self._samples] + mem = [s["mem"] for s in self._samples] + fb = [s["fb_mb"] for s in self._samples] + return { + "sm_mean": int(np.mean(sm)), + "sm_max": int(np.max(sm)), + "mem_mean": int(np.mean(mem)), + "mem_max": int(np.max(mem)), + "fb_max_mb": int(np.max(fb)), + "n": len(sm), + } + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def sync(): + torch.cuda.synchronize() + +def gpu_device(): + return torch.cuda.current_device() + + +def load_model(ckpt_path, config_path, dataset_root): + from megalodon.models.module import Graph3DInterpolantModel + from megalodon.data.batch_preprocessor import BatchPreProcessor + cfg = OmegaConf.load(config_path) + if dataset_root: + OmegaConf.update(cfg, "data.dataset_root", dataset_root, merge=True) + model = Graph3DInterpolantModel.load_from_checkpoint( + ckpt_path, + loss_params=cfg.loss, + interpolant_params=cfg.interpolant, + sampling_params=cfg.sample, + batch_preprocessor=BatchPreProcessor(cfg.data.aug_rotations, cfg.data.scale_coords), + strict=False, + ) + return model.to("cuda").eval(), cfg + + +def load_smiles(smiles_pickle, n_mols, seed=42): + with open(smiles_pickle, "rb") as f: + all_smiles = pickle.load(f) + rng = random.Random(seed) + rng.shuffle(all_smiles) + + validated = [] + atom_counts = [] + for smi in all_smiles: + if len(validated) >= n_mols: + break + mol, err = validate_smiles(smi) + if err is None: + validated.append(smi) + atom_counts.append(mol.GetNumAtoms()) + + return validated, np.array(atom_counts) + + +def bar(val, max_val=100, width=20, fill="█", empty="░"): + n = int(val / max_val * width) + return fill * n + empty * (width - n) + + +def section(title): + print(f"\n{'═'*66}") + print(f" {title}") + print(f"{'═'*66}") + + +# ── sweep ──────────────────────────────────────────────────────────────────── + +def run_sweep(model, cfg, smiles_list, n_confs, gpu_idx, batch_sizes): + results = [] + + for bs in batch_sizes: + # Skip if previous run OOM'd + if results and results[-1].get("oom"): + results.append({"bs": bs, "oom": True}) + continue + + mon = GpuMonitor(gpu_idx) + torch.cuda.reset_peak_memory_stats() + sync() + + try: + mon.start() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=n_confs, + batch_size=bs, + ) + sync() + elapsed = time.perf_counter() - t0 + except torch.cuda.OutOfMemoryError: + mon.stop() + results.append({"bs": bs, "oom": True}) + torch.cuda.empty_cache() + continue + + gpu_stats = mon.stop() + peak_mb = torch.cuda.max_memory_allocated() / 1e6 + n_ok = result.n_success + rate = n_ok / elapsed + + results.append({ + "bs": bs, + "elapsed": elapsed, + "rate": rate, + "ms_per_conf": elapsed / n_ok * 1000, + "n_ok": n_ok, + "peak_mb": peak_mb, + "sm_mean": gpu_stats["sm_mean"], + "sm_max": gpu_stats["sm_max"], + "mem_mean": gpu_stats["mem_mean"], + "mem_max": gpu_stats["mem_max"], + "n_samples": gpu_stats["n"], + "oom": False, + }) + + return results + + +def print_results(results, n_mols, n_confs): + best_rate = max((r["rate"] for r in results if not r.get("oom")), default=1) + + print(f"\n {'bs':>5} {'conf/s':>7} {'ms/conf':>8} " + f"{'peak MB':>8} {'SM% avg':>8} {'SM% max':>8} " + f"{'BW% avg':>8} {'throughput bar'}") + print(f" {'─'*5} {'─'*7} {'─'*8} {'─'*8} {'─'*8} {'─'*8} {'─'*8} {'─'*20}") + + for r in results: + bs = r["bs"] + if r.get("oom"): + print(f" {bs:>5} {'OOM':>7}") + continue + rel = r["rate"] / best_rate + tbar = bar(rel, max_val=1.0, width=22) + print(f" {bs:>5} {r['rate']:>7.1f} {r['ms_per_conf']:>8.1f} " + f"{r['peak_mb']:>8.0f} {r['sm_mean']:>8d} {r['sm_max']:>8d} " + f"{r['mem_mean']:>8d} {tbar} {rel*100:.0f}%") + + # Find optimal: best throughput per memory dollar (rate / peak_mb) + valid = [r for r in results if not r.get("oom")] + if not valid: + return + best_throughput = max(valid, key=lambda r: r["rate"]) + best_efficiency = max(valid, key=lambda r: r["rate"] / r["peak_mb"]) + + print(f"\n Optimal throughput: batch_size={best_throughput['bs']} " + f"→ {best_throughput['rate']:.1f} conf/s " + f"(SM avg {best_throughput['sm_mean']}% " + f"peak {best_throughput['peak_mb']:.0f} MB)") + print(f" Optimal efficiency: batch_size={best_efficiency['bs']} " + f"→ {best_efficiency['rate']:.1f} conf/s / {best_efficiency['peak_mb']:.0f} MB " + f"= {best_efficiency['rate']/best_efficiency['peak_mb']*1000:.2f} conf·s⁻¹·GB⁻¹") + + +def print_utilization_timeline(results): + """Show how SM utilization changes with batch size.""" + section("GPU SM% utilization by batch size") + valid = [r for r in results if not r.get("oom")] + if not valid: + return + max_sm = max(r["sm_max"] for r in valid) or 1 + for r in valid: + b = bar(r["sm_mean"], max_val=100, width=30) + b2 = bar(r["sm_max"], max_val=100, width=30) + print(f" bs={r['bs']:>4} avg {r['sm_mean']:>3}% {b} " + f"max {r['sm_max']:>3}% {b2}") + + +# ── main ───────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True) + parser.add_argument("--ckpt", required=True) + parser.add_argument("--dataset_root", default=None) + parser.add_argument("--smiles_pickle", required=True) + parser.add_argument("--n_mols", type=int, default=200, + help="Number of molecules to use per sweep run") + parser.add_argument("--n_confs", type=int, default=1) + parser.add_argument("--gpu", type=int, default=0, + help="GPU index for nvidia-smi monitoring") + parser.add_argument("--batch_sizes", type=str, default="1,2,4,8,12,16,24,32,48,64,96,128", + help="Comma-separated batch sizes to sweep") + args = parser.parse_args() + + batch_sizes = [int(x) for x in args.batch_sizes.split(",")] + + # ── Load molecules ──────────────────────────────────────────────────────── + section(f"Loading {args.n_mols} molecules") + smiles_list, atom_counts = load_smiles(args.smiles_pickle, args.n_mols) + print(f" Loaded {len(smiles_list)} molecules") + print(f" Atom counts: min={atom_counts.min()} max={atom_counts.max()} " + f"mean={atom_counts.mean():.1f} median={np.median(atom_counts):.1f}") + + # Buckets + for label, lo, hi in [("tiny ≤20", 0, 20), ("small 21-40", 21, 40), + ("medium 41-60", 41, 60), ("large >60", 61, 9999)]: + n = ((atom_counts >= lo) & (atom_counts <= hi)).sum() + print(f" {label:14s}: {n:4d} ({100*n/len(atom_counts):.1f}%)") + + # ── Load model ──────────────────────────────────────────────────────────── + section("Loading model") + model, cfg = load_model(args.ckpt, args.config, args.dataset_root) + gpu_name = torch.cuda.get_device_name(args.gpu) + gpu_total_mb = torch.cuda.get_device_properties(args.gpu).total_memory / 1e6 + print(f" GPU {args.gpu}: {gpu_name} ({gpu_total_mb:.0f} MB total)") + print(f" Model: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params") + + # ── Warm-up ─────────────────────────────────────────────────────────────── + section("Warm-up (16 molecules, discarded)") + _ = generate_conformers(smiles_list[:16], model, cfg, n_confs=1, batch_size=8) + sync() + print(" Done.") + + # ── Sweep: n_confs=1 ───────────────────────────────────────────────────── + section(f"Sweep — {len(smiles_list)} mols × {args.n_confs} conf " + f"(batch sizes: {batch_sizes})") + results = run_sweep(model, cfg, smiles_list, args.n_confs, args.gpu, batch_sizes) + print_results(results, len(smiles_list), args.n_confs) + print_utilization_timeline(results) + + # ── Repeat with n_confs=5 for multi-conf view ───────────────────────────── + if args.n_confs == 1: + section(f"Sweep — {len(smiles_list)} mols × 5 conf (multi-conformer view)") + results5 = run_sweep(model, cfg, smiles_list, 5, args.gpu, batch_sizes) + print_results(results5, len(smiles_list), 5) + + section("Done") + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_inference.py b/scripts/benchmark_inference.py new file mode 100644 index 0000000..e921292 --- /dev/null +++ b/scripts/benchmark_inference.py @@ -0,0 +1,343 @@ +""" +Benchmark script for the new megalodon.inference API. + +Tests: + 1. Validation pipeline (bad SMILES, unsupported elements, etc.) + 2. Single-conformer batched generation (batch_size control) + 3. Multi-conformer generation (n_confs > 1 per molecule) + 4. FFD atom-count bin-packing vs fixed batch_size + 5. Timing breakdown at each stage + 6. Accuracy probe: SMILES round-trip identity, atom-count conservation + +Usage: + conda run -n loqi python scripts/benchmark_inference.py \ + --config scripts/conf/loqi/loqi.yaml \ + --ckpt /home/olexandr/geoopt/data/loqi.ckpt +""" + +import sys, time, argparse +sys.path.insert(0, "src") + +import torch +from omegaconf import OmegaConf +from rdkit import Chem + +# ── Diverse drug-like SMILES ──────────────────────────────────────────────── +SMILES_SET = { + # Name → SMILES (all single-fragment, drug-like) + "aspirin": "CC(=O)Oc1ccccc1C(=O)O", + "caffeine": "Cn1cnc2c1c(=O)n(C)c(=O)n2C", + "ibuprofen": "CC(C)Cc1ccc(cc1)C(C)C(=O)O", + "paracetamol": "CC(=O)Nc1ccc(O)cc1", + "naproxen": "COc1ccc2cc(ccc2c1)C(C)C(=O)O", + "ciprofloxacin":"O=C(O)c1cn(C2CC2)c2cc(N3CCNCC3)c(F)cc2c1=O", + "metformin": "CN(C)C(=N)NC(=N)N", + "atorvastatin": "CC(C)c1c(C(=O)Nc2ccccc2F)c(-c2ccccc2)c(-c2ccc(F)cc2)n1CCC(O)CC(O)CC(=O)O", + "lisinopril": "NCCCC[C@@H](N[C@@H](CCc1ccccc1)C(=O)O)C(=O)N1CCC[C@H]1C(=O)O", + "tamoxifen": "CCC(=C(c1ccccc1)c1ccc(OCCN(C)C)cc1)c1ccccc1", + "morphine": "OC1=CC=C2C[C@H]3N(CCc4cc5c(cc4O3)OCC5)C[C@@H]2C1", # complex ring system + "imatinib": "Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1", # big drug + "benzene": "c1ccccc1", # tiny + "dopamine": "NCCc1ccc(O)c(O)c1", + "serotonin": "NCCc1c[nH]c2ccc(O)cc12", + # Edge cases for validation + "has_fluorine": "Fc1ccccc1", + "has_chlorine": "Clc1ccccc1", + "has_bromine": "Brc1ccccc1", + "has_phosphorus": "OP(O)(O)=O", + "has_sulfur": "c1ccc(S)cc1", +} + +BAD_SMILES = { + "empty": "", + "garbage": "NOT_A_SMILES_XYZ", + "salt": "CC(=O)O.[Na+]", + "radical": "[CH3]", # radical (unlikely but valid test) + "unsupported": "[Pt](Cl)(Cl)(Cl)Cl", # Pt not in vocab +} + +ALL_GOOD_SMILES = list(SMILES_SET.values()) +ALL_NAMES = list(SMILES_SET.keys()) + + +def section(title: str): + print(f"\n{'═'*60}") + print(f" {title}") + print(f"{'═'*60}") + + +def run_validation_tests(): + section("1. Validation Pipeline") + from megalodon.inference import validate_smiles + + # Good SMILES + good_pass = 0 + for name, smi in SMILES_SET.items(): + mol, err = validate_smiles(smi) + if mol is not None and err is None: + good_pass += 1 + else: + print(f" UNEXPECTED FAIL [{name}]: {err}") + print(f" Good SMILES: {good_pass}/{len(SMILES_SET)} passed validation") + + # Bad SMILES + bad_caught = 0 + for name, smi in BAD_SMILES.items(): + mol, err = validate_smiles(smi) + if err is not None and mol is None: + bad_caught += 1 + print(f" [CAUGHT] {name}: {err[:80]}") + else: + print(f" MISSED {name}: expected failure but got mol") + print(f" Bad SMILES: {bad_caught}/{len(BAD_SMILES)} correctly rejected") + + +def load_model(ckpt_path: str, config_path: str, dataset_root: str = None): + section("2. Model Loading") + from megalodon.models.module import Graph3DInterpolantModel + from megalodon.data.batch_preprocessor import BatchPreProcessor + + cfg = OmegaConf.load(config_path) + if dataset_root is not None: + OmegaConf.update(cfg, "data.dataset_root", dataset_root, merge=True) + t0 = time.perf_counter() + model = Graph3DInterpolantModel.load_from_checkpoint( + ckpt_path, + loss_params=cfg.loss, + interpolant_params=cfg.interpolant, + sampling_params=cfg.sample, + batch_preprocessor=BatchPreProcessor(cfg.data.aug_rotations, cfg.data.scale_coords), + strict=False, # 'freqs' buffers added in inference-opt; not in older checkpoints + ) + model = model.to("cuda").eval() + t1 = time.perf_counter() + print(f" Model loaded in {t1-t0:.2f}s") + n_params = sum(p.numel() for p in model.parameters()) + print(f" Parameters: {n_params/1e6:.1f}M") + return model, cfg + + +def run_single_conf_benchmark(model, cfg, smiles_list, names): + section("3. Single-Conformer Batched Generation") + from megalodon.inference import generate_conformers + + for batch_size in [4, 8, 16]: + torch.cuda.synchronize() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=1, + batch_size=batch_size, + ) + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + + n_ok = result.n_success + n_err = result.n_errors + rate = n_ok / elapsed + print(f" batch_size={batch_size:2d} | {elapsed:6.2f}s | " + f"{n_ok}/{len(smiles_list)} OK | {n_err} errors | " + f"{rate:.1f} conformers/s") + if result.errors: + for e in result.errors: + print(f" SKIP [{names[e.index]}]: {e.error[:70]}") + return result # return last result for accuracy checks + + +def run_multi_conf_benchmark(model, cfg, smiles_list, names): + section("4. Multi-Conformer Generation (n_confs=5)") + from megalodon.inference import generate_conformers + + torch.cuda.synchronize() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=5, + batch_size=16, + ) + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + + total_confs = result.n_success + n_mol = len(result.conformers) + print(f" Generated {total_confs} conformers for {n_mol} molecules in {elapsed:.2f}s") + print(f" {elapsed/total_confs*1000:.1f} ms/conformer | {total_confs/elapsed:.1f} conformers/s") + print(f" Conformers per molecule:") + for smi, mols in result.conformers.items(): + name = names[smiles_list.index(smi)] if smi in smiles_list else smi[:20] + print(f" {name}: {len(mols)} conformers") + return result + + +def run_variable_nconfs_benchmark(model, cfg, smiles_list, names): + section("5. Variable n_confs per molecule") + from megalodon.inference import generate_conformers + + # Give small molecules more conformers, big ones fewer + from megalodon.inference import validate_smiles + n_confs_list = [] + for smi in smiles_list: + mol, _ = validate_smiles(smi) + n_atoms = mol.GetNumAtoms() if mol else 30 + n_confs_list.append(max(1, 10 - n_atoms // 10)) + + torch.cuda.synchronize() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=n_confs_list, + batch_size=16, + ) + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + + total = result.n_success + print(f" {total} conformers total in {elapsed:.2f}s ({total/elapsed:.1f}/s)") + print(f" Per-mol requested vs generated:") + for (smi, n_req), name in zip(zip(smiles_list, n_confs_list), names): + n_gen = len(result.conformers.get(smi, [])) + print(f" {name:20s}: requested={n_req}, generated={n_gen}") + return result + + +def run_ffd_benchmark(model, cfg, smiles_list, names): + section("6. FFD Bin-Packing vs Fixed batch_size") + from megalodon.inference import generate_conformers, validate_smiles, ffd_pack_indices + + # Compute atom counts for reference + print(" Atom counts per molecule:") + total_atoms = 0 + for name, smi in zip(names, smiles_list): + mol, _ = validate_smiles(smi) + n = mol.GetNumAtoms() if mol else 0 + total_atoms += n + print(f" {name:20s}: {n:3d} atoms") + print(f" Total atoms: {total_atoms}, avg: {total_atoms/len(smiles_list):.1f}") + + # FFD with max_atoms_per_batch + for max_atoms in [128, 256, 512]: + torch.cuda.synchronize() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=1, + max_atoms_per_batch=max_atoms, + ) + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + print(f" FFD max_atoms={max_atoms:4d} | {elapsed:6.2f}s | " + f"{result.n_success}/{len(smiles_list)} OK") + + # Fixed batch_size for comparison + for bs in [8, 16]: + torch.cuda.synchronize() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=1, + batch_size=bs, + ) + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + print(f" Fixed batch_size={bs:4d} | {elapsed:6.2f}s | " + f"{result.n_success}/{len(smiles_list)} OK") + + +def run_accuracy_checks(result, smiles_list, names): + section("7. Accuracy Checks") + from megalodon.inference import validate_smiles + + n_checked = 0 + n_correct_atoms = 0 + n_has_conf = 0 + n_3d_coords_nonzero = 0 + atom_errors = [] + + for smi, conf_mols in result.conformers.items(): + ref_mol, _ = validate_smiles(smi) + ref_atoms = ref_mol.GetNumAtoms() if ref_mol else None + for i, gen_mol in enumerate(conf_mols): + n_checked += 1 + if gen_mol is None: + continue + if gen_mol.GetNumConformers() > 0: + n_has_conf += 1 + pos = gen_mol.GetConformer().GetPositions() + if pos.std() > 0.1: + n_3d_coords_nonzero += 1 + gen_atoms = gen_mol.GetNumAtoms() + if ref_atoms is not None and gen_atoms == ref_atoms: + n_correct_atoms += 1 + elif ref_atoms is not None: + name = names[smiles_list.index(smi)] if smi in smiles_list else smi[:20] + atom_errors.append(f" {name}: ref={ref_atoms}, gen={gen_atoms}") + + print(f" Conformers checked: {n_checked}") + print(f" Has 3D conformer: {n_has_conf}/{n_checked}") + print(f" 3D coords non-trivial: {n_3d_coords_nonzero}/{n_has_conf}") + print(f" Correct atom count: {n_correct_atoms}/{n_checked}") + if atom_errors: + print(" Atom count mismatches:") + for e in atom_errors: + print(e) + + # SMILES round-trip + n_roundtrip = 0 + for smi, conf_mols in result.conformers.items(): + for gen_mol in conf_mols: + if gen_mol is None: + continue + try: + gen_smi = Chem.MolToSmiles(Chem.RemoveHs(gen_mol)) + ref_smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi)) + if gen_smi == ref_smi: + n_roundtrip += 1 + except Exception: + pass + print(f" SMILES round-trip match: {n_roundtrip}/{n_checked}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True) + parser.add_argument("--ckpt", required=True) + parser.add_argument("--dataset_root", default=None, + help="Override data.dataset_root in config (needed if config has stale path)") + parser.add_argument("--quick", action="store_true", + help="Use only 6 small molecules for fast iteration") + args = parser.parse_args() + + smiles_list = ALL_GOOD_SMILES + names = ALL_NAMES + + if args.quick: + keep = ["aspirin", "caffeine", "ibuprofen", "paracetamol", "naproxen", "dopamine"] + idx = [ALL_NAMES.index(k) for k in keep if k in ALL_NAMES] + smiles_list = [ALL_GOOD_SMILES[i] for i in idx] + names = [ALL_NAMES[i] for i in idx] + print(f"[quick mode] Using {len(smiles_list)} small molecules") + + run_validation_tests() + model, cfg = load_model(args.ckpt, args.config, dataset_root=args.dataset_root) + + result_single = run_single_conf_benchmark(model, cfg, smiles_list, names) + result_multi = run_multi_conf_benchmark(model, cfg, smiles_list, names) + run_variable_nconfs_benchmark(model, cfg, smiles_list, names) + run_ffd_benchmark(model, cfg, smiles_list, names) + run_accuracy_checks(result_multi, smiles_list, names) + + section("Done") + + +if __name__ == "__main__": + main() diff --git a/scripts/sample_conformers.py b/scripts/sample_conformers.py index 56ee051..aee99b6 100644 --- a/scripts/sample_conformers.py +++ b/scripts/sample_conformers.py @@ -1,22 +1,16 @@ import os -import pickle from argparse import ArgumentParser from rdkit import Chem from rdkit.Chem import AllChem -from tqdm import tqdm from torch_geometric.data import DataLoader import torch import numpy as np from omegaconf import OmegaConf -from copy import deepcopy +from copy import copy, deepcopy from torch_geometric.data import Data from megalodon.models.module import Graph3DInterpolantModel from megalodon.data.batch_preprocessor import BatchPreProcessor -from megalodon.data.statistics import Statistics -from megalodon.metrics.conformer_evaluation_callback import ( - ConformerEvaluationCallback, write_coords_to_mol, convert_coords_to_np -) from megalodon.metrics.molecule_evaluation_callback import full_atom_encoder @@ -174,9 +168,10 @@ def mols_to_data_list(mols, n_confs=1, use_3d=True): pos = mol.GetConformer().GetPositions() if use_3d and mol.GetNumConformers() > 0 else None + # Build topology once, then replicate cheaply + base_data = raw_to_pyg(Chem.Mol(mol), pos, use_3d=use_3d) for _ in range(n_confs): - data = raw_to_pyg(Chem.Mol(mol), pos, use_3d=use_3d) - data_list.append(data) + data_list.append(copy(base_data)) return data_list @@ -187,9 +182,9 @@ def main(): parser.add_argument("--ckpt", type=str, required=True) parser.add_argument("--output", type=str, required=True) parser.add_argument("--n_confs", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--no_3d", action="store_true", help="Skip 3D embedding generation") - parser.add_argument("--skip_eval", action="store_true", help="Skip evaluation") + parser.add_argument("--batch_size", type=int, default=48) + parser.add_argument("--no_3d", action="store_true", help="Skip 3D embedding for SDF input preprocessing (does not affect diffusion-based coordinate generation)") + parser.add_argument("--skip_eval", action="store_true", help="(Ignored) Evaluation metrics are not computed by this script; kept for backward compatibility") args = parser.parse_args() # Load model @@ -199,70 +194,57 @@ def main(): loss_params=cfg.loss, interpolant_params=cfg.interpolant, sampling_params=cfg.sample, - batch_preporcessor=BatchPreProcessor(cfg.data.aug_rotations, cfg.data.scale_coords) + batch_preprocessor=BatchPreProcessor(cfg.data.aug_rotations, cfg.data.scale_coords) ) model = model.to("cuda").eval() - # Load molecules and replicate them n_confs times + # Load molecules use_3d = not args.no_3d mols = load_rdkit_molecules(args.input, use_3d=use_3d) - data_list = mols_to_data_list(mols, n_confs=args.n_confs, use_3d=use_3d) - loader = DataLoader(data_list, batch_size=args.batch_size) - # Sampling + # Build SMILES list and preserve _Name property from SDF inputs. + # Chem.MolToSmiles() produces canonical SMILES, which becomes the key in + # ConformerGenerationResult.conformers. Map canonical SMILES back to the + # original mol name so pickle output stays compatible with the old format. + smiles_list = [Chem.MolToSmiles(m) for m in mols] + smiles_to_name = { + smi: (m.GetProp("_Name") if m.HasProp("_Name") else smi) + for smi, m in zip(smiles_list, mols) + } + + # Sampling via inference API + from megalodon.inference import generate_conformers + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=args.n_confs, + batch_size=args.batch_size, + ) + generated = [] - references = [] if not args.skip_eval else None ids = [] - - for batch in tqdm(loader, desc="Sampling"): - batch = batch.to(model.device) - sample = model.sample(batch=batch, timesteps=cfg.interpolant.timesteps, pre_format=True) - coords_list = convert_coords_to_np(sample) - mols_gen = [write_coords_to_mol(mol, coords) for mol, coords in zip(batch["mol"], coords_list)] - generated.extend(mols_gen) - if not args.skip_eval: - references.extend(batch["mol"]) - ids.extend([m.GetProp("_Name") if m.HasProp("_Name") else "NA" for m in batch["mol"]]) + for smiles, conf_mols in result.conformers.items(): + generated.extend(conf_mols) + name = smiles_to_name.get(smiles, smiles) + ids.extend([name] * len(conf_mols)) + + for err in result.errors: + print(f"WARNING: skipped SMILES at index {err.index}: {err.error}") # Save output if args.output.endswith(".sdf"): - from rdkit.Chem import SDWriter - writer = SDWriter(args.output) - for mol in generated: - writer.write(mol) - writer.close() + with open(args.output, "w") as f: + f.write(result.to_sdf()) else: + import pickle output_dict = {"generated": generated, "ids": ids} - if references is not None: - output_dict["reference"] = references with open(args.output, "wb") as f: pickle.dump(output_dict, f) - # Evaluate only if references are available and evaluation is not skipped - if not args.skip_eval and references: - stats = Statistics.load_statistics(cfg.data.dataset_root + "/processed", "train") - eval_cb = ConformerEvaluationCallback( - timesteps=cfg.evaluation.timesteps, - compute_3D_metrics=cfg.evaluation.compute_3D_metrics, - compute_energy_metrics=cfg.evaluation.compute_energy_metrics, - energy_metrics_args=OmegaConf.to_container(cfg.evaluation.energy_metrics_args, - resolve=True), - statistics=stats, - scale_coords=cfg.evaluation.scale_coords, - compute_stereo_metrics=True - ) - for gen, ref in zip(generated, references): - if ref.GetNumConformers() == 0: - ref.AddConformer(Chem.Conformer(ref.GetNumAtoms())) - conf = gen.GetConformer(0) - pos = conf.GetPositions() - conf.SetPositions(pos) - ref.AddConformer(conf) - results = eval_cb.evaluate_molecules(generated, reference_molecules=references, device=model.device) - print("Evaluation Results:") - print(results) - - print(f"Generated {len(generated)} conformers for {len(set(ids))} unique molecules.") + print(f"Generated {result.n_success} conformers for " + f"{len(result.conformers)} unique molecules " + f"({result.n_errors} SMILES failed validation).") if __name__ == "__main__": diff --git a/scripts/sustained_perf_test.py b/scripts/sustained_perf_test.py new file mode 100644 index 0000000..e73650a --- /dev/null +++ b/scripts/sustained_perf_test.py @@ -0,0 +1,344 @@ +""" +Sustained performance test using real ChEMBL3D test-set SMILES. + +Samples molecules across the full atom-count range, runs large batches, +measures GPU memory, throughput, and timing breakdowns. + +Usage: + conda run -n loqi python scripts/sustained_perf_test.py \ + --config scripts/conf/loqi/loqi.yaml \ + --ckpt /home/olexandr/geoopt/data/loqi.ckpt \ + --dataset_root /home/olexandr/geoopt/data/LoQI/chembl3d_stereo \ + --smiles_pickle /home/olexandr/geoopt/data/LoQI/chembl3d_stereo/processed/test_smiles.pickle \ + --n_mols 500 \ + --n_confs 1 +""" + +import sys, time, argparse, pickle, random, textwrap +sys.path.insert(0, "src") + +import torch +import numpy as np +from omegaconf import OmegaConf +from rdkit import Chem + +from megalodon.inference import generate_conformers, validate_smiles, ffd_pack_indices + + +# ── helpers ───────────────────────────────────────────────────────────────── + +def gpu_mem_gb(): + if torch.cuda.is_available(): + return torch.cuda.max_memory_allocated() / 1e9 + return 0.0 + +def reset_peak(): + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + +def sync(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def load_model(ckpt_path, config_path, dataset_root): + from megalodon.models.module import Graph3DInterpolantModel + from megalodon.data.batch_preprocessor import BatchPreProcessor + cfg = OmegaConf.load(config_path) + if dataset_root: + OmegaConf.update(cfg, "data.dataset_root", dataset_root, merge=True) + model = Graph3DInterpolantModel.load_from_checkpoint( + ckpt_path, + loss_params=cfg.loss, + interpolant_params=cfg.interpolant, + sampling_params=cfg.sample, + batch_preprocessor=BatchPreProcessor(cfg.data.aug_rotations, cfg.data.scale_coords), + strict=False, + ) + return model.to("cuda").eval(), cfg + + +def sample_smiles(smiles_pickle, n_mols, seed=42): + """ + Load SMILES and stratify by atom count so we get a representative mix + of tiny (≤20), small (21-40), medium (41-60), and large (>60 atom) mols. + """ + with open(smiles_pickle, "rb") as f: + all_smiles = pickle.load(f) + + rng = random.Random(seed) + rng.shuffle(all_smiles) + + # Validate and count atoms + validated = [] + for smi in all_smiles: + if len(validated) >= n_mols * 4: # over-sample then stratify + break + mol, err = validate_smiles(smi) + if err is None: + n_atoms = mol.GetNumAtoms() + validated.append((smi, n_atoms)) + + # Stratify: 25% per bracket + buckets = { + "tiny": [s for s, n in validated if n <= 20], + "small": [s for s, n in validated if 21 <= n <= 40], + "medium": [s for s, n in validated if 41 <= n <= 60], + "large": [s for s, n in validated if n > 60], + } + per_bucket = n_mols // 4 + result = [] + for label, pool in buckets.items(): + chosen = pool[:per_bucket] + result.extend(chosen) + print(f" {label:6s}: {len(chosen):4d} molecules " + f"(pool {len(pool)})") + + # Top up from any bucket if needed + remainder = n_mols - len(result) + if remainder > 0: + all_valid = [s for s, _ in validated if s not in set(result)] + result.extend(all_valid[:remainder]) + + rng.shuffle(result) + print(f" Total sampled: {len(result)}") + return result + + +def atom_count_stats(smiles_list): + counts = [] + for smi in smiles_list: + mol = Chem.MolFromSmiles(smi) + if mol: + mol = Chem.AddHs(mol) + counts.append(mol.GetNumAtoms()) + arr = np.array(counts) + return arr + + +def print_histogram(counts, bins=8, label="Atom count distribution"): + hist, edges = np.histogram(counts, bins=bins) + peak = max(hist) + bar_width = 30 + print(f"\n {label}:") + for i, (lo, hi, h) in enumerate(zip(edges, edges[1:], hist)): + bar = "█" * int(h / peak * bar_width) + print(f" {lo:4.0f}–{hi:4.0f}: {bar:<{bar_width}} {h}") + print(f" min={counts.min():.0f} max={counts.max():.0f} " + f"mean={counts.mean():.1f} median={np.median(counts):.1f}") + + +def section(title): + print(f"\n{'═'*62}") + print(f" {title}") + print(f"{'═'*62}") + + +# ── benchmark runs ──────────────────────────────────────────────────────────── + +def run_batch_size_sweep(model, cfg, smiles_list, n_confs, label=""): + section(f"Batch-size sweep — {len(smiles_list)} mols × {n_confs} conf {label}") + results = {} + for bs in [8, 16, 32, 64]: + reset_peak() + sync() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=n_confs, + batch_size=bs, + ) + sync() + elapsed = time.perf_counter() - t0 + peak_gb = gpu_mem_gb() + n_ok = result.n_success + rate = n_ok / elapsed + results[bs] = dict(elapsed=elapsed, rate=rate, peak_gb=peak_gb, + n_ok=n_ok, n_err=result.n_errors) + print(f" batch_size={bs:3d} | {elapsed:7.2f}s | " + f"{n_ok}/{len(smiles_list)*n_confs} OK | " + f"{rate:6.1f} conf/s | peak GPU {peak_gb:.2f} GB") + return results + + +def run_ffd_sweep(model, cfg, smiles_list, n_confs): + section(f"FFD atom-count sweep — {len(smiles_list)} mols × {n_confs} conf") + + # Show what FFD bins look like at each cap + mol_atoms = atom_count_stats(smiles_list) + for cap in [256, 512, 1024, 2048]: + bins = ffd_pack_indices(mol_atoms.tolist(), cap) + sizes = [len(b) for b in bins] + print(f" cap={cap:5d}: {len(bins):3d} bins, " + f"mols/bin min={min(sizes)} max={max(sizes)} avg={np.mean(sizes):.1f}") + print() + + for cap in [256, 512, 1024, 2048]: + reset_peak() + sync() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=n_confs, + max_atoms_per_batch=cap, + ) + sync() + elapsed = time.perf_counter() - t0 + peak_gb = gpu_mem_gb() + n_ok = result.n_success + print(f" max_atoms={cap:5d} | {elapsed:7.2f}s | " + f"{n_ok}/{len(smiles_list)*n_confs} OK | " + f"{n_ok/elapsed:6.1f} conf/s | peak GPU {peak_gb:.2f} GB") + + +def run_multi_conf_scaling(model, cfg, smiles_list): + section(f"Multi-conformer scaling — {len(smiles_list)} unique molecules") + for n_confs in [1, 5, 10, 20]: + reset_peak() + sync() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=n_confs, + batch_size=32, + ) + sync() + elapsed = time.perf_counter() - t0 + peak_gb = gpu_mem_gb() + total_confs = result.n_success + print(f" n_confs={n_confs:3d} | {elapsed:7.2f}s | " + f"{total_confs} conformers | " + f"{total_confs/elapsed:6.1f} conf/s | " + f"{elapsed/total_confs*1000:.1f} ms/conf | " + f"peak GPU {peak_gb:.2f} GB") + + +def run_sustained_load(model, cfg, smiles_list, n_confs, batch_size, n_rounds=5): + section(f"Sustained load — {n_rounds} rounds × {len(smiles_list)} mols × {n_confs} conf " + f"(batch_size={batch_size})") + round_times = [] + total_confs = 0 + t_wall_start = time.perf_counter() + for i in range(n_rounds): + reset_peak() + sync() + t0 = time.perf_counter() + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=n_confs, + batch_size=batch_size, + ) + sync() + elapsed = time.perf_counter() - t0 + peak_gb = gpu_mem_gb() + total_confs += result.n_success + round_times.append(elapsed) + print(f" round {i+1}/{n_rounds}: {elapsed:.2f}s " + f"{result.n_success/elapsed:.1f} conf/s " + f"peak {peak_gb:.2f} GB") + + wall = time.perf_counter() - t_wall_start + arr = np.array(round_times) + print(f"\n Rounds: mean={arr.mean():.2f}s std={arr.std():.2f}s " + f"cv={arr.std()/arr.mean()*100:.1f}%") + print(f" Total: {total_confs} conformers in {wall:.1f}s " + f"({total_confs/wall:.1f} sustained conf/s)") + + +def run_accuracy_probe(model, cfg, smiles_list, n_confs=3): + section(f"Accuracy probe — {len(smiles_list)} mols × {n_confs} conf") + result = generate_conformers( + smiles_list=smiles_list, + model=model, + cfg=cfg, + n_confs=n_confs, + batch_size=32, + ) + + n_checked = n_correct_atoms = n_has_3d = n_roundtrip = 0 + atom_errors = [] + + for smi, conf_mols in result.conformers.items(): + ref_mol = Chem.AddHs(Chem.MolFromSmiles(smi)) + ref_n = ref_mol.GetNumAtoms() + for gen_mol in conf_mols: + n_checked += 1 + if gen_mol is None: + continue + if gen_mol.GetNumConformers() > 0: + n_has_3d += 1 + if gen_mol.GetNumAtoms() == ref_n: + n_correct_atoms += 1 + else: + atom_errors.append((smi, ref_n, gen_mol.GetNumAtoms())) + try: + gen_smi = Chem.MolToSmiles(Chem.RemoveHs(gen_mol)) + if gen_smi == Chem.MolToSmiles(Chem.MolFromSmiles(smi)): + n_roundtrip += 1 + except Exception: + pass + + print(f" Conformers checked: {n_checked}") + print(f" Has 3D conformer: {n_has_3d}/{n_checked} " + f"({100*n_has_3d/max(n_checked,1):.1f}%)") + print(f" Correct atom count: {n_correct_atoms}/{n_checked} " + f"({100*n_correct_atoms/max(n_checked,1):.1f}%)") + print(f" SMILES round-trip: {n_roundtrip}/{n_checked} " + f"({100*n_roundtrip/max(n_checked,1):.1f}%)") + if atom_errors: + print(f" Atom count mismatches ({len(atom_errors)}):") + for smi, r, g in atom_errors[:5]: + print(f" {smi[:50]}: ref={r} gen={g}") + print(f" Errors (failed validation): {result.n_errors}") + + +# ── main ───────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True) + parser.add_argument("--ckpt", required=True) + parser.add_argument("--dataset_root", default=None) + parser.add_argument("--smiles_pickle", required=True) + parser.add_argument("--n_mols", type=int, default=500) + parser.add_argument("--n_confs", type=int, default=1) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + section(f"Setup — sampling {args.n_mols} molecules from ChEMBL3D test set") + smiles_list = sample_smiles(args.smiles_pickle, args.n_mols, seed=args.seed) + + counts = atom_count_stats(smiles_list) + print_histogram(counts, bins=10) + + print(f"\n Loading model...") + model, cfg = load_model(args.ckpt, args.config, args.dataset_root) + n_params = sum(p.numel() for p in model.parameters()) + device = next(model.parameters()).device + print(f" {n_params/1e6:.1f}M params on {device} " + f"({torch.cuda.get_device_name(0)})") + + # Warm-up: one small batch so CUDA kernels are compiled + section("Warm-up (10 molecules)") + _ = generate_conformers(smiles_list[:10], model, cfg, n_confs=1, batch_size=8) + print(" Done.") + + run_batch_size_sweep(model, cfg, smiles_list, args.n_confs) + run_ffd_sweep(model, cfg, smiles_list, args.n_confs) + run_multi_conf_scaling(model, cfg, smiles_list[:100]) # 100 unique mols + run_sustained_load(model, cfg, smiles_list, args.n_confs, batch_size=32, n_rounds=5) + run_accuracy_probe(model, cfg, smiles_list[:200], n_confs=3) + + section("Done") + + +if __name__ == "__main__": + main() diff --git a/src/megalodon/dynamics/fn_model.py b/src/megalodon/dynamics/fn_model.py index 2dc4deb..e5d910c 100644 --- a/src/megalodon/dynamics/fn_model.py +++ b/src/megalodon/dynamics/fn_model.py @@ -126,7 +126,7 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256): + def __init__(self, hidden_size, frequency_embedding_size=256, max_period=10000): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), @@ -134,24 +134,20 @@ def __init__(self, hidden_size, frequency_embedding_size=256): nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size + half = frequency_embedding_size // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ) + self.register_buffer("freqs", freqs) - @staticmethod - def timestep_embedding(t, dim, max_period=10000): + def timestep_embedding(self, t, dim): """ - Create sinusoidal timestep embeddings. + Create sinusoidal timestep embeddings using pre-registered frequency buffer. :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=t.device - ) - args = t[:, None].float() * freqs[None] + args = t[:, None].float() * self.freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) @@ -438,7 +434,6 @@ def __init__( self.mask_z = mask_z self.lin_edge0 = nn.Linear(hidden_size, edge_hidden_size, bias=False) - self.lin_edge1 = nn.Linear(hidden_size, hidden_size, bias=False) self.lin_edge1 = nn.Linear(edge_hidden_size + dist_size, edge_hidden_size, bias=False) self.ffn_norm_edge = BatchLayerNorm(edge_hidden_size) self.ffn_edge = swiglu_ffn_edge(edge_hidden_size, bias=False) @@ -463,6 +458,7 @@ def forward( dist: torch.Tensor = None, edge_batch: torch.Tensor = None, Z: torch.Tensor = None, + precomputed_attn_mask: torch.Tensor = None, ): """ This assume pytorch geometric batching so batch size of 1 so skip rotary as it depends on having an actual batch @@ -515,15 +511,24 @@ def forward( Q, K, V = map(reshaper, (Q, K, V)) if x.dim() == 2: - attn_mask = batch.unsqueeze(0) == batch.unsqueeze(1) - attn_mask = attn_mask.unsqueeze(0).unsqueeze( - 0 - ) # ! if float it is added as the biasbut would still need a mask s -infs? + if precomputed_attn_mask is not None: + attn_mask = precomputed_attn_mask + else: + attn_mask = (batch.unsqueeze(0) == batch.unsqueeze(1)).unsqueeze(0).unsqueeze(0) else: attn_mask = batch if Z is not None: if x.dim() == 2: + # NOTE: This branch assumes attn_mask is a boolean tensor (True=attend, + # False=block) and converts it to a float additive bias via == 0 / == 1 + # checks. If the caller already passes a float additive bias (as + # module.py's sample() now does), those checks will not behave as + # intended: masked_fill(== 0) would zero-out valid positions and + # masked_fill(== 1) would be a no-op. LoQI's config never sets Z + # (pair embeddings), so this branch is unreachable in practice, but + # it must be updated before pair embeddings are used with batched + # inference. mask = torch.ones((x.size(0), x.size(0))) if self.mask_z: mask.fill_diagonal_(0) @@ -644,7 +649,6 @@ def __init__( self.dist_projection = nn.Linear(n_vector_features, dist_size, bias=False) def forward(self, batch, X, H, E_idx, E, t): - torch.max(batch) + 1 pos = self.coord_emb(X.unsqueeze(-1)) # N x 3 x K H = self.atom_embedder(H) @@ -739,33 +743,50 @@ def __init__( self.dist_projection = nn.Linear(n_vector_features, dist_size, bias=False) self.return_features = return_features + self._te_cache: dict = {} - def forward(self, batch, X, H, E_idx, E, t): - torch.max(batch) + 1 + def clear_te_cache(self) -> None: + """Clear the timestep embedding cache. Call between inference runs if timesteps change.""" + self._te_cache.clear() + + def forward(self, batch, X, H, E_idx, E, t, precomputed_attn_mask=None): pos = self.coord_emb(X.unsqueeze(-1)) # N x 3 x K H = self.atom_embedder(H) E = self.edge_embedder(E) # should be + n_vector_features not + 1 edge_batch = batch[E_idx[0]] - te_h = self.node_time_embedding(t) - te_e = self.edge_time_embedding(t) - # te_h = te[batch] - # te_e = te[batch[E_idx[0]]] + # Cache is only safe when all molecules share the same timestep (inference path). + # During training, each molecule gets a different random t — fall back to full call. + if (t == t[0]).all(): + t_key = t[0].item() + if t_key not in self._te_cache: + self._te_cache[t_key] = ( + self.node_time_embedding(t[:1]), # shape [1, hidden_dim] + self.edge_time_embedding(t[:1]), # shape [1, edge_hidden_dim] + ) + te_h_single, te_e_single = self._te_cache[t_key] + te_h = te_h_single.expand(t.shape[0], -1) + te_e = te_e_single.expand(t.shape[0], -1) + else: + te_h = self.node_time_embedding(t) + te_e = self.edge_time_embedding(t) edge_attr = E for layer_index in range(len(self.dit_layers)): proj_pos = self.dist_projection(pos) distances = coord2distfn(proj_pos, E_idx, self.scale_dist_features, batch) # E x K # import ipdb; ipdb.set_trace() - H, edge_attr = self.dit_layers[layer_index](batch, H, te_h, edge_attr, E_idx, te_e, - distances, edge_batch) + H, edge_attr = self.dit_layers[layer_index]( + batch, H, te_h, edge_attr, E_idx, te_e, distances, edge_batch, + precomputed_attn_mask=precomputed_attn_mask, + ) pos = self.egnn_layers[layer_index](batch, pos, H, E_idx, edge_attr, te_e) X = self.coord_pred(pos).squeeze(-1) x = X - scatter_mean(X, index=batch, dim=0)[batch] out = { - "x_hat": x, + "x_hat": x, "H": H } return out \ No newline at end of file diff --git a/src/megalodon/inference/__init__.py b/src/megalodon/inference/__init__.py new file mode 100644 index 0000000..97ca84c --- /dev/null +++ b/src/megalodon/inference/__init__.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rdkit import Chem + +# Use legacy stereo perception to match training-time stereo assignments. +# scripts/sample_conformers.py sets this same flag at module level (line 23). +Chem.SetUseLegacyStereoPerception(True) + +from megalodon.inference.batching import ffd_pack_indices, pack_batches +from megalodon.inference.generation import generate_conformers +from megalodon.inference.result import ConformerGenerationResult, MoleculeProcessingError +from megalodon.inference.validation import validate_smiles + +__all__ = [ + "generate_conformers", + "ConformerGenerationResult", + "MoleculeProcessingError", + "validate_smiles", + "pack_batches", + "ffd_pack_indices", +] diff --git a/src/megalodon/inference/batching.py b/src/megalodon/inference/batching.py new file mode 100644 index 0000000..c88b6fd --- /dev/null +++ b/src/megalodon/inference/batching.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Atom-count-aware batching utilities for LoQI inference. + +Problem: molecules have highly variable sizes (10–100+ atoms with Hs). +PyG's DataLoader batches by graph count, which can lead to: +- Batches with many small molecules (wasted parallelism) OR +- Batches with one huge molecule and many idle cores + +Solution: First-Fit-Decreasing (FFD) bin-packing by atom count. +Sort molecules largest-first, then greedily fill bins up to max_atoms_per_batch. +This minimises the number of batches and reduces wasted GPU memory. +""" + +from typing import List + +from torch_geometric.data import Batch, Data + + +def ffd_pack_indices(atom_counts: List[int], max_atoms_per_batch: int) -> List[List[int]]: + """ + First-Fit-Decreasing bin-packing by atom count. + + Args: + atom_counts: Atom count for each Data item. + max_atoms_per_batch: Soft upper bound on total atoms per batch. + A single molecule whose atom count already exceeds this value + is placed in its own bin (the bound is never enforced across + bin boundaries, only during placement decisions). + + Returns: + List of index lists, one per bin. + Each inner list contains indices into the original data_list. + Indices within each bin are sorted ascending for reproducibility. + """ + if not atom_counts: + return [] + + # Sort descending by atom count + sorted_indices = sorted(range(len(atom_counts)), key=lambda i: atom_counts[i], reverse=True) + + bins: List[List] = [] # each entry: [current_total_atoms, [idx, ...]] + + for idx in sorted_indices: + n_atoms = atom_counts[idx] + placed = False + for b in bins: + if b[0] + n_atoms <= max_atoms_per_batch: + b[0] += n_atoms + b[1].append(idx) + placed = True + break + if not placed: + bins.append([n_atoms, [idx]]) + + return [sorted(b[1]) for b in bins] + + +def pack_batches(data_list: List[Data], max_atoms_per_batch: int) -> List[Batch]: + """ + Pack data_list into Batch objects using FFD atom-count bin-packing. + + Args: + data_list: List of PyG Data objects to batch. + max_atoms_per_batch: Maximum total atoms per Batch. + + Returns: + List of Batch objects, one per bin. + Items within each Batch appear in ascending original-index order. + """ + if not data_list: + return [] + atom_counts = [data.x.size(0) for data in data_list] + bin_indices = ffd_pack_indices(atom_counts, max_atoms_per_batch) + return [Batch.from_data_list([data_list[i] for i in indices]) for indices in bin_indices] diff --git a/src/megalodon/inference/featurization.py b/src/megalodon/inference/featurization.py new file mode 100644 index 0000000..bd71077 --- /dev/null +++ b/src/megalodon/inference/featurization.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +import numpy as np +import torch +from rdkit import Chem +from torch_geometric.data import Data + +# Atom type encoder: maps element symbol to integer index. +# Must match full_atom_encoder in src/megalodon/data/data_utils.py exactly. +_ATOM_ENCODER = { + "H": 0, "B": 1, "C": 2, "N": 3, "O": 4, "F": 5, + "Al": 6, "Si": 7, "P": 8, "S": 9, "Cl": 10, "As": 11, + "Br": 12, "I": 13, "Hg": 14, "Bi": 15, "Se": 16, +} + + +def _add_stereo_bonds( + mol: Chem.Mol, + chi_bonds: List[int], + ez_bonds: Dict, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + from_3D: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Add stereochemistry virtual edges to the molecular graph. + + Copied from scripts/sample_conformers.py to avoid circular imports. + from_3D=False reads stereo from SMILES (no 3D conformer needed). + """ + result = [] + if from_3D and mol.GetNumConformers() > 0: + Chem.AssignStereochemistryFrom3D(mol, replaceExistingTags=True) + else: + Chem.AssignStereochemistry(mol, cleanIt=True, force=True) + + for bond in mol.GetBonds(): + stereo = bond.GetStereo() + if bond.GetBondType() == Chem.BondType.DOUBLE and stereo in ez_bonds: + idx_3, idx_4 = bond.GetStereoAtoms() + atom_1, atom_2 = bond.GetBeginAtom(), bond.GetEndAtom() + idx_1, idx_2 = atom_1.GetIdx(), atom_2.GetIdx() + idx_5 = [n.GetIdx() for n in atom_1.GetNeighbors() if n.GetIdx() not in {idx_2, idx_3}] + idx_6 = [n.GetIdx() for n in atom_2.GetNeighbors() if n.GetIdx() not in {idx_1, idx_4}] + inv_stereo = Chem.BondStereo.STEREOE if stereo == Chem.BondStereo.STEREOZ else Chem.BondStereo.STEREOZ + result.extend([(idx_3, idx_4, ez_bonds[stereo]), (idx_4, idx_3, ez_bonds[stereo])]) + if idx_5: + result.extend([(idx_5[0], idx_4, ez_bonds[inv_stereo]), (idx_4, idx_5[0], ez_bonds[inv_stereo])]) + if idx_6: + result.extend([(idx_3, idx_6[0], ez_bonds[inv_stereo]), (idx_6[0], idx_3, ez_bonds[inv_stereo])]) + if idx_5 and idx_6: + result.extend([(idx_5[0], idx_6[0], ez_bonds[stereo]), (idx_6[0], idx_5[0], ez_bonds[stereo])]) + + if bond.GetBeginAtom().HasProp("_CIPCode"): + chirality = bond.GetBeginAtom().GetProp("_CIPCode") + neighbors = bond.GetBeginAtom().GetNeighbors() + if all(n.HasProp("_CIPRank") for n in neighbors): + sorted_neighbors = sorted(neighbors, key=lambda x: int(x.GetProp("_CIPRank")), reverse=True) + sorted_neighbors = [a.GetIdx() for a in sorted_neighbors] + a, b, c = sorted_neighbors[:3] if chirality == "R" else sorted_neighbors[:3][::-1] + d = sorted_neighbors[-1] + result.extend([ + (a, d, chi_bonds[0]), (b, d, chi_bonds[0]), (c, d, chi_bonds[0]), + (d, a, chi_bonds[0]), (d, b, chi_bonds[0]), (d, c, chi_bonds[0]), + (b, a, chi_bonds[1]), (c, b, chi_bonds[1]), (a, c, chi_bonds[1]), + ]) + + if not result: + return edge_index, edge_attr + new_edge_index = torch.tensor([[i, j] for i, j, _ in result], dtype=torch.long).T + new_edge_attr = torch.tensor([b for _, _, b in result], dtype=torch.uint8) + edge_index = torch.cat([edge_index, new_edge_index], dim=1) + edge_attr = torch.cat([edge_attr, new_edge_attr]) + return edge_index, edge_attr + + +def _mol_to_pyg(mol: Chem.Mol, smiles: str) -> Data: + """Convert RDKit molecule (with Hs, no 3D conformer) to PyG Data. + + Zero-filled coordinates — the diffusion model replaces them during sampling. + Stereochemistry is read from SMILES (from_3D=False). + """ + Chem.SanitizeMol(mol) + Chem.Kekulize(mol, clearAromaticFlags=True) + + adj = torch.from_numpy(Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True)) + edge_index = adj.nonzero().contiguous().T + bond_types = adj[edge_index[0], edge_index[1]] + bond_types[bond_types == 1.5] = 4 + edge_attr = bond_types.to(torch.uint8) + + pos = torch.zeros((mol.GetNumAtoms(), 3), dtype=torch.float32) + + atom_types = torch.tensor( + [_ATOM_ENCODER[atom.GetSymbol()] for atom in mol.GetAtoms()], + dtype=torch.uint8, + ) + charges = torch.tensor( + [atom.GetFormalCharge() for atom in mol.GetAtoms()], + dtype=torch.int8, + ) + + chi_bonds = [7, 8] + ez_bonds = {Chem.BondStereo.STEREOE: 5, Chem.BondStereo.STEREOZ: 6} + edge_index, edge_attr = _add_stereo_bonds( + mol, chi_bonds, ez_bonds, edge_index, edge_attr, from_3D=False + ) + + return Data( + x=atom_types, + edge_index=edge_index, + edge_attr=edge_attr.to(torch.uint8), + pos=pos, + charges=charges, + smiles=smiles, + mol=mol, + chemblid=mol.GetProp("_Name") if mol.HasProp("_Name") else "", + ) + + +def build_data_list( + valid_entries: List[Tuple[int, str, Chem.Mol]], + n_confs_per_mol: List[int], +) -> Tuple[List[Data], List[int], List[int]]: + """Build a flat list of PyG Data objects with parallel identity tracking lists. + + Args: + valid_entries: list of (smiles_idx, smiles, rdkit_mol_with_hs) + n_confs_per_mol: list of int, one per entry in valid_entries + + Returns: + (data_list, source_smiles_indices, conformer_indices) + All three lists are parallel — index i refers to the same item across them. + """ + data_list: List[Data] = [] + source_smiles_indices: List[int] = [] + conformer_indices: List[int] = [] + + for (smiles_idx, smiles, mol), n_confs in zip(valid_entries, n_confs_per_mol): + base_data = _mol_to_pyg(mol, smiles) + for conf_idx in range(n_confs): + data_list.append(base_data.clone()) + source_smiles_indices.append(smiles_idx) + conformer_indices.append(conf_idx) + + return data_list, source_smiles_indices, conformer_indices + + +def debatch_conformers( + generated_mols: List, + source_smiles_indices: List[int], + smiles_list: List[str], +) -> Dict[str, List]: + """Reconstruct {smiles: [mol, ...]} from flat generated list + parallel tracking. + + Args: + generated_mols: flat list, parallel to source_smiles_indices + source_smiles_indices: source_smiles_indices[i] = index into smiles_list + smiles_list: original list of SMILES (all, including failed ones) + + Returns: + {smiles_string: [mol, ...]} — keys only for successfully processed SMILES + """ + result: Dict[str, List] = {} + for mol, src_idx in zip(generated_mols, source_smiles_indices): + smiles = smiles_list[src_idx] + if smiles not in result: + result[smiles] = [] + result[smiles].append(mol) + return result + + +def _convert_coords_to_np(out): + """ + Converts the output dictionary containing 'x' (coordinates) and 'batch' (molecule indices) + into a list of NumPy arrays, where each entry represents coordinates for one molecule. + + Parameters: + out (dict): Dictionary containing: + - 'x' (torch.Tensor): Tensor of atomic coordinates (N, 3) + - 'batch' (torch.Tensor): Tensor indicating molecule index for each atom + + Returns: + List[np.ndarray]: List where each element is a NumPy array (M, 3) for a molecule. + """ + coords_list = [] + + x = out["x"].cpu().numpy() # Convert tensor to NumPy (N, 3) + batch = out["batch"].cpu().numpy() # Convert tensor to NumPy (N,) + + unique_mols = np.unique(batch) # Get unique molecule indices + + for mol_id in unique_mols: + coords_list.append(x[batch == mol_id]) # Select coordinates for each molecule + + return coords_list + + +def _write_coords_to_mol(mol, coord): + """ + Embeds 3D coordinates into an RDKit molecule and assigns stereochemistry. + """ + + # Deserialize RDKit molecule + rdkit_mol = Chem.Mol(mol) + + rdkit_mol.RemoveAllConformers() + conf = Chem.Conformer(rdkit_mol.GetNumAtoms()) + + coords = np.asarray(coord) + + for i in range(rdkit_mol.GetNumAtoms()): + conf.SetAtomPosition(i, (float(coords[i][0]), float(coords[i][1]), float(coords[i][2]))) + + rdkit_mol.AddConformer(conf) + + return rdkit_mol diff --git a/src/megalodon/inference/generation.py b/src/megalodon/inference/generation.py new file mode 100644 index 0000000..17822ba --- /dev/null +++ b/src/megalodon/inference/generation.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Public API for LoQI conformer generation. + +Usage: + result = generate_conformers( + smiles_list=["c1ccccc1", "CC(=O)O"], + model=model, + cfg=cfg, + n_confs=10, # uniform: 10 conformers per molecule + ) + sdf = result.to_sdf() + for smiles, mols in result.conformers.items(): + print(f"{smiles}: {len(mols)} conformers") + for err in result.errors: + print(f" FAILED {err.smiles}: {err.error}") +""" + +from typing import List, Optional, Union + +import torch +from omegaconf import DictConfig +from torch_geometric.loader import DataLoader + +from megalodon.inference.batching import pack_batches +from megalodon.inference.featurization import ( + build_data_list, + debatch_conformers, + _convert_coords_to_np, + _write_coords_to_mol, +) +from megalodon.inference.result import ConformerGenerationResult, MoleculeProcessingError +from megalodon.inference.validation import validate_smiles + + +def generate_conformers( + smiles_list: List[str], + model: "Graph3DInterpolantModel", + cfg: DictConfig, + n_confs: Union[int, List[int]] = 1, + batch_size: int = 48, + max_atoms_per_batch: Optional[int] = None, +) -> ConformerGenerationResult: + """ + Generate 3D conformers for a list of SMILES strings. + + Each SMILES is validated before GPU processing. Invalid SMILES are isolated + into the `errors` field without crashing the rest of the batch. + + Args: + smiles_list: List of SMILES strings to process. + model: Loaded Graph3DInterpolantModel (already on GPU, in eval mode). + cfg: OmegaConf config (needs cfg.interpolant.timesteps). + n_confs: Number of conformers per molecule. Either a single int + (same for all) or a list of ints (one per SMILES). + batch_size: Maximum number of graphs per DataLoader batch. + max_atoms_per_batch: Optional int. When provided, uses FFD atom-count bin-packing + instead of graph-count batching. Prevents OOM for mixed molecule + sizes. Overrides batch_size when specified. + + Returns: + ConformerGenerationResult with .conformers and .errors. + """ + if not smiles_list: + return ConformerGenerationResult(conformers={}, errors=[]) + + # Normalise n_confs to a per-molecule list + if isinstance(n_confs, int): + n_confs_list = [n_confs] * len(smiles_list) + else: + if len(n_confs) != len(smiles_list): + raise ValueError( + f"n_confs list length ({len(n_confs)}) must match " + f"smiles_list length ({len(smiles_list)})" + ) + n_confs_list = list(n_confs) + + # --- Validation pass: isolate bad SMILES before touching GPU --- + errors: List[MoleculeProcessingError] = [] + valid_entries = [] + valid_n_confs: List[int] = [] + + for idx, smiles in enumerate(smiles_list): + mol, err_msg = validate_smiles(smiles) + if err_msg is not None: + errors.append(MoleculeProcessingError(smiles=smiles, error=err_msg, index=idx)) + else: + valid_entries.append((idx, smiles, mol)) + valid_n_confs.append(n_confs_list[idx]) + + if not valid_entries: + return ConformerGenerationResult(conformers={}, errors=errors) + + # --- Build PyG data list with identity tracking --- + data_list, source_smiles_indices, _conf_indices = build_data_list( + valid_entries, valid_n_confs + ) + + # --- GPU sampling --- + if max_atoms_per_batch is not None: + batches = pack_batches(data_list, max_atoms_per_batch) + else: + loader = DataLoader(data_list, batch_size=batch_size, shuffle=False) + batches = list(loader) + + all_generated = [] + + with torch.no_grad(): + for batch in batches: + batch = batch.to(model.device) + sample = model.sample( + batch=batch, + timesteps=cfg.interpolant.timesteps, + pre_format=True, + ) + coords_list = _convert_coords_to_np(sample) + mols_gen = [ + _write_coords_to_mol(mol, coords) + for mol, coords in zip(batch["mol"], coords_list) + ] + all_generated.extend(mols_gen) + + # --- Reconstruct per-SMILES result --- + conformers = debatch_conformers(all_generated, source_smiles_indices, smiles_list) + + return ConformerGenerationResult(conformers=conformers, errors=errors) diff --git a/src/megalodon/inference/result.py b/src/megalodon/inference/result.py new file mode 100644 index 0000000..f5a28d8 --- /dev/null +++ b/src/megalodon/inference/result.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from dataclasses import dataclass, field +from typing import Dict, List + +from rdkit import Chem +from rdkit.Chem import SDWriter + + +@dataclass +class MoleculeProcessingError: + """Records why a SMILES string could not be processed.""" + smiles: str + error: str + index: int # position in the original smiles_list + + +@dataclass +class ConformerGenerationResult: + """ + Structured output from generate_conformers(). + + conformers: {smiles: [RDKit Mol, ...]} — one list per successfully processed SMILES. + Each Mol has a 3D conformer embedded by LoQI. + errors: list of MoleculeProcessingError for SMILES that failed validation or sampling. + """ + conformers: Dict[str, List] # {smiles: [Chem.Mol]} + errors: List[MoleculeProcessingError] = field(default_factory=list) + + @property + def n_success(self) -> int: + return sum(len(v) for v in self.conformers.values()) + + @property + def n_errors(self) -> int: + return len(self.errors) + + def to_sdf(self) -> str: + """Serialize all conformers to SDF format (string). Returns empty string if none.""" + buf = io.StringIO() + writer = SDWriter(buf) + for smiles, mols in self.conformers.items(): + for i, mol in enumerate(mols): + if mol is not None and mol.GetNumConformers() > 0: + mol_copy = Chem.Mol(mol) + mol_copy.SetProp("SMILES", smiles) + mol_copy.SetProp("conformer_idx", str(i)) + writer.write(mol_copy) + writer.close() + return buf.getvalue() diff --git a/src/megalodon/inference/validation.py b/src/megalodon/inference/validation.py new file mode 100644 index 0000000..72ab2b8 --- /dev/null +++ b/src/megalodon/inference/validation.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +from rdkit import Chem + +SUPPORTED_ELEMENTS = { + "H", "B", "C", "N", "O", "F", "Al", "Si", + "P", "S", "Cl", "As", "Br", "I", "Hg", "Bi", "Se" +} + + +def validate_smiles(smiles: str) -> Tuple[Optional[Chem.Mol], Optional[str]]: + """ + Validate a SMILES string for LoQI compatibility. + + Returns: + (mol, None) — RDKit Mol with Hs added, ready for featurization + (None, error) — string describing why the SMILES is invalid + """ + if not smiles or not smiles.strip(): + return None, "Empty SMILES string" + + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None, f"RDKit failed to parse SMILES: {smiles!r}" + + # Disconnected fragments (e.g. salts, mixtures) are not supported + if len(Chem.GetMolFrags(mol)) > 1: + return None, "Disconnected fragments (e.g. salts) are not supported" + + # Add hydrogens before element check — some implicit Hs become explicit + mol_h = Chem.AddHs(mol) + + for atom in mol_h.GetAtoms(): + sym = atom.GetSymbol() + if sym not in SUPPORTED_ELEMENTS: + return None, f"Unsupported element: {sym!r} (not in LoQI atom vocabulary)" + if atom.GetNumRadicalElectrons() > 0: + return None, ( + f"Radical electrons on {atom.GetSymbol()} (index {atom.GetIdx()}) " + f"are not supported" + ) + + return mol_h, None diff --git a/src/megalodon/models/denoising_models.py b/src/megalodon/models/denoising_models.py index 831f7f8..2bfdcb6 100644 --- a/src/megalodon/models/denoising_models.py +++ b/src/megalodon/models/denoising_models.py @@ -243,6 +243,7 @@ def forward(self, batch, time, conditional_batch=None, timesteps=None, return_fe E=batch["edge_attr_t"], E_idx=batch["edge_index"], t=time, + precomputed_attn_mask=batch.get("_attn_mask"), ) out["h_logits"] = batch["h_t"] out["edge_attr_logits"] = batch["edge_attr_t"] diff --git a/src/megalodon/models/module.py b/src/megalodon/models/module.py index 69f83e2..440659c 100644 --- a/src/megalodon/models/module.py +++ b/src/megalodon/models/module.py @@ -23,8 +23,6 @@ from lightning import pytorch as pl from omegaconf import DictConfig, OmegaConf from torch_geometric.utils import dense_to_sparse, sort_edge_index -from tqdm import tqdm - from megalodon.dynamics.utils import InterpolantLossFunction from megalodon.interpolant.builder import build_interpolant from megalodon.models.denoising_models import ModelBuilder @@ -301,8 +299,11 @@ def separate_discrete_variables(self, out, batch): for _key in combined_keys: if f"{_key}_og" in batch: batch[f"{_key}_t"] = batch[f"{_key}_og"] - if self.interpolants[_key] and self.interpolants[_key].prior_type in ["absorb", - "mask"]: + # Skip softmax for null interpolants — their logits are pass-throughs never used + if self.interpolants[_key] is None: + out[f"{_key}_hat"] = out[f"{_key}_logits"] + continue + if self.interpolants[_key].prior_type in ["absorb", "mask"]: logits = out[f"{_key}_logits"].clone() logits[:, -1] = -1e9 else: @@ -543,14 +544,62 @@ def sample(self, num_samples=None, timesteps=500, time_discretization="linear", else: shape = (total_num_atoms, interpolant.num_classes) data[f"{key}_t"] = prior[key] = interpolant.prior(batch_index, shape, self.device) + + # Pre-encode null discrete variables once before the loop — their values never change + # during diffusion, so one-hot conversion can be done here instead of 25x inside the loop. + # Only applicable when batch is not None (i.e. data[key_t] holds integer class indices). + # When batch is None, data[key_t] is already a 2D float zero tensor, not integer indices. + _null_discrete_keys = set() + if batch is not None: + for key, interpolant in self.interpolants.items(): + if interpolant is None: + interp_param = self.interpolant_param_variables[key] + if interp_param.interpolant_type is not None and "discrete" in interp_param.interpolant_type: + num_classes = interp_param.num_classes + data[f"{key}_t"] = F.one_hot(data[f"{key}_t"].long(), num_classes).float() + _null_discrete_keys.add(key) + + # Determine whether any non-null discrete variables exist (need in-loop one_hot) + _has_nonnull_discrete = any( + interpolant is not None + and self.interpolant_param_variables[key].interpolant_type is not None + and "discrete" in self.interpolant_param_variables[key].interpolant_type + for key, interpolant in self.interpolants.items() + ) + # Iterate through time, query the dynamics, apply interpolant step update + # Pre-build all time tensors before the loop to avoid 25x H2D transfers + if time_type == "continuous": + time_tensors = [ + torch.full((num_samples,), float(timeline[i]), device=self.device) + for i in range(len(DT)) + ] + else: + time_tensors = [ + torch.full((num_samples,), int(timeline[i]), device=self.device, dtype=torch.long) + for i in range(len(DT)) + ] + + # Precompute float additive bias mask once — batch_index never changes during sampling. + # Float format (0.0 = attend, -inf = block) enables efficient attention dispatch in SDPA. + _bool_mask = (batch_index.unsqueeze(0) == batch_index.unsqueeze(1)) + attn_mask = torch.zeros( + 1, 1, _bool_mask.size(0), _bool_mask.size(0), + device=self.device, dtype=torch.float32, + ) + attn_mask.masked_fill_(~_bool_mask.unsqueeze(0).unsqueeze(0), float('-inf')) + del _bool_mask + out = {} - for idx in tqdm(list(range(len(DT))), total=len(DT)): + for idx in range(len(DT)): t = timeline[idx] dt = DT[idx] - time = torch.tensor([t] * num_samples).to(self.device) - data = self.one_hot(data) + time = time_tensors[idx] + # Null discrete variables are already one-hot (pre-encoded above). + # Only call one_hot for non-null discrete variables if any exist. + if _has_nonnull_discrete: + data = self.one_hot(data) # Apply Self Conditioning pre_conditioning_variables = {} # ! Try turning off self conditioning --> fixed some but still had edge blow ups can try adding norms here TODO @@ -560,6 +609,7 @@ def sample(self, num_samples=None, timesteps=500, time_discretization="linear", data = self.aggregate_discrete_variables(data) data["batch"] = batch_index data["edge_index"] = edge_index + data["_attn_mask"] = attn_mask out = self.dynamics(data, time, conditional_batch=out, timesteps=timesteps) # ! Error is for FM sampling EQGAT is producing NANs in discrete logits out, data = self.separate_discrete_variables(out, data) @@ -568,6 +618,9 @@ def sample(self, num_samples=None, timesteps=500, time_discretization="linear", data[key] = pre_conditioning_variables[key] for key, interpolant in self.interpolants.items(): if interpolant is None: + if key in _null_discrete_keys: + # Already one-hot encoded before the loop; skip reset to raw integers. + continue prior[key] = batch[key] data[f"{key}_t"] = prior[key] continue