diff --git a/pyproject.toml b/pyproject.toml index dec4906e..94feeaac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,19 @@ [project] name = "arc-state" -version = "0.9.31" +version = "0.9.32" description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts." readme = "README.md" authors = [ { name = "Abhinav Adduri", email = "abhinav.adduri@arcinstitute.org" }, { name = "Yusuf Roohani", email = "yusuf.roohani@arcinstitute.org" }, { name = "Noam Teyssier", email = "noam.teyssier@arcinstitute.org" }, - { name = "Rajesh Ilango" }, + { name = "Rajesh Ilango", email = "rilango@gmail.com" }, { name = "Dhruv Gautam", email = "dhruvgautam@berkeley.edu" }, ] requires-python = ">=3.10,<3.13" dependencies = [ "anndata>=0.11.4", - "cell-load>=0.8.3", + "cell-load>=0.8.7", "numpy>=2.2.6", "pandas>=2.2.3", "pyyaml>=6.0.2", @@ -27,11 +27,15 @@ dependencies = [ "geomloss>=0.2.6", "transformers>=4.52.3", "peft>=0.11.0", - "cell-eval>=0.5.22", + "cell-eval>=0.6.2", "ipykernel>=6.30.1", "scipy>=1.15.0", ] +[tool.uv.sources] +cell-load = {path = "/home/aadduri/cell-load"} +cell-eval = {git = "https://github.com/ArcInstitute/cell-eval", branch = "aadduri/aupr_curves"} + [project.optional-dependencies] vectordb = [ "lancedb>=0.24.0" diff --git a/scripts/state_embed_anndata.py b/scripts/state_embed_anndata.py deleted file mode 100644 index 3af33ab0..00000000 --- a/scripts/state_embed_anndata.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -""" -VCI Model Embedding Script - -This script computes embeddings for an input anndata file using a pre-trained VCI model checkpoint. -It can be run from any directory and outputs the embedded anndata to a specified location. - -Usage: - python embed_vci.py --checkpoint PATH_TO_CHECKPOINT --input INPUT_ANNDATA --output OUTPUT_ANNDATA - -Example: - python embed_vci.py --checkpoint /path/to/model.ckpt --input data.h5ad --output embedded_data.h5ad -""" - -import argparse -import os - -from omegaconf import OmegaConf - -from state_sets.state.inference import Inference - - -# Parse command line arguments -def parse_args(): - parser = argparse.ArgumentParser(description="Compute embeddings for anndata using a VCI model") - parser.add_argument("--checkpoint", required=True, help="Path to the model checkpoint file") - parser.add_argument("--config", required=True, help="Path to the model training config") - parser.add_argument("--input", required=True, help="Path to input anndata file (h5ad)") - parser.add_argument("--output", required=True, help="Path to output embedded anndata file (h5ad)") - parser.add_argument("--dataset-name", default="perturbation", help="Dataset name to be used in dataloader creation") - parser.add_argument("--gpu", action="store_true", help="Use GPU if available") - parser.add_argument("--filter", action="store_true", help="Filter gene set to our esm embeddings only.") - parser.add_argument("--embed-key", help="Name of key to store") - - return parser.parse_args() - - -def main(): - # Parse command line arguments - args = parse_args() - - conf = OmegaConf.load(args.config) - inferer = Inference(conf) - inferer.load_model(args.checkpoint) - os.makedirs(os.path.dirname(args.output), exist_ok=True) - inferer.encode_adata(args.input, args.output, emb_key=args.embed_key, dataset_name=args.dataset_name) - - -if __name__ == "__main__": - main() diff --git a/src/state/__main__.py b/src/state/__main__.py index 0a7f9236..20068d8d 100644 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -11,6 +11,7 @@ run_emb_query, run_emb_preprocess, run_emb_eval, + run_tx_combo, run_tx_infer, run_tx_predict, run_tx_preprocess_infer, @@ -124,6 +125,8 @@ def main(): case "infer": # Run inference using argparse, similar to predict run_tx_infer(args) + case "combo": + run_tx_combo(args) case "preprocess_train": # Run preprocessing using argparse run_tx_preprocess_train(args.adata, args.output, args.num_hvgs) diff --git a/src/state/_cli/__init__.py b/src/state/_cli/__init__.py index 2507d565..af9b4107 100644 --- a/src/state/_cli/__init__.py +++ b/src/state/_cli/__init__.py @@ -1,6 +1,7 @@ from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query, run_emb_preprocess, run_emb_eval from ._tx import ( add_arguments_tx, + run_tx_combo, run_tx_infer, run_tx_predict, run_tx_preprocess_infer, @@ -16,6 +17,7 @@ "run_tx_infer", "run_tx_preprocess_train", "run_tx_preprocess_infer", + "run_tx_combo", "run_emb_fit", "run_emb_query", "run_emb_transform", diff --git a/src/state/_cli/_emb/_transform.py b/src/state/_cli/_emb/_transform.py index 26b38cfa..becd7480 100644 --- a/src/state/_cli/_emb/_transform.py +++ b/src/state/_cli/_emb/_transform.py @@ -3,8 +3,16 @@ def add_arguments_transform(parser: ap.ArgumentParser): """Add arguments for state embedding CLI.""" - parser.add_argument("--model-folder", required=True, help="Path to the model checkpoint folder") - parser.add_argument("--checkpoint", required=False, help="Path to the specific model checkpoint") + parser.add_argument( + "--model-folder", + required=False, + help="Path to the model checkpoint folder (required if --checkpoint is not provided)", + ) + parser.add_argument( + "--checkpoint", + required=False, + help="Path to the specific model checkpoint (required if --model-folder is not provided)", + ) parser.add_argument( "--config", required=False, @@ -46,6 +54,7 @@ def run_emb_transform(args: ap.ArgumentParser): import glob import logging import os + import numpy as np import torch from omegaconf import OmegaConf @@ -60,13 +69,19 @@ def run_emb_transform(args: ap.ArgumentParser): logger.error("Either --output or --lancedb must be provided") raise ValueError("Either --output or --lancedb must be provided") - # look in the model folder with glob for *.ckpt, get the first one, and print it - model_files = glob.glob(os.path.join(args.model_folder, "*.ckpt")) - if not model_files: - logger.error(f"No model checkpoint found in {args.model_folder}") - raise FileNotFoundError(f"No model checkpoint found in {args.model_folder}") - if not args.checkpoint: - args.checkpoint = model_files[-1] + # Resolve checkpoint path, allowing either --checkpoint, --model-folder, or both + checkpoint_path = args.checkpoint + if args.model_folder: + model_files = glob.glob(os.path.join(args.model_folder, "*.ckpt")) + if not model_files and not checkpoint_path: + logger.error(f"No model checkpoint found in {args.model_folder}") + raise FileNotFoundError(f"No model checkpoint found in {args.model_folder}") + if not checkpoint_path and model_files: + checkpoint_path = model_files[-1] + if not checkpoint_path: + logger.error("Either --checkpoint or --model-folder must be provided") + raise ValueError("Either --checkpoint or --model-folder must be provided") + args.checkpoint = checkpoint_path logger.info(f"Using model checkpoint: {args.checkpoint}") # Create inference object @@ -79,7 +94,7 @@ def run_emb_transform(args: ap.ArgumentParser): if args.protein_embeddings: logger.info(f"Using protein embeddings override: {args.protein_embeddings}") protein_embeds = torch.load(args.protein_embeddings, weights_only=False, map_location="cpu") - else: + elif args.model_folder: # Try auto-detect in model folder try: exact_path = os.path.join(args.model_folder, "protein_embeddings.pt") @@ -110,6 +125,12 @@ def run_emb_transform(args: ap.ArgumentParser): logger.info(f"Loading model from checkpoint: {args.checkpoint}") inferer.load_model(args.checkpoint) + save_as_npy = False + output_target = args.output + if args.output: + _, ext = os.path.splitext(args.output) + save_as_npy = ext.lower() == ".npy" + # Create output directory if it doesn't exist if args.output: output_dir = os.path.dirname(args.output) @@ -120,13 +141,16 @@ def run_emb_transform(args: ap.ArgumentParser): # Generate embeddings logger.info(f"Computing embeddings for {args.input}") if args.output: - logger.info(f"Output will be saved to {args.output}") + if save_as_npy: + logger.info(f"Output embeddings will be saved to {args.output} as a NumPy array") + else: + logger.info(f"Output will be saved to {args.output}") if args.lancedb: logger.info(f"Embeddings will be saved to LanceDB at {args.lancedb}") - inferer.encode_adata( + embeddings = inferer.encode_adata( input_adata_path=args.input, - output_adata_path=args.output, + output_adata_path=None if save_as_npy else output_target, emb_key=args.embed_key, batch_size=args.batch_size if getattr(args, "batch_size", None) is not None else None, lancedb_path=args.lancedb, @@ -134,4 +158,11 @@ def run_emb_transform(args: ap.ArgumentParser): lancedb_batch_size=args.lancedb_batch_size, ) + if save_as_npy: + if embeddings is None: + logger.error("Failed to generate embeddings for NumPy output") + raise RuntimeError("Embedding generation returned no data") + np.save(args.output, embeddings) + logger.info(f"Saved embeddings matrix with shape {embeddings.shape} to {args.output}") + logger.info("Embedding computation completed successfully!") diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py index 975fba42..cdadc904 100644 --- a/src/state/_cli/_tx/__init__.py +++ b/src/state/_cli/_tx/__init__.py @@ -5,6 +5,7 @@ from ._preprocess_infer import add_arguments_preprocess_infer, run_tx_preprocess_infer from ._preprocess_train import add_arguments_preprocess_train, run_tx_preprocess_train from ._train import add_arguments_train, run_tx_train +from ._combo import add_arguments_combo, run_tx_combo __all__ = [ "run_tx_train", @@ -12,6 +13,7 @@ "run_tx_infer", "run_tx_preprocess_train", "run_tx_preprocess_infer", + "run_tx_combo", "add_arguments_tx", ] @@ -24,3 +26,4 @@ def add_arguments_tx(parser: ap.ArgumentParser): add_arguments_infer(subparsers.add_parser("infer")) add_arguments_preprocess_train(subparsers.add_parser("preprocess_train")) add_arguments_preprocess_infer(subparsers.add_parser("preprocess_infer")) + add_arguments_combo(subparsers.add_parser("combo")) diff --git a/src/state/_cli/_tx/_combo.py b/src/state/_cli/_tx/_combo.py new file mode 100644 index 00000000..7b5dc07a --- /dev/null +++ b/src/state/_cli/_tx/_combo.py @@ -0,0 +1,635 @@ +import argparse as ap + + +def add_arguments_combo(parser: ap.ArgumentParser) -> None: + """CLI for two-stage perturbation combination sweeps.""" + + parser.add_argument("--model-dir", type=str, required=True, help="Path to the trained model directory.") + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help=( + "Optional checkpoint path. If omitted, defaults to /checkpoints/last.ckpt " + "(falling back to final.ckpt if needed)." + ), + ) + parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad).") + parser.add_argument( + "--embed-key", + type=str, + default=None, + help="Optional key in adata.obsm for input features (defaults to adata.X).", + ) + parser.add_argument( + "--pert-col", + type=str, + required=True, + help="Column in adata.obs containing perturbation labels.", + ) + parser.add_argument( + "--control-pert", + type=str, + required=True, + help="Label of the control perturbation (used to construct the base control set).", + ) + parser.add_argument( + "--cell-type", + type=str, + required=True, + help="Target cell type value to filter before running the combo sweep.", + ) + parser.add_argument( + "--celltype-col", + type=str, + default=None, + help=( + "Optional column name in adata.obs for cell types. If omitted, attempts to detect using the " + "training config or common fallbacks." + ), + ) + parser.add_argument( + "--cell-set-len", + type=int, + default=None, + help="Override the model cell_set_len when constructing the fixed control set.", + ) + parser.add_argument( + "--batch-col", + type=str, + default=None, + help=( + "Optional batch column in adata.obs. If omitted, attempts to detect from training config " + "or common fallbacks when the model uses a batch encoder." + ), + ) + parser.add_argument( + "--inner-batch-size", + type=int, + default=1, + help="Number of target perturbations to evaluate simultaneously in the second pass.", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed for control sampling.") + parser.add_argument( + "--output-folder", + type=str, + default=None, + help=( + "Directory where per-perturbation AnnData outputs (.h5ad) are written." + " Defaults to _combo/ alongside the input file." + ), + ) + parser.add_argument("--quiet", action="store_true", help="Reduce logging verbosity.") + + +def run_tx_combo(args: ap.Namespace) -> None: + import logging + import os + import pickle + import re + + import anndata as ad + import numpy as np + import pandas as pd + import scanpy as sc + import torch + import yaml + + from tqdm import tqdm + + from ...tx.models.state_transition import StateTransitionPerturbationModel + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + if args.quiet: + logger.setLevel(logging.WARNING) + + def _load_config(cfg_path: str) -> dict: + if not os.path.exists(cfg_path): + raise FileNotFoundError(f"Could not find config file: {cfg_path}") + with open(cfg_path, "r", encoding="utf-8") as handle: + return yaml.safe_load(handle) + + def _pick_first_present(columns: pd.Index, candidates: list[str | None]) -> str | None: + for key in candidates: + if key and key in columns: + return key + return None + + def _to_dense(matrix) -> np.ndarray: + try: + import scipy.sparse as sp # type: ignore + + if sp.issparse(matrix): + return matrix.toarray() + except Exception: + pass + return np.asarray(matrix) + + def _normalize_pert_vector(raw_vec, expected_dim: int) -> torch.Tensor: + if raw_vec is None: + return torch.zeros(expected_dim, dtype=torch.float32) + if torch.is_tensor(raw_vec): + return raw_vec.detach().float() + vec_np = np.asarray(raw_vec) + return torch.tensor(vec_np, dtype=torch.float32) + + def _flatten_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None: + if tensor is None: + return None + if tensor.dim() == 3 and tensor.shape[0] == 1: + return tensor.squeeze(0) + return tensor + + def _tensor_to_numpy(tensor: torch.Tensor | None) -> np.ndarray | None: + if tensor is None: + return None + flat = _flatten_tensor(tensor) + if flat is None: + return None + return flat.detach().cpu().numpy().astype(np.float32) + + def _argmax_index_from_any(value, expected_dim: int | None = None) -> int | None: + if value is None: + return None + try: + if torch.is_tensor(value): + if value.ndim == 0: + return int(value.item()) + if value.ndim == 1: + return int(torch.argmax(value).item()) + return None + except Exception: + return None + try: + arr = np.asarray(value) + if arr.ndim == 0: + return int(arr.item()) + if arr.ndim == 1: + return int(arr.argmax()) + except Exception: + pass + if isinstance(value, (int, np.integer)): + return int(value) + if isinstance(value, (list, tuple)): + try: + arr = np.asarray(value) + if arr.ndim == 1: + return int(arr.argmax()) + except Exception: + return None + return None + + model_dir = os.path.abspath(args.model_dir) + config_path = os.path.join(model_dir, "config.yaml") + cfg = _load_config(config_path) + + var_dims_path = os.path.join(model_dir, "var_dims.pkl") + if not os.path.exists(var_dims_path): + raise FileNotFoundError(f"Missing var_dims.pkl at {var_dims_path}") + with open(var_dims_path, "rb") as handle: + var_dims = pickle.load(handle) + + input_dim = int(var_dims.get("input_dim", 0)) + if input_dim <= 0: + raise ValueError("input_dim missing from var_dims.pkl; cannot determine feature dimension") + + pert_dim = int(var_dims.get("pert_dim", 0)) + if pert_dim <= 0: + raise ValueError("pert_dim missing from var_dims.pkl; cannot build perturbation embeddings") + + batch_dim_entry = var_dims.get("batch_dim") + batch_dim = int(batch_dim_entry) if batch_dim_entry is not None else None + + pert_map_path = os.path.join(model_dir, "pert_onehot_map.pt") + if not os.path.exists(pert_map_path): + raise FileNotFoundError(f"Missing pert_onehot_map.pt at {pert_map_path}") + pert_onehot_map = torch.load(pert_map_path, weights_only=False) + + batch_onehot_map_path = os.path.join(model_dir, "batch_onehot_map.pkl") + batch_onehot_map = None + if os.path.exists(batch_onehot_map_path): + with open(batch_onehot_map_path, "rb") as handle: + batch_onehot_map = pickle.load(handle) + + checkpoint_path = args.checkpoint + if checkpoint_path is None: + default_last = os.path.join(model_dir, "checkpoints", "last.ckpt") + default_final = os.path.join(model_dir, "checkpoints", "final.ckpt") + checkpoint_path = default_last if os.path.exists(default_last) else default_final + elif not os.path.isabs(checkpoint_path): + candidate = os.path.join(model_dir, checkpoint_path) + checkpoint_path = candidate if os.path.exists(candidate) else checkpoint_path + + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") + + model = StateTransitionPerturbationModel.load_from_checkpoint(checkpoint_path) + model.eval() + device = next(model.parameters()).device + cell_set_len = args.cell_set_len or getattr(model, "cell_sentence_len", 256) + + uses_batch_encoder = getattr(model, "batch_encoder", None) is not None + if uses_batch_encoder and (batch_dim is None or batch_dim <= 0): + raise ValueError("Model uses a batch encoder but batch_dim missing from var_dims.pkl") + if uses_batch_encoder and batch_onehot_map is None: + raise FileNotFoundError( + "Model uses a batch encoder but batch_onehot_map.pkl was not found in the model directory" + ) + + logger.info("Loaded model from %s (cell_set_len=%d)", checkpoint_path, cell_set_len) + + adata = sc.read_h5ad(args.adata) + + data_kwargs = {} + try: + data_kwargs = cfg.get("data", {}).get("kwargs", {}) # type: ignore[assignment] + except AttributeError: + data_kwargs = {} + + celltype_col = args.celltype_col + if celltype_col is None: + cfg_celltype = None + try: + cfg_celltype = data_kwargs.get("cell_type_key") + except Exception: + cfg_celltype = None + candidates = [ + cfg_celltype, + "cell_type", + "celltype", + "cell_type_name", + "celltype_name", + "cellType", + "ctype", + ] + celltype_col = _pick_first_present(adata.obs.columns, candidates) + if celltype_col is None: + raise ValueError("Could not determine cell type column; provide --celltype-col explicitly.") + if celltype_col not in adata.obs: + raise KeyError(f"Column '{celltype_col}' not found in adata.obs") + + if args.pert_col not in adata.obs: + raise KeyError(f"Perturbation column '{args.pert_col}' not found in adata.obs") + + adata_ct = adata[adata.obs[celltype_col].astype(str) == str(args.cell_type)].copy() + if adata_ct.n_obs == 0: + raise ValueError(f"No cells found with cell type '{args.cell_type}' in column '{celltype_col}'") + + pert_series = adata_ct.obs[args.pert_col].astype(str) + control_mask = pert_series == str(args.control_pert) + control_indices = np.where(control_mask)[0] + if len(control_indices) == 0: + raise ValueError( + f"No control cells with perturbation '{args.control_pert}' found in column '{args.pert_col}' " + f"for cell type '{args.cell_type}'" + ) + + perts_all = pd.unique(pert_series) + perts = [p for p in perts_all if p != str(args.control_pert)] + if len(perts) == 0: + raise ValueError("No non-control perturbations found in filtered AnnData") + + batch_indices_all: np.ndarray | None = None + batch_col = args.batch_col if args.batch_col is not None else data_kwargs.get("batch_col") + if uses_batch_encoder: + candidate_batch_cols: list[str] = [] + if batch_col is not None: + candidate_batch_cols.append(batch_col) + if isinstance(data_kwargs.get("batch_col"), str): + candidate_batch_cols.append(data_kwargs.get("batch_col")) + candidate_batch_cols.extend( + [ + "gem_group", + "gemgroup", + "batch", + "donor", + "plate", + "experiment", + "lane", + "batch_id", + ] + ) + resolved_batch_col = next((col for col in candidate_batch_cols if col in adata_ct.obs), None) + if resolved_batch_col is None: + raise ValueError( + "Model uses a batch encoder but no batch column was found. Provide --batch-col explicitly." + ) + batch_col = resolved_batch_col + raw_batch_labels = adata_ct.obs[batch_col].astype(str).values + + assert batch_onehot_map is not None + label_to_idx: dict[str, int] = {} + if isinstance(batch_onehot_map, dict): + for key, value in batch_onehot_map.items(): + idx = _argmax_index_from_any(value, batch_dim) + if idx is not None: + label_to_idx[str(key)] = idx + + if not label_to_idx and batch_dim is not None: + unique_labels = sorted(set(raw_batch_labels)) + label_to_idx = {lab: min(i, batch_dim - 1) for i, lab in enumerate(unique_labels)} + + if not label_to_idx: + raise ValueError("Unable to construct batch label mapping; batch_onehot_map is empty or invalid") + + fallback_idx = sorted(label_to_idx.values())[0] + batch_indices_all = np.zeros(len(raw_batch_labels), dtype=np.int64) + misses = 0 + for i, lab in enumerate(raw_batch_labels): + idx = label_to_idx.get(lab) + if idx is None: + batch_indices_all[i] = fallback_idx + misses += 1 + else: + batch_indices_all[i] = idx + + if misses: + logger.warning( + "Batch column '%s': %d/%d labels missing from saved mapping; using fallback index %d", + batch_col, + misses, + len(raw_batch_labels), + fallback_idx, + ) + logger.info( + "Using batch column '%s' with %d unique mapped labels", + batch_col, + len(np.unique(batch_indices_all)), + ) + + cfg_embed_key = data_kwargs.get("embed_key") + explicit_embed_key = args.embed_key is not None + + candidate_order: list[str | None] = [] + seen_keys: set[str | None] = set() + + def _append_candidate(key: str | None) -> None: + if key in seen_keys: + return + seen_keys.add(key) + candidate_order.append(key) + + if explicit_embed_key: + _append_candidate(args.embed_key) + else: + if isinstance(cfg_embed_key, str): + _append_candidate(cfg_embed_key) + _append_candidate(None) + for fallback_key in ("X_hvg", "X_state", "X_state_basal", "X_state_pred", "X_pca", "X_latent"): + if fallback_key in adata_ct.obsm: + _append_candidate(fallback_key) + + selection_errors: list[str] = [] + features = None + used_embed_key: str | None = None + + for candidate in candidate_order: + matrix = None + label = "adata.X" if candidate is None else f"adata.obsm['{candidate}']" + + if candidate is None: + matrix = _to_dense(adata_ct.X) + else: + if candidate not in adata_ct.obsm: + if explicit_embed_key: + raise KeyError(f"Embedding key '{candidate}' not found in adata.obsm") + selection_errors.append(f"{label} missing") + continue + matrix = np.asarray(adata_ct.obsm[candidate]) + + if matrix.shape[0] != adata_ct.n_obs: + msg = f"{label} row count {matrix.shape[0]} != filtered AnnData cells {adata_ct.n_obs}" + if explicit_embed_key: + raise ValueError(msg) + selection_errors.append(msg) + continue + + if matrix.shape[1] != input_dim: + msg = f"{label} feature dimension {matrix.shape[1]} != model input_dim {input_dim}" + if explicit_embed_key: + raise ValueError( + msg + + ". Provide --embed-key pointing to a representation with matching dimension or preprocess the input." + ) + selection_errors.append(msg) + continue + + features = matrix + used_embed_key = candidate + break + + if features is None: + tried = ", ".join(["adata.X" if c is None else f"adata.obsm['{c}']" for c in candidate_order]) or "(none)" + detail = "; ".join(selection_errors) if selection_errors else "No suitable feature representation found." + raise ValueError( + f"Unable to find a feature matrix matching the model input dimension. Tried: {tried}. {detail}" + ) + + if used_embed_key is None: + logger.info("Using adata.X (%d cells x %d features) as input features", features.shape[0], features.shape[1]) + else: + logger.info( + "Using adata.obsm['%s'] (%d cells x %d features) as input features", + used_embed_key, + features.shape[0], + features.shape[1], + ) + + features = features.astype(np.float32, copy=False) + + rng = np.random.default_rng(args.seed) + replace = len(control_indices) < cell_set_len + sampled_idx = rng.choice(control_indices, size=cell_set_len, replace=replace) + control_features = features[sampled_idx] + + default_vec = _normalize_pert_vector(pert_onehot_map.get(str(args.control_pert)), pert_dim) + if default_vec.numel() != pert_dim: + default_vec = torch.zeros(pert_dim, dtype=torch.float32) + + control_batch_tensor = None + if batch_indices_all is not None: + control_batch_tensor = torch.tensor(batch_indices_all[sampled_idx], dtype=torch.long, device=device) + + def _pert_batch_tensor(name: str) -> torch.Tensor: + raw_vec = pert_onehot_map.get(name) + vec = _normalize_pert_vector(raw_vec, pert_dim) if raw_vec is not None else default_vec + if vec.dim() == 0: + vec = vec.unsqueeze(0) + vec = vec.reshape(-1) + if vec.numel() != pert_dim: + raise ValueError(f"Perturbation vector for '{name}' has incorrect dimension {vec.numel()} != {pert_dim}") + return vec.unsqueeze(0).repeat(cell_set_len, 1).to(device) + + pert_batch_vectors = {name: _pert_batch_tensor(name) for name in perts} + + control_tensor = torch.tensor(control_features, dtype=torch.float32, device=device) + + use_counts: bool | None = None + inner_batch_size = max(1, int(args.inner_batch_size)) + + def _default_output_dir(path: str) -> str: + base_dir = os.path.dirname(os.path.abspath(path)) + base_name = os.path.splitext(os.path.basename(path))[0] + return os.path.join(base_dir, f"{base_name}_combo") + + output_dir = args.output_folder or _default_output_dir(args.adata) + output_dir = os.path.abspath(output_dir) + os.makedirs(output_dir, exist_ok=True) + logger.info("Writing per-perturbation combo outputs to %s", output_dir) + + def _sanitize_filename(label: str) -> str: + sanitized = re.sub(r"[^0-9A-Za-z_.-]+", "_", label) + sanitized = sanitized.strip("._") + return sanitized or "perturbation" + + used_output_names: dict[str, int] = {} + written_files: list[str] = [] + skipped_perts: list[str] = [] + + try: + existing_output_names = { + os.path.splitext(fname)[0] + for fname in os.listdir(output_dir) + if fname.endswith(".h5ad") + } + except OSError: + existing_output_names = set() + + num_target_perts = len(perts) + + with torch.no_grad(): + progress_total = num_target_perts * num_target_perts + progress_bar = tqdm( + total=progress_total, + desc="Combo sweeps", + unit="combo", + disable=args.quiet, + ) + for pert1 in perts: + base_name = _sanitize_filename(pert1) + occurrence_idx = used_output_names.get(base_name, -1) + 1 + used_output_names[base_name] = occurrence_idx + output_name = base_name if occurrence_idx == 0 else f"{base_name}_{occurrence_idx}" + output_path = os.path.join(output_dir, f"{output_name}.h5ad") + + if output_name in existing_output_names or os.path.exists(output_path): + skipped_perts.append(pert1) + progress_bar.update(num_target_perts) + logger.info("Skipping combos for %s; existing output at %s", pert1, output_path) + continue + + per_pert_X_blocks: list[np.ndarray] = [] + per_pert_latent_blocks: list[np.ndarray] = [] + per_pert_obs_rows: list[dict[str, str | int]] = [] + + first_batch = { + "ctrl_cell_emb": control_tensor.clone(), + "pert_emb": pert_batch_vectors[pert1], + "pert_name": [pert1] * cell_set_len, + } + if control_batch_tensor is not None: + first_batch["batch"] = control_batch_tensor.clone() + first_out = model.predict_step(first_batch, batch_idx=0, padded=False) + first_latent_tensor = _flatten_tensor(first_out.get("preds")) + if first_latent_tensor is None: + raise RuntimeError("Model predict_step did not return 'preds' tensor") + first_latent_tensor = first_latent_tensor.detach().to(device) + + for chunk_start in range(0, len(perts), inner_batch_size): + chunk_perts = perts[chunk_start : chunk_start + inner_batch_size] + chunk_size = len(chunk_perts) + + ctrl_chunk = torch.cat([first_latent_tensor.clone() for _ in chunk_perts], dim=0) + pert_chunk = torch.cat([pert_batch_vectors[p] for p in chunk_perts], dim=0) + names_chunk = [p for p in chunk_perts for _ in range(cell_set_len)] + + second_batch = { + "ctrl_cell_emb": ctrl_chunk, + "pert_emb": pert_chunk, + "pert_name": names_chunk, + } + + if control_batch_tensor is not None: + batch_chunk = control_batch_tensor.repeat(chunk_size) + second_batch["batch"] = batch_chunk + + second_out = model.predict_step(second_batch, batch_idx=0, padded=True) + + latent_np = _tensor_to_numpy(second_out.get("preds")) + counts_np = _tensor_to_numpy(second_out.get("pert_cell_counts_preds")) + + if latent_np is None: + raise RuntimeError("Second-stage prediction missing 'preds' output") + + latent_np = latent_np.reshape(chunk_size, cell_set_len, -1) + counts_np = counts_np.reshape(chunk_size, cell_set_len, -1) if counts_np is not None else None + + if use_counts is None: + use_counts = counts_np is not None + elif use_counts and counts_np is None: + raise RuntimeError("Inconsistent decoder outputs across perturbations; expected counts predictions") + + for idx_chunk, pert2 in enumerate(chunk_perts): + latent_slice = latent_np[idx_chunk].astype(np.float32) + if use_counts: + assert counts_np is not None + per_pert_X_blocks.append(counts_np[idx_chunk].astype(np.float32)) + else: + per_pert_X_blocks.append(latent_slice) + per_pert_latent_blocks.append(latent_slice) + + for cell_idx in range(cell_set_len): + per_pert_obs_rows.append({"pert1": pert1, "pert2": pert2, "cell_index": cell_idx}) + + progress_bar.update(1) + + X_matrix = np.vstack(per_pert_X_blocks) if per_pert_X_blocks else np.empty((0, 0), dtype=np.float32) + latent_matrix = ( + np.vstack(per_pert_latent_blocks) if per_pert_latent_blocks else np.empty((0, 0), dtype=np.float32) + ) + obs_df = pd.DataFrame(per_pert_obs_rows) + + feature_dim = 0 + if use_counts and X_matrix.size > 0: + feature_dim = X_matrix.shape[1] + elif latent_matrix.size > 0: + feature_dim = latent_matrix.shape[1] + elif X_matrix.size > 0: + feature_dim = X_matrix.shape[1] + + gene_names = var_dims.get("gene_names") + if ( + use_counts + and feature_dim > 0 + and isinstance(gene_names, (list, tuple)) + and len(gene_names) == feature_dim + ): + var_index = pd.Index([str(name) for name in gene_names], name="gene") + else: + var_index = pd.Index([f"feature_{i}" for i in range(feature_dim)], name="feature") + var_df = pd.DataFrame(index=var_index) + + combo_adata = ad.AnnData(X=X_matrix, obs=obs_df, var=var_df) + combo_adata.obsm["latent_preds"] = latent_matrix + combo_adata.uns["cell_type"] = str(args.cell_type) + combo_adata.uns["perturbations"] = perts + combo_adata.uns["pert1"] = pert1 + combo_adata.uns["control_pert"] = str(args.control_pert) + combo_adata.uns["cell_set_len"] = cell_set_len + combo_adata.uns["input_embed_key"] = used_embed_key if used_embed_key is not None else "X" + if uses_batch_encoder and batch_col is not None: + combo_adata.uns["batch_col"] = batch_col + combo_adata.uns["inner_batch_size"] = inner_batch_size + combo_adata.uns["sampled_control_indices"] = adata_ct.obs_names[sampled_idx].tolist() + + combo_adata.write_h5ad(output_path) + written_files.append(output_path) + existing_output_names.add(output_name) + logger.info("Saved combos for %s with %d cells to %s", pert1, combo_adata.n_obs, output_path) + + progress_bar.close() + + logger.info("Finished writing %d combo files to %s", len(written_files), output_dir) + if skipped_perts: + logger.info("Skipped %d perturbations with existing combo outputs", len(skipped_perts)) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index a4f05d04..9927623a 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -1,4 +1,5 @@ import argparse +import ast from typing import Dict, List, Optional import pandas as pd @@ -23,11 +24,17 @@ def add_arguments_infer(parser: argparse.ArgumentParser): default="drugname_drugconc", help="Column in adata.obs for perturbation labels", ) + parser.add_argument( + "--dosages", + type=str, + default=None, + help="Optional list of dosages (floats) to materialize for each perturbation, e.g. \"[0.1, 0.5, 1.0]\".", + ) parser.add_argument( "--output", type=str, default=None, - help="Path to output AnnData file (.h5ad). Defaults to _simulated.h5ad", + help="Path to output file (.h5ad or .npy). Defaults to _simulated.h5ad", ) parser.add_argument( "--model-dir", @@ -82,6 +89,29 @@ def add_arguments_infer(parser: argparse.ArgumentParser): default=None, help="Path to TSV file with columns 'perturbation' and 'num_cells' to pad the adata with additional perturbation cells copied from random controls.", ) + parser.add_argument( + "--all-perts", + action="store_true", + help="If set, add virtual copies of control cells for every perturbation in the saved one-hot map so all perturbations are simulated.", + ) + parser.add_argument( + "--virtual-cells-per-pert", + type=int, + default=None, + help="When using --all-perts, limit the number of control cells cloned for each virtual perturbation to this many (default: use all available controls).", + ) + parser.add_argument( + "--min-cells", + type=int, + default=None, + help="Ensure each perturbation has at least this many cells by padding with virtual controls (if needed).", + ) + parser.add_argument( + "--max-cells", + type=int, + default=None, + help="Upper bound on cells per perturbation after padding; subsamples excess cells if necessary.", + ) def run_tx_infer(args: argparse.Namespace): @@ -151,12 +181,115 @@ def argmax_index_from_any(v, expected_dim: Optional[int]) -> Optional[int]: return int(v) return None + def parse_dosage_argument(arg_value: Optional[str]) -> List[float]: + if arg_value is None: + return [] + if isinstance(arg_value, (list, tuple)): + candidate_values = arg_value + else: + text = str(arg_value).strip() + if not text: + return [] + parsed = None + try: + parsed = ast.literal_eval(text) + except (ValueError, SyntaxError): + parsed = None + if isinstance(parsed, (list, tuple)): + candidate_values = parsed + else: + text = text.strip("[]") + parts = [p for p in text.replace(",", " ").split() if p] + candidate_values = parts + deduped: List[float] = [] + seen: set[float] = set() + for value in candidate_values: + try: + val = float(value) + except (TypeError, ValueError): + raise ValueError(f"Invalid dosage value '{value}' in --dosages argument.") + key = round(val, 12) + if key not in seen: + seen.add(key) + deduped.append(val) + return deduped + + def extend_perturbation_map_for_dosages( + pert_map: Dict[str, torch.Tensor], + requested_dosages: List[float], + *, + control_label: Optional[str], + quiet: bool, + ) -> List[str]: + if not requested_dosages: + return [] + + def almost_equal(a: float, b: float, tol: float = 1e-9) -> bool: + return abs(a - b) <= tol + + canonical_vectors: Dict[tuple[str, Optional[str]], Dict[str, object]] = {} + for key, vec in pert_map.items(): + key_str = str(key) + if control_label is not None and key_str == control_label: + continue + try: + parsed = ast.literal_eval(key_str) + except (ValueError, SyntaxError): + continue + if not isinstance(parsed, (list, tuple)) or len(parsed) != 1: + continue + entry = parsed[0] + if not isinstance(entry, (list, tuple)) or len(entry) < 2: + continue + pert_name = str(entry[0]) + unit = str(entry[2]) if len(entry) > 2 else None + try: + dose_val = float(entry[1]) + except (TypeError, ValueError): + continue + + base_key = (pert_name, unit) + base_info = canonical_vectors.setdefault( + base_key, + { + "template_key": key_str, + "unit": unit, + "existing": [], + "vector": vec, + }, + ) + existing: List[float] = base_info["existing"] # type: ignore[assignment] + if not any(almost_equal(dose_val, existing_dose) for existing_dose in existing): + existing.append(dose_val) + # Prefer first encountered vector as canonical; assume all are equivalent + + added_keys: List[str] = [] + for (pert_name, unit), info in canonical_vectors.items(): + vector: torch.Tensor = info["vector"] # type: ignore[assignment] + existing: List[float] = info["existing"] # type: ignore[assignment] + for dosage in requested_dosages: + if any(almost_equal(dosage, existing_dose) for existing_dose in existing): + continue + if unit is None: + new_entry = [(pert_name, float(dosage))] + else: + new_entry = [(pert_name, float(dosage), unit)] + key_str = str(new_entry) + if key_str in pert_map: + continue + pert_map[key_str] = vector.clone() + added_keys.append(key_str) + if added_keys and not quiet: + print(f"Extended perturbation map with {len(added_keys)} dosage variants.") + return added_keys + def prepare_batch( ctrl_basal_np: np.ndarray, pert_onehots: torch.Tensor, batch_indices: Optional[torch.Tensor], pert_names: List[str], device: torch.device, + pert_dosage: Optional[float] = None, ) -> Dict[str, torch.Tensor | List[str]]: """ Construct a model batch with variable-length sentence (B=1, S=T, ...). @@ -170,6 +303,14 @@ def prepare_batch( } if batch_indices is not None: batch["batch"] = batch_indices.to(device) # [T] + if pert_dosage is not None: + seq_len = X_batch.shape[0] + batch["pert_dosage"] = torch.full( + (seq_len,), + float(pert_dosage), + dtype=torch.float32, + device=device, + ) return batch def pad_adata_with_tsv( @@ -321,6 +462,13 @@ def pad_adata_with_tsv( control_pert = "non-targeting" if not args.quiet: print(f"Control perturbation: {control_pert}") + control_pert_str = str(control_pert) + + requested_dosages = parse_dosage_argument(args.dosages) + if requested_dosages and not args.quiet: + print(f"Requested dosages: {requested_dosages}") + if requested_dosages and not args.all_perts and not args.quiet: + print("Note: --dosages provided without --all-perts; only dosages present in AnnData will be simulated.") # choose cell type column if args.celltype_col is None: @@ -361,6 +509,16 @@ def pad_adata_with_tsv( if not os.path.exists(pert_onehot_map_path): raise FileNotFoundError(f"Missing pert_onehot_map.pt at {pert_onehot_map_path}") pert_onehot_map: Dict[str, torch.Tensor] = torch.load(pert_onehot_map_path, weights_only=False) + added_dosage_keys = extend_perturbation_map_for_dosages( + pert_map=pert_onehot_map, + requested_dosages=requested_dosages, + control_label=control_pert_str, + quiet=args.quiet, + ) + if requested_dosages and not added_dosage_keys and not args.quiet: + print("No new dosage variants were added; requested values may already exist in the perturbation map.") + pert_name_lookup: Dict[str, object] = {str(k): k for k in pert_onehot_map.keys()} + pert_names_in_map: List[str] = list(pert_name_lookup.keys()) batch_onehot_map_path = os.path.join(args.model_dir, "batch_onehot_map.pkl") batch_onehot_map = None @@ -423,6 +581,145 @@ def pad_adata_with_tsv( if not args.quiet: print(f"Filtered to {adata.n_obs} cells (from {n0}) for cell types: {keep_cts}") + needs_virtual_padding = args.all_perts or (args.min_cells is not None) or (args.max_cells is not None) + if needs_virtual_padding: + if args.pert_col not in adata.obs: + raise KeyError(f"Perturbation column '{args.pert_col}' not found in adata.obs") + + adata.obs = adata.obs.copy() + adata.obs[args.pert_col] = adata.obs[args.pert_col].astype(str) + + # optionally expand controls to cover every perturbation in the map + if args.all_perts: + observed_perts = set(adata.obs[args.pert_col].values) + missing_perts = [p for p in pert_names_in_map if p not in observed_perts] + + if missing_perts: + ctrl_mask_all_perts = adata.obs[args.pert_col] == control_pert_str + if not bool(np.any(ctrl_mask_all_perts)): + raise ValueError( + "--all-perts requested, but no control cells are available to template new perturbations." + ) + + ctrl_template = adata[ctrl_mask_all_perts].copy() + ctrl_template.obs = ctrl_template.obs.copy() + ctrl_template.obs[args.pert_col] = ctrl_template.obs[args.pert_col].astype(str) + + if args.virtual_cells_per_pert is not None: + if args.virtual_cells_per_pert <= 0: + raise ValueError("--virtual-cells-per-pert must be a positive integer if provided.") + if ctrl_template.n_obs > args.virtual_cells_per_pert: + virtual_rng = np.random.RandomState(args.seed) + sampled_idx = virtual_rng.choice( + ctrl_template.n_obs, size=args.virtual_cells_per_pert, replace=False + ) + ctrl_template = ctrl_template[sampled_idx].copy() + ctrl_template.obs = ctrl_template.obs.copy() + ctrl_template.obs[args.pert_col] = ctrl_template.obs[args.pert_col].astype(str) + if not args.quiet: + print( + "--all-perts: limiting virtual control template to " + f"{ctrl_template.n_obs} cells per perturbation (requested {args.virtual_cells_per_pert})." + ) + + virtual_blocks: List["sc.AnnData"] = [] + for pert_name in missing_perts: + clone = ctrl_template.copy() + clone.obs = clone.obs.copy() + clone.obs[args.pert_col] = pert_name + clone.obs_names = [f"{obs_name}__virt_{pert_name}" for obs_name in clone.obs_names] + virtual_blocks.append(clone) + + adata = sc.concat([adata, *virtual_blocks], axis=0, join="inner") + + if not args.quiet: + preview = ", ".join(missing_perts[:5]) + if len(missing_perts) > 5: + preview += ", ..." + print( + f"Added virtual control copies for {len(missing_perts)} perturbations" + f" ({preview if preview else 'n/a'}). Total cells: {adata.n_obs}." + ) + elif not args.quiet: + print("--all-perts requested, but all perturbations already present in AnnData.") + + # ensure each perturbation meets the minimum count by cloning controls + if args.min_cells is not None: + if args.min_cells <= 0: + raise ValueError("--min-cells must be a positive integer if provided.") + + ctrl_mask_min_cells = adata.obs[args.pert_col] == control_pert_str + if not bool(np.any(ctrl_mask_min_cells)): + raise ValueError("--min-cells requested, but no control cells are available for cloning.") + + pad_rng = np.random.RandomState(args.seed) + ctrl_pool = adata[ctrl_mask_min_cells].copy() + ctrl_pool.obs = ctrl_pool.obs.copy() + virtual_blocks: List["sc.AnnData"] = [] + + pert_counts = adata.obs[args.pert_col].value_counts() + for pert_name, count in pert_counts.items(): + deficit = int(args.min_cells) - int(count) + if deficit <= 0: + continue + + sampled_idx = pad_rng.choice(ctrl_pool.n_obs, size=deficit, replace=True) + clone = ctrl_pool[sampled_idx].copy() + clone.obs = clone.obs.copy() + clone.obs[args.pert_col] = pert_name + base_names = list(clone.obs_names) + clone.obs_names = [f"{obs_name}__virt_{pert_name}__pad{idx + 1}" for idx, obs_name in enumerate(base_names)] + virtual_blocks.append(clone) + + if virtual_blocks: + adata = sc.concat([adata, *virtual_blocks], axis=0, join="inner") + if not args.quiet: + preview = ", ".join( + [f"{pert}:{args.min_cells}" for pert, cnt in pert_counts.items() if int(cnt) < int(args.min_cells)][ + :5 + ] + ) + if len(virtual_blocks) > 5: + preview += ", ..." + total_added = sum(vb.n_obs for vb in virtual_blocks) + print( + f"Added {total_added} padding cells to meet --min-cells " + f"(examples: {preview if preview else 'n/a'}). Total cells: {adata.n_obs}." + ) + elif not args.quiet: + print("--min-cells set, but all perturbations already meet the threshold.") + + # cap the number of cells per perturbation by subsampling + if args.max_cells is not None: + if args.max_cells <= 0: + raise ValueError("--max-cells must be a positive integer if provided.") + if args.min_cells is not None and args.max_cells < args.min_cells: + raise ValueError("--max-cells cannot be smaller than --min-cells.") + + trim_rng = np.random.RandomState(args.seed + 1) + keep_mask = np.ones(adata.n_obs, dtype=bool) + pert_labels = adata.obs[args.pert_col].values + + unique_perts = np.unique(pert_labels) + for pert_name in unique_perts: + idxs = np.where(pert_labels == pert_name)[0] + if len(idxs) <= args.max_cells: + continue + + chosen = trim_rng.choice(idxs, size=args.max_cells, replace=False) + chosen = np.sort(chosen) + drop = np.setdiff1d(idxs, chosen, assume_unique=True) + keep_mask[drop] = False + + if not np.all(keep_mask): + original_n = adata.n_obs + adata = adata[keep_mask].copy() + if not args.quiet: + total_dropped = original_n - adata.n_obs + print( + f"Subsampled perturbations exceeding --max-cells; dropped {total_dropped} cells. Total cells: {adata.n_obs}." + ) + # select features: embeddings or genes if args.embed_key is None: X_in = to_dense(adata.X) # [N, E_in] @@ -491,7 +788,7 @@ def pad_adata_with_tsv( rng = np.random.RandomState(args.seed) # Identify control vs non-control - ctl_mask = pert_names_all == str(control_pert) + ctl_mask = pert_names_all == control_pert_str n_controls = int(ctl_mask.sum()) n_total = adata.n_obs n_nonctl = n_total - n_controls @@ -500,12 +797,26 @@ def pad_adata_with_tsv( # Where we will write predictions (initialize with originals; we overwrite all rows, including controls) if writes_to[0] == ".X": - sim_X = X_in.copy() + sim_X = X_in.astype(np.float32, copy=True) out_target = "X" else: - sim_obsm = X_in.copy() + sim_obsm = X_in.astype(np.float32, copy=True) out_target = f"obsm['{writes_to[1]}']" + counts_expected = output_space in {"gene", "all"} + counts_out_target: Optional[str] = None + counts_obsm_key: Optional[str] = None + sim_counts: Optional[np.ndarray] = None + counts_written = False + + if output_space == "gene": + counts_out_target = "obsm['X_hvg']" + counts_obsm_key = "X_hvg" + elif output_space == "all": + counts_out_target = "X" + if writes_to[0] == ".X": + sim_counts = sim_X + # Group labels for set-to-set behavior if args.celltype_col and args.celltype_col in adata.obs: group_labels = adata.obs[args.celltype_col].astype(str).values @@ -525,8 +836,9 @@ def group_control_indices(group_name: str) -> np.ndarray: return grp_ctl if len(grp_ctl) > 0 else all_control_indices # default pert vector when unmapped label shows up - if control_pert in pert_onehot_map: - default_pert_vec = pert_onehot_map[control_pert].float().clone() + control_map_key = pert_name_lookup.get(control_pert_str, control_pert) + if control_map_key in pert_onehot_map: + default_pert_vec = pert_onehot_map[control_map_key].float().clone() else: default_pert_vec = torch.zeros(pert_dim, dtype=torch.float32) if pert_dim and pert_dim > 0: @@ -568,11 +880,17 @@ def group_control_indices(group_name: str) -> np.ndarray: continue # one-hot vector for this perturbation (repeat across window) - vec = pert_onehot_map.get(p, None) + map_key = pert_name_lookup.get(p, p) + vec = pert_onehot_map.get(map_key, None) if vec is None: vec = default_pert_vec if not args.quiet: print(f" (group {g}) pert '{p}' not in mapping; using control fallback one-hot.") + dosage_value = ( + StateTransitionPerturbationModel._parse_dosage_from_name(p) + if getattr(model, "use_dosage_encoder", False) + else None + ) start = 0 while start < len(idxs): @@ -600,6 +918,7 @@ def group_control_indices(group_name: str) -> np.ndarray: batch_indices=bi, pert_names=[p] * win_size, device=model_device, + pert_dosage=dosage_value, ) batch_out = model.predict_step(batch, batch_idx=0, padded=False) @@ -615,6 +934,43 @@ def group_control_indices(group_name: str) -> np.ndarray: else: preds = batch_out["preds"].detach().cpu().numpy().astype(np.float32) # [win, D] + counts_preds = None + if counts_expected and ("pert_cell_counts_preds" in batch_out): + counts_tensor = batch_out.get("pert_cell_counts_preds") + if counts_tensor is not None: + counts_preds = counts_tensor.detach().cpu().numpy().astype(np.float32) + + if counts_preds is not None: + if sim_counts is None: + target_dim = counts_preds.shape[1] + if output_space == "gene": + if counts_obsm_key and counts_obsm_key in adata.obsm: + existing = np.asarray(adata.obsm[counts_obsm_key]) + if existing.shape[1] == target_dim: + sim_counts = existing.astype(np.float32, copy=True) + else: + if not args.quiet: + print( + f"Dimension mismatch for existing obsm['{counts_obsm_key}'] " + f"(got {existing.shape[1]} vs predictions {target_dim}). " + "Reinitializing storage with zeros." + ) + sim_counts = np.zeros((n_total, target_dim), dtype=np.float32) + else: + sim_counts = np.zeros((n_total, target_dim), dtype=np.float32) + else: # output_space == "all" + if writes_to[0] == ".X": + sim_counts = sim_X + else: + sim_counts = np.zeros((n_total, target_dim), dtype=np.float32) + if sim_counts.shape[1] != counts_preds.shape[1]: + raise ValueError( + "Predicted counts dimension mismatch: " + f"expected {sim_counts.shape[1]} but got {counts_preds.shape[1]}" + ) + sim_counts[idx_window, :] = counts_preds + counts_written = True + # 6) Write predictions for these rows (controls included) if writes_to[0] == ".X": if preds.shape[1] == sim_X.shape[1]: @@ -650,15 +1006,48 @@ def group_control_indices(group_name: str) -> np.ndarray: # ----------------------- # 5) Persist the updated AnnData # ----------------------- + output_path = args.output or args.adata.replace(".h5ad", "_simulated.h5ad") + output_is_npy = output_path.lower().endswith(".npy") + + if counts_expected and not counts_written and not args.quiet: + print( + "Warning: Model configured to produce gene counts, but no predicted counts were returned; " + "counts will not be saved." + ) + + pred_matrix = None if writes_to[0] == ".X": if out_target == "X": adata.X = sim_X + pred_matrix = sim_X + elif out_target.startswith("obsm['") and out_target.endswith("']"): + pred_key = out_target[6:-2] + pred_matrix = adata.obsm.get(pred_key) + else: + pred_matrix = sim_X else: if out_target == f"obsm['{writes_to[1]}']": adata.obsm[writes_to[1]] = sim_obsm - - output_path = args.output or args.adata.replace(".h5ad", "_simulated.h5ad") - adata.write_h5ad(output_path) + pred_matrix = sim_obsm + elif out_target.startswith("obsm['") and out_target.endswith("']"): + pred_key = out_target[6:-2] + pred_matrix = adata.obsm.get(pred_key) + else: + pred_matrix = sim_obsm + + if counts_written and sim_counts is not None: + if output_space == "gene": + key = counts_obsm_key or "X_hvg" + adata.obsm[key] = sim_counts + elif output_space == "all": + adata.X = sim_counts + + if output_is_npy: + if pred_matrix is None: + raise ValueError("Predictions matrix is unavailable; cannot write .npy output") + np.save(output_path, np.asarray(pred_matrix)) + else: + adata.write_h5ad(output_path) # ----------------------- # 6) Summary @@ -667,5 +1056,12 @@ def group_control_indices(group_name: str) -> np.ndarray: print(f"Input cells: {n_total}") print(f"Controls simulated: {n_controls}") print(f"Treated simulated: {n_nonctl}") - print(f"Wrote predictions to adata.{out_target}") - print(f"Saved: {output_path}") + if output_is_npy: + shape_str = " x ".join(str(dim) for dim in pred_matrix.shape) if pred_matrix is not None else "unknown" + print(f"Wrote predictions array (shape: {shape_str})") + print(f"Saved NumPy file: {output_path}") + else: + print(f"Wrote predictions to adata.{out_target}") + print(f"Saved: {output_path}") + if counts_written and counts_out_target: + print(f"Saved count predictions to adata.{counts_out_target}") diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index a41192cc..2182b702 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -12,6 +12,12 @@ def add_arguments_predict(parser: ap.ArgumentParser): required=True, help="Path to the output_dir containing the config.yaml file that was saved during training.", ) + parser.add_argument( + "--toml", + type=str, + default=None, + help="Optional path to a TOML data config to use instead of the training config.", + ) parser.add_argument( "--checkpoint", type=str, @@ -40,6 +46,12 @@ def add_arguments_predict(parser: ap.ArgumentParser): help="If set, only run prediction without evaluation metrics.", ) + parser.add_argument( + "--split-batch", + action="store_true", + help="If set, compute metrics separately for each (cell type, batch) pair.", + ) + parser.add_argument( "--shared-only", action="store_true", @@ -124,12 +136,34 @@ def load_config(cfg_path: str) -> dict: cfg = load_config(config_path) logger.info(f"Loaded config from {config_path}") + if args.toml: + data_section = cfg.get("data") + if data_section is None or "kwargs" not in data_section: + raise KeyError( + "The loaded config does not contain data.kwargs, unable to override toml_config_path." + ) + cfg["data"]["kwargs"]["toml_config_path"] = args.toml + logger.info("Overriding data.kwargs.toml_config_path to %s", args.toml) + # 2. Find run output directory & load data module run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) data_module_path = os.path.join(run_output_dir, "data_module.torch") if not os.path.exists(data_module_path): raise FileNotFoundError(f"Could not find data module at {data_module_path}?") data_module = PerturbationDataModule.load_state(data_module_path) + if args.toml: + if not os.path.exists(args.toml): + raise FileNotFoundError(f"Could not find TOML config file at {args.toml}") + from cell_load.config import ExperimentConfig + + logger.info("Reloading data module configuration from %s", args.toml) + data_module.toml_config_path = args.toml + data_module.config = ExperimentConfig.from_toml(args.toml) + data_module.config.validate() + data_module.train_datasets = [] + data_module.val_datasets = [] + data_module.test_datasets = [] + data_module._setup_global_maps() data_module.setup(stage="test") logger.info("Loaded data module from %s", data_module_path) @@ -175,6 +209,10 @@ def load_config(cfg_path: str) -> dict: from ...tx.models.decoder_only import DecoderOnlyPerturbationModel ModelClass = DecoderOnlyPerturbationModel + elif model_class_name.lower() == "pseudobulk": + from ...tx.models.pseudobulk import PseudobulkPerturbationModel + + ModelClass = PseudobulkPerturbationModel else: raise ValueError(f"Unknown model class: {model_class_name}") @@ -284,17 +322,70 @@ def load_config(cfg_path: str) -> dict: else: all_celltypes.append(batch_preds["celltype_name"]) - # Handle gem_group - if isinstance(batch_preds["batch"], list): - all_gem_groups.extend([str(x) for x in batch_preds["batch"]]) - elif isinstance(batch_preds["batch"], torch.Tensor): - all_gem_groups.extend([str(x) for x in batch_preds["batch"].cpu().numpy()]) - else: - all_gem_groups.append(str(batch_preds["batch"])) + batch_size = batch_preds["preds"].shape[0] + + # Handle gem_group - prefer human-readable batch names when available + def normalize_batch_labels(values): + if values is None: + return None + if isinstance(values, torch.Tensor): + values = values.detach().cpu().numpy() + if isinstance(values, np.ndarray): + if values.ndim == 2: + if values.shape[0] != batch_size: + return None + if values.shape[1] == 1: + flat = values.reshape(batch_size) + return [str(x) for x in flat.tolist()] + indices = values.argmax(axis=1) + return [str(int(x)) for x in indices.tolist()] + if values.ndim == 1: + if values.shape[0] != batch_size: + return None + return [str(x) for x in values.tolist()] + if values.ndim == 0: + return [str(values.item())] * batch_size + return None + if isinstance(values, (list, tuple)): + if len(values) != batch_size: + return None + normalized = [] + for item in values: + if isinstance(item, torch.Tensor): + item = item.detach().cpu().numpy() + if isinstance(item, np.ndarray): + if item.ndim == 0: + normalized.append(str(item.item())) + continue + if item.ndim == 1: + if item.size == 1: + normalized.append(str(item.item())) + elif np.count_nonzero(item) == 1: + normalized.append(str(int(item.argmax()))) + else: + normalized.append(str(item.tolist())) + continue + normalized.append(str(item)) + return normalized + return [str(values)] * batch_size + + batch_name_candidates = ( + batch.get("batch_name"), + batch_preds.get("batch_name"), + batch_preds.get("batch"), + ) + + batch_labels = None + for candidate in batch_name_candidates: + batch_labels = normalize_batch_labels(candidate) + if batch_labels is not None: + break + if batch_labels is None: + batch_labels = ["None"] * batch_size + all_gem_groups.extend(batch_labels) batch_pred_np = batch_preds["preds"].cpu().numpy().astype(np.float32) batch_real_np = batch_preds["pert_cell_emb"].cpu().numpy().astype(np.float32) - batch_size = batch_pred_np.shape[0] final_preds[current_idx : current_idx + batch_size, :] = batch_pred_np final_reals[current_idx : current_idx + batch_size, :] = batch_real_np current_idx += batch_size @@ -328,16 +419,16 @@ def load_config(cfg_path: str) -> dict: var = pd.DataFrame({"gene_names": gene_names}) if final_X_hvg is not None: - if len(gene_names) != final_pert_cell_counts_preds.shape[1]: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True - ) - var = pd.DataFrame({"gene_names": gene_names}) + # if len(gene_names) != final_pert_cell_counts_preds.shape[1]: + # gene_names = np.load( + # "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True + # ) + # var = pd.DataFrame({"gene_names": gene_names}) # Create adata for predictions - using the decoded gene expression values - adata_pred = anndata.AnnData(X=final_pert_cell_counts_preds, obs=obs, var=var) + adata_pred = anndata.AnnData(X=final_pert_cell_counts_preds, obs=obs) # Create adata for real - using the true gene expression values - adata_real = anndata.AnnData(X=final_X_hvg, obs=obs, var=var) + adata_real = anndata.AnnData(X=final_X_hvg, obs=obs) # add the embedding predictions adata_pred.obsm[data_module.embed_key] = final_preds @@ -384,7 +475,10 @@ def load_config(cfg_path: str) -> dict: ) # Save the AnnData objects - results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) + if args.eval_train_data: + results_dir = os.path.join(args.output_dir, "eval_train_" + os.path.basename(args.checkpoint)) + else: + results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) os.makedirs(results_dir, exist_ok=True) adata_pred_path = os.path.join(results_dir, "adata_pred.h5ad") adata_real_path = os.path.join(results_dir, "adata_real.h5ad") @@ -409,6 +503,7 @@ def load_config(cfg_path: str) -> dict: ) pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) + for ct in ct_split_real.keys(): real_ct = ct_split_real[ct] pred_ct = ct_split_pred[ct] diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 5db050e8..385451ea 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -108,6 +108,11 @@ def run_tx_train(cfg: DictConfig): elif cfg["model"]["name"].lower() == "scvi": cfg["data"]["kwargs"]["transform"] = None + output_space = cfg["data"]["kwargs"].get("output_space", "gene") + assert output_space in {"embedding", "gene", "all"}, ( + f"data.kwargs.output_space must be one of 'embedding', 'gene', or 'all'; got {output_space!r}" + ) + data_module: PerturbationDataModule = get_datamodule( cfg["data"]["name"], cfg["data"]["kwargs"], @@ -125,23 +130,27 @@ def run_tx_train(cfg: DictConfig): print("batch size:", dl.batch_size) var_dims = data_module.get_var_dims() # {"gene_dim": …, "hvg_dim": …} - if cfg["data"]["kwargs"]["output_space"] == "gene": + if output_space == "gene": gene_dim = var_dims.get("hvg_dim", 2000) # fallback if key missing else: gene_dim = var_dims.get("gene_dim", 2000) # fallback if key missing latent_dim = var_dims["output_dim"] # same as model.output_dim hidden_dims = cfg["model"]["kwargs"].get("decoder_hidden_dims", [1024, 1024, 512]) - decoder_cfg = dict( - latent_dim=latent_dim, - gene_dim=gene_dim, - hidden_dims=hidden_dims, - dropout=cfg["model"]["kwargs"].get("decoder_dropout", 0.1), - residual_decoder=cfg["model"]["kwargs"].get("residual_decoder", False), - ) + if output_space in {"gene", "all"}: + decoder_cfg = dict( + latent_dim=latent_dim, + gene_dim=gene_dim, + hidden_dims=hidden_dims, + dropout=cfg["model"]["kwargs"].get("decoder_dropout", 0.1), + residual_decoder=cfg["model"]["kwargs"].get("residual_decoder", False), + ) - # tuck it into the kwargs that will reach the LightningModule - cfg["model"]["kwargs"]["decoder_cfg"] = decoder_cfg + # tuck it into the kwargs that will reach the LightningModule + cfg["model"]["kwargs"]["decoder_cfg"] = decoder_cfg + else: + cfg["model"]["kwargs"].pop("decoder_cfg", None) + cfg["model"]["kwargs"]["gene_decoder_bool"] = False # Save the onehot maps as pickle files instead of storing in config cell_type_onehot_map_path = join(run_output_dir, "cell_type_onehot_map.pkl") @@ -225,7 +234,7 @@ def run_tx_train(cfg: DictConfig): callbacks.append(mfu_cb) - # Add CumulativeFLOPSCallback to track cumulative FLOPs + if "cumulative_flops_use_backward" in cfg["training"] and cfg["model"]["name"] == "state": cumulative_flops_use_backward = cfg["training"]["cumulative_flops_use_backward"] cumulative_flops_cb = CumulativeFLOPSCallback(use_backward=cumulative_flops_use_backward) callbacks.append(cumulative_flops_cb) @@ -259,14 +268,14 @@ def run_tx_train(cfg: DictConfig): plugins=plugins, callbacks=callbacks, gradient_clip_val=cfg["training"]["gradient_clip_val"] if cfg["model"]["name"].lower() != "cpa" else None, - use_distributed_sampler=False, # Prevent Lightning from wrapping PerturbationBatchSampler with DistributedSampler + accumulate_grad_batches=cfg["training"].get("gradient_accumulation_steps", 1), + use_distributed_sampler=False, ) # Align logging cadence with rolling MFU window (and W&B logging) if "log_every_n_steps" in cfg["training"]: trainer_kwargs["log_every_n_steps"] = cfg["training"]["log_every_n_steps"] - # Build trainer print(f"Building trainer with kwargs: {trainer_kwargs}") trainer = pl.Trainer(**trainer_kwargs) @@ -332,7 +341,9 @@ def run_tx_train(cfg: DictConfig): pert_encoder_weight_key = "pert_encoder.0.weight" if pert_encoder_weight_key in checkpoint_state: checkpoint_pert_dim = checkpoint_state[pert_encoder_weight_key].shape[1] - if checkpoint_pert_dim != model.pert_dim: + + # if the cell embedding dim doesn't match, or if it was HVGs, rebuild for transfer learning + if checkpoint_pert_dim != model.pert_dim or cfg["data"]["kwargs"]["embed_key"] == "X_hvg": print( f"pert_encoder input dimension mismatch: model.pert_dim = {model.pert_dim} but checkpoint expects {checkpoint_pert_dim}. Overriding model's pert_dim and rebuilding pert_encoder." ) diff --git a/src/state/configs/model/pseudobulk.yaml b/src/state/configs/model/pseudobulk.yaml new file mode 100644 index 00000000..8d22cead --- /dev/null +++ b/src/state/configs/model/pseudobulk.yaml @@ -0,0 +1,54 @@ +name: pseudobulk +checkpoint: null +device: cuda + +kwargs: + cell_set_len: 512 + blur: 0.05 + hidden_dim: 768 # hidden dimension going into the transformer backbone + loss: energy + confidence_head: False + n_encoder_layers: 1 + n_decoder_layers: 1 + predict_residual: True + softplus: True + freeze_pert_backbone: False + transformer_decoder: False + finetune_vci_decoder: False + residual_decoder: False + batch_encoder: False + nb_decoder: False + mask_attn: False + use_effect_gating_token: False + distributional_loss: energy + init_from: null + transformer_backbone_key: llama + transformer_backbone_kwargs: + max_position_embeddings: ${model.kwargs.cell_set_len} + n_positions: ${model.kwargs.cell_set_len} + hidden_size: ${model.kwargs.hidden_dim} + intermediate_size: 3072 + num_hidden_layers: 8 + num_attention_heads: 12 + num_key_value_heads: 12 + head_dim: 64 + use_cache: false + attention_dropout: 0.0 + hidden_dropout: 0.0 + layer_norm_eps: 1e-6 + pad_token_id: 0 + bos_token_id: 1 + eos_token_id: 2 + tie_word_embeddings: false + rotary_dim: 0 + use_rotary_embeddings: false + lora: + enable: false + r: 16 + alpha: 32 + dropout: 0.05 + bias: none + target: auto + adapt_mlp: false + task_type: FEATURE_EXTRACTION + merge_on_eval: false diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index e9b3e34d..e47217e6 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -5,9 +5,9 @@ device: cuda kwargs: cell_set_len: 512 blur: 0.05 - hidden_dim: 696 # hidden dimension going into the transformer backbone + hidden_dim: 768 loss: energy - confidence_head: False + confidence_token: False n_encoder_layers: 1 n_decoder_layers: 1 predict_residual: True @@ -20,19 +20,30 @@ kwargs: use_batch_token: False nb_decoder: False mask_attn: False + + # --- Dose handling --- + dosage: False + hill_prior: False + dose_momentum: 0.01 # NEW: EMA for log10-dose mean/std + dose_strength_init: 1.0 # NEW: initial strength of FiLM modulation + dose_smooth_weight: 0.01 # NEW: curvature penalty across doses + use_effect_gating_token: False distributional_loss: energy init_from: null + mmd_num_chunks: 1 + randomize_mmd_chunks: false + transformer_backbone_key: llama transformer_backbone_kwargs: - bidirectional_attention: false + bidirectional_attention: true # was false; matches the CLI you used max_position_embeddings: ${model.kwargs.cell_set_len} hidden_size: ${model.kwargs.hidden_dim} - intermediate_size: 2784 + intermediate_size: 3072 num_hidden_layers: 8 num_attention_heads: 12 num_key_value_heads: 12 - head_dim: 58 + head_dim: 64 use_cache: false attention_dropout: 0.0 hidden_dropout: 0.0 @@ -43,6 +54,7 @@ kwargs: tie_word_embeddings: false rotary_dim: 0 use_rotary_embeddings: false + lora: enable: false r: 16 diff --git a/src/state/configs/model/state_sm.yaml b/src/state/configs/model/state_sm.yaml index 77ddfd1f..11fd84ef 100644 --- a/src/state/configs/model/state_sm.yaml +++ b/src/state/configs/model/state_sm.yaml @@ -24,6 +24,8 @@ kwargs: distributional_loss: energy gene_decoder_bool: False init_from: null + mmd_num_chunks: 1 + randomize_mmd_chunks: false transformer_backbone_key: llama transformer_backbone_kwargs: bidirectional_attention: false diff --git a/src/state/configs/state-defaults.yaml b/src/state/configs/state-defaults.yaml index 8414ec8b..8968115b 100644 --- a/src/state/configs/state-defaults.yaml +++ b/src/state/configs/state-defaults.yaml @@ -19,8 +19,8 @@ experiment: ddp_timeout: 3600 checkpoint: path: /scratch/ctc/ML/vci/checkpoint/pretrain - save_top_k: 4 - monitor: trainer/train_loss + save_top_k: 2 + monitor: validation/val_loss every_n_train_steps: 1000 wandb: enable: true diff --git a/src/state/configs/training/default.yaml b/src/state/configs/training/default.yaml index 3b31cd27..a1fe5d7d 100644 --- a/src/state/configs/training/default.yaml +++ b/src/state/configs/training/default.yaml @@ -7,6 +7,7 @@ train_seed: 42 val_freq: 2000 ckpt_every_n_steps: 2000 gradient_clip_val: 10 # 0 means no clipping +gradient_accumulation_steps: 1 loss_fn: mse devices: 1 # Number of GPUs to use for training strategy: auto # DDP strategy for multi-GPU training @@ -16,4 +17,4 @@ mfu_kwargs: use_backward: true logging_interval: 10 window_size: 2 -cumulative_flops_use_backward: true \ No newline at end of file +cumulative_flops_use_backward: true diff --git a/src/state/emb/finetune_decoder.py b/src/state/emb/finetune_decoder.py index 8fdaf819..50880a9d 100644 --- a/src/state/emb/finetune_decoder.py +++ b/src/state/emb/finetune_decoder.py @@ -1,6 +1,11 @@ +# src/state/emb/finetune_decoder.py + import logging +from typing import Dict, List, Optional, Tuple + import torch from torch import nn +from omegaconf import OmegaConf from vci.nn.model import StateEmbeddingModel from vci.train.trainer import get_embeddings @@ -9,150 +14,317 @@ log = logging.getLogger(__name__) -class Finetune: - def __init__(self, cfg, learning_rate=1e-4): +class Finetune(nn.Module): + def __init__( + self, + cfg: Optional[OmegaConf] = None, + learning_rate: float = 1e-4, + read_depth: float = 4.0, + train_binary_decoder: bool = False, + ): """ - Initialize the Finetune class for fine-tuning the binary decoder of a pre-trained model. - - Parameters: - ----------- - cfg : OmegaConf - Configuration object containing model settings - learning_rate : float - Learning rate for fine-tuning the binary decoder + Helper module that loads a pretrained SE/VCI checkpoint and exposes: + - get_gene_embedding(genes): returns gene/task embeddings with differentiable + replacement for any genes missing from pretrained protein embeddings + - get_counts(cell_embs, genes): runs the pretrained binary decoder in a vectorized way + + Args: + cfg: OmegaConf for the SE model (if not embedded in checkpoint) + learning_rate: (kept for API compatibility; not used directly here) + read_depth: initial value for a learnable read depth scalar (if RDA enabled) """ - self.model = None + super().__init__() + self.model: Optional[StateEmbeddingModel] = None self.collator = None - self.protein_embeds = None + self.protein_embeds: Optional[Dict[str, torch.Tensor]] = None self._vci_conf = cfg self.learning_rate = learning_rate - self.cached_gene_embeddings = {} - self.device = None - - def load_model(self, checkpoint): - """ - Load a pre-trained model from a checkpoint and prepare it for fine-tuning. - - Parameters: - ----------- - checkpoint : str - Path to the checkpoint file - """ - if self.model: - raise ValueError("Model already initialized") - - # Import locally to avoid circular imports - - # Load and initialize model for eval - self.model = StateEmbeddingModel.load_from_checkpoint(checkpoint, strict=False) - - # Ensure model uses the provided config, not the stored one - if self._vci_conf is not None: - self.model.update_config(self._vci_conf) - - self.device = self.model.device - - # Load protein embeddings - all_pe = get_embeddings(self._vci_conf) - all_pe.requires_grad = False - self.model.pe_embedding = nn.Embedding.from_pretrained(all_pe) - self.model.pe_embedding.to(self.device) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.train_binary_decoder = train_binary_decoder - # Load protein embeddings - self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings) + # --- Learnable read-depth scalar used when RDA is enabled --- + self.read_depth = nn.Parameter(torch.tensor(float(read_depth), dtype=torch.float), requires_grad=True) - # Freeze all parameters - for param in self.model.parameters(): - param.requires_grad = False + # --- Caching & state for gene embeddings and missing-gene handling --- + self.cached_gene_embeddings: Dict[Tuple[str, ...], torch.Tensor] = {} - # Enable gradients only for binary decoder - for param in self.model.binary_decoder.parameters(): - param.requires_grad = False + self.missing_table: Optional[nn.Embedding] = None + self._last_missing_count: int = 0 + self._last_missing_dim: int = 0 - # Ensure the binary decoder is in training mode so gradients are enabled. - self.model.binary_decoder.eval() + # Cache present masks and index maps per gene set + self._present_mask_cache: Dict[Tuple[str, ...], torch.Tensor] = {} + self._missing_index_map_cache: Dict[Tuple[str, ...], torch.Tensor] = {} - def get_gene_embedding(self, genes): + # ------------------------- + # Loading / setup + # ------------------------- + def load_model(self, checkpoint: str): """ - Get embeddings for a list of genes, with caching to avoid recomputation. - - Parameters: - ----------- - genes : list - List of gene names/identifiers - - Returns: - -------- - torch.Tensor - Tensor of gene embeddings + Load a pre-trained SE model from a single checkpoint path and prepare it. + Prefers embedded cfg in checkpoint; falls back to provided cfg if needed. """ - # Cache key based on genes tuple - cache_key = tuple(genes) - - # Return cached embeddings if available - if cache_key in self.cached_gene_embeddings: - return self.cached_gene_embeddings[cache_key] + if self.model is not None: + raise ValueError("Model already initialized") - # Compute gene embeddings - protein_embeds = [self.protein_embeds[x] if x in self.protein_embeds else torch.zeros(5120) for x in genes] - protein_embeds = torch.stack(protein_embeds).to(self.device) - gene_embeds = self.model.gene_embedding_layer(protein_embeds) + # Resolve configuration: prefer embedded cfg in checkpoint + cfg_to_use = self._vci_conf + if cfg_to_use is None: + try: + ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False) + if isinstance(ckpt, dict) and "cfg_yaml" in ckpt: + cfg_to_use = OmegaConf.create(ckpt["cfg_yaml"]) # type: ignore + elif isinstance(ckpt, dict) and "hyper_parameters" in ckpt: + hp = ckpt.get("hyper_parameters", {}) or {} + if isinstance(hp, dict) and len(hp) > 0: + try: + cfg_to_use = OmegaConf.create(hp["cfg"]) if "cfg" in hp else OmegaConf.create(hp) + except Exception: + cfg_to_use = OmegaConf.create(hp) + except Exception as e: + log.warning(f"Could not extract config from checkpoint: {e}") + if cfg_to_use is None: + raise ValueError( + "No config found in checkpoint and no override provided. " + "Provide SE cfg or a full checkpoint with embedded config." + ) + self._vci_conf = cfg_to_use + + # Load model; allow passing cfg to constructor like inference + self.model = StateEmbeddingModel.load_from_checkpoint(checkpoint, dropout=0.0, strict=False, cfg=self._vci_conf) + self.device = self.model.device # type: ignore + + # Try to extract packaged protein embeddings from checkpoint + packaged_pe = None + try: + ckpt2 = torch.load(checkpoint, map_location="cpu", weights_only=False) + if isinstance(ckpt2, dict) and "protein_embeds_dict" in ckpt2: + packaged_pe = ckpt2["protein_embeds_dict"] + except Exception: + pass + + # Resolve protein embeddings for pe_embedding weights + all_pe = packaged_pe or get_embeddings(self._vci_conf) + if isinstance(all_pe, dict): + # For the model's token embedding table, we only need the stacked array. + stacked = torch.vstack(list(all_pe.values())) + else: + stacked = all_pe + stacked.requires_grad = False + self.model.pe_embedding = nn.Embedding.from_pretrained(stacked) # type: ignore + self.model.pe_embedding.to(self.device) # type: ignore + + # Keep a mapping from gene name -> raw protein embedding vector + self.protein_embeds = packaged_pe + if self.protein_embeds is None: + # Fallback to configured path on disk + self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings, weights_only=False) + + # Freeze SE model; optionally unfreeze just the binary decoder + for p in self.model.parameters(): # type: ignore + p.requires_grad = False + for p in self.model.binary_decoder.parameters(): # type: ignore + p.requires_grad = self.train_binary_decoder + self.model.binary_decoder.train(mode=self.train_binary_decoder) # type: ignore + + # ------------------------- + # Gene utilities + # ------------------------- + def _auto_detect_gene_column(self, adata): + """Auto-detect the gene column with highest overlap with protein embeddings.""" + if self.protein_embeds is None: + log.warning("No protein embeddings available for auto-detection, using index") + return None + + protein_genes = set(self.protein_embeds.keys()) + best_column = None + best_overlap = 0 + + # Index first + index_genes = set(getattr(adata.var, "index", [])) + overlap = len(protein_genes.intersection(index_genes)) + if overlap > best_overlap: + best_overlap = overlap + best_column = None # None => use index + + # All columns + for col in adata.var.columns: + try: + col_vals = adata.var[col].dropna().astype(str) + except Exception: + continue + col_genes = set(col_vals) + overlap = len(protein_genes.intersection(col_genes)) + if overlap > best_overlap: + best_overlap = overlap + best_column = col + + return best_column + + def genes_from_adata(self, adata) -> List[str]: + """Return list of gene names from AnnData using auto-detected column/index.""" + col = self._auto_detect_gene_column(adata) + if col is None: + return list(map(str, adata.var.index.values)) + return list(adata.var[col].astype(str).values) + + def _ensure_missing_table( + self, + genes_key: Tuple[str, ...], + gene_embed_dim: int, + present_mask: torch.Tensor, + ): + """ + Make sure self.missing_table matches the current gene set's missing count & dim. + Builds a per-position index map (for missing genes) and caches the mask + map. + """ + # Build / cache index map for missing positions (pos -> 0..(n_missing-1)) + if genes_key in self._missing_index_map_cache and genes_key in self._present_mask_cache: + return # already prepared for this gene set + + # Identify missing positions + present = present_mask.bool().tolist() + missing_positions = [i for i, ok in enumerate(present) if not ok] + n_missing = len(missing_positions) + + # Cache mask for this gene set (on device) + self._present_mask_cache[genes_key] = present_mask + + if n_missing == 0: + # No missing genes -> trivial index map of zeros (unused) + self._missing_index_map_cache[genes_key] = torch.zeros(len(genes_key), dtype=torch.long, device=self.device) + return + + # (Re)create the missing table if shape changed + if ( + self.missing_table is None + or self._last_missing_count != n_missing + or self._last_missing_dim != gene_embed_dim + ): + self.missing_table = nn.Embedding(n_missing, gene_embed_dim) + nn.init.normal_(self.missing_table.weight, mean=0.0, std=0.02) + # Ensure the embedding table lives on the same device as inputs/masks + self.missing_table.to(present_mask.device) + self._last_missing_count = n_missing + self._last_missing_dim = gene_embed_dim + + # Build a position -> compact missing index map + inv = {pos: j for j, pos in enumerate(missing_positions)} + idx_map = [inv.get(i, 0) for i in range(len(genes_key))] + self._missing_index_map_cache[genes_key] = torch.tensor(idx_map, dtype=torch.long, device=present_mask.device) + + def get_gene_embedding(self, genes: List[str]) -> torch.Tensor: + """ + Return gene/task embeddings for 'genes'. + For genes missing from the pretrained protein embeddings dictionary, we replace the + post-ESM embedding with a learnable vector from `self.missing_table` via torch.where. + + Caching: + - If no genes are missing, the post-ESM embeddings are cached and reused. + - If some genes are missing, we recompute each call so gradients flow into + self.missing_table (no caching of the final tensor). + """ + if self.model is None: + raise RuntimeError("Model not loaded. Call load_model(checkpoint) first.") + if self.protein_embeds is None: + # Should have been set in load_model; keep a defensive fallback: + self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings, weights_only=False) + + genes_key = tuple(genes) + + # Fast path: if we saw this gene set before and no missing genes were involved, reuse cache + if genes_key in self.cached_gene_embeddings: + return self.cached_gene_embeddings[genes_key].to(self.device) + + # Build a [G, embed_size] tensor of raw protein embeddings (zeros where missing) + # Determine the raw protein embedding size + try: + example_vec = next(iter(self.protein_embeds.values())) + embed_size = int(example_vec.shape[-1]) + except Exception: + embed_size = get_embedding_cfg(self._vci_conf).size # fallback + + raw_list = [ + self.protein_embeds[g] if g in self.protein_embeds else torch.zeros(embed_size) # type: ignore + for g in genes + ] + protein_embeds = torch.stack(raw_list).to(self.device) + + # Project through the model's gene embedding layer (post-ESM projection) + gene_embeds_raw = self.model.gene_embedding_layer(protein_embeds) # type: ignore # [G, d_model] + gene_embeds_raw = gene_embeds_raw.to(self.device) + d_model = int(gene_embeds_raw.shape[-1]) + + # Present mask: True where gene exists in pretrained protein_embeds + present_mask = torch.tensor([g in self.protein_embeds for g in genes], device=self.device).unsqueeze(1) + + # Prepare missing-table and position index map if needed + self._ensure_missing_table(genes_key, d_model, present_mask.squeeze(1)) + + # If we have a non-empty missing_table for this gene set, replace missing rows + idx_map = self._missing_index_map_cache[genes_key] + # Safety: if the missing table exists but is on a different device, move it + if self.missing_table is not None and self.missing_table.weight.device != idx_map.device: + self.missing_table.to(idx_map.device) + if self.missing_table is not None and self._last_missing_count > 0: + learned_full = self.missing_table(idx_map) # [G, d_model] + # Differentiable replacement: keep present rows from gene_embeds_raw, else take learned_full + gene_embeds = torch.where(present_mask, gene_embeds_raw, learned_full) + else: + gene_embeds = gene_embeds_raw + + # Cache only when there are no missing genes for this set (so the tensor is static) + if self._last_missing_count == 0: + self.cached_gene_embeddings[genes_key] = gene_embeds.detach().clone() - # Cache and return - self.cached_gene_embeddings[cache_key] = gene_embeds return gene_embeds - def get_counts(self, cell_embs, genes, read_depth=None, batch_size=32): + # ------------------------- + # Counts decoding (vectorized over genes) + # ------------------------- + def get_counts(self, cell_embs, genes: List[str], batch_size: int = 32) -> torch.Tensor: """ - Generate predictions with the binary decoder with gradients enabled. - - Parameters: - - cell_embs: A tensor or array of cell embeddings. - - genes: List of gene names. - - read_depth: Optional read depth for RDA normalization. - - batch_size: Batch size for processing. + Generate predictions with the (pretrained) binary decoder. This is vectorized + over all genes (no per-gene loops). Returns: - A single tensor of shape [N, num_genes] where N is the total number of cells. + Tensor of shape [Ncells, Ngenes] """ + if self.model is None: + raise RuntimeError("Model not loaded. Call load_model(checkpoint) first.") - # Convert cell_embs to a tensor on the correct device. - cell_embs = torch.tensor(cell_embs, dtype=torch.float, device=self.device) - - # Check if RDA is enabled. - use_rda = getattr(self.model.cfg.model, "rda", False) - if use_rda and read_depth is None: - read_depth = 1000.0 + # Convert cell_embs to a tensor on the correct device (no detach here) + cell_embs = torch.as_tensor(cell_embs, dtype=torch.float, device=self.device) - # Retrieve gene embeddings (cached if available). - gene_embeds = self.get_gene_embedding(genes) + # RDA must be enabled to use read_depth + use_rda = getattr(self.model.cfg.model, "rda", False) # type: ignore + assert use_rda, "RDA must be enabled to use get_counts (model.cfg.model.rda == True)." - # List to collect the output predictions for each batch. - output_batches = [] + # Retrieve (and possibly learn) gene embeddings (with differentiable missing replacement) + gene_embeds = self.get_gene_embedding(genes) # [G, d_model] - # Loop over cell embeddings in batches. + outputs = [] for i in range(0, cell_embs.size(0), batch_size): - # Determine batch indices. end_idx = min(i + batch_size, cell_embs.size(0)) - cell_embeds_batch = cell_embs[i:end_idx] + cell_batch = cell_embs[i:end_idx] # [B, E_cell] + + # NOTE: Learnable read depth scalar, expanded to batch (keeps gradient) + task_counts = self.read_depth.expand(cell_batch.shape[0]).to(cell_batch.dtype).to(cell_batch.device) - # Set up task counts if using RDA. - if use_rda: - task_counts = torch.full((cell_embeds_batch.shape[0],), read_depth, device=self.device) - else: - task_counts = None + # Build [B, G, *] pairwise features and decode + merged = self.model.resize_batch(cell_batch, gene_embeds, task_counts) # type: ignore - # Resize the batch using the model's method. - merged_embs = self.model.resize_batch(cell_embeds_batch, gene_embeds, task_counts) + # Align dtype with decoder weights to avoid mixed-precision issues on CPU + dec_param_dtype = next(self.model.binary_decoder.parameters()).dtype # type: ignore + if merged.dtype != dec_param_dtype: + merged = merged.to(dec_param_dtype) - # Forward pass through the binary decoder. - logprobs_batch = self.model.binary_decoder(merged_embs) + logprobs_batch = self.model.binary_decoder(merged) # type: ignore - # If the output has an extra singleton dimension (e.g., [B, gene_dim, 1]), squeeze it. + # Squeeze trailing singleton if present: [B, G, 1] -> [B, G] if logprobs_batch.dim() == 3 and logprobs_batch.size(-1) == 1: logprobs_batch = logprobs_batch.squeeze(-1) - output_batches.append(logprobs_batch) + outputs.append(logprobs_batch) - # Concatenate all batch outputs along the first dimension. - return torch.cat(output_batches, dim=0) + return torch.cat(outputs, dim=0) diff --git a/src/state/emb/inference.py b/src/state/emb/inference.py index d042864f..7df4fe1a 100644 --- a/src/state/emb/inference.py +++ b/src/state/emb/inference.py @@ -277,6 +277,8 @@ def encode_adata( log.info(f"Successfully saved {len(all_embeddings)} embeddings to LanceDB") + return all_embeddings + def _convert_to_csr(self, adata): """Convert the adata.X matrix to CSR format if it's not already.""" from scipy.sparse import csr_matrix, issparse diff --git a/src/state/tx/callbacks/cumulative_flops.py b/src/state/tx/callbacks/cumulative_flops.py index 720083e8..6ab05ce1 100644 --- a/src/state/tx/callbacks/cumulative_flops.py +++ b/src/state/tx/callbacks/cumulative_flops.py @@ -36,18 +36,32 @@ def __init__( self._batch_count: int = 0 def _trainstep_forward_backward(self, model: LightningModule, batch: Any) -> torch.Tensor: - """Encapsulate calling StateTransitionPerturbationModel.training_step and backward. + """Call the model's training_step (handling optional args) and run backward if configured.""" - This intentionally targets StateTransitionPerturbationModel's signature and - performs both forward and backward to capture full FLOPs. - - !!WARNING!! - This has only been tested with StateTransitionPerturbationModel. Behavior with any other model has not been verified. - """ model.zero_grad(set_to_none=True) - loss: torch.Tensor = model.training_step(batch, 0, padded=True) # type: ignore + + try: + loss_out = model.training_step(batch, 0, padded=True) + except TypeError: + loss_out = model.training_step(batch, 0) + + if isinstance(loss_out, dict): + loss = loss_out.get("loss") + if loss is None: + raise RuntimeError( + "CumulativeFLOPSCallback expected training_step to return a Tensor or dict containing 'loss'." + ) + else: + loss = loss_out + + if not isinstance(loss, torch.Tensor): # pragma: no cover - defensive guard + raise TypeError( + "CumulativeFLOPSCallback requires training_step to return a Tensor (or dict with 'loss' Tensor)." + ) + if self.use_backward: loss.backward() + return loss def _measure_flops_once(self, trainer: Trainer, pl_module: Any, batch: Any) -> None: diff --git a/src/state/tx/models/base.py b/src/state/tx/models/base.py index 10378ef8..635536c7 100644 --- a/src/state/tx/models/base.py +++ b/src/state/tx/models/base.py @@ -123,12 +123,12 @@ class PerturbationModel(ABC, LightningModule): Args: input_dim: Dimension of input features (genes or embeddings) hidden_dim: Hidden dimension for neural network layers - output_dim: Dimension of output (always gene space) + output_dim: Dimension of output (gene space or embedding space) pert_dim: Dimension of perturbation embeddings dropout: Dropout rate lr: Learning rate for optimizer loss_fn: Loss function ('mse' or custom nn.Module) - output_space: 'gene' or 'latent' + output_space: 'gene', 'all', or 'embedding' """ def __init__( @@ -174,6 +174,10 @@ def __init__( self.embed_key = embed_key self.output_space = output_space + if self.output_space not in {"embedding", "gene", "all"}: + raise ValueError( + f"Unsupported output_space '{self.output_space}'. Expected one of 'embedding', 'gene', or 'all'." + ) self.batch_size = batch_size self.control_pert = control_pert @@ -182,6 +186,18 @@ def __init__( self.dropout = dropout self.lr = lr self.loss_fn = get_loss_fn(loss_fn) + + if self.output_space == "embedding": + self.gene_decoder_bool = False + self.decoder_cfg = None + # keep hyperparameters metadata consistent with the actual model state + try: + if hasattr(self, "hparams"): + self.hparams["gene_decoder_bool"] = False # type: ignore[index] + self.hparams["decoder_cfg"] = None # type: ignore[index] + except Exception: + pass + self._build_decoder() def transfer_batch_to_device(self, batch, device, dataloader_idx: int): @@ -216,6 +232,28 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: if self.gene_decoder_bool == False: self.gene_decoder = None return + + # When finetuning with the pretrained VCI decoder, keep the existing + # FinetuneVCICountsDecoder instance. Overwriting it with a freshly + # constructed LatentToGeneDecoder would make the checkpoint weights + # incompatible and surface load_state_dict errors. + finetune_decoder_active = False + hparams = getattr(self, "hparams", None) + if hparams is not None: + if hasattr(hparams, "get"): + finetune_decoder_active = bool(hparams.get("finetune_vci_decoder", False)) + else: + finetune_decoder_active = bool(getattr(hparams, "finetune_vci_decoder", False)) + if not finetune_decoder_active: + finetune_decoder_active = bool(getattr(self, "finetune_vci_decoder", False)) + + if finetune_decoder_active: + # Preserve decoder_cfg for completeness but avoid rebuilding the module. + if "decoder_cfg" in checkpoint.get("hyper_parameters", {}): + self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] + logger.info("Finetune VCI decoder active; keeping existing decoder during checkpoint load") + return + if not decoder_already_configured and "decoder_cfg" in checkpoint["hyper_parameters"]: self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg) diff --git a/src/state/tx/models/context_mean.py b/src/state/tx/models/context_mean.py index 7491dbcd..386bf0a3 100644 --- a/src/state/tx/models/context_mean.py +++ b/src/state/tx/models/context_mean.py @@ -91,7 +91,14 @@ def on_fit_start(self): return # Initialize dictionary to accumulate sum and count for each cell type. - celltype_sums = defaultdict(lambda: {"sum": torch.zeros(self.output_dim), "count": 0, "control_sum": torch.zeros(self.output_dim), "control_count": 0}) + celltype_sums = defaultdict( + lambda: { + "sum": torch.zeros(self.output_dim), + "count": 0, + "control_sum": torch.zeros(self.output_dim), + "control_count": 0, + } + ) with torch.no_grad(): for batch in train_loader: @@ -127,7 +134,9 @@ def on_fit_start(self): if stats["control_count"] > 0: # Use control cell average as fallback for cell types with no perturbations self.celltype_pert_means[ct_name] = stats["control_sum"] / stats["control_count"] - logger.info(f"ContextMean: Using control cell average for cell type '{ct_name}' (no perturbations found, {stats['control_count']} control cells used).") + logger.info( + f"ContextMean: Using control cell average for cell type '{ct_name}' (no perturbations found, {stats['control_count']} control cells used)." + ) else: logger.warning(f"No perturbed or control cells found for cell type {ct_name}.") continue diff --git a/src/state/tx/models/decoders.py b/src/state/tx/models/decoders.py index b7caa741..ae06565c 100644 --- a/src/state/tx/models/decoders.py +++ b/src/state/tx/models/decoders.py @@ -1,7 +1,10 @@ import logging +import os +from typing import Optional import torch import torch.nn as nn + from omegaconf import OmegaConf from ...emb.finetune_decoder import Finetune @@ -12,116 +15,179 @@ class FinetuneVCICountsDecoder(nn.Module): def __init__( self, - genes, - # model_loc="/large_storage/ctc/userspace/aadduri/vci/checkpoint/rda_tabular_counts_2048_new/step=950000.ckpt", - # config="/large_storage/ctc/userspace/aadduri/vci/checkpoint/rda_tabular_counts_2048_new/tahoe_config.yaml", - model_loc="/home/aadduri/vci_pretrain/vci_1.4.2.ckpt", - config="/large_storage/ctc/userspace/aadduri/vci/checkpoint/large_1e-4_rda_tabular_counts_2048/crossds_config.yaml", - read_depth=1200, - latent_dim=1024, # dimension of pretrained vci model - hidden_dims=[512, 512, 512], # hidden dimensions of the decoder - dropout=0.1, - basal_residual=False, + genes=None, + adata=None, + # checkpoint: Optional[str] = "/large_storage/ctc/userspace/aadduri/SE-600M/se600m_epoch15.ckpt", + # config: Optional[str] = "/large_storage/ctc/userspace/aadduri/SE-600M/config.yaml", + checkpoint: Optional[str] = "/home/aadduri/vci_pretrain/vci_1.4.4/vci_1.4.4_v7.ckpt", + config: Optional[str] = "/home/aadduri/vci_pretrain/vci_1.4.4/config.yaml", + latent_dim: int = 1034, # total input dim (cell emb + optional ds emb) + read_depth: float = 4.0, + ds_emb_dim: int = 10, # dataset embedding dim at the tail of input + hidden_dim: int = 512, + dropout: float = 0.1, + basal_residual: bool = False, + train_binary_decoder: bool = True, ): super().__init__() + # Initialize finetune helper and model from a single checkpoint + if config is None: + raise ValueError( + "FinetuneVCICountsDecoder requires a VCI/SE config. Set kwargs.vci_config or env STATE_VCI_CONFIG." + ) + self.finetune = Finetune(cfg=OmegaConf.load(config), train_binary_decoder=train_binary_decoder) + self.finetune.load_model(checkpoint) + # Resolve genes: prefer explicit list; else infer from anndata if provided + if genes is None and adata is not None: + try: + genes = self.finetune.genes_from_adata(adata) + except Exception as e: + raise ValueError(f"Failed to infer genes from AnnData: {e}") + if genes is None: + raise ValueError("FinetuneVCICountsDecoder requires 'genes' or 'adata' to derive gene names") self.genes = genes - self.model_loc = model_loc - self.config = config - self.finetune = Finetune(OmegaConf.load(self.config)) - self.finetune.load_model(self.model_loc) - self.read_depth = nn.Parameter(torch.tensor(read_depth, dtype=torch.float), requires_grad=False) + # Keep read_depth as a learnable parameter so decoded counts can adapt + self.read_depth = nn.Parameter(torch.tensor(read_depth, dtype=torch.float), requires_grad=True) self.basal_residual = basal_residual - - # layers = [ - # nn.Linear(latent_dim, hidden_dims[0]), - # ] - - # self.gene_lora = nn.Sequential(*layers) + self.ds_emb_dim = int(ds_emb_dim) if ds_emb_dim is not None else 0 + self.input_total_dim = int(latent_dim) self.latent_decoder = nn.Sequential( - nn.Linear(latent_dim, hidden_dims[0]), - nn.LayerNorm(hidden_dims[0]), + nn.Linear(latent_dim, hidden_dim), + nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), - nn.Linear(hidden_dims[0], hidden_dims[1]), - nn.LayerNorm(hidden_dims[1]), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), - nn.Linear(hidden_dims[1], len(self.genes)), - nn.ReLU(), + nn.Linear(hidden_dim, len(self.genes)), ) self.gene_decoder_proj = nn.Sequential( nn.Linear(len(self.genes), 128), + nn.LayerNorm(128), + nn.GELU(), + nn.Linear(128, 128), + nn.LayerNorm(128), + nn.GELU(), + nn.Dropout(dropout), nn.Linear(128, len(self.genes)), ) - self.binary_decoder = self.finetune.model.binary_decoder - for param in self.binary_decoder.parameters(): - param.requires_grad = False + self.binary_decoder = self.finetune.model.binary_decoder # type: ignore + + # Validate that all requested genes exist in the pretrained checkpoint's embeddings + pe = getattr(self.finetune, "protein_embeds", {}) + self.present_mask = [g in pe for g in self.genes] + self.missing_positions = [i for i, g in enumerate(self.genes) if g not in pe] + self.missing_genes = [self.genes[i] for i in self.missing_positions] + total_req = len(self.genes) + found = total_req - len(self.missing_positions) + total_pe = len(pe) if hasattr(pe, "__len__") else -1 + miss_pct = (len(self.missing_positions) / total_req) if total_req > 0 else 0.0 + logger.info( + f"FinetuneVCICountsDecoder gene check: requested={total_req}, found={found}, missing={len(self.missing_positions)} ({miss_pct:.1%}), all_embeddings_size={total_pe}" + ) + + # Create learnable embeddings for missing genes in the post-ESM gene embedding space + if len(self.missing_positions) > 0: + # Infer gene embedding output dimension by a dry-run through gene_embedding_layer + try: + sample_vec = next(iter(pe.values())).to(self.finetune.model.device) + if sample_vec.dim() == 1: + sample_vec = sample_vec.unsqueeze(0) + gene_embed_dim = self.finetune.model.gene_embedding_layer(sample_vec).shape[-1] + except Exception: + # Conservative fallback + gene_embed_dim = 1024 + + self.missing_table = nn.Embedding(len(self.missing_positions), gene_embed_dim) + nn.init.normal_(self.missing_table.weight, mean=0.0, std=0.02) + # For user visibility + try: + self.finetune.missing_genes = self.missing_genes + except Exception: + pass + else: + # Register a dummy buffer so attributes exist + self.missing_table = None + + # Ensure the wrapped Finetune helper creates its own missing-table parameters + # prior to Lightning's checkpoint load. Otherwise the checkpoint will contain + # weights like `gene_decoder.finetune.missing_table.weight` that are absent + # from a freshly constructed module, triggering "unexpected key" errors. + try: + with torch.no_grad(): + self.finetune.get_gene_embedding(self.genes) + except Exception as exc: + logger.debug(f"Deferred Finetune missing-table initialization failed: {exc}") def gene_dim(self): return len(self.genes) def forward(self, x: torch.Tensor) -> torch.Tensor: - # x is [B, S, latent_dim]. - if len(x.shape) != 3: + # x is [B, S, total_dim] + if x.dim() != 3: x = x.unsqueeze(0) - batch_size, seq_len, latent_dim = x.shape - x = x.view(batch_size * seq_len, latent_dim) - - # Get gene embeddings + batch_size, seq_len, total_dim = x.shape + x_flat = x.reshape(batch_size * seq_len, total_dim) + + # Split cell and dataset embeddings + if self.ds_emb_dim > 0: + cell_embeds = x_flat[:, : total_dim - self.ds_emb_dim] + ds_emb = x_flat[:, total_dim - self.ds_emb_dim : total_dim] + else: + cell_embeds = x_flat + ds_emb = None + + # Prepare gene embeddings (replace any missing with learned vectors) gene_embeds = self.finetune.get_gene_embedding(self.genes) - - # Handle RDA task counts + if self.missing_table is not None and len(self.missing_positions) > 0: + device = gene_embeds.device + learned = self.missing_table.weight.to(device) + idx = torch.tensor(self.missing_positions, device=device, dtype=torch.long) + gene_embeds = gene_embeds.clone() + gene_embeds.index_copy_(0, idx, learned) + # Ensure embeddings live on the same device as cell_embeds + if gene_embeds.device != cell_embeds.device: + gene_embeds = gene_embeds.to(cell_embeds.device) + + # RDA read depth vector (if enabled in SE model) use_rda = getattr(self.finetune.model.cfg.model, "rda", False) - # Define your sub-batch size (tweak this based on your available memory) - sub_batch_size = 16 - logprob_chunks = [] # to store outputs of each sub-batch - - for i in range(0, x.shape[0], sub_batch_size): - # Get the sub-batch of latent vectors - x_sub = x[i : i + sub_batch_size] - - # Create task_counts for the sub-batch if needed - if use_rda: - # task_counts_sub = torch.full( - # (x_sub.shape[0],), self.read_depth, device=x.device - # ) - task_counts_sub = torch.ones((x_sub.shape[0],), device=x.device) * self.read_depth - else: - task_counts_sub = None - - # Compute merged embeddings for the sub-batch - merged_embs_sub = self.finetune.model.resize_batch(x_sub, gene_embeds, task_counts_sub) - - # Run the binary decoder on the sub-batch - logprobs_sub = self.binary_decoder(merged_embs_sub) - - # Squeeze the singleton dimension if needed - if logprobs_sub.dim() == 3 and logprobs_sub.size(-1) == 1: - logprobs_sub = logprobs_sub.squeeze(-1) - - # Collect the results - logprob_chunks.append(logprobs_sub) - - # Concatenate the sub-batches back together - logprobs = torch.cat(logprob_chunks, dim=0) + task_counts = None + if use_rda: + task_counts = torch.full((cell_embeds.shape[0],), self.read_depth.item(), device=cell_embeds.device) + + # Binary decoder forward with safe dtype handling. + # - On CUDA: enable bf16 autocast for speed. + # - On CPU: ensure inputs match decoder weight dtype to avoid BF16/FP32 mismatch. + device_type = "cuda" if cell_embeds.is_cuda else "cpu" + with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=(device_type == "cuda")): + merged = self.finetune.model.resize_batch( + cell_embeds=cell_embeds, task_embeds=gene_embeds, task_counts=task_counts, ds_emb=ds_emb + ) + + # Align input dtype with decoder weights when autocast is not active (e.g., CPU path) + dec_param_dtype = next(self.binary_decoder.parameters()).dtype + if device_type != "cuda" and merged.dtype != dec_param_dtype: + merged = merged.to(dec_param_dtype) + + logprobs = self.binary_decoder(merged) + if logprobs.dim() == 3 and logprobs.size(-1) == 1: + logprobs = logprobs.squeeze(-1) # Reshape back to [B, S, gene_dim] decoded_gene = logprobs.view(batch_size, seq_len, len(self.genes)) - decoded_gene = decoded_gene + self.gene_decoder_proj(decoded_gene) - # decoded_gene = torch.nn.functional.relu(decoded_gene) - - # # normalize the sum of decoded_gene to be read depth - # decoded_gene = decoded_gene / decoded_gene.sum(dim=2, keepdim=True) * self.read_depth - # decoded_gene = self.gene_lora(decoded_gene) - # TODO: fix this to work with basal counts - - # add logic for basal_residual: - decoded_x = self.latent_decoder(x) - decoded_x = decoded_x.view(batch_size, seq_len, len(self.genes)) + # Match dtype for post-decoder projection to avoid mixed-dtype matmul + proj_param_dtype = next(self.gene_decoder_proj.parameters()).dtype + if decoded_gene.dtype != proj_param_dtype: + decoded_gene = decoded_gene.to(proj_param_dtype) + decoded_gene = decoded_gene + self.gene_decoder_proj(decoded_gene) - # Pass through the additional decoder layers - return decoded_gene + decoded_x + # Optional residual from latent decoder (operates on full input features) + ld_param_dtype = next(self.latent_decoder.parameters()).dtype + x_flat_for_ld = x_flat if x_flat.dtype == ld_param_dtype else x_flat.to(ld_param_dtype) + decoded_x = self.latent_decoder(x_flat_for_ld).view(batch_size, seq_len, len(self.genes)) + return torch.nn.functional.relu(decoded_gene + decoded_x) diff --git a/src/state/tx/models/pseudobulk.py b/src/state/tx/models/pseudobulk.py index c63eb43e..5494dea2 100644 --- a/src/state/tx/models/pseudobulk.py +++ b/src/state/tx/models/pseudobulk.py @@ -115,36 +115,95 @@ def __init__( control_pert = kwargs.get("control_pert", "non-targeting") if kwargs.get("finetune_vci_decoder", False): - gene_names = [] - - if output_space == "gene": - # hvg's but for which dataset? - if "DMSO_TF" in control_pert: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True - ) - elif "non-targeting" in control_pert: - temp = ad.read_h5ad("/large_storage/ctc/userspace/aadduri/datasets/hvg/replogle/jurkat.h5") - gene_names = temp.var.index.values - else: - assert output_space == "all" - if "DMSO_TF" in control_pert: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_names.npy", allow_pickle=True - ) - elif "non-targeting" in control_pert: - # temp = ad.read_h5ad('/scratch/ctc/ML/vci/paper_replogle/jurkat.h5') - # gene_names = temp.var.index.values - temp = ad.read_h5ad("/large_storage/ctc/userspace/aadduri/cross_dataset/replogle/jurkat.h5") - gene_names = temp.var.index.values - + # Prefer the gene names supplied by the data module (aligned to training output) + gene_names = self.gene_names + if gene_names is None: + raise ValueError( + "finetune_vci_decoder=True but model.gene_names is None. " + "Please provide gene_names via data module var_dims." + ) + + n_genes = len(gene_names) + logger.info( + f"Initializing FinetuneVCICountsDecoder with {n_genes} genes (output_space={output_space}; " + + ("HVG subset" if output_space == "gene" else "all genes") + + ")" + ) self.gene_decoder = FinetuneVCICountsDecoder( genes=gene_names, - # latent_dim=self.output_dim + (self.batch_dim or 0), + checkpoint=kwargs.get("vci_checkpoint", None), ) print(self) + def _decoder_in_features(self) -> Optional[int]: + """ + Best-effort inspection of the decoder's expected input dimensionality. + Returns None if it cannot be determined reliably. + """ + gd = self.gene_decoder + if gd is None: + return None + # LatentToGeneDecoder (non-residual): has .decoder (Sequential) starting with Linear + if hasattr(gd, "decoder") and isinstance(getattr(gd, "decoder"), nn.Sequential): + seq = gd.decoder + for m in seq: + if isinstance(m, nn.Linear): + return m.in_features + return None + # LatentToGeneDecoder (residual): has .blocks (ModuleList) of Sequentials, first starts with Linear + if hasattr(gd, "blocks"): + blocks = getattr(gd, "blocks") + if len(blocks) > 0 and isinstance(blocks[0], nn.Sequential) and isinstance(blocks[0][0], nn.Linear): + return blocks[0][0].in_features + return None + # NBDecoder: has .encoder (Sequential) starting with Linear + if hasattr(gd, "encoder") and isinstance(getattr(gd, "encoder"), nn.Sequential): + seq = gd.encoder + for m in seq: + if isinstance(m, nn.Linear): + return m.in_features + return None + return None + + def _maybe_concat_batch(self, latent: torch.Tensor, batch: torch.Tensor, padded: bool) -> torch.Tensor: + """ + Concatenate batch covariates to the latent only if the decoder expects them. + This avoids shape mismatches at inference when loading a checkpointed decoder + that was trained without batch concatenation. + """ + if self.gene_decoder is None or self.batch_dim is None: + return latent + + expected_in = self._decoder_in_features() + last_dim = latent.size(-1) + + # Prepare batch tensor to match latent shape + if latent.dim() == 2: + batch_var = batch.reshape(latent.shape[0], -1) + else: + batch_var = batch.reshape(latent.shape[0], latent.shape[1], -1) + + # Decide whether to concatenate based on the decoder's input expectation + if expected_in is None: + # Fallback to previous behavior: concatenate for non-VCI decoders + return torch.cat([latent, batch_var], dim=-1) + + if expected_in == last_dim: + # Decoder expects just the latent; do NOT concat + return latent + elif expected_in == last_dim + batch_var.size(-1): + # Decoder expects latent + batch covariates; concat + return torch.cat([latent, batch_var], dim=-1) + else: + # Mismatch: give a clear error message to guide the user + raise RuntimeError( + f"Decoder input dim mismatch: got latent size {last_dim}" + f" (batch_dim={batch_var.size(-1)}), but decoder expects {expected_in}." + " This usually means the checkpointed decoder was trained without" + " concatenating batch covariates, while predict is attempting to." + ) + def _build_networks(self): """ Here we instantiate the actual GPT2-based model. @@ -281,10 +340,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T # with torch.no_grad(): # latent_preds = pred.detach() # Detach to prevent gradient flow back to main model - batch_var = batch["batch"].reshape(latent_preds.shape[0], latent_preds.shape[1], -1) - # concatenate on the last axis - if self.batch_dim is not None and not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): - latent_preds = torch.cat([latent_preds, batch_var], dim=-1) + if not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): + latent_preds = self._maybe_concat_batch(latent_preds, batch["batch"], padded=True) if isinstance(self.gene_decoder, NBDecoder): mu, theta = self.gene_decoder(latent_preds) @@ -314,8 +371,8 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non target = batch["pert_cell_emb"] target = target.reshape(-1, self.cell_sentence_len, self.output_dim) - loss = self.loss_fn(pred, target).mean() - self.log("val_loss", loss) + loss = torch.nanmean(self.loss_fn(pred, target)) + self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) if self.gene_decoder is not None and "pert_cell_counts" in batch: gene_targets = batch["pert_cell_counts"] @@ -328,7 +385,10 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non gene_targets = batch["pert_cell_counts"].reshape_as(mu) decoder_loss = nb_nll(gene_targets, mu, theta) else: - pert_cell_counts_preds = self.gene_decoder(latent_preds) # verify this is automatically detached + # Match decoder input dims + if not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): + latent_preds = self._maybe_concat_batch(latent_preds, batch["batch"], padded=True) + pert_cell_counts_preds = self.gene_decoder(latent_preds) # Get decoder predictions pert_cell_counts_preds = pert_cell_counts_preds.reshape(-1, self.cell_sentence_len, self.gene_dim) @@ -368,17 +428,14 @@ def predict_step(self, batch, batch_idx, padded=True, **kwargs): basal_hvg = batch.get("ctrl_cell_counts", None) if self.gene_decoder is not None: - if latent_output.dim() == 2: - batch_var = batch["batch"].reshape(latent_output.shape[0], -1) - else: - batch_var = batch["batch"].reshape(latent_output.shape[0], latent_output.shape[1], -1) - # concatenate on the last axis - if self.batch_dim is not None and not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): - latent_output = torch.cat([latent_output, batch_var], dim=-1) if isinstance(self.gene_decoder, NBDecoder): + # NB decoder already configured with latent_dim including batch if needed mu, _ = self.gene_decoder(latent_output) pert_cell_counts_preds = mu else: + # Only concat batch covariates if decoder expects them + if not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): + latent_output = self._maybe_concat_batch(latent_output, batch["batch"], padded=padded) pert_cell_counts_preds = self.gene_decoder(latent_output) output_dict["pert_cell_counts_preds"] = pert_cell_counts_preds diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index ecce5e29..412cac82 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -1,5 +1,6 @@ +import ast import logging -from typing import Dict, Optional +import math import anndata as ad import numpy as np @@ -8,7 +9,7 @@ import torch.nn.functional as F from geomloss import SamplesLoss -from typing import Tuple +from typing import Dict, Optional, Tuple from .base import PerturbationModel from .decoders import FinetuneVCICountsDecoder @@ -20,9 +21,7 @@ class CombinedLoss(nn.Module): - """ - Combined Sinkhorn + Energy loss - """ + """Combined Sinkhorn + Energy loss.""" def __init__(self, sinkhorn_weight=0.001, energy_weight=1.0, blur=0.05): super().__init__() @@ -98,6 +97,28 @@ def extract_confidence_prediction(self, transformer_output: torch.Tensor) -> Tup return main_output, confidence_pred +class HillGate(nn.Module): + """ + Monotone, saturating gate w(d) for dose d that multiplies the residual. + For each hidden unit h: + w_h(d) = softplus(Emax_h) * (d/EC50_h)^{n_h} / (1 + (d/EC50_h)^{n_h}) + The forward expects log10-dose with shape [B,S,1] (standardized is fine). + """ + def __init__(self, hidden_dim: int): + super().__init__() + self.log_ec50 = nn.Parameter(torch.zeros(hidden_dim)) # EC50 = exp(log_ec50) > 0 + self.emax = nn.Parameter(torch.ones(hidden_dim)) # softplus(emax) > 0 + self.hill = nn.Parameter(torch.ones(hidden_dim)) # softplus(hill) > 0 + + def forward(self, log10_d: torch.Tensor) -> torch.Tensor: + # Convert log10-dose to linear space and broadcast to hidden dim + d = (10.0 ** log10_d).clamp_min(1e-12) # [B,S,1] + n = F.softplus(self.hill).view(1, 1, -1) # [1,1,H] + emax = F.softplus(self.emax).view(1, 1, -1) # [1,1,H] + ec50 = self.log_ec50.exp().view(1, 1, -1) # [1,1,H] + ratio_pow = (d / ec50).pow(n) # [B,S,H] + w = emax * ratio_pow / (1.0 + ratio_pow + 1e-12) # [B,S,H], in (0, emax] + return w class StateTransitionPerturbationModel(PerturbationModel): """ @@ -164,6 +185,8 @@ def __init__( self.distributional_loss = distributional_loss self.gene_dim = gene_dim + self.mmd_num_chunks = max(int(kwargs.get("mmd_num_chunks", 1)), 1) + self.randomize_mmd_chunks = bool(kwargs.get("randomize_mmd_chunks", False)) # Build the distributional loss from geomloss blur = kwargs.get("blur", 0.05) @@ -173,7 +196,7 @@ def __init__( elif loss_name == "mse": self.loss_fn = nn.MSELoss() elif loss_name == "se": - sinkhorn_weight = kwargs.get("sinkhorn_weight", 0.01) # 1/100 = 0.01 + sinkhorn_weight = kwargs.get("sinkhorn_weight", 0.01) energy_weight = kwargs.get("energy_weight", 1.0) self.loss_fn = CombinedLoss(sinkhorn_weight=sinkhorn_weight, energy_weight=energy_weight, blur=blur) elif loss_name == "sinkhorn": @@ -197,6 +220,44 @@ def __init__( ) self.batch_dim = batch_dim + # Optional batch predictor ablation: learns a single batch token added to every position, + # and adds an auxiliary per-token batch classification head + CE loss. + self.batch_predictor = bool(kwargs.get("batch_predictor", False)) + # If batch_encoder is enabled, disable batch_predictor per request + if self.batch_encoder is not None and self.batch_predictor: + logger.warning( + "Both model.kwargs.batch_encoder and model.kwargs.batch_predictor are True. " + "Disabling batch_predictor and proceeding with batch_encoder." + ) + self.batch_predictor = False + try: + # Keep hparams in sync if available + self.hparams["batch_predictor"] = False # type: ignore[index] + except Exception: + pass + + self.batch_predictor_weight = float(kwargs.get("batch_predictor_weight", 0.1)) + self.batch_predictor_num_classes: Optional[int] = batch_dim if self.batch_predictor else None + if self.batch_predictor: + if self.batch_predictor_num_classes is None: + raise ValueError("batch_predictor=True requires a valid `batch_dim` (number of batch classes).") + # A single learnable batch token that is added to each position + self.batch_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim)) + # Simple per-token classifier from transformer hidden to batch classes + self.batch_classifier = build_mlp( + in_dim=self.hidden_dim, + out_dim=self.batch_predictor_num_classes, + hidden_dim=self.hidden_dim, + n_layers=4, + dropout=self.dropout, + activation=self.activation_class, + ) + else: + self.batch_token = None + self.batch_classifier = None + # Internal cache for last token features (B, S, H) from transformer for aux loss + self._token_features: Optional[torch.Tensor] = None + # if the model is outputting to counts space, apply relu # otherwise its in embedding space and we don't want to is_gene_space = kwargs["embed_key"] == "X_hvg" or kwargs["embed_key"] is None @@ -250,6 +311,36 @@ def __init__( if kwargs.get("confidence_token", False): self.confidence_token = ConfidenceToken(hidden_dim=self.hidden_dim, dropout=self.dropout) self.confidence_loss_fn = nn.MSELoss() + self.confidence_target_scale = float(kwargs.get("confidence_target_scale", 10.0)) + self.confidence_weight = float(kwargs.get("confidence_weight", 0.01)) + else: + self.confidence_target_scale = None + self.confidence_weight = 0.0 + + self.use_dosage_encoder = bool(kwargs.get("dosage", False)) + self.dosage_encoder = nn.Linear(1, self.hidden_dim) if self.use_dosage_encoder else None + self._warned_missing_dosage = False + # Feature flag: pharmacologically-informed dose handling + self.use_hill_prior = bool(kwargs.get("hill_prior", False)) + if self.use_hill_prior: + # Running stats for standardized log10-dose + self.register_buffer("dose_mean", torch.tensor(0.0)) + self.register_buffer("dose_std", torch.tensor(1.0)) + self.dose_momentum = float(kwargs.get("dose_momentum", 0.01)) + # Strength of FiLM modulation + self.dose_strength = nn.Parameter(torch.tensor(float(kwargs.get("dose_strength_init", 1.0)))) + # FiLM network on standardized log10-dose + self.dose_film = nn.Sequential( + nn.Linear(1, 128), + nn.SiLU(), + nn.Linear(128, 2 * self.hidden_dim), + ) + # Hill/Emax gate on residual + self.hill_gate = HillGate(self.hidden_dim) + # Small curvature penalty across doses of the same drug + self.dose_smooth_weight = float(kwargs.get("dose_smooth_weight", 0.01)) + else: + self.dose_smooth_weight = 0.0 # Backward-compat: accept legacy key `freeze_pert` self.freeze_pert_backbone = kwargs.get("freeze_pert_backbone", kwargs.get("freeze_pert", False)) @@ -274,32 +365,22 @@ def __init__( control_pert = kwargs.get("control_pert", "non-targeting") if kwargs.get("finetune_vci_decoder", False): # TODO: This will go very soon - gene_names = [] - - if output_space == "gene": - # hvg's but for which dataset? - if "DMSO_TF" in control_pert: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True - ) - elif "non-targeting" in control_pert: - temp = ad.read_h5ad("/large_storage/ctc/userspace/aadduri/datasets/hvg/replogle/jurkat.h5") - # gene_names = temp.var.index.values - else: - assert output_space == "all" - if "DMSO_TF" in control_pert: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_names.npy", allow_pickle=True - ) - elif "non-targeting" in control_pert: - # temp = ad.read_h5ad('/scratch/ctc/ML/vci/paper_replogle/jurkat.h5') - # gene_names = temp.var.index.values - temp = ad.read_h5ad("/large_storage/ctc/userspace/aadduri/cross_dataset/replogle/jurkat.h5") - gene_names = temp.var.index.values + # Prefer the gene names supplied by the data module (aligned to training output) + gene_names = self.gene_names + if gene_names is None: + raise ValueError( + "finetune_vci_decoder=True but model.gene_names is None. " + "Please provide gene_names via data module var_dims." + ) + n_genes = len(gene_names) + logger.info( + f"Initializing FinetuneVCICountsDecoder with {n_genes} genes (output_space={output_space}; " + + ("HVG subset" if output_space == "gene" else "all genes") + + ")" + ) self.gene_decoder = FinetuneVCICountsDecoder( genes=gene_names, - # latent_dim=self.output_dim + (self.batch_dim or 0), ) print(self) @@ -398,6 +479,27 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: combined_input = pert_embedding + control_cells # Shape: [B, S, hidden_dim] seq_input = combined_input # Shape: [B, S, hidden_dim] + logd_norm: Optional[torch.Tensor] = None + if self.use_dosage_encoder: + dosage_tensor = self._prepare_dosage_tensor(batch, seq_input.device, pert.shape[:2]) + if dosage_tensor is not None: + if self.use_hill_prior: + logd = torch.log10(dosage_tensor.clamp_min(1e-9)) # [B,S,1] + if self.training: + with torch.no_grad(): + bmean = logd.mean() + bstd = logd.std(unbiased=False).clamp_min(1e-6) + self.dose_mean = (1 - self.dose_momentum) * self.dose_mean + self.dose_momentum * bmean + self.dose_std = (1 - self.dose_momentum) * self.dose_std + self.dose_momentum * bstd + logd_norm = (logd - self.dose_mean) / (self.dose_std + 1e-6) + film_params = self.dose_film(logd_norm) + gamma, beta = film_params.chunk(2, dim=-1) + gamma = F.softplus(gamma) + seq_input = (1 + self.dose_strength * (gamma - 1)) * seq_input + self.dose_strength * beta + elif self.dosage_encoder is not None: + dosage_features = self.dosage_encoder(torch.log1p(dosage_tensor)) + seq_input = seq_input + dosage_features + if self.batch_encoder is not None: # Extract batch indices (assume they are integers or convert from one-hot) batch_indices = batch["batch"] @@ -431,11 +533,11 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: if self.hparams.get("mask_attn", False): batch_size, seq_length, _ = seq_input.shape device = seq_input.device - self.transformer_backbone._attn_implementation = "eager" # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] + self.transformer_backbone._attn_implementation = "eager" # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] # create a [1,1,S,S] mask (now S+1 if confidence token is used) base = torch.eye(seq_length, device=device, dtype=torch.bool).view(1, 1, seq_length, seq_length) - + # Get number of attention heads from model config num_heads = self.transformer_backbone.config.num_attention_heads @@ -471,6 +573,12 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: res_pred = transformer_output self._batch_token_cache = None + # Apply a monotone, saturating gate on the residual if enabled + if self.use_hill_prior and logd_norm is not None: + res_pred = res_pred * self.hill_gate(logd_norm) # [B,S,H]×[B,S,H] + # Cache token features for auxiliary batch prediction loss (B, S, H) + self._token_features = res_pred + # add to basal if predicting residual if self.predict_residual and self.output_space == "all": # Project control_cells to hidden_dim space to match res_pred @@ -485,8 +593,6 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: # apply relu if specified and we output to HVG space is_gene_space = self.hparams["embed_key"] == "X_hvg" or self.hparams["embed_key"] is None - # logger.info(f"DEBUG: is_gene_space: {is_gene_space}") - # logger.info(f"DEBUG: self.gene_decoder: {self.gene_decoder}") if is_gene_space or self.gene_decoder is None: out_pred = self.relu(out_pred) @@ -497,6 +603,179 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: else: return output + def _prepare_dosage_tensor( + self, batch: Dict[str, torch.Tensor], device: torch.device, shape: Tuple[int, int] + ) -> Optional[torch.Tensor]: + """Return dosage tensor shaped for broadcasting or None if unavailable.""" + + if not self.use_dosage_encoder: + return None + + dosage_values = batch.get("pert_dosage") + + if dosage_values is not None: + if torch.is_tensor(dosage_values): + dosage_tensor = dosage_values.to(device=device, dtype=torch.float32) + else: + dosage_tensor = torch.as_tensor(dosage_values, device=device, dtype=torch.float32) + else: + pert_names = batch.get("pert_name") + if pert_names is None: + if not self._warned_missing_dosage: + logger.warning( + "Dosage encoder enabled but no dosage information found in batch; skipping dosage term." + ) + self._warned_missing_dosage = True + return None + + if isinstance(pert_names, torch.Tensor): + pert_names = pert_names.tolist() + if not isinstance(pert_names, (list, tuple)): + pert_names = [pert_names] + + dosage_list = [self._parse_dosage_from_name(name) for name in pert_names] + dosage_tensor = torch.tensor(dosage_list, device=device, dtype=torch.float32) + + if not self._warned_missing_dosage: + logger.warning( + "Falling back to parsing dosage from perturbation names; consider providing 'pert_dosage'." + ) + self._warned_missing_dosage = True + + dosage_tensor = dosage_tensor.flatten() + expected_elems = shape[0] * shape[1] + + if dosage_tensor.numel() == expected_elems: + return dosage_tensor.reshape(shape[0], shape[1], 1) + + if dosage_tensor.numel() == shape[0] and shape[1] > 0: + return dosage_tensor.view(shape[0], 1, 1).expand(shape[0], shape[1], 1) + + if shape[0] == 1 and dosage_tensor.numel() == shape[1]: + return dosage_tensor.view(1, shape[1], 1) + + logger.warning( + "Dosage tensor has %d elements but expected either %d or %d; skipping dosage term for this batch.", + dosage_tensor.numel(), + expected_elems, + shape[0], + ) + return None + + @staticmethod + def _parse_dosage_from_name(name: Optional[str]) -> float: + """Extract dosage value from perturbation name string.""" + + if not isinstance(name, str): + return 0.0 + + try: + parsed = ast.literal_eval(name) + except (ValueError, SyntaxError): + return 0.0 + + try: + if isinstance(parsed, (list, tuple)) and len(parsed) > 0: + first_entry = parsed[0] + if isinstance(first_entry, (list, tuple)) and len(first_entry) > 1: + return float(first_entry[1]) + except (TypeError, ValueError): + pass + + return 0.0 + + @staticmethod + def _parse_drug_from_name(name: Optional[str]) -> str: + """Best-effort extraction of the base drug identifier from a perturbation name.""" + if not isinstance(name, str): + return "unknown" + try: + parsed = ast.literal_eval(name) + if isinstance(parsed, (list, tuple)) and len(parsed) > 0: + first_entry = parsed[0] + if isinstance(first_entry, (list, tuple)) and len(first_entry) > 0: + return str(first_entry[0]) + except (ValueError, SyntaxError): + pass + # Fallback: strip common separators if present + return name.split("@")[0].split("|")[0] + + def _dose_smoothness_loss(self, batch: Dict[str, torch.Tensor], pred: torch.Tensor, padded: bool) -> torch.Tensor: + """ + Encourage a smooth (low curvature) trajectory across log-dose for the same drug within the minibatch. + 'pred' is [B,S,D] (set of cells per dose). We reduce over S (cells) first. + Requires 'pert_dosage' (or parseable names) to be present; otherwise returns 0. + """ + if not self.use_hill_prior or self.dose_smooth_weight <= 0.0: + return pred.new_tensor(0.0) + + B, S, D = pred.shape + device = pred.device + + # One dose per sentence + dose = self._prepare_dosage_tensor(batch, device, (B, S)) + if dose is None: + return pred.new_tensor(0.0) + dose_per_sentence = dose[:, 0, 0] # [B] + + # One drug label per sentence (best effort) + groups = None + names = batch.get("pert_name", None) + if names is not None: + if isinstance(names, torch.Tensor): + names_list = names.reshape(-1).tolist() + else: + names_list = list(names) + if len(names_list) >= B * S: + per_sentence = [names_list[i * S] for i in range(B)] + elif len(names_list) >= B: + per_sentence = [names_list[i] for i in range(B)] + else: + per_sentence = None + if per_sentence is not None: + groups = [self._parse_drug_from_name(n) for n in per_sentence] + if groups is None: + groups = ["__all__"] * B # fall back to one pooled group + + # Reduce each sentence to a set-level vector + set_pred = pred.mean(dim=1) # [B, D] + + buckets: Dict[str, list] = {} + for i, g in enumerate(groups): + buckets.setdefault(g, []).append((dose_per_sentence[i].item(), i)) + + losses = [] + for _, lst in buckets.items(): + if len(lst) < 3: + continue + lst.sort(key=lambda x: x[0]) # ascending dose + idx = [i for _, i in lst] + series = set_pred[idx] # [Nd, D] + second = series[2:] - 2 * series[1:-1] + series[:-2] + losses.append((second ** 2).mean()) + + if not losses: + return pred.new_tensor(0.0) + return torch.stack(losses).mean() + + def _compute_distribution_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Apply the primary distributional loss, optionally chunking feature dimensions for SamplesLoss.""" + + if isinstance(self.loss_fn, SamplesLoss) and self.mmd_num_chunks > 1: + feature_dim = pred.shape[-1] + num_chunks = min(self.mmd_num_chunks, feature_dim) + if num_chunks > 1 and feature_dim > 0: + if self.randomize_mmd_chunks and self.training: + perm = torch.randperm(feature_dim, device=pred.device) + pred = pred.index_select(-1, perm) + target = target.index_select(-1, perm) + pred_chunks = torch.chunk(pred, num_chunks, dim=-1) + target_chunks = torch.chunk(target, num_chunks, dim=-1) + chunk_losses = [self.loss_fn(p_chunk, t_chunk) for p_chunk, t_chunk in zip(pred_chunks, target_chunks)] + return torch.stack(chunk_losses, dim=0).nanmean(dim=0) + + return self.loss_fn(pred, target) + def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=True) -> torch.Tensor: """Training step logic for both main model and decoder.""" # Get model predictions (in latent space) @@ -515,7 +794,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T pred = pred.reshape(1, -1, self.output_dim) target = target.reshape(1, -1, self.output_dim) - main_loss = self.loss_fn(pred, target).nanmean() + per_set_main_losses = self._compute_distribution_loss(pred, target) + main_loss = torch.nanmean(per_set_main_losses) self.log("train_loss", main_loss) # Log individual loss components if using combined loss @@ -579,6 +859,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T self.log("train/batch_token_loss", ce_loss) total_loss = total_loss + self.batch_token_weight * ce_loss + # Auxiliary batch prediction loss (per token), if enabled if self.gene_decoder is not None and "pert_cell_counts" in batch: gene_targets = batch["pert_cell_counts"] # Train decoder to map latent predictions to gene space @@ -603,7 +884,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T else: gene_targets = gene_targets.reshape(1, -1, self.gene_decoder.gene_dim()) - decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() + decoder_per_set = self._compute_distribution_loss(pert_cell_counts_preds, gene_targets) + decoder_loss = decoder_per_set.mean() # Log decoder loss self.log("decoder_loss", decoder_loss) @@ -611,25 +893,18 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T total_loss = total_loss + self.decoder_loss_weight * decoder_loss if confidence_pred is not None: - # Detach main loss to prevent gradients flowing through it - loss_target = total_loss.detach().clone().unsqueeze(0) * 10 - - # Ensure proper shapes for confidence loss computation - if confidence_pred.dim() == 2: # [B, 1] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1) - else: # confidence_pred is [B,] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0)) - - # Compute confidence loss - confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze()) + confidence_pred_vals = confidence_pred + if confidence_pred_vals.dim() > 1: + confidence_pred_vals = confidence_pred_vals.squeeze(-1) + confidence_targets = per_set_main_losses.detach() + if self.confidence_target_scale is not None: + confidence_targets = confidence_targets * self.confidence_target_scale + confidence_targets = confidence_targets.to(confidence_pred_vals.device) + + confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets) self.log("train/confidence_loss", confidence_loss) - self.log("train/actual_loss", loss_target.mean()) - - # Add to total loss with weighting - confidence_weight = 0.1 # You can make this configurable - total_loss = total_loss + confidence_weight * confidence_loss + self.log("train/actual_loss", confidence_targets.mean()) - # Add to total loss total_loss = total_loss + confidence_loss if self.regularization > 0.0: @@ -645,6 +920,12 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T # Add regularization to total loss total_loss = total_loss + self.regularization * l1_loss + if self.use_hill_prior and self.dose_smooth_weight > 0.0: + with torch.no_grad() if not self.training else torch.enable_grad(): + smooth_loss = self._dose_smoothness_loss(batch, pred, padded=padded) + self.log("train/dose_smooth_loss", smooth_loss) + total_loss = total_loss + self.dose_smooth_weight * smooth_loss + return total_loss def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: @@ -658,7 +939,8 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non target = batch["pert_cell_emb"] target = target.reshape(-1, self.cell_sentence_len, self.output_dim) - loss = self.loss_fn(pred, target).mean() + per_set_main_losses = self._compute_distribution_loss(pred, target) + loss = torch.nanmean(per_set_main_losses) self.log("val_loss", loss) # Log individual loss components if using combined loss @@ -685,26 +967,31 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non -1, self.cell_sentence_len, self.gene_decoder.gene_dim() ) gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_decoder.gene_dim()) - decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() + decoder_per_set = self._compute_distribution_loss(pert_cell_counts_preds, gene_targets) + decoder_loss = decoder_per_set.mean() # Log the validation metric self.log("val/decoder_loss", decoder_loss) loss = loss + self.decoder_loss_weight * decoder_loss if confidence_pred is not None: - # Detach main loss to prevent gradients flowing through it - loss_target = loss.detach().clone() * 10 - - # Ensure proper shapes for confidence loss computation - if confidence_pred.dim() == 2: # [B, 1] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1) - else: # confidence_pred is [B,] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0)) - - # Compute confidence loss - confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze()) + confidence_pred_vals = confidence_pred + if confidence_pred_vals.dim() > 1: + confidence_pred_vals = confidence_pred_vals.squeeze(-1) + confidence_targets = per_set_main_losses.detach() + if self.confidence_target_scale is not None: + confidence_targets = confidence_targets * self.confidence_target_scale + confidence_targets = confidence_targets.to(confidence_pred_vals.device) + + confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets) self.log("val/confidence_loss", confidence_loss) - self.log("val/actual_loss", loss_target.mean()) + self.log("val/actual_loss", confidence_targets.mean()) + + # Validation analogue of curvature penalty + if self.use_hill_prior and self.dose_smooth_weight > 0.0: + smooth_loss = self._dose_smoothness_loss(batch, pred, padded=True) + self.log("val/dose_smooth_loss", smooth_loss) + loss = loss + self.dose_smooth_weight * smooth_loss return {"loss": loss, "predictions": pred} @@ -717,21 +1004,20 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: target = batch["pert_cell_emb"] pred = pred.reshape(1, -1, self.output_dim) target = target.reshape(1, -1, self.output_dim) - loss = self.loss_fn(pred, target).mean() + per_set_main_losses = self._compute_distribution_loss(pred, target) + loss = torch.nanmean(per_set_main_losses) self.log("test_loss", loss) if confidence_pred is not None: - # Detach main loss to prevent gradients flowing through it - loss_target = loss.detach().clone() * 10.0 - - # Ensure proper shapes for confidence loss computation - if confidence_pred.dim() == 2: # [B, 1] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1) - else: # confidence_pred is [B,] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0)) - - # Compute confidence loss - confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze()) + confidence_pred_vals = confidence_pred + if confidence_pred_vals.dim() > 1: + confidence_pred_vals = confidence_pred_vals.squeeze(-1) + confidence_targets = per_set_main_losses.detach() + if self.confidence_target_scale is not None: + confidence_targets = confidence_targets * self.confidence_target_scale + confidence_targets = confidence_targets.to(confidence_pred_vals.device) + + confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets) self.log("test/confidence_loss", confidence_loss) def predict_step(self, batch, batch_idx, padded=True, **kwargs): diff --git a/src/state/tx/models/utils.py b/src/state/tx/models/utils.py index 47185a83..fdfdf577 100644 --- a/src/state/tx/models/utils.py +++ b/src/state/tx/models/utils.py @@ -160,16 +160,12 @@ def apply_lora(model: PreTrainedModel, backbone_key: str, lora_cfg: dict | None) return model if LoraConfig is None or get_peft_model is None: - raise ImportError( - "peft is not installed but `lora.enable` is True. Add `peft` to dependencies." - ) + raise ImportError("peft is not installed but `lora.enable` is True. Add `peft` to dependencies.") target = lora_cfg.get("target", "auto") adapt_mlp = bool(lora_cfg.get("adapt_mlp", False)) target_modules = ( - lora_cfg.get("target_modules") - if target != "auto" - else _default_lora_targets(backbone_key, adapt_mlp) + lora_cfg.get("target_modules") if target != "auto" else _default_lora_targets(backbone_key, adapt_mlp) ) # Build PEFT LoRA config @@ -230,13 +226,13 @@ def __init__(self, config: LlamaConfig): self.rotary_emb = NoRoPE( head_dim=config.head_dim, ) - + # Explicitly disable causal attention self.config.is_causal = False # force every layer to be non-causal for layer in self.layers: if hasattr(layer, "self_attn"): - layer.self_attn.is_causal = False # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] + layer.self_attn.is_causal = False # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] def _update_causal_mask( self, @@ -265,7 +261,7 @@ def forward( **flash_attn_kwargs, ): flash_attn_kwargs["is_causal"] = False - + # If no attention_mask is provided, create an all-ones mask (no masking) # This ensures bidirectional attention with correct device/dtype if attention_mask is None: diff --git a/src/state/tx/utils/__init__.py b/src/state/tx/utils/__init__.py index 7a35c853..e4fcf9ac 100644 --- a/src/state/tx/utils/__init__.py +++ b/src/state/tx/utils/__init__.py @@ -127,7 +127,7 @@ def get_loggers( return loggers -def get_checkpoint_callbacks(output_dir: str, name: str, val_freq: int, ckpt_every_n_steps: int): +def get_checkpoint_callbacks(output_dir: str, name: str, val_freq: int, _ckpt_every_n_steps: int): """ Create checkpoint callbacks based on validation frequency. @@ -136,28 +136,18 @@ def get_checkpoint_callbacks(output_dir: str, name: str, val_freq: int, ckpt_eve checkpoint_dir = join(output_dir, name, "checkpoints") callbacks = [] - # Save best checkpoint based on validation loss + # Save only the best checkpoint (by val_loss) plus the latest checkpoint best_ckpt = ModelCheckpoint( dirpath=checkpoint_dir, - filename="step={step}-val_loss={val_loss:.4f}", - save_last="link", # Will create last.ckpt symlink to best checkpoint + filename="best", + save_last=True, monitor="val_loss", mode="min", - save_top_k=1, # Only keep the best checkpoint + save_top_k=1, every_n_train_steps=val_freq, ) callbacks.append(best_ckpt) - # Also save periodic checkpoints (without affecting the "last" symlink) - periodic_ckpt = ModelCheckpoint( - dirpath=checkpoint_dir, - filename="{step}", - save_last=False, # Don't create/update symlink - every_n_train_steps=ckpt_every_n_steps, - save_top_k=-1, # Keep all periodic checkpoints - ) - callbacks.append(periodic_ckpt) - return callbacks diff --git a/tests/test_bidirectional_models.py b/tests/test_bidirectional_models.py index 5b41cc9a..9a0576be 100644 --- a/tests/test_bidirectional_models.py +++ b/tests/test_bidirectional_models.py @@ -24,10 +24,10 @@ def small_llama_config(): def test_llama_bidirectional_config_is_non_causal(small_llama_config): """Test that LlamaBidirectionalModel sets is_causal to False.""" model = LlamaBidirectionalModel(small_llama_config) - + # Check that the model config is non-causal assert model.config.is_causal is False - + # Check that all attention layers are non-causal for layer in model.layers: if hasattr(layer, "self_attn"): @@ -37,13 +37,13 @@ def test_llama_bidirectional_config_is_non_causal(small_llama_config): def test_llama_bidirectional_update_causal_mask_returns_none(small_llama_config): """Test that _update_causal_mask returns None, disabling causal masking.""" model = LlamaBidirectionalModel(small_llama_config) - + # Create dummy inputs batch_size, seq_len = 2, 8 attention_mask = torch.ones(batch_size, seq_len) input_tensor = torch.randn(batch_size, seq_len, small_llama_config.hidden_size) cache_position = torch.arange(seq_len) - + # Call _update_causal_mask result = model._update_causal_mask( attention_mask=attention_mask, @@ -52,7 +52,7 @@ def test_llama_bidirectional_update_causal_mask_returns_none(small_llama_config) past_key_values=None, output_attentions=False, ) - + # Should return None (no causal masking) assert result is None @@ -98,169 +98,163 @@ def test_get_transformer_backbone_llama_bidirectional_flag(): def test_llama_bidirectional_attention_vs_causal(small_llama_config): """ Test that bidirectional attention produces different outputs than causal attention. - + This is the key test: in bidirectional attention, later tokens should affect earlier token representations, which doesn't happen in causal attention. """ torch.manual_seed(42) - + # Create both bidirectional and standard (causal) models bidirectional_model = LlamaBidirectionalModel(small_llama_config) causal_model = LlamaModel(small_llama_config) - + # Copy weights from bidirectional to causal to ensure same initialization causal_model.load_state_dict(bidirectional_model.state_dict(), strict=False) - + # Create input batch_size, seq_len = 2, 8 inputs_embeds = torch.randn(batch_size, seq_len, small_llama_config.hidden_size) - + # Set models to eval mode bidirectional_model.eval() causal_model.eval() - + with torch.no_grad(): # Get outputs from bidirectional model bidirectional_output = bidirectional_model(inputs_embeds=inputs_embeds) - + # Get outputs from causal model causal_output = causal_model(inputs_embeds=inputs_embeds) - + # The outputs should be different because bidirectional allows all tokens # to attend to each other, while causal only allows attending to past tokens - assert not torch.allclose( - bidirectional_output.last_hidden_state, - causal_output.last_hidden_state, - atol=1e-5 - ), "Bidirectional and causal outputs should differ" + assert not torch.allclose(bidirectional_output.last_hidden_state, causal_output.last_hidden_state, atol=1e-5), ( + "Bidirectional and causal outputs should differ" + ) def test_llama_bidirectional_future_tokens_affect_past(small_llama_config): """ Test that future tokens affect past token representations in bidirectional model. - + This is the core property of bidirectional attention: changing a future token should change the representation of past tokens. """ torch.manual_seed(42) - + model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 1, 6 hidden_size = small_llama_config.hidden_size - + # Create two inputs that differ only in the last token inputs_embeds_1 = torch.randn(batch_size, seq_len, hidden_size) inputs_embeds_2 = inputs_embeds_1.clone() - + # Modify only the last token embedding in the second input inputs_embeds_2[:, -1, :] = torch.randn(batch_size, hidden_size) - + with torch.no_grad(): output_1 = model(inputs_embeds=inputs_embeds_1) output_2 = model(inputs_embeds=inputs_embeds_2) - + # Check that the first tokens' representations differ between the two inputs # This demonstrates that the last token (future) affects the first token (past) first_token_repr_1 = output_1.last_hidden_state[:, 0, :] first_token_repr_2 = output_2.last_hidden_state[:, 0, :] - - assert not torch.allclose(first_token_repr_1, first_token_repr_2, atol=1e-5), \ + + assert not torch.allclose(first_token_repr_1, first_token_repr_2, atol=1e-5), ( "First token representation should change when last token changes (bidirectional attention)" + ) def test_llama_bidirectional_first_token_differs_across_batch(small_llama_config): """ Test that first token representations differ across batch when sequences differ. - + This is a critical test for bidirectional attention: in causal attention, the first token can only attend to itself, so if all sequences have the same first token, they would produce identical first token representations. - + In bidirectional attention, the first token attends to all tokens in the sequence, so different sequences should produce different first token representations even when the first tokens themselves are identical. """ torch.manual_seed(42) - + model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 4, 8 hidden_size = small_llama_config.hidden_size - + # Create a batch where ALL sequences have the SAME first token embedding # but DIFFERENT subsequent tokens inputs_embeds = torch.randn(batch_size, seq_len, hidden_size) - + # Make the first token identical across all sequences shared_first_token = torch.randn(1, hidden_size) inputs_embeds[:, 0, :] = shared_first_token - + with torch.no_grad(): output = model(inputs_embeds=inputs_embeds) - + # Extract first token representations for all sequences first_token_reprs = output.last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size) - + # In bidirectional attention, these should all be DIFFERENT # because each attends to different subsequent tokens # Check that not all first tokens are the same for i in range(batch_size): for j in range(i + 1, batch_size): - assert not torch.allclose( - first_token_reprs[i], - first_token_reprs[j], - atol=1e-5 - ), f"First token representations for sequences {i} and {j} should differ in bidirectional attention" - + assert not torch.allclose(first_token_reprs[i], first_token_reprs[j], atol=1e-5), ( + f"First token representations for sequences {i} and {j} should differ in bidirectional attention" + ) + # Additional check: variance across batch should be substantial variance_per_dim = torch.var(first_token_reprs, dim=0) mean_variance = variance_per_dim.mean() - assert mean_variance > 1e-4, \ + assert mean_variance > 1e-4, ( "First token representations should have substantial variance across batch in bidirectional attention" + ) def test_llama_bidirectional_symmetric_position_influence(small_llama_config): """ - Test that in bidirectional attention, position i affects position j + Test that in bidirectional attention, position i affects position j as much as position j affects position i (roughly symmetric). """ torch.manual_seed(42) - + model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 1, 4 hidden_size = small_llama_config.hidden_size - + # Create base input base_input = torch.randn(batch_size, seq_len, hidden_size) - + # Modify position 0 and see effect on position 2 input_modify_0 = base_input.clone() input_modify_0[:, 0, :] = torch.randn(batch_size, hidden_size) - + # Modify position 2 and see effect on position 0 input_modify_2 = base_input.clone() input_modify_2[:, 2, :] = torch.randn(batch_size, hidden_size) - + with torch.no_grad(): output_base = model(inputs_embeds=base_input) output_modify_0 = model(inputs_embeds=input_modify_0) output_modify_2 = model(inputs_embeds=input_modify_2) - + # Calculate how much position 2 changes when position 0 changes - effect_0_on_2 = torch.norm( - output_modify_0.last_hidden_state[:, 2, :] - output_base.last_hidden_state[:, 2, :] - ) - + effect_0_on_2 = torch.norm(output_modify_0.last_hidden_state[:, 2, :] - output_base.last_hidden_state[:, 2, :]) + # Calculate how much position 0 changes when position 2 changes - effect_2_on_0 = torch.norm( - output_modify_2.last_hidden_state[:, 0, :] - output_base.last_hidden_state[:, 0, :] - ) - + effect_2_on_0 = torch.norm(output_modify_2.last_hidden_state[:, 0, :] - output_base.last_hidden_state[:, 0, :]) + # In bidirectional attention, these effects should both be non-zero # (demonstrating mutual influence, unlike in causal attention) assert effect_0_on_2 > 0.01, "Position 0 should affect position 2" @@ -271,13 +265,13 @@ def test_llama_bidirectional_forward_with_input_ids(small_llama_config): """Test that forward pass works with input_ids.""" model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 2, 10 input_ids = torch.randint(0, small_llama_config.vocab_size, (batch_size, seq_len)) - + with torch.no_grad(): output = model(input_ids=input_ids) - + # Check output shape assert output.last_hidden_state.shape == (batch_size, seq_len, small_llama_config.hidden_size) @@ -286,18 +280,18 @@ def test_llama_bidirectional_forward_with_attention_mask(small_llama_config): """Test that forward pass respects attention mask for padding.""" model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 2, 10 hidden_size = small_llama_config.hidden_size inputs_embeds = torch.randn(batch_size, seq_len, hidden_size) - + # Create attention mask: first sequence has padding at positions 8-9 attention_mask = torch.ones(batch_size, seq_len) attention_mask[0, 8:] = 0 # Mask out last 2 positions for first sequence - + with torch.no_grad(): output = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask) - + # Check that output shape is correct assert output.last_hidden_state.shape == (batch_size, seq_len, hidden_size) @@ -306,24 +300,24 @@ def test_llama_bidirectional_is_causal_false_in_forward(small_llama_config): """Test that is_causal=False is passed in flash_attn_kwargs during forward.""" model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 1, 8 inputs_embeds = torch.randn(batch_size, seq_len, small_llama_config.hidden_size) - + # Monkey-patch the parent's forward to capture flash_attn_kwargs original_forward = LlamaModel.forward captured_kwargs = {} - + def capture_forward(self, **kwargs): captured_kwargs.update(kwargs) return original_forward(self, **kwargs) - + LlamaModel.forward = capture_forward # type: ignore - + try: with torch.no_grad(): model(inputs_embeds=inputs_embeds) - + # Check that is_causal was set to False assert "is_causal" in captured_kwargs assert captured_kwargs["is_causal"] is False @@ -335,9 +329,8 @@ def capture_forward(self, **kwargs): def test_llama_bidirectional_no_rope(small_llama_config): """Test that NoRoPE is used instead of standard rotary embeddings.""" from state.tx.models.utils import NoRoPE - + model = LlamaBidirectionalModel(small_llama_config) - + # Check that rotary_emb is an instance of NoRoPE assert isinstance(model.rotary_emb, NoRoPE) -