Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ filtered_adata.write_h5ad("filtered_data.h5ad")
- **`should_yield_control_cells`**: Include control cells in output (default: `true`)
- **`num_workers`**: Number of workers for data loading (default: 8)
- **`batch_size`**: Batch size for training (default: 128)
- **`val_subsample_fraction`**: Fraction of validation subsets to keep (e.g., `0.01` keeps ~1% of `val_datasets`)

### Usage

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "cell-load"
version = "0.8.8"
version = "0.10.2"
description = "Dataloaders for training models on huge single-cell datasets"
readme = "README.md"
authors = [
Expand Down
2 changes: 0 additions & 2 deletions src/cell_load/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def get_fewshot_celltypes(self, dataset: str) -> dict[str, dict[str, list[str]]]
if key.startswith(f"{dataset}."):
celltype = key.split(".", 1)[1]
result[celltype] = pert_config
print(dataset, celltype, {k: len(v) for k, v in pert_config.items()})
return result

def validate(self) -> None:
Expand All @@ -84,7 +83,6 @@ def validate(self) -> None:

# Check that dataset paths exist
for dataset, path in self.datasets.items():
print(path)
if not Path(path).exists():
logger.warning(f"Dataset path does not exist: {path}")

Expand Down
179 changes: 161 additions & 18 deletions src/cell_load/data_modules/perturbation_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import glob
import re

Expand All @@ -10,7 +11,7 @@
import numpy as np
import torch
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader, Dataset, get_worker_info
from tqdm import tqdm

from ..config import ExperimentConfig
Expand All @@ -26,6 +27,28 @@
logger = logging.getLogger(__name__)


def _worker_init_fn(worker_id: int) -> None:
for var in (
"OMP_NUM_THREADS",
"MKL_NUM_THREADS",
"OPENBLAS_NUM_THREADS",
"NUMEXPR_NUM_THREADS",
):
os.environ.setdefault(var, "1")
try:
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
except RuntimeError:
pass

worker_info = get_worker_info()
if worker_info is None:
return
dataset = worker_info.dataset
if hasattr(dataset, "ensure_h5_open"):
dataset.ensure_h5_open()


class PerturbationDataModule(LightningDataModule):
"""
A unified data module that sets up train/val/test splits for multiple dataset/celltype
Expand All @@ -46,14 +69,18 @@ def __init__(
embed_key: Literal["X_hvg", "X_state"] | None = None,
output_space: Literal["gene", "all", "embedding"] = "gene",
downsample: float | None = None,
is_log1p: bool = False,
downsample_cells: int | None = None,
is_log1p: bool = True,
basal_mapping_strategy: Literal["batch", "random"] = "random",
n_basal_samples: int = 1,
should_yield_control_cells: bool = True,
cell_sentence_len: int = 512,
cache_perturbation_control_pairs: bool = False,
drop_last: bool = False,
val_subsample_fraction: float | None = None,
additional_obs: list[str] | None = None,
use_consecutive_loading: bool = False,
h5_open_kwargs: dict | None = None,
**kwargs, # missing perturbation_features_file and store_raw_basal for backwards compatibility
):
"""
Expand All @@ -69,12 +96,18 @@ def __init__(
random_seed: For reproducible splits & sampling
embed_key: Embedding key or matrix in the H5 file to use for feauturizing cells
output_space: The output space for model predictions (gene, all genes, or embedding-only)
downsample: Fraction of counts to retain via binomial downsampling (only for output_space="all")
is_log1p: Whether raw counts in X are log1p-transformed (auto-set if uns/log1p is present)
downsample: If <=1, fraction of counts to retain via binomial downsampling; if >1, target
read depth per cell (only for output_space="all")
downsample_cells: Max cells per (cell_type, perturbation[, batch]) group; if a group has
fewer cells it is unchanged
is_log1p: Whether raw counts in X are log1p-transformed (default True; auto-set if uns/log1p is present)
basal_mapping_strategy: One of {"batch","random","nearest","ot"}
n_basal_samples: Number of control cells to sample per perturbed cell
cache_perturbation_control_pairs: If True cache perturbation-control pairs at the start of training and reuse them.
drop_last: Whether to drop the last sentence set if it is smaller than cell_sentence_len
val_subsample_fraction: Fraction of validation subsets to keep (subsamples self.val_datasets)
use_consecutive_loading: Whether to form cell sets from consecutive indices for faster IO
h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes)
"""
super().__init__()

Expand All @@ -89,6 +122,21 @@ def __init__(
self.random_seed = random_seed
self.rng = np.random.default_rng(random_seed)
self.drop_last = drop_last
self.use_consecutive_loading = use_consecutive_loading
if val_subsample_fraction is None:
self.val_subsample_fraction = None
else:
try:
val_subsample_fraction = float(val_subsample_fraction)
except (TypeError, ValueError) as exc:
raise ValueError(
"val_subsample_fraction must be a float in (0, 1]."
) from exc
if not (0.0 < val_subsample_fraction <= 1.0):
raise ValueError(
f"val_subsample_fraction must be in (0, 1]; got {val_subsample_fraction!r}"
)
self.val_subsample_fraction = val_subsample_fraction

# H5 field names
self.pert_col = pert_col
Expand All @@ -102,7 +150,24 @@ def __init__(
f"output_space must be one of 'gene', 'all', or 'embedding'; got {self.output_space!r}"
)
self.downsample = downsample
self.is_log1p = is_log1p
if downsample_cells is None:
self.downsample_cells = None
else:
if isinstance(downsample_cells, bool):
raise ValueError("downsample_cells must be a positive integer or None.")
if isinstance(downsample_cells, float):
if not downsample_cells.is_integer():
raise ValueError(
"downsample_cells must be a positive integer or None."
)
downsample_cells = int(downsample_cells)
elif not isinstance(downsample_cells, (int, np.integer)):
raise ValueError("downsample_cells must be a positive integer or None.")
downsample_cells = int(downsample_cells)
if downsample_cells <= 0:
raise ValueError("downsample_cells must be a positive integer or None.")
self.downsample_cells = downsample_cells
self.is_log1p = bool(is_log1p)

# Sampling and mapping
self.n_basal_samples = n_basal_samples
Expand All @@ -113,11 +178,14 @@ def __init__(
# Optional behaviors
self.map_controls = kwargs.get("map_controls", True)
self.perturbation_features_file = kwargs.get("perturbation_features_file")
self.int_counts = kwargs.get("int_counts", False)
self.exp_counts = kwargs.get("exp_counts", False)
self.normalize_counts = kwargs.get("normalize_counts", False)
self.store_raw_basal = kwargs.get("store_raw_basal", False)
self.barcode = kwargs.get("barcode", False)
self.additional_obs = additional_obs
self.h5_open_kwargs = h5_open_kwargs
if self.use_consecutive_loading:
self._set_h5_cache_env_defaults()

logger.info(
f"Initializing DataModule: batch_size={batch_size}, workers={num_workers}, "
Expand Down Expand Up @@ -160,6 +228,12 @@ def _get_reference_dataset(self) -> PerturbationDataset:
return datasets[0].dataset
raise ValueError("No datasets available to extract metadata.")

@staticmethod
def _set_h5_cache_env_defaults() -> None:
os.environ.setdefault("CELL_LOAD_H5_RDCC_NBYTES", str(64 * 1024 * 1024))
os.environ.setdefault("CELL_LOAD_H5_RDCC_NSLOTS", "1000003")
os.environ.setdefault("CELL_LOAD_H5_RDCC_W0", "0.75")

def get_var_names(self):
"""
Get the variable names (gene names) from the first available dataset.
Expand All @@ -174,6 +248,7 @@ def setup(self, stage: str | None = None):
"""
if len(self.train_datasets) == 0:
self._setup_datasets()
self._apply_val_subsample()
logger.info(
"Done! Train / Val / Test splits: %d / %d / %d",
len(self.train_datasets),
Expand Down Expand Up @@ -201,6 +276,7 @@ def save_state(self, filepath: str):
"embed_key": self.embed_key,
"output_space": self.output_space,
"downsample": self.downsample,
"downsample_cells": self.downsample_cells,
"is_log1p": self.is_log1p,
"basal_mapping_strategy": self.basal_mapping_strategy,
"n_basal_samples": self.n_basal_samples,
Expand All @@ -210,11 +286,14 @@ def save_state(self, filepath: str):
# Include the optional behaviors
"map_controls": self.map_controls,
"perturbation_features_file": self.perturbation_features_file,
"int_counts": self.int_counts,
"exp_counts": self.exp_counts,
"normalize_counts": self.normalize_counts,
"store_raw_basal": self.store_raw_basal,
"barcode": self.barcode,
"val_subsample_fraction": self.val_subsample_fraction,
"additional_obs": self.additional_obs,
"use_consecutive_loading": self.use_consecutive_loading,
"h5_open_kwargs": self.h5_open_kwargs,
}

torch.save(save_dict, filepath)
Expand Down Expand Up @@ -253,10 +332,12 @@ def load_state(cls, filepath: str):
"perturbation_features_file": save_dict.pop(
"perturbation_features_file", None
),
"int_counts": save_dict.pop("int_counts", False),
"exp_counts": save_dict.pop("exp_counts", False),
"normalize_counts": save_dict.pop("normalize_counts", False),
"store_raw_basal": save_dict.pop("store_raw_basal", False),
"barcode": save_dict.pop("barcode", True),
"use_consecutive_loading": save_dict.pop("use_consecutive_loading", False),
"h5_open_kwargs": save_dict.pop("h5_open_kwargs", None),
}

# Create new instance with all the saved parameters
Expand Down Expand Up @@ -379,8 +460,8 @@ def _create_dataloader(
batch_size: int | None = None,
):
"""Create a DataLoader with appropriate configuration."""
use_int_counts = "int_counts" in self.__dict__ and self.int_counts
collate_fn = partial(PerturbationDataset.collate_fn, int_counts=use_int_counts)
use_exp_counts = "exp_counts" in self.__dict__ and self.exp_counts
collate_fn = partial(PerturbationDataset.collate_fn, exp_counts=use_exp_counts)

ds = MetadataConcatDataset(datasets)
use_batch = self.basal_mapping_strategy == "batch"
Expand All @@ -394,6 +475,8 @@ def _create_dataloader(
cell_sentence_len=self.cell_sentence_len,
test=test,
use_batch=use_batch,
use_consecutive_loading=self.use_consecutive_loading,
downsample_cells=self.downsample_cells,
)

return DataLoader(
Expand All @@ -403,6 +486,8 @@ def _create_dataloader(
collate_fn=collate_fn,
pin_memory=True,
prefetch_factor=4 if not test and self.num_workers > 0 else None,
persistent_workers=bool(self.num_workers > 0 and not test),
worker_init_fn=_worker_init_fn if self.num_workers > 0 else None,
)

def _setup_global_maps(self):
Expand All @@ -421,7 +506,6 @@ def _setup_global_maps(self):
for dataset_name in self.config.get_all_datasets():
dataset_path = Path(self.config.datasets[dataset_name])
files = self._find_dataset_files(dataset_path)

for _fname, fpath in files.items():
with h5py.File(fpath, "r") as f:
uns = f.get("uns")
Expand Down Expand Up @@ -465,6 +549,19 @@ def _setup_global_maps(self):
"Detected uns/log1p in at least one dataset; setting is_log1p=True."
)

if self.is_log1p:
if seen_log1p:
logger.warning(
"is_log1p mode is ENABLED. Detected uns/log1p metadata (example: %s).",
log1p_example,
)
else:
logger.warning(
"is_log1p mode is ENABLED by configuration/default, but no uns/log1p metadata was detected."
)
else:
logger.warning("is_log1p mode is DISABLED.")

# Create one-hot maps
if self.perturbation_features_file:
# Load the custom featurizations from a torch file
Expand Down Expand Up @@ -502,6 +599,7 @@ def _create_base_dataset(
mapping_kwargs["cache_perturbation_control_pairs"] = (
self.cache_perturbation_control_pairs
)
mapping_kwargs["use_consecutive_loading"] = self.use_consecutive_loading

return PerturbationDataset(
name=dataset_name,
Expand All @@ -528,6 +626,8 @@ def _create_base_dataset(
additional_obs=self.additional_obs,
downsample=self.downsample,
is_log1p=self.is_log1p,
cell_sentence_len=self.cell_sentence_len,
h5_open_kwargs=self.h5_open_kwargs,
)

def _setup_datasets(self):
Expand All @@ -536,9 +636,26 @@ def _setup_datasets(self):
Uses H5MetadataCache for faster metadata access.
"""

for dataset_name in self.config.get_all_datasets():
dataset_names = list(self.config.get_all_datasets())
dataset_files: dict[str, dict[str, Path]] = {}
total_files = 0
for dataset_name in dataset_names:
dataset_path = Path(self.config.datasets[dataset_name])
files = self._find_dataset_files(dataset_path)
dataset_files[dataset_name] = files
total_files += len(files)

pbar = (
tqdm(total=total_files, desc="Processing datasets", leave=False)
if total_files > 0
else None
)

for dataset_name in dataset_names:
files = dataset_files[dataset_name]

if pbar is not None:
pbar.set_description(f"Processing {dataset_name}")

# Get configuration for this dataset
zeroshot_celltypes = self.config.get_zeroshot_celltypes(dataset_name)
Expand All @@ -551,9 +668,7 @@ def _setup_datasets(self):
logger.info(f" - Fewshot cell types: {list(fewshot_celltypes.keys())}")

# Process each file in the dataset
for fname, fpath in tqdm(
list(files.items()), desc=f"Processing {dataset_name}"
):
for fname, fpath in files.items():
# Create metadata cache
cache = GlobalH5MetadataCache().get_cache(
str(fpath),
Expand Down Expand Up @@ -600,12 +715,40 @@ def _setup_datasets(self):
val_sum += counts["val"]
test_sum += counts["test"]

tqdm.write(
f"Processed {fname}: {train_sum} train, {val_sum} val, {test_sum} test"
)
if pbar is not None:
pbar.update(1)
pbar.set_postfix_str(
f"{train_sum} train, {val_sum} val, {test_sum} test"
)

logger.info("\n")

if pbar is not None:
pbar.close()

def _apply_val_subsample(self) -> None:
"""Subsample validation datasets based on val_subsample_fraction."""
if self.val_subsample_fraction is None or len(self.val_datasets) == 0:
return

total = len(self.val_datasets)
keep = max(1, int(np.ceil(total * self.val_subsample_fraction)))
if keep >= total:
return

# Deterministic ordering for reproducibility.
ordered = sorted(
self.val_datasets,
key=lambda subset: (subset.dataset.name, str(subset.dataset.h5_path)),
)
self.val_datasets = ordered[:keep]
logger.info(
"Subsampled validation datasets: kept %d/%d (val_subsample_fraction=%.4f).",
keep,
total,
self.val_subsample_fraction,
)

def _split_fewshot_celltype(
self,
ds: PerturbationDataset,
Expand Down
Loading
Loading