Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
721d37d
transfer with hvg
abhinadduri Sep 14, 2025
0325c7f
updated files to initial version with fine tuning the decoder
abhinadduri Sep 14, 2025
f45e06d
fixed model to use the pretrained decoder but with a for loop
abhinadduri Sep 14, 2025
3f41fc7
working impl for 1024 and 2048 dim state impls
abhinadduri Sep 14, 2025
16fed4e
updated with correct fine tuning, gradient accumulation
abhinadduri Sep 15, 2025
5e0e715
added batch predictor implementation
abhinadduri Sep 15, 2025
f6307af
added pseudobulk model file
abhinadduri Sep 15, 2025
15d95a8
updated to fix pseudobulk model during eval
abhinadduri Sep 16, 2025
f40291d
fixed vci fine tune decoder setup
abhinadduri Sep 16, 2025
5aa2252
bumping semvar
abhinadduri Sep 16, 2025
848859c
fixed using embedding for the output space
abhinadduri Sep 17, 2025
b137ae6
included fix to transform wit new checkpoints
abhinadduri Sep 19, 2025
bd5bd29
distributed sampler commit
abhinadduri Sep 21, 2025
154b555
bump cell load requirement
abhinadduri Sep 21, 2025
43eff46
updated so that .npy output for files will write out just embeddings …
abhinadduri Sep 23, 2025
9819805
removed periodic checkpointing
abhinadduri Sep 26, 2025
e5bf17b
updated checkpoint logic to only store a best.ckpt
abhinadduri Sep 29, 2025
0df3dbc
added split batch option that relies on cell-eval having the split ba…
abhinadduri Sep 29, 2025
19933ed
udpated cell eval to 0.6
abhinadduri Oct 2, 2025
1738e6d
added multi head MMD for optimization
abhinadduri Oct 2, 2025
33becb5
removed batch split for cell eval
abhinadduri Oct 3, 2025
d3fa227
chore: formatting
abhinadduri Oct 3, 2025
116cc8b
added combo perturbation for rpe1
abhinadduri Oct 3, 2025
7e7f022
updated infer to take max and min cells, and all perts argument
abhinadduri Oct 7, 2025
ed00c06
small fix
abhinadduri Oct 8, 2025
8ded763
chore: formatting
abhinadduri Oct 8, 2025
358b1a4
Fix pseudobulk model for wandb tracking
abhinadduri Oct 8, 2025
edddc31
added dosage arg
abhinadduri Oct 8, 2025
6983bfe
added hill prior
abhinadduri Oct 8, 2025
647901a
fix indentation
abhinadduri Oct 8, 2025
132bd00
fix indentation
abhinadduri Oct 8, 2025
3819fe5
updated dosage impl
abhinadduri Oct 8, 2025
41f38fd
added virtual cells per pert limiter
abhinadduri Oct 9, 2025
598e299
updating flops callback logic
abhinadduri Oct 14, 2025
da50a88
added toml option for predict script
abhinadduri Oct 14, 2025
f0509b4
updated yaml to fix typo
abhinadduri Oct 16, 2025
a8cbbce
updated infer to also output predicted counts
abhinadduri Oct 17, 2025
bdb3d02
Merge branch 'aadduri/simulation' of github.com:ArcInstitute/state-se…
abhinadduri Oct 17, 2025
ed98abf
updated cell eval to 0.6.2
abhinadduri Oct 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
[project]
name = "arc-state"
version = "0.9.31"
version = "0.9.32"
description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts."
readme = "README.md"
authors = [
{ name = "Abhinav Adduri", email = "abhinav.adduri@arcinstitute.org" },
{ name = "Yusuf Roohani", email = "yusuf.roohani@arcinstitute.org" },
{ name = "Noam Teyssier", email = "noam.teyssier@arcinstitute.org" },
{ name = "Rajesh Ilango" },
{ name = "Rajesh Ilango", email = "rilango@gmail.com" },
{ name = "Dhruv Gautam", email = "dhruvgautam@berkeley.edu" },
]
requires-python = ">=3.10,<3.13"
dependencies = [
"anndata>=0.11.4",
"cell-load>=0.8.3",
"cell-load>=0.8.7",
"numpy>=2.2.6",
"pandas>=2.2.3",
"pyyaml>=6.0.2",
Expand All @@ -27,11 +27,15 @@ dependencies = [
"geomloss>=0.2.6",
"transformers>=4.52.3",
"peft>=0.11.0",
"cell-eval>=0.5.22",
"cell-eval>=0.6.2",
"ipykernel>=6.30.1",
"scipy>=1.15.0",
]

[tool.uv.sources]
cell-load = {path = "/home/aadduri/cell-load"}
cell-eval = {git = "https://github.com/ArcInstitute/cell-eval", branch = "aadduri/aupr_curves"}

[project.optional-dependencies]
vectordb = [
"lancedb>=0.24.0"
Expand Down
50 changes: 0 additions & 50 deletions scripts/state_embed_anndata.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/state/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
run_emb_query,
run_emb_preprocess,
run_emb_eval,
run_tx_combo,
run_tx_infer,
run_tx_predict,
run_tx_preprocess_infer,
Expand Down Expand Up @@ -124,6 +125,8 @@ def main():
case "infer":
# Run inference using argparse, similar to predict
run_tx_infer(args)
case "combo":
run_tx_combo(args)
case "preprocess_train":
# Run preprocessing using argparse
run_tx_preprocess_train(args.adata, args.output, args.num_hvgs)
Expand Down
2 changes: 2 additions & 0 deletions src/state/_cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query, run_emb_preprocess, run_emb_eval
from ._tx import (
add_arguments_tx,
run_tx_combo,
run_tx_infer,
run_tx_predict,
run_tx_preprocess_infer,
Expand All @@ -16,6 +17,7 @@
"run_tx_infer",
"run_tx_preprocess_train",
"run_tx_preprocess_infer",
"run_tx_combo",
"run_emb_fit",
"run_emb_query",
"run_emb_transform",
Expand Down
57 changes: 44 additions & 13 deletions src/state/_cli/_emb/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@

def add_arguments_transform(parser: ap.ArgumentParser):
"""Add arguments for state embedding CLI."""
parser.add_argument("--model-folder", required=True, help="Path to the model checkpoint folder")
parser.add_argument("--checkpoint", required=False, help="Path to the specific model checkpoint")
parser.add_argument(
"--model-folder",
required=False,
help="Path to the model checkpoint folder (required if --checkpoint is not provided)",
)
parser.add_argument(
"--checkpoint",
required=False,
help="Path to the specific model checkpoint (required if --model-folder is not provided)",
)
parser.add_argument(
"--config",
required=False,
Expand Down Expand Up @@ -46,6 +54,7 @@ def run_emb_transform(args: ap.ArgumentParser):
import glob
import logging
import os
import numpy as np

import torch
from omegaconf import OmegaConf
Expand All @@ -60,13 +69,19 @@ def run_emb_transform(args: ap.ArgumentParser):
logger.error("Either --output or --lancedb must be provided")
raise ValueError("Either --output or --lancedb must be provided")

# look in the model folder with glob for *.ckpt, get the first one, and print it
model_files = glob.glob(os.path.join(args.model_folder, "*.ckpt"))
if not model_files:
logger.error(f"No model checkpoint found in {args.model_folder}")
raise FileNotFoundError(f"No model checkpoint found in {args.model_folder}")
if not args.checkpoint:
args.checkpoint = model_files[-1]
# Resolve checkpoint path, allowing either --checkpoint, --model-folder, or both
checkpoint_path = args.checkpoint
if args.model_folder:
model_files = glob.glob(os.path.join(args.model_folder, "*.ckpt"))
if not model_files and not checkpoint_path:
logger.error(f"No model checkpoint found in {args.model_folder}")
raise FileNotFoundError(f"No model checkpoint found in {args.model_folder}")
if not checkpoint_path and model_files:
checkpoint_path = model_files[-1]
if not checkpoint_path:
logger.error("Either --checkpoint or --model-folder must be provided")
raise ValueError("Either --checkpoint or --model-folder must be provided")
args.checkpoint = checkpoint_path
logger.info(f"Using model checkpoint: {args.checkpoint}")

# Create inference object
Expand All @@ -79,7 +94,7 @@ def run_emb_transform(args: ap.ArgumentParser):
if args.protein_embeddings:
logger.info(f"Using protein embeddings override: {args.protein_embeddings}")
protein_embeds = torch.load(args.protein_embeddings, weights_only=False, map_location="cpu")
else:
elif args.model_folder:
# Try auto-detect in model folder
try:
exact_path = os.path.join(args.model_folder, "protein_embeddings.pt")
Expand Down Expand Up @@ -110,6 +125,12 @@ def run_emb_transform(args: ap.ArgumentParser):
logger.info(f"Loading model from checkpoint: {args.checkpoint}")
inferer.load_model(args.checkpoint)

save_as_npy = False
output_target = args.output
if args.output:
_, ext = os.path.splitext(args.output)
save_as_npy = ext.lower() == ".npy"

# Create output directory if it doesn't exist
if args.output:
output_dir = os.path.dirname(args.output)
Expand All @@ -120,18 +141,28 @@ def run_emb_transform(args: ap.ArgumentParser):
# Generate embeddings
logger.info(f"Computing embeddings for {args.input}")
if args.output:
logger.info(f"Output will be saved to {args.output}")
if save_as_npy:
logger.info(f"Output embeddings will be saved to {args.output} as a NumPy array")
else:
logger.info(f"Output will be saved to {args.output}")
if args.lancedb:
logger.info(f"Embeddings will be saved to LanceDB at {args.lancedb}")

inferer.encode_adata(
embeddings = inferer.encode_adata(
input_adata_path=args.input,
output_adata_path=args.output,
output_adata_path=None if save_as_npy else output_target,
emb_key=args.embed_key,
batch_size=args.batch_size if getattr(args, "batch_size", None) is not None else None,
lancedb_path=args.lancedb,
update_lancedb=args.lancedb_update,
lancedb_batch_size=args.lancedb_batch_size,
)

if save_as_npy:
if embeddings is None:
logger.error("Failed to generate embeddings for NumPy output")
raise RuntimeError("Embedding generation returned no data")
np.save(args.output, embeddings)
logger.info(f"Saved embeddings matrix with shape {embeddings.shape} to {args.output}")

logger.info("Embedding computation completed successfully!")
3 changes: 3 additions & 0 deletions src/state/_cli/_tx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from ._preprocess_infer import add_arguments_preprocess_infer, run_tx_preprocess_infer
from ._preprocess_train import add_arguments_preprocess_train, run_tx_preprocess_train
from ._train import add_arguments_train, run_tx_train
from ._combo import add_arguments_combo, run_tx_combo

__all__ = [
"run_tx_train",
"run_tx_predict",
"run_tx_infer",
"run_tx_preprocess_train",
"run_tx_preprocess_infer",
"run_tx_combo",
"add_arguments_tx",
]

Expand All @@ -24,3 +26,4 @@ def add_arguments_tx(parser: ap.ArgumentParser):
add_arguments_infer(subparsers.add_parser("infer"))
add_arguments_preprocess_train(subparsers.add_parser("preprocess_train"))
add_arguments_preprocess_infer(subparsers.add_parser("preprocess_infer"))
add_arguments_combo(subparsers.add_parser("combo"))
Loading