From 721d37d2b32cb611a581a422d0be568d17799f0b Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sat, 13 Sep 2025 17:23:45 -0700 Subject: [PATCH 01/38] transfer with hvg --- src/state/_cli/_tx/_predict.py | 5 ++++- src/state/_cli/_tx/_train.py | 4 +++- src/state/configs/model/state.yaml | 6 +++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index a41192cc..b7191037 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -384,7 +384,10 @@ def load_config(cfg_path: str) -> dict: ) # Save the AnnData objects - results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) + if args.eval_train_data: + results_dir = os.path.join(args.output_dir, "eval_train_" + os.path.basename(args.checkpoint)) + else: + results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) os.makedirs(results_dir, exist_ok=True) adata_pred_path = os.path.join(results_dir, "adata_pred.h5ad") adata_real_path = os.path.join(results_dir, "adata_real.h5ad") diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 5db050e8..98ef7665 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -332,7 +332,9 @@ def run_tx_train(cfg: DictConfig): pert_encoder_weight_key = "pert_encoder.0.weight" if pert_encoder_weight_key in checkpoint_state: checkpoint_pert_dim = checkpoint_state[pert_encoder_weight_key].shape[1] - if checkpoint_pert_dim != model.pert_dim: + + # if the cell embedding dim doesn't match, or if it was HVGs, rebuild for transfer learning + if checkpoint_pert_dim != model.pert_dim or cfg["data"]["kwargs"]["embed_key"] == "X_hvg": print( f"pert_encoder input dimension mismatch: model.pert_dim = {model.pert_dim} but checkpoint expects {checkpoint_pert_dim}. Overriding model's pert_dim and rebuilding pert_encoder." ) diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index e9b3e34d..9de7263b 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -5,7 +5,7 @@ device: cuda kwargs: cell_set_len: 512 blur: 0.05 - hidden_dim: 696 # hidden dimension going into the transformer backbone + hidden_dim: 768 # hidden dimension going into the transformer backbone loss: energy confidence_head: False n_encoder_layers: 1 @@ -28,11 +28,11 @@ kwargs: bidirectional_attention: false max_position_embeddings: ${model.kwargs.cell_set_len} hidden_size: ${model.kwargs.hidden_dim} - intermediate_size: 2784 + intermediate_size: 3072 num_hidden_layers: 8 num_attention_heads: 12 num_key_value_heads: 12 - head_dim: 58 + head_dim: 64 use_cache: false attention_dropout: 0.0 hidden_dropout: 0.0 From 0325c7fea05dad43aa3498dd4e28ee6dda9774c2 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sat, 13 Sep 2025 19:22:36 -0700 Subject: [PATCH 02/38] updated files to initial version with fine tuning the decoder --- src/state/emb/finetune_decoder.py | 51 ++++++++++++++++++++++++- src/state/tx/models/decoders.py | 44 ++++++++++++++++++--- src/state/tx/models/pseudobulk.py | 40 +++++++------------ src/state/tx/models/state_transition.py | 38 +++++++----------- 4 files changed, 116 insertions(+), 57 deletions(-) diff --git a/src/state/emb/finetune_decoder.py b/src/state/emb/finetune_decoder.py index 8fdaf819..cf081f55 100644 --- a/src/state/emb/finetune_decoder.py +++ b/src/state/emb/finetune_decoder.py @@ -72,6 +72,47 @@ def load_model(self, checkpoint): # Ensure the binary decoder is in training mode so gradients are enabled. self.model.binary_decoder.eval() + def _auto_detect_gene_column(self, adata): + """Auto-detect the gene column with highest overlap with protein embeddings. + + Returns None to indicate var.index, or a string column name in var. + """ + if self.protein_embeds is None: + log.warning("No protein embeddings available for auto-detection, using index") + return None + + protein_genes = set(self.protein_embeds.keys()) + best_column = None + best_overlap = 0 + + # Check index first + index_genes = set(getattr(adata.var, "index", [])) + overlap = len(protein_genes.intersection(index_genes)) + if overlap > best_overlap: + best_overlap = overlap + best_column = None # None means use index + + # Check all columns in var + for col in adata.var.columns: + try: + col_vals = adata.var[col].dropna().astype(str) + except Exception: + continue + col_genes = set(col_vals) + overlap = len(protein_genes.intersection(col_genes)) + if overlap > best_overlap: + best_overlap = overlap + best_column = col + + return best_column + + def genes_from_adata(self, adata): + """Return list of gene names from AnnData using auto-detected column/index.""" + col = self._auto_detect_gene_column(adata) + if col is None: + return list(map(str, adata.var.index.values)) + return list(adata.var[col].astype(str).values) + def get_gene_embedding(self, genes): """ Get embeddings for a list of genes, with caching to avoid recomputation. @@ -93,8 +134,16 @@ def get_gene_embedding(self, genes): if cache_key in self.cached_gene_embeddings: return self.cached_gene_embeddings[cache_key] + # Strict validation: ensure all genes are present in pretrained all_embeddings + missing = [g for g in genes if g not in self.protein_embeds] + if len(missing) > 0: + raise ValueError( + f"Finetune.get_gene_embedding: {len(missing)} gene(s) not found in pretrained all_embeddings: " + f"{missing[:10]}{' ...' if len(missing) > 10 else ''}." + ) + # Compute gene embeddings - protein_embeds = [self.protein_embeds[x] if x in self.protein_embeds else torch.zeros(5120) for x in genes] + protein_embeds = [self.protein_embeds[x] for x in genes] protein_embeds = torch.stack(protein_embeds).to(self.device) gene_embeds = self.model.gene_embedding_layer(protein_embeds) diff --git a/src/state/tx/models/decoders.py b/src/state/tx/models/decoders.py index b7caa741..0b1f5c22 100644 --- a/src/state/tx/models/decoders.py +++ b/src/state/tx/models/decoders.py @@ -12,7 +12,8 @@ class FinetuneVCICountsDecoder(nn.Module): def __init__( self, - genes, + genes=None, + adata=None, # model_loc="/large_storage/ctc/userspace/aadduri/vci/checkpoint/rda_tabular_counts_2048_new/step=950000.ckpt", # config="/large_storage/ctc/userspace/aadduri/vci/checkpoint/rda_tabular_counts_2048_new/tahoe_config.yaml", model_loc="/home/aadduri/vci_pretrain/vci_1.4.2.ckpt", @@ -24,12 +25,22 @@ def __init__( basal_residual=False, ): super().__init__() - self.genes = genes + # Initialize finetune helper and model self.model_loc = model_loc self.config = config self.finetune = Finetune(OmegaConf.load(self.config)) self.finetune.load_model(self.model_loc) - self.read_depth = nn.Parameter(torch.tensor(read_depth, dtype=torch.float), requires_grad=False) + # Resolve genes: prefer explicit list; else infer from anndata if provided + if genes is None and adata is not None: + try: + genes = self.finetune.genes_from_adata(adata) + except Exception as e: + raise ValueError(f"Failed to infer genes from AnnData: {e}") + if genes is None: + raise ValueError("FinetuneVCICountsDecoder requires 'genes' or 'adata' to derive gene names") + self.genes = genes + # Keep read_depth as a learnable parameter so decoded counts can adapt + self.read_depth = nn.Parameter(torch.tensor(read_depth, dtype=torch.float), requires_grad=True) self.basal_residual = basal_residual # layers = [ @@ -60,6 +71,27 @@ def __init__( for param in self.binary_decoder.parameters(): param.requires_grad = False + # Validate that all requested genes exist in the pretrained checkpoint's embeddings + pe = getattr(self.finetune, "protein_embeds", {}) + present = [g for g in self.genes if g in pe] + missing = [g for g in self.genes if g not in pe] + if len(missing) > 0: + total_req = len(self.genes) + total_pe = len(pe) if hasattr(pe, "__len__") else -1 + found = total_req - len(missing) + miss_pct = (len(missing) / total_req) if total_req > 0 else 1.0 + logger.error( + f"FinetuneVCICountsDecoder gene check: requested={total_req}, found={found}, missing={len(missing)} ({miss_pct:.1%}), all_embeddings_size={total_pe}" + ) + logger.error( + f"Examples missing: {missing[:10]}{' ...' if len(missing) > 10 else ''}; examples present: {present[:10]}{' ...' if len(present) > 10 else ''}" + ) + raise ValueError( + f"FinetuneVCICountsDecoder: {len(missing)} gene(s) not found in pretrained all_embeddings. " + f"Requested={total_req}, Found={found}, Missing={len(missing)} ({miss_pct:.1%}). " + f"First missing: {missing[:10]}{' ...' if len(missing) > 10 else ''}." + ) + def gene_dim(self): return len(self.genes) @@ -113,8 +145,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: decoded_gene = decoded_gene + self.gene_decoder_proj(decoded_gene) # decoded_gene = torch.nn.functional.relu(decoded_gene) - # # normalize the sum of decoded_gene to be read depth - # decoded_gene = decoded_gene / decoded_gene.sum(dim=2, keepdim=True) * self.read_depth + # Normalize the sum of decoded_gene to be read depth (learnable) + # Guard against divide-by-zero by adding small epsilon + eps = 1e-6 + decoded_gene = decoded_gene / (decoded_gene.sum(dim=2, keepdim=True) + eps) * self.read_depth # decoded_gene = self.gene_lora(decoded_gene) # TODO: fix this to work with basal counts diff --git a/src/state/tx/models/pseudobulk.py b/src/state/tx/models/pseudobulk.py index c63eb43e..971db977 100644 --- a/src/state/tx/models/pseudobulk.py +++ b/src/state/tx/models/pseudobulk.py @@ -115,33 +115,21 @@ def __init__( control_pert = kwargs.get("control_pert", "non-targeting") if kwargs.get("finetune_vci_decoder", False): - gene_names = [] - - if output_space == "gene": - # hvg's but for which dataset? - if "DMSO_TF" in control_pert: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True - ) - elif "non-targeting" in control_pert: - temp = ad.read_h5ad("/large_storage/ctc/userspace/aadduri/datasets/hvg/replogle/jurkat.h5") - gene_names = temp.var.index.values - else: - assert output_space == "all" - if "DMSO_TF" in control_pert: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_names.npy", allow_pickle=True - ) - elif "non-targeting" in control_pert: - # temp = ad.read_h5ad('/scratch/ctc/ML/vci/paper_replogle/jurkat.h5') - # gene_names = temp.var.index.values - temp = ad.read_h5ad("/large_storage/ctc/userspace/aadduri/cross_dataset/replogle/jurkat.h5") - gene_names = temp.var.index.values - - self.gene_decoder = FinetuneVCICountsDecoder( - genes=gene_names, - # latent_dim=self.output_dim + (self.batch_dim or 0), + # Prefer the gene names supplied by the data module (aligned to training output) + gene_names = self.gene_names + if gene_names is None: + raise ValueError( + "finetune_vci_decoder=True but model.gene_names is None. " + "Please provide gene_names via data module var_dims." + ) + + n_genes = len(gene_names) + logger.info( + f"Initializing FinetuneVCICountsDecoder with {n_genes} genes (output_space={output_space}; " + + ("HVG subset" if output_space == "gene" else "all genes") + + ")" ) + self.gene_decoder = FinetuneVCICountsDecoder(genes=gene_names) print(self) diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index ecce5e29..d9e5cdbd 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -274,33 +274,21 @@ def __init__( control_pert = kwargs.get("control_pert", "non-targeting") if kwargs.get("finetune_vci_decoder", False): # TODO: This will go very soon - gene_names = [] + # Prefer the gene names supplied by the data module (aligned to training output) + gene_names = self.gene_names + if gene_names is None: + raise ValueError( + "finetune_vci_decoder=True but model.gene_names is None. " + "Please provide gene_names via data module var_dims." + ) - if output_space == "gene": - # hvg's but for which dataset? - if "DMSO_TF" in control_pert: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True - ) - elif "non-targeting" in control_pert: - temp = ad.read_h5ad("/large_storage/ctc/userspace/aadduri/datasets/hvg/replogle/jurkat.h5") - # gene_names = temp.var.index.values - else: - assert output_space == "all" - if "DMSO_TF" in control_pert: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_names.npy", allow_pickle=True - ) - elif "non-targeting" in control_pert: - # temp = ad.read_h5ad('/scratch/ctc/ML/vci/paper_replogle/jurkat.h5') - # gene_names = temp.var.index.values - temp = ad.read_h5ad("/large_storage/ctc/userspace/aadduri/cross_dataset/replogle/jurkat.h5") - gene_names = temp.var.index.values - - self.gene_decoder = FinetuneVCICountsDecoder( - genes=gene_names, - # latent_dim=self.output_dim + (self.batch_dim or 0), + n_genes = len(gene_names) + logger.info( + f"Initializing FinetuneVCICountsDecoder with {n_genes} genes (output_space={output_space}; " + + ("HVG subset" if output_space == "gene" else "all genes") + + ")" ) + self.gene_decoder = FinetuneVCICountsDecoder(genes=gene_names) print(self) def _build_networks(self, lora_cfg=None): From f45e06d937810ce2f4812462663be08e6939462f Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sat, 13 Sep 2025 22:55:20 -0700 Subject: [PATCH 03/38] fixed model to use the pretrained decoder but with a for loop --- src/state/emb/finetune_decoder.py | 108 +++++++++++++-------- src/state/tx/models/decoders.py | 123 ++++++++++++++---------- src/state/tx/models/pseudobulk.py | 5 +- src/state/tx/models/state_transition.py | 4 +- 4 files changed, 150 insertions(+), 90 deletions(-) diff --git a/src/state/emb/finetune_decoder.py b/src/state/emb/finetune_decoder.py index cf081f55..d5d01fb2 100644 --- a/src/state/emb/finetune_decoder.py +++ b/src/state/emb/finetune_decoder.py @@ -1,6 +1,7 @@ import logging import torch from torch import nn +from omegaconf import OmegaConf from vci.nn.model import StateEmbeddingModel from vci.train.trainer import get_embeddings @@ -10,7 +11,7 @@ class Finetune: - def __init__(self, cfg, learning_rate=1e-4): + def __init__(self, cfg=None, learning_rate=1e-4): """ Initialize the Finetune class for fine-tuning the binary decoder of a pre-trained model. @@ -29,47 +30,71 @@ def __init__(self, cfg, learning_rate=1e-4): self.cached_gene_embeddings = {} self.device = None - def load_model(self, checkpoint): + def load_model(self, checkpoint: str): """ - Load a pre-trained model from a checkpoint and prepare it for fine-tuning. - - Parameters: - ----------- - checkpoint : str - Path to the checkpoint file + Load a pre-trained SE model from a single checkpoint path and prepare + it for use. Mirrors the transform/inference loader behavior: extract + config and embeddings from the checkpoint if present, otherwise fallbacks. """ if self.model: raise ValueError("Model already initialized") - # Import locally to avoid circular imports - - # Load and initialize model for eval - self.model = StateEmbeddingModel.load_from_checkpoint(checkpoint, strict=False) - - # Ensure model uses the provided config, not the stored one - if self._vci_conf is not None: - self.model.update_config(self._vci_conf) - + # Resolve configuration: prefer embedded cfg in checkpoint + cfg_to_use = self._vci_conf + if cfg_to_use is None: + try: + ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False) + if isinstance(ckpt, dict) and "cfg_yaml" in ckpt: + cfg_to_use = OmegaConf.create(ckpt["cfg_yaml"]) # type: ignore + elif isinstance(ckpt, dict) and "hyper_parameters" in ckpt: + hp = ckpt.get("hyper_parameters", {}) or {} + # Some checkpoints may have a cfg-like structure in hyper_parameters + if isinstance(hp, dict) and len(hp) > 0: + try: + cfg_to_use = OmegaConf.create(hp["cfg"]) if "cfg" in hp else OmegaConf.create(hp) + except Exception: + cfg_to_use = OmegaConf.create(hp) + except Exception as e: + log.warning(f"Could not extract config from checkpoint: {e}") + if cfg_to_use is None: + raise ValueError("No config found in checkpoint and no override provided. Provide SE cfg or a full checkpoint.") + + self._vci_conf = cfg_to_use + + # Load model; allow passing cfg to constructor like inference + self.model = StateEmbeddingModel.load_from_checkpoint( + checkpoint, dropout=0.0, strict=False, cfg=self._vci_conf + ) self.device = self.model.device - # Load protein embeddings - all_pe = get_embeddings(self._vci_conf) + # Try to extract packaged protein embeddings from checkpoint + packaged_pe = None + try: + ckpt2 = torch.load(checkpoint, map_location="cpu", weights_only=False) + if isinstance(ckpt2, dict) and "protein_embeds_dict" in ckpt2: + packaged_pe = ckpt2["protein_embeds_dict"] + except Exception: + pass + + # Resolve protein embeddings for pe_embedding weights + all_pe = packaged_pe or get_embeddings(self._vci_conf) + if isinstance(all_pe, dict): + all_pe = torch.vstack(list(all_pe.values())) all_pe.requires_grad = False self.model.pe_embedding = nn.Embedding.from_pretrained(all_pe) self.model.pe_embedding.to(self.device) - # Load protein embeddings - self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings) - - # Freeze all parameters - for param in self.model.parameters(): - param.requires_grad = False - - # Enable gradients only for binary decoder - for param in self.model.binary_decoder.parameters(): - param.requires_grad = False - - # Ensure the binary decoder is in training mode so gradients are enabled. + # Keep a mapping from gene name -> protein embedding vector + self.protein_embeds = packaged_pe + if self.protein_embeds is None: + # Fallback to configured path + self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings, weights_only=False) + + # Freeze SE model and decoder + for p in self.model.parameters(): + p.requires_grad = False + for p in self.model.binary_decoder.parameters(): + p.requires_grad = False self.model.binary_decoder.eval() def _auto_detect_gene_column(self, adata): @@ -134,16 +159,23 @@ def get_gene_embedding(self, genes): if cache_key in self.cached_gene_embeddings: return self.cached_gene_embeddings[cache_key] - # Strict validation: ensure all genes are present in pretrained all_embeddings + # Compute gene embeddings; fallback to zero vectors for missing genes. missing = [g for g in genes if g not in self.protein_embeds] if len(missing) > 0: - raise ValueError( - f"Finetune.get_gene_embedding: {len(missing)} gene(s) not found in pretrained all_embeddings: " - f"{missing[:10]}{' ...' if len(missing) > 10 else ''}." + try: + embed_size = next(iter(self.protein_embeds.values())).shape[-1] + except Exception: + embed_size = 5120 + # Log once per call to aid debugging + log.warning( + f"Finetune.get_gene_embedding: {len(missing)} gene(s) missing from pretrained embeddings; using zeros as placeholders. " + f"First missing: {missing[:10]}{' ...' if len(missing) > 10 else ''}." ) - # Compute gene embeddings - protein_embeds = [self.protein_embeds[x] for x in genes] + protein_embeds = [ + self.protein_embeds[x] if x in self.protein_embeds else torch.zeros(embed_size) + for x in genes + ] protein_embeds = torch.stack(protein_embeds).to(self.device) gene_embeds = self.model.gene_embedding_layer(protein_embeds) @@ -171,7 +203,7 @@ def get_counts(self, cell_embs, genes, read_depth=None, batch_size=32): # Check if RDA is enabled. use_rda = getattr(self.model.cfg.model, "rda", False) if use_rda and read_depth is None: - read_depth = 1000.0 + read_depth = 4.0 # Retrieve gene embeddings (cached if available). gene_embeds = self.get_gene_embedding(genes) diff --git a/src/state/tx/models/decoders.py b/src/state/tx/models/decoders.py index 0b1f5c22..96862bfc 100644 --- a/src/state/tx/models/decoders.py +++ b/src/state/tx/models/decoders.py @@ -1,4 +1,5 @@ import logging +import os import torch import torch.nn as nn @@ -14,22 +15,24 @@ def __init__( self, genes=None, adata=None, - # model_loc="/large_storage/ctc/userspace/aadduri/vci/checkpoint/rda_tabular_counts_2048_new/step=950000.ckpt", - # config="/large_storage/ctc/userspace/aadduri/vci/checkpoint/rda_tabular_counts_2048_new/tahoe_config.yaml", - model_loc="/home/aadduri/vci_pretrain/vci_1.4.2.ckpt", - config="/large_storage/ctc/userspace/aadduri/vci/checkpoint/large_1e-4_rda_tabular_counts_2048/crossds_config.yaml", - read_depth=1200, - latent_dim=1024, # dimension of pretrained vci model - hidden_dims=[512, 512, 512], # hidden dimensions of the decoder + # checkpoint: str = "/large_storage/ctc/userspace/aadduri/SE-600M/se600m_epoch15.ckpt", + # config: str = "/large_storage/ctc/userspace/aadduri/SE-600M/config.yaml", + checkpoint: str = "/home/aadduri/vci_pretrain/vci_1.4.4/vci_1.4.4_v7.ckpt", + config: str = "/home/aadduri/vci_pretrain/vci_1.4.4/config.yaml", + read_depth=4.0, + latent_dim=1034, # dimension of pretrained vci model + hidden_dim=512, # hidden dimensions of the decoder dropout=0.1, basal_residual=False, ): super().__init__() - # Initialize finetune helper and model - self.model_loc = model_loc - self.config = config - self.finetune = Finetune(OmegaConf.load(self.config)) - self.finetune.load_model(self.model_loc) + # Initialize finetune helper and model from a single checkpoint + if checkpoint is None: + raise ValueError( + "FinetuneVCICountsDecoder requires a VCI/SE checkpoint. Set kwargs.vci_checkpoint or env STATE_VCI_CHECKPOINT." + ) + self.finetune = Finetune(cfg=OmegaConf.load(config)) + self.finetune.load_model(checkpoint) # Resolve genes: prefer explicit list; else infer from anndata if provided if genes is None and adata is not None: try: @@ -50,20 +53,25 @@ def __init__( # self.gene_lora = nn.Sequential(*layers) self.latent_decoder = nn.Sequential( - nn.Linear(latent_dim, hidden_dims[0]), - nn.LayerNorm(hidden_dims[0]), + nn.Linear(latent_dim, hidden_dim), + nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), - nn.Linear(hidden_dims[0], hidden_dims[1]), - nn.LayerNorm(hidden_dims[1]), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), - nn.Linear(hidden_dims[1], len(self.genes)), - nn.ReLU(), + nn.Linear(hidden_dim, len(self.genes)), ) self.gene_decoder_proj = nn.Sequential( nn.Linear(len(self.genes), 128), + nn.LayerNorm(128), + nn.GELU(), + nn.Linear(128, 128), + nn.LayerNorm(128), + nn.GELU(), + nn.Dropout(dropout), nn.Linear(128, len(self.genes)), ) @@ -73,24 +81,39 @@ def __init__( # Validate that all requested genes exist in the pretrained checkpoint's embeddings pe = getattr(self.finetune, "protein_embeds", {}) - present = [g for g in self.genes if g in pe] - missing = [g for g in self.genes if g not in pe] - if len(missing) > 0: - total_req = len(self.genes) - total_pe = len(pe) if hasattr(pe, "__len__") else -1 - found = total_req - len(missing) - miss_pct = (len(missing) / total_req) if total_req > 0 else 1.0 - logger.error( - f"FinetuneVCICountsDecoder gene check: requested={total_req}, found={found}, missing={len(missing)} ({miss_pct:.1%}), all_embeddings_size={total_pe}" - ) - logger.error( - f"Examples missing: {missing[:10]}{' ...' if len(missing) > 10 else ''}; examples present: {present[:10]}{' ...' if len(present) > 10 else ''}" - ) - raise ValueError( - f"FinetuneVCICountsDecoder: {len(missing)} gene(s) not found in pretrained all_embeddings. " - f"Requested={total_req}, Found={found}, Missing={len(missing)} ({miss_pct:.1%}). " - f"First missing: {missing[:10]}{' ...' if len(missing) > 10 else ''}." - ) + self.present_mask = [g in pe for g in self.genes] + self.missing_positions = [i for i, g in enumerate(self.genes) if g not in pe] + self.missing_genes = [self.genes[i] for i in self.missing_positions] + total_req = len(self.genes) + found = total_req - len(self.missing_positions) + total_pe = len(pe) if hasattr(pe, "__len__") else -1 + miss_pct = (len(self.missing_positions) / total_req) if total_req > 0 else 0.0 + logger.info( + f"FinetuneVCICountsDecoder gene check: requested={total_req}, found={found}, missing={len(self.missing_positions)} ({miss_pct:.1%}), all_embeddings_size={total_pe}" + ) + + # Create learnable embeddings for missing genes in the post-ESM gene embedding space + if len(self.missing_positions) > 0: + # Infer gene embedding output dimension by a dry-run through gene_embedding_layer + try: + sample_vec = next(iter(pe.values())).to(self.finetune.model.device) + if sample_vec.dim() == 1: + sample_vec = sample_vec.unsqueeze(0) + gene_embed_dim = self.finetune.model.gene_embedding_layer(sample_vec).shape[-1] + except Exception: + # Conservative fallback + gene_embed_dim = 1024 + + self.missing_table = nn.Embedding(len(self.missing_positions), gene_embed_dim) + nn.init.normal_(self.missing_table.weight, mean=0.0, std=0.02) + # For user visibility + try: + self.finetune.missing_genes = self.missing_genes + except Exception: + pass + else: + # Register a dummy buffer so attributes exist + self.missing_table = None def gene_dim(self): return len(self.genes) @@ -104,6 +127,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Get gene embeddings gene_embeds = self.finetune.get_gene_embedding(self.genes) + # Replace missing gene rows with learnable embeddings + if self.missing_table is not None and len(self.missing_positions) > 0: + device = gene_embeds.device + learned = self.missing_table.weight.to(device) + idx = torch.tensor(self.missing_positions, device=device, dtype=torch.long) + gene_embeds = gene_embeds.clone() + gene_embeds.index_copy_(0, idx, learned) # Handle RDA task counts use_rda = getattr(self.finetune.model.cfg.model, "rda", False) @@ -117,15 +147,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Create task_counts for the sub-batch if needed if use_rda: - # task_counts_sub = torch.full( - # (x_sub.shape[0],), self.read_depth, device=x.device - # ) task_counts_sub = torch.ones((x_sub.shape[0],), device=x.device) * self.read_depth else: task_counts_sub = None # Compute merged embeddings for the sub-batch - merged_embs_sub = self.finetune.model.resize_batch(x_sub, gene_embeds, task_counts_sub) + # resize_batch(cell_embeds, task_embeds, task_counts=None, sampled_rda=None, ds_emb=None) + cell_embeds = x_sub[:, :-10] + ds_emb = x_sub[:, -10:] + merged_embs_sub = self.finetune.model.resize_batch( + cell_embeds=cell_embeds, task_embeds=gene_embeds, task_counts=task_counts_sub, ds_emb=ds_emb + ) # Run the binary decoder on the sub-batch logprobs_sub = self.binary_decoder(merged_embs_sub) @@ -143,19 +175,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Reshape back to [B, S, gene_dim] decoded_gene = logprobs.view(batch_size, seq_len, len(self.genes)) decoded_gene = decoded_gene + self.gene_decoder_proj(decoded_gene) - # decoded_gene = torch.nn.functional.relu(decoded_gene) - - # Normalize the sum of decoded_gene to be read depth (learnable) - # Guard against divide-by-zero by adding small epsilon - eps = 1e-6 - decoded_gene = decoded_gene / (decoded_gene.sum(dim=2, keepdim=True) + eps) * self.read_depth - - # decoded_gene = self.gene_lora(decoded_gene) - # TODO: fix this to work with basal counts # add logic for basal_residual: decoded_x = self.latent_decoder(x) decoded_x = decoded_x.view(batch_size, seq_len, len(self.genes)) # Pass through the additional decoder layers - return decoded_gene + decoded_x + return torch.nn.functional.relu(decoded_gene + decoded_x) diff --git a/src/state/tx/models/pseudobulk.py b/src/state/tx/models/pseudobulk.py index 971db977..82309c33 100644 --- a/src/state/tx/models/pseudobulk.py +++ b/src/state/tx/models/pseudobulk.py @@ -129,7 +129,10 @@ def __init__( + ("HVG subset" if output_space == "gene" else "all genes") + ")" ) - self.gene_decoder = FinetuneVCICountsDecoder(genes=gene_names) + self.gene_decoder = FinetuneVCICountsDecoder( + genes=gene_names, + checkpoint=kwargs.get("vci_checkpoint", None), + ) print(self) diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index d9e5cdbd..0e7ba952 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -288,7 +288,9 @@ def __init__( + ("HVG subset" if output_space == "gene" else "all genes") + ")" ) - self.gene_decoder = FinetuneVCICountsDecoder(genes=gene_names) + self.gene_decoder = FinetuneVCICountsDecoder( + genes=gene_names, + ) print(self) def _build_networks(self, lora_cfg=None): From 3f41fc771404d8eb6c263043568c5b2d92a31df7 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sun, 14 Sep 2025 14:38:18 -0700 Subject: [PATCH 04/38] working impl for 1024 and 2048 dim state impls --- src/state/tx/models/decoders.py | 114 +++++++++++++++++--------------- 1 file changed, 60 insertions(+), 54 deletions(-) diff --git a/src/state/tx/models/decoders.py b/src/state/tx/models/decoders.py index 96862bfc..bc04e66e 100644 --- a/src/state/tx/models/decoders.py +++ b/src/state/tx/models/decoders.py @@ -1,8 +1,10 @@ import logging import os +from typing import Optional import torch import torch.nn as nn + from omegaconf import OmegaConf from ...emb.finetune_decoder import Finetune @@ -15,21 +17,22 @@ def __init__( self, genes=None, adata=None, - # checkpoint: str = "/large_storage/ctc/userspace/aadduri/SE-600M/se600m_epoch15.ckpt", - # config: str = "/large_storage/ctc/userspace/aadduri/SE-600M/config.yaml", - checkpoint: str = "/home/aadduri/vci_pretrain/vci_1.4.4/vci_1.4.4_v7.ckpt", - config: str = "/home/aadduri/vci_pretrain/vci_1.4.4/config.yaml", - read_depth=4.0, - latent_dim=1034, # dimension of pretrained vci model - hidden_dim=512, # hidden dimensions of the decoder - dropout=0.1, - basal_residual=False, + # checkpoint: Optional[str] = "/large_storage/ctc/userspace/aadduri/SE-600M/se600m_epoch15.ckpt", + # config: Optional[str] = "/large_storage/ctc/userspace/aadduri/SE-600M/config.yaml", + checkpoint: Optional[str] = "/home/aadduri/vci_pretrain/vci_1.4.4/vci_1.4.4_v7.ckpt", + config: Optional[str] = "/home/aadduri/vci_pretrain/vci_1.4.4/config.yaml", + latent_dim: int = 1034, # total input dim (cell emb + optional ds emb) + read_depth: float = 4.0, + ds_emb_dim: int = 10, # dataset embedding dim at the tail of input + hidden_dim: int = 512, + dropout: float = 0.1, + basal_residual: bool = False, ): super().__init__() # Initialize finetune helper and model from a single checkpoint - if checkpoint is None: + if config is None: raise ValueError( - "FinetuneVCICountsDecoder requires a VCI/SE checkpoint. Set kwargs.vci_checkpoint or env STATE_VCI_CHECKPOINT." + "FinetuneVCICountsDecoder requires a VCI/SE config. Set kwargs.vci_config or env STATE_VCI_CONFIG." ) self.finetune = Finetune(cfg=OmegaConf.load(config)) self.finetune.load_model(checkpoint) @@ -45,6 +48,8 @@ def __init__( # Keep read_depth as a learnable parameter so decoded counts can adapt self.read_depth = nn.Parameter(torch.tensor(read_depth, dtype=torch.float), requires_grad=True) self.basal_residual = basal_residual + self.ds_emb_dim = int(ds_emb_dim) if ds_emb_dim is not None else 0 + self.input_total_dim = int(latent_dim) # layers = [ # nn.Linear(latent_dim, hidden_dims[0]), @@ -119,66 +124,67 @@ def gene_dim(self): return len(self.genes) def forward(self, x: torch.Tensor) -> torch.Tensor: - # x is [B, S, latent_dim]. - if len(x.shape) != 3: + # x is [B, S, total_dim] + if x.dim() != 3: x = x.unsqueeze(0) - batch_size, seq_len, latent_dim = x.shape - x = x.view(batch_size * seq_len, latent_dim) + batch_size, seq_len, total_dim = x.shape + x_flat = x.reshape(batch_size * seq_len, total_dim) + + # Split cell and dataset embeddings + if self.ds_emb_dim > 0: + cell_embeds = x_flat[:, : total_dim - self.ds_emb_dim] + ds_emb = x_flat[:, total_dim - self.ds_emb_dim : total_dim] + else: + cell_embeds = x_flat + ds_emb = None - # Get gene embeddings + # Prepare gene embeddings (replace any missing with learned vectors) gene_embeds = self.finetune.get_gene_embedding(self.genes) - # Replace missing gene rows with learnable embeddings if self.missing_table is not None and len(self.missing_positions) > 0: device = gene_embeds.device learned = self.missing_table.weight.to(device) idx = torch.tensor(self.missing_positions, device=device, dtype=torch.long) gene_embeds = gene_embeds.clone() gene_embeds.index_copy_(0, idx, learned) + # Ensure embeddings live on the same device as cell_embeds + if gene_embeds.device != cell_embeds.device: + gene_embeds = gene_embeds.to(cell_embeds.device) - # Handle RDA task counts + # RDA read depth vector (if enabled in SE model) use_rda = getattr(self.finetune.model.cfg.model, "rda", False) - # Define your sub-batch size (tweak this based on your available memory) - sub_batch_size = 16 - logprob_chunks = [] # to store outputs of each sub-batch - - for i in range(0, x.shape[0], sub_batch_size): - # Get the sub-batch of latent vectors - x_sub = x[i : i + sub_batch_size] - - # Create task_counts for the sub-batch if needed - if use_rda: - task_counts_sub = torch.ones((x_sub.shape[0],), device=x.device) * self.read_depth - else: - task_counts_sub = None - - # Compute merged embeddings for the sub-batch - # resize_batch(cell_embeds, task_embeds, task_counts=None, sampled_rda=None, ds_emb=None) - cell_embeds = x_sub[:, :-10] - ds_emb = x_sub[:, -10:] - merged_embs_sub = self.finetune.model.resize_batch( - cell_embeds=cell_embeds, task_embeds=gene_embeds, task_counts=task_counts_sub, ds_emb=ds_emb + task_counts = None + if use_rda: + task_counts = torch.full((cell_embeds.shape[0],), self.read_depth.item(), device=cell_embeds.device) + + # Binary decoder forward with safe dtype handling. + # - On CUDA: enable bf16 autocast for speed. + # - On CPU: ensure inputs match decoder weight dtype to avoid BF16/FP32 mismatch. + device_type = "cuda" if cell_embeds.is_cuda else "cpu" + with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=(device_type == "cuda")): + merged = self.finetune.model.resize_batch( + cell_embeds=cell_embeds, task_embeds=gene_embeds, task_counts=task_counts, ds_emb=ds_emb ) - # Run the binary decoder on the sub-batch - logprobs_sub = self.binary_decoder(merged_embs_sub) - - # Squeeze the singleton dimension if needed - if logprobs_sub.dim() == 3 and logprobs_sub.size(-1) == 1: - logprobs_sub = logprobs_sub.squeeze(-1) + # Align input dtype with decoder weights when autocast is not active (e.g., CPU path) + dec_param_dtype = next(self.binary_decoder.parameters()).dtype + if device_type != "cuda" and merged.dtype != dec_param_dtype: + merged = merged.to(dec_param_dtype) - # Collect the results - logprob_chunks.append(logprobs_sub) - - # Concatenate the sub-batches back together - logprobs = torch.cat(logprob_chunks, dim=0) + logprobs = self.binary_decoder(merged) + if logprobs.dim() == 3 and logprobs.size(-1) == 1: + logprobs = logprobs.squeeze(-1) # Reshape back to [B, S, gene_dim] decoded_gene = logprobs.view(batch_size, seq_len, len(self.genes)) - decoded_gene = decoded_gene + self.gene_decoder_proj(decoded_gene) - # add logic for basal_residual: - decoded_x = self.latent_decoder(x) - decoded_x = decoded_x.view(batch_size, seq_len, len(self.genes)) + # Match dtype for post-decoder projection to avoid mixed-dtype matmul + proj_param_dtype = next(self.gene_decoder_proj.parameters()).dtype + if decoded_gene.dtype != proj_param_dtype: + decoded_gene = decoded_gene.to(proj_param_dtype) + decoded_gene = decoded_gene + self.gene_decoder_proj(decoded_gene) - # Pass through the additional decoder layers + # Optional residual from latent decoder (operates on full input features) + ld_param_dtype = next(self.latent_decoder.parameters()).dtype + x_flat_for_ld = x_flat if x_flat.dtype == ld_param_dtype else x_flat.to(ld_param_dtype) + decoded_x = self.latent_decoder(x_flat_for_ld).view(batch_size, seq_len, len(self.genes)) return torch.nn.functional.relu(decoded_gene + decoded_x) From 16fed4ebc72e3fae7b04a61ec1af102a9fa65303 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sun, 14 Sep 2025 17:49:23 -0700 Subject: [PATCH 05/38] updated with correct fine tuning, gradient accumulation --- src/state/_cli/_tx/_train.py | 2 +- src/state/configs/training/default.yaml | 3 +- src/state/emb/finetune_decoder.py | 317 +++++++++++++++--------- src/state/tx/models/decoders.py | 13 +- 4 files changed, 208 insertions(+), 127 deletions(-) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 98ef7665..c3f475e6 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -259,7 +259,7 @@ def run_tx_train(cfg: DictConfig): plugins=plugins, callbacks=callbacks, gradient_clip_val=cfg["training"]["gradient_clip_val"] if cfg["model"]["name"].lower() != "cpa" else None, - use_distributed_sampler=False, # Prevent Lightning from wrapping PerturbationBatchSampler with DistributedSampler + accumulate_grad_batches=cfg["training"].get("gradient_accumulation_steps", 1), ) # Align logging cadence with rolling MFU window (and W&B logging) diff --git a/src/state/configs/training/default.yaml b/src/state/configs/training/default.yaml index 3b31cd27..a1fe5d7d 100644 --- a/src/state/configs/training/default.yaml +++ b/src/state/configs/training/default.yaml @@ -7,6 +7,7 @@ train_seed: 42 val_freq: 2000 ckpt_every_n_steps: 2000 gradient_clip_val: 10 # 0 means no clipping +gradient_accumulation_steps: 1 loss_fn: mse devices: 1 # Number of GPUs to use for training strategy: auto # DDP strategy for multi-GPU training @@ -16,4 +17,4 @@ mfu_kwargs: use_backward: true logging_interval: 10 window_size: 2 -cumulative_flops_use_backward: true \ No newline at end of file +cumulative_flops_use_backward: true diff --git a/src/state/emb/finetune_decoder.py b/src/state/emb/finetune_decoder.py index d5d01fb2..42af047a 100644 --- a/src/state/emb/finetune_decoder.py +++ b/src/state/emb/finetune_decoder.py @@ -1,4 +1,8 @@ +# src/state/emb/finetune_decoder.py + import logging +from typing import Dict, List, Optional, Tuple + import torch from torch import nn from omegaconf import OmegaConf @@ -10,33 +14,51 @@ log = logging.getLogger(__name__) -class Finetune: - def __init__(self, cfg=None, learning_rate=1e-4): +class Finetune(nn.Module): + def __init__(self, cfg: Optional[OmegaConf] = None, learning_rate: float = 1e-4, read_depth: float = 4.0, train_binary_decoder: bool = False): """ - Initialize the Finetune class for fine-tuning the binary decoder of a pre-trained model. - - Parameters: - ----------- - cfg : OmegaConf - Configuration object containing model settings - learning_rate : float - Learning rate for fine-tuning the binary decoder + Helper module that loads a pretrained SE/VCI checkpoint and exposes: + - get_gene_embedding(genes): returns gene/task embeddings with differentiable + replacement for any genes missing from pretrained protein embeddings + - get_counts(cell_embs, genes): runs the pretrained binary decoder in a vectorized way + + Args: + cfg: OmegaConf for the SE model (if not embedded in checkpoint) + learning_rate: (kept for API compatibility; not used directly here) + read_depth: initial value for a learnable read depth scalar (if RDA enabled) """ - self.model = None + super().__init__() + self.model: Optional[StateEmbeddingModel] = None self.collator = None - self.protein_embeds = None + self.protein_embeds: Optional[Dict[str, torch.Tensor]] = None self._vci_conf = cfg self.learning_rate = learning_rate - self.cached_gene_embeddings = {} - self.device = None - + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.train_binary_decoder = train_binary_decoder + + # --- Learnable read-depth scalar used when RDA is enabled --- + self.read_depth = nn.Parameter(torch.tensor(float(read_depth), dtype=torch.float), requires_grad=True) + + # --- Caching & state for gene embeddings and missing-gene handling --- + self.cached_gene_embeddings: Dict[Tuple[str, ...], torch.Tensor] = {} + + self.missing_table: Optional[nn.Embedding] = None + self._last_missing_count: int = 0 + self._last_missing_dim: int = 0 + + # Cache present masks and index maps per gene set + self._present_mask_cache: Dict[Tuple[str, ...], torch.Tensor] = {} + self._missing_index_map_cache: Dict[Tuple[str, ...], torch.Tensor] = {} + + # ------------------------- + # Loading / setup + # ------------------------- def load_model(self, checkpoint: str): """ - Load a pre-trained SE model from a single checkpoint path and prepare - it for use. Mirrors the transform/inference loader behavior: extract - config and embeddings from the checkpoint if present, otherwise fallbacks. + Load a pre-trained SE model from a single checkpoint path and prepare it. + Prefers embedded cfg in checkpoint; falls back to provided cfg if needed. """ - if self.model: + if self.model is not None: raise ValueError("Model already initialized") # Resolve configuration: prefer embedded cfg in checkpoint @@ -48,7 +70,6 @@ def load_model(self, checkpoint: str): cfg_to_use = OmegaConf.create(ckpt["cfg_yaml"]) # type: ignore elif isinstance(ckpt, dict) and "hyper_parameters" in ckpt: hp = ckpt.get("hyper_parameters", {}) or {} - # Some checkpoints may have a cfg-like structure in hyper_parameters if isinstance(hp, dict) and len(hp) > 0: try: cfg_to_use = OmegaConf.create(hp["cfg"]) if "cfg" in hp else OmegaConf.create(hp) @@ -57,15 +78,17 @@ def load_model(self, checkpoint: str): except Exception as e: log.warning(f"Could not extract config from checkpoint: {e}") if cfg_to_use is None: - raise ValueError("No config found in checkpoint and no override provided. Provide SE cfg or a full checkpoint.") - + raise ValueError( + "No config found in checkpoint and no override provided. " + "Provide SE cfg or a full checkpoint with embedded config." + ) self._vci_conf = cfg_to_use # Load model; allow passing cfg to constructor like inference self.model = StateEmbeddingModel.load_from_checkpoint( checkpoint, dropout=0.0, strict=False, cfg=self._vci_conf ) - self.device = self.model.device + self.device = self.model.device # type: ignore # Try to extract packaged protein embeddings from checkpoint packaged_pe = None @@ -79,29 +102,32 @@ def load_model(self, checkpoint: str): # Resolve protein embeddings for pe_embedding weights all_pe = packaged_pe or get_embeddings(self._vci_conf) if isinstance(all_pe, dict): - all_pe = torch.vstack(list(all_pe.values())) - all_pe.requires_grad = False - self.model.pe_embedding = nn.Embedding.from_pretrained(all_pe) - self.model.pe_embedding.to(self.device) - - # Keep a mapping from gene name -> protein embedding vector + # For the model's token embedding table, we only need the stacked array. + stacked = torch.vstack(list(all_pe.values())) + else: + stacked = all_pe + stacked.requires_grad = False + self.model.pe_embedding = nn.Embedding.from_pretrained(stacked) # type: ignore + self.model.pe_embedding.to(self.device) # type: ignore + + # Keep a mapping from gene name -> raw protein embedding vector self.protein_embeds = packaged_pe if self.protein_embeds is None: - # Fallback to configured path + # Fallback to configured path on disk self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings, weights_only=False) - # Freeze SE model and decoder - for p in self.model.parameters(): - p.requires_grad = False - for p in self.model.binary_decoder.parameters(): + # Freeze SE model; optionally unfreeze just the binary decoder + for p in self.model.parameters(): # type: ignore p.requires_grad = False - self.model.binary_decoder.eval() + for p in self.model.binary_decoder.parameters(): # type: ignore + p.requires_grad = self.train_binary_decoder + self.model.binary_decoder.train(mode=self.train_binary_decoder) # type: ignore + # ------------------------- + # Gene utilities + # ------------------------- def _auto_detect_gene_column(self, adata): - """Auto-detect the gene column with highest overlap with protein embeddings. - - Returns None to indicate var.index, or a string column name in var. - """ + """Auto-detect the gene column with highest overlap with protein embeddings.""" if self.protein_embeds is None: log.warning("No protein embeddings available for auto-detection, using index") return None @@ -110,14 +136,14 @@ def _auto_detect_gene_column(self, adata): best_column = None best_overlap = 0 - # Check index first + # Index first index_genes = set(getattr(adata.var, "index", [])) overlap = len(protein_genes.intersection(index_genes)) if overlap > best_overlap: best_overlap = overlap - best_column = None # None means use index + best_column = None # None => use index - # Check all columns in var + # All columns for col in adata.var.columns: try: col_vals = adata.var[col].dropna().astype(str) @@ -131,109 +157,170 @@ def _auto_detect_gene_column(self, adata): return best_column - def genes_from_adata(self, adata): + def genes_from_adata(self, adata) -> List[str]: """Return list of gene names from AnnData using auto-detected column/index.""" col = self._auto_detect_gene_column(adata) if col is None: return list(map(str, adata.var.index.values)) return list(adata.var[col].astype(str).values) - def get_gene_embedding(self, genes): + def _ensure_missing_table( + self, + genes_key: Tuple[str, ...], + gene_embed_dim: int, + present_mask: torch.Tensor, + ): """ - Get embeddings for a list of genes, with caching to avoid recomputation. - - Parameters: - ----------- - genes : list - List of gene names/identifiers - - Returns: - -------- - torch.Tensor - Tensor of gene embeddings + Make sure self.missing_table matches the current gene set's missing count & dim. + Builds a per-position index map (for missing genes) and caches the mask + map. + """ + # Build / cache index map for missing positions (pos -> 0..(n_missing-1)) + if genes_key in self._missing_index_map_cache and genes_key in self._present_mask_cache: + return # already prepared for this gene set + + # Identify missing positions + present = present_mask.bool().tolist() + missing_positions = [i for i, ok in enumerate(present) if not ok] + n_missing = len(missing_positions) + + # Cache mask for this gene set (on device) + self._present_mask_cache[genes_key] = present_mask + + if n_missing == 0: + # No missing genes -> trivial index map of zeros (unused) + self._missing_index_map_cache[genes_key] = torch.zeros(len(genes_key), dtype=torch.long, device=self.device) + return + + # (Re)create the missing table if shape changed + if ( + self.missing_table is None + or self._last_missing_count != n_missing + or self._last_missing_dim != gene_embed_dim + ): + self.missing_table = nn.Embedding(n_missing, gene_embed_dim) + nn.init.normal_(self.missing_table.weight, mean=0.0, std=0.02) + # Ensure the embedding table lives on the same device as inputs/masks + self.missing_table.to(present_mask.device) + self._last_missing_count = n_missing + self._last_missing_dim = gene_embed_dim + + # Build a position -> compact missing index map + inv = {pos: j for j, pos in enumerate(missing_positions)} + idx_map = [inv.get(i, 0) for i in range(len(genes_key))] + self._missing_index_map_cache[genes_key] = torch.tensor(idx_map, dtype=torch.long, device=present_mask.device) + + def get_gene_embedding(self, genes: List[str]) -> torch.Tensor: """ - # Cache key based on genes tuple - cache_key = tuple(genes) + Return gene/task embeddings for 'genes'. + For genes missing from the pretrained protein embeddings dictionary, we replace the + post-ESM embedding with a learnable vector from `self.missing_table` via torch.where. + + Caching: + - If no genes are missing, the post-ESM embeddings are cached and reused. + - If some genes are missing, we recompute each call so gradients flow into + self.missing_table (no caching of the final tensor). + """ + if self.model is None: + raise RuntimeError("Model not loaded. Call load_model(checkpoint) first.") + if self.protein_embeds is None: + # Should have been set in load_model; keep a defensive fallback: + self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings, weights_only=False) - # Return cached embeddings if available - if cache_key in self.cached_gene_embeddings: - return self.cached_gene_embeddings[cache_key] + genes_key = tuple(genes) - # Compute gene embeddings; fallback to zero vectors for missing genes. - missing = [g for g in genes if g not in self.protein_embeds] - if len(missing) > 0: - try: - embed_size = next(iter(self.protein_embeds.values())).shape[-1] - except Exception: - embed_size = 5120 - # Log once per call to aid debugging - log.warning( - f"Finetune.get_gene_embedding: {len(missing)} gene(s) missing from pretrained embeddings; using zeros as placeholders. " - f"First missing: {missing[:10]}{' ...' if len(missing) > 10 else ''}." - ) + # Fast path: if we saw this gene set before and no missing genes were involved, reuse cache + if genes_key in self.cached_gene_embeddings: + return self.cached_gene_embeddings[genes_key].to(self.device) - protein_embeds = [ - self.protein_embeds[x] if x in self.protein_embeds else torch.zeros(embed_size) - for x in genes + # Build a [G, embed_size] tensor of raw protein embeddings (zeros where missing) + # Determine the raw protein embedding size + try: + example_vec = next(iter(self.protein_embeds.values())) + embed_size = int(example_vec.shape[-1]) + except Exception: + embed_size = get_embedding_cfg(self._vci_conf).size # fallback + + raw_list = [ + self.protein_embeds[g] if g in self.protein_embeds else torch.zeros(embed_size) # type: ignore + for g in genes ] - protein_embeds = torch.stack(protein_embeds).to(self.device) - gene_embeds = self.model.gene_embedding_layer(protein_embeds) + protein_embeds = torch.stack(raw_list).to(self.device) + + # Project through the model's gene embedding layer (post-ESM projection) + gene_embeds_raw = self.model.gene_embedding_layer(protein_embeds) # type: ignore # [G, d_model] + gene_embeds_raw = gene_embeds_raw.to(self.device) + d_model = int(gene_embeds_raw.shape[-1]) + + # Present mask: True where gene exists in pretrained protein_embeds + present_mask = torch.tensor([g in self.protein_embeds for g in genes], device=self.device).unsqueeze(1) + + # Prepare missing-table and position index map if needed + self._ensure_missing_table(genes_key, d_model, present_mask.squeeze(1)) + + # If we have a non-empty missing_table for this gene set, replace missing rows + idx_map = self._missing_index_map_cache[genes_key] + # Safety: if the missing table exists but is on a different device, move it + if self.missing_table is not None and self.missing_table.weight.device != idx_map.device: + self.missing_table.to(idx_map.device) + if self.missing_table is not None and self._last_missing_count > 0: + learned_full = self.missing_table(idx_map) # [G, d_model] + # Differentiable replacement: keep present rows from gene_embeds_raw, else take learned_full + gene_embeds = torch.where(present_mask, gene_embeds_raw, learned_full) + else: + gene_embeds = gene_embeds_raw + + # Cache only when there are no missing genes for this set (so the tensor is static) + if self._last_missing_count == 0: + self.cached_gene_embeddings[genes_key] = gene_embeds.detach().clone() - # Cache and return - self.cached_gene_embeddings[cache_key] = gene_embeds return gene_embeds - def get_counts(self, cell_embs, genes, read_depth=None, batch_size=32): + # ------------------------- + # Counts decoding (vectorized over genes) + # ------------------------- + def get_counts(self, cell_embs, genes: List[str], batch_size: int = 32) -> torch.Tensor: """ - Generate predictions with the binary decoder with gradients enabled. - - Parameters: - - cell_embs: A tensor or array of cell embeddings. - - genes: List of gene names. - - read_depth: Optional read depth for RDA normalization. - - batch_size: Batch size for processing. + Generate predictions with the (pretrained) binary decoder. This is vectorized + over all genes (no per-gene loops). Returns: - A single tensor of shape [N, num_genes] where N is the total number of cells. + Tensor of shape [Ncells, Ngenes] """ + if self.model is None: + raise RuntimeError("Model not loaded. Call load_model(checkpoint) first.") - # Convert cell_embs to a tensor on the correct device. - cell_embs = torch.tensor(cell_embs, dtype=torch.float, device=self.device) - - # Check if RDA is enabled. - use_rda = getattr(self.model.cfg.model, "rda", False) - if use_rda and read_depth is None: - read_depth = 4.0 + # Convert cell_embs to a tensor on the correct device (no detach here) + cell_embs = torch.as_tensor(cell_embs, dtype=torch.float, device=self.device) - # Retrieve gene embeddings (cached if available). - gene_embeds = self.get_gene_embedding(genes) + # RDA must be enabled to use read_depth + use_rda = getattr(self.model.cfg.model, "rda", False) # type: ignore + assert use_rda, "RDA must be enabled to use get_counts (model.cfg.model.rda == True)." - # List to collect the output predictions for each batch. - output_batches = [] + # Retrieve (and possibly learn) gene embeddings (with differentiable missing replacement) + gene_embeds = self.get_gene_embedding(genes) # [G, d_model] - # Loop over cell embeddings in batches. + outputs = [] for i in range(0, cell_embs.size(0), batch_size): - # Determine batch indices. end_idx = min(i + batch_size, cell_embs.size(0)) - cell_embeds_batch = cell_embs[i:end_idx] + cell_batch = cell_embs[i:end_idx] # [B, E_cell] + + # NOTE: Learnable read depth scalar, expanded to batch (keeps gradient) + task_counts = self.read_depth.expand(cell_batch.shape[0]).to(cell_batch.dtype).to(cell_batch.device) - # Set up task counts if using RDA. - if use_rda: - task_counts = torch.full((cell_embeds_batch.shape[0],), read_depth, device=self.device) - else: - task_counts = None + # Build [B, G, *] pairwise features and decode + merged = self.model.resize_batch(cell_batch, gene_embeds, task_counts) # type: ignore - # Resize the batch using the model's method. - merged_embs = self.model.resize_batch(cell_embeds_batch, gene_embeds, task_counts) + # Align dtype with decoder weights to avoid mixed-precision issues on CPU + dec_param_dtype = next(self.model.binary_decoder.parameters()).dtype # type: ignore + if merged.dtype != dec_param_dtype: + merged = merged.to(dec_param_dtype) - # Forward pass through the binary decoder. - logprobs_batch = self.model.binary_decoder(merged_embs) + logprobs_batch = self.model.binary_decoder(merged) # type: ignore - # If the output has an extra singleton dimension (e.g., [B, gene_dim, 1]), squeeze it. + # Squeeze trailing singleton if present: [B, G, 1] -> [B, G] if logprobs_batch.dim() == 3 and logprobs_batch.size(-1) == 1: logprobs_batch = logprobs_batch.squeeze(-1) - output_batches.append(logprobs_batch) + outputs.append(logprobs_batch) - # Concatenate all batch outputs along the first dimension. - return torch.cat(output_batches, dim=0) + return torch.cat(outputs, dim=0) diff --git a/src/state/tx/models/decoders.py b/src/state/tx/models/decoders.py index bc04e66e..c42fa91a 100644 --- a/src/state/tx/models/decoders.py +++ b/src/state/tx/models/decoders.py @@ -27,6 +27,7 @@ def __init__( hidden_dim: int = 512, dropout: float = 0.1, basal_residual: bool = False, + train_binary_decoder: bool = True, ): super().__init__() # Initialize finetune helper and model from a single checkpoint @@ -34,7 +35,7 @@ def __init__( raise ValueError( "FinetuneVCICountsDecoder requires a VCI/SE config. Set kwargs.vci_config or env STATE_VCI_CONFIG." ) - self.finetune = Finetune(cfg=OmegaConf.load(config)) + self.finetune = Finetune(cfg=OmegaConf.load(config), train_binary_decoder=train_binary_decoder) self.finetune.load_model(checkpoint) # Resolve genes: prefer explicit list; else infer from anndata if provided if genes is None and adata is not None: @@ -51,12 +52,6 @@ def __init__( self.ds_emb_dim = int(ds_emb_dim) if ds_emb_dim is not None else 0 self.input_total_dim = int(latent_dim) - # layers = [ - # nn.Linear(latent_dim, hidden_dims[0]), - # ] - - # self.gene_lora = nn.Sequential(*layers) - self.latent_decoder = nn.Sequential( nn.Linear(latent_dim, hidden_dim), nn.LayerNorm(hidden_dim), @@ -80,9 +75,7 @@ def __init__( nn.Linear(128, len(self.genes)), ) - self.binary_decoder = self.finetune.model.binary_decoder - for param in self.binary_decoder.parameters(): - param.requires_grad = False + self.binary_decoder = self.finetune.model.binary_decoder # type: ignore # Validate that all requested genes exist in the pretrained checkpoint's embeddings pe = getattr(self.finetune, "protein_embeds", {}) From 5e0e715848a634baae5b8e3a3cd4da9d62213682 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sun, 14 Sep 2025 19:20:44 -0700 Subject: [PATCH 06/38] added batch predictor implementation --- src/state/tx/models/state_transition.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index 0e7ba952..e6980f44 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -197,6 +197,9 @@ def __init__( ) self.batch_dim = batch_dim + # Internal cache for last token features (B, S, H) from transformer for aux loss + self._token_features: Optional[torch.Tensor] = None + # if the model is outputting to counts space, apply relu # otherwise its in embedding space and we don't want to is_gene_space = kwargs["embed_key"] == "X_hvg" or kwargs["embed_key"] is None @@ -461,6 +464,9 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: res_pred = transformer_output self._batch_token_cache = None + # Cache token features for auxiliary batch prediction loss (B, S, H) + self._token_features = res_pred + # add to basal if predicting residual if self.predict_residual and self.output_space == "all": # Project control_cells to hidden_dim space to match res_pred @@ -475,8 +481,6 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: # apply relu if specified and we output to HVG space is_gene_space = self.hparams["embed_key"] == "X_hvg" or self.hparams["embed_key"] is None - # logger.info(f"DEBUG: is_gene_space: {is_gene_space}") - # logger.info(f"DEBUG: self.gene_decoder: {self.gene_decoder}") if is_gene_space or self.gene_decoder is None: out_pred = self.relu(out_pred) @@ -569,6 +573,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T self.log("train/batch_token_loss", ce_loss) total_loss = total_loss + self.batch_token_weight * ce_loss + # Auxiliary batch prediction loss (per token), if enabled if self.gene_decoder is not None and "pert_cell_counts" in batch: gene_targets = batch["pert_cell_counts"] # Train decoder to map latent predictions to gene space From f6307af5bbded0beaa05c3d3a8e293f2db3e33df Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sun, 14 Sep 2025 19:56:47 -0700 Subject: [PATCH 07/38] added pseudobulk model file --- src/state/configs/model/pseudobulk.yaml | 54 +++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 src/state/configs/model/pseudobulk.yaml diff --git a/src/state/configs/model/pseudobulk.yaml b/src/state/configs/model/pseudobulk.yaml new file mode 100644 index 00000000..8d22cead --- /dev/null +++ b/src/state/configs/model/pseudobulk.yaml @@ -0,0 +1,54 @@ +name: pseudobulk +checkpoint: null +device: cuda + +kwargs: + cell_set_len: 512 + blur: 0.05 + hidden_dim: 768 # hidden dimension going into the transformer backbone + loss: energy + confidence_head: False + n_encoder_layers: 1 + n_decoder_layers: 1 + predict_residual: True + softplus: True + freeze_pert_backbone: False + transformer_decoder: False + finetune_vci_decoder: False + residual_decoder: False + batch_encoder: False + nb_decoder: False + mask_attn: False + use_effect_gating_token: False + distributional_loss: energy + init_from: null + transformer_backbone_key: llama + transformer_backbone_kwargs: + max_position_embeddings: ${model.kwargs.cell_set_len} + n_positions: ${model.kwargs.cell_set_len} + hidden_size: ${model.kwargs.hidden_dim} + intermediate_size: 3072 + num_hidden_layers: 8 + num_attention_heads: 12 + num_key_value_heads: 12 + head_dim: 64 + use_cache: false + attention_dropout: 0.0 + hidden_dropout: 0.0 + layer_norm_eps: 1e-6 + pad_token_id: 0 + bos_token_id: 1 + eos_token_id: 2 + tie_word_embeddings: false + rotary_dim: 0 + use_rotary_embeddings: false + lora: + enable: false + r: 16 + alpha: 32 + dropout: 0.05 + bias: none + target: auto + adapt_mlp: false + task_type: FEATURE_EXTRACTION + merge_on_eval: false From 15d95a885a5eb94c0e764fb0a8b8f676c0a623d8 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 16 Sep 2025 11:09:43 -0700 Subject: [PATCH 08/38] updated to fix pseudobulk model during eval --- src/state/_cli/_tx/_predict.py | 4 ++ src/state/tx/models/pseudobulk.py | 90 ++++++++++++++++++++++++++----- 2 files changed, 82 insertions(+), 12 deletions(-) diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index b7191037..31b0e5aa 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -175,6 +175,10 @@ def load_config(cfg_path: str) -> dict: from ...tx.models.decoder_only import DecoderOnlyPerturbationModel ModelClass = DecoderOnlyPerturbationModel + elif model_class_name.lower() == "pseudobulk": + from ...tx.models.pseudobulk import PseudobulkPerturbationModel + + ModelClass = PseudobulkPerturbationModel else: raise ValueError(f"Unknown model class: {model_class_name}") diff --git a/src/state/tx/models/pseudobulk.py b/src/state/tx/models/pseudobulk.py index 82309c33..4da11f99 100644 --- a/src/state/tx/models/pseudobulk.py +++ b/src/state/tx/models/pseudobulk.py @@ -136,6 +136,74 @@ def __init__( print(self) + def _decoder_in_features(self) -> Optional[int]: + """ + Best-effort inspection of the decoder's expected input dimensionality. + Returns None if it cannot be determined reliably. + """ + gd = self.gene_decoder + if gd is None: + return None + # LatentToGeneDecoder (non-residual): has .decoder (Sequential) starting with Linear + if hasattr(gd, "decoder") and isinstance(getattr(gd, "decoder"), nn.Sequential): + seq = gd.decoder + for m in seq: + if isinstance(m, nn.Linear): + return m.in_features + return None + # LatentToGeneDecoder (residual): has .blocks (ModuleList) of Sequentials, first starts with Linear + if hasattr(gd, "blocks"): + blocks = getattr(gd, "blocks") + if len(blocks) > 0 and isinstance(blocks[0], nn.Sequential) and isinstance(blocks[0][0], nn.Linear): + return blocks[0][0].in_features + return None + # NBDecoder: has .encoder (Sequential) starting with Linear + if hasattr(gd, "encoder") and isinstance(getattr(gd, "encoder"), nn.Sequential): + seq = gd.encoder + for m in seq: + if isinstance(m, nn.Linear): + return m.in_features + return None + return None + + def _maybe_concat_batch(self, latent: torch.Tensor, batch: torch.Tensor, padded: bool) -> torch.Tensor: + """ + Concatenate batch covariates to the latent only if the decoder expects them. + This avoids shape mismatches at inference when loading a checkpointed decoder + that was trained without batch concatenation. + """ + if self.gene_decoder is None or self.batch_dim is None: + return latent + + expected_in = self._decoder_in_features() + last_dim = latent.size(-1) + + # Prepare batch tensor to match latent shape + if latent.dim() == 2: + batch_var = batch.reshape(latent.shape[0], -1) + else: + batch_var = batch.reshape(latent.shape[0], latent.shape[1], -1) + + # Decide whether to concatenate based on the decoder's input expectation + if expected_in is None: + # Fallback to previous behavior: concatenate for non-VCI decoders + return torch.cat([latent, batch_var], dim=-1) + + if expected_in == last_dim: + # Decoder expects just the latent; do NOT concat + return latent + elif expected_in == last_dim + batch_var.size(-1): + # Decoder expects latent + batch covariates; concat + return torch.cat([latent, batch_var], dim=-1) + else: + # Mismatch: give a clear error message to guide the user + raise RuntimeError( + f"Decoder input dim mismatch: got latent size {last_dim}" + f" (batch_dim={batch_var.size(-1)}), but decoder expects {expected_in}." + " This usually means the checkpointed decoder was trained without" + " concatenating batch covariates, while predict is attempting to." + ) + def _build_networks(self): """ Here we instantiate the actual GPT2-based model. @@ -272,10 +340,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T # with torch.no_grad(): # latent_preds = pred.detach() # Detach to prevent gradient flow back to main model - batch_var = batch["batch"].reshape(latent_preds.shape[0], latent_preds.shape[1], -1) - # concatenate on the last axis - if self.batch_dim is not None and not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): - latent_preds = torch.cat([latent_preds, batch_var], dim=-1) + if not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): + latent_preds = self._maybe_concat_batch(latent_preds, batch["batch"], padded=True) if isinstance(self.gene_decoder, NBDecoder): mu, theta = self.gene_decoder(latent_preds) @@ -319,7 +385,10 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non gene_targets = batch["pert_cell_counts"].reshape_as(mu) decoder_loss = nb_nll(gene_targets, mu, theta) else: - pert_cell_counts_preds = self.gene_decoder(latent_preds) # verify this is automatically detached + # Match decoder input dims + if not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): + latent_preds = self._maybe_concat_batch(latent_preds, batch["batch"], padded=True) + pert_cell_counts_preds = self.gene_decoder(latent_preds) # Get decoder predictions pert_cell_counts_preds = pert_cell_counts_preds.reshape(-1, self.cell_sentence_len, self.gene_dim) @@ -359,17 +428,14 @@ def predict_step(self, batch, batch_idx, padded=True, **kwargs): basal_hvg = batch.get("ctrl_cell_counts", None) if self.gene_decoder is not None: - if latent_output.dim() == 2: - batch_var = batch["batch"].reshape(latent_output.shape[0], -1) - else: - batch_var = batch["batch"].reshape(latent_output.shape[0], latent_output.shape[1], -1) - # concatenate on the last axis - if self.batch_dim is not None and not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): - latent_output = torch.cat([latent_output, batch_var], dim=-1) if isinstance(self.gene_decoder, NBDecoder): + # NB decoder already configured with latent_dim including batch if needed mu, _ = self.gene_decoder(latent_output) pert_cell_counts_preds = mu else: + # Only concat batch covariates if decoder expects them + if not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): + latent_output = self._maybe_concat_batch(latent_output, batch["batch"], padded=padded) pert_cell_counts_preds = self.gene_decoder(latent_output) output_dict["pert_cell_counts_preds"] = pert_cell_counts_preds From f40291df74a5bc7a66ff8c1a4e4ad9670a292647 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 16 Sep 2025 13:09:31 -0700 Subject: [PATCH 09/38] fixed vci fine tune decoder setup --- pyproject.toml | 2 +- src/state/tx/models/base.py | 22 ++++++++++++++++++++++ src/state/tx/models/decoders.py | 10 ++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index dec4906e..2c421d32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ { name = "Abhinav Adduri", email = "abhinav.adduri@arcinstitute.org" }, { name = "Yusuf Roohani", email = "yusuf.roohani@arcinstitute.org" }, { name = "Noam Teyssier", email = "noam.teyssier@arcinstitute.org" }, - { name = "Rajesh Ilango" }, + { name = "Rajesh Ilango", email = "rilango@gmail.com" }, { name = "Dhruv Gautam", email = "dhruvgautam@berkeley.edu" }, ] requires-python = ">=3.10,<3.13" diff --git a/src/state/tx/models/base.py b/src/state/tx/models/base.py index 10378ef8..53e6ce70 100644 --- a/src/state/tx/models/base.py +++ b/src/state/tx/models/base.py @@ -216,6 +216,28 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: if self.gene_decoder_bool == False: self.gene_decoder = None return + + # When finetuning with the pretrained VCI decoder, keep the existing + # FinetuneVCICountsDecoder instance. Overwriting it with a freshly + # constructed LatentToGeneDecoder would make the checkpoint weights + # incompatible and surface load_state_dict errors. + finetune_decoder_active = False + hparams = getattr(self, "hparams", None) + if hparams is not None: + if hasattr(hparams, "get"): + finetune_decoder_active = bool(hparams.get("finetune_vci_decoder", False)) + else: + finetune_decoder_active = bool(getattr(hparams, "finetune_vci_decoder", False)) + if not finetune_decoder_active: + finetune_decoder_active = bool(getattr(self, "finetune_vci_decoder", False)) + + if finetune_decoder_active: + # Preserve decoder_cfg for completeness but avoid rebuilding the module. + if "decoder_cfg" in checkpoint.get("hyper_parameters", {}): + self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] + logger.info("Finetune VCI decoder active; keeping existing decoder during checkpoint load") + return + if not decoder_already_configured and "decoder_cfg" in checkpoint["hyper_parameters"]: self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg) diff --git a/src/state/tx/models/decoders.py b/src/state/tx/models/decoders.py index c42fa91a..8c8616b7 100644 --- a/src/state/tx/models/decoders.py +++ b/src/state/tx/models/decoders.py @@ -113,6 +113,16 @@ def __init__( # Register a dummy buffer so attributes exist self.missing_table = None + # Ensure the wrapped Finetune helper creates its own missing-table parameters + # prior to Lightning's checkpoint load. Otherwise the checkpoint will contain + # weights like `gene_decoder.finetune.missing_table.weight` that are absent + # from a freshly constructed module, triggering "unexpected key" errors. + try: + with torch.no_grad(): + self.finetune.get_gene_embedding(self.genes) + except Exception as exc: + logger.debug(f"Deferred Finetune missing-table initialization failed: {exc}") + def gene_dim(self): return len(self.genes) From 5aa2252e0608ea149d2d8048c2893324282e2aa3 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 16 Sep 2025 13:09:54 -0700 Subject: [PATCH 10/38] bumping semvar --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2c421d32..122f2024 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arc-state" -version = "0.9.31" +version = "0.9.32" description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts." readme = "README.md" authors = [ From 848859cd3a359d637ffa7672c949f09cd8a4bfdb Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 16 Sep 2025 18:36:43 -0700 Subject: [PATCH 11/38] fixed using embedding for the output space --- src/state/_cli/_tx/_train.py | 29 +++++++++++++------- src/state/emb/finetune_decoder.py | 14 ++++++---- src/state/tx/models/base.py | 20 ++++++++++++-- src/state/tx/models/decoders.py | 4 +-- src/state/tx/models/state_transition.py | 35 +++++++++++++++++++++++++ src/state/tx/models/utils.py | 8 ++---- 6 files changed, 85 insertions(+), 25 deletions(-) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index c3f475e6..054a2bd0 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -108,6 +108,11 @@ def run_tx_train(cfg: DictConfig): elif cfg["model"]["name"].lower() == "scvi": cfg["data"]["kwargs"]["transform"] = None + output_space = cfg["data"]["kwargs"].get("output_space", "gene") + assert output_space in {"embedding", "gene", "all"}, ( + f"data.kwargs.output_space must be one of 'embedding', 'gene', or 'all'; got {output_space!r}" + ) + data_module: PerturbationDataModule = get_datamodule( cfg["data"]["name"], cfg["data"]["kwargs"], @@ -125,23 +130,27 @@ def run_tx_train(cfg: DictConfig): print("batch size:", dl.batch_size) var_dims = data_module.get_var_dims() # {"gene_dim": …, "hvg_dim": …} - if cfg["data"]["kwargs"]["output_space"] == "gene": + if output_space == "gene": gene_dim = var_dims.get("hvg_dim", 2000) # fallback if key missing else: gene_dim = var_dims.get("gene_dim", 2000) # fallback if key missing latent_dim = var_dims["output_dim"] # same as model.output_dim hidden_dims = cfg["model"]["kwargs"].get("decoder_hidden_dims", [1024, 1024, 512]) - decoder_cfg = dict( - latent_dim=latent_dim, - gene_dim=gene_dim, - hidden_dims=hidden_dims, - dropout=cfg["model"]["kwargs"].get("decoder_dropout", 0.1), - residual_decoder=cfg["model"]["kwargs"].get("residual_decoder", False), - ) + if output_space in {"gene", "all"}: + decoder_cfg = dict( + latent_dim=latent_dim, + gene_dim=gene_dim, + hidden_dims=hidden_dims, + dropout=cfg["model"]["kwargs"].get("decoder_dropout", 0.1), + residual_decoder=cfg["model"]["kwargs"].get("residual_decoder", False), + ) - # tuck it into the kwargs that will reach the LightningModule - cfg["model"]["kwargs"]["decoder_cfg"] = decoder_cfg + # tuck it into the kwargs that will reach the LightningModule + cfg["model"]["kwargs"]["decoder_cfg"] = decoder_cfg + else: + cfg["model"]["kwargs"].pop("decoder_cfg", None) + cfg["model"]["kwargs"]["gene_decoder_bool"] = False # Save the onehot maps as pickle files instead of storing in config cell_type_onehot_map_path = join(run_output_dir, "cell_type_onehot_map.pkl") diff --git a/src/state/emb/finetune_decoder.py b/src/state/emb/finetune_decoder.py index 42af047a..50880a9d 100644 --- a/src/state/emb/finetune_decoder.py +++ b/src/state/emb/finetune_decoder.py @@ -15,7 +15,13 @@ class Finetune(nn.Module): - def __init__(self, cfg: Optional[OmegaConf] = None, learning_rate: float = 1e-4, read_depth: float = 4.0, train_binary_decoder: bool = False): + def __init__( + self, + cfg: Optional[OmegaConf] = None, + learning_rate: float = 1e-4, + read_depth: float = 4.0, + train_binary_decoder: bool = False, + ): """ Helper module that loads a pretrained SE/VCI checkpoint and exposes: - get_gene_embedding(genes): returns gene/task embeddings with differentiable @@ -45,7 +51,7 @@ def __init__(self, cfg: Optional[OmegaConf] = None, learning_rate: float = 1e-4, self.missing_table: Optional[nn.Embedding] = None self._last_missing_count: int = 0 self._last_missing_dim: int = 0 - + # Cache present masks and index maps per gene set self._present_mask_cache: Dict[Tuple[str, ...], torch.Tensor] = {} self._missing_index_map_cache: Dict[Tuple[str, ...], torch.Tensor] = {} @@ -85,9 +91,7 @@ def load_model(self, checkpoint: str): self._vci_conf = cfg_to_use # Load model; allow passing cfg to constructor like inference - self.model = StateEmbeddingModel.load_from_checkpoint( - checkpoint, dropout=0.0, strict=False, cfg=self._vci_conf - ) + self.model = StateEmbeddingModel.load_from_checkpoint(checkpoint, dropout=0.0, strict=False, cfg=self._vci_conf) self.device = self.model.device # type: ignore # Try to extract packaged protein embeddings from checkpoint diff --git a/src/state/tx/models/base.py b/src/state/tx/models/base.py index 53e6ce70..635536c7 100644 --- a/src/state/tx/models/base.py +++ b/src/state/tx/models/base.py @@ -123,12 +123,12 @@ class PerturbationModel(ABC, LightningModule): Args: input_dim: Dimension of input features (genes or embeddings) hidden_dim: Hidden dimension for neural network layers - output_dim: Dimension of output (always gene space) + output_dim: Dimension of output (gene space or embedding space) pert_dim: Dimension of perturbation embeddings dropout: Dropout rate lr: Learning rate for optimizer loss_fn: Loss function ('mse' or custom nn.Module) - output_space: 'gene' or 'latent' + output_space: 'gene', 'all', or 'embedding' """ def __init__( @@ -174,6 +174,10 @@ def __init__( self.embed_key = embed_key self.output_space = output_space + if self.output_space not in {"embedding", "gene", "all"}: + raise ValueError( + f"Unsupported output_space '{self.output_space}'. Expected one of 'embedding', 'gene', or 'all'." + ) self.batch_size = batch_size self.control_pert = control_pert @@ -182,6 +186,18 @@ def __init__( self.dropout = dropout self.lr = lr self.loss_fn = get_loss_fn(loss_fn) + + if self.output_space == "embedding": + self.gene_decoder_bool = False + self.decoder_cfg = None + # keep hyperparameters metadata consistent with the actual model state + try: + if hasattr(self, "hparams"): + self.hparams["gene_decoder_bool"] = False # type: ignore[index] + self.hparams["decoder_cfg"] = None # type: ignore[index] + except Exception: + pass + self._build_decoder() def transfer_batch_to_device(self, batch, device, dataloader_idx: int): diff --git a/src/state/tx/models/decoders.py b/src/state/tx/models/decoders.py index 8c8616b7..ae06565c 100644 --- a/src/state/tx/models/decoders.py +++ b/src/state/tx/models/decoders.py @@ -23,7 +23,7 @@ def __init__( config: Optional[str] = "/home/aadduri/vci_pretrain/vci_1.4.4/config.yaml", latent_dim: int = 1034, # total input dim (cell emb + optional ds emb) read_depth: float = 4.0, - ds_emb_dim: int = 10, # dataset embedding dim at the tail of input + ds_emb_dim: int = 10, # dataset embedding dim at the tail of input hidden_dim: int = 512, dropout: float = 0.1, basal_residual: bool = False, @@ -75,7 +75,7 @@ def __init__( nn.Linear(128, len(self.genes)), ) - self.binary_decoder = self.finetune.model.binary_decoder # type: ignore + self.binary_decoder = self.finetune.model.binary_decoder # type: ignore # Validate that all requested genes exist in the pretrained checkpoint's embeddings pe = getattr(self.finetune, "protein_embeds", {}) diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index e6980f44..641fdc6d 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -197,6 +197,41 @@ def __init__( ) self.batch_dim = batch_dim + # Optional batch predictor ablation: learns a single batch token added to every position, + # and adds an auxiliary per-token batch classification head + CE loss. + self.batch_predictor = bool(kwargs.get("batch_predictor", False)) + # If batch_encoder is enabled, disable batch_predictor per request + if self.batch_encoder is not None and self.batch_predictor: + logger.warning( + "Both model.kwargs.batch_encoder and model.kwargs.batch_predictor are True. " + "Disabling batch_predictor and proceeding with batch_encoder." + ) + self.batch_predictor = False + try: + # Keep hparams in sync if available + self.hparams["batch_predictor"] = False # type: ignore[index] + except Exception: + pass + + self.batch_predictor_weight = float(kwargs.get("batch_predictor_weight", 0.1)) + self.batch_predictor_num_classes: Optional[int] = batch_dim if self.batch_predictor else None + if self.batch_predictor: + if self.batch_predictor_num_classes is None: + raise ValueError("batch_predictor=True requires a valid `batch_dim` (number of batch classes).") + # A single learnable batch token that is added to each position + self.batch_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim)) + # Simple per-token classifier from transformer hidden to batch classes + self.batch_classifier = build_mlp( + in_dim=self.hidden_dim, + out_dim=self.batch_predictor_num_classes, + hidden_dim=self.hidden_dim, + n_layers=4, + dropout=self.dropout, + activation=self.activation_class, + ) + else: + self.batch_token = None + self.batch_classifier = None # Internal cache for last token features (B, S, H) from transformer for aux loss self._token_features: Optional[torch.Tensor] = None diff --git a/src/state/tx/models/utils.py b/src/state/tx/models/utils.py index 47185a83..6deb4a28 100644 --- a/src/state/tx/models/utils.py +++ b/src/state/tx/models/utils.py @@ -160,16 +160,12 @@ def apply_lora(model: PreTrainedModel, backbone_key: str, lora_cfg: dict | None) return model if LoraConfig is None or get_peft_model is None: - raise ImportError( - "peft is not installed but `lora.enable` is True. Add `peft` to dependencies." - ) + raise ImportError("peft is not installed but `lora.enable` is True. Add `peft` to dependencies.") target = lora_cfg.get("target", "auto") adapt_mlp = bool(lora_cfg.get("adapt_mlp", False)) target_modules = ( - lora_cfg.get("target_modules") - if target != "auto" - else _default_lora_targets(backbone_key, adapt_mlp) + lora_cfg.get("target_modules") if target != "auto" else _default_lora_targets(backbone_key, adapt_mlp) ) # Build PEFT LoRA config From b137ae64070bdbbe796567bc6fdc9036a32fef91 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Fri, 19 Sep 2025 20:43:19 +0000 Subject: [PATCH 12/38] included fix to transform wit new checkpoints --- scripts/state_embed_anndata.py | 50 ------------------------------- src/state/_cli/_emb/_transform.py | 34 ++++++++++++++------- 2 files changed, 24 insertions(+), 60 deletions(-) delete mode 100644 scripts/state_embed_anndata.py diff --git a/scripts/state_embed_anndata.py b/scripts/state_embed_anndata.py deleted file mode 100644 index 3af33ab0..00000000 --- a/scripts/state_embed_anndata.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -""" -VCI Model Embedding Script - -This script computes embeddings for an input anndata file using a pre-trained VCI model checkpoint. -It can be run from any directory and outputs the embedded anndata to a specified location. - -Usage: - python embed_vci.py --checkpoint PATH_TO_CHECKPOINT --input INPUT_ANNDATA --output OUTPUT_ANNDATA - -Example: - python embed_vci.py --checkpoint /path/to/model.ckpt --input data.h5ad --output embedded_data.h5ad -""" - -import argparse -import os - -from omegaconf import OmegaConf - -from state_sets.state.inference import Inference - - -# Parse command line arguments -def parse_args(): - parser = argparse.ArgumentParser(description="Compute embeddings for anndata using a VCI model") - parser.add_argument("--checkpoint", required=True, help="Path to the model checkpoint file") - parser.add_argument("--config", required=True, help="Path to the model training config") - parser.add_argument("--input", required=True, help="Path to input anndata file (h5ad)") - parser.add_argument("--output", required=True, help="Path to output embedded anndata file (h5ad)") - parser.add_argument("--dataset-name", default="perturbation", help="Dataset name to be used in dataloader creation") - parser.add_argument("--gpu", action="store_true", help="Use GPU if available") - parser.add_argument("--filter", action="store_true", help="Filter gene set to our esm embeddings only.") - parser.add_argument("--embed-key", help="Name of key to store") - - return parser.parse_args() - - -def main(): - # Parse command line arguments - args = parse_args() - - conf = OmegaConf.load(args.config) - inferer = Inference(conf) - inferer.load_model(args.checkpoint) - os.makedirs(os.path.dirname(args.output), exist_ok=True) - inferer.encode_adata(args.input, args.output, emb_key=args.embed_key, dataset_name=args.dataset_name) - - -if __name__ == "__main__": - main() diff --git a/src/state/_cli/_emb/_transform.py b/src/state/_cli/_emb/_transform.py index 26b38cfa..54ba75e2 100644 --- a/src/state/_cli/_emb/_transform.py +++ b/src/state/_cli/_emb/_transform.py @@ -3,8 +3,16 @@ def add_arguments_transform(parser: ap.ArgumentParser): """Add arguments for state embedding CLI.""" - parser.add_argument("--model-folder", required=True, help="Path to the model checkpoint folder") - parser.add_argument("--checkpoint", required=False, help="Path to the specific model checkpoint") + parser.add_argument( + "--model-folder", + required=False, + help="Path to the model checkpoint folder (required if --checkpoint is not provided)", + ) + parser.add_argument( + "--checkpoint", + required=False, + help="Path to the specific model checkpoint (required if --model-folder is not provided)", + ) parser.add_argument( "--config", required=False, @@ -60,13 +68,19 @@ def run_emb_transform(args: ap.ArgumentParser): logger.error("Either --output or --lancedb must be provided") raise ValueError("Either --output or --lancedb must be provided") - # look in the model folder with glob for *.ckpt, get the first one, and print it - model_files = glob.glob(os.path.join(args.model_folder, "*.ckpt")) - if not model_files: - logger.error(f"No model checkpoint found in {args.model_folder}") - raise FileNotFoundError(f"No model checkpoint found in {args.model_folder}") - if not args.checkpoint: - args.checkpoint = model_files[-1] + # Resolve checkpoint path, allowing either --checkpoint, --model-folder, or both + checkpoint_path = args.checkpoint + if args.model_folder: + model_files = glob.glob(os.path.join(args.model_folder, "*.ckpt")) + if not model_files and not checkpoint_path: + logger.error(f"No model checkpoint found in {args.model_folder}") + raise FileNotFoundError(f"No model checkpoint found in {args.model_folder}") + if not checkpoint_path and model_files: + checkpoint_path = model_files[-1] + if not checkpoint_path: + logger.error("Either --checkpoint or --model-folder must be provided") + raise ValueError("Either --checkpoint or --model-folder must be provided") + args.checkpoint = checkpoint_path logger.info(f"Using model checkpoint: {args.checkpoint}") # Create inference object @@ -79,7 +93,7 @@ def run_emb_transform(args: ap.ArgumentParser): if args.protein_embeddings: logger.info(f"Using protein embeddings override: {args.protein_embeddings}") protein_embeds = torch.load(args.protein_embeddings, weights_only=False, map_location="cpu") - else: + elif args.model_folder: # Try auto-detect in model folder try: exact_path = os.path.join(args.model_folder, "protein_embeddings.pt") From bd5bd299dbf0ef5b26ffb59730ca61242d217ced Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sun, 21 Sep 2025 18:33:36 +0000 Subject: [PATCH 13/38] distributed sampler commit --- src/state/_cli/_tx/_train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 054a2bd0..86229dd5 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -269,6 +269,7 @@ def run_tx_train(cfg: DictConfig): callbacks=callbacks, gradient_clip_val=cfg["training"]["gradient_clip_val"] if cfg["model"]["name"].lower() != "cpa" else None, accumulate_grad_batches=cfg["training"].get("gradient_accumulation_steps", 1), + use_distributed_sampler=False, ) # Align logging cadence with rolling MFU window (and W&B logging) From 154b55562d07323c4117109bdd069b7c3d7d7ea4 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sun, 21 Sep 2025 22:01:03 +0000 Subject: [PATCH 14/38] bump cell load requirement --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 122f2024..5da63a58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ requires-python = ">=3.10,<3.13" dependencies = [ "anndata>=0.11.4", - "cell-load>=0.8.3", + "cell-load>=0.8.4", "numpy>=2.2.6", "pandas>=2.2.3", "pyyaml>=6.0.2", From 43eff469ef41afe1111a30be675cec50e7843625 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 23 Sep 2025 17:57:49 +0000 Subject: [PATCH 15/38] updated so that .npy output for files will write out just embeddings for transform, not the whole adata --- pyproject.toml | 2 +- src/state/_cli/_emb/_transform.py | 23 +++++++++++++++++--- src/state/_cli/_tx/_infer.py | 35 ++++++++++++++++++++++++++----- src/state/emb/inference.py | 2 ++ 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5da63a58..17545e4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ requires-python = ">=3.10,<3.13" dependencies = [ "anndata>=0.11.4", - "cell-load>=0.8.4", + "cell-load>=0.8.5", "numpy>=2.2.6", "pandas>=2.2.3", "pyyaml>=6.0.2", diff --git a/src/state/_cli/_emb/_transform.py b/src/state/_cli/_emb/_transform.py index 54ba75e2..becd7480 100644 --- a/src/state/_cli/_emb/_transform.py +++ b/src/state/_cli/_emb/_transform.py @@ -54,6 +54,7 @@ def run_emb_transform(args: ap.ArgumentParser): import glob import logging import os + import numpy as np import torch from omegaconf import OmegaConf @@ -124,6 +125,12 @@ def run_emb_transform(args: ap.ArgumentParser): logger.info(f"Loading model from checkpoint: {args.checkpoint}") inferer.load_model(args.checkpoint) + save_as_npy = False + output_target = args.output + if args.output: + _, ext = os.path.splitext(args.output) + save_as_npy = ext.lower() == ".npy" + # Create output directory if it doesn't exist if args.output: output_dir = os.path.dirname(args.output) @@ -134,13 +141,16 @@ def run_emb_transform(args: ap.ArgumentParser): # Generate embeddings logger.info(f"Computing embeddings for {args.input}") if args.output: - logger.info(f"Output will be saved to {args.output}") + if save_as_npy: + logger.info(f"Output embeddings will be saved to {args.output} as a NumPy array") + else: + logger.info(f"Output will be saved to {args.output}") if args.lancedb: logger.info(f"Embeddings will be saved to LanceDB at {args.lancedb}") - inferer.encode_adata( + embeddings = inferer.encode_adata( input_adata_path=args.input, - output_adata_path=args.output, + output_adata_path=None if save_as_npy else output_target, emb_key=args.embed_key, batch_size=args.batch_size if getattr(args, "batch_size", None) is not None else None, lancedb_path=args.lancedb, @@ -148,4 +158,11 @@ def run_emb_transform(args: ap.ArgumentParser): lancedb_batch_size=args.lancedb_batch_size, ) + if save_as_npy: + if embeddings is None: + logger.error("Failed to generate embeddings for NumPy output") + raise RuntimeError("Embedding generation returned no data") + np.save(args.output, embeddings) + logger.info(f"Saved embeddings matrix with shape {embeddings.shape} to {args.output}") + logger.info("Embedding computation completed successfully!") diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index a4f05d04..530cee1f 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -27,7 +27,7 @@ def add_arguments_infer(parser: argparse.ArgumentParser): "--output", type=str, default=None, - help="Path to output AnnData file (.h5ad). Defaults to _simulated.h5ad", + help="Path to output file (.h5ad or .npy). Defaults to _simulated.h5ad", ) parser.add_argument( "--model-dir", @@ -650,15 +650,35 @@ def group_control_indices(group_name: str) -> np.ndarray: # ----------------------- # 5) Persist the updated AnnData # ----------------------- + output_path = args.output or args.adata.replace(".h5ad", "_simulated.h5ad") + output_is_npy = output_path.lower().endswith(".npy") + + pred_matrix = None if writes_to[0] == ".X": if out_target == "X": adata.X = sim_X + pred_matrix = sim_X + elif out_target.startswith("obsm['") and out_target.endswith("']"): + pred_key = out_target[6:-2] + pred_matrix = adata.obsm.get(pred_key) + else: + pred_matrix = sim_X else: if out_target == f"obsm['{writes_to[1]}']": adata.obsm[writes_to[1]] = sim_obsm + pred_matrix = sim_obsm + elif out_target.startswith("obsm['") and out_target.endswith("']"): + pred_key = out_target[6:-2] + pred_matrix = adata.obsm.get(pred_key) + else: + pred_matrix = sim_obsm - output_path = args.output or args.adata.replace(".h5ad", "_simulated.h5ad") - adata.write_h5ad(output_path) + if output_is_npy: + if pred_matrix is None: + raise ValueError("Predictions matrix is unavailable; cannot write .npy output") + np.save(output_path, np.asarray(pred_matrix)) + else: + adata.write_h5ad(output_path) # ----------------------- # 6) Summary @@ -667,5 +687,10 @@ def group_control_indices(group_name: str) -> np.ndarray: print(f"Input cells: {n_total}") print(f"Controls simulated: {n_controls}") print(f"Treated simulated: {n_nonctl}") - print(f"Wrote predictions to adata.{out_target}") - print(f"Saved: {output_path}") + if output_is_npy: + shape_str = " x ".join(str(dim) for dim in pred_matrix.shape) if pred_matrix is not None else "unknown" + print(f"Wrote predictions array (shape: {shape_str})") + print(f"Saved NumPy file: {output_path}") + else: + print(f"Wrote predictions to adata.{out_target}") + print(f"Saved: {output_path}") diff --git a/src/state/emb/inference.py b/src/state/emb/inference.py index d042864f..7df4fe1a 100644 --- a/src/state/emb/inference.py +++ b/src/state/emb/inference.py @@ -277,6 +277,8 @@ def encode_adata( log.info(f"Successfully saved {len(all_embeddings)} embeddings to LanceDB") + return all_embeddings + def _convert_to_csr(self, adata): """Convert the adata.X matrix to CSR format if it's not already.""" from scipy.sparse import csr_matrix, issparse From 9819805623532408a21853057d0efad77b288a21 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Fri, 26 Sep 2025 16:56:50 +0000 Subject: [PATCH 16/38] removed periodic checkpointing --- src/state/configs/state-defaults.yaml | 4 ++-- src/state/tx/utils/__init__.py | 18 ++++-------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/state/configs/state-defaults.yaml b/src/state/configs/state-defaults.yaml index 8414ec8b..8968115b 100644 --- a/src/state/configs/state-defaults.yaml +++ b/src/state/configs/state-defaults.yaml @@ -19,8 +19,8 @@ experiment: ddp_timeout: 3600 checkpoint: path: /scratch/ctc/ML/vci/checkpoint/pretrain - save_top_k: 4 - monitor: trainer/train_loss + save_top_k: 2 + monitor: validation/val_loss every_n_train_steps: 1000 wandb: enable: true diff --git a/src/state/tx/utils/__init__.py b/src/state/tx/utils/__init__.py index 7a35c853..ae49016c 100644 --- a/src/state/tx/utils/__init__.py +++ b/src/state/tx/utils/__init__.py @@ -127,7 +127,7 @@ def get_loggers( return loggers -def get_checkpoint_callbacks(output_dir: str, name: str, val_freq: int, ckpt_every_n_steps: int): +def get_checkpoint_callbacks(output_dir: str, name: str, val_freq: int, _ckpt_every_n_steps: int): """ Create checkpoint callbacks based on validation frequency. @@ -136,28 +136,18 @@ def get_checkpoint_callbacks(output_dir: str, name: str, val_freq: int, ckpt_eve checkpoint_dir = join(output_dir, name, "checkpoints") callbacks = [] - # Save best checkpoint based on validation loss + # Save only the two best checkpoints (by val_loss) plus the latest checkpoint best_ckpt = ModelCheckpoint( dirpath=checkpoint_dir, filename="step={step}-val_loss={val_loss:.4f}", - save_last="link", # Will create last.ckpt symlink to best checkpoint + save_last=True, monitor="val_loss", mode="min", - save_top_k=1, # Only keep the best checkpoint + save_top_k=2, every_n_train_steps=val_freq, ) callbacks.append(best_ckpt) - # Also save periodic checkpoints (without affecting the "last" symlink) - periodic_ckpt = ModelCheckpoint( - dirpath=checkpoint_dir, - filename="{step}", - save_last=False, # Don't create/update symlink - every_n_train_steps=ckpt_every_n_steps, - save_top_k=-1, # Keep all periodic checkpoints - ) - callbacks.append(periodic_ckpt) - return callbacks From e5bf17b45652d069d9068cdcb350d2aadfa47eaa Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Mon, 29 Sep 2025 01:51:46 +0000 Subject: [PATCH 17/38] updated checkpoint logic to only store a best.ckpt --- pyproject.toml | 8 ++++++-- src/state/_cli/_tx/_predict.py | 14 +++++++------- src/state/tx/utils/__init__.py | 6 +++--- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 17545e4e..2ae77fdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ requires-python = ">=3.10,<3.13" dependencies = [ "anndata>=0.11.4", - "cell-load>=0.8.5", + "cell-load>=0.8.7", "numpy>=2.2.6", "pandas>=2.2.3", "pyyaml>=6.0.2", @@ -27,11 +27,15 @@ dependencies = [ "geomloss>=0.2.6", "transformers>=4.52.3", "peft>=0.11.0", - "cell-eval>=0.5.22", + "cell-eval>=0.5.45", "ipykernel>=6.30.1", "scipy>=1.15.0", ] +[tool.uv.sources] +cell-load = {path = "/home/aadduri/cell-load"} +cell-eval = {path = "/home/aadduri/cell-eval"} + [project.optional-dependencies] vectordb = [ "lancedb>=0.24.0" diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index 31b0e5aa..5ea1d3c4 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -332,16 +332,16 @@ def load_config(cfg_path: str) -> dict: var = pd.DataFrame({"gene_names": gene_names}) if final_X_hvg is not None: - if len(gene_names) != final_pert_cell_counts_preds.shape[1]: - gene_names = np.load( - "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True - ) - var = pd.DataFrame({"gene_names": gene_names}) + # if len(gene_names) != final_pert_cell_counts_preds.shape[1]: + # gene_names = np.load( + # "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True + # ) + # var = pd.DataFrame({"gene_names": gene_names}) # Create adata for predictions - using the decoded gene expression values - adata_pred = anndata.AnnData(X=final_pert_cell_counts_preds, obs=obs, var=var) + adata_pred = anndata.AnnData(X=final_pert_cell_counts_preds, obs=obs) # Create adata for real - using the true gene expression values - adata_real = anndata.AnnData(X=final_X_hvg, obs=obs, var=var) + adata_real = anndata.AnnData(X=final_X_hvg, obs=obs) # add the embedding predictions adata_pred.obsm[data_module.embed_key] = final_preds diff --git a/src/state/tx/utils/__init__.py b/src/state/tx/utils/__init__.py index ae49016c..e4fcf9ac 100644 --- a/src/state/tx/utils/__init__.py +++ b/src/state/tx/utils/__init__.py @@ -136,14 +136,14 @@ def get_checkpoint_callbacks(output_dir: str, name: str, val_freq: int, _ckpt_ev checkpoint_dir = join(output_dir, name, "checkpoints") callbacks = [] - # Save only the two best checkpoints (by val_loss) plus the latest checkpoint + # Save only the best checkpoint (by val_loss) plus the latest checkpoint best_ckpt = ModelCheckpoint( dirpath=checkpoint_dir, - filename="step={step}-val_loss={val_loss:.4f}", + filename="best", save_last=True, monitor="val_loss", mode="min", - save_top_k=2, + save_top_k=1, every_n_train_steps=val_freq, ) callbacks.append(best_ckpt) From 0df3dbc9d443f736d970758e9e56d33ac959138b Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Mon, 29 Sep 2025 17:22:50 +0000 Subject: [PATCH 18/38] added split batch option that relies on cell-eval having the split batch option --- pyproject.toml | 2 +- src/state/_cli/_tx/_predict.py | 113 +++++++++++++++++++++++++++------ 2 files changed, 95 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2ae77fdf..8ab6f2d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "geomloss>=0.2.6", "transformers>=4.52.3", "peft>=0.11.0", - "cell-eval>=0.5.45", + "cell-eval>=0.5.46", "ipykernel>=6.30.1", "scipy>=1.15.0", ] diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index 5ea1d3c4..c314063e 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -40,6 +40,12 @@ def add_arguments_predict(parser: ap.ArgumentParser): help="If set, only run prediction without evaluation metrics.", ) + parser.add_argument( + "--split-batch", + action="store_true", + help="If set, compute metrics separately for each (cell type, batch) pair.", + ) + parser.add_argument( "--shared-only", action="store_true", @@ -67,7 +73,7 @@ def run_tx_predict(args: ap.ArgumentParser): # Cell-eval for metrics computation from cell_eval import MetricsEvaluator - from cell_eval.utils import split_anndata_on_celltype + from cell_eval.utils import build_celltype_split_specs from cell_load.data_modules import PerturbationDataModule from tqdm import tqdm @@ -288,17 +294,70 @@ def load_config(cfg_path: str) -> dict: else: all_celltypes.append(batch_preds["celltype_name"]) - # Handle gem_group - if isinstance(batch_preds["batch"], list): - all_gem_groups.extend([str(x) for x in batch_preds["batch"]]) - elif isinstance(batch_preds["batch"], torch.Tensor): - all_gem_groups.extend([str(x) for x in batch_preds["batch"].cpu().numpy()]) - else: - all_gem_groups.append(str(batch_preds["batch"])) + batch_size = batch_preds["preds"].shape[0] + + # Handle gem_group - prefer human-readable batch names when available + def normalize_batch_labels(values): + if values is None: + return None + if isinstance(values, torch.Tensor): + values = values.detach().cpu().numpy() + if isinstance(values, np.ndarray): + if values.ndim == 2: + if values.shape[0] != batch_size: + return None + if values.shape[1] == 1: + flat = values.reshape(batch_size) + return [str(x) for x in flat.tolist()] + indices = values.argmax(axis=1) + return [str(int(x)) for x in indices.tolist()] + if values.ndim == 1: + if values.shape[0] != batch_size: + return None + return [str(x) for x in values.tolist()] + if values.ndim == 0: + return [str(values.item())] * batch_size + return None + if isinstance(values, (list, tuple)): + if len(values) != batch_size: + return None + normalized = [] + for item in values: + if isinstance(item, torch.Tensor): + item = item.detach().cpu().numpy() + if isinstance(item, np.ndarray): + if item.ndim == 0: + normalized.append(str(item.item())) + continue + if item.ndim == 1: + if item.size == 1: + normalized.append(str(item.item())) + elif np.count_nonzero(item) == 1: + normalized.append(str(int(item.argmax()))) + else: + normalized.append(str(item.tolist())) + continue + normalized.append(str(item)) + return normalized + return [str(values)] * batch_size + + batch_name_candidates = ( + batch.get("batch_name"), + batch_preds.get("batch_name"), + batch_preds.get("batch"), + ) + + batch_labels = None + for candidate in batch_name_candidates: + batch_labels = normalize_batch_labels(candidate) + if batch_labels is not None: + break + if batch_labels is None: + batch_labels = ["None"] * batch_size + all_gem_groups.extend(batch_labels) batch_pred_np = batch_preds["preds"].cpu().numpy().astype(np.float32) batch_real_np = batch_preds["pert_cell_emb"].cpu().numpy().astype(np.float32) - batch_size = batch_pred_np.shape[0] final_preds[current_idx : current_idx + batch_size, :] = batch_pred_np final_reals[current_idx : current_idx + batch_size, :] = batch_real_np current_idx += batch_size @@ -408,25 +467,41 @@ def load_config(cfg_path: str) -> dict: control_pert = data_module.get_control_pert() - ct_split_real = split_anndata_on_celltype(adata=adata_real, celltype_col=data_module.cell_type_key) - ct_split_pred = split_anndata_on_celltype(adata=adata_pred, celltype_col=data_module.cell_type_key) + batch_key = data_module.batch_col if args.split_batch else None + if args.split_batch: + if not batch_key: + raise ValueError("--split-batch requested but no batch column is configured on the data module.") + logger.info( + "Splitting evaluation by cell type and batch column '%s'", batch_key + ) - assert len(ct_split_real) == len(ct_split_pred), ( - f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}" + split_specs = build_celltype_split_specs( + real=adata_real, + pred=adata_pred, + celltype_col=data_module.cell_type_key, + batch_key=batch_key, ) pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) - for ct in ct_split_real.keys(): - real_ct = ct_split_real[ct] - pred_ct = ct_split_pred[ct] + for split in split_specs: + batch_suffix = ( + f", batch={split.batch}" + if split.batch is not None and not pd.isna(split.batch) + else "" + ) + logger.info( + "Evaluating metrics for celltype=%s%s", + split.celltype, + batch_suffix, + ) evaluator = MetricsEvaluator( - adata_pred=pred_ct, - adata_real=real_ct, + adata_pred=split.pred, + adata_real=split.real, control_pert=control_pert, pert_col=data_module.pert_col, outdir=results_dir, - prefix=ct, + prefix=split.label, pdex_kwargs=pdex_kwargs, batch_size=2048, ) From 19933ed244e082fcdb963bc25441ded03a58fa05 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Thu, 2 Oct 2025 10:21:10 -0700 Subject: [PATCH 19/38] udpated cell eval to 0.6 --- pyproject.toml | 3 +- src/state/configs/model/state.yaml | 2 +- src/state/tx/models/state_transition.py | 89 ++++++++++++------------- 3 files changed, 44 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8ab6f2d8..84cdb1d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,14 +27,13 @@ dependencies = [ "geomloss>=0.2.6", "transformers>=4.52.3", "peft>=0.11.0", - "cell-eval>=0.5.46", + "cell-eval>=0.6.0", "ipykernel>=6.30.1", "scipy>=1.15.0", ] [tool.uv.sources] cell-load = {path = "/home/aadduri/cell-load"} -cell-eval = {path = "/home/aadduri/cell-eval"} [project.optional-dependencies] vectordb = [ diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index 9de7263b..86e52b95 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -7,7 +7,7 @@ kwargs: blur: 0.05 hidden_dim: 768 # hidden dimension going into the transformer backbone loss: energy - confidence_head: False + confidence_token: False n_encoder_layers: 1 n_decoder_layers: 1 predict_residual: True diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index 641fdc6d..f7fc1ed6 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -20,9 +20,7 @@ class CombinedLoss(nn.Module): - """ - Combined Sinkhorn + Energy loss - """ + """Combined Sinkhorn + Energy loss.""" def __init__(self, sinkhorn_weight=0.001, energy_weight=1.0, blur=0.05): super().__init__() @@ -173,7 +171,7 @@ def __init__( elif loss_name == "mse": self.loss_fn = nn.MSELoss() elif loss_name == "se": - sinkhorn_weight = kwargs.get("sinkhorn_weight", 0.01) # 1/100 = 0.01 + sinkhorn_weight = kwargs.get("sinkhorn_weight", 0.01) energy_weight = kwargs.get("energy_weight", 1.0) self.loss_fn = CombinedLoss(sinkhorn_weight=sinkhorn_weight, energy_weight=energy_weight, blur=blur) elif loss_name == "sinkhorn": @@ -288,6 +286,11 @@ def __init__( if kwargs.get("confidence_token", False): self.confidence_token = ConfidenceToken(hidden_dim=self.hidden_dim, dropout=self.dropout) self.confidence_loss_fn = nn.MSELoss() + self.confidence_target_scale = float(kwargs.get("confidence_target_scale", 10.0)) + self.confidence_weight = float(kwargs.get("confidence_weight", 0.01)) + else: + self.confidence_target_scale = None + self.confidence_weight = 0.0 # Backward-compat: accept legacy key `freeze_pert` self.freeze_pert_backbone = kwargs.get("freeze_pert_backbone", kwargs.get("freeze_pert", False)) @@ -544,7 +547,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T pred = pred.reshape(1, -1, self.output_dim) target = target.reshape(1, -1, self.output_dim) - main_loss = self.loss_fn(pred, target).nanmean() + per_set_main_losses = self.loss_fn(pred, target) + main_loss = torch.nanmean(per_set_main_losses) self.log("train_loss", main_loss) # Log individual loss components if using combined loss @@ -641,25 +645,18 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T total_loss = total_loss + self.decoder_loss_weight * decoder_loss if confidence_pred is not None: - # Detach main loss to prevent gradients flowing through it - loss_target = total_loss.detach().clone().unsqueeze(0) * 10 - - # Ensure proper shapes for confidence loss computation - if confidence_pred.dim() == 2: # [B, 1] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1) - else: # confidence_pred is [B,] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0)) - - # Compute confidence loss - confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze()) + confidence_pred_vals = confidence_pred + if confidence_pred_vals.dim() > 1: + confidence_pred_vals = confidence_pred_vals.squeeze(-1) + confidence_targets = per_set_main_losses.detach() + if self.confidence_target_scale is not None: + confidence_targets = confidence_targets * self.confidence_target_scale + confidence_targets = confidence_targets.to(confidence_pred_vals.device) + + confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets) self.log("train/confidence_loss", confidence_loss) - self.log("train/actual_loss", loss_target.mean()) + self.log("train/actual_loss", confidence_targets.mean()) - # Add to total loss with weighting - confidence_weight = 0.1 # You can make this configurable - total_loss = total_loss + confidence_weight * confidence_loss - - # Add to total loss total_loss = total_loss + confidence_loss if self.regularization > 0.0: @@ -688,7 +685,8 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non target = batch["pert_cell_emb"] target = target.reshape(-1, self.cell_sentence_len, self.output_dim) - loss = self.loss_fn(pred, target).mean() + per_set_main_losses = self.loss_fn(pred, target) + loss = torch.nanmean(per_set_main_losses) self.log("val_loss", loss) # Log individual loss components if using combined loss @@ -722,19 +720,17 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non loss = loss + self.decoder_loss_weight * decoder_loss if confidence_pred is not None: - # Detach main loss to prevent gradients flowing through it - loss_target = loss.detach().clone() * 10 - - # Ensure proper shapes for confidence loss computation - if confidence_pred.dim() == 2: # [B, 1] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1) - else: # confidence_pred is [B,] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0)) - - # Compute confidence loss - confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze()) + confidence_pred_vals = confidence_pred + if confidence_pred_vals.dim() > 1: + confidence_pred_vals = confidence_pred_vals.squeeze(-1) + confidence_targets = per_set_main_losses.detach() + if self.confidence_target_scale is not None: + confidence_targets = confidence_targets * self.confidence_target_scale + confidence_targets = confidence_targets.to(confidence_pred_vals.device) + + confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets) self.log("val/confidence_loss", confidence_loss) - self.log("val/actual_loss", loss_target.mean()) + self.log("val/actual_loss", confidence_targets.mean()) return {"loss": loss, "predictions": pred} @@ -747,21 +743,20 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: target = batch["pert_cell_emb"] pred = pred.reshape(1, -1, self.output_dim) target = target.reshape(1, -1, self.output_dim) - loss = self.loss_fn(pred, target).mean() + per_set_main_losses = self.loss_fn(pred, target) + loss = torch.nanmean(per_set_main_losses) self.log("test_loss", loss) if confidence_pred is not None: - # Detach main loss to prevent gradients flowing through it - loss_target = loss.detach().clone() * 10.0 - - # Ensure proper shapes for confidence loss computation - if confidence_pred.dim() == 2: # [B, 1] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1) - else: # confidence_pred is [B,] - loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0)) - - # Compute confidence loss - confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze()) + confidence_pred_vals = confidence_pred + if confidence_pred_vals.dim() > 1: + confidence_pred_vals = confidence_pred_vals.squeeze(-1) + confidence_targets = per_set_main_losses.detach() + if self.confidence_target_scale is not None: + confidence_targets = confidence_targets * self.confidence_target_scale + confidence_targets = confidence_targets.to(confidence_pred_vals.device) + + confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets) self.log("test/confidence_loss", confidence_loss) def predict_step(self, batch, batch_idx, padded=True, **kwargs): From 1738e6d3a650b8cd12f69bfe01cf023cbd99c35d Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Thu, 2 Oct 2025 15:15:24 -0700 Subject: [PATCH 20/38] added multi head MMD for optimization --- src/state/configs/model/state.yaml | 2 ++ src/state/configs/model/state_sm.yaml | 2 ++ src/state/tx/models/state_transition.py | 32 +++++++++++++++++++++---- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index 86e52b95..98c15880 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -23,6 +23,8 @@ kwargs: use_effect_gating_token: False distributional_loss: energy init_from: null + mmd_num_chunks: 1 + randomize_mmd_chunks: false transformer_backbone_key: llama transformer_backbone_kwargs: bidirectional_attention: false diff --git a/src/state/configs/model/state_sm.yaml b/src/state/configs/model/state_sm.yaml index 77ddfd1f..11fd84ef 100644 --- a/src/state/configs/model/state_sm.yaml +++ b/src/state/configs/model/state_sm.yaml @@ -24,6 +24,8 @@ kwargs: distributional_loss: energy gene_decoder_bool: False init_from: null + mmd_num_chunks: 1 + randomize_mmd_chunks: false transformer_backbone_key: llama transformer_backbone_kwargs: bidirectional_attention: false diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index f7fc1ed6..a9f9a3ab 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -162,6 +162,8 @@ def __init__( self.distributional_loss = distributional_loss self.gene_dim = gene_dim + self.mmd_num_chunks = max(int(kwargs.get("mmd_num_chunks", 1)), 1) + self.randomize_mmd_chunks = bool(kwargs.get("randomize_mmd_chunks", False)) # Build the distributional loss from geomloss blur = kwargs.get("blur", 0.05) @@ -529,6 +531,24 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: else: return output + def _compute_distribution_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Apply the primary distributional loss, optionally chunking feature dimensions for SamplesLoss.""" + + if isinstance(self.loss_fn, SamplesLoss) and self.mmd_num_chunks > 1: + feature_dim = pred.shape[-1] + num_chunks = min(self.mmd_num_chunks, feature_dim) + if num_chunks > 1 and feature_dim > 0: + if self.randomize_mmd_chunks and self.training: + perm = torch.randperm(feature_dim, device=pred.device) + pred = pred.index_select(-1, perm) + target = target.index_select(-1, perm) + pred_chunks = torch.chunk(pred, num_chunks, dim=-1) + target_chunks = torch.chunk(target, num_chunks, dim=-1) + chunk_losses = [self.loss_fn(p_chunk, t_chunk) for p_chunk, t_chunk in zip(pred_chunks, target_chunks)] + return torch.stack(chunk_losses, dim=0).nanmean(dim=0) + + return self.loss_fn(pred, target) + def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=True) -> torch.Tensor: """Training step logic for both main model and decoder.""" # Get model predictions (in latent space) @@ -547,7 +567,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T pred = pred.reshape(1, -1, self.output_dim) target = target.reshape(1, -1, self.output_dim) - per_set_main_losses = self.loss_fn(pred, target) + per_set_main_losses = self._compute_distribution_loss(pred, target) main_loss = torch.nanmean(per_set_main_losses) self.log("train_loss", main_loss) @@ -637,7 +657,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T else: gene_targets = gene_targets.reshape(1, -1, self.gene_decoder.gene_dim()) - decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() + decoder_per_set = self._compute_distribution_loss(pert_cell_counts_preds, gene_targets) + decoder_loss = decoder_per_set.mean() # Log decoder loss self.log("decoder_loss", decoder_loss) @@ -685,7 +706,7 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non target = batch["pert_cell_emb"] target = target.reshape(-1, self.cell_sentence_len, self.output_dim) - per_set_main_losses = self.loss_fn(pred, target) + per_set_main_losses = self._compute_distribution_loss(pred, target) loss = torch.nanmean(per_set_main_losses) self.log("val_loss", loss) @@ -713,7 +734,8 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non -1, self.cell_sentence_len, self.gene_decoder.gene_dim() ) gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_decoder.gene_dim()) - decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() + decoder_per_set = self._compute_distribution_loss(pert_cell_counts_preds, gene_targets) + decoder_loss = decoder_per_set.mean() # Log the validation metric self.log("val/decoder_loss", decoder_loss) @@ -743,7 +765,7 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: target = batch["pert_cell_emb"] pred = pred.reshape(1, -1, self.output_dim) target = target.reshape(1, -1, self.output_dim) - per_set_main_losses = self.loss_fn(pred, target) + per_set_main_losses = self._compute_distribution_loss(pred, target) loss = torch.nanmean(per_set_main_losses) self.log("test_loss", loss) From 33becb5a20bd4b3e6d5f6d37e6563aff23174566 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Fri, 3 Oct 2025 18:22:03 +0000 Subject: [PATCH 21/38] removed batch split for cell eval --- src/state/_cli/_tx/_predict.py | 39 +++++++++++----------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index c314063e..cc6cabf5 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -73,7 +73,7 @@ def run_tx_predict(args: ap.ArgumentParser): # Cell-eval for metrics computation from cell_eval import MetricsEvaluator - from cell_eval.utils import build_celltype_split_specs + from cell_eval.utils import split_anndata_on_celltype from cell_load.data_modules import PerturbationDataModule from tqdm import tqdm @@ -467,41 +467,26 @@ def normalize_batch_labels(values): control_pert = data_module.get_control_pert() - batch_key = data_module.batch_col if args.split_batch else None - if args.split_batch: - if not batch_key: - raise ValueError("--split-batch requested but no batch column is configured on the data module.") - logger.info( - "Splitting evaluation by cell type and batch column '%s'", batch_key - ) + ct_split_real = split_anndata_on_celltype(adata=adata_real, celltype_col=data_module.cell_type_key) + ct_split_pred = split_anndata_on_celltype(adata=adata_pred, celltype_col=data_module.cell_type_key) - split_specs = build_celltype_split_specs( - real=adata_real, - pred=adata_pred, - celltype_col=data_module.cell_type_key, - batch_key=batch_key, + assert len(ct_split_real) == len(ct_split_pred), ( + f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}" ) pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) - for split in split_specs: - batch_suffix = ( - f", batch={split.batch}" - if split.batch is not None and not pd.isna(split.batch) - else "" - ) - logger.info( - "Evaluating metrics for celltype=%s%s", - split.celltype, - batch_suffix, - ) + + for ct in ct_split_real.keys(): + real_ct = ct_split_real[ct] + pred_ct = ct_split_pred[ct] evaluator = MetricsEvaluator( - adata_pred=split.pred, - adata_real=split.real, + adata_pred=pred_ct, + adata_real=real_ct, control_pert=control_pert, pert_col=data_module.pert_col, outdir=results_dir, - prefix=split.label, + prefix=ct, pdex_kwargs=pdex_kwargs, batch_size=2048, ) From d3fa2271b8491682e490384b63a47c680ac34324 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Fri, 3 Oct 2025 18:22:29 +0000 Subject: [PATCH 22/38] chore: formatting --- src/state/_cli/_tx/_train.py | 3 +- src/state/tx/models/context_mean.py | 13 ++- src/state/tx/models/state_transition.py | 4 +- src/state/tx/models/utils.py | 6 +- tests/test_bidirectional_models.py | 143 +++++++++++------------- 5 files changed, 85 insertions(+), 84 deletions(-) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 86229dd5..2c9f10a0 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -269,14 +269,13 @@ def run_tx_train(cfg: DictConfig): callbacks=callbacks, gradient_clip_val=cfg["training"]["gradient_clip_val"] if cfg["model"]["name"].lower() != "cpa" else None, accumulate_grad_batches=cfg["training"].get("gradient_accumulation_steps", 1), - use_distributed_sampler=False, + use_distributed_sampler=False, ) # Align logging cadence with rolling MFU window (and W&B logging) if "log_every_n_steps" in cfg["training"]: trainer_kwargs["log_every_n_steps"] = cfg["training"]["log_every_n_steps"] - # Build trainer print(f"Building trainer with kwargs: {trainer_kwargs}") trainer = pl.Trainer(**trainer_kwargs) diff --git a/src/state/tx/models/context_mean.py b/src/state/tx/models/context_mean.py index 7491dbcd..386bf0a3 100644 --- a/src/state/tx/models/context_mean.py +++ b/src/state/tx/models/context_mean.py @@ -91,7 +91,14 @@ def on_fit_start(self): return # Initialize dictionary to accumulate sum and count for each cell type. - celltype_sums = defaultdict(lambda: {"sum": torch.zeros(self.output_dim), "count": 0, "control_sum": torch.zeros(self.output_dim), "control_count": 0}) + celltype_sums = defaultdict( + lambda: { + "sum": torch.zeros(self.output_dim), + "count": 0, + "control_sum": torch.zeros(self.output_dim), + "control_count": 0, + } + ) with torch.no_grad(): for batch in train_loader: @@ -127,7 +134,9 @@ def on_fit_start(self): if stats["control_count"] > 0: # Use control cell average as fallback for cell types with no perturbations self.celltype_pert_means[ct_name] = stats["control_sum"] / stats["control_count"] - logger.info(f"ContextMean: Using control cell average for cell type '{ct_name}' (no perturbations found, {stats['control_count']} control cells used).") + logger.info( + f"ContextMean: Using control cell average for cell type '{ct_name}' (no perturbations found, {stats['control_count']} control cells used)." + ) else: logger.warning(f"No perturbed or control cells found for cell type {ct_name}.") continue diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index a9f9a3ab..72d7020b 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -464,11 +464,11 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: if self.hparams.get("mask_attn", False): batch_size, seq_length, _ = seq_input.shape device = seq_input.device - self.transformer_backbone._attn_implementation = "eager" # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] + self.transformer_backbone._attn_implementation = "eager" # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] # create a [1,1,S,S] mask (now S+1 if confidence token is used) base = torch.eye(seq_length, device=device, dtype=torch.bool).view(1, 1, seq_length, seq_length) - + # Get number of attention heads from model config num_heads = self.transformer_backbone.config.num_attention_heads diff --git a/src/state/tx/models/utils.py b/src/state/tx/models/utils.py index 6deb4a28..fdfdf577 100644 --- a/src/state/tx/models/utils.py +++ b/src/state/tx/models/utils.py @@ -226,13 +226,13 @@ def __init__(self, config: LlamaConfig): self.rotary_emb = NoRoPE( head_dim=config.head_dim, ) - + # Explicitly disable causal attention self.config.is_causal = False # force every layer to be non-causal for layer in self.layers: if hasattr(layer, "self_attn"): - layer.self_attn.is_causal = False # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] + layer.self_attn.is_causal = False # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] def _update_causal_mask( self, @@ -261,7 +261,7 @@ def forward( **flash_attn_kwargs, ): flash_attn_kwargs["is_causal"] = False - + # If no attention_mask is provided, create an all-ones mask (no masking) # This ensures bidirectional attention with correct device/dtype if attention_mask is None: diff --git a/tests/test_bidirectional_models.py b/tests/test_bidirectional_models.py index 5b41cc9a..9a0576be 100644 --- a/tests/test_bidirectional_models.py +++ b/tests/test_bidirectional_models.py @@ -24,10 +24,10 @@ def small_llama_config(): def test_llama_bidirectional_config_is_non_causal(small_llama_config): """Test that LlamaBidirectionalModel sets is_causal to False.""" model = LlamaBidirectionalModel(small_llama_config) - + # Check that the model config is non-causal assert model.config.is_causal is False - + # Check that all attention layers are non-causal for layer in model.layers: if hasattr(layer, "self_attn"): @@ -37,13 +37,13 @@ def test_llama_bidirectional_config_is_non_causal(small_llama_config): def test_llama_bidirectional_update_causal_mask_returns_none(small_llama_config): """Test that _update_causal_mask returns None, disabling causal masking.""" model = LlamaBidirectionalModel(small_llama_config) - + # Create dummy inputs batch_size, seq_len = 2, 8 attention_mask = torch.ones(batch_size, seq_len) input_tensor = torch.randn(batch_size, seq_len, small_llama_config.hidden_size) cache_position = torch.arange(seq_len) - + # Call _update_causal_mask result = model._update_causal_mask( attention_mask=attention_mask, @@ -52,7 +52,7 @@ def test_llama_bidirectional_update_causal_mask_returns_none(small_llama_config) past_key_values=None, output_attentions=False, ) - + # Should return None (no causal masking) assert result is None @@ -98,169 +98,163 @@ def test_get_transformer_backbone_llama_bidirectional_flag(): def test_llama_bidirectional_attention_vs_causal(small_llama_config): """ Test that bidirectional attention produces different outputs than causal attention. - + This is the key test: in bidirectional attention, later tokens should affect earlier token representations, which doesn't happen in causal attention. """ torch.manual_seed(42) - + # Create both bidirectional and standard (causal) models bidirectional_model = LlamaBidirectionalModel(small_llama_config) causal_model = LlamaModel(small_llama_config) - + # Copy weights from bidirectional to causal to ensure same initialization causal_model.load_state_dict(bidirectional_model.state_dict(), strict=False) - + # Create input batch_size, seq_len = 2, 8 inputs_embeds = torch.randn(batch_size, seq_len, small_llama_config.hidden_size) - + # Set models to eval mode bidirectional_model.eval() causal_model.eval() - + with torch.no_grad(): # Get outputs from bidirectional model bidirectional_output = bidirectional_model(inputs_embeds=inputs_embeds) - + # Get outputs from causal model causal_output = causal_model(inputs_embeds=inputs_embeds) - + # The outputs should be different because bidirectional allows all tokens # to attend to each other, while causal only allows attending to past tokens - assert not torch.allclose( - bidirectional_output.last_hidden_state, - causal_output.last_hidden_state, - atol=1e-5 - ), "Bidirectional and causal outputs should differ" + assert not torch.allclose(bidirectional_output.last_hidden_state, causal_output.last_hidden_state, atol=1e-5), ( + "Bidirectional and causal outputs should differ" + ) def test_llama_bidirectional_future_tokens_affect_past(small_llama_config): """ Test that future tokens affect past token representations in bidirectional model. - + This is the core property of bidirectional attention: changing a future token should change the representation of past tokens. """ torch.manual_seed(42) - + model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 1, 6 hidden_size = small_llama_config.hidden_size - + # Create two inputs that differ only in the last token inputs_embeds_1 = torch.randn(batch_size, seq_len, hidden_size) inputs_embeds_2 = inputs_embeds_1.clone() - + # Modify only the last token embedding in the second input inputs_embeds_2[:, -1, :] = torch.randn(batch_size, hidden_size) - + with torch.no_grad(): output_1 = model(inputs_embeds=inputs_embeds_1) output_2 = model(inputs_embeds=inputs_embeds_2) - + # Check that the first tokens' representations differ between the two inputs # This demonstrates that the last token (future) affects the first token (past) first_token_repr_1 = output_1.last_hidden_state[:, 0, :] first_token_repr_2 = output_2.last_hidden_state[:, 0, :] - - assert not torch.allclose(first_token_repr_1, first_token_repr_2, atol=1e-5), \ + + assert not torch.allclose(first_token_repr_1, first_token_repr_2, atol=1e-5), ( "First token representation should change when last token changes (bidirectional attention)" + ) def test_llama_bidirectional_first_token_differs_across_batch(small_llama_config): """ Test that first token representations differ across batch when sequences differ. - + This is a critical test for bidirectional attention: in causal attention, the first token can only attend to itself, so if all sequences have the same first token, they would produce identical first token representations. - + In bidirectional attention, the first token attends to all tokens in the sequence, so different sequences should produce different first token representations even when the first tokens themselves are identical. """ torch.manual_seed(42) - + model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 4, 8 hidden_size = small_llama_config.hidden_size - + # Create a batch where ALL sequences have the SAME first token embedding # but DIFFERENT subsequent tokens inputs_embeds = torch.randn(batch_size, seq_len, hidden_size) - + # Make the first token identical across all sequences shared_first_token = torch.randn(1, hidden_size) inputs_embeds[:, 0, :] = shared_first_token - + with torch.no_grad(): output = model(inputs_embeds=inputs_embeds) - + # Extract first token representations for all sequences first_token_reprs = output.last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size) - + # In bidirectional attention, these should all be DIFFERENT # because each attends to different subsequent tokens # Check that not all first tokens are the same for i in range(batch_size): for j in range(i + 1, batch_size): - assert not torch.allclose( - first_token_reprs[i], - first_token_reprs[j], - atol=1e-5 - ), f"First token representations for sequences {i} and {j} should differ in bidirectional attention" - + assert not torch.allclose(first_token_reprs[i], first_token_reprs[j], atol=1e-5), ( + f"First token representations for sequences {i} and {j} should differ in bidirectional attention" + ) + # Additional check: variance across batch should be substantial variance_per_dim = torch.var(first_token_reprs, dim=0) mean_variance = variance_per_dim.mean() - assert mean_variance > 1e-4, \ + assert mean_variance > 1e-4, ( "First token representations should have substantial variance across batch in bidirectional attention" + ) def test_llama_bidirectional_symmetric_position_influence(small_llama_config): """ - Test that in bidirectional attention, position i affects position j + Test that in bidirectional attention, position i affects position j as much as position j affects position i (roughly symmetric). """ torch.manual_seed(42) - + model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 1, 4 hidden_size = small_llama_config.hidden_size - + # Create base input base_input = torch.randn(batch_size, seq_len, hidden_size) - + # Modify position 0 and see effect on position 2 input_modify_0 = base_input.clone() input_modify_0[:, 0, :] = torch.randn(batch_size, hidden_size) - + # Modify position 2 and see effect on position 0 input_modify_2 = base_input.clone() input_modify_2[:, 2, :] = torch.randn(batch_size, hidden_size) - + with torch.no_grad(): output_base = model(inputs_embeds=base_input) output_modify_0 = model(inputs_embeds=input_modify_0) output_modify_2 = model(inputs_embeds=input_modify_2) - + # Calculate how much position 2 changes when position 0 changes - effect_0_on_2 = torch.norm( - output_modify_0.last_hidden_state[:, 2, :] - output_base.last_hidden_state[:, 2, :] - ) - + effect_0_on_2 = torch.norm(output_modify_0.last_hidden_state[:, 2, :] - output_base.last_hidden_state[:, 2, :]) + # Calculate how much position 0 changes when position 2 changes - effect_2_on_0 = torch.norm( - output_modify_2.last_hidden_state[:, 0, :] - output_base.last_hidden_state[:, 0, :] - ) - + effect_2_on_0 = torch.norm(output_modify_2.last_hidden_state[:, 0, :] - output_base.last_hidden_state[:, 0, :]) + # In bidirectional attention, these effects should both be non-zero # (demonstrating mutual influence, unlike in causal attention) assert effect_0_on_2 > 0.01, "Position 0 should affect position 2" @@ -271,13 +265,13 @@ def test_llama_bidirectional_forward_with_input_ids(small_llama_config): """Test that forward pass works with input_ids.""" model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 2, 10 input_ids = torch.randint(0, small_llama_config.vocab_size, (batch_size, seq_len)) - + with torch.no_grad(): output = model(input_ids=input_ids) - + # Check output shape assert output.last_hidden_state.shape == (batch_size, seq_len, small_llama_config.hidden_size) @@ -286,18 +280,18 @@ def test_llama_bidirectional_forward_with_attention_mask(small_llama_config): """Test that forward pass respects attention mask for padding.""" model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 2, 10 hidden_size = small_llama_config.hidden_size inputs_embeds = torch.randn(batch_size, seq_len, hidden_size) - + # Create attention mask: first sequence has padding at positions 8-9 attention_mask = torch.ones(batch_size, seq_len) attention_mask[0, 8:] = 0 # Mask out last 2 positions for first sequence - + with torch.no_grad(): output = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask) - + # Check that output shape is correct assert output.last_hidden_state.shape == (batch_size, seq_len, hidden_size) @@ -306,24 +300,24 @@ def test_llama_bidirectional_is_causal_false_in_forward(small_llama_config): """Test that is_causal=False is passed in flash_attn_kwargs during forward.""" model = LlamaBidirectionalModel(small_llama_config) model.eval() - + batch_size, seq_len = 1, 8 inputs_embeds = torch.randn(batch_size, seq_len, small_llama_config.hidden_size) - + # Monkey-patch the parent's forward to capture flash_attn_kwargs original_forward = LlamaModel.forward captured_kwargs = {} - + def capture_forward(self, **kwargs): captured_kwargs.update(kwargs) return original_forward(self, **kwargs) - + LlamaModel.forward = capture_forward # type: ignore - + try: with torch.no_grad(): model(inputs_embeds=inputs_embeds) - + # Check that is_causal was set to False assert "is_causal" in captured_kwargs assert captured_kwargs["is_causal"] is False @@ -335,9 +329,8 @@ def capture_forward(self, **kwargs): def test_llama_bidirectional_no_rope(small_llama_config): """Test that NoRoPE is used instead of standard rotary embeddings.""" from state.tx.models.utils import NoRoPE - + model = LlamaBidirectionalModel(small_llama_config) - + # Check that rotary_emb is an instance of NoRoPE assert isinstance(model.rotary_emb, NoRoPE) - From 116cc8bd208dc2f0592fb8788649b4f919e50b9d Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Fri, 3 Oct 2025 23:37:16 +0000 Subject: [PATCH 23/38] added combo perturbation for rpe1 --- src/state/__main__.py | 3 + src/state/_cli/__init__.py | 2 + src/state/_cli/_tx/__init__.py | 3 + src/state/_cli/_tx/_combo.py | 574 +++++++++++++++++++++++++++++++++ 4 files changed, 582 insertions(+) create mode 100644 src/state/_cli/_tx/_combo.py diff --git a/src/state/__main__.py b/src/state/__main__.py index 0a7f9236..20068d8d 100644 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -11,6 +11,7 @@ run_emb_query, run_emb_preprocess, run_emb_eval, + run_tx_combo, run_tx_infer, run_tx_predict, run_tx_preprocess_infer, @@ -124,6 +125,8 @@ def main(): case "infer": # Run inference using argparse, similar to predict run_tx_infer(args) + case "combo": + run_tx_combo(args) case "preprocess_train": # Run preprocessing using argparse run_tx_preprocess_train(args.adata, args.output, args.num_hvgs) diff --git a/src/state/_cli/__init__.py b/src/state/_cli/__init__.py index 2507d565..af9b4107 100644 --- a/src/state/_cli/__init__.py +++ b/src/state/_cli/__init__.py @@ -1,6 +1,7 @@ from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query, run_emb_preprocess, run_emb_eval from ._tx import ( add_arguments_tx, + run_tx_combo, run_tx_infer, run_tx_predict, run_tx_preprocess_infer, @@ -16,6 +17,7 @@ "run_tx_infer", "run_tx_preprocess_train", "run_tx_preprocess_infer", + "run_tx_combo", "run_emb_fit", "run_emb_query", "run_emb_transform", diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py index 975fba42..cdadc904 100644 --- a/src/state/_cli/_tx/__init__.py +++ b/src/state/_cli/_tx/__init__.py @@ -5,6 +5,7 @@ from ._preprocess_infer import add_arguments_preprocess_infer, run_tx_preprocess_infer from ._preprocess_train import add_arguments_preprocess_train, run_tx_preprocess_train from ._train import add_arguments_train, run_tx_train +from ._combo import add_arguments_combo, run_tx_combo __all__ = [ "run_tx_train", @@ -12,6 +13,7 @@ "run_tx_infer", "run_tx_preprocess_train", "run_tx_preprocess_infer", + "run_tx_combo", "add_arguments_tx", ] @@ -24,3 +26,4 @@ def add_arguments_tx(parser: ap.ArgumentParser): add_arguments_infer(subparsers.add_parser("infer")) add_arguments_preprocess_train(subparsers.add_parser("preprocess_train")) add_arguments_preprocess_infer(subparsers.add_parser("preprocess_infer")) + add_arguments_combo(subparsers.add_parser("combo")) diff --git a/src/state/_cli/_tx/_combo.py b/src/state/_cli/_tx/_combo.py new file mode 100644 index 00000000..ef167368 --- /dev/null +++ b/src/state/_cli/_tx/_combo.py @@ -0,0 +1,574 @@ +import argparse as ap + + +def add_arguments_combo(parser: ap.ArgumentParser) -> None: + """CLI for two-stage perturbation combination sweeps.""" + + parser.add_argument("--model-dir", type=str, required=True, help="Path to the trained model directory.") + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help=( + "Optional checkpoint path. If omitted, defaults to /checkpoints/last.ckpt " + "(falling back to final.ckpt if needed)." + ), + ) + parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad).") + parser.add_argument( + "--embed-key", + type=str, + default=None, + help="Optional key in adata.obsm for input features (defaults to adata.X).", + ) + parser.add_argument( + "--pert-col", + type=str, + required=True, + help="Column in adata.obs containing perturbation labels.", + ) + parser.add_argument( + "--control-pert", + type=str, + required=True, + help="Label of the control perturbation (used to construct the base control set).", + ) + parser.add_argument( + "--cell-type", + type=str, + required=True, + help="Target cell type value to filter before running the combo sweep.", + ) + parser.add_argument( + "--celltype-col", + type=str, + default=None, + help=( + "Optional column name in adata.obs for cell types. If omitted, attempts to detect using the " + "training config or common fallbacks." + ), + ) + parser.add_argument( + "--cell-set-len", + type=int, + default=None, + help="Override the model cell_set_len when constructing the fixed control set.", + ) + parser.add_argument( + "--batch-col", + type=str, + default=None, + help=( + "Optional batch column in adata.obs. If omitted, attempts to detect from training config " + "or common fallbacks when the model uses a batch encoder." + ), + ) + parser.add_argument( + "--inner-batch-size", + type=int, + default=1, + help="Number of target perturbations to evaluate simultaneously in the second pass.", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed for control sampling.") + parser.add_argument( + "--output", + type=str, + default=None, + help="Path to output AnnData file (.h5ad). Defaults to _combo.h5ad", + ) + parser.add_argument("--quiet", action="store_true", help="Reduce logging verbosity.") + + +def run_tx_combo(args: ap.Namespace) -> None: + import logging + import os + import pickle + + import anndata as ad + import numpy as np + import pandas as pd + import scanpy as sc + import torch + import yaml + + from tqdm import tqdm + + from ...tx.models.state_transition import StateTransitionPerturbationModel + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + if args.quiet: + logger.setLevel(logging.WARNING) + + def _load_config(cfg_path: str) -> dict: + if not os.path.exists(cfg_path): + raise FileNotFoundError(f"Could not find config file: {cfg_path}") + with open(cfg_path, "r", encoding="utf-8") as handle: + return yaml.safe_load(handle) + + def _pick_first_present(columns: pd.Index, candidates: list[str | None]) -> str | None: + for key in candidates: + if key and key in columns: + return key + return None + + def _to_dense(matrix) -> np.ndarray: + try: + import scipy.sparse as sp # type: ignore + + if sp.issparse(matrix): + return matrix.toarray() + except Exception: + pass + return np.asarray(matrix) + + def _normalize_pert_vector(raw_vec, expected_dim: int) -> torch.Tensor: + if raw_vec is None: + return torch.zeros(expected_dim, dtype=torch.float32) + if torch.is_tensor(raw_vec): + return raw_vec.detach().float() + vec_np = np.asarray(raw_vec) + return torch.tensor(vec_np, dtype=torch.float32) + + def _flatten_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None: + if tensor is None: + return None + if tensor.dim() == 3 and tensor.shape[0] == 1: + return tensor.squeeze(0) + return tensor + + def _tensor_to_numpy(tensor: torch.Tensor | None) -> np.ndarray | None: + if tensor is None: + return None + flat = _flatten_tensor(tensor) + if flat is None: + return None + return flat.detach().cpu().numpy().astype(np.float32) + + def _argmax_index_from_any(value, expected_dim: int | None = None) -> int | None: + if value is None: + return None + try: + if torch.is_tensor(value): + if value.ndim == 0: + return int(value.item()) + if value.ndim == 1: + return int(torch.argmax(value).item()) + return None + except Exception: + return None + try: + arr = np.asarray(value) + if arr.ndim == 0: + return int(arr.item()) + if arr.ndim == 1: + return int(arr.argmax()) + except Exception: + pass + if isinstance(value, (int, np.integer)): + return int(value) + if isinstance(value, (list, tuple)): + try: + arr = np.asarray(value) + if arr.ndim == 1: + return int(arr.argmax()) + except Exception: + return None + return None + + model_dir = os.path.abspath(args.model_dir) + config_path = os.path.join(model_dir, "config.yaml") + cfg = _load_config(config_path) + + var_dims_path = os.path.join(model_dir, "var_dims.pkl") + if not os.path.exists(var_dims_path): + raise FileNotFoundError(f"Missing var_dims.pkl at {var_dims_path}") + with open(var_dims_path, "rb") as handle: + var_dims = pickle.load(handle) + + input_dim = int(var_dims.get("input_dim", 0)) + if input_dim <= 0: + raise ValueError("input_dim missing from var_dims.pkl; cannot determine feature dimension") + + pert_dim = int(var_dims.get("pert_dim", 0)) + if pert_dim <= 0: + raise ValueError("pert_dim missing from var_dims.pkl; cannot build perturbation embeddings") + + batch_dim_entry = var_dims.get("batch_dim") + batch_dim = int(batch_dim_entry) if batch_dim_entry is not None else None + + pert_map_path = os.path.join(model_dir, "pert_onehot_map.pt") + if not os.path.exists(pert_map_path): + raise FileNotFoundError(f"Missing pert_onehot_map.pt at {pert_map_path}") + pert_onehot_map = torch.load(pert_map_path, weights_only=False) + + batch_onehot_map_path = os.path.join(model_dir, "batch_onehot_map.pkl") + batch_onehot_map = None + if os.path.exists(batch_onehot_map_path): + with open(batch_onehot_map_path, "rb") as handle: + batch_onehot_map = pickle.load(handle) + + checkpoint_path = args.checkpoint + if checkpoint_path is None: + default_last = os.path.join(model_dir, "checkpoints", "last.ckpt") + default_final = os.path.join(model_dir, "checkpoints", "final.ckpt") + checkpoint_path = default_last if os.path.exists(default_last) else default_final + elif not os.path.isabs(checkpoint_path): + candidate = os.path.join(model_dir, checkpoint_path) + checkpoint_path = candidate if os.path.exists(candidate) else checkpoint_path + + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") + + model = StateTransitionPerturbationModel.load_from_checkpoint(checkpoint_path) + model.eval() + device = next(model.parameters()).device + cell_set_len = args.cell_set_len or getattr(model, "cell_sentence_len", 256) + + uses_batch_encoder = getattr(model, "batch_encoder", None) is not None + if uses_batch_encoder and (batch_dim is None or batch_dim <= 0): + raise ValueError("Model uses a batch encoder but batch_dim missing from var_dims.pkl") + if uses_batch_encoder and batch_onehot_map is None: + raise FileNotFoundError( + "Model uses a batch encoder but batch_onehot_map.pkl was not found in the model directory" + ) + + logger.info("Loaded model from %s (cell_set_len=%d)", checkpoint_path, cell_set_len) + + adata = sc.read_h5ad(args.adata) + + data_kwargs = {} + try: + data_kwargs = cfg.get("data", {}).get("kwargs", {}) # type: ignore[assignment] + except AttributeError: + data_kwargs = {} + + celltype_col = args.celltype_col + if celltype_col is None: + cfg_celltype = None + try: + cfg_celltype = data_kwargs.get("cell_type_key") + except Exception: + cfg_celltype = None + candidates = [ + cfg_celltype, + "cell_type", + "celltype", + "cell_type_name", + "celltype_name", + "cellType", + "ctype", + ] + celltype_col = _pick_first_present(adata.obs.columns, candidates) + if celltype_col is None: + raise ValueError("Could not determine cell type column; provide --celltype-col explicitly.") + if celltype_col not in adata.obs: + raise KeyError(f"Column '{celltype_col}' not found in adata.obs") + + if args.pert_col not in adata.obs: + raise KeyError(f"Perturbation column '{args.pert_col}' not found in adata.obs") + + adata_ct = adata[adata.obs[celltype_col].astype(str) == str(args.cell_type)].copy() + if adata_ct.n_obs == 0: + raise ValueError(f"No cells found with cell type '{args.cell_type}' in column '{celltype_col}'") + + pert_series = adata_ct.obs[args.pert_col].astype(str) + control_mask = pert_series == str(args.control_pert) + control_indices = np.where(control_mask)[0] + if len(control_indices) == 0: + raise ValueError( + f"No control cells with perturbation '{args.control_pert}' found in column '{args.pert_col}' " + f"for cell type '{args.cell_type}'" + ) + + perts_all = pd.unique(pert_series) + perts = [p for p in perts_all if p != str(args.control_pert)] + if len(perts) == 0: + raise ValueError("No non-control perturbations found in filtered AnnData") + + batch_indices_all: np.ndarray | None = None + batch_col = args.batch_col if args.batch_col is not None else data_kwargs.get("batch_col") + if uses_batch_encoder: + candidate_batch_cols: list[str] = [] + if batch_col is not None: + candidate_batch_cols.append(batch_col) + if isinstance(data_kwargs.get("batch_col"), str): + candidate_batch_cols.append(data_kwargs.get("batch_col")) + candidate_batch_cols.extend( + [ + "gem_group", + "gemgroup", + "batch", + "donor", + "plate", + "experiment", + "lane", + "batch_id", + ] + ) + resolved_batch_col = next((col for col in candidate_batch_cols if col in adata_ct.obs), None) + if resolved_batch_col is None: + raise ValueError( + "Model uses a batch encoder but no batch column was found. Provide --batch-col explicitly." + ) + batch_col = resolved_batch_col + raw_batch_labels = adata_ct.obs[batch_col].astype(str).values + + assert batch_onehot_map is not None + label_to_idx: dict[str, int] = {} + if isinstance(batch_onehot_map, dict): + for key, value in batch_onehot_map.items(): + idx = _argmax_index_from_any(value, batch_dim) + if idx is not None: + label_to_idx[str(key)] = idx + + if not label_to_idx and batch_dim is not None: + unique_labels = sorted(set(raw_batch_labels)) + label_to_idx = {lab: min(i, batch_dim - 1) for i, lab in enumerate(unique_labels)} + + if not label_to_idx: + raise ValueError("Unable to construct batch label mapping; batch_onehot_map is empty or invalid") + + fallback_idx = sorted(label_to_idx.values())[0] + batch_indices_all = np.zeros(len(raw_batch_labels), dtype=np.int64) + misses = 0 + for i, lab in enumerate(raw_batch_labels): + idx = label_to_idx.get(lab) + if idx is None: + batch_indices_all[i] = fallback_idx + misses += 1 + else: + batch_indices_all[i] = idx + + if misses: + logger.warning( + "Batch column '%s': %d/%d labels missing from saved mapping; using fallback index %d", + batch_col, + misses, + len(raw_batch_labels), + fallback_idx, + ) + logger.info( + "Using batch column '%s' with %d unique mapped labels", + batch_col, + len(np.unique(batch_indices_all)), + ) + + cfg_embed_key = data_kwargs.get("embed_key") + explicit_embed_key = args.embed_key is not None + + candidate_order: list[str | None] = [] + seen_keys: set[str | None] = set() + + def _append_candidate(key: str | None) -> None: + if key in seen_keys: + return + seen_keys.add(key) + candidate_order.append(key) + + if explicit_embed_key: + _append_candidate(args.embed_key) + else: + if isinstance(cfg_embed_key, str): + _append_candidate(cfg_embed_key) + _append_candidate(None) + for fallback_key in ("X_hvg", "X_state", "X_state_basal", "X_state_pred", "X_pca", "X_latent"): + if fallback_key in adata_ct.obsm: + _append_candidate(fallback_key) + + selection_errors: list[str] = [] + features = None + used_embed_key: str | None = None + + for candidate in candidate_order: + matrix = None + label = "adata.X" if candidate is None else f"adata.obsm['{candidate}']" + + if candidate is None: + matrix = _to_dense(adata_ct.X) + else: + if candidate not in adata_ct.obsm: + if explicit_embed_key: + raise KeyError(f"Embedding key '{candidate}' not found in adata.obsm") + selection_errors.append(f"{label} missing") + continue + matrix = np.asarray(adata_ct.obsm[candidate]) + + if matrix.shape[0] != adata_ct.n_obs: + msg = f"{label} row count {matrix.shape[0]} != filtered AnnData cells {adata_ct.n_obs}" + if explicit_embed_key: + raise ValueError(msg) + selection_errors.append(msg) + continue + + if matrix.shape[1] != input_dim: + msg = f"{label} feature dimension {matrix.shape[1]} != model input_dim {input_dim}" + if explicit_embed_key: + raise ValueError( + msg + + ". Provide --embed-key pointing to a representation with matching dimension or preprocess the input." + ) + selection_errors.append(msg) + continue + + features = matrix + used_embed_key = candidate + break + + if features is None: + tried = ", ".join(["adata.X" if c is None else f"adata.obsm['{c}']" for c in candidate_order]) or "(none)" + detail = "; ".join(selection_errors) if selection_errors else "No suitable feature representation found." + raise ValueError( + f"Unable to find a feature matrix matching the model input dimension. Tried: {tried}. {detail}" + ) + + if used_embed_key is None: + logger.info("Using adata.X (%d cells x %d features) as input features", features.shape[0], features.shape[1]) + else: + logger.info( + "Using adata.obsm['%s'] (%d cells x %d features) as input features", + used_embed_key, + features.shape[0], + features.shape[1], + ) + + features = features.astype(np.float32, copy=False) + + rng = np.random.default_rng(args.seed) + replace = len(control_indices) < cell_set_len + sampled_idx = rng.choice(control_indices, size=cell_set_len, replace=replace) + control_features = features[sampled_idx] + + default_vec = _normalize_pert_vector(pert_onehot_map.get(str(args.control_pert)), pert_dim) + if default_vec.numel() != pert_dim: + default_vec = torch.zeros(pert_dim, dtype=torch.float32) + + control_batch_tensor = None + if batch_indices_all is not None: + control_batch_tensor = torch.tensor(batch_indices_all[sampled_idx], dtype=torch.long, device=device) + + def _pert_batch_tensor(name: str) -> torch.Tensor: + raw_vec = pert_onehot_map.get(name) + vec = _normalize_pert_vector(raw_vec, pert_dim) if raw_vec is not None else default_vec + if vec.dim() == 0: + vec = vec.unsqueeze(0) + vec = vec.reshape(-1) + if vec.numel() != pert_dim: + raise ValueError(f"Perturbation vector for '{name}' has incorrect dimension {vec.numel()} != {pert_dim}") + return vec.unsqueeze(0).repeat(cell_set_len, 1).to(device) + + pert_batch_vectors = {name: _pert_batch_tensor(name) for name in perts} + + control_tensor = torch.tensor(control_features, dtype=torch.float32, device=device) + + use_counts: bool | None = None + X_blocks: list[np.ndarray] = [] + latent_blocks: list[np.ndarray] = [] + obs_rows: list[dict[str, str | int]] = [] + + inner_batch_size = max(1, int(args.inner_batch_size)) + + with torch.no_grad(): + progress_total = len(perts) * len(perts) + progress_bar = tqdm( + total=progress_total, + desc="Combo sweeps", + unit="combo", + disable=args.quiet, + ) + for pert1 in perts: + first_batch = { + "ctrl_cell_emb": control_tensor.clone(), + "pert_emb": pert_batch_vectors[pert1], + "pert_name": [pert1] * cell_set_len, + } + if control_batch_tensor is not None: + first_batch["batch"] = control_batch_tensor.clone() + first_out = model.predict_step(first_batch, batch_idx=0, padded=False) + first_latent_tensor = _flatten_tensor(first_out.get("preds")) + if first_latent_tensor is None: + raise RuntimeError("Model predict_step did not return 'preds' tensor") + first_latent_tensor = first_latent_tensor.detach().to(device) + + for chunk_start in range(0, len(perts), inner_batch_size): + chunk_perts = perts[chunk_start : chunk_start + inner_batch_size] + chunk_size = len(chunk_perts) + + ctrl_chunk = torch.cat([first_latent_tensor.clone() for _ in chunk_perts], dim=0) + pert_chunk = torch.cat([pert_batch_vectors[p] for p in chunk_perts], dim=0) + names_chunk = [p for p in chunk_perts for _ in range(cell_set_len)] + + second_batch = { + "ctrl_cell_emb": ctrl_chunk, + "pert_emb": pert_chunk, + "pert_name": names_chunk, + } + + if control_batch_tensor is not None: + batch_chunk = control_batch_tensor.repeat(chunk_size) + second_batch["batch"] = batch_chunk + + second_out = model.predict_step(second_batch, batch_idx=0, padded=True) + + latent_np = _tensor_to_numpy(second_out.get("preds")) + counts_np = _tensor_to_numpy(second_out.get("pert_cell_counts_preds")) + + if latent_np is None: + raise RuntimeError("Second-stage prediction missing 'preds' output") + + latent_np = latent_np.reshape(chunk_size, cell_set_len, -1) + counts_np = counts_np.reshape(chunk_size, cell_set_len, -1) if counts_np is not None else None + + if use_counts is None: + use_counts = counts_np is not None + elif use_counts and counts_np is None: + raise RuntimeError("Inconsistent decoder outputs across perturbations; expected counts predictions") + + for idx_chunk, pert2 in enumerate(chunk_perts): + latent_slice = latent_np[idx_chunk].astype(np.float32) + if use_counts: + assert counts_np is not None + X_blocks.append(counts_np[idx_chunk].astype(np.float32)) + else: + X_blocks.append(latent_slice) + latent_blocks.append(latent_slice) + + for cell_idx in range(cell_set_len): + obs_rows.append({"pert1": pert1, "pert2": pert2, "cell_index": cell_idx}) + + progress_bar.update(1) + + progress_bar.close() + + X_matrix = np.vstack(X_blocks) if X_blocks else np.empty((0, 0), dtype=np.float32) + latent_matrix = np.vstack(latent_blocks) if latent_blocks else np.empty((0, 0), dtype=np.float32) + obs_df = pd.DataFrame(obs_rows) + + feature_dim = X_matrix.shape[1] if X_matrix.size > 0 else latent_matrix.shape[1] + gene_names = var_dims.get("gene_names") + if use_counts and isinstance(gene_names, (list, tuple)) and len(gene_names) == feature_dim: + var_index = pd.Index([str(name) for name in gene_names], name="gene") + else: + var_index = pd.Index([f"feature_{i}" for i in range(feature_dim)], name="feature") + var_df = pd.DataFrame(index=var_index) + + combo_adata = ad.AnnData(X=X_matrix, obs=obs_df, var=var_df) + combo_adata.obsm["latent_preds"] = latent_matrix + combo_adata.uns["cell_type"] = str(args.cell_type) + combo_adata.uns["perturbations"] = perts + combo_adata.uns["control_pert"] = str(args.control_pert) + combo_adata.uns["cell_set_len"] = cell_set_len + combo_adata.uns["input_embed_key"] = used_embed_key if used_embed_key is not None else "X" + if uses_batch_encoder and batch_col is not None: + combo_adata.uns["batch_col"] = batch_col + combo_adata.uns["inner_batch_size"] = inner_batch_size + combo_adata.uns["sampled_control_indices"] = adata_ct.obs_names[sampled_idx].tolist() + + output_path = args.output or args.adata.replace(".h5ad", "_combo.h5ad") + output_path = os.path.abspath(output_path) + output_dir = os.path.dirname(output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + combo_adata.write_h5ad(output_path) + + logger.info("Saved combo AnnData with %d cells to %s", combo_adata.n_obs, output_path) From 7e7f0227ced0d2604ab8dfb1aab4232ae79e2382 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 7 Oct 2025 22:03:11 +0000 Subject: [PATCH 24/38] updated infer to take max and min cells, and all perts argument --- src/state/_cli/_tx/_combo.py | 132 +++++++++++++------- src/state/_cli/_tx/_infer.py | 153 +++++++++++++++++++++++- src/state/configs/model/state.yaml | 1 + src/state/tx/models/state_transition.py | 94 ++++++++++++++- 4 files changed, 332 insertions(+), 48 deletions(-) diff --git a/src/state/_cli/_tx/_combo.py b/src/state/_cli/_tx/_combo.py index ef167368..7b55d929 100644 --- a/src/state/_cli/_tx/_combo.py +++ b/src/state/_cli/_tx/_combo.py @@ -71,10 +71,13 @@ def add_arguments_combo(parser: ap.ArgumentParser) -> None: ) parser.add_argument("--seed", type=int, default=0, help="Random seed for control sampling.") parser.add_argument( - "--output", + "--output-folder", type=str, default=None, - help="Path to output AnnData file (.h5ad). Defaults to _combo.h5ad", + help=( + "Directory where per-perturbation AnnData outputs (.h5ad) are written." + " Defaults to _combo/ alongside the input file." + ), ) parser.add_argument("--quiet", action="store_true", help="Reduce logging verbosity.") @@ -83,6 +86,7 @@ def run_tx_combo(args: ap.Namespace) -> None: import logging import os import pickle + import re import anndata as ad import numpy as np @@ -462,12 +466,26 @@ def _pert_batch_tensor(name: str) -> torch.Tensor: control_tensor = torch.tensor(control_features, dtype=torch.float32, device=device) use_counts: bool | None = None - X_blocks: list[np.ndarray] = [] - latent_blocks: list[np.ndarray] = [] - obs_rows: list[dict[str, str | int]] = [] - inner_batch_size = max(1, int(args.inner_batch_size)) + def _default_output_dir(path: str) -> str: + base_dir = os.path.dirname(os.path.abspath(path)) + base_name = os.path.splitext(os.path.basename(path))[0] + return os.path.join(base_dir, f"{base_name}_combo") + + output_dir = args.output_folder or _default_output_dir(args.adata) + output_dir = os.path.abspath(output_dir) + os.makedirs(output_dir, exist_ok=True) + logger.info("Writing per-perturbation combo outputs to %s", output_dir) + + def _sanitize_filename(label: str) -> str: + sanitized = re.sub(r"[^0-9A-Za-z_.-]+", "_", label) + sanitized = sanitized.strip("._") + return sanitized or "perturbation" + + used_output_names: dict[str, int] = {} + written_files: list[str] = [] + with torch.no_grad(): progress_total = len(perts) * len(perts) progress_bar = tqdm( @@ -477,6 +495,10 @@ def _pert_batch_tensor(name: str) -> torch.Tensor: disable=args.quiet, ) for pert1 in perts: + per_pert_X_blocks: list[np.ndarray] = [] + per_pert_latent_blocks: list[np.ndarray] = [] + per_pert_obs_rows: list[dict[str, str | int]] = [] + first_batch = { "ctrl_cell_emb": control_tensor.clone(), "pert_emb": pert_batch_vectors[pert1], @@ -528,47 +550,73 @@ def _pert_batch_tensor(name: str) -> torch.Tensor: latent_slice = latent_np[idx_chunk].astype(np.float32) if use_counts: assert counts_np is not None - X_blocks.append(counts_np[idx_chunk].astype(np.float32)) + per_pert_X_blocks.append(counts_np[idx_chunk].astype(np.float32)) else: - X_blocks.append(latent_slice) - latent_blocks.append(latent_slice) + per_pert_X_blocks.append(latent_slice) + per_pert_latent_blocks.append(latent_slice) for cell_idx in range(cell_set_len): - obs_rows.append({"pert1": pert1, "pert2": pert2, "cell_index": cell_idx}) + per_pert_obs_rows.append({"pert1": pert1, "pert2": pert2, "cell_index": cell_idx}) progress_bar.update(1) - progress_bar.close() + X_matrix = ( + np.vstack(per_pert_X_blocks) + if per_pert_X_blocks + else np.empty((0, 0), dtype=np.float32) + ) + latent_matrix = ( + np.vstack(per_pert_latent_blocks) + if per_pert_latent_blocks + else np.empty((0, 0), dtype=np.float32) + ) + obs_df = pd.DataFrame(per_pert_obs_rows) + + feature_dim = 0 + if use_counts and X_matrix.size > 0: + feature_dim = X_matrix.shape[1] + elif latent_matrix.size > 0: + feature_dim = latent_matrix.shape[1] + elif X_matrix.size > 0: + feature_dim = X_matrix.shape[1] + + gene_names = var_dims.get("gene_names") + if ( + use_counts + and feature_dim > 0 + and isinstance(gene_names, (list, tuple)) + and len(gene_names) == feature_dim + ): + var_index = pd.Index([str(name) for name in gene_names], name="gene") + else: + var_index = pd.Index([f"feature_{i}" for i in range(feature_dim)], name="feature") + var_df = pd.DataFrame(index=var_index) + + combo_adata = ad.AnnData(X=X_matrix, obs=obs_df, var=var_df) + combo_adata.obsm["latent_preds"] = latent_matrix + combo_adata.uns["cell_type"] = str(args.cell_type) + combo_adata.uns["perturbations"] = perts + combo_adata.uns["pert1"] = pert1 + combo_adata.uns["control_pert"] = str(args.control_pert) + combo_adata.uns["cell_set_len"] = cell_set_len + combo_adata.uns["input_embed_key"] = used_embed_key if used_embed_key is not None else "X" + if uses_batch_encoder and batch_col is not None: + combo_adata.uns["batch_col"] = batch_col + combo_adata.uns["inner_batch_size"] = inner_batch_size + combo_adata.uns["sampled_control_indices"] = adata_ct.obs_names[sampled_idx].tolist() + + output_name = _sanitize_filename(pert1) + if output_name in used_output_names: + used_output_names[output_name] += 1 + output_name = f"{output_name}_{used_output_names[output_name]}" + else: + used_output_names[output_name] = 0 - X_matrix = np.vstack(X_blocks) if X_blocks else np.empty((0, 0), dtype=np.float32) - latent_matrix = np.vstack(latent_blocks) if latent_blocks else np.empty((0, 0), dtype=np.float32) - obs_df = pd.DataFrame(obs_rows) + output_path = os.path.join(output_dir, f"{output_name}.h5ad") + combo_adata.write_h5ad(output_path) + written_files.append(output_path) + logger.info("Saved combos for %s with %d cells to %s", pert1, combo_adata.n_obs, output_path) - feature_dim = X_matrix.shape[1] if X_matrix.size > 0 else latent_matrix.shape[1] - gene_names = var_dims.get("gene_names") - if use_counts and isinstance(gene_names, (list, tuple)) and len(gene_names) == feature_dim: - var_index = pd.Index([str(name) for name in gene_names], name="gene") - else: - var_index = pd.Index([f"feature_{i}" for i in range(feature_dim)], name="feature") - var_df = pd.DataFrame(index=var_index) - - combo_adata = ad.AnnData(X=X_matrix, obs=obs_df, var=var_df) - combo_adata.obsm["latent_preds"] = latent_matrix - combo_adata.uns["cell_type"] = str(args.cell_type) - combo_adata.uns["perturbations"] = perts - combo_adata.uns["control_pert"] = str(args.control_pert) - combo_adata.uns["cell_set_len"] = cell_set_len - combo_adata.uns["input_embed_key"] = used_embed_key if used_embed_key is not None else "X" - if uses_batch_encoder and batch_col is not None: - combo_adata.uns["batch_col"] = batch_col - combo_adata.uns["inner_batch_size"] = inner_batch_size - combo_adata.uns["sampled_control_indices"] = adata_ct.obs_names[sampled_idx].tolist() - - output_path = args.output or args.adata.replace(".h5ad", "_combo.h5ad") - output_path = os.path.abspath(output_path) - output_dir = os.path.dirname(output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) - combo_adata.write_h5ad(output_path) - - logger.info("Saved combo AnnData with %d cells to %s", combo_adata.n_obs, output_path) + progress_bar.close() + + logger.info("Finished writing %d combo files to %s", len(written_files), output_dir) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 530cee1f..4d6ccdc2 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -82,6 +82,23 @@ def add_arguments_infer(parser: argparse.ArgumentParser): default=None, help="Path to TSV file with columns 'perturbation' and 'num_cells' to pad the adata with additional perturbation cells copied from random controls.", ) + parser.add_argument( + "--all-perts", + action="store_true", + help="If set, add virtual copies of control cells for every perturbation in the saved one-hot map so all perturbations are simulated.", + ) + parser.add_argument( + "--min-cells", + type=int, + default=None, + help="Ensure each perturbation has at least this many cells by padding with virtual controls (if needed).", + ) + parser.add_argument( + "--max-cells", + type=int, + default=None, + help="Upper bound on cells per perturbation after padding; subsamples excess cells if necessary.", + ) def run_tx_infer(args: argparse.Namespace): @@ -321,6 +338,7 @@ def pad_adata_with_tsv( control_pert = "non-targeting" if not args.quiet: print(f"Control perturbation: {control_pert}") + control_pert_str = str(control_pert) # choose cell type column if args.celltype_col is None: @@ -361,6 +379,8 @@ def pad_adata_with_tsv( if not os.path.exists(pert_onehot_map_path): raise FileNotFoundError(f"Missing pert_onehot_map.pt at {pert_onehot_map_path}") pert_onehot_map: Dict[str, torch.Tensor] = torch.load(pert_onehot_map_path, weights_only=False) + pert_name_lookup: Dict[str, object] = {str(k): k for k in pert_onehot_map.keys()} + pert_names_in_map: List[str] = list(pert_name_lookup.keys()) batch_onehot_map_path = os.path.join(args.model_dir, "batch_onehot_map.pkl") batch_onehot_map = None @@ -423,6 +443,129 @@ def pad_adata_with_tsv( if not args.quiet: print(f"Filtered to {adata.n_obs} cells (from {n0}) for cell types: {keep_cts}") + needs_virtual_padding = args.all_perts or (args.min_cells is not None) or (args.max_cells is not None) + if needs_virtual_padding: + if args.pert_col not in adata.obs: + raise KeyError(f"Perturbation column '{args.pert_col}' not found in adata.obs") + + adata.obs = adata.obs.copy() + adata.obs[args.pert_col] = adata.obs[args.pert_col].astype(str) + + # optionally expand controls to cover every perturbation in the map + if args.all_perts: + observed_perts = set(adata.obs[args.pert_col].values) + missing_perts = [p for p in pert_names_in_map if p not in observed_perts] + + if missing_perts: + ctrl_mask_all_perts = adata.obs[args.pert_col] == control_pert_str + if not bool(np.any(ctrl_mask_all_perts)): + raise ValueError( + "--all-perts requested, but no control cells are available to template new perturbations." + ) + + ctrl_template = adata[ctrl_mask_all_perts].copy() + ctrl_template.obs = ctrl_template.obs.copy() + ctrl_template.obs[args.pert_col] = ctrl_template.obs[args.pert_col].astype(str) + + virtual_blocks: List["sc.AnnData"] = [] + for pert_name in missing_perts: + clone = ctrl_template.copy() + clone.obs = clone.obs.copy() + clone.obs[args.pert_col] = pert_name + clone.obs_names = [f"{obs_name}__virt_{pert_name}" for obs_name in clone.obs_names] + virtual_blocks.append(clone) + + adata = sc.concat([adata, *virtual_blocks], axis=0, join="same", label=None, index_unique=None) + + if not args.quiet: + preview = ", ".join(missing_perts[:5]) + if len(missing_perts) > 5: + preview += ", ..." + print( + f"Added virtual control copies for {len(missing_perts)} perturbations" + f" ({preview if preview else 'n/a'}). Total cells: {adata.n_obs}." + ) + elif not args.quiet: + print("--all-perts requested, but all perturbations already present in AnnData.") + + # ensure each perturbation meets the minimum count by cloning controls + if args.min_cells is not None: + if args.min_cells <= 0: + raise ValueError("--min-cells must be a positive integer if provided.") + + ctrl_mask_min_cells = adata.obs[args.pert_col] == control_pert_str + if not bool(np.any(ctrl_mask_min_cells)): + raise ValueError("--min-cells requested, but no control cells are available for cloning.") + + pad_rng = np.random.RandomState(args.seed) + ctrl_pool = adata[ctrl_mask_min_cells].copy() + ctrl_pool.obs = ctrl_pool.obs.copy() + virtual_blocks: List["sc.AnnData"] = [] + + pert_counts = adata.obs[args.pert_col].value_counts() + for pert_name, count in pert_counts.items(): + deficit = int(args.min_cells) - int(count) + if deficit <= 0: + continue + + sampled_idx = pad_rng.choice(ctrl_pool.n_obs, size=deficit, replace=True) + clone = ctrl_pool[sampled_idx].copy() + clone.obs = clone.obs.copy() + clone.obs[args.pert_col] = pert_name + base_names = list(clone.obs_names) + clone.obs_names = [ + f"{obs_name}__virt_{pert_name}__pad{idx+1}" + for idx, obs_name in enumerate(base_names) + ] + virtual_blocks.append(clone) + + if virtual_blocks: + adata = sc.concat([adata, *virtual_blocks], axis=0, join="same", label=None, index_unique=None) + if not args.quiet: + preview = ", ".join( + [f"{pert}:{args.min_cells}" for pert, cnt in pert_counts.items() if int(cnt) < int(args.min_cells)][:5] + ) + if len(virtual_blocks) > 5: + preview += ", ..." + total_added = sum(vb.n_obs for vb in virtual_blocks) + print( + f"Added {total_added} padding cells to meet --min-cells " + f"(examples: {preview if preview else 'n/a'}). Total cells: {adata.n_obs}." + ) + elif not args.quiet: + print("--min-cells set, but all perturbations already meet the threshold.") + + # cap the number of cells per perturbation by subsampling + if args.max_cells is not None: + if args.max_cells <= 0: + raise ValueError("--max-cells must be a positive integer if provided.") + if args.min_cells is not None and args.max_cells < args.min_cells: + raise ValueError("--max-cells cannot be smaller than --min-cells.") + + trim_rng = np.random.RandomState(args.seed + 1) + keep_mask = np.ones(adata.n_obs, dtype=bool) + pert_labels = adata.obs[args.pert_col].values + + unique_perts = np.unique(pert_labels) + for pert_name in unique_perts: + idxs = np.where(pert_labels == pert_name)[0] + if len(idxs) <= args.max_cells: + continue + + chosen = trim_rng.choice(idxs, size=args.max_cells, replace=False) + chosen = np.sort(chosen) + drop = np.setdiff1d(idxs, chosen, assume_unique=True) + keep_mask[drop] = False + + if not np.all(keep_mask): + original_n = adata.n_obs + adata = adata[keep_mask].copy() + if not args.quiet: + total_dropped = original_n - adata.n_obs + print( + f"Subsampled perturbations exceeding --max-cells; dropped {total_dropped} cells. Total cells: {adata.n_obs}." + ) + # select features: embeddings or genes if args.embed_key is None: X_in = to_dense(adata.X) # [N, E_in] @@ -491,7 +634,7 @@ def pad_adata_with_tsv( rng = np.random.RandomState(args.seed) # Identify control vs non-control - ctl_mask = pert_names_all == str(control_pert) + ctl_mask = pert_names_all == control_pert_str n_controls = int(ctl_mask.sum()) n_total = adata.n_obs n_nonctl = n_total - n_controls @@ -525,8 +668,9 @@ def group_control_indices(group_name: str) -> np.ndarray: return grp_ctl if len(grp_ctl) > 0 else all_control_indices # default pert vector when unmapped label shows up - if control_pert in pert_onehot_map: - default_pert_vec = pert_onehot_map[control_pert].float().clone() + control_map_key = pert_name_lookup.get(control_pert_str, control_pert) + if control_map_key in pert_onehot_map: + default_pert_vec = pert_onehot_map[control_map_key].float().clone() else: default_pert_vec = torch.zeros(pert_dim, dtype=torch.float32) if pert_dim and pert_dim > 0: @@ -568,7 +712,8 @@ def group_control_indices(group_name: str) -> np.ndarray: continue # one-hot vector for this perturbation (repeat across window) - vec = pert_onehot_map.get(p, None) + map_key = pert_name_lookup.get(p, p) + vec = pert_onehot_map.get(map_key, None) if vec is None: vec = default_pert_vec if not args.quiet: diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index 98c15880..c5128319 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -20,6 +20,7 @@ kwargs: use_batch_token: False nb_decoder: False mask_attn: False + dosage: False use_effect_gating_token: False distributional_loss: energy init_from: null diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index 72d7020b..4050479b 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -1,5 +1,6 @@ +import ast import logging -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import anndata as ad import numpy as np @@ -8,7 +9,6 @@ import torch.nn.functional as F from geomloss import SamplesLoss -from typing import Tuple from .base import PerturbationModel from .decoders import FinetuneVCICountsDecoder @@ -294,6 +294,13 @@ def __init__( self.confidence_target_scale = None self.confidence_weight = 0.0 + self.use_dosage_encoder = bool(kwargs.get("dosage", False)) + if self.use_dosage_encoder: + self.dosage_encoder = nn.Linear(1, self.hidden_dim) + else: + self.dosage_encoder = None + self._warned_missing_dosage = False + # Backward-compat: accept legacy key `freeze_pert` self.freeze_pert_backbone = kwargs.get("freeze_pert_backbone", kwargs.get("freeze_pert", False)) if self.freeze_pert_backbone: @@ -431,6 +438,12 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: combined_input = pert_embedding + control_cells # Shape: [B, S, hidden_dim] seq_input = combined_input # Shape: [B, S, hidden_dim] + if self.use_dosage_encoder: + dosage_tensor = self._prepare_dosage_tensor(batch, seq_input.device, pert.shape[:2]) + if dosage_tensor is not None: + dosage_features = self.dosage_encoder(torch.log1p(dosage_tensor)) + seq_input = seq_input + dosage_features + if self.batch_encoder is not None: # Extract batch indices (assume they are integers or convert from one-hot) batch_indices = batch["batch"] @@ -531,6 +544,83 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: else: return output + def _prepare_dosage_tensor( + self, batch: Dict[str, torch.Tensor], device: torch.device, shape: Tuple[int, int] + ) -> Optional[torch.Tensor]: + """Return dosage tensor shaped for broadcasting or None if unavailable.""" + + if not self.use_dosage_encoder: + return None + + dosage_values = batch.get("pert_dosage") + + if dosage_values is not None: + if torch.is_tensor(dosage_values): + dosage_tensor = dosage_values.to(device=device, dtype=torch.float32) + else: + dosage_tensor = torch.as_tensor(dosage_values, device=device, dtype=torch.float32) + else: + pert_names = batch.get("pert_name") + if pert_names is None: + if not self._warned_missing_dosage: + logger.warning("Dosage encoder enabled but no dosage information found in batch; skipping dosage term.") + self._warned_missing_dosage = True + return None + + if isinstance(pert_names, torch.Tensor): + pert_names = pert_names.tolist() + if not isinstance(pert_names, (list, tuple)): + pert_names = [pert_names] + + dosage_list = [self._parse_dosage_from_name(name) for name in pert_names] + dosage_tensor = torch.tensor(dosage_list, device=device, dtype=torch.float32) + + if not self._warned_missing_dosage: + logger.warning("Falling back to parsing dosage from perturbation names; consider providing 'pert_dosage'.") + self._warned_missing_dosage = True + + dosage_tensor = dosage_tensor.flatten() + expected_elems = shape[0] * shape[1] + + if dosage_tensor.numel() == expected_elems: + return dosage_tensor.reshape(shape[0], shape[1], 1) + + if dosage_tensor.numel() == shape[0] and shape[1] > 0: + return dosage_tensor.view(shape[0], 1, 1).expand(shape[0], shape[1], 1) + + if shape[0] == 1 and dosage_tensor.numel() == shape[1]: + return dosage_tensor.view(1, shape[1], 1) + + logger.warning( + "Dosage tensor has %d elements but expected either %d or %d; skipping dosage term for this batch.", + dosage_tensor.numel(), + expected_elems, + shape[0], + ) + return None + + @staticmethod + def _parse_dosage_from_name(name: Optional[str]) -> float: + """Extract dosage value from perturbation name string.""" + + if not isinstance(name, str): + return 0.0 + + try: + parsed = ast.literal_eval(name) + except (ValueError, SyntaxError): + return 0.0 + + try: + if isinstance(parsed, (list, tuple)) and len(parsed) > 0: + first_entry = parsed[0] + if isinstance(first_entry, (list, tuple)) and len(first_entry) > 1: + return float(first_entry[1]) + except (TypeError, ValueError): + pass + + return 0.0 + def _compute_distribution_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Apply the primary distributional loss, optionally chunking feature dimensions for SamplesLoss.""" From ed00c06ba6511b570f68acc6c80325bdf994e63b Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Wed, 8 Oct 2025 05:31:51 +0000 Subject: [PATCH 25/38] small fix --- src/state/_cli/_tx/_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 4d6ccdc2..0d7ba11d 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -475,7 +475,7 @@ def pad_adata_with_tsv( clone.obs_names = [f"{obs_name}__virt_{pert_name}" for obs_name in clone.obs_names] virtual_blocks.append(clone) - adata = sc.concat([adata, *virtual_blocks], axis=0, join="same", label=None, index_unique=None) + adata = sc.concat([adata, *virtual_blocks], axis=0, join="inner") if not args.quiet: preview = ", ".join(missing_perts[:5]) @@ -520,7 +520,7 @@ def pad_adata_with_tsv( virtual_blocks.append(clone) if virtual_blocks: - adata = sc.concat([adata, *virtual_blocks], axis=0, join="same", label=None, index_unique=None) + adata = sc.concat([adata, *virtual_blocks], axis=0, join="inner") if not args.quiet: preview = ", ".join( [f"{pert}:{args.min_cells}" for pert, cnt in pert_counts.items() if int(cnt) < int(args.min_cells)][:5] From 8ded763e7efbd338f9174ffe22fa15a9c8c59996 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Wed, 8 Oct 2025 05:32:13 +0000 Subject: [PATCH 26/38] chore: formatting --- src/state/_cli/_tx/_combo.py | 10 ++-------- src/state/_cli/_tx/_infer.py | 9 ++++----- src/state/tx/models/state_transition.py | 8 ++++++-- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/state/_cli/_tx/_combo.py b/src/state/_cli/_tx/_combo.py index 7b55d929..5cd4da73 100644 --- a/src/state/_cli/_tx/_combo.py +++ b/src/state/_cli/_tx/_combo.py @@ -560,15 +560,9 @@ def _sanitize_filename(label: str) -> str: progress_bar.update(1) - X_matrix = ( - np.vstack(per_pert_X_blocks) - if per_pert_X_blocks - else np.empty((0, 0), dtype=np.float32) - ) + X_matrix = np.vstack(per_pert_X_blocks) if per_pert_X_blocks else np.empty((0, 0), dtype=np.float32) latent_matrix = ( - np.vstack(per_pert_latent_blocks) - if per_pert_latent_blocks - else np.empty((0, 0), dtype=np.float32) + np.vstack(per_pert_latent_blocks) if per_pert_latent_blocks else np.empty((0, 0), dtype=np.float32) ) obs_df = pd.DataFrame(per_pert_obs_rows) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 0d7ba11d..53952edd 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -513,17 +513,16 @@ def pad_adata_with_tsv( clone.obs = clone.obs.copy() clone.obs[args.pert_col] = pert_name base_names = list(clone.obs_names) - clone.obs_names = [ - f"{obs_name}__virt_{pert_name}__pad{idx+1}" - for idx, obs_name in enumerate(base_names) - ] + clone.obs_names = [f"{obs_name}__virt_{pert_name}__pad{idx + 1}" for idx, obs_name in enumerate(base_names)] virtual_blocks.append(clone) if virtual_blocks: adata = sc.concat([adata, *virtual_blocks], axis=0, join="inner") if not args.quiet: preview = ", ".join( - [f"{pert}:{args.min_cells}" for pert, cnt in pert_counts.items() if int(cnt) < int(args.min_cells)][:5] + [f"{pert}:{args.min_cells}" for pert, cnt in pert_counts.items() if int(cnt) < int(args.min_cells)][ + :5 + ] ) if len(virtual_blocks) > 5: preview += ", ..." diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index 4050479b..4001a453 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -563,7 +563,9 @@ def _prepare_dosage_tensor( pert_names = batch.get("pert_name") if pert_names is None: if not self._warned_missing_dosage: - logger.warning("Dosage encoder enabled but no dosage information found in batch; skipping dosage term.") + logger.warning( + "Dosage encoder enabled but no dosage information found in batch; skipping dosage term." + ) self._warned_missing_dosage = True return None @@ -576,7 +578,9 @@ def _prepare_dosage_tensor( dosage_tensor = torch.tensor(dosage_list, device=device, dtype=torch.float32) if not self._warned_missing_dosage: - logger.warning("Falling back to parsing dosage from perturbation names; consider providing 'pert_dosage'.") + logger.warning( + "Falling back to parsing dosage from perturbation names; consider providing 'pert_dosage'." + ) self._warned_missing_dosage = True dosage_tensor = dosage_tensor.flatten() From 358b1a42234e5dd750fb28c745b45a18b9f780cc Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Wed, 8 Oct 2025 12:15:46 -0700 Subject: [PATCH 27/38] Fix pseudobulk model for wandb tracking --- src/state/_cli/_tx/_train.py | 2 +- src/state/tx/callbacks/cumulative_flops.py | 30 ++++++++++++++++------ src/state/tx/models/pseudobulk.py | 4 +-- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 2c9f10a0..08d0ef9a 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -234,7 +234,7 @@ def run_tx_train(cfg: DictConfig): callbacks.append(mfu_cb) - # Add CumulativeFLOPSCallback to track cumulative FLOPs + if "cumulative_flops_use_backward" in cfg["training"]: cumulative_flops_use_backward = cfg["training"]["cumulative_flops_use_backward"] cumulative_flops_cb = CumulativeFLOPSCallback(use_backward=cumulative_flops_use_backward) callbacks.append(cumulative_flops_cb) diff --git a/src/state/tx/callbacks/cumulative_flops.py b/src/state/tx/callbacks/cumulative_flops.py index 720083e8..6ab05ce1 100644 --- a/src/state/tx/callbacks/cumulative_flops.py +++ b/src/state/tx/callbacks/cumulative_flops.py @@ -36,18 +36,32 @@ def __init__( self._batch_count: int = 0 def _trainstep_forward_backward(self, model: LightningModule, batch: Any) -> torch.Tensor: - """Encapsulate calling StateTransitionPerturbationModel.training_step and backward. + """Call the model's training_step (handling optional args) and run backward if configured.""" - This intentionally targets StateTransitionPerturbationModel's signature and - performs both forward and backward to capture full FLOPs. - - !!WARNING!! - This has only been tested with StateTransitionPerturbationModel. Behavior with any other model has not been verified. - """ model.zero_grad(set_to_none=True) - loss: torch.Tensor = model.training_step(batch, 0, padded=True) # type: ignore + + try: + loss_out = model.training_step(batch, 0, padded=True) + except TypeError: + loss_out = model.training_step(batch, 0) + + if isinstance(loss_out, dict): + loss = loss_out.get("loss") + if loss is None: + raise RuntimeError( + "CumulativeFLOPSCallback expected training_step to return a Tensor or dict containing 'loss'." + ) + else: + loss = loss_out + + if not isinstance(loss, torch.Tensor): # pragma: no cover - defensive guard + raise TypeError( + "CumulativeFLOPSCallback requires training_step to return a Tensor (or dict with 'loss' Tensor)." + ) + if self.use_backward: loss.backward() + return loss def _measure_flops_once(self, trainer: Trainer, pl_module: Any, batch: Any) -> None: diff --git a/src/state/tx/models/pseudobulk.py b/src/state/tx/models/pseudobulk.py index 4da11f99..5494dea2 100644 --- a/src/state/tx/models/pseudobulk.py +++ b/src/state/tx/models/pseudobulk.py @@ -371,8 +371,8 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non target = batch["pert_cell_emb"] target = target.reshape(-1, self.cell_sentence_len, self.output_dim) - loss = self.loss_fn(pred, target).mean() - self.log("val_loss", loss) + loss = torch.nanmean(self.loss_fn(pred, target)) + self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) if self.gene_decoder is not None and "pert_cell_counts" in batch: gene_targets = batch["pert_cell_counts"] From edddc31eacf96d8cdabb50ce55ed0d8a2cd3ad57 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Wed, 8 Oct 2025 19:33:43 +0000 Subject: [PATCH 28/38] added dosage arg --- src/state/_cli/_tx/_combo.py | 37 +++++++--- src/state/_cli/_tx/_infer.py | 138 +++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 9 deletions(-) diff --git a/src/state/_cli/_tx/_combo.py b/src/state/_cli/_tx/_combo.py index 5cd4da73..7b5dc07a 100644 --- a/src/state/_cli/_tx/_combo.py +++ b/src/state/_cli/_tx/_combo.py @@ -485,9 +485,21 @@ def _sanitize_filename(label: str) -> str: used_output_names: dict[str, int] = {} written_files: list[str] = [] + skipped_perts: list[str] = [] + + try: + existing_output_names = { + os.path.splitext(fname)[0] + for fname in os.listdir(output_dir) + if fname.endswith(".h5ad") + } + except OSError: + existing_output_names = set() + + num_target_perts = len(perts) with torch.no_grad(): - progress_total = len(perts) * len(perts) + progress_total = num_target_perts * num_target_perts progress_bar = tqdm( total=progress_total, desc="Combo sweeps", @@ -495,6 +507,18 @@ def _sanitize_filename(label: str) -> str: disable=args.quiet, ) for pert1 in perts: + base_name = _sanitize_filename(pert1) + occurrence_idx = used_output_names.get(base_name, -1) + 1 + used_output_names[base_name] = occurrence_idx + output_name = base_name if occurrence_idx == 0 else f"{base_name}_{occurrence_idx}" + output_path = os.path.join(output_dir, f"{output_name}.h5ad") + + if output_name in existing_output_names or os.path.exists(output_path): + skipped_perts.append(pert1) + progress_bar.update(num_target_perts) + logger.info("Skipping combos for %s; existing output at %s", pert1, output_path) + continue + per_pert_X_blocks: list[np.ndarray] = [] per_pert_latent_blocks: list[np.ndarray] = [] per_pert_obs_rows: list[dict[str, str | int]] = [] @@ -599,18 +623,13 @@ def _sanitize_filename(label: str) -> str: combo_adata.uns["inner_batch_size"] = inner_batch_size combo_adata.uns["sampled_control_indices"] = adata_ct.obs_names[sampled_idx].tolist() - output_name = _sanitize_filename(pert1) - if output_name in used_output_names: - used_output_names[output_name] += 1 - output_name = f"{output_name}_{used_output_names[output_name]}" - else: - used_output_names[output_name] = 0 - - output_path = os.path.join(output_dir, f"{output_name}.h5ad") combo_adata.write_h5ad(output_path) written_files.append(output_path) + existing_output_names.add(output_name) logger.info("Saved combos for %s with %d cells to %s", pert1, combo_adata.n_obs, output_path) progress_bar.close() logger.info("Finished writing %d combo files to %s", len(written_files), output_dir) + if skipped_perts: + logger.info("Skipped %d perturbations with existing combo outputs", len(skipped_perts)) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 53952edd..82c3545a 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -1,4 +1,5 @@ import argparse +import ast from typing import Dict, List, Optional import pandas as pd @@ -23,6 +24,12 @@ def add_arguments_infer(parser: argparse.ArgumentParser): default="drugname_drugconc", help="Column in adata.obs for perturbation labels", ) + parser.add_argument( + "--dosages", + type=str, + default=None, + help="Optional list of dosages (floats) to materialize for each perturbation, e.g. \"[0.1, 0.5, 1.0]\".", + ) parser.add_argument( "--output", type=str, @@ -168,12 +175,115 @@ def argmax_index_from_any(v, expected_dim: Optional[int]) -> Optional[int]: return int(v) return None + def parse_dosage_argument(arg_value: Optional[str]) -> List[float]: + if arg_value is None: + return [] + if isinstance(arg_value, (list, tuple)): + candidate_values = arg_value + else: + text = str(arg_value).strip() + if not text: + return [] + parsed = None + try: + parsed = ast.literal_eval(text) + except (ValueError, SyntaxError): + parsed = None + if isinstance(parsed, (list, tuple)): + candidate_values = parsed + else: + text = text.strip("[]") + parts = [p for p in text.replace(",", " ").split() if p] + candidate_values = parts + deduped: List[float] = [] + seen: set[float] = set() + for value in candidate_values: + try: + val = float(value) + except (TypeError, ValueError): + raise ValueError(f"Invalid dosage value '{value}' in --dosages argument.") + key = round(val, 12) + if key not in seen: + seen.add(key) + deduped.append(val) + return deduped + + def extend_perturbation_map_for_dosages( + pert_map: Dict[str, torch.Tensor], + requested_dosages: List[float], + *, + control_label: Optional[str], + quiet: bool, + ) -> List[str]: + if not requested_dosages: + return [] + + def almost_equal(a: float, b: float, tol: float = 1e-9) -> bool: + return abs(a - b) <= tol + + canonical_vectors: Dict[tuple[str, Optional[str]], Dict[str, object]] = {} + for key, vec in pert_map.items(): + key_str = str(key) + if control_label is not None and key_str == control_label: + continue + try: + parsed = ast.literal_eval(key_str) + except (ValueError, SyntaxError): + continue + if not isinstance(parsed, (list, tuple)) or len(parsed) != 1: + continue + entry = parsed[0] + if not isinstance(entry, (list, tuple)) or len(entry) < 2: + continue + pert_name = str(entry[0]) + unit = str(entry[2]) if len(entry) > 2 else None + try: + dose_val = float(entry[1]) + except (TypeError, ValueError): + continue + + base_key = (pert_name, unit) + base_info = canonical_vectors.setdefault( + base_key, + { + "template_key": key_str, + "unit": unit, + "existing": [], + "vector": vec, + }, + ) + existing: List[float] = base_info["existing"] # type: ignore[assignment] + if not any(almost_equal(dose_val, existing_dose) for existing_dose in existing): + existing.append(dose_val) + # Prefer first encountered vector as canonical; assume all are equivalent + + added_keys: List[str] = [] + for (pert_name, unit), info in canonical_vectors.items(): + vector: torch.Tensor = info["vector"] # type: ignore[assignment] + existing: List[float] = info["existing"] # type: ignore[assignment] + for dosage in requested_dosages: + if any(almost_equal(dosage, existing_dose) for existing_dose in existing): + continue + if unit is None: + new_entry = [(pert_name, float(dosage))] + else: + new_entry = [(pert_name, float(dosage), unit)] + key_str = str(new_entry) + if key_str in pert_map: + continue + pert_map[key_str] = vector.clone() + added_keys.append(key_str) + if added_keys and not quiet: + print(f"Extended perturbation map with {len(added_keys)} dosage variants.") + return added_keys + def prepare_batch( ctrl_basal_np: np.ndarray, pert_onehots: torch.Tensor, batch_indices: Optional[torch.Tensor], pert_names: List[str], device: torch.device, + pert_dosage: Optional[float] = None, ) -> Dict[str, torch.Tensor | List[str]]: """ Construct a model batch with variable-length sentence (B=1, S=T, ...). @@ -187,6 +297,14 @@ def prepare_batch( } if batch_indices is not None: batch["batch"] = batch_indices.to(device) # [T] + if pert_dosage is not None: + seq_len = X_batch.shape[0] + batch["pert_dosage"] = torch.full( + (seq_len,), + float(pert_dosage), + dtype=torch.float32, + device=device, + ) return batch def pad_adata_with_tsv( @@ -340,6 +458,12 @@ def pad_adata_with_tsv( print(f"Control perturbation: {control_pert}") control_pert_str = str(control_pert) + requested_dosages = parse_dosage_argument(args.dosages) + if requested_dosages and not args.quiet: + print(f"Requested dosages: {requested_dosages}") + if requested_dosages and not args.all_perts and not args.quiet: + print("Note: --dosages provided without --all-perts; only dosages present in AnnData will be simulated.") + # choose cell type column if args.celltype_col is None: ct_from_cfg = None @@ -379,6 +503,14 @@ def pad_adata_with_tsv( if not os.path.exists(pert_onehot_map_path): raise FileNotFoundError(f"Missing pert_onehot_map.pt at {pert_onehot_map_path}") pert_onehot_map: Dict[str, torch.Tensor] = torch.load(pert_onehot_map_path, weights_only=False) + added_dosage_keys = extend_perturbation_map_for_dosages( + pert_map=pert_onehot_map, + requested_dosages=requested_dosages, + control_label=control_pert_str, + quiet=args.quiet, + ) + if requested_dosages and not added_dosage_keys and not args.quiet: + print("No new dosage variants were added; requested values may already exist in the perturbation map.") pert_name_lookup: Dict[str, object] = {str(k): k for k in pert_onehot_map.keys()} pert_names_in_map: List[str] = list(pert_name_lookup.keys()) @@ -717,6 +849,11 @@ def group_control_indices(group_name: str) -> np.ndarray: vec = default_pert_vec if not args.quiet: print(f" (group {g}) pert '{p}' not in mapping; using control fallback one-hot.") + dosage_value = ( + StateTransitionPerturbationModel._parse_dosage_from_name(p) + if getattr(model, "use_dosage_encoder", False) + else None + ) start = 0 while start < len(idxs): @@ -744,6 +881,7 @@ def group_control_indices(group_name: str) -> np.ndarray: batch_indices=bi, pert_names=[p] * win_size, device=model_device, + pert_dosage=dosage_value, ) batch_out = model.predict_step(batch, batch_idx=0, padded=False) From 6983bfecdc1b49dee9c8258c12803a58daad0990 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Wed, 8 Oct 2025 23:13:54 +0000 Subject: [PATCH 29/38] added hill prior --- src/state/configs/model/state.yaml | 15 ++++- src/state/tx/models/state_transition.py | 89 ++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 4 deletions(-) diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index c5128319..8b6baf5d 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -5,7 +5,7 @@ device: cuda kwargs: cell_set_len: 512 blur: 0.05 - hidden_dim: 768 # hidden dimension going into the transformer backbone + hidden_dim: 768 loss: energy confidence_token: False n_encoder_layers: 1 @@ -20,15 +20,23 @@ kwargs: use_batch_token: False nb_decoder: False mask_attn: False - dosage: False + + # --- Dose handling --- + dosage: False # was False + hill_prior: True # NEW: turn on FiLM + Hill/Emax gate + smoothing + dose_momentum: 0.01 # NEW: EMA for log10-dose mean/std + dose_strength_init: 1.0 # NEW: initial strength of FiLM modulation + dose_smooth_weight: 0.01 # NEW: curvature penalty across doses + use_effect_gating_token: False distributional_loss: energy init_from: null mmd_num_chunks: 1 randomize_mmd_chunks: false + transformer_backbone_key: llama transformer_backbone_kwargs: - bidirectional_attention: false + bidirectional_attention: true # was false; matches the CLI you used max_position_embeddings: ${model.kwargs.cell_set_len} hidden_size: ${model.kwargs.hidden_dim} intermediate_size: 3072 @@ -46,6 +54,7 @@ kwargs: tie_word_embeddings: false rotary_dim: 0 use_rotary_embeddings: false + lora: enable: false r: 16 diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index 4001a453..21aaf67c 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -1,6 +1,6 @@ import ast import logging -from typing import Dict, Optional, Tuple +import math import anndata as ad import numpy as np @@ -9,6 +9,7 @@ import torch.nn.functional as F from geomloss import SamplesLoss +from typing import Dict, Optional, Tuple from .base import PerturbationModel from .decoders import FinetuneVCICountsDecoder @@ -625,6 +626,80 @@ def _parse_dosage_from_name(name: Optional[str]) -> float: return 0.0 + @staticmethod + def _parse_drug_from_name(name: Optional[str]) -> str: + """Best-effort extraction of the base drug identifier from a perturbation name.""" + if not isinstance(name, str): + return "unknown" + try: + parsed = ast.literal_eval(name) + if isinstance(parsed, (list, tuple)) and len(parsed) > 0: + first_entry = parsed[0] + if isinstance(first_entry, (list, tuple)) and len(first_entry) > 0: + return str(first_entry[0]) + except (ValueError, SyntaxError): + pass + # Fallback: strip common separators if present + return name.split("@")[0].split("|")[0] + + def _dose_smoothness_loss(self, batch: Dict[str, torch.Tensor], pred: torch.Tensor, padded: bool) -> torch.Tensor: + """ + Encourage a smooth (low curvature) trajectory across log-dose for the same drug within the minibatch. + 'pred' is [B,S,D] (set of cells per dose). We reduce over S (cells) first. + Requires 'pert_dosage' (or parseable names) to be present; otherwise returns 0. + """ + if not self.use_hill_prior or self.dose_smooth_weight <= 0.0: + return pred.new_tensor(0.0) + + B, S, D = pred.shape + device = pred.device + + # One dose per sentence + dose = self._prepare_dosage_tensor(batch, device, (B, S)) + if dose is None: + return pred.new_tensor(0.0) + dose_per_sentence = dose[:, 0, 0] # [B] + + # One drug label per sentence (best effort) + groups = None + names = batch.get("pert_name", None) + if names is not None: + if isinstance(names, torch.Tensor): + names_list = names.reshape(-1).tolist() + else: + names_list = list(names) + if len(names_list) >= B * S: + per_sentence = [names_list[i * S] for i in range(B)] + elif len(names_list) >= B: + per_sentence = [names_list[i] for i in range(B)] + else: + per_sentence = None + if per_sentence is not None: + groups = [self._parse_drug_from_name(n) for n in per_sentence] + if groups is None: + groups = ["__all__"] * B # fall back to one pooled group + + # Reduce each sentence to a set-level vector + set_pred = pred.mean(dim=1) # [B, D] + + buckets: Dict[str, list] = {} + for i, g in enumerate(groups): + buckets.setdefault(g, []).append((dose_per_sentence[i].item(), i)) + + losses = [] + for _, lst in buckets.items(): + if len(lst) < 3: + continue + lst.sort(key=lambda x: x[0]) # ascending dose + idx = [i for _, i in lst] + series = set_pred[idx] # [Nd, D] + second = series[2:] - 2 * series[1:-1] + series[:-2] + losses.append((second ** 2).mean()) + + if not losses: + return pred.new_tensor(0.0) + return torch.stack(losses).mean() + def _compute_distribution_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Apply the primary distributional loss, optionally chunking feature dimensions for SamplesLoss.""" @@ -787,6 +862,12 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T # Add regularization to total loss total_loss = total_loss + self.regularization * l1_loss + if self.use_hill_prior and self.dose_smooth_weight > 0.0: + with torch.no_grad() if not self.training else torch.enable_grad(): + smooth_loss = self._dose_smoothness_loss(batch, pred, padded=padded) + self.log("train/dose_smooth_loss", smooth_loss) + total_loss = total_loss + self.dose_smooth_weight * smooth_loss + return total_loss def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: @@ -848,6 +929,12 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non self.log("val/confidence_loss", confidence_loss) self.log("val/actual_loss", confidence_targets.mean()) + # Validation analogue of curvature penalty + if self.use_hill_prior and self.dose_smooth_weight > 0.0: + smooth_loss = self._dose_smoothness_loss(batch, pred, padded=True) + self.log("val/dose_smooth_loss", smooth_loss) + loss = loss + self.dose_smooth_weight * smooth_loss + return {"loss": loss, "predictions": pred} def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: From 647901a51bd4c98908e11a06757bb62ca153b983 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Wed, 8 Oct 2025 23:15:48 +0000 Subject: [PATCH 30/38] fix indentation --- src/state/tx/models/state_transition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index 21aaf67c..a60ded5e 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -634,7 +634,7 @@ def _parse_drug_from_name(name: Optional[str]) -> str: try: parsed = ast.literal_eval(name) if isinstance(parsed, (list, tuple)) and len(parsed) > 0: - first_entry = parsed[0] + first_entry = parsed[0] if isinstance(first_entry, (list, tuple)) and len(first_entry) > 0: return str(first_entry[0]) except (ValueError, SyntaxError): From 132bd00e81fa5064d7c0ebe24ccf49c92e3a3f78 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Wed, 8 Oct 2025 23:16:42 +0000 Subject: [PATCH 31/38] fix indentation --- src/state/tx/models/state_transition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index a60ded5e..097a5b1b 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -661,7 +661,7 @@ def _dose_smoothness_loss(self, batch: Dict[str, torch.Tensor], pred: torch.Tens dose_per_sentence = dose[:, 0, 0] # [B] # One drug label per sentence (best effort) - groups = None + groups = None names = batch.get("pert_name", None) if names is not None: if isinstance(names, torch.Tensor): From 3819fe564cb7f2da79a43ce0e44c2750355f990f Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Wed, 8 Oct 2025 23:26:39 +0000 Subject: [PATCH 32/38] updated dosage impl --- src/state/tx/models/state_transition.py | 74 +++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index 097a5b1b..959b10d5 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -97,6 +97,28 @@ def extract_confidence_prediction(self, transformer_output: torch.Tensor) -> Tup return main_output, confidence_pred +class HillGate(nn.Module): + """ + Monotone, saturating gate w(d) for dose d that multiplies the residual. + For each hidden unit h: + w_h(d) = softplus(Emax_h) * (d/EC50_h)^{n_h} / (1 + (d/EC50_h)^{n_h}) + The forward expects log10-dose with shape [B,S,1] (standardized is fine). + """ + def __init__(self, hidden_dim: int): + super().__init__() + self.log_ec50 = nn.Parameter(torch.zeros(hidden_dim)) # EC50 = exp(log_ec50) > 0 + self.emax = nn.Parameter(torch.ones(hidden_dim)) # softplus(emax) > 0 + self.hill = nn.Parameter(torch.ones(hidden_dim)) # softplus(hill) > 0 + + def forward(self, log10_d: torch.Tensor) -> torch.Tensor: + # Convert log10-dose to linear space and broadcast to hidden dim + d = (10.0 ** log10_d).clamp_min(1e-12) # [B,S,1] + n = F.softplus(self.hill).view(1, 1, -1) # [1,1,H] + emax = F.softplus(self.emax).view(1, 1, -1) # [1,1,H] + ec50 = self.log_ec50.exp().view(1, 1, -1) # [1,1,H] + ratio_pow = (d / ec50).pow(n) # [B,S,H] + w = emax * ratio_pow / (1.0 + ratio_pow + 1e-12) # [B,S,H], in (0, emax] + return w class StateTransitionPerturbationModel(PerturbationModel): """ @@ -296,11 +318,52 @@ def __init__( self.confidence_weight = 0.0 self.use_dosage_encoder = bool(kwargs.get("dosage", False)) - if self.use_dosage_encoder: - self.dosage_encoder = nn.Linear(1, self.hidden_dim) + # Feature flag: pharmacologically-informed dose handling + self.use_hill_prior = bool(kwargs.get("hill_prior", False)) + if self.use_hill_prior: + # Running stats for standardized log10-dose + self.register_buffer("dose_mean", torch.tensor(0.0)) + self.register_buffer("dose_std", torch.tensor(1.0)) + self.dose_momentum = float(kwargs.get("dose_momentum", 0.01)) + # Strength of FiLM modulation + self.dose_strength = nn.Parameter(torch.tensor(float(kwargs.get("dose_strength_init", 1.0)))) + # FiLM network on standardized log10-dose + self.dose_film = nn.Sequential( + nn.Linear(1, 128), + nn.SiLU(), + nn.Linear(128, 2 * self.hidden_dim), + ) + # Hill/Emax gate on residual + self.hill_gate = HillGate(self.hidden_dim) + # Small curvature penalty across doses of the same drug + self.dose_smooth_weight = float(kwargs.get("dose_smooth_weight", 0.01)) else: - self.dosage_encoder = None - self._warned_missing_dosage = False + self.dose_smooth_weight = 0.0 + + # Dose conditioning + logd_norm = None # keep to apply Hill gate after the transformer + if self.use_dosage_encoder: + dosage_tensor = self._prepare_dosage_tensor(batch, seq_input.device, pert.shape[:2]) + if dosage_tensor is not None: + if self.use_hill_prior: + # log10 and standardize with running stats + logd = torch.log10(dosage_tensor.clamp_min(1e-9)) # [B,S,1] + if self.training: + with torch.no_grad(): + bmean = logd.mean() + bstd = logd.std(unbiased=False).clamp_min(1e-6) + self.dose_mean = (1 - self.dose_momentum) * self.dose_mean + self.dose_momentum * bmean + self.dose_std = (1 - self.dose_momentum) * self.dose_std + self.dose_momentum * bstd + logd_norm = (logd - self.dose_mean) / (self.dose_std + 1e-6) + # FiLM modulation: seq_input <- (1 + α*(γ-1))*x + α*β + film_params = self.dose_film(logd_norm) # [B,S,2H] + gamma, beta = film_params.chunk(2, dim=-1) + gamma = F.softplus(gamma) # positive scale + seq_input = (1 + self.dose_strength * (gamma - 1)) * seq_input + self.dose_strength * beta + else: + # Legacy additive dose features + dosage_features = self.dosage_encoder(torch.log1p(dosage_tensor)) + seq_input = seq_input + dosage_features # Backward-compat: accept legacy key `freeze_pert` self.freeze_pert_backbone = kwargs.get("freeze_pert_backbone", kwargs.get("freeze_pert", False)) @@ -518,6 +581,9 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: res_pred = transformer_output self._batch_token_cache = None + # Apply a monotone, saturating gate on the residual if enabled + if self.use_hill_prior and logd_norm is not None: + res_pred = res_pred * self.hill_gate(logd_norm) # [B,S,H]×[B,S,H] # Cache token features for auxiliary batch prediction loss (B, S, H) self._token_features = res_pred From 41f38fdd8546f4d8e3e82bd390f81d65c6a43301 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Thu, 9 Oct 2025 10:04:47 -0700 Subject: [PATCH 33/38] added virtual cells per pert limiter --- src/state/_cli/_tx/_infer.py | 23 +++++++++++++ src/state/configs/model/state.yaml | 2 +- src/state/tx/models/state_transition.py | 46 ++++++++++--------------- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 82c3545a..dc14415a 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -94,6 +94,12 @@ def add_arguments_infer(parser: argparse.ArgumentParser): action="store_true", help="If set, add virtual copies of control cells for every perturbation in the saved one-hot map so all perturbations are simulated.", ) + parser.add_argument( + "--virtual-cells-per-pert", + type=int, + default=None, + help="When using --all-perts, limit the number of control cells cloned for each virtual perturbation to this many (default: use all available controls).", + ) parser.add_argument( "--min-cells", type=int, @@ -599,6 +605,23 @@ def pad_adata_with_tsv( ctrl_template.obs = ctrl_template.obs.copy() ctrl_template.obs[args.pert_col] = ctrl_template.obs[args.pert_col].astype(str) + if args.virtual_cells_per_pert is not None: + if args.virtual_cells_per_pert <= 0: + raise ValueError("--virtual-cells-per-pert must be a positive integer if provided.") + if ctrl_template.n_obs > args.virtual_cells_per_pert: + virtual_rng = np.random.RandomState(args.seed) + sampled_idx = virtual_rng.choice( + ctrl_template.n_obs, size=args.virtual_cells_per_pert, replace=False + ) + ctrl_template = ctrl_template[sampled_idx].copy() + ctrl_template.obs = ctrl_template.obs.copy() + ctrl_template.obs[args.pert_col] = ctrl_template.obs[args.pert_col].astype(str) + if not args.quiet: + print( + "--all-perts: limiting virtual control template to " + f"{ctrl_template.n_obs} cells per perturbation (requested {args.virtual_cells_per_pert})." + ) + virtual_blocks: List["sc.AnnData"] = [] for pert_name in missing_perts: clone = ctrl_template.copy() diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index 8b6baf5d..9ce6f52e 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -23,7 +23,7 @@ kwargs: # --- Dose handling --- dosage: False # was False - hill_prior: True # NEW: turn on FiLM + Hill/Emax gate + smoothing + hill_prior: False# NEW: turn on FiLM + Hill/Emax gate + smoothing dose_momentum: 0.01 # NEW: EMA for log10-dose mean/std dose_strength_init: 1.0 # NEW: initial strength of FiLM modulation dose_smooth_weight: 0.01 # NEW: curvature penalty across doses diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index 959b10d5..412cac82 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -318,6 +318,8 @@ def __init__( self.confidence_weight = 0.0 self.use_dosage_encoder = bool(kwargs.get("dosage", False)) + self.dosage_encoder = nn.Linear(1, self.hidden_dim) if self.use_dosage_encoder else None + self._warned_missing_dosage = False # Feature flag: pharmacologically-informed dose handling self.use_hill_prior = bool(kwargs.get("hill_prior", False)) if self.use_hill_prior: @@ -340,31 +342,6 @@ def __init__( else: self.dose_smooth_weight = 0.0 - # Dose conditioning - logd_norm = None # keep to apply Hill gate after the transformer - if self.use_dosage_encoder: - dosage_tensor = self._prepare_dosage_tensor(batch, seq_input.device, pert.shape[:2]) - if dosage_tensor is not None: - if self.use_hill_prior: - # log10 and standardize with running stats - logd = torch.log10(dosage_tensor.clamp_min(1e-9)) # [B,S,1] - if self.training: - with torch.no_grad(): - bmean = logd.mean() - bstd = logd.std(unbiased=False).clamp_min(1e-6) - self.dose_mean = (1 - self.dose_momentum) * self.dose_mean + self.dose_momentum * bmean - self.dose_std = (1 - self.dose_momentum) * self.dose_std + self.dose_momentum * bstd - logd_norm = (logd - self.dose_mean) / (self.dose_std + 1e-6) - # FiLM modulation: seq_input <- (1 + α*(γ-1))*x + α*β - film_params = self.dose_film(logd_norm) # [B,S,2H] - gamma, beta = film_params.chunk(2, dim=-1) - gamma = F.softplus(gamma) # positive scale - seq_input = (1 + self.dose_strength * (gamma - 1)) * seq_input + self.dose_strength * beta - else: - # Legacy additive dose features - dosage_features = self.dosage_encoder(torch.log1p(dosage_tensor)) - seq_input = seq_input + dosage_features - # Backward-compat: accept legacy key `freeze_pert` self.freeze_pert_backbone = kwargs.get("freeze_pert_backbone", kwargs.get("freeze_pert", False)) if self.freeze_pert_backbone: @@ -502,11 +479,26 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: combined_input = pert_embedding + control_cells # Shape: [B, S, hidden_dim] seq_input = combined_input # Shape: [B, S, hidden_dim] + logd_norm: Optional[torch.Tensor] = None if self.use_dosage_encoder: dosage_tensor = self._prepare_dosage_tensor(batch, seq_input.device, pert.shape[:2]) if dosage_tensor is not None: - dosage_features = self.dosage_encoder(torch.log1p(dosage_tensor)) - seq_input = seq_input + dosage_features + if self.use_hill_prior: + logd = torch.log10(dosage_tensor.clamp_min(1e-9)) # [B,S,1] + if self.training: + with torch.no_grad(): + bmean = logd.mean() + bstd = logd.std(unbiased=False).clamp_min(1e-6) + self.dose_mean = (1 - self.dose_momentum) * self.dose_mean + self.dose_momentum * bmean + self.dose_std = (1 - self.dose_momentum) * self.dose_std + self.dose_momentum * bstd + logd_norm = (logd - self.dose_mean) / (self.dose_std + 1e-6) + film_params = self.dose_film(logd_norm) + gamma, beta = film_params.chunk(2, dim=-1) + gamma = F.softplus(gamma) + seq_input = (1 + self.dose_strength * (gamma - 1)) * seq_input + self.dose_strength * beta + elif self.dosage_encoder is not None: + dosage_features = self.dosage_encoder(torch.log1p(dosage_tensor)) + seq_input = seq_input + dosage_features if self.batch_encoder is not None: # Extract batch indices (assume they are integers or convert from one-hot) From 598e2990a7ec4620e587a522757299d8244e3f06 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 14 Oct 2025 20:28:54 +0000 Subject: [PATCH 34/38] updating flops callback logic --- src/state/_cli/_tx/_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 08d0ef9a..385451ea 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -234,7 +234,7 @@ def run_tx_train(cfg: DictConfig): callbacks.append(mfu_cb) - if "cumulative_flops_use_backward" in cfg["training"]: + if "cumulative_flops_use_backward" in cfg["training"] and cfg["model"]["name"] == "state": cumulative_flops_use_backward = cfg["training"]["cumulative_flops_use_backward"] cumulative_flops_cb = CumulativeFLOPSCallback(use_backward=cumulative_flops_use_backward) callbacks.append(cumulative_flops_cb) From da50a88d6a7d5e51319ed29da94aac03a30fbc1a Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 14 Oct 2025 20:46:21 +0000 Subject: [PATCH 35/38] added toml option for predict script --- src/state/_cli/_tx/_predict.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index cc6cabf5..2182b702 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -12,6 +12,12 @@ def add_arguments_predict(parser: ap.ArgumentParser): required=True, help="Path to the output_dir containing the config.yaml file that was saved during training.", ) + parser.add_argument( + "--toml", + type=str, + default=None, + help="Optional path to a TOML data config to use instead of the training config.", + ) parser.add_argument( "--checkpoint", type=str, @@ -130,12 +136,34 @@ def load_config(cfg_path: str) -> dict: cfg = load_config(config_path) logger.info(f"Loaded config from {config_path}") + if args.toml: + data_section = cfg.get("data") + if data_section is None or "kwargs" not in data_section: + raise KeyError( + "The loaded config does not contain data.kwargs, unable to override toml_config_path." + ) + cfg["data"]["kwargs"]["toml_config_path"] = args.toml + logger.info("Overriding data.kwargs.toml_config_path to %s", args.toml) + # 2. Find run output directory & load data module run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) data_module_path = os.path.join(run_output_dir, "data_module.torch") if not os.path.exists(data_module_path): raise FileNotFoundError(f"Could not find data module at {data_module_path}?") data_module = PerturbationDataModule.load_state(data_module_path) + if args.toml: + if not os.path.exists(args.toml): + raise FileNotFoundError(f"Could not find TOML config file at {args.toml}") + from cell_load.config import ExperimentConfig + + logger.info("Reloading data module configuration from %s", args.toml) + data_module.toml_config_path = args.toml + data_module.config = ExperimentConfig.from_toml(args.toml) + data_module.config.validate() + data_module.train_datasets = [] + data_module.val_datasets = [] + data_module.test_datasets = [] + data_module._setup_global_maps() data_module.setup(stage="test") logger.info("Loaded data module from %s", data_module_path) From f0509b456b5468cb0c1703a4c6f6c1762e157730 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Thu, 16 Oct 2025 22:43:08 +0000 Subject: [PATCH 36/38] updated yaml to fix typo --- src/state/configs/model/state.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index 9ce6f52e..e47217e6 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -22,8 +22,8 @@ kwargs: mask_attn: False # --- Dose handling --- - dosage: False # was False - hill_prior: False# NEW: turn on FiLM + Hill/Emax gate + smoothing + dosage: False + hill_prior: False dose_momentum: 0.01 # NEW: EMA for log10-dose mean/std dose_strength_init: 1.0 # NEW: initial strength of FiLM modulation dose_smooth_weight: 0.01 # NEW: curvature penalty across doses From a8cbbce5b204d342c1a76a5c77fe14e59ef5337f Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Thu, 16 Oct 2025 20:41:08 -0700 Subject: [PATCH 37/38] updated infer to also output predicted counts --- src/state/_cli/_tx/_infer.py | 70 ++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index dc14415a..9927623a 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -797,12 +797,26 @@ def pad_adata_with_tsv( # Where we will write predictions (initialize with originals; we overwrite all rows, including controls) if writes_to[0] == ".X": - sim_X = X_in.copy() + sim_X = X_in.astype(np.float32, copy=True) out_target = "X" else: - sim_obsm = X_in.copy() + sim_obsm = X_in.astype(np.float32, copy=True) out_target = f"obsm['{writes_to[1]}']" + counts_expected = output_space in {"gene", "all"} + counts_out_target: Optional[str] = None + counts_obsm_key: Optional[str] = None + sim_counts: Optional[np.ndarray] = None + counts_written = False + + if output_space == "gene": + counts_out_target = "obsm['X_hvg']" + counts_obsm_key = "X_hvg" + elif output_space == "all": + counts_out_target = "X" + if writes_to[0] == ".X": + sim_counts = sim_X + # Group labels for set-to-set behavior if args.celltype_col and args.celltype_col in adata.obs: group_labels = adata.obs[args.celltype_col].astype(str).values @@ -920,6 +934,43 @@ def group_control_indices(group_name: str) -> np.ndarray: else: preds = batch_out["preds"].detach().cpu().numpy().astype(np.float32) # [win, D] + counts_preds = None + if counts_expected and ("pert_cell_counts_preds" in batch_out): + counts_tensor = batch_out.get("pert_cell_counts_preds") + if counts_tensor is not None: + counts_preds = counts_tensor.detach().cpu().numpy().astype(np.float32) + + if counts_preds is not None: + if sim_counts is None: + target_dim = counts_preds.shape[1] + if output_space == "gene": + if counts_obsm_key and counts_obsm_key in adata.obsm: + existing = np.asarray(adata.obsm[counts_obsm_key]) + if existing.shape[1] == target_dim: + sim_counts = existing.astype(np.float32, copy=True) + else: + if not args.quiet: + print( + f"Dimension mismatch for existing obsm['{counts_obsm_key}'] " + f"(got {existing.shape[1]} vs predictions {target_dim}). " + "Reinitializing storage with zeros." + ) + sim_counts = np.zeros((n_total, target_dim), dtype=np.float32) + else: + sim_counts = np.zeros((n_total, target_dim), dtype=np.float32) + else: # output_space == "all" + if writes_to[0] == ".X": + sim_counts = sim_X + else: + sim_counts = np.zeros((n_total, target_dim), dtype=np.float32) + if sim_counts.shape[1] != counts_preds.shape[1]: + raise ValueError( + "Predicted counts dimension mismatch: " + f"expected {sim_counts.shape[1]} but got {counts_preds.shape[1]}" + ) + sim_counts[idx_window, :] = counts_preds + counts_written = True + # 6) Write predictions for these rows (controls included) if writes_to[0] == ".X": if preds.shape[1] == sim_X.shape[1]: @@ -958,6 +1009,12 @@ def group_control_indices(group_name: str) -> np.ndarray: output_path = args.output or args.adata.replace(".h5ad", "_simulated.h5ad") output_is_npy = output_path.lower().endswith(".npy") + if counts_expected and not counts_written and not args.quiet: + print( + "Warning: Model configured to produce gene counts, but no predicted counts were returned; " + "counts will not be saved." + ) + pred_matrix = None if writes_to[0] == ".X": if out_target == "X": @@ -978,6 +1035,13 @@ def group_control_indices(group_name: str) -> np.ndarray: else: pred_matrix = sim_obsm + if counts_written and sim_counts is not None: + if output_space == "gene": + key = counts_obsm_key or "X_hvg" + adata.obsm[key] = sim_counts + elif output_space == "all": + adata.X = sim_counts + if output_is_npy: if pred_matrix is None: raise ValueError("Predictions matrix is unavailable; cannot write .npy output") @@ -999,3 +1063,5 @@ def group_control_indices(group_name: str) -> np.ndarray: else: print(f"Wrote predictions to adata.{out_target}") print(f"Saved: {output_path}") + if counts_written and counts_out_target: + print(f"Saved count predictions to adata.{counts_out_target}") From ed98abf307ee64b158933ce549eb0cee9a4f9b34 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 21 Oct 2025 16:37:34 -0700 Subject: [PATCH 38/38] updated cell eval to 0.6.2 --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 84cdb1d4..94feeaac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,13 +27,14 @@ dependencies = [ "geomloss>=0.2.6", "transformers>=4.52.3", "peft>=0.11.0", - "cell-eval>=0.6.0", + "cell-eval>=0.6.2", "ipykernel>=6.30.1", "scipy>=1.15.0", ] [tool.uv.sources] cell-load = {path = "/home/aadduri/cell-load"} +cell-eval = {git = "https://github.com/ArcInstitute/cell-eval", branch = "aadduri/aupr_curves"} [project.optional-dependencies] vectordb = [