Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions chemogenetic.toml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "cell-load"
version = "0.10.3"
version = "0.11.0"
description = "Dataloaders for training models on huge single-cell datasets"
readme = "README.md"
authors = [
Expand Down
58 changes: 39 additions & 19 deletions src/cell_load/_cli/filter_on_target_knockdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
"""CLI for filter_on_target_knockdown function."""

import argparse
import logging
import sys
from pathlib import Path

import anndata
import scanpy as sc

logger = logging.getLogger(__name__)


def preprocess_state_paper(adata_pp: anndata.AnnData) -> anndata.AnnData:
"""
Expand All @@ -21,14 +24,14 @@ def preprocess_state_paper(adata_pp: anndata.AnnData) -> anndata.AnnData:
Returns:
Preprocessed AnnData object
"""
print("Applying state paper preprocessing...")
logger.info("Applying state paper preprocessing.")

# 1. Normalize to 10k read depth
print(" - Normalizing to 10k read depth...")
logger.info("Normalizing to 10k read depth.")
sc.pp.normalize_total(adata_pp, target_sum=1e4)

# 2. Log transform
print(" - Log transforming...")
logger.info("Applying log1p transform.")
sc.pp.log1p(adata_pp)

return adata_pp
Expand Down Expand Up @@ -105,29 +108,44 @@ def main():
help="Apply preprocessing as in state paper: normalize to 10k read depth "
"and log transform before writing output",
)
parser.add_argument(
"--log-level",
type=str,
default="WARNING",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
help="Python logging level (default: WARNING)",
)

args = parser.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper()),
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)

# Validate input file exists
input_path = Path(args.input)
if not input_path.exists():
print(f"Error: Input file {args.input} does not exist", file=sys.stderr)
logger.error("Input file %s does not exist.", args.input)
sys.exit(1)

# Load data
try:
print(f"Loading data from {args.input}...")
logger.info("Loading data from %s.", args.input)
adata = anndata.read_h5ad(args.input)
print(f"Loaded AnnData with {adata.n_obs} cells and {adata.n_vars} genes")
logger.info(
"Loaded AnnData with %d cells and %d genes.",
adata.n_obs,
adata.n_vars,
)
except Exception as e:
print(f"Error loading data: {e}", file=sys.stderr)
logger.exception("Error loading data: %s", e)
sys.exit(1)

# Import and apply filter
try:
from ..utils.data_utils import filter_on_target_knockdown

print("Applying on-target knockdown filter...")
logger.info("Applying on-target knockdown filter.")
filtered_adata = filter_on_target_knockdown(
adata,
perturbation_column=args.perturbation_column,
Expand All @@ -139,20 +157,22 @@ def main():
var_gene_name=args.var_gene_name,
)

print(
f"Filtered to {filtered_adata.n_obs} cells and {filtered_adata.n_vars} genes"
logger.info(
"Filtered to %d cells and %d genes.",
filtered_adata.n_obs,
filtered_adata.n_vars,
)

except Exception as e:
print(f"Error applying filter: {e}", file=sys.stderr)
logger.exception("Error applying filter: %s", e)
sys.exit(1)

# Apply preprocessing if requested
if args.preprocess:
try:
filtered_adata = preprocess_state_paper(filtered_adata)
except Exception as e:
print(f"Error during preprocessing: {e}", file=sys.stderr)
logger.exception("Error during preprocessing: %s", e)
sys.exit(1)

# Save output
Expand All @@ -167,19 +187,19 @@ def main():
column_values = filtered_adata.var[filtered_adata.var.index.name].values
if not all(index_values == column_values):
# Rename the index to avoid conflict
print(
" - Fixing var index name conflict: "
f"{filtered_adata.var.index.name} -> "
f"{filtered_adata.var.index.name}_index"
logger.info(
"Fixing var index name conflict: %s -> %s",
filtered_adata.var.index.name,
f"{filtered_adata.var.index.name}_index",
)
filtered_adata.var.index.name = f"{filtered_adata.var.index.name}_index"

print(f"Saving filtered data to {args.output}...")
logger.info("Saving filtered data to %s.", args.output)
filtered_adata.write_h5ad(args.output)
print("Done!")
logger.info("Done.")

except Exception as e:
print(f"Error saving output: {e}", file=sys.stderr)
logger.exception("Error saving output: %s", e)
sys.exit(1)


Expand Down
34 changes: 29 additions & 5 deletions src/cell_load/data_modules/perturbation_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import glob
import re
import sys

from functools import partial
from pathlib import Path
Expand All @@ -17,6 +18,8 @@
from ..config import ExperimentConfig
from ..dataset import MetadataConcatDataset, PerturbationDataset
from ..mapping_strategies import BatchMappingStrategy, RandomMappingStrategy

_OUTPUT_SPACE_ALIASES: dict[str, str] = {"hvg": "gene", "transcriptome": "all"}
from ..utils.data_utils import (
GlobalH5MetadataCache,
generate_onehot_map,
Expand Down Expand Up @@ -61,15 +64,19 @@ def __init__(
toml_config_path: str,
batch_size: int = 128,
num_workers: int = 8,
pin_memory: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default for pin_memory has been changed from True (hardcoded in the previous version) to False. Memory pinning is generally recommended when training on GPUs as it speeds up data transfer from CPU to GPU. If the primary use case is GPU training, consider keeping the default as True.

random_seed: int = 42, # this should be removed by seed everything
pert_col: str = "gene",
batch_col: str = "gem_group",
cell_type_key: str = "cell_type",
control_pert: str = "non-targeting",
embed_key: Literal["X_hvg", "X_state"] | None = None,
output_space: Literal["gene", "all", "embedding"] = "gene",
output_space: Literal[
"gene", "all", "embedding", "hvg", "transcriptome"
] = "gene",
downsample: float | None = None,
downsample_cells: int | None = None,
balance_outliers: bool = False,
is_log1p: bool = True,
basal_mapping_strategy: Literal["batch", "random"] = "random",
n_basal_samples: int = 1,
Expand All @@ -81,6 +88,8 @@ def __init__(
additional_obs: list[str] | None = None,
use_consecutive_loading: bool = False,
h5_open_kwargs: dict | None = None,
show_progress: bool = True,
collate_dtype: str = "float16",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The default value for collate_dtype is set to "float16". This is a significant change from the previous implicit default of float32 (via torch.FloatTensor). While this reduces memory usage, it may lead to precision issues or compatibility problems with models expecting float32 inputs. Consider if "float32" would be a safer default for a general-purpose dataloader, or ensure this change is clearly documented.

**kwargs, # missing perturbation_features_file and store_raw_basal for backwards compatibility
):
"""
Expand Down Expand Up @@ -109,6 +118,7 @@ def __init__(
val_subsample_fraction: Fraction of validation subsets to keep (subsamples self.val_datasets)
use_consecutive_loading: Whether to form cell sets from consecutive indices for faster IO
h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes)
show_progress: Whether to display tqdm progress during dataset setup
"""
super().__init__()

Expand All @@ -120,6 +130,7 @@ def __init__(
# Experiment level params
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = bool(pin_memory)
self.random_seed = random_seed
self.rng = np.random.default_rng(random_seed)
self.drop_last = drop_last
Expand All @@ -145,7 +156,7 @@ def __init__(
self.cell_type_key = cell_type_key
self.control_pert = control_pert
self.embed_key = embed_key
self.output_space = output_space
self.output_space = _OUTPUT_SPACE_ALIASES.get(output_space, output_space)
if self.output_space not in {"gene", "all", "embedding"}:
raise ValueError(
f"output_space must be one of 'gene', 'all', or 'embedding'; got {self.output_space!r}"
Expand All @@ -168,6 +179,7 @@ def __init__(
if downsample_cells <= 0:
raise ValueError("downsample_cells must be a positive integer or None.")
self.downsample_cells = downsample_cells
self.balance_outliers = bool(balance_outliers)
self.is_log1p = bool(is_log1p)

# Sampling and mapping
Expand All @@ -185,6 +197,8 @@ def __init__(
self.barcode = kwargs.get("barcode", False)
self.additional_obs = additional_obs
self.h5_open_kwargs = h5_open_kwargs
self.show_progress = bool(show_progress)
self.collate_dtype = collate_dtype
if self.use_consecutive_loading:
self._set_h5_cache_env_defaults()

Expand Down Expand Up @@ -295,6 +309,7 @@ def save_state(self, filepath: str):
"additional_obs": self.additional_obs,
"use_consecutive_loading": self.use_consecutive_loading,
"h5_open_kwargs": self.h5_open_kwargs,
"collate_dtype": self.collate_dtype,
}

torch.save(save_dict, filepath)
Expand Down Expand Up @@ -339,6 +354,7 @@ def load_state(cls, filepath: str):
"barcode": save_dict.pop("barcode", True),
"use_consecutive_loading": save_dict.pop("use_consecutive_loading", False),
"h5_open_kwargs": save_dict.pop("h5_open_kwargs", None),
"collate_dtype": save_dict.pop("collate_dtype", "float16"),
}

# Create new instance with all the saved parameters
Expand Down Expand Up @@ -469,6 +485,7 @@ def _create_dataloader(

batch_size = batch_size or (1 if test else self.batch_size)

is_training = datasets is self.train_datasets
sampler = PerturbationBatchSampler(
dataset=ds,
batch_size=batch_size,
Expand All @@ -478,14 +495,15 @@ def _create_dataloader(
use_batch=use_batch,
use_consecutive_loading=self.use_consecutive_loading,
downsample_cells=self.downsample_cells,
balance_outliers=self.balance_outliers if is_training else False,
)

return DataLoader(
ds,
batch_sampler=sampler,
num_workers=self.num_workers,
collate_fn=collate_fn,
pin_memory=True,
pin_memory=getattr(self, "pin_memory", False),
prefetch_factor=4 if not test and self.num_workers > 0 else None,
persistent_workers=bool(self.num_workers > 0 and not test),
worker_init_fn=_worker_init_fn if self.num_workers > 0 else None,
Expand Down Expand Up @@ -629,6 +647,7 @@ def _create_base_dataset(
is_log1p=self.is_log1p,
cell_sentence_len=self.cell_sentence_len,
h5_open_kwargs=self.h5_open_kwargs,
collate_dtype=self.collate_dtype,
)

def _setup_datasets(self):
Expand All @@ -647,8 +666,13 @@ def _setup_datasets(self):
total_files += len(files)

pbar = (
tqdm(total=total_files, desc="Processing datasets", leave=False)
if total_files > 0
tqdm(
total=total_files,
desc="Processing datasets",
leave=False,
file=sys.stderr,
)
if (self.show_progress and total_files > 0)
else None
)

Expand Down
Loading
Loading