diff --git a/README.md b/README.md index 67778ac..115a402 100644 --- a/README.md +++ b/README.md @@ -287,6 +287,7 @@ filtered_adata.write_h5ad("filtered_data.h5ad") - **`should_yield_control_cells`**: Include control cells in output (default: `true`) - **`num_workers`**: Number of workers for data loading (default: 8) - **`batch_size`**: Batch size for training (default: 128) +- **`val_subsample_fraction`**: Fraction of validation subsets to keep (e.g., `0.01` keeps ~1% of `val_datasets`) ### Usage diff --git a/pyproject.toml b/pyproject.toml index 93cd94c..5317f42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cell-load" -version = "0.8.8" +version = "0.10.2" description = "Dataloaders for training models on huge single-cell datasets" readme = "README.md" authors = [ diff --git a/src/cell_load/config.py b/src/cell_load/config.py index 6dad97a..ddc430e 100644 --- a/src/cell_load/config.py +++ b/src/cell_load/config.py @@ -70,7 +70,6 @@ def get_fewshot_celltypes(self, dataset: str) -> dict[str, dict[str, list[str]]] if key.startswith(f"{dataset}."): celltype = key.split(".", 1)[1] result[celltype] = pert_config - print(dataset, celltype, {k: len(v) for k, v in pert_config.items()}) return result def validate(self) -> None: @@ -84,7 +83,6 @@ def validate(self) -> None: # Check that dataset paths exist for dataset, path in self.datasets.items(): - print(path) if not Path(path).exists(): logger.warning(f"Dataset path does not exist: {path}") diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index c787272..e6934b9 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -1,4 +1,5 @@ import logging +import os import glob import re @@ -10,7 +11,7 @@ import numpy as np import torch from lightning.pytorch import LightningDataModule -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader, Dataset, get_worker_info from tqdm import tqdm from ..config import ExperimentConfig @@ -26,6 +27,28 @@ logger = logging.getLogger(__name__) +def _worker_init_fn(worker_id: int) -> None: + for var in ( + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + "OPENBLAS_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + ): + os.environ.setdefault(var, "1") + try: + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + except RuntimeError: + pass + + worker_info = get_worker_info() + if worker_info is None: + return + dataset = worker_info.dataset + if hasattr(dataset, "ensure_h5_open"): + dataset.ensure_h5_open() + + class PerturbationDataModule(LightningDataModule): """ A unified data module that sets up train/val/test splits for multiple dataset/celltype @@ -46,14 +69,18 @@ def __init__( embed_key: Literal["X_hvg", "X_state"] | None = None, output_space: Literal["gene", "all", "embedding"] = "gene", downsample: float | None = None, - is_log1p: bool = False, + downsample_cells: int | None = None, + is_log1p: bool = True, basal_mapping_strategy: Literal["batch", "random"] = "random", n_basal_samples: int = 1, should_yield_control_cells: bool = True, cell_sentence_len: int = 512, cache_perturbation_control_pairs: bool = False, drop_last: bool = False, + val_subsample_fraction: float | None = None, additional_obs: list[str] | None = None, + use_consecutive_loading: bool = False, + h5_open_kwargs: dict | None = None, **kwargs, # missing perturbation_features_file and store_raw_basal for backwards compatibility ): """ @@ -69,12 +96,18 @@ def __init__( random_seed: For reproducible splits & sampling embed_key: Embedding key or matrix in the H5 file to use for feauturizing cells output_space: The output space for model predictions (gene, all genes, or embedding-only) - downsample: Fraction of counts to retain via binomial downsampling (only for output_space="all") - is_log1p: Whether raw counts in X are log1p-transformed (auto-set if uns/log1p is present) + downsample: If <=1, fraction of counts to retain via binomial downsampling; if >1, target + read depth per cell (only for output_space="all") + downsample_cells: Max cells per (cell_type, perturbation[, batch]) group; if a group has + fewer cells it is unchanged + is_log1p: Whether raw counts in X are log1p-transformed (default True; auto-set if uns/log1p is present) basal_mapping_strategy: One of {"batch","random","nearest","ot"} n_basal_samples: Number of control cells to sample per perturbed cell cache_perturbation_control_pairs: If True cache perturbation-control pairs at the start of training and reuse them. drop_last: Whether to drop the last sentence set if it is smaller than cell_sentence_len + val_subsample_fraction: Fraction of validation subsets to keep (subsamples self.val_datasets) + use_consecutive_loading: Whether to form cell sets from consecutive indices for faster IO + h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes) """ super().__init__() @@ -89,6 +122,21 @@ def __init__( self.random_seed = random_seed self.rng = np.random.default_rng(random_seed) self.drop_last = drop_last + self.use_consecutive_loading = use_consecutive_loading + if val_subsample_fraction is None: + self.val_subsample_fraction = None + else: + try: + val_subsample_fraction = float(val_subsample_fraction) + except (TypeError, ValueError) as exc: + raise ValueError( + "val_subsample_fraction must be a float in (0, 1]." + ) from exc + if not (0.0 < val_subsample_fraction <= 1.0): + raise ValueError( + f"val_subsample_fraction must be in (0, 1]; got {val_subsample_fraction!r}" + ) + self.val_subsample_fraction = val_subsample_fraction # H5 field names self.pert_col = pert_col @@ -102,7 +150,24 @@ def __init__( f"output_space must be one of 'gene', 'all', or 'embedding'; got {self.output_space!r}" ) self.downsample = downsample - self.is_log1p = is_log1p + if downsample_cells is None: + self.downsample_cells = None + else: + if isinstance(downsample_cells, bool): + raise ValueError("downsample_cells must be a positive integer or None.") + if isinstance(downsample_cells, float): + if not downsample_cells.is_integer(): + raise ValueError( + "downsample_cells must be a positive integer or None." + ) + downsample_cells = int(downsample_cells) + elif not isinstance(downsample_cells, (int, np.integer)): + raise ValueError("downsample_cells must be a positive integer or None.") + downsample_cells = int(downsample_cells) + if downsample_cells <= 0: + raise ValueError("downsample_cells must be a positive integer or None.") + self.downsample_cells = downsample_cells + self.is_log1p = bool(is_log1p) # Sampling and mapping self.n_basal_samples = n_basal_samples @@ -113,11 +178,14 @@ def __init__( # Optional behaviors self.map_controls = kwargs.get("map_controls", True) self.perturbation_features_file = kwargs.get("perturbation_features_file") - self.int_counts = kwargs.get("int_counts", False) + self.exp_counts = kwargs.get("exp_counts", False) self.normalize_counts = kwargs.get("normalize_counts", False) self.store_raw_basal = kwargs.get("store_raw_basal", False) self.barcode = kwargs.get("barcode", False) self.additional_obs = additional_obs + self.h5_open_kwargs = h5_open_kwargs + if self.use_consecutive_loading: + self._set_h5_cache_env_defaults() logger.info( f"Initializing DataModule: batch_size={batch_size}, workers={num_workers}, " @@ -160,6 +228,12 @@ def _get_reference_dataset(self) -> PerturbationDataset: return datasets[0].dataset raise ValueError("No datasets available to extract metadata.") + @staticmethod + def _set_h5_cache_env_defaults() -> None: + os.environ.setdefault("CELL_LOAD_H5_RDCC_NBYTES", str(64 * 1024 * 1024)) + os.environ.setdefault("CELL_LOAD_H5_RDCC_NSLOTS", "1000003") + os.environ.setdefault("CELL_LOAD_H5_RDCC_W0", "0.75") + def get_var_names(self): """ Get the variable names (gene names) from the first available dataset. @@ -174,6 +248,7 @@ def setup(self, stage: str | None = None): """ if len(self.train_datasets) == 0: self._setup_datasets() + self._apply_val_subsample() logger.info( "Done! Train / Val / Test splits: %d / %d / %d", len(self.train_datasets), @@ -201,6 +276,7 @@ def save_state(self, filepath: str): "embed_key": self.embed_key, "output_space": self.output_space, "downsample": self.downsample, + "downsample_cells": self.downsample_cells, "is_log1p": self.is_log1p, "basal_mapping_strategy": self.basal_mapping_strategy, "n_basal_samples": self.n_basal_samples, @@ -210,11 +286,14 @@ def save_state(self, filepath: str): # Include the optional behaviors "map_controls": self.map_controls, "perturbation_features_file": self.perturbation_features_file, - "int_counts": self.int_counts, + "exp_counts": self.exp_counts, "normalize_counts": self.normalize_counts, "store_raw_basal": self.store_raw_basal, "barcode": self.barcode, + "val_subsample_fraction": self.val_subsample_fraction, "additional_obs": self.additional_obs, + "use_consecutive_loading": self.use_consecutive_loading, + "h5_open_kwargs": self.h5_open_kwargs, } torch.save(save_dict, filepath) @@ -253,10 +332,12 @@ def load_state(cls, filepath: str): "perturbation_features_file": save_dict.pop( "perturbation_features_file", None ), - "int_counts": save_dict.pop("int_counts", False), + "exp_counts": save_dict.pop("exp_counts", False), "normalize_counts": save_dict.pop("normalize_counts", False), "store_raw_basal": save_dict.pop("store_raw_basal", False), "barcode": save_dict.pop("barcode", True), + "use_consecutive_loading": save_dict.pop("use_consecutive_loading", False), + "h5_open_kwargs": save_dict.pop("h5_open_kwargs", None), } # Create new instance with all the saved parameters @@ -379,8 +460,8 @@ def _create_dataloader( batch_size: int | None = None, ): """Create a DataLoader with appropriate configuration.""" - use_int_counts = "int_counts" in self.__dict__ and self.int_counts - collate_fn = partial(PerturbationDataset.collate_fn, int_counts=use_int_counts) + use_exp_counts = "exp_counts" in self.__dict__ and self.exp_counts + collate_fn = partial(PerturbationDataset.collate_fn, exp_counts=use_exp_counts) ds = MetadataConcatDataset(datasets) use_batch = self.basal_mapping_strategy == "batch" @@ -394,6 +475,8 @@ def _create_dataloader( cell_sentence_len=self.cell_sentence_len, test=test, use_batch=use_batch, + use_consecutive_loading=self.use_consecutive_loading, + downsample_cells=self.downsample_cells, ) return DataLoader( @@ -403,6 +486,8 @@ def _create_dataloader( collate_fn=collate_fn, pin_memory=True, prefetch_factor=4 if not test and self.num_workers > 0 else None, + persistent_workers=bool(self.num_workers > 0 and not test), + worker_init_fn=_worker_init_fn if self.num_workers > 0 else None, ) def _setup_global_maps(self): @@ -421,7 +506,6 @@ def _setup_global_maps(self): for dataset_name in self.config.get_all_datasets(): dataset_path = Path(self.config.datasets[dataset_name]) files = self._find_dataset_files(dataset_path) - for _fname, fpath in files.items(): with h5py.File(fpath, "r") as f: uns = f.get("uns") @@ -465,6 +549,19 @@ def _setup_global_maps(self): "Detected uns/log1p in at least one dataset; setting is_log1p=True." ) + if self.is_log1p: + if seen_log1p: + logger.warning( + "is_log1p mode is ENABLED. Detected uns/log1p metadata (example: %s).", + log1p_example, + ) + else: + logger.warning( + "is_log1p mode is ENABLED by configuration/default, but no uns/log1p metadata was detected." + ) + else: + logger.warning("is_log1p mode is DISABLED.") + # Create one-hot maps if self.perturbation_features_file: # Load the custom featurizations from a torch file @@ -502,6 +599,7 @@ def _create_base_dataset( mapping_kwargs["cache_perturbation_control_pairs"] = ( self.cache_perturbation_control_pairs ) + mapping_kwargs["use_consecutive_loading"] = self.use_consecutive_loading return PerturbationDataset( name=dataset_name, @@ -528,6 +626,8 @@ def _create_base_dataset( additional_obs=self.additional_obs, downsample=self.downsample, is_log1p=self.is_log1p, + cell_sentence_len=self.cell_sentence_len, + h5_open_kwargs=self.h5_open_kwargs, ) def _setup_datasets(self): @@ -536,9 +636,26 @@ def _setup_datasets(self): Uses H5MetadataCache for faster metadata access. """ - for dataset_name in self.config.get_all_datasets(): + dataset_names = list(self.config.get_all_datasets()) + dataset_files: dict[str, dict[str, Path]] = {} + total_files = 0 + for dataset_name in dataset_names: dataset_path = Path(self.config.datasets[dataset_name]) files = self._find_dataset_files(dataset_path) + dataset_files[dataset_name] = files + total_files += len(files) + + pbar = ( + tqdm(total=total_files, desc="Processing datasets", leave=False) + if total_files > 0 + else None + ) + + for dataset_name in dataset_names: + files = dataset_files[dataset_name] + + if pbar is not None: + pbar.set_description(f"Processing {dataset_name}") # Get configuration for this dataset zeroshot_celltypes = self.config.get_zeroshot_celltypes(dataset_name) @@ -551,9 +668,7 @@ def _setup_datasets(self): logger.info(f" - Fewshot cell types: {list(fewshot_celltypes.keys())}") # Process each file in the dataset - for fname, fpath in tqdm( - list(files.items()), desc=f"Processing {dataset_name}" - ): + for fname, fpath in files.items(): # Create metadata cache cache = GlobalH5MetadataCache().get_cache( str(fpath), @@ -600,12 +715,40 @@ def _setup_datasets(self): val_sum += counts["val"] test_sum += counts["test"] - tqdm.write( - f"Processed {fname}: {train_sum} train, {val_sum} val, {test_sum} test" - ) + if pbar is not None: + pbar.update(1) + pbar.set_postfix_str( + f"{train_sum} train, {val_sum} val, {test_sum} test" + ) logger.info("\n") + if pbar is not None: + pbar.close() + + def _apply_val_subsample(self) -> None: + """Subsample validation datasets based on val_subsample_fraction.""" + if self.val_subsample_fraction is None or len(self.val_datasets) == 0: + return + + total = len(self.val_datasets) + keep = max(1, int(np.ceil(total * self.val_subsample_fraction))) + if keep >= total: + return + + # Deterministic ordering for reproducibility. + ordered = sorted( + self.val_datasets, + key=lambda subset: (subset.dataset.name, str(subset.dataset.h5_path)), + ) + self.val_datasets = ordered[:keep] + logger.info( + "Subsampled validation datasets: kept %d/%d (val_subsample_fraction=%.4f).", + keep, + total, + self.val_subsample_fraction, + ) + def _split_fewshot_celltype( self, ds: PerturbationDataset, diff --git a/src/cell_load/data_modules/samplers.py b/src/cell_load/data_modules/samplers.py index 9ebb767..47c6eb1 100644 --- a/src/cell_load/data_modules/samplers.py +++ b/src/cell_load/data_modules/samplers.py @@ -33,6 +33,8 @@ def __init__( cell_sentence_len: int = 512, test: bool = False, use_batch: bool = False, + use_consecutive_loading: bool = False, + downsample_cells: int | None = None, seed: int = 0, epoch: int = 0, ): @@ -48,6 +50,7 @@ def __init__( self.batch_size = batch_size self.test = test self.use_batch = use_batch + self.use_consecutive_loading = use_consecutive_loading self.seed = seed self.epoch = epoch @@ -59,6 +62,7 @@ def __init__( self.cell_sentence_len = cell_sentence_len self.drop_last = drop_last + self.downsample_cells = self._validate_downsample_cells(downsample_cells) # Setup distributed settings if distributed mode is enabled. self.distributed = False @@ -124,9 +128,15 @@ def _create_batches(self) -> list[list[int]]: # If batch is smaller than cell_sentence_len, sample with replacement if len(sentence) < self.cell_sentence_len and not self.test: # during inference, don't sample by replacement - new_sentence = np.random.choice( - sentence, size=self.cell_sentence_len, replace=True - ).tolist() + if self.use_consecutive_loading: + repeats = int( + np.ceil(self.cell_sentence_len / max(len(sentence), 1)) + ) + new_sentence = (sentence * repeats)[: self.cell_sentence_len] + else: + new_sentence = np.random.choice( + sentence, size=self.cell_sentence_len, replace=True + ).tolist() num_partial += 1 else: new_sentence = copy.deepcopy(sentence) @@ -157,6 +167,48 @@ def _create_batches(self) -> list[list[int]]: return all_batches + def _validate_downsample_cells(self, downsample_cells: int | None) -> int | None: + if downsample_cells is None: + return None + if isinstance(downsample_cells, bool): + raise ValueError("downsample_cells must be a positive integer or None.") + if isinstance(downsample_cells, float): + if not downsample_cells.is_integer(): + raise ValueError("downsample_cells must be a positive integer or None.") + downsample_cells = int(downsample_cells) + elif not isinstance(downsample_cells, (int, np.integer)): + raise ValueError("downsample_cells must be a positive integer or None.") + downsample_cells = int(downsample_cells) + if downsample_cells <= 0: + raise ValueError("downsample_cells must be a positive integer or None.") + return downsample_cells + + def _apply_downsample_cells(self, sentences: list[list[int]]) -> list[list[int]]: + if self.downsample_cells is None or not sentences: + return sentences + + total = sum(len(sentence) for sentence in sentences) + if total <= self.downsample_cells: + return sentences + + order = np.random.permutation(len(sentences)) + selected: list[list[int]] = [] + remaining = self.downsample_cells + for idx in order: + if remaining <= 0: + break + sentence = sentences[idx] + if len(sentence) <= remaining: + selected.append(sentence) + remaining -= len(sentence) + else: + if not self.drop_last and remaining > 0: + selected.append(sentence[:remaining]) + remaining = 0 + break + + return selected + def _get_rank_sentences(self) -> list[list[int]]: """ Get the subset of sentences that this rank should process. @@ -239,11 +291,67 @@ def _process_subset(self, global_offset: int, subset: Subset) -> list[list[int]] group_indices = sorted_indices[start:end] np.random.shuffle(group_indices) + group_sentences = [] for i in range(0, len(group_indices), self.cell_sentence_len): sentence = group_indices[i : i + self.cell_sentence_len] if len(sentence) < self.cell_sentence_len and self.drop_last: continue - subset_batches.append(sentence.tolist()) + group_sentences.append(sentence.tolist()) + + group_sentences = self._apply_downsample_cells(group_sentences) + subset_batches.extend(group_sentences) + + return subset_batches + + def _process_subset_consecutive( + self, global_offset: int, subset: Subset + ) -> list[list[int]]: + """ + Process a single subset to create consecutive sentences based on H5 codes. + + This assumes the input indices are already in file order and splits + sentences at code-change boundaries without shuffling. + """ + base_dataset = subset.dataset + indices = np.array(subset.indices) + if indices.size == 0: + return [] + + cache: H5MetadataCache = self.metadata_caches[base_dataset.h5_path] + + # Codes in file order + cell_codes = cache.cell_type_codes[indices] + pert_codes = cache.pert_codes[indices] + if getattr(self, "use_batch", False): + batch_codes = cache.batch_codes[indices] + code_change = ( + (batch_codes[1:] != batch_codes[:-1]) + | (cell_codes[1:] != cell_codes[:-1]) + | (pert_codes[1:] != pert_codes[:-1]) + ) + else: + code_change = (cell_codes[1:] != cell_codes[:-1]) | ( + pert_codes[1:] != pert_codes[:-1] + ) + + # Global indices in dataset order + global_indices = np.arange(global_offset, global_offset + len(indices)) + + # Split into contiguous segments when codes change + boundaries = np.where(code_change)[0] + 1 + segments = np.split(global_indices, boundaries) + + subset_batches = [] + for segment in segments: + group_sentences = [] + for i in range(0, len(segment), self.cell_sentence_len): + sentence = segment[i : i + self.cell_sentence_len] + if len(sentence) < self.cell_sentence_len and self.drop_last: + continue + group_sentences.append(sentence.tolist()) + + group_sentences = self._apply_downsample_cells(group_sentences) + subset_batches.extend(group_sentences) return subset_batches @@ -254,7 +362,10 @@ def _create_sentences(self) -> list[list[int]]: global_offset = 0 all_batches = [] for subset in self.dataset.datasets: - subset_batches = self._process_subset(global_offset, subset) + if self.use_consecutive_loading: + subset_batches = self._process_subset_consecutive(global_offset, subset) + else: + subset_batches = self._process_subset(global_offset, subset) all_batches.extend(subset_batches) global_offset += len(subset) np.random.shuffle(all_batches) diff --git a/src/cell_load/dataset/_metadata.py b/src/cell_load/dataset/_metadata.py index f0736ea..900f8dc 100644 --- a/src/cell_load/dataset/_metadata.py +++ b/src/cell_load/dataset/_metadata.py @@ -1,4 +1,6 @@ -from torch.utils.data import ConcatDataset, Dataset +from bisect import bisect_right + +from torch.utils.data import ConcatDataset, Dataset, Subset class MetadataConcatDataset(ConcatDataset): @@ -25,3 +27,49 @@ def __init__(self, datasets: list[Dataset]): raise ValueError( "All datasets must share the same embed_key, control_pert, pert_col, and batch_col" ) + + def __getitems__(self, indices): + """ + Batch-aware fetch to enable fast-path loading when supported by datasets. + Falls back to per-item access for non-consecutive loading. + """ + if not getattr(self.base.mapping_strategy, "use_consecutive_loading", False): + return [self[i] for i in indices] + + results = [None] * len(indices) + grouped = {} + + for out_pos, idx in enumerate(indices): + dataset_idx = bisect_right(self.cumulative_sizes, idx) + sample_idx = ( + idx + if dataset_idx == 0 + else idx - self.cumulative_sizes[dataset_idx - 1] + ) + grouped.setdefault(dataset_idx, []).append((out_pos, sample_idx)) + + for dataset_idx, pos_samples in grouped.items(): + ds = self.datasets[dataset_idx] + positions, sample_indices = zip(*pos_samples) + + if isinstance(ds, Subset) and hasattr(ds.dataset, "__getitems__"): + underlying_indices = [ds.indices[i] for i in sample_indices] + samples = ds.dataset.__getitems__(underlying_indices) + elif hasattr(ds, "__getitems__"): + samples = ds.__getitems__(list(sample_indices)) + else: + samples = [ds[i] for i in sample_indices] + + for pos, sample in zip(positions, samples): + results[pos] = sample + + return results + + def ensure_h5_open(self) -> None: + """ + Ensure all underlying H5 files are open in the current process. + """ + for ds in self.datasets: + base = ds.dataset if isinstance(ds, Subset) else ds + if hasattr(base, "ensure_h5_open"): + base.ensure_h5_open() diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index b11e64b..1e6a8c5 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -1,4 +1,5 @@ import logging +import os from pathlib import Path from functools import lru_cache @@ -45,7 +46,9 @@ def __init__( barcode: bool = False, additional_obs: list[str] | None = None, downsample: float | None = None, - is_log1p: bool = False, + is_log1p: bool = True, + cell_sentence_len: int | None = None, + h5_open_kwargs: dict | None = None, **kwargs, ): """ @@ -68,8 +71,11 @@ def __init__( store_raw_basal: If True, include raw basal expression barcode: If True, include cell barcodes in output additional_obs: Optional list of obs column names to include in each sample - downsample: Fraction of counts to retain via binomial downsampling (only for output_space="all") - is_log1p: Whether raw counts in X are log1p-transformed (affects downsampling) + downsample: If <=1, fraction of counts to retain via binomial downsampling; if >1, target + read depth per cell (only for output_space="all") + is_log1p: Whether raw counts in X are log1p-transformed (default True; affects downsampling) + cell_sentence_len: Optional sentence length for consecutive loading batches + h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes) **kwargs: Additional options (e.g. output_space) """ super().__init__() @@ -101,19 +107,23 @@ def __init__( downsample = float(downsample) except (TypeError, ValueError) as exc: raise ValueError( - f"downsample must be a float in (0, 1]; got {downsample!r}" + f"downsample must be a positive float; got {downsample!r}" ) from exc - if not (0.0 < downsample <= 1.0): - raise ValueError(f"downsample must be in (0, 1]; got {downsample!r}") + if not (0.0 < downsample): + raise ValueError(f"downsample must be > 0; got {downsample!r}") self.downsample = downsample self.is_log1p = bool(is_log1p) + self.cell_sentence_len = cell_sentence_len + self.h5_open_kwargs = self._normalize_h5_open_kwargs(h5_open_kwargs) self.additional_obs = self._validate_additional_obs(additional_obs) # Load metadata cache and open file self.metadata_cache = GlobalH5MetadataCache().get_cache( str(self.h5_path), pert_col, cell_type_key, control_pert, batch_col ) - self.h5_file = h5py.File(self.h5_path, "r") + self.h5_file = None + self._h5_pid = None + self._open_h5_file() # Load cell barcodes if requested if self.barcode: @@ -135,6 +145,96 @@ def __init__( splits = ["train", "train_eval", "val", "test"] self.split_perturbed_indices = {s: set() for s in splits} self.split_control_indices = {s: set() for s in splits} + self._init_split_index_cache() + + def _init_split_index_cache(self) -> None: + self._split_code_to_name = ("train", "train_eval", "val", "test") + self._split_name_to_code = { + name: idx for idx, name in enumerate(self._split_code_to_name) + } + self._index_to_split_code = np.full(self.n_cells, -1, dtype=np.int8) + for split, indices in self.split_perturbed_indices.items(): + if indices: + self._index_to_split_code[list(indices)] = self._split_name_to_code[ + split + ] + for split, indices in self.split_control_indices.items(): + if indices: + self._index_to_split_code[list(indices)] = self._split_name_to_code[ + split + ] + + @staticmethod + def _parse_env_int(name: str, default: int) -> int: + raw = os.getenv(name) + if raw is None: + return default + try: + return int(float(raw)) + except ValueError: + logger.warning("Invalid %s=%r; using %d", name, raw, default) + return default + + @staticmethod + def _parse_env_float(name: str, default: float) -> float: + raw = os.getenv(name) + if raw is None: + return default + try: + return float(raw) + except ValueError: + logger.warning("Invalid %s=%r; using %.3f", name, raw, default) + return default + + def _default_h5_open_kwargs(self) -> dict: + rdcc_nbytes = self._parse_env_int("CELL_LOAD_H5_RDCC_NBYTES", 64 * 1024 * 1024) + rdcc_nslots = self._parse_env_int("CELL_LOAD_H5_RDCC_NSLOTS", 1_000_003) + rdcc_w0 = self._parse_env_float("CELL_LOAD_H5_RDCC_W0", 0.75) + kwargs = { + "rdcc_nbytes": rdcc_nbytes, + "rdcc_nslots": rdcc_nslots, + "rdcc_w0": rdcc_w0, + } + return self._sanitize_h5_open_kwargs(kwargs) + + @staticmethod + def _sanitize_h5_open_kwargs(kwargs: dict) -> dict: + cleaned = {} + rdcc_nbytes = kwargs.get("rdcc_nbytes") + if rdcc_nbytes is not None and rdcc_nbytes > 0: + cleaned["rdcc_nbytes"] = int(rdcc_nbytes) + rdcc_nslots = kwargs.get("rdcc_nslots") + if rdcc_nslots is not None and rdcc_nslots > 0: + cleaned["rdcc_nslots"] = int(rdcc_nslots) + rdcc_w0 = kwargs.get("rdcc_w0") + if rdcc_w0 is not None and 0.0 <= float(rdcc_w0) <= 1.0: + cleaned["rdcc_w0"] = float(rdcc_w0) + return cleaned + + def _normalize_h5_open_kwargs(self, h5_open_kwargs: dict | None) -> dict: + if h5_open_kwargs is None: + return self._default_h5_open_kwargs() + return self._sanitize_h5_open_kwargs(h5_open_kwargs) + + def _open_h5_file(self) -> None: + if self.h5_file is not None: + try: + self.h5_file.close() + except Exception: + pass + self.h5_file = h5py.File(self.h5_path, "r", **self.h5_open_kwargs) + self._h5_pid = os.getpid() + + def _ensure_h5_open(self) -> None: + if ( + self.h5_file is None + or self._h5_pid != os.getpid() + or not self.h5_file.id.valid + ): + self._open_h5_file() + + def ensure_h5_open(self) -> None: + self._ensure_h5_open() def set_store_raw_expression(self, flag: bool) -> None: """ @@ -175,9 +275,11 @@ def __getitem__(self, idx: int): - cell_type: the cell type - batch: the batch (as an int or string) - batch_name: the batch name (as a string) + - dataset_name: the dataset name (from the TOML) - pert_cell_counts: the raw gene expression of the perturbed cell (if store_raw_expression is True) - ctrl_cell_counts: the raw gene expression of the control cell (if store_raw_basal is True) """ + self._ensure_h5_open() # Get the perturbed cell expression, control cell expression, and index of mapped control cell file_idx = int(self.all_indices[idx]) @@ -215,6 +317,7 @@ def __getitem__(self, idx: int): "ctrl_cell_emb": ctrl_expr, "pert_emb": pert_onehot, "pert_name": pert_name, + "dataset_name": self.name, "batch_name": batch_name, "batch": batch_onehot, "cell_type": cell_type, @@ -250,6 +353,191 @@ def __getitem__(self, idx: int): return sample + def __getitems__(self, indices): + """ + Batch-aware fetch for consecutive loading with batched CSR densification. + Falls back to per-item access when not applicable. + """ + self._ensure_h5_open() + if not self._use_batched_fetch(): + return [self.__getitem__(int(i)) for i in indices] + + idx_arr = np.asarray(indices, dtype=np.int64) + if idx_arr.size == 0: + return [] + + file_indices = self.all_indices[idx_arr] + splits = [self._find_split_for_idx(int(i)) for i in file_indices] + + ctrl_indices = [] + missing_ctrl = [] + sentence_len = self.cell_sentence_len + use_sentence_blocks = ( + sentence_len is not None + and sentence_len > 0 + and len(file_indices) % sentence_len == 0 + and getattr(self.mapping_strategy, "use_consecutive_loading", False) + and hasattr(self.mapping_strategy, "_sample_consecutive_controls") + and hasattr(self.mapping_strategy, "split_control_pool") + ) + + if use_sentence_blocks: + for start in range(0, len(file_indices), sentence_len): + sentence_idx = file_indices[start : start + sentence_len] + split = splits[start] + cell_type_code = self.get_cell_type_code(sentence_idx[0]) + pool = self.mapping_strategy.split_control_pool[split].get( + cell_type_code, None + ) + if not pool: + raise ValueError( + f"No control cells found in RandomMappingStrategy for cell type '{self.get_cell_type(sentence_idx[0])}'" + ) + block = self.mapping_strategy._sample_consecutive_controls( + pool, len(sentence_idx) + ) + ctrl_indices.extend(block.tolist()) + missing_ctrl.extend([False] * len(block)) + else: + for file_idx, split in zip(file_indices, splits): + ctrl_idx = self.mapping_strategy.get_control_index( + self, split, int(file_idx) + ) + if ctrl_idx is None: + ctrl_indices.append(-1) + missing_ctrl.append(True) + else: + ctrl_indices.append(int(ctrl_idx)) + missing_ctrl.append(False) + + ctrl_indices_arr = np.asarray(ctrl_indices, dtype=np.int64) + missing_ctrl_mask = np.asarray(missing_ctrl, dtype=bool) + + if missing_ctrl_mask.any(): + missing_pos = int(np.flatnonzero(missing_ctrl_mask)[0]) + missing_file_idx = int(file_indices[missing_pos]) + if not self.embed_key: + raise ValueError( + f"No control cells found for cell type '{self.get_cell_type(missing_file_idx)}'" + ) + if (self.store_raw_basal and self.output_space != "embedding") or ( + self.barcode and self.cell_barcodes is not None + ): + raise ValueError( + f"No control cells found for cell type '{self.get_cell_type(missing_file_idx)}'" + ) + + if self.embed_key: + pert_expr_batch = self._fetch_obsm_expression_batch( + file_indices, self.embed_key + ) + if missing_ctrl_mask.any(): + ctrl_expr_batch = torch.zeros_like(pert_expr_batch) + valid_positions = np.flatnonzero(~missing_ctrl_mask) + if valid_positions.size: + valid_ctrl_indices = ctrl_indices_arr[valid_positions] + valid_positions_t = torch.from_numpy( + valid_positions.astype(np.int64) + ) + ctrl_expr_batch[valid_positions_t] = ( + self._fetch_obsm_expression_batch( + valid_ctrl_indices, self.embed_key + ) + ) + else: + ctrl_expr_batch = self._fetch_obsm_expression_batch( + ctrl_indices_arr, self.embed_key + ) + else: + pert_expr_batch = self._fetch_gene_expression_batch(file_indices) + ctrl_expr_batch = self._fetch_gene_expression_batch(ctrl_indices_arr) + + pert_counts_batch = None + ctrl_counts_batch = None + if self.store_raw_expression and self.output_space != "embedding": + if self.output_space == "gene": + pert_counts_batch = self._fetch_obsm_expression_batch( + file_indices, "X_hvg" + ) + elif self.output_space == "all": + if self.embed_key: + pert_counts_batch = self._fetch_gene_expression_batch(file_indices) + else: + pert_counts_batch = pert_expr_batch + + if self.store_raw_basal and self.output_space != "embedding": + if self.output_space == "gene": + ctrl_counts_batch = self._fetch_obsm_expression_batch( + ctrl_indices_arr, "X_hvg" + ) + elif self.output_space == "all": + if self.embed_key: + ctrl_counts_batch = self._fetch_gene_expression_batch( + ctrl_indices_arr + ) + else: + ctrl_counts_batch = ctrl_expr_batch + + samples = [] + for i, file_idx in enumerate(file_indices): + pert_expr = pert_expr_batch[i] + ctrl_expr = ctrl_expr_batch[i] + ctrl_idx = ctrl_indices_arr[i] + + pert_code = self.metadata_cache.pert_codes[file_idx] + pert_name = self.pert_categories[pert_code] + pert_onehot = ( + self.pert_onehot_map.get(pert_name) if self.pert_onehot_map else None + ) + + cell_type = self.cell_type_categories[ + self.metadata_cache.cell_type_codes[file_idx] + ] + cell_type_onehot = ( + self.cell_type_onehot_map.get(cell_type) + if self.cell_type_onehot_map + else None + ) + + batch_code = self.metadata_cache.batch_codes[file_idx] + batch_name = self.metadata_cache.batch_categories[batch_code] + batch_onehot = ( + self.batch_onehot_map.get(batch_name) if self.batch_onehot_map else None + ) + + sample = { + "pert_cell_emb": pert_expr, + "ctrl_cell_emb": ctrl_expr, + "pert_emb": pert_onehot, + "pert_name": pert_name, + "dataset_name": self.name, + "batch_name": batch_name, + "batch": batch_onehot, + "cell_type": cell_type, + "cell_type_onehot": cell_type_onehot, + } + + if pert_counts_batch is not None: + sample["pert_cell_counts"] = pert_counts_batch[i] + + if ctrl_counts_batch is not None: + sample["ctrl_cell_counts"] = ctrl_counts_batch[i] + + if self.barcode and self.cell_barcodes is not None: + sample["pert_cell_barcode"] = self.cell_barcodes[file_idx] + sample["ctrl_cell_barcode"] = self.cell_barcodes[ctrl_idx] + + if self.additional_obs: + for obs_key in self.additional_obs: + sample[obs_key] = self._fetch_obs_value(file_idx, obs_key) + + samples.append(sample) + + return samples + + def _use_batched_fetch(self) -> bool: + return bool(getattr(self.mapping_strategy, "use_consecutive_loading", False)) + def _validate_additional_obs(self, additional_obs: list[str] | None) -> list[str]: if additional_obs is None: return [] @@ -326,7 +614,9 @@ def get_dim_for_obsm(self, key: str) -> int: """ Get the feature dimensionality of obsm data with the specified key (e.g., 'X_uce'). """ - return self.h5_file[f"obsm/{key}"].shape[1] + matrix = self._get_obsm_matrix(key) + _, n_cols = self._get_matrix_shape(matrix) + return n_cols def get_cell_type(self, idx): """ @@ -337,6 +627,13 @@ def get_cell_type(self, idx): code = self.metadata_cache.cell_type_codes[idx] return self.metadata_cache.cell_type_categories[code] + def get_cell_type_code(self, idx: int) -> int: + """ + Get the cell type code for a given index. + """ + idx = int(idx) if hasattr(idx, "__int__") else idx + return int(self.metadata_cache.cell_type_codes[idx]) + def get_all_cell_types(self, indices): """ Get the cell types for all given indices. @@ -344,6 +641,12 @@ def get_all_cell_types(self, indices): codes = self.metadata_cache.cell_type_codes[indices] return self.metadata_cache.cell_type_categories[codes] + def get_all_cell_type_codes(self, indices) -> np.ndarray: + """ + Get the cell type codes for all given indices. + """ + return self.metadata_cache.cell_type_codes[indices] + def get_perturbation_name(self, idx): """ Get the perturbation name for a given index. @@ -353,6 +656,32 @@ def get_perturbation_name(self, idx): pert_code = self.metadata_cache.pert_codes[idx] return self.metadata_cache.pert_categories[pert_code] + def get_perturbation_code(self, idx: int) -> int: + """ + Get the perturbation code for a given index. + """ + idx = int(idx) if hasattr(idx, "__int__") else idx + return int(self.metadata_cache.pert_codes[idx]) + + def get_all_perturbation_codes(self, indices) -> np.ndarray: + """ + Get the perturbation codes for all given indices. + """ + return self.metadata_cache.pert_codes[indices] + + def get_batch_code(self, idx: int) -> int: + """ + Get the batch code for a given index. + """ + idx = int(idx) if hasattr(idx, "__int__") else idx + return int(self.metadata_cache.batch_codes[idx]) + + def get_all_batch_codes(self, indices) -> np.ndarray: + """ + Get the batch codes for all given indices. + """ + return self.metadata_cache.batch_codes[indices] + def to_subset_dataset( self, split: str, @@ -435,22 +764,53 @@ def _fetch_gene_expression_csr_row(self, idx: int) -> tuple[np.ndarray, np.ndarr def _maybe_downsample_counts(self, counts: torch.Tensor) -> torch.Tensor: if ( self.downsample is None - or self.downsample >= 1.0 or self.output_space != "all" + or self.downsample == 1.0 ): return counts counts_np = counts.detach().cpu().numpy() + sampled = self._maybe_downsample_counts_array(counts_np) + return torch.tensor(sampled, dtype=torch.float32) + + def _maybe_downsample_counts_array(self, counts: np.ndarray) -> np.ndarray: + if ( + self.downsample is None + or self.output_space != "all" + or self.downsample == 1.0 + ): + return counts + if self.is_log1p: - counts_lin = np.expm1(counts_np) + counts_lin = np.expm1(counts) counts_int = np.rint(counts_lin).astype(np.int64) else: - counts_int = counts_np.astype(np.int64) + counts_int = counts.astype(np.int64) counts_int = np.maximum(counts_int, 0) - sampled = self.rng.binomial(counts_int, self.downsample) + if self.downsample < 1.0: + sampled = self.rng.binomial(counts_int, self.downsample) + else: + target = float(self.downsample) + if counts_int.ndim == 1: + total = counts_int.sum() + if total <= 0: + sampled = counts_int + else: + p = min(1.0, target / float(total)) + sampled = self.rng.binomial(counts_int, p) + else: + total = counts_int.sum(axis=1, keepdims=True).astype(np.float64) + p = np.divide( + target, + total, + out=np.ones_like(total, dtype=np.float64), + where=total > 0, + ) + p = np.minimum(p, 1.0) + sampled = self.rng.binomial(counts_int, p) if self.is_log1p: sampled = np.log1p(sampled) - return torch.tensor(sampled, dtype=torch.float32) + return sampled.astype(np.float32) def fetch_gene_expression(self, idx: int) -> torch.Tensor: """ @@ -460,7 +820,7 @@ def fetch_gene_expression(self, idx: int) -> torch.Tensor: if ( attrs.get("encoding-type") == "csr_matrix" and self.downsample is not None - and self.downsample < 1.0 + and self.downsample != 1.0 and self.output_space == "all" ): sub_indices, sub_data = self._fetch_gene_expression_csr_row(idx) @@ -472,9 +832,12 @@ def fetch_gene_expression(self, idx: int) -> torch.Tensor: else: counts_int = sub_data.astype(np.int64) counts_int = np.maximum(counts_int, 0) - sampled = self.rng.binomial(counts_int, self.downsample).astype( - np.float32 - ) + if self.downsample < 1.0: + p = self.downsample + else: + total = counts_int.sum() + p = 1.0 if total <= 0 else min(1.0, self.downsample / float(total)) + sampled = self.rng.binomial(counts_int, p).astype(np.float32) if self.is_log1p: sampled = np.log1p(sampled) dense[sub_indices] = sampled @@ -483,6 +846,276 @@ def fetch_gene_expression(self, idx: int) -> torch.Tensor: data = self._fetch_gene_expression_raw(idx) return self._maybe_downsample_counts(data) + def _fetch_dense_matrix_batch( + self, ds, indices: np.ndarray, n_cols: int + ) -> np.ndarray: + if indices.size == 0: + return np.empty((0, n_cols), dtype=np.float32) + + order = np.argsort(indices) + sorted_rows = indices[order] + dense_sorted = np.zeros((len(sorted_rows), n_cols), dtype=np.float32) + + run_start = 0 + for i in range(1, len(sorted_rows)): + if sorted_rows[i] != sorted_rows[i - 1] + 1: + row_start = int(sorted_rows[run_start]) + row_end = int(sorted_rows[i - 1]) + block = np.asarray(ds[row_start : row_end + 1], dtype=np.float32) + if block.ndim == 1: + block = block[:, None] + dense_sorted[run_start:i] = block + run_start = i + + row_start = int(sorted_rows[run_start]) + row_end = int(sorted_rows[-1]) + block = np.asarray(ds[row_start : row_end + 1], dtype=np.float32) + if block.ndim == 1: + block = block[:, None] + dense_sorted[run_start : len(sorted_rows)] = block + + inv_order = np.empty_like(order) + inv_order[order] = np.arange(len(order)) + return dense_sorted[inv_order] + + @staticmethod + def _is_csr_group(obj) -> bool: + return isinstance(obj, h5py.Group) and ( + obj.attrs.get("encoding-type") == "csr_matrix" + or all(k in obj for k in ("data", "indices", "indptr")) + ) + + def _infer_n_cols_from_indices(self, indices_ds) -> int: + # Rare fallback for malformed files that omit CSR shape metadata. + total_nnz = int(indices_ds.shape[0]) + if total_nnz == 0: + return 0 + + max_col = -1 + chunk = 1_000_000 + for start in range(0, total_nnz, chunk): + stop = min(start + chunk, total_nnz) + block = np.asarray(indices_ds[start:stop], dtype=np.int64) + if block.size == 0: + continue + local_max = int(block.max()) + if local_max > max_col: + max_col = local_max + return max_col + 1 + + def _get_matrix_shape(self, matrix_obj) -> tuple[int, int]: + if isinstance(matrix_obj, h5py.Dataset): + if len(matrix_obj.shape) == 1: + return int(matrix_obj.shape[0]), 1 + if len(matrix_obj.shape) >= 2: + return int(matrix_obj.shape[0]), int(matrix_obj.shape[1]) + raise ValueError("Dataset has invalid rank for matrix-like data.") + + if self._is_csr_group(matrix_obj): + shape_attr = matrix_obj.attrs.get("shape") + if shape_attr is not None: + shape_arr = np.asarray(shape_attr, dtype=np.int64).reshape(-1) + if shape_arr.size >= 2: + return int(shape_arr[0]), int(shape_arr[1]) + + if "indptr" not in matrix_obj: + raise KeyError("CSR group is missing required 'indptr' dataset.") + n_rows = int(matrix_obj["indptr"].shape[0]) - 1 + + if "indices" not in matrix_obj: + raise KeyError("CSR group is missing required 'indices' dataset.") + n_cols = self._infer_n_cols_from_indices(matrix_obj["indices"]) + return n_rows, n_cols + + raise TypeError( + f"Unsupported matrix storage type: {type(matrix_obj).__name__}. " + "Expected h5py.Dataset or CSR-encoded h5py.Group." + ) + + def _fetch_csr_matrix_batch( + self, matrix_group: h5py.Group, indices: np.ndarray, n_cols: int + ) -> np.ndarray: + if indices.size == 0: + return np.empty((0, n_cols), dtype=np.float32) + + if not all(k in matrix_group for k in ("indptr", "data", "indices")): + raise KeyError( + "CSR group must contain 'indptr', 'data', and 'indices' datasets." + ) + + indptr_ds = matrix_group["indptr"] + data_ds = matrix_group["data"] + indices_ds = matrix_group["indices"] + + order = np.argsort(indices) + sorted_rows = indices[order] + dense_sorted = np.zeros((len(sorted_rows), n_cols), dtype=np.float32) + + run_start = 0 + for i in range(1, len(sorted_rows)): + if sorted_rows[i] != sorted_rows[i - 1] + 1: + self._fill_dense_run( + sorted_rows, + run_start, + i, + dense_sorted, + indptr_ds, + data_ds, + indices_ds, + ) + run_start = i + + self._fill_dense_run( + sorted_rows, + run_start, + len(sorted_rows), + dense_sorted, + indptr_ds, + data_ds, + indices_ds, + ) + + inv_order = np.empty_like(order) + inv_order[order] = np.arange(len(order)) + return dense_sorted[inv_order] + + def _fetch_csr_row( + self, matrix_group: h5py.Group, idx: int, n_cols: int + ) -> np.ndarray: + if not all(k in matrix_group for k in ("indptr", "data", "indices")): + raise KeyError( + "CSR group must contain 'indptr', 'data', and 'indices' datasets." + ) + + indptr_ds = matrix_group["indptr"] + data_ds = matrix_group["data"] + indices_ds = matrix_group["indices"] + + start_ptr = int(indptr_ds[idx]) + end_ptr = int(indptr_ds[idx + 1]) + + dense = np.zeros(n_cols, dtype=np.float32) + if end_ptr <= start_ptr: + return dense + + row_data = np.asarray(data_ds[start_ptr:end_ptr], dtype=np.float32) + row_indices = np.asarray(indices_ds[start_ptr:end_ptr], dtype=np.int64) + dense[row_indices] = row_data + return dense + + def _get_obsm_matrix(self, key: str): + path = f"/obsm/{key}" + if path not in self.h5_file: + raise KeyError(f"obsm key '{key}' not found in {self.h5_path}") + return self.h5_file[path] + + def _fetch_gene_expression_batch(self, indices: np.ndarray) -> torch.Tensor: + """ + Fetch raw gene counts for multiple indices at once (CSR fast path). + """ + if indices.size == 0: + return torch.empty((0, self.n_genes), dtype=torch.float32) + + attrs = dict(self.h5_file["X"].attrs) + if attrs.get("encoding-type") != "csr_matrix": + dense = self._fetch_dense_matrix_batch( + self.h5_file["/X"], indices, self.n_genes + ) + dense = self._maybe_downsample_counts_array(dense) + return torch.from_numpy(dense) + + indptr_ds = self.h5_file["/X/indptr"] + data_ds = self.h5_file["/X/data"] + indices_ds = self.h5_file["/X/indices"] + + order = np.argsort(indices) + sorted_rows = indices[order] + dense_sorted = np.zeros((len(sorted_rows), self.n_genes), dtype=np.float32) + + run_start = 0 + for i in range(1, len(sorted_rows)): + if sorted_rows[i] != sorted_rows[i - 1] + 1: + self._fill_dense_run( + sorted_rows, + run_start, + i, + dense_sorted, + indptr_ds, + data_ds, + indices_ds, + ) + run_start = i + + self._fill_dense_run( + sorted_rows, + run_start, + len(sorted_rows), + dense_sorted, + indptr_ds, + data_ds, + indices_ds, + ) + + inv_order = np.empty_like(order) + inv_order[order] = np.arange(len(order)) + dense = dense_sorted[inv_order] + dense = self._maybe_downsample_counts_array(dense) + + return torch.from_numpy(dense) + + def _fetch_obsm_expression_batch( + self, indices: np.ndarray, key: str + ) -> torch.Tensor: + matrix = self._get_obsm_matrix(key) + _, n_cols = self._get_matrix_shape(matrix) + if indices.size == 0: + return torch.empty((0, n_cols), dtype=torch.float32) + + if isinstance(matrix, h5py.Dataset): + dense = self._fetch_dense_matrix_batch(matrix, indices, n_cols) + elif self._is_csr_group(matrix): + dense = self._fetch_csr_matrix_batch(matrix, indices, n_cols) + else: + raise TypeError( + f"Unsupported obsm storage for key '{key}': {type(matrix).__name__}" + ) + + return torch.from_numpy(dense) + + def _fill_dense_run( + self, + sorted_rows: np.ndarray, + start: int, + end: int, + dense_sorted: np.ndarray, + indptr_ds, + data_ds, + indices_ds, + ) -> None: + if start >= end: + return + + row_start = int(sorted_rows[start]) + row_end = int(sorted_rows[end - 1]) + indptr_slice = indptr_ds[row_start : row_end + 2] + base_ptr = int(indptr_slice[0]) + end_ptr = int(indptr_slice[-1]) + + if end_ptr <= base_ptr: + return + + block_data = np.asarray(data_ds[base_ptr:end_ptr], dtype=np.float32) + block_indices = np.asarray(indices_ds[base_ptr:end_ptr], dtype=np.int64) + + for offset, row in enumerate(sorted_rows[start:end]): + ptr_start = int(indptr_slice[offset] - base_ptr) + ptr_end = int(indptr_slice[offset + 1] - base_ptr) + if ptr_end <= ptr_start: + continue + dense_sorted[start + offset, block_indices[ptr_start:ptr_end]] = block_data[ + ptr_start:ptr_end + ] + @lru_cache(maxsize=10000) def fetch_obsm_expression(self, idx: int, key: str) -> torch.Tensor: """ @@ -494,8 +1127,24 @@ def fetch_obsm_expression(self, idx: int, key: str) -> torch.Tensor: Returns: 1D FloatTensor of that embedding """ - row_data = self.h5_file[f"/obsm/{key}"][idx] - return torch.tensor(row_data, dtype=torch.float32) + matrix = self._get_obsm_matrix(key) + _, n_cols = self._get_matrix_shape(matrix) + + if isinstance(matrix, h5py.Dataset): + row_data = np.asarray(matrix[idx], dtype=np.float32) + if row_data.ndim == 0: + row_data = np.asarray([row_data], dtype=np.float32) + elif row_data.ndim > 1: + row_data = row_data.reshape(-1) + return torch.from_numpy(row_data) + + if self._is_csr_group(matrix): + row_data = self._fetch_csr_row(matrix, int(idx), n_cols) + return torch.from_numpy(row_data) + + raise TypeError( + f"Unsupported obsm storage for key '{key}': {type(matrix).__name__}" + ) def get_gene_names(self, output_space="all") -> list[str]: """ @@ -565,7 +1214,7 @@ def _decode(x): # Static methods ############################## @staticmethod - def collate_fn(batch, int_counts=False): + def collate_fn(batch, exp_counts=False): """ Optimized collate function with preallocated lists. Safely handles normalization when vectors sum to zero. @@ -638,42 +1287,30 @@ def collate_fn(batch, int_counts=False): is_discrete = suspected_discrete_torch(pert_cell_counts) is_log = suspected_log_torch(pert_cell_counts) already_logged = (not is_discrete) and is_log + if exp_counts: + if already_logged: + pert_cell_counts = torch.expm1(pert_cell_counts) + pert_cell_counts = torch.nan_to_num( + pert_cell_counts, nan=0.0, posinf=0.0, neginf=0.0 + ) + pert_cell_counts = pert_cell_counts.clamp_min(0).round().to(torch.int32) batch_dict["pert_cell_counts"] = pert_cell_counts - # if already_logged: # counts are already log transformed - # if ( - # int_counts - # ): # if the user wants to model with raw counts, don't log transform - # batch_dict["pert_cell_counts"] = torch.expm1(pert_cell_counts) - # else: - # batch_dict["pert_cell_counts"] = pert_cell_counts - # else: - # if int_counts: - # batch_dict["pert_cell_counts"] = pert_cell_counts - # else: - # batch_dict["pert_cell_counts"] = torch.log1p(pert_cell_counts) - if has_ctrl_cell_counts: ctrl_cell_counts = torch.stack(ctrl_cell_counts_list) - is_discrete = suspected_discrete_torch(pert_cell_counts) - is_log = suspected_log_torch(pert_cell_counts) + is_discrete = suspected_discrete_torch(ctrl_cell_counts) + is_log = suspected_log_torch(ctrl_cell_counts) already_logged = (not is_discrete) and is_log + if exp_counts: + if already_logged: + ctrl_cell_counts = torch.expm1(ctrl_cell_counts) + ctrl_cell_counts = torch.nan_to_num( + ctrl_cell_counts, nan=0.0, posinf=0.0, neginf=0.0 + ) + ctrl_cell_counts = ctrl_cell_counts.clamp_min(0).round().to(torch.int32) batch_dict["ctrl_cell_counts"] = ctrl_cell_counts - # if already_logged: # counts are already log transformed - # if ( - # int_counts - # ): # if the user wants to model with raw counts, don't log transform - # batch_dict["ctrl_cell_counts"] = torch.expm1(ctrl_cell_counts) - # else: - # batch_dict["ctrl_cell_counts"] = ctrl_cell_counts - # else: - # if int_counts: - # batch_dict["ctrl_cell_counts"] = ctrl_cell_counts - # else: - # batch_dict["ctrl_cell_counts"] = torch.log1p(ctrl_cell_counts) - if has_barcodes: batch_dict["pert_cell_barcode"] = pert_cell_barcode_list batch_dict["ctrl_cell_barcode"] = ctrl_cell_barcode_list @@ -748,6 +1385,13 @@ def _register_split_indices( # update them in the dataset self.split_perturbed_indices[split] |= set(perturbed_indices) self.split_control_indices[split] |= set(control_indices) + if not hasattr(self, "_index_to_split_code"): + self._init_split_index_cache() + code = self._split_name_to_code[split] + if len(perturbed_indices) > 0: + self._index_to_split_code[perturbed_indices] = code + if len(control_indices) > 0: + self._index_to_split_code[control_indices] = code # forward these to the mapping strategy self.mapping_strategy.register_split_indices( @@ -756,6 +1400,11 @@ def _register_split_indices( def _find_split_for_idx(self, idx: int) -> str | None: """Utility to find which split (train/val/test) this idx belongs to.""" + if hasattr(self, "_index_to_split_code"): + code = int(self._index_to_split_code[idx]) + if code >= 0: + return self._split_code_to_name[code] + return None for s in self.split_perturbed_indices.keys(): if ( idx in self.split_perturbed_indices[s] @@ -779,13 +1428,13 @@ def _get_num_genes(self) -> int: indices = self.h5_file["X/indices"][:] n_cols = indices.max() + 1 except KeyError: - n_cols = self.h5_file["obsm/X_hvg"].shape[1] + n_cols = self._get_matrix_shape(self.h5_file["obsm/X_hvg"])[1] return n_cols def get_num_hvgs(self) -> int: """Return the number of highly variable genes in the obsm matrix.""" try: - return self.h5_file["obsm/X_hvg"].shape[1] + return self._get_matrix_shape(self.h5_file["obsm/X_hvg"])[1] except: return 0 @@ -800,7 +1449,7 @@ def _get_num_cells(self) -> int: n_rows = len(indptr) - 1 except Exception: # if this also fails, fall back to obsm - n_rows = self.h5_file["obsm/X_hvg"].shape[0] + n_rows = self._get_matrix_shape(self.h5_file["obsm/X_hvg"])[0] return n_rows def get_pert_name(self, idx: int) -> str: @@ -820,9 +1469,13 @@ def __getstate__(self): # Copy the object's dict state = self.__dict__.copy() # Remove the open file object if it exists - if "h5_file" in state: - # We'll also store whether it's currently open, so that we can re-open later if needed - del state["h5_file"] + if self.h5_file is not None: + try: + self.h5_file.close() + except Exception: + pass + state.pop("h5_file", None) + state.pop("_h5_pid", None) return state def __setstate__(self, state): @@ -831,8 +1484,11 @@ def __setstate__(self, state): """ # TODO-Abhi: remove this before release self.__dict__.update(state) - # This ensures that after we unpickle, we have a valid h5_file handle again - self.h5_file = h5py.File(self.h5_path, "r") + if not hasattr(self, "h5_open_kwargs"): + self.h5_open_kwargs = self._normalize_h5_open_kwargs(None) + self.h5_file = None + self._h5_pid = None + self._open_h5_file() self.metadata_cache = GlobalH5MetadataCache().get_cache( str(self.h5_path), self.pert_col, @@ -840,6 +1496,8 @@ def __setstate__(self, state): self.control_pert, self.batch_col, ) + if not hasattr(self, "_index_to_split_code"): + self._init_split_index_cache() def _load_cell_barcodes(self) -> np.ndarray: """ diff --git a/src/cell_load/mapping_strategies/batch.py b/src/cell_load/mapping_strategies/batch.py index 153592b..450e0e1 100644 --- a/src/cell_load/mapping_strategies/batch.py +++ b/src/cell_load/mapping_strategies/batch.py @@ -20,8 +20,16 @@ class BatchMappingStrategy(BaseMappingStrategy): by the tuple (batch, cell_type) instead of just by cell type. """ - def __init__(self, name="batch", random_state=42, n_basal_samples=1, **kwargs): + def __init__( + self, + name="batch", + random_state=42, + n_basal_samples=1, + use_consecutive_loading=False, + **kwargs, + ): super().__init__(name, random_state, n_basal_samples, **kwargs) + self.use_consecutive_loading = use_consecutive_loading # For each split, store a mapping: {(batch, cell_type): [ctrl_indices]} self.split_control_maps = { "train": {}, @@ -29,6 +37,18 @@ def __init__(self, name="batch", random_state=42, n_basal_samples=1, **kwargs): "val": {}, "test": {}, } + # Fixed mapping from perturbed_idx -> list of control indices for consecutive loading. + self.split_control_mapping: dict[str, dict[int, list[int]]] = { + "train": {}, + "train_eval": {}, + "val": {}, + "test": {}, + } + if self.use_consecutive_loading: + logger.info( + "BatchMappingStrategy initialized with use_consecutive_loading=True; " + "control mappings will be assigned in file order." + ) def name(): """Name of the mapping strategy.""" @@ -38,7 +58,7 @@ def register_split_indices( self, dataset: "PerturbationDataset", split: str, - _perturbed_indices: np.ndarray, + perturbed_indices: np.ndarray, control_indices: np.ndarray, ): """ @@ -46,13 +66,66 @@ def register_split_indices( For each control cell, we retrieve both its batch and cell type, using that pair as the key. """ for idx in control_indices: - batch = dataset.get_batch(idx) - cell_type = dataset.get_cell_type(idx) + batch = dataset.get_batch_code(idx) + cell_type = dataset.get_cell_type_code(idx) key = (batch, cell_type) if key not in self.split_control_maps[split]: self.split_control_maps[split][key] = [] self.split_control_maps[split][key].append(idx) + if self.use_consecutive_loading: + self._build_consecutive_mapping( + dataset, split, perturbed_indices, control_indices + ) + + def _build_consecutive_mapping( + self, + dataset: "PerturbationDataset", + split: str, + perturbed_indices: np.ndarray, + control_indices: np.ndarray, + ) -> None: + """ + Build a fixed mapping from each index to control indices using sequential + assignment within each (batch, cell_type) pool, with fallback to cell_type pools. + """ + all_indices = np.concatenate([perturbed_indices, control_indices]) + + # Fallback pools by cell type (sorted for deterministic order). + fallback_pools: dict[int, list[int]] = {} + for (batch, cell_type), indices in self.split_control_maps[split].items(): + fallback_pools.setdefault(cell_type, []).extend(indices) + for cell_type, pool in fallback_pools.items(): + fallback_pools[cell_type] = sorted(pool) + + key_offsets: dict[tuple[int, int], int] = {} + fallback_offsets: dict[int, int] = {} + + for idx in all_indices: + batch = dataset.get_batch_code(idx) + cell_type = dataset.get_cell_type_code(idx) + key = (batch, cell_type) + pool = self.split_control_maps[split].get(key, []) + + if not pool: + pool = fallback_pools.get(cell_type, []) + if not pool: + self.split_control_mapping[split][idx] = [] + continue + offset = fallback_offsets.get(cell_type, 0) + control_idxs = [ + pool[(offset + i) % len(pool)] for i in range(self.n_basal_samples) + ] + fallback_offsets[cell_type] = offset + self.n_basal_samples + else: + offset = key_offsets.get(key, 0) + control_idxs = [ + pool[(offset + i) % len(pool)] for i in range(self.n_basal_samples) + ] + key_offsets[key] = offset + self.n_basal_samples + + self.split_control_mapping[split][idx] = control_idxs + def get_control_indices( self, dataset: "PerturbationDataset", split: str, perturbed_idx: int ) -> np.ndarray: @@ -63,8 +136,18 @@ def get_control_indices( If the batch group for the perturbed cell is empty, the method falls back to using all control cells from the same cell type (regardless of batch). """ - batch = dataset.get_batch(perturbed_idx) - cell_type = dataset.get_cell_type(perturbed_idx) + if self.use_consecutive_loading: + control_idxs = self.split_control_mapping[split].get(perturbed_idx, []) + if not control_idxs: + raise ValueError( + "No control cells found in BatchMappingStrategy for cell type '{}'".format( + dataset.get_cell_type(perturbed_idx) + ) + ) + return np.array(control_idxs) + + batch = dataset.get_batch_code(perturbed_idx) + cell_type = dataset.get_cell_type_code(perturbed_idx) key = (batch, cell_type) pool = self.split_control_maps[split].get(key, []) @@ -78,7 +161,7 @@ def get_control_indices( if not pool: raise ValueError( "No control cells found in BatchMappingStrategy for cell type '{}'".format( - cell_type + dataset.get_cell_type(perturbed_idx) ) ) @@ -92,8 +175,14 @@ def get_control_index( This method first attempts to select from controls in the same batch and cell type. If no controls are present in the same batch, it falls back to all controls from the same cell type. """ - batch = dataset.get_batch(perturbed_idx) - cell_type = dataset.get_cell_type(perturbed_idx) + if self.use_consecutive_loading: + control_idxs = self.split_control_mapping[split].get(perturbed_idx, []) + if not control_idxs: + return None + return control_idxs[0] + + batch = dataset.get_batch_code(perturbed_idx) + cell_type = dataset.get_cell_type_code(perturbed_idx) key = (batch, cell_type) pool = self.split_control_maps[split].get(key, []) diff --git a/src/cell_load/mapping_strategies/mapping_strategies.py b/src/cell_load/mapping_strategies/mapping_strategies.py index 939f6a6..14660b0 100644 --- a/src/cell_load/mapping_strategies/mapping_strategies.py +++ b/src/cell_load/mapping_strategies/mapping_strategies.py @@ -52,6 +52,8 @@ def __setstate__(self, state): logger.info( f"Adding missing 'map_controls' attribute to {self.name} mapping strategy." ) + if not hasattr(self, "use_consecutive_loading"): + self.use_consecutive_loading = False @abstractmethod def register_split_indices( diff --git a/src/cell_load/mapping_strategies/random.py b/src/cell_load/mapping_strategies/random.py index cf879ae..ebea3a8 100644 --- a/src/cell_load/mapping_strategies/random.py +++ b/src/cell_load/mapping_strategies/random.py @@ -31,11 +31,13 @@ def __init__( random_state=42, n_basal_samples=1, cache_perturbation_control_pairs=False, + use_consecutive_loading=False, **kwargs, ): super().__init__(name, random_state, n_basal_samples, **kwargs) self.cache_perturbation_control_pairs = cache_perturbation_control_pairs + self.use_consecutive_loading = use_consecutive_loading if self.cache_perturbation_control_pairs: logger.info( @@ -44,6 +46,17 @@ def __init__( logger.info( f"Warning: If using n_basal_samples > 1, use the original behavior by setting cache_perturbation_control_pairs=False" ) + if self.use_consecutive_loading: + if self.cache_perturbation_control_pairs: + logger.info( + "RandomMappingStrategy initialized with use_consecutive_loading=True; " + "control mappings will be assigned in file order." + ) + else: + logger.info( + "RandomMappingStrategy initialized with use_consecutive_loading=True; " + "control cells will be sampled as consecutive blocks with random offsets." + ) # Map cell type -> list of control indices. self.split_control_pool = { @@ -85,33 +98,34 @@ def register_split_indices( """ all_indices = np.concatenate([perturbed_indices, control_indices]) - # Get cell types for all control indices - cell_types = dataset.get_all_cell_types(control_indices) + # Get cell type codes for all control indices + cell_types = dataset.get_all_cell_type_codes(control_indices) # Group by cell type and store the control indices - for ct in np.unique(cell_types): - ct_mask = cell_types == ct + for ct_code in np.unique(cell_types): + ct_mask = cell_types == ct_code ct_indices = control_indices[ct_mask] - if ct not in self.split_control_pool[split]: - self.split_control_pool[split][ct] = list(ct_indices) + if ct_code not in self.split_control_pool[split]: + self.split_control_pool[split][ct_code] = list(ct_indices) else: - self.split_control_pool[split][ct].extend(ct_indices) + self.split_control_pool[split][ct_code].extend(ct_indices) - if self.cache_perturbation_control_pairs: + build_mapping = self.cache_perturbation_control_pairs + if build_mapping: logger.info( f"Creating cached perturbation-control mapping for split '{split}' with {len(perturbed_indices)} perturbed cells and {len(control_indices)} control cells" ) # Create a fixed mapping from perturbed_idx -> list of control indices # Only if caching is enabled - if self.cache_perturbation_control_pairs: + if build_mapping: pert_groups = {} # Group perturbed indices by cell type and perturbation name for pert_idx in all_indices: - pert_cell_type = dataset.get_cell_type(pert_idx) - pert_name = dataset.get_perturbation_name(pert_idx) + pert_cell_type = dataset.get_cell_type_code(pert_idx) + pert_name = dataset.get_perturbation_code(pert_idx) key = (pert_cell_type, pert_name) if key not in pert_groups: @@ -129,20 +143,29 @@ def register_split_indices( self.split_control_mapping[split][pert_idx] = [] continue - # Shuffle control pool for random assignment - shuffled_pool = pool.copy() - self.rng.shuffle(shuffled_pool) - # Calculate total assignments needed for this cell type / perturbation total_assignments_needed = len(pert_idxs_list) * self.n_basal_samples - - # Ensure we have enough controls for all assignments - assert len(shuffled_pool) >= total_assignments_needed, ( - f"Need {total_assignments_needed} controls for {cell_type} / {pert_name} but only have {len(shuffled_pool)}" - ) - - # Assign control cells without replacement to this cell type / perturbation - control_assignments = shuffled_pool[:total_assignments_needed] + if self.use_consecutive_loading: + pool_arr = np.asarray(pool, dtype=np.int64) + if pool_arr.size == 0: + control_assignments = [] + else: + repeats = int(np.ceil(total_assignments_needed / pool_arr.size)) + control_assignments = np.tile(pool_arr, repeats)[ + :total_assignments_needed + ].tolist() + else: + # Shuffle control pool for random assignment + shuffled_pool = pool.copy() + self.rng.shuffle(shuffled_pool) + + # Ensure we have enough controls for all assignments + assert len(shuffled_pool) >= total_assignments_needed, ( + f"Need {total_assignments_needed} controls for {cell_type} / {pert_name} but only have {len(shuffled_pool)}" + ) + + # Assign control cells without replacement to this cell type / perturbation + control_assignments = shuffled_pool[:total_assignments_needed] # Assign control cells to each perturbed cell for i, pert_idx in enumerate(pert_idxs_list): @@ -163,7 +186,8 @@ def get_control_indices( Returns n_basal_samples control indices that are from the same cell type as the perturbed cell. If cache_perturbation_control_pairs is True, uses the pre-computed mapping. - If False, samples new control cells each time (original behavior). + If use_consecutive_loading is True, samples a consecutive block with a random offset. + Otherwise, samples new control cells each time (original behavior). """ if self.cache_perturbation_control_pairs: @@ -174,13 +198,21 @@ def get_control_indices( f"No control cells found in RandomMappingStrategy for cell type '{dataset.get_cell_type(perturbed_idx)}'" ) return np.array(control_idxs) + if self.use_consecutive_loading: + pert_cell_type = dataset.get_cell_type_code(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + raise ValueError( + f"No control cells found in RandomMappingStrategy for cell type '{dataset.get_cell_type(perturbed_idx)}'" + ) + return self._sample_consecutive_controls(pool, self.n_basal_samples) else: # Sample new control cells each time (original behavior) - pert_cell_type = dataset.get_cell_type(perturbed_idx) + pert_cell_type = dataset.get_cell_type_code(perturbed_idx) pool = self.split_control_pool[split].get(pert_cell_type, None) if not pool: raise ValueError( - f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" + f"No control cells found in RandomMappingStrategy for cell type '{dataset.get_cell_type(perturbed_idx)}'" ) control_idxs = self.rng.choices(pool, k=self.n_basal_samples) return np.array(control_idxs) @@ -192,7 +224,8 @@ def get_control_index( Returns a single control index from the same cell type as the perturbed cell. If cache_perturbation_control_pairs is True, uses the pre-computed mapping. - If False, samples a new control cell each time (original behavior). + If use_consecutive_loading is True, samples a consecutive control with a random offset. + Otherwise, samples a new control cell each time (original behavior). """ if self.cache_perturbation_control_pairs: @@ -201,10 +234,37 @@ def get_control_index( if len(control_idxs) == 0: return None return control_idxs[0] + if self.use_consecutive_loading: + pert_cell_type = dataset.get_cell_type_code(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + return None + return self._sample_consecutive_controls(pool, 1)[0] else: # Sample new control cell each time (original behavior) - pert_cell_type = dataset.get_cell_type(perturbed_idx) + pert_cell_type = dataset.get_cell_type_code(perturbed_idx) pool = self.split_control_pool[split].get(pert_cell_type, None) if not pool: return None return self.rng.choice(pool) + + def _sample_consecutive_controls( + self, pool: list[int], n_samples: int + ) -> np.ndarray: + """Return n_samples consecutive control indices with a random start offset.""" + pool_size = len(pool) + if pool_size == 0 or n_samples <= 0: + return np.array([], dtype=np.int64) + if pool_size == 1: + return np.array([pool[0]] * n_samples, dtype=np.int64) + + start = self.rng.randrange(pool_size) + if n_samples == 1: + return np.array([pool[start]], dtype=np.int64) + + if start + n_samples <= pool_size: + return np.array(pool[start : start + n_samples], dtype=np.int64) + + tail = pool[start:] + head = pool[: n_samples - len(tail)] + return np.array(tail + head, dtype=np.int64) diff --git a/src/cell_load/utils/data_utils.py b/src/cell_load/utils/data_utils.py index 25a3d83..ead8749 100644 --- a/src/cell_load/utils/data_utils.py +++ b/src/cell_load/utils/data_utils.py @@ -136,6 +136,11 @@ def generate_onehot_map(keys) -> dict: """ Build a map from each unique key to a fixed-length one-hot torch vector. + Note: + We clone each row from the identity matrix so every tensor owns compact + storage. This avoids pathological file sizes when maps are serialized + with pickle (shared-storage tensor views can serialize very poorly). + Args: keys: iterable of hashable items Returns: @@ -145,7 +150,7 @@ def generate_onehot_map(keys) -> dict: num_classes = len(unique_keys) # identity matrix rows are one-hot vectors onehots = torch.eye(num_classes) - return {k: onehots[i] for i, k in enumerate(unique_keys)} + return {k: onehots[i].clone() for i, k in enumerate(unique_keys)} def data_to_torch_X(X):