From 1ac3c961bf6c1b7f0c152579b9ccdb2cc7099beb Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Mon, 26 Jan 2026 19:42:34 +0000 Subject: [PATCH 01/13] added option for use_consecutive_dataloader to speedup IO --- pyproject.toml | 2 +- .../data_modules/perturbation_dataloader.py | 7 ++ src/cell_load/data_modules/samplers.py | 67 ++++++++++++- src/cell_load/mapping_strategies/batch.py | 93 ++++++++++++++++++- .../mapping_strategies/mapping_strategies.py | 2 + src/cell_load/mapping_strategies/random.py | 51 ++++++---- 6 files changed, 199 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 93cd94c..71533dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cell-load" -version = "0.8.8" +version = "0.9.0" description = "Dataloaders for training models on huge single-cell datasets" readme = "README.md" authors = [ diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index c787272..17d56b7 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -54,6 +54,7 @@ def __init__( cache_perturbation_control_pairs: bool = False, drop_last: bool = False, additional_obs: list[str] | None = None, + use_consecutive_loading: bool = False, **kwargs, # missing perturbation_features_file and store_raw_basal for backwards compatibility ): """ @@ -75,6 +76,7 @@ def __init__( n_basal_samples: Number of control cells to sample per perturbed cell cache_perturbation_control_pairs: If True cache perturbation-control pairs at the start of training and reuse them. drop_last: Whether to drop the last sentence set if it is smaller than cell_sentence_len + use_consecutive_loading: Whether to form cell sets from consecutive indices for faster IO """ super().__init__() @@ -89,6 +91,7 @@ def __init__( self.random_seed = random_seed self.rng = np.random.default_rng(random_seed) self.drop_last = drop_last + self.use_consecutive_loading = use_consecutive_loading # H5 field names self.pert_col = pert_col @@ -215,6 +218,7 @@ def save_state(self, filepath: str): "store_raw_basal": self.store_raw_basal, "barcode": self.barcode, "additional_obs": self.additional_obs, + "use_consecutive_loading": self.use_consecutive_loading, } torch.save(save_dict, filepath) @@ -257,6 +261,7 @@ def load_state(cls, filepath: str): "normalize_counts": save_dict.pop("normalize_counts", False), "store_raw_basal": save_dict.pop("store_raw_basal", False), "barcode": save_dict.pop("barcode", True), + "use_consecutive_loading": save_dict.pop("use_consecutive_loading", False), } # Create new instance with all the saved parameters @@ -394,6 +399,7 @@ def _create_dataloader( cell_sentence_len=self.cell_sentence_len, test=test, use_batch=use_batch, + use_consecutive_loading=self.use_consecutive_loading, ) return DataLoader( @@ -502,6 +508,7 @@ def _create_base_dataset( mapping_kwargs["cache_perturbation_control_pairs"] = ( self.cache_perturbation_control_pairs ) + mapping_kwargs["use_consecutive_loading"] = self.use_consecutive_loading return PerturbationDataset( name=dataset_name, diff --git a/src/cell_load/data_modules/samplers.py b/src/cell_load/data_modules/samplers.py index 9ebb767..d4eb413 100644 --- a/src/cell_load/data_modules/samplers.py +++ b/src/cell_load/data_modules/samplers.py @@ -33,6 +33,7 @@ def __init__( cell_sentence_len: int = 512, test: bool = False, use_batch: bool = False, + use_consecutive_loading: bool = False, seed: int = 0, epoch: int = 0, ): @@ -48,6 +49,7 @@ def __init__( self.batch_size = batch_size self.test = test self.use_batch = use_batch + self.use_consecutive_loading = use_consecutive_loading self.seed = seed self.epoch = epoch @@ -124,9 +126,15 @@ def _create_batches(self) -> list[list[int]]: # If batch is smaller than cell_sentence_len, sample with replacement if len(sentence) < self.cell_sentence_len and not self.test: # during inference, don't sample by replacement - new_sentence = np.random.choice( - sentence, size=self.cell_sentence_len, replace=True - ).tolist() + if self.use_consecutive_loading: + repeats = int( + np.ceil(self.cell_sentence_len / max(len(sentence), 1)) + ) + new_sentence = (sentence * repeats)[: self.cell_sentence_len] + else: + new_sentence = np.random.choice( + sentence, size=self.cell_sentence_len, replace=True + ).tolist() num_partial += 1 else: new_sentence = copy.deepcopy(sentence) @@ -247,6 +255,54 @@ def _process_subset(self, global_offset: int, subset: Subset) -> list[list[int]] return subset_batches + def _process_subset_consecutive( + self, global_offset: int, subset: Subset + ) -> list[list[int]]: + """ + Process a single subset to create consecutive sentences based on H5 codes. + + This assumes the input indices are already in file order and splits + sentences at code-change boundaries without shuffling. + """ + base_dataset = subset.dataset + indices = np.array(subset.indices) + if indices.size == 0: + return [] + + cache: H5MetadataCache = self.metadata_caches[base_dataset.h5_path] + + # Codes in file order + cell_codes = cache.cell_type_codes[indices] + pert_codes = cache.pert_codes[indices] + if getattr(self, "use_batch", False): + batch_codes = cache.batch_codes[indices] + code_change = ( + (batch_codes[1:] != batch_codes[:-1]) + | (cell_codes[1:] != cell_codes[:-1]) + | (pert_codes[1:] != pert_codes[:-1]) + ) + else: + code_change = (cell_codes[1:] != cell_codes[:-1]) | ( + pert_codes[1:] != pert_codes[:-1] + ) + + # Global indices in dataset order + global_indices = np.arange(global_offset, global_offset + len(indices)) + + # Split into contiguous segments when codes change + boundaries = np.where(code_change)[0] + 1 + segments = np.split(global_indices, boundaries) + + subset_batches = [] + for segment in segments: + for i in range(0, len(segment), self.cell_sentence_len): + sentence = segment[i : i + self.cell_sentence_len] + if len(sentence) < self.cell_sentence_len and self.drop_last: + continue + subset_batches.append(sentence.tolist()) + + return subset_batches + def _create_sentences(self) -> list[list[int]]: """ Process each subset sequentially (across all datasets) and combine the batches. @@ -254,7 +310,10 @@ def _create_sentences(self) -> list[list[int]]: global_offset = 0 all_batches = [] for subset in self.dataset.datasets: - subset_batches = self._process_subset(global_offset, subset) + if self.use_consecutive_loading: + subset_batches = self._process_subset_consecutive(global_offset, subset) + else: + subset_batches = self._process_subset(global_offset, subset) all_batches.extend(subset_batches) global_offset += len(subset) np.random.shuffle(all_batches) diff --git a/src/cell_load/mapping_strategies/batch.py b/src/cell_load/mapping_strategies/batch.py index 153592b..2b84fda 100644 --- a/src/cell_load/mapping_strategies/batch.py +++ b/src/cell_load/mapping_strategies/batch.py @@ -20,8 +20,16 @@ class BatchMappingStrategy(BaseMappingStrategy): by the tuple (batch, cell_type) instead of just by cell type. """ - def __init__(self, name="batch", random_state=42, n_basal_samples=1, **kwargs): + def __init__( + self, + name="batch", + random_state=42, + n_basal_samples=1, + use_consecutive_loading=False, + **kwargs, + ): super().__init__(name, random_state, n_basal_samples, **kwargs) + self.use_consecutive_loading = use_consecutive_loading # For each split, store a mapping: {(batch, cell_type): [ctrl_indices]} self.split_control_maps = { "train": {}, @@ -29,6 +37,18 @@ def __init__(self, name="batch", random_state=42, n_basal_samples=1, **kwargs): "val": {}, "test": {}, } + # Fixed mapping from perturbed_idx -> list of control indices for consecutive loading. + self.split_control_mapping: dict[str, dict[int, list[int]]] = { + "train": {}, + "train_eval": {}, + "val": {}, + "test": {}, + } + if self.use_consecutive_loading: + logger.info( + "BatchMappingStrategy initialized with use_consecutive_loading=True; " + "control mappings will be assigned in file order." + ) def name(): """Name of the mapping strategy.""" @@ -38,7 +58,7 @@ def register_split_indices( self, dataset: "PerturbationDataset", split: str, - _perturbed_indices: np.ndarray, + perturbed_indices: np.ndarray, control_indices: np.ndarray, ): """ @@ -53,6 +73,59 @@ def register_split_indices( self.split_control_maps[split][key] = [] self.split_control_maps[split][key].append(idx) + if self.use_consecutive_loading: + self._build_consecutive_mapping( + dataset, split, perturbed_indices, control_indices + ) + + def _build_consecutive_mapping( + self, + dataset: "PerturbationDataset", + split: str, + perturbed_indices: np.ndarray, + control_indices: np.ndarray, + ) -> None: + """ + Build a fixed mapping from each index to control indices using sequential + assignment within each (batch, cell_type) pool, with fallback to cell_type pools. + """ + all_indices = np.concatenate([perturbed_indices, control_indices]) + + # Fallback pools by cell type (sorted for deterministic order). + fallback_pools: dict[str, list[int]] = {} + for (batch, cell_type), indices in self.split_control_maps[split].items(): + fallback_pools.setdefault(cell_type, []).extend(indices) + for cell_type, pool in fallback_pools.items(): + fallback_pools[cell_type] = sorted(pool) + + key_offsets: dict[tuple[int, str], int] = {} + fallback_offsets: dict[str, int] = {} + + for idx in all_indices: + batch = dataset.get_batch(idx) + cell_type = dataset.get_cell_type(idx) + key = (batch, cell_type) + pool = self.split_control_maps[split].get(key, []) + + if not pool: + pool = fallback_pools.get(cell_type, []) + if not pool: + self.split_control_mapping[split][idx] = [] + continue + offset = fallback_offsets.get(cell_type, 0) + control_idxs = [ + pool[(offset + i) % len(pool)] for i in range(self.n_basal_samples) + ] + fallback_offsets[cell_type] = offset + self.n_basal_samples + else: + offset = key_offsets.get(key, 0) + control_idxs = [ + pool[(offset + i) % len(pool)] for i in range(self.n_basal_samples) + ] + key_offsets[key] = offset + self.n_basal_samples + + self.split_control_mapping[split][idx] = control_idxs + def get_control_indices( self, dataset: "PerturbationDataset", split: str, perturbed_idx: int ) -> np.ndarray: @@ -63,6 +136,16 @@ def get_control_indices( If the batch group for the perturbed cell is empty, the method falls back to using all control cells from the same cell type (regardless of batch). """ + if self.use_consecutive_loading: + control_idxs = self.split_control_mapping[split].get(perturbed_idx, []) + if not control_idxs: + raise ValueError( + "No control cells found in BatchMappingStrategy for cell type '{}'".format( + dataset.get_cell_type(perturbed_idx) + ) + ) + return np.array(control_idxs) + batch = dataset.get_batch(perturbed_idx) cell_type = dataset.get_cell_type(perturbed_idx) key = (batch, cell_type) @@ -92,6 +175,12 @@ def get_control_index( This method first attempts to select from controls in the same batch and cell type. If no controls are present in the same batch, it falls back to all controls from the same cell type. """ + if self.use_consecutive_loading: + control_idxs = self.split_control_mapping[split].get(perturbed_idx, []) + if not control_idxs: + return None + return control_idxs[0] + batch = dataset.get_batch(perturbed_idx) cell_type = dataset.get_cell_type(perturbed_idx) key = (batch, cell_type) diff --git a/src/cell_load/mapping_strategies/mapping_strategies.py b/src/cell_load/mapping_strategies/mapping_strategies.py index 939f6a6..14660b0 100644 --- a/src/cell_load/mapping_strategies/mapping_strategies.py +++ b/src/cell_load/mapping_strategies/mapping_strategies.py @@ -52,6 +52,8 @@ def __setstate__(self, state): logger.info( f"Adding missing 'map_controls' attribute to {self.name} mapping strategy." ) + if not hasattr(self, "use_consecutive_loading"): + self.use_consecutive_loading = False @abstractmethod def register_split_indices( diff --git a/src/cell_load/mapping_strategies/random.py b/src/cell_load/mapping_strategies/random.py index cf879ae..2d18024 100644 --- a/src/cell_load/mapping_strategies/random.py +++ b/src/cell_load/mapping_strategies/random.py @@ -31,11 +31,13 @@ def __init__( random_state=42, n_basal_samples=1, cache_perturbation_control_pairs=False, + use_consecutive_loading=False, **kwargs, ): super().__init__(name, random_state, n_basal_samples, **kwargs) self.cache_perturbation_control_pairs = cache_perturbation_control_pairs + self.use_consecutive_loading = use_consecutive_loading if self.cache_perturbation_control_pairs: logger.info( @@ -44,6 +46,11 @@ def __init__( logger.info( f"Warning: If using n_basal_samples > 1, use the original behavior by setting cache_perturbation_control_pairs=False" ) + if self.use_consecutive_loading: + logger.info( + "RandomMappingStrategy initialized with use_consecutive_loading=True; " + "control mappings will be assigned in file order." + ) # Map cell type -> list of control indices. self.split_control_pool = { @@ -98,14 +105,15 @@ def register_split_indices( else: self.split_control_pool[split][ct].extend(ct_indices) - if self.cache_perturbation_control_pairs: + build_mapping = self.cache_perturbation_control_pairs or self.use_consecutive_loading + if build_mapping: logger.info( f"Creating cached perturbation-control mapping for split '{split}' with {len(perturbed_indices)} perturbed cells and {len(control_indices)} control cells" ) # Create a fixed mapping from perturbed_idx -> list of control indices # Only if caching is enabled - if self.cache_perturbation_control_pairs: + if build_mapping: pert_groups = {} # Group perturbed indices by cell type and perturbation name @@ -129,20 +137,31 @@ def register_split_indices( self.split_control_mapping[split][pert_idx] = [] continue - # Shuffle control pool for random assignment - shuffled_pool = pool.copy() - self.rng.shuffle(shuffled_pool) - # Calculate total assignments needed for this cell type / perturbation total_assignments_needed = len(pert_idxs_list) * self.n_basal_samples - - # Ensure we have enough controls for all assignments - assert len(shuffled_pool) >= total_assignments_needed, ( - f"Need {total_assignments_needed} controls for {cell_type} / {pert_name} but only have {len(shuffled_pool)}" - ) - - # Assign control cells without replacement to this cell type / perturbation - control_assignments = shuffled_pool[:total_assignments_needed] + if self.use_consecutive_loading: + pool_arr = np.asarray(pool, dtype=np.int64) + if pool_arr.size == 0: + control_assignments = [] + else: + repeats = int( + np.ceil(total_assignments_needed / pool_arr.size) + ) + control_assignments = np.tile(pool_arr, repeats)[ + :total_assignments_needed + ].tolist() + else: + # Shuffle control pool for random assignment + shuffled_pool = pool.copy() + self.rng.shuffle(shuffled_pool) + + # Ensure we have enough controls for all assignments + assert len(shuffled_pool) >= total_assignments_needed, ( + f"Need {total_assignments_needed} controls for {cell_type} / {pert_name} but only have {len(shuffled_pool)}" + ) + + # Assign control cells without replacement to this cell type / perturbation + control_assignments = shuffled_pool[:total_assignments_needed] # Assign control cells to each perturbed cell for i, pert_idx in enumerate(pert_idxs_list): @@ -166,7 +185,7 @@ def get_control_indices( If False, samples new control cells each time (original behavior). """ - if self.cache_perturbation_control_pairs: + if self.cache_perturbation_control_pairs or self.use_consecutive_loading: # Use cached mapping control_idxs = self.split_control_mapping[split][perturbed_idx] if len(control_idxs) == 0: @@ -195,7 +214,7 @@ def get_control_index( If False, samples a new control cell each time (original behavior). """ - if self.cache_perturbation_control_pairs: + if self.cache_perturbation_control_pairs or self.use_consecutive_loading: # Use cached mapping control_idxs = self.split_control_mapping[split][perturbed_idx] if len(control_idxs) == 0: From c4845c1b9c991a9f6c284f234966072ed3a7d953 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 3 Feb 2026 16:48:17 +0000 Subject: [PATCH 02/13] working fast data loader 3x for some accuracy drop --- pyproject.toml | 2 +- src/cell_load/mapping_strategies/random.py | 61 ++++++++++++++++++---- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 71533dd..5dc8eaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cell-load" -version = "0.9.0" +version = "0.9.1" description = "Dataloaders for training models on huge single-cell datasets" readme = "README.md" authors = [ diff --git a/src/cell_load/mapping_strategies/random.py b/src/cell_load/mapping_strategies/random.py index 2d18024..27dde8b 100644 --- a/src/cell_load/mapping_strategies/random.py +++ b/src/cell_load/mapping_strategies/random.py @@ -47,10 +47,16 @@ def __init__( f"Warning: If using n_basal_samples > 1, use the original behavior by setting cache_perturbation_control_pairs=False" ) if self.use_consecutive_loading: - logger.info( - "RandomMappingStrategy initialized with use_consecutive_loading=True; " - "control mappings will be assigned in file order." - ) + if self.cache_perturbation_control_pairs: + logger.info( + "RandomMappingStrategy initialized with use_consecutive_loading=True; " + "control mappings will be assigned in file order." + ) + else: + logger.info( + "RandomMappingStrategy initialized with use_consecutive_loading=True; " + "control cells will be sampled as consecutive blocks with random offsets." + ) # Map cell type -> list of control indices. self.split_control_pool = { @@ -105,7 +111,7 @@ def register_split_indices( else: self.split_control_pool[split][ct].extend(ct_indices) - build_mapping = self.cache_perturbation_control_pairs or self.use_consecutive_loading + build_mapping = self.cache_perturbation_control_pairs if build_mapping: logger.info( f"Creating cached perturbation-control mapping for split '{split}' with {len(perturbed_indices)} perturbed cells and {len(control_indices)} control cells" @@ -182,10 +188,11 @@ def get_control_indices( Returns n_basal_samples control indices that are from the same cell type as the perturbed cell. If cache_perturbation_control_pairs is True, uses the pre-computed mapping. - If False, samples new control cells each time (original behavior). + If use_consecutive_loading is True, samples a consecutive block with a random offset. + Otherwise, samples new control cells each time (original behavior). """ - if self.cache_perturbation_control_pairs or self.use_consecutive_loading: + if self.cache_perturbation_control_pairs: # Use cached mapping control_idxs = self.split_control_mapping[split][perturbed_idx] if len(control_idxs) == 0: @@ -193,6 +200,14 @@ def get_control_indices( f"No control cells found in RandomMappingStrategy for cell type '{dataset.get_cell_type(perturbed_idx)}'" ) return np.array(control_idxs) + if self.use_consecutive_loading: + pert_cell_type = dataset.get_cell_type(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + raise ValueError( + f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" + ) + return self._sample_consecutive_controls(pool, self.n_basal_samples) else: # Sample new control cells each time (original behavior) pert_cell_type = dataset.get_cell_type(perturbed_idx) @@ -211,15 +226,22 @@ def get_control_index( Returns a single control index from the same cell type as the perturbed cell. If cache_perturbation_control_pairs is True, uses the pre-computed mapping. - If False, samples a new control cell each time (original behavior). + If use_consecutive_loading is True, samples a consecutive control with a random offset. + Otherwise, samples a new control cell each time (original behavior). """ - if self.cache_perturbation_control_pairs or self.use_consecutive_loading: + if self.cache_perturbation_control_pairs: # Use cached mapping control_idxs = self.split_control_mapping[split][perturbed_idx] if len(control_idxs) == 0: return None return control_idxs[0] + if self.use_consecutive_loading: + pert_cell_type = dataset.get_cell_type(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + return None + return self._sample_consecutive_controls(pool, 1)[0] else: # Sample new control cell each time (original behavior) pert_cell_type = dataset.get_cell_type(perturbed_idx) @@ -227,3 +249,24 @@ def get_control_index( if not pool: return None return self.rng.choice(pool) + + def _sample_consecutive_controls( + self, pool: list[int], n_samples: int + ) -> np.ndarray: + """Return n_samples consecutive control indices with a random start offset.""" + pool_size = len(pool) + if pool_size == 0 or n_samples <= 0: + return np.array([], dtype=np.int64) + if pool_size == 1: + return np.array([pool[0]] * n_samples, dtype=np.int64) + + start = self.rng.randrange(pool_size) + if n_samples == 1: + return np.array([pool[start]], dtype=np.int64) + + if start + n_samples <= pool_size: + return np.array(pool[start : start + n_samples], dtype=np.int64) + + tail = pool[start:] + head = pool[: n_samples - len(tail)] + return np.array(tail + head, dtype=np.int64) From 85eb3f38cced43ddc3a118c985c982cc242283a5 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 3 Feb 2026 21:13:26 +0000 Subject: [PATCH 03/13] two hours per epoch, batched densify, fixed controls --- pyproject.toml | 2 +- src/cell_load/config.py | 2 - .../data_modules/perturbation_dataloader.py | 38 +++- src/cell_load/dataset/_metadata.py | 39 +++- src/cell_load/dataset/_perturbation.py | 192 ++++++++++++++++++ src/cell_load/mapping_strategies/random.py | 99 +++------ 6 files changed, 288 insertions(+), 84 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5dc8eaa..8a1ba22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cell-load" -version = "0.9.1" +version = "0.9.3" description = "Dataloaders for training models on huge single-cell datasets" readme = "README.md" authors = [ diff --git a/src/cell_load/config.py b/src/cell_load/config.py index 6dad97a..ddc430e 100644 --- a/src/cell_load/config.py +++ b/src/cell_load/config.py @@ -70,7 +70,6 @@ def get_fewshot_celltypes(self, dataset: str) -> dict[str, dict[str, list[str]]] if key.startswith(f"{dataset}."): celltype = key.split(".", 1)[1] result[celltype] = pert_config - print(dataset, celltype, {k: len(v) for k, v in pert_config.items()}) return result def validate(self) -> None: @@ -84,7 +83,6 @@ def validate(self) -> None: # Check that dataset paths exist for dataset, path in self.datasets.items(): - print(path) if not Path(path).exists(): logger.warning(f"Dataset path does not exist: {path}") diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index 17d56b7..764658e 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -408,7 +408,8 @@ def _create_dataloader( num_workers=self.num_workers, collate_fn=collate_fn, pin_memory=True, - prefetch_factor=4 if not test and self.num_workers > 0 else None, + prefetch_factor=8 if not test and self.num_workers > 0 else None, + persistent_workers=bool(self.num_workers > 0 and not test), ) def _setup_global_maps(self): @@ -427,7 +428,6 @@ def _setup_global_maps(self): for dataset_name in self.config.get_all_datasets(): dataset_path = Path(self.config.datasets[dataset_name]) files = self._find_dataset_files(dataset_path) - for _fname, fpath in files.items(): with h5py.File(fpath, "r") as f: uns = f.get("uns") @@ -543,9 +543,26 @@ def _setup_datasets(self): Uses H5MetadataCache for faster metadata access. """ - for dataset_name in self.config.get_all_datasets(): + dataset_names = list(self.config.get_all_datasets()) + dataset_files: dict[str, dict[str, Path]] = {} + total_files = 0 + for dataset_name in dataset_names: dataset_path = Path(self.config.datasets[dataset_name]) files = self._find_dataset_files(dataset_path) + dataset_files[dataset_name] = files + total_files += len(files) + + pbar = ( + tqdm(total=total_files, desc="Processing datasets", leave=False) + if total_files > 0 + else None + ) + + for dataset_name in dataset_names: + files = dataset_files[dataset_name] + + if pbar is not None: + pbar.set_description(f"Processing {dataset_name}") # Get configuration for this dataset zeroshot_celltypes = self.config.get_zeroshot_celltypes(dataset_name) @@ -558,9 +575,7 @@ def _setup_datasets(self): logger.info(f" - Fewshot cell types: {list(fewshot_celltypes.keys())}") # Process each file in the dataset - for fname, fpath in tqdm( - list(files.items()), desc=f"Processing {dataset_name}" - ): + for fname, fpath in files.items(): # Create metadata cache cache = GlobalH5MetadataCache().get_cache( str(fpath), @@ -607,12 +622,17 @@ def _setup_datasets(self): val_sum += counts["val"] test_sum += counts["test"] - tqdm.write( - f"Processed {fname}: {train_sum} train, {val_sum} val, {test_sum} test" - ) + if pbar is not None: + pbar.update(1) + pbar.set_postfix_str( + f"{train_sum} train, {val_sum} val, {test_sum} test" + ) logger.info("\n") + if pbar is not None: + pbar.close() + def _split_fewshot_celltype( self, ds: PerturbationDataset, diff --git a/src/cell_load/dataset/_metadata.py b/src/cell_load/dataset/_metadata.py index f0736ea..c1a9dd4 100644 --- a/src/cell_load/dataset/_metadata.py +++ b/src/cell_load/dataset/_metadata.py @@ -1,4 +1,6 @@ -from torch.utils.data import ConcatDataset, Dataset +from bisect import bisect_right + +from torch.utils.data import ConcatDataset, Dataset, Subset class MetadataConcatDataset(ConcatDataset): @@ -25,3 +27,38 @@ def __init__(self, datasets: list[Dataset]): raise ValueError( "All datasets must share the same embed_key, control_pert, pert_col, and batch_col" ) + + def __getitems__(self, indices): + """ + Batch-aware fetch to enable fast-path loading when supported by datasets. + Falls back to per-item access for non-consecutive loading. + """ + if not getattr(self.base.mapping_strategy, "use_consecutive_loading", False): + return [self[i] for i in indices] + + results = [None] * len(indices) + grouped = {} + + for out_pos, idx in enumerate(indices): + dataset_idx = bisect_right(self.cumulative_sizes, idx) + sample_idx = ( + idx if dataset_idx == 0 else idx - self.cumulative_sizes[dataset_idx - 1] + ) + grouped.setdefault(dataset_idx, []).append((out_pos, sample_idx)) + + for dataset_idx, pos_samples in grouped.items(): + ds = self.datasets[dataset_idx] + positions, sample_indices = zip(*pos_samples) + + if isinstance(ds, Subset) and hasattr(ds.dataset, "__getitems__"): + underlying_indices = [ds.indices[i] for i in sample_indices] + samples = ds.dataset.__getitems__(underlying_indices) + elif hasattr(ds, "__getitems__"): + samples = ds.__getitems__(list(sample_indices)) + else: + samples = [ds[i] for i in sample_indices] + + for pos, sample in zip(positions, samples): + results[pos] = sample + + return results diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index b11e64b..ebc6952 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -250,6 +250,114 @@ def __getitem__(self, idx: int): return sample + def __getitems__(self, indices): + """ + Batch-aware fetch for consecutive loading with batched CSR densification. + Falls back to per-item access when not applicable. + """ + if not self._use_batched_fetch(): + return [self.__getitem__(int(i)) for i in indices] + + idx_arr = np.asarray(indices, dtype=np.int64) + if idx_arr.size == 0: + return [] + + file_indices = self.all_indices[idx_arr] + splits = [self._find_split_for_idx(int(i)) for i in file_indices] + + ctrl_indices = [] + for file_idx, split in zip(file_indices, splits): + ctrl_idx = self.mapping_strategy.get_control_index( + self, split, int(file_idx) + ) + if ctrl_idx is None: + raise ValueError( + f"No control cells found for cell type '{self.get_cell_type(file_idx)}'" + ) + ctrl_indices.append(int(ctrl_idx)) + + ctrl_indices_arr = np.asarray(ctrl_indices, dtype=np.int64) + + pert_expr_batch = self._fetch_gene_expression_batch(file_indices) + ctrl_expr_batch = self._fetch_gene_expression_batch(ctrl_indices_arr) + + samples = [] + for i, file_idx in enumerate(file_indices): + pert_expr = pert_expr_batch[i] + ctrl_expr = ctrl_expr_batch[i] + ctrl_idx = ctrl_indices_arr[i] + + pert_code = self.metadata_cache.pert_codes[file_idx] + pert_name = self.pert_categories[pert_code] + pert_onehot = ( + self.pert_onehot_map.get(pert_name) if self.pert_onehot_map else None + ) + + cell_type = self.cell_type_categories[ + self.metadata_cache.cell_type_codes[file_idx] + ] + cell_type_onehot = ( + self.cell_type_onehot_map.get(cell_type) + if self.cell_type_onehot_map + else None + ) + + batch_code = self.metadata_cache.batch_codes[file_idx] + batch_name = self.metadata_cache.batch_categories[batch_code] + batch_onehot = ( + self.batch_onehot_map.get(batch_name) if self.batch_onehot_map else None + ) + + sample = { + "pert_cell_emb": pert_expr, + "ctrl_cell_emb": ctrl_expr, + "pert_emb": pert_onehot, + "pert_name": pert_name, + "batch_name": batch_name, + "batch": batch_onehot, + "cell_type": cell_type, + "cell_type_onehot": cell_type_onehot, + } + + if self.store_raw_expression and self.output_space != "embedding": + if self.output_space == "gene": + sample["pert_cell_counts"] = self.fetch_obsm_expression( + file_idx, "X_hvg" + ) + elif self.output_space == "all": + sample["pert_cell_counts"] = pert_expr + + if self.store_raw_basal and self.output_space != "embedding": + if self.output_space == "gene": + sample["ctrl_cell_counts"] = self.fetch_obsm_expression( + ctrl_idx, "X_hvg" + ) + elif self.output_space == "all": + sample["ctrl_cell_counts"] = ctrl_expr + + if self.barcode and self.cell_barcodes is not None: + sample["pert_cell_barcode"] = self.cell_barcodes[file_idx] + sample["ctrl_cell_barcode"] = self.cell_barcodes[ctrl_idx] + + if self.additional_obs: + for obs_key in self.additional_obs: + sample[obs_key] = self._fetch_obs_value(file_idx, obs_key) + + samples.append(sample) + + return samples + + def _use_batched_fetch(self) -> bool: + if not getattr(self.mapping_strategy, "use_consecutive_loading", False): + return False + if self.embed_key is not None: + return False + if self.output_space != "all": + return False + if self.downsample is not None and self.downsample < 1.0: + return False + return True + def _validate_additional_obs(self, additional_obs: list[str] | None) -> list[str]: if additional_obs is None: return [] @@ -483,6 +591,90 @@ def fetch_gene_expression(self, idx: int) -> torch.Tensor: data = self._fetch_gene_expression_raw(idx) return self._maybe_downsample_counts(data) + def _fetch_gene_expression_batch(self, indices: np.ndarray) -> torch.Tensor: + """ + Fetch raw gene counts for multiple indices at once (CSR fast path). + """ + if indices.size == 0: + return torch.empty((0, self.n_genes), dtype=torch.float32) + + attrs = dict(self.h5_file["X"].attrs) + if attrs.get("encoding-type") != "csr_matrix": + row_data = self.h5_file["/X"][indices] + return torch.tensor(row_data, dtype=torch.float32) + + indptr_ds = self.h5_file["/X/indptr"] + data_ds = self.h5_file["/X/data"] + indices_ds = self.h5_file["/X/indices"] + + order = np.argsort(indices) + sorted_rows = indices[order] + dense_sorted = np.zeros((len(sorted_rows), self.n_genes), dtype=np.float32) + + run_start = 0 + for i in range(1, len(sorted_rows)): + if sorted_rows[i] != sorted_rows[i - 1] + 1: + self._fill_dense_run( + sorted_rows, + run_start, + i, + dense_sorted, + indptr_ds, + data_ds, + indices_ds, + ) + run_start = i + + self._fill_dense_run( + sorted_rows, + run_start, + len(sorted_rows), + dense_sorted, + indptr_ds, + data_ds, + indices_ds, + ) + + inv_order = np.empty_like(order) + inv_order[order] = np.arange(len(order)) + dense = dense_sorted[inv_order] + + return torch.from_numpy(dense) + + def _fill_dense_run( + self, + sorted_rows: np.ndarray, + start: int, + end: int, + dense_sorted: np.ndarray, + indptr_ds, + data_ds, + indices_ds, + ) -> None: + if start >= end: + return + + row_start = int(sorted_rows[start]) + row_end = int(sorted_rows[end - 1]) + indptr_slice = indptr_ds[row_start : row_end + 2] + base_ptr = int(indptr_slice[0]) + end_ptr = int(indptr_slice[-1]) + + if end_ptr <= base_ptr: + return + + block_data = np.asarray(data_ds[base_ptr:end_ptr], dtype=np.float32) + block_indices = np.asarray(indices_ds[base_ptr:end_ptr], dtype=np.int64) + + for offset, row in enumerate(sorted_rows[start:end]): + ptr_start = int(indptr_slice[offset] - base_ptr) + ptr_end = int(indptr_slice[offset + 1] - base_ptr) + if ptr_end <= ptr_start: + continue + dense_sorted[start + offset, block_indices[ptr_start:ptr_end]] = ( + block_data[ptr_start:ptr_end] + ) + @lru_cache(maxsize=10000) def fetch_obsm_expression(self, idx: int, key: str) -> torch.Tensor: """ diff --git a/src/cell_load/mapping_strategies/random.py b/src/cell_load/mapping_strategies/random.py index 27dde8b..04eacd4 100644 --- a/src/cell_load/mapping_strategies/random.py +++ b/src/cell_load/mapping_strategies/random.py @@ -47,16 +47,10 @@ def __init__( f"Warning: If using n_basal_samples > 1, use the original behavior by setting cache_perturbation_control_pairs=False" ) if self.use_consecutive_loading: - if self.cache_perturbation_control_pairs: - logger.info( - "RandomMappingStrategy initialized with use_consecutive_loading=True; " - "control mappings will be assigned in file order." - ) - else: - logger.info( - "RandomMappingStrategy initialized with use_consecutive_loading=True; " - "control cells will be sampled as consecutive blocks with random offsets." - ) + logger.info( + "RandomMappingStrategy initialized with use_consecutive_loading=True; " + "control mappings will be assigned in file order." + ) # Map cell type -> list of control indices. self.split_control_pool = { @@ -111,7 +105,7 @@ def register_split_indices( else: self.split_control_pool[split][ct].extend(ct_indices) - build_mapping = self.cache_perturbation_control_pairs + build_mapping = self.cache_perturbation_control_pairs or self.use_consecutive_loading if build_mapping: logger.info( f"Creating cached perturbation-control mapping for split '{split}' with {len(perturbed_indices)} perturbed cells and {len(control_indices)} control cells" @@ -187,12 +181,12 @@ def get_control_indices( """ Returns n_basal_samples control indices that are from the same cell type as the perturbed cell. - If cache_perturbation_control_pairs is True, uses the pre-computed mapping. - If use_consecutive_loading is True, samples a consecutive block with a random offset. - Otherwise, samples new control cells each time (original behavior). + If cache_perturbation_control_pairs is True (or use_consecutive_loading is enabled), + uses the pre-computed mapping. Otherwise, samples new control cells each time + (original behavior). """ - if self.cache_perturbation_control_pairs: + if self.cache_perturbation_control_pairs or self.use_consecutive_loading: # Use cached mapping control_idxs = self.split_control_mapping[split][perturbed_idx] if len(control_idxs) == 0: @@ -200,24 +194,15 @@ def get_control_indices( f"No control cells found in RandomMappingStrategy for cell type '{dataset.get_cell_type(perturbed_idx)}'" ) return np.array(control_idxs) - if self.use_consecutive_loading: - pert_cell_type = dataset.get_cell_type(perturbed_idx) - pool = self.split_control_pool[split].get(pert_cell_type, None) - if not pool: - raise ValueError( - f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" - ) - return self._sample_consecutive_controls(pool, self.n_basal_samples) - else: - # Sample new control cells each time (original behavior) - pert_cell_type = dataset.get_cell_type(perturbed_idx) - pool = self.split_control_pool[split].get(pert_cell_type, None) - if not pool: - raise ValueError( - f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" - ) - control_idxs = self.rng.choices(pool, k=self.n_basal_samples) - return np.array(control_idxs) + # Sample new control cells each time (original behavior) + pert_cell_type = dataset.get_cell_type(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + raise ValueError( + f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" + ) + control_idxs = self.rng.choices(pool, k=self.n_basal_samples) + return np.array(control_idxs) def get_control_index( self, dataset: "PerturbationDataset", split: str, perturbed_idx: int @@ -225,48 +210,20 @@ def get_control_index( """ Returns a single control index from the same cell type as the perturbed cell. - If cache_perturbation_control_pairs is True, uses the pre-computed mapping. - If use_consecutive_loading is True, samples a consecutive control with a random offset. - Otherwise, samples a new control cell each time (original behavior). + If cache_perturbation_control_pairs is True (or use_consecutive_loading is enabled), + uses the pre-computed mapping. Otherwise, samples a new control cell each time + (original behavior). """ - if self.cache_perturbation_control_pairs: + if self.cache_perturbation_control_pairs or self.use_consecutive_loading: # Use cached mapping control_idxs = self.split_control_mapping[split][perturbed_idx] if len(control_idxs) == 0: return None return control_idxs[0] - if self.use_consecutive_loading: - pert_cell_type = dataset.get_cell_type(perturbed_idx) - pool = self.split_control_pool[split].get(pert_cell_type, None) - if not pool: - return None - return self._sample_consecutive_controls(pool, 1)[0] - else: - # Sample new control cell each time (original behavior) - pert_cell_type = dataset.get_cell_type(perturbed_idx) - pool = self.split_control_pool[split].get(pert_cell_type, None) - if not pool: - return None - return self.rng.choice(pool) - - def _sample_consecutive_controls( - self, pool: list[int], n_samples: int - ) -> np.ndarray: - """Return n_samples consecutive control indices with a random start offset.""" - pool_size = len(pool) - if pool_size == 0 or n_samples <= 0: - return np.array([], dtype=np.int64) - if pool_size == 1: - return np.array([pool[0]] * n_samples, dtype=np.int64) - - start = self.rng.randrange(pool_size) - if n_samples == 1: - return np.array([pool[start]], dtype=np.int64) - - if start + n_samples <= pool_size: - return np.array(pool[start : start + n_samples], dtype=np.int64) - - tail = pool[start:] - head = pool[: n_samples - len(tail)] - return np.array(tail + head, dtype=np.int64) + # Sample new control cell each time (original behavior) + pert_cell_type = dataset.get_cell_type(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + return None + return self.rng.choice(pool) From 323429dca1f86ed6066c365e5d425fe4252fb5d4 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 3 Feb 2026 22:05:42 +0000 Subject: [PATCH 04/13] updated control cells to load contiguously, with batched densification, and with persistent workers, for 12x speedup --- .../data_modules/perturbation_dataloader.py | 1 + src/cell_load/dataset/_perturbation.py | 214 +++++++++++++++--- src/cell_load/mapping_strategies/random.py | 99 +++++--- 3 files changed, 251 insertions(+), 63 deletions(-) diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index 764658e..b4f4735 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -535,6 +535,7 @@ def _create_base_dataset( additional_obs=self.additional_obs, downsample=self.downsample, is_log1p=self.is_log1p, + cell_sentence_len=self.cell_sentence_len, ) def _setup_datasets(self): diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index ebc6952..5510ce0 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -46,6 +46,7 @@ def __init__( additional_obs: list[str] | None = None, downsample: float | None = None, is_log1p: bool = False, + cell_sentence_len: int | None = None, **kwargs, ): """ @@ -70,6 +71,7 @@ def __init__( additional_obs: Optional list of obs column names to include in each sample downsample: Fraction of counts to retain via binomial downsampling (only for output_space="all") is_log1p: Whether raw counts in X are log1p-transformed (affects downsampling) + cell_sentence_len: Optional sentence length for consecutive loading batches **kwargs: Additional options (e.g. output_space) """ super().__init__() @@ -107,6 +109,7 @@ def __init__( raise ValueError(f"downsample must be in (0, 1]; got {downsample!r}") self.downsample = downsample self.is_log1p = bool(is_log1p) + self.cell_sentence_len = cell_sentence_len self.additional_obs = self._validate_additional_obs(additional_obs) # Load metadata cache and open file @@ -266,20 +269,114 @@ def __getitems__(self, indices): splits = [self._find_split_for_idx(int(i)) for i in file_indices] ctrl_indices = [] - for file_idx, split in zip(file_indices, splits): - ctrl_idx = self.mapping_strategy.get_control_index( - self, split, int(file_idx) - ) - if ctrl_idx is None: - raise ValueError( - f"No control cells found for cell type '{self.get_cell_type(file_idx)}'" + missing_ctrl = [] + sentence_len = self.cell_sentence_len + use_sentence_blocks = ( + sentence_len is not None + and sentence_len > 0 + and len(file_indices) % sentence_len == 0 + and getattr(self.mapping_strategy, "use_consecutive_loading", False) + and hasattr(self.mapping_strategy, "_sample_consecutive_controls") + and hasattr(self.mapping_strategy, "split_control_pool") + ) + + if use_sentence_blocks: + for start in range(0, len(file_indices), sentence_len): + sentence_idx = file_indices[start : start + sentence_len] + split = splits[start] + cell_type = self.get_cell_type(sentence_idx[0]) + pool = self.mapping_strategy.split_control_pool[split].get( + cell_type, None ) - ctrl_indices.append(int(ctrl_idx)) + if not pool: + raise ValueError( + f"No control cells found in RandomMappingStrategy for cell type '{cell_type}'" + ) + block = self.mapping_strategy._sample_consecutive_controls( + pool, len(sentence_idx) + ) + ctrl_indices.extend(block.tolist()) + missing_ctrl.extend([False] * len(block)) + else: + for file_idx, split in zip(file_indices, splits): + ctrl_idx = self.mapping_strategy.get_control_index( + self, split, int(file_idx) + ) + if ctrl_idx is None: + ctrl_indices.append(-1) + missing_ctrl.append(True) + else: + ctrl_indices.append(int(ctrl_idx)) + missing_ctrl.append(False) ctrl_indices_arr = np.asarray(ctrl_indices, dtype=np.int64) + missing_ctrl_mask = np.asarray(missing_ctrl, dtype=bool) + + if missing_ctrl_mask.any(): + missing_pos = int(np.flatnonzero(missing_ctrl_mask)[0]) + missing_file_idx = int(file_indices[missing_pos]) + if not self.embed_key: + raise ValueError( + f"No control cells found for cell type '{self.get_cell_type(missing_file_idx)}'" + ) + if ( + (self.store_raw_basal and self.output_space != "embedding") + or (self.barcode and self.cell_barcodes is not None) + ): + raise ValueError( + f"No control cells found for cell type '{self.get_cell_type(missing_file_idx)}'" + ) + + if self.embed_key: + pert_expr_batch = self._fetch_obsm_expression_batch( + file_indices, self.embed_key + ) + if missing_ctrl_mask.any(): + ctrl_expr_batch = torch.zeros_like(pert_expr_batch) + valid_positions = np.flatnonzero(~missing_ctrl_mask) + if valid_positions.size: + valid_ctrl_indices = ctrl_indices_arr[valid_positions] + valid_positions_t = torch.from_numpy( + valid_positions.astype(np.int64) + ) + ctrl_expr_batch[valid_positions_t] = ( + self._fetch_obsm_expression_batch( + valid_ctrl_indices, self.embed_key + ) + ) + else: + ctrl_expr_batch = self._fetch_obsm_expression_batch( + ctrl_indices_arr, self.embed_key + ) + else: + pert_expr_batch = self._fetch_gene_expression_batch(file_indices) + ctrl_expr_batch = self._fetch_gene_expression_batch(ctrl_indices_arr) + + pert_counts_batch = None + ctrl_counts_batch = None + if self.store_raw_expression and self.output_space != "embedding": + if self.output_space == "gene": + pert_counts_batch = self._fetch_obsm_expression_batch( + file_indices, "X_hvg" + ) + elif self.output_space == "all": + if self.embed_key: + pert_counts_batch = self._fetch_gene_expression_batch(file_indices) + else: + pert_counts_batch = pert_expr_batch - pert_expr_batch = self._fetch_gene_expression_batch(file_indices) - ctrl_expr_batch = self._fetch_gene_expression_batch(ctrl_indices_arr) + if self.store_raw_basal and self.output_space != "embedding": + if self.output_space == "gene": + ctrl_counts_batch = self._fetch_obsm_expression_batch( + ctrl_indices_arr, "X_hvg" + ) + elif self.output_space == "all": + if self.embed_key: + ctrl_counts_batch = self._fetch_gene_expression_batch( + ctrl_indices_arr + ) + else: + ctrl_counts_batch = ctrl_expr_batch samples = [] for i, file_idx in enumerate(file_indices): @@ -319,21 +416,11 @@ def __getitems__(self, indices): "cell_type_onehot": cell_type_onehot, } - if self.store_raw_expression and self.output_space != "embedding": - if self.output_space == "gene": - sample["pert_cell_counts"] = self.fetch_obsm_expression( - file_idx, "X_hvg" - ) - elif self.output_space == "all": - sample["pert_cell_counts"] = pert_expr + if pert_counts_batch is not None: + sample["pert_cell_counts"] = pert_counts_batch[i] - if self.store_raw_basal and self.output_space != "embedding": - if self.output_space == "gene": - sample["ctrl_cell_counts"] = self.fetch_obsm_expression( - ctrl_idx, "X_hvg" - ) - elif self.output_space == "all": - sample["ctrl_cell_counts"] = ctrl_expr + if ctrl_counts_batch is not None: + sample["ctrl_cell_counts"] = ctrl_counts_batch[i] if self.barcode and self.cell_barcodes is not None: sample["pert_cell_barcode"] = self.cell_barcodes[file_idx] @@ -348,15 +435,7 @@ def __getitems__(self, indices): return samples def _use_batched_fetch(self) -> bool: - if not getattr(self.mapping_strategy, "use_consecutive_loading", False): - return False - if self.embed_key is not None: - return False - if self.output_space != "all": - return False - if self.downsample is not None and self.downsample < 1.0: - return False - return True + return bool(getattr(self.mapping_strategy, "use_consecutive_loading", False)) def _validate_additional_obs(self, additional_obs: list[str] | None) -> list[str]: if additional_obs is None: @@ -560,6 +639,25 @@ def _maybe_downsample_counts(self, counts: torch.Tensor) -> torch.Tensor: sampled = np.log1p(sampled) return torch.tensor(sampled, dtype=torch.float32) + def _maybe_downsample_counts_array(self, counts: np.ndarray) -> np.ndarray: + if ( + self.downsample is None + or self.downsample >= 1.0 + or self.output_space != "all" + ): + return counts + + if self.is_log1p: + counts_lin = np.expm1(counts) + counts_int = np.rint(counts_lin).astype(np.int64) + else: + counts_int = counts.astype(np.int64) + counts_int = np.maximum(counts_int, 0) + sampled = self.rng.binomial(counts_int, self.downsample) + if self.is_log1p: + sampled = np.log1p(sampled) + return sampled.astype(np.float32) + def fetch_gene_expression(self, idx: int) -> torch.Tensor: """ Fetch raw gene counts for a given cell index, applying optional downsampling. @@ -591,6 +689,38 @@ def fetch_gene_expression(self, idx: int) -> torch.Tensor: data = self._fetch_gene_expression_raw(idx) return self._maybe_downsample_counts(data) + def _fetch_dense_matrix_batch( + self, ds, indices: np.ndarray, n_cols: int + ) -> np.ndarray: + if indices.size == 0: + return np.empty((0, n_cols), dtype=np.float32) + + order = np.argsort(indices) + sorted_rows = indices[order] + dense_sorted = np.zeros((len(sorted_rows), n_cols), dtype=np.float32) + + run_start = 0 + for i in range(1, len(sorted_rows)): + if sorted_rows[i] != sorted_rows[i - 1] + 1: + row_start = int(sorted_rows[run_start]) + row_end = int(sorted_rows[i - 1]) + block = np.asarray(ds[row_start : row_end + 1], dtype=np.float32) + if block.ndim == 1: + block = block[:, None] + dense_sorted[run_start:i] = block + run_start = i + + row_start = int(sorted_rows[run_start]) + row_end = int(sorted_rows[-1]) + block = np.asarray(ds[row_start : row_end + 1], dtype=np.float32) + if block.ndim == 1: + block = block[:, None] + dense_sorted[run_start:len(sorted_rows)] = block + + inv_order = np.empty_like(order) + inv_order[order] = np.arange(len(order)) + return dense_sorted[inv_order] + def _fetch_gene_expression_batch(self, indices: np.ndarray) -> torch.Tensor: """ Fetch raw gene counts for multiple indices at once (CSR fast path). @@ -600,8 +730,11 @@ def _fetch_gene_expression_batch(self, indices: np.ndarray) -> torch.Tensor: attrs = dict(self.h5_file["X"].attrs) if attrs.get("encoding-type") != "csr_matrix": - row_data = self.h5_file["/X"][indices] - return torch.tensor(row_data, dtype=torch.float32) + dense = self._fetch_dense_matrix_batch( + self.h5_file["/X"], indices, self.n_genes + ) + dense = self._maybe_downsample_counts_array(dense) + return torch.from_numpy(dense) indptr_ds = self.h5_file["/X/indptr"] data_ds = self.h5_file["/X/data"] @@ -638,7 +771,18 @@ def _fetch_gene_expression_batch(self, indices: np.ndarray) -> torch.Tensor: inv_order = np.empty_like(order) inv_order[order] = np.arange(len(order)) dense = dense_sorted[inv_order] + dense = self._maybe_downsample_counts_array(dense) + + return torch.from_numpy(dense) + def _fetch_obsm_expression_batch( + self, indices: np.ndarray, key: str + ) -> torch.Tensor: + ds = self.h5_file[f"/obsm/{key}"] + n_cols = int(ds.shape[1]) if ds.ndim > 1 else 1 + if indices.size == 0: + return torch.empty((0, n_cols), dtype=torch.float32) + dense = self._fetch_dense_matrix_batch(ds, indices, n_cols) return torch.from_numpy(dense) def _fill_dense_run( diff --git a/src/cell_load/mapping_strategies/random.py b/src/cell_load/mapping_strategies/random.py index 04eacd4..27dde8b 100644 --- a/src/cell_load/mapping_strategies/random.py +++ b/src/cell_load/mapping_strategies/random.py @@ -47,10 +47,16 @@ def __init__( f"Warning: If using n_basal_samples > 1, use the original behavior by setting cache_perturbation_control_pairs=False" ) if self.use_consecutive_loading: - logger.info( - "RandomMappingStrategy initialized with use_consecutive_loading=True; " - "control mappings will be assigned in file order." - ) + if self.cache_perturbation_control_pairs: + logger.info( + "RandomMappingStrategy initialized with use_consecutive_loading=True; " + "control mappings will be assigned in file order." + ) + else: + logger.info( + "RandomMappingStrategy initialized with use_consecutive_loading=True; " + "control cells will be sampled as consecutive blocks with random offsets." + ) # Map cell type -> list of control indices. self.split_control_pool = { @@ -105,7 +111,7 @@ def register_split_indices( else: self.split_control_pool[split][ct].extend(ct_indices) - build_mapping = self.cache_perturbation_control_pairs or self.use_consecutive_loading + build_mapping = self.cache_perturbation_control_pairs if build_mapping: logger.info( f"Creating cached perturbation-control mapping for split '{split}' with {len(perturbed_indices)} perturbed cells and {len(control_indices)} control cells" @@ -181,12 +187,12 @@ def get_control_indices( """ Returns n_basal_samples control indices that are from the same cell type as the perturbed cell. - If cache_perturbation_control_pairs is True (or use_consecutive_loading is enabled), - uses the pre-computed mapping. Otherwise, samples new control cells each time - (original behavior). + If cache_perturbation_control_pairs is True, uses the pre-computed mapping. + If use_consecutive_loading is True, samples a consecutive block with a random offset. + Otherwise, samples new control cells each time (original behavior). """ - if self.cache_perturbation_control_pairs or self.use_consecutive_loading: + if self.cache_perturbation_control_pairs: # Use cached mapping control_idxs = self.split_control_mapping[split][perturbed_idx] if len(control_idxs) == 0: @@ -194,15 +200,24 @@ def get_control_indices( f"No control cells found in RandomMappingStrategy for cell type '{dataset.get_cell_type(perturbed_idx)}'" ) return np.array(control_idxs) - # Sample new control cells each time (original behavior) - pert_cell_type = dataset.get_cell_type(perturbed_idx) - pool = self.split_control_pool[split].get(pert_cell_type, None) - if not pool: - raise ValueError( - f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" - ) - control_idxs = self.rng.choices(pool, k=self.n_basal_samples) - return np.array(control_idxs) + if self.use_consecutive_loading: + pert_cell_type = dataset.get_cell_type(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + raise ValueError( + f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" + ) + return self._sample_consecutive_controls(pool, self.n_basal_samples) + else: + # Sample new control cells each time (original behavior) + pert_cell_type = dataset.get_cell_type(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + raise ValueError( + f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" + ) + control_idxs = self.rng.choices(pool, k=self.n_basal_samples) + return np.array(control_idxs) def get_control_index( self, dataset: "PerturbationDataset", split: str, perturbed_idx: int @@ -210,20 +225,48 @@ def get_control_index( """ Returns a single control index from the same cell type as the perturbed cell. - If cache_perturbation_control_pairs is True (or use_consecutive_loading is enabled), - uses the pre-computed mapping. Otherwise, samples a new control cell each time - (original behavior). + If cache_perturbation_control_pairs is True, uses the pre-computed mapping. + If use_consecutive_loading is True, samples a consecutive control with a random offset. + Otherwise, samples a new control cell each time (original behavior). """ - if self.cache_perturbation_control_pairs or self.use_consecutive_loading: + if self.cache_perturbation_control_pairs: # Use cached mapping control_idxs = self.split_control_mapping[split][perturbed_idx] if len(control_idxs) == 0: return None return control_idxs[0] - # Sample new control cell each time (original behavior) - pert_cell_type = dataset.get_cell_type(perturbed_idx) - pool = self.split_control_pool[split].get(pert_cell_type, None) - if not pool: - return None - return self.rng.choice(pool) + if self.use_consecutive_loading: + pert_cell_type = dataset.get_cell_type(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + return None + return self._sample_consecutive_controls(pool, 1)[0] + else: + # Sample new control cell each time (original behavior) + pert_cell_type = dataset.get_cell_type(perturbed_idx) + pool = self.split_control_pool[split].get(pert_cell_type, None) + if not pool: + return None + return self.rng.choice(pool) + + def _sample_consecutive_controls( + self, pool: list[int], n_samples: int + ) -> np.ndarray: + """Return n_samples consecutive control indices with a random start offset.""" + pool_size = len(pool) + if pool_size == 0 or n_samples <= 0: + return np.array([], dtype=np.int64) + if pool_size == 1: + return np.array([pool[0]] * n_samples, dtype=np.int64) + + start = self.rng.randrange(pool_size) + if n_samples == 1: + return np.array([pool[start]], dtype=np.int64) + + if start + n_samples <= pool_size: + return np.array(pool[start : start + n_samples], dtype=np.int64) + + tail = pool[start:] + head = pool[: n_samples - len(tail)] + return np.array(tail + head, dtype=np.int64) From 9237475a6edd279edb2d0800f2ee2fd306680bb6 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Wed, 4 Feb 2026 22:34:01 +0000 Subject: [PATCH 05/13] finalized changes for consecutive data loader speed --- pyproject.toml | 2 +- .../data_modules/perturbation_dataloader.py | 37 +++- src/cell_load/dataset/_metadata.py | 9 + src/cell_load/dataset/_perturbation.py | 174 +++++++++++++++++- src/cell_load/mapping_strategies/batch.py | 24 +-- src/cell_load/mapping_strategies/random.py | 30 +-- 6 files changed, 237 insertions(+), 39 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8a1ba22..d332ddb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cell-load" -version = "0.9.3" +version = "0.10.0" description = "Dataloaders for training models on huge single-cell datasets" readme = "README.md" authors = [ diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index b4f4735..8271a76 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -1,4 +1,5 @@ import logging +import os import glob import re @@ -10,7 +11,7 @@ import numpy as np import torch from lightning.pytorch import LightningDataModule -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader, Dataset, get_worker_info from tqdm import tqdm from ..config import ExperimentConfig @@ -26,6 +27,23 @@ logger = logging.getLogger(__name__) +def _worker_init_fn(worker_id: int) -> None: + for var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS"): + os.environ.setdefault(var, "1") + try: + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + except RuntimeError: + pass + + worker_info = get_worker_info() + if worker_info is None: + return + dataset = worker_info.dataset + if hasattr(dataset, "ensure_h5_open"): + dataset.ensure_h5_open() + + class PerturbationDataModule(LightningDataModule): """ A unified data module that sets up train/val/test splits for multiple dataset/celltype @@ -55,6 +73,7 @@ def __init__( drop_last: bool = False, additional_obs: list[str] | None = None, use_consecutive_loading: bool = False, + h5_open_kwargs: dict | None = None, **kwargs, # missing perturbation_features_file and store_raw_basal for backwards compatibility ): """ @@ -77,6 +96,7 @@ def __init__( cache_perturbation_control_pairs: If True cache perturbation-control pairs at the start of training and reuse them. drop_last: Whether to drop the last sentence set if it is smaller than cell_sentence_len use_consecutive_loading: Whether to form cell sets from consecutive indices for faster IO + h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes) """ super().__init__() @@ -121,6 +141,9 @@ def __init__( self.store_raw_basal = kwargs.get("store_raw_basal", False) self.barcode = kwargs.get("barcode", False) self.additional_obs = additional_obs + self.h5_open_kwargs = h5_open_kwargs + if self.use_consecutive_loading: + self._set_h5_cache_env_defaults() logger.info( f"Initializing DataModule: batch_size={batch_size}, workers={num_workers}, " @@ -163,6 +186,12 @@ def _get_reference_dataset(self) -> PerturbationDataset: return datasets[0].dataset raise ValueError("No datasets available to extract metadata.") + @staticmethod + def _set_h5_cache_env_defaults() -> None: + os.environ.setdefault("CELL_LOAD_H5_RDCC_NBYTES", str(64 * 1024 * 1024)) + os.environ.setdefault("CELL_LOAD_H5_RDCC_NSLOTS", "1000003") + os.environ.setdefault("CELL_LOAD_H5_RDCC_W0", "0.75") + def get_var_names(self): """ Get the variable names (gene names) from the first available dataset. @@ -219,6 +248,7 @@ def save_state(self, filepath: str): "barcode": self.barcode, "additional_obs": self.additional_obs, "use_consecutive_loading": self.use_consecutive_loading, + "h5_open_kwargs": self.h5_open_kwargs, } torch.save(save_dict, filepath) @@ -262,6 +292,7 @@ def load_state(cls, filepath: str): "store_raw_basal": save_dict.pop("store_raw_basal", False), "barcode": save_dict.pop("barcode", True), "use_consecutive_loading": save_dict.pop("use_consecutive_loading", False), + "h5_open_kwargs": save_dict.pop("h5_open_kwargs", None), } # Create new instance with all the saved parameters @@ -408,8 +439,9 @@ def _create_dataloader( num_workers=self.num_workers, collate_fn=collate_fn, pin_memory=True, - prefetch_factor=8 if not test and self.num_workers > 0 else None, + prefetch_factor=4 if not test and self.num_workers > 0 else None, persistent_workers=bool(self.num_workers > 0 and not test), + worker_init_fn=_worker_init_fn if self.num_workers > 0 else None, ) def _setup_global_maps(self): @@ -536,6 +568,7 @@ def _create_base_dataset( downsample=self.downsample, is_log1p=self.is_log1p, cell_sentence_len=self.cell_sentence_len, + h5_open_kwargs=self.h5_open_kwargs, ) def _setup_datasets(self): diff --git a/src/cell_load/dataset/_metadata.py b/src/cell_load/dataset/_metadata.py index c1a9dd4..2038e3e 100644 --- a/src/cell_load/dataset/_metadata.py +++ b/src/cell_load/dataset/_metadata.py @@ -62,3 +62,12 @@ def __getitems__(self, indices): results[pos] = sample return results + + def ensure_h5_open(self) -> None: + """ + Ensure all underlying H5 files are open in the current process. + """ + for ds in self.datasets: + base = ds.dataset if isinstance(ds, Subset) else ds + if hasattr(base, "ensure_h5_open"): + base.ensure_h5_open() diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index 5510ce0..a4c9049 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -1,4 +1,5 @@ import logging +import os from pathlib import Path from functools import lru_cache @@ -47,6 +48,7 @@ def __init__( downsample: float | None = None, is_log1p: bool = False, cell_sentence_len: int | None = None, + h5_open_kwargs: dict | None = None, **kwargs, ): """ @@ -72,6 +74,7 @@ def __init__( downsample: Fraction of counts to retain via binomial downsampling (only for output_space="all") is_log1p: Whether raw counts in X are log1p-transformed (affects downsampling) cell_sentence_len: Optional sentence length for consecutive loading batches + h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes) **kwargs: Additional options (e.g. output_space) """ super().__init__() @@ -110,13 +113,16 @@ def __init__( self.downsample = downsample self.is_log1p = bool(is_log1p) self.cell_sentence_len = cell_sentence_len + self.h5_open_kwargs = self._normalize_h5_open_kwargs(h5_open_kwargs) self.additional_obs = self._validate_additional_obs(additional_obs) # Load metadata cache and open file self.metadata_cache = GlobalH5MetadataCache().get_cache( str(self.h5_path), pert_col, cell_type_key, control_pert, batch_col ) - self.h5_file = h5py.File(self.h5_path, "r") + self.h5_file = None + self._h5_pid = None + self._open_h5_file() # Load cell barcodes if requested if self.barcode: @@ -138,6 +144,94 @@ def __init__( splits = ["train", "train_eval", "val", "test"] self.split_perturbed_indices = {s: set() for s in splits} self.split_control_indices = {s: set() for s in splits} + self._init_split_index_cache() + + def _init_split_index_cache(self) -> None: + self._split_code_to_name = ("train", "train_eval", "val", "test") + self._split_name_to_code = { + name: idx for idx, name in enumerate(self._split_code_to_name) + } + self._index_to_split_code = np.full(self.n_cells, -1, dtype=np.int8) + for split, indices in self.split_perturbed_indices.items(): + if indices: + self._index_to_split_code[list(indices)] = self._split_name_to_code[split] + for split, indices in self.split_control_indices.items(): + if indices: + self._index_to_split_code[list(indices)] = self._split_name_to_code[split] + + @staticmethod + def _parse_env_int(name: str, default: int) -> int: + raw = os.getenv(name) + if raw is None: + return default + try: + return int(float(raw)) + except ValueError: + logger.warning("Invalid %s=%r; using %d", name, raw, default) + return default + + @staticmethod + def _parse_env_float(name: str, default: float) -> float: + raw = os.getenv(name) + if raw is None: + return default + try: + return float(raw) + except ValueError: + logger.warning("Invalid %s=%r; using %.3f", name, raw, default) + return default + + def _default_h5_open_kwargs(self) -> dict: + rdcc_nbytes = self._parse_env_int( + "CELL_LOAD_H5_RDCC_NBYTES", 64 * 1024 * 1024 + ) + rdcc_nslots = self._parse_env_int("CELL_LOAD_H5_RDCC_NSLOTS", 1_000_003) + rdcc_w0 = self._parse_env_float("CELL_LOAD_H5_RDCC_W0", 0.75) + kwargs = { + "rdcc_nbytes": rdcc_nbytes, + "rdcc_nslots": rdcc_nslots, + "rdcc_w0": rdcc_w0, + } + return self._sanitize_h5_open_kwargs(kwargs) + + @staticmethod + def _sanitize_h5_open_kwargs(kwargs: dict) -> dict: + cleaned = {} + rdcc_nbytes = kwargs.get("rdcc_nbytes") + if rdcc_nbytes is not None and rdcc_nbytes > 0: + cleaned["rdcc_nbytes"] = int(rdcc_nbytes) + rdcc_nslots = kwargs.get("rdcc_nslots") + if rdcc_nslots is not None and rdcc_nslots > 0: + cleaned["rdcc_nslots"] = int(rdcc_nslots) + rdcc_w0 = kwargs.get("rdcc_w0") + if rdcc_w0 is not None and 0.0 <= float(rdcc_w0) <= 1.0: + cleaned["rdcc_w0"] = float(rdcc_w0) + return cleaned + + def _normalize_h5_open_kwargs(self, h5_open_kwargs: dict | None) -> dict: + if h5_open_kwargs is None: + return self._default_h5_open_kwargs() + return self._sanitize_h5_open_kwargs(h5_open_kwargs) + + def _open_h5_file(self) -> None: + if self.h5_file is not None: + try: + self.h5_file.close() + except Exception: + pass + self.h5_file = h5py.File(self.h5_path, "r", **self.h5_open_kwargs) + self._h5_pid = os.getpid() + + def _ensure_h5_open(self) -> None: + if ( + self.h5_file is None + or self._h5_pid != os.getpid() + or not self.h5_file.id.valid + ): + self._open_h5_file() + + def ensure_h5_open(self) -> None: + self._ensure_h5_open() def set_store_raw_expression(self, flag: bool) -> None: """ @@ -181,6 +275,7 @@ def __getitem__(self, idx: int): - pert_cell_counts: the raw gene expression of the perturbed cell (if store_raw_expression is True) - ctrl_cell_counts: the raw gene expression of the control cell (if store_raw_basal is True) """ + self._ensure_h5_open() # Get the perturbed cell expression, control cell expression, and index of mapped control cell file_idx = int(self.all_indices[idx]) @@ -258,6 +353,7 @@ def __getitems__(self, indices): Batch-aware fetch for consecutive loading with batched CSR densification. Falls back to per-item access when not applicable. """ + self._ensure_h5_open() if not self._use_batched_fetch(): return [self.__getitem__(int(i)) for i in indices] @@ -284,13 +380,13 @@ def __getitems__(self, indices): for start in range(0, len(file_indices), sentence_len): sentence_idx = file_indices[start : start + sentence_len] split = splits[start] - cell_type = self.get_cell_type(sentence_idx[0]) + cell_type_code = self.get_cell_type_code(sentence_idx[0]) pool = self.mapping_strategy.split_control_pool[split].get( - cell_type, None + cell_type_code, None ) if not pool: raise ValueError( - f"No control cells found in RandomMappingStrategy for cell type '{cell_type}'" + f"No control cells found in RandomMappingStrategy for cell type '{self.get_cell_type(sentence_idx[0])}'" ) block = self.mapping_strategy._sample_consecutive_controls( pool, len(sentence_idx) @@ -524,6 +620,13 @@ def get_cell_type(self, idx): code = self.metadata_cache.cell_type_codes[idx] return self.metadata_cache.cell_type_categories[code] + def get_cell_type_code(self, idx: int) -> int: + """ + Get the cell type code for a given index. + """ + idx = int(idx) if hasattr(idx, "__int__") else idx + return int(self.metadata_cache.cell_type_codes[idx]) + def get_all_cell_types(self, indices): """ Get the cell types for all given indices. @@ -531,6 +634,12 @@ def get_all_cell_types(self, indices): codes = self.metadata_cache.cell_type_codes[indices] return self.metadata_cache.cell_type_categories[codes] + def get_all_cell_type_codes(self, indices) -> np.ndarray: + """ + Get the cell type codes for all given indices. + """ + return self.metadata_cache.cell_type_codes[indices] + def get_perturbation_name(self, idx): """ Get the perturbation name for a given index. @@ -540,6 +649,32 @@ def get_perturbation_name(self, idx): pert_code = self.metadata_cache.pert_codes[idx] return self.metadata_cache.pert_categories[pert_code] + def get_perturbation_code(self, idx: int) -> int: + """ + Get the perturbation code for a given index. + """ + idx = int(idx) if hasattr(idx, "__int__") else idx + return int(self.metadata_cache.pert_codes[idx]) + + def get_all_perturbation_codes(self, indices) -> np.ndarray: + """ + Get the perturbation codes for all given indices. + """ + return self.metadata_cache.pert_codes[indices] + + def get_batch_code(self, idx: int) -> int: + """ + Get the batch code for a given index. + """ + idx = int(idx) if hasattr(idx, "__int__") else idx + return int(self.metadata_cache.batch_codes[idx]) + + def get_all_batch_codes(self, indices) -> np.ndarray: + """ + Get the batch codes for all given indices. + """ + return self.metadata_cache.batch_codes[indices] + def to_subset_dataset( self, split: str, @@ -1084,6 +1219,13 @@ def _register_split_indices( # update them in the dataset self.split_perturbed_indices[split] |= set(perturbed_indices) self.split_control_indices[split] |= set(control_indices) + if not hasattr(self, "_index_to_split_code"): + self._init_split_index_cache() + code = self._split_name_to_code[split] + if len(perturbed_indices) > 0: + self._index_to_split_code[perturbed_indices] = code + if len(control_indices) > 0: + self._index_to_split_code[control_indices] = code # forward these to the mapping strategy self.mapping_strategy.register_split_indices( @@ -1092,6 +1234,11 @@ def _register_split_indices( def _find_split_for_idx(self, idx: int) -> str | None: """Utility to find which split (train/val/test) this idx belongs to.""" + if hasattr(self, "_index_to_split_code"): + code = int(self._index_to_split_code[idx]) + if code >= 0: + return self._split_code_to_name[code] + return None for s in self.split_perturbed_indices.keys(): if ( idx in self.split_perturbed_indices[s] @@ -1156,9 +1303,13 @@ def __getstate__(self): # Copy the object's dict state = self.__dict__.copy() # Remove the open file object if it exists - if "h5_file" in state: - # We'll also store whether it's currently open, so that we can re-open later if needed - del state["h5_file"] + if self.h5_file is not None: + try: + self.h5_file.close() + except Exception: + pass + state.pop("h5_file", None) + state.pop("_h5_pid", None) return state def __setstate__(self, state): @@ -1167,8 +1318,11 @@ def __setstate__(self, state): """ # TODO-Abhi: remove this before release self.__dict__.update(state) - # This ensures that after we unpickle, we have a valid h5_file handle again - self.h5_file = h5py.File(self.h5_path, "r") + if not hasattr(self, "h5_open_kwargs"): + self.h5_open_kwargs = self._normalize_h5_open_kwargs(None) + self.h5_file = None + self._h5_pid = None + self._open_h5_file() self.metadata_cache = GlobalH5MetadataCache().get_cache( str(self.h5_path), self.pert_col, @@ -1176,6 +1330,8 @@ def __setstate__(self, state): self.control_pert, self.batch_col, ) + if not hasattr(self, "_index_to_split_code"): + self._init_split_index_cache() def _load_cell_barcodes(self) -> np.ndarray: """ diff --git a/src/cell_load/mapping_strategies/batch.py b/src/cell_load/mapping_strategies/batch.py index 2b84fda..450e0e1 100644 --- a/src/cell_load/mapping_strategies/batch.py +++ b/src/cell_load/mapping_strategies/batch.py @@ -66,8 +66,8 @@ def register_split_indices( For each control cell, we retrieve both its batch and cell type, using that pair as the key. """ for idx in control_indices: - batch = dataset.get_batch(idx) - cell_type = dataset.get_cell_type(idx) + batch = dataset.get_batch_code(idx) + cell_type = dataset.get_cell_type_code(idx) key = (batch, cell_type) if key not in self.split_control_maps[split]: self.split_control_maps[split][key] = [] @@ -92,18 +92,18 @@ def _build_consecutive_mapping( all_indices = np.concatenate([perturbed_indices, control_indices]) # Fallback pools by cell type (sorted for deterministic order). - fallback_pools: dict[str, list[int]] = {} + fallback_pools: dict[int, list[int]] = {} for (batch, cell_type), indices in self.split_control_maps[split].items(): fallback_pools.setdefault(cell_type, []).extend(indices) for cell_type, pool in fallback_pools.items(): fallback_pools[cell_type] = sorted(pool) - key_offsets: dict[tuple[int, str], int] = {} - fallback_offsets: dict[str, int] = {} + key_offsets: dict[tuple[int, int], int] = {} + fallback_offsets: dict[int, int] = {} for idx in all_indices: - batch = dataset.get_batch(idx) - cell_type = dataset.get_cell_type(idx) + batch = dataset.get_batch_code(idx) + cell_type = dataset.get_cell_type_code(idx) key = (batch, cell_type) pool = self.split_control_maps[split].get(key, []) @@ -146,8 +146,8 @@ def get_control_indices( ) return np.array(control_idxs) - batch = dataset.get_batch(perturbed_idx) - cell_type = dataset.get_cell_type(perturbed_idx) + batch = dataset.get_batch_code(perturbed_idx) + cell_type = dataset.get_cell_type_code(perturbed_idx) key = (batch, cell_type) pool = self.split_control_maps[split].get(key, []) @@ -161,7 +161,7 @@ def get_control_indices( if not pool: raise ValueError( "No control cells found in BatchMappingStrategy for cell type '{}'".format( - cell_type + dataset.get_cell_type(perturbed_idx) ) ) @@ -181,8 +181,8 @@ def get_control_index( return None return control_idxs[0] - batch = dataset.get_batch(perturbed_idx) - cell_type = dataset.get_cell_type(perturbed_idx) + batch = dataset.get_batch_code(perturbed_idx) + cell_type = dataset.get_cell_type_code(perturbed_idx) key = (batch, cell_type) pool = self.split_control_maps[split].get(key, []) diff --git a/src/cell_load/mapping_strategies/random.py b/src/cell_load/mapping_strategies/random.py index 27dde8b..45a04e6 100644 --- a/src/cell_load/mapping_strategies/random.py +++ b/src/cell_load/mapping_strategies/random.py @@ -98,18 +98,18 @@ def register_split_indices( """ all_indices = np.concatenate([perturbed_indices, control_indices]) - # Get cell types for all control indices - cell_types = dataset.get_all_cell_types(control_indices) + # Get cell type codes for all control indices + cell_types = dataset.get_all_cell_type_codes(control_indices) # Group by cell type and store the control indices - for ct in np.unique(cell_types): - ct_mask = cell_types == ct + for ct_code in np.unique(cell_types): + ct_mask = cell_types == ct_code ct_indices = control_indices[ct_mask] - if ct not in self.split_control_pool[split]: - self.split_control_pool[split][ct] = list(ct_indices) + if ct_code not in self.split_control_pool[split]: + self.split_control_pool[split][ct_code] = list(ct_indices) else: - self.split_control_pool[split][ct].extend(ct_indices) + self.split_control_pool[split][ct_code].extend(ct_indices) build_mapping = self.cache_perturbation_control_pairs if build_mapping: @@ -124,8 +124,8 @@ def register_split_indices( # Group perturbed indices by cell type and perturbation name for pert_idx in all_indices: - pert_cell_type = dataset.get_cell_type(pert_idx) - pert_name = dataset.get_perturbation_name(pert_idx) + pert_cell_type = dataset.get_cell_type_code(pert_idx) + pert_name = dataset.get_perturbation_code(pert_idx) key = (pert_cell_type, pert_name) if key not in pert_groups: @@ -201,20 +201,20 @@ def get_control_indices( ) return np.array(control_idxs) if self.use_consecutive_loading: - pert_cell_type = dataset.get_cell_type(perturbed_idx) + pert_cell_type = dataset.get_cell_type_code(perturbed_idx) pool = self.split_control_pool[split].get(pert_cell_type, None) if not pool: raise ValueError( - f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" + f"No control cells found in RandomMappingStrategy for cell type '{dataset.get_cell_type(perturbed_idx)}'" ) return self._sample_consecutive_controls(pool, self.n_basal_samples) else: # Sample new control cells each time (original behavior) - pert_cell_type = dataset.get_cell_type(perturbed_idx) + pert_cell_type = dataset.get_cell_type_code(perturbed_idx) pool = self.split_control_pool[split].get(pert_cell_type, None) if not pool: raise ValueError( - f"No control cells found in RandomMappingStrategy for cell type '{pert_cell_type}'" + f"No control cells found in RandomMappingStrategy for cell type '{dataset.get_cell_type(perturbed_idx)}'" ) control_idxs = self.rng.choices(pool, k=self.n_basal_samples) return np.array(control_idxs) @@ -237,14 +237,14 @@ def get_control_index( return None return control_idxs[0] if self.use_consecutive_loading: - pert_cell_type = dataset.get_cell_type(perturbed_idx) + pert_cell_type = dataset.get_cell_type_code(perturbed_idx) pool = self.split_control_pool[split].get(pert_cell_type, None) if not pool: return None return self._sample_consecutive_controls(pool, 1)[0] else: # Sample new control cell each time (original behavior) - pert_cell_type = dataset.get_cell_type(perturbed_idx) + pert_cell_type = dataset.get_cell_type_code(perturbed_idx) pool = self.split_control_pool[split].get(pert_cell_type, None) if not pool: return None From 10502b87027176a8a04dfcda2792b4ea16c25527 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Thu, 5 Feb 2026 19:57:21 +0000 Subject: [PATCH 06/13] added self.val_subsample_fraction parameter to data module, to subset the number of validation batches used --- README.md | 1 + .../data_modules/perturbation_dataloader.py | 41 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/README.md b/README.md index 67778ac..115a402 100644 --- a/README.md +++ b/README.md @@ -287,6 +287,7 @@ filtered_adata.write_h5ad("filtered_data.h5ad") - **`should_yield_control_cells`**: Include control cells in output (default: `true`) - **`num_workers`**: Number of workers for data loading (default: 8) - **`batch_size`**: Batch size for training (default: 128) +- **`val_subsample_fraction`**: Fraction of validation subsets to keep (e.g., `0.01` keeps ~1% of `val_datasets`) ### Usage diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index 8271a76..95ff308 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -71,6 +71,7 @@ def __init__( cell_sentence_len: int = 512, cache_perturbation_control_pairs: bool = False, drop_last: bool = False, + val_subsample_fraction: float | None = None, additional_obs: list[str] | None = None, use_consecutive_loading: bool = False, h5_open_kwargs: dict | None = None, @@ -95,6 +96,7 @@ def __init__( n_basal_samples: Number of control cells to sample per perturbed cell cache_perturbation_control_pairs: If True cache perturbation-control pairs at the start of training and reuse them. drop_last: Whether to drop the last sentence set if it is smaller than cell_sentence_len + val_subsample_fraction: Fraction of validation subsets to keep (subsamples self.val_datasets) use_consecutive_loading: Whether to form cell sets from consecutive indices for faster IO h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes) """ @@ -112,6 +114,20 @@ def __init__( self.rng = np.random.default_rng(random_seed) self.drop_last = drop_last self.use_consecutive_loading = use_consecutive_loading + if val_subsample_fraction is None: + self.val_subsample_fraction = None + else: + try: + val_subsample_fraction = float(val_subsample_fraction) + except (TypeError, ValueError) as exc: + raise ValueError( + "val_subsample_fraction must be a float in (0, 1]." + ) from exc + if not (0.0 < val_subsample_fraction <= 1.0): + raise ValueError( + f"val_subsample_fraction must be in (0, 1]; got {val_subsample_fraction!r}" + ) + self.val_subsample_fraction = val_subsample_fraction # H5 field names self.pert_col = pert_col @@ -206,6 +222,7 @@ def setup(self, stage: str | None = None): """ if len(self.train_datasets) == 0: self._setup_datasets() + self._apply_val_subsample() logger.info( "Done! Train / Val / Test splits: %d / %d / %d", len(self.train_datasets), @@ -246,6 +263,7 @@ def save_state(self, filepath: str): "normalize_counts": self.normalize_counts, "store_raw_basal": self.store_raw_basal, "barcode": self.barcode, + "val_subsample_fraction": self.val_subsample_fraction, "additional_obs": self.additional_obs, "use_consecutive_loading": self.use_consecutive_loading, "h5_open_kwargs": self.h5_open_kwargs, @@ -667,6 +685,29 @@ def _setup_datasets(self): if pbar is not None: pbar.close() + def _apply_val_subsample(self) -> None: + """Subsample validation datasets based on val_subsample_fraction.""" + if self.val_subsample_fraction is None or len(self.val_datasets) == 0: + return + + total = len(self.val_datasets) + keep = max(1, int(np.ceil(total * self.val_subsample_fraction))) + if keep >= total: + return + + # Deterministic ordering for reproducibility. + ordered = sorted( + self.val_datasets, + key=lambda subset: (subset.dataset.name, str(subset.dataset.h5_path)), + ) + self.val_datasets = ordered[:keep] + logger.info( + "Subsampled validation datasets: kept %d/%d (val_subsample_fraction=%.4f).", + keep, + total, + self.val_subsample_fraction, + ) + def _split_fewshot_celltype( self, ds: PerturbationDataset, From b701167afbb8dee87d395869d6ca1fee22b0467b Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Thu, 5 Feb 2026 20:59:08 +0000 Subject: [PATCH 07/13] updated code to also yield dataset name in each batch --- src/cell_load/dataset/_perturbation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index a4c9049..ab01e0e 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -272,6 +272,7 @@ def __getitem__(self, idx: int): - cell_type: the cell type - batch: the batch (as an int or string) - batch_name: the batch name (as a string) + - dataset_name: the dataset name (from the TOML) - pert_cell_counts: the raw gene expression of the perturbed cell (if store_raw_expression is True) - ctrl_cell_counts: the raw gene expression of the control cell (if store_raw_basal is True) """ @@ -313,6 +314,7 @@ def __getitem__(self, idx: int): "ctrl_cell_emb": ctrl_expr, "pert_emb": pert_onehot, "pert_name": pert_name, + "dataset_name": self.name, "batch_name": batch_name, "batch": batch_onehot, "cell_type": cell_type, @@ -506,6 +508,7 @@ def __getitems__(self, indices): "ctrl_cell_emb": ctrl_expr, "pert_emb": pert_onehot, "pert_name": pert_name, + "dataset_name": self.name, "batch_name": batch_name, "batch": batch_onehot, "cell_type": cell_type, From 879374c07ffb084c484718795f130972c220b363 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Thu, 5 Feb 2026 21:45:27 +0000 Subject: [PATCH 08/13] added logic for downsampling the number of cells per condition --- .../data_modules/perturbation_dataloader.py | 25 +++++++- src/cell_load/data_modules/samplers.py | 58 ++++++++++++++++- src/cell_load/dataset/_perturbation.py | 64 +++++++++++-------- 3 files changed, 116 insertions(+), 31 deletions(-) diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index 95ff308..a6cd337 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -64,6 +64,7 @@ def __init__( embed_key: Literal["X_hvg", "X_state"] | None = None, output_space: Literal["gene", "all", "embedding"] = "gene", downsample: float | None = None, + downsample_cells: int | None = None, is_log1p: bool = False, basal_mapping_strategy: Literal["batch", "random"] = "random", n_basal_samples: int = 1, @@ -90,7 +91,10 @@ def __init__( random_seed: For reproducible splits & sampling embed_key: Embedding key or matrix in the H5 file to use for feauturizing cells output_space: The output space for model predictions (gene, all genes, or embedding-only) - downsample: Fraction of counts to retain via binomial downsampling (only for output_space="all") + downsample: If <=1, fraction of counts to retain via binomial downsampling; if >1, target + read depth per cell (only for output_space="all") + downsample_cells: Max cells per (cell_type, perturbation[, batch]) group; if a group has + fewer cells it is unchanged is_log1p: Whether raw counts in X are log1p-transformed (auto-set if uns/log1p is present) basal_mapping_strategy: One of {"batch","random","nearest","ot"} n_basal_samples: Number of control cells to sample per perturbed cell @@ -141,6 +145,23 @@ def __init__( f"output_space must be one of 'gene', 'all', or 'embedding'; got {self.output_space!r}" ) self.downsample = downsample + if downsample_cells is None: + self.downsample_cells = None + else: + if isinstance(downsample_cells, bool): + raise ValueError("downsample_cells must be a positive integer or None.") + if isinstance(downsample_cells, float): + if not downsample_cells.is_integer(): + raise ValueError( + "downsample_cells must be a positive integer or None." + ) + downsample_cells = int(downsample_cells) + elif not isinstance(downsample_cells, (int, np.integer)): + raise ValueError("downsample_cells must be a positive integer or None.") + downsample_cells = int(downsample_cells) + if downsample_cells <= 0: + raise ValueError("downsample_cells must be a positive integer or None.") + self.downsample_cells = downsample_cells self.is_log1p = is_log1p # Sampling and mapping @@ -250,6 +271,7 @@ def save_state(self, filepath: str): "embed_key": self.embed_key, "output_space": self.output_space, "downsample": self.downsample, + "downsample_cells": self.downsample_cells, "is_log1p": self.is_log1p, "basal_mapping_strategy": self.basal_mapping_strategy, "n_basal_samples": self.n_basal_samples, @@ -449,6 +471,7 @@ def _create_dataloader( test=test, use_batch=use_batch, use_consecutive_loading=self.use_consecutive_loading, + downsample_cells=self.downsample_cells, ) return DataLoader( diff --git a/src/cell_load/data_modules/samplers.py b/src/cell_load/data_modules/samplers.py index d4eb413..5c50b3d 100644 --- a/src/cell_load/data_modules/samplers.py +++ b/src/cell_load/data_modules/samplers.py @@ -34,6 +34,7 @@ def __init__( test: bool = False, use_batch: bool = False, use_consecutive_loading: bool = False, + downsample_cells: int | None = None, seed: int = 0, epoch: int = 0, ): @@ -61,6 +62,7 @@ def __init__( self.cell_sentence_len = cell_sentence_len self.drop_last = drop_last + self.downsample_cells = self._validate_downsample_cells(downsample_cells) # Setup distributed settings if distributed mode is enabled. self.distributed = False @@ -165,6 +167,50 @@ def _create_batches(self) -> list[list[int]]: return all_batches + def _validate_downsample_cells(self, downsample_cells: int | None) -> int | None: + if downsample_cells is None: + return None + if isinstance(downsample_cells, bool): + raise ValueError("downsample_cells must be a positive integer or None.") + if isinstance(downsample_cells, float): + if not downsample_cells.is_integer(): + raise ValueError("downsample_cells must be a positive integer or None.") + downsample_cells = int(downsample_cells) + elif not isinstance(downsample_cells, (int, np.integer)): + raise ValueError("downsample_cells must be a positive integer or None.") + downsample_cells = int(downsample_cells) + if downsample_cells <= 0: + raise ValueError("downsample_cells must be a positive integer or None.") + return downsample_cells + + def _apply_downsample_cells( + self, sentences: list[list[int]] + ) -> list[list[int]]: + if self.downsample_cells is None or not sentences: + return sentences + + total = sum(len(sentence) for sentence in sentences) + if total <= self.downsample_cells: + return sentences + + order = np.random.permutation(len(sentences)) + selected: list[list[int]] = [] + remaining = self.downsample_cells + for idx in order: + if remaining <= 0: + break + sentence = sentences[idx] + if len(sentence) <= remaining: + selected.append(sentence) + remaining -= len(sentence) + else: + if not self.drop_last and remaining > 0: + selected.append(sentence[:remaining]) + remaining = 0 + break + + return selected + def _get_rank_sentences(self) -> list[list[int]]: """ Get the subset of sentences that this rank should process. @@ -247,11 +293,15 @@ def _process_subset(self, global_offset: int, subset: Subset) -> list[list[int]] group_indices = sorted_indices[start:end] np.random.shuffle(group_indices) + group_sentences = [] for i in range(0, len(group_indices), self.cell_sentence_len): sentence = group_indices[i : i + self.cell_sentence_len] if len(sentence) < self.cell_sentence_len and self.drop_last: continue - subset_batches.append(sentence.tolist()) + group_sentences.append(sentence.tolist()) + + group_sentences = self._apply_downsample_cells(group_sentences) + subset_batches.extend(group_sentences) return subset_batches @@ -295,11 +345,15 @@ def _process_subset_consecutive( subset_batches = [] for segment in segments: + group_sentences = [] for i in range(0, len(segment), self.cell_sentence_len): sentence = segment[i : i + self.cell_sentence_len] if len(sentence) < self.cell_sentence_len and self.drop_last: continue - subset_batches.append(sentence.tolist()) + group_sentences.append(sentence.tolist()) + + group_sentences = self._apply_downsample_cells(group_sentences) + subset_batches.extend(group_sentences) return subset_batches diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index ab01e0e..86813c5 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -71,7 +71,8 @@ def __init__( store_raw_basal: If True, include raw basal expression barcode: If True, include cell barcodes in output additional_obs: Optional list of obs column names to include in each sample - downsample: Fraction of counts to retain via binomial downsampling (only for output_space="all") + downsample: If <=1, fraction of counts to retain via binomial downsampling; if >1, target + read depth per cell (only for output_space="all") is_log1p: Whether raw counts in X are log1p-transformed (affects downsampling) cell_sentence_len: Optional sentence length for consecutive loading batches h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes) @@ -106,10 +107,10 @@ def __init__( downsample = float(downsample) except (TypeError, ValueError) as exc: raise ValueError( - f"downsample must be a float in (0, 1]; got {downsample!r}" + f"downsample must be a positive float; got {downsample!r}" ) from exc - if not (0.0 < downsample <= 1.0): - raise ValueError(f"downsample must be in (0, 1]; got {downsample!r}") + if not (0.0 < downsample): + raise ValueError(f"downsample must be > 0; got {downsample!r}") self.downsample = downsample self.is_log1p = bool(is_log1p) self.cell_sentence_len = cell_sentence_len @@ -758,31 +759,15 @@ def _fetch_gene_expression_csr_row(self, idx: int) -> tuple[np.ndarray, np.ndarr return sub_indices.astype(np.int64), sub_data.astype(np.float32) def _maybe_downsample_counts(self, counts: torch.Tensor) -> torch.Tensor: - if ( - self.downsample is None - or self.downsample >= 1.0 - or self.output_space != "all" - ): + if self.downsample is None or self.output_space != "all" or self.downsample == 1.0: return counts counts_np = counts.detach().cpu().numpy() - if self.is_log1p: - counts_lin = np.expm1(counts_np) - counts_int = np.rint(counts_lin).astype(np.int64) - else: - counts_int = counts_np.astype(np.int64) - counts_int = np.maximum(counts_int, 0) - sampled = self.rng.binomial(counts_int, self.downsample) - if self.is_log1p: - sampled = np.log1p(sampled) + sampled = self._maybe_downsample_counts_array(counts_np) return torch.tensor(sampled, dtype=torch.float32) def _maybe_downsample_counts_array(self, counts: np.ndarray) -> np.ndarray: - if ( - self.downsample is None - or self.downsample >= 1.0 - or self.output_space != "all" - ): + if self.downsample is None or self.output_space != "all" or self.downsample == 1.0: return counts if self.is_log1p: @@ -791,7 +776,27 @@ def _maybe_downsample_counts_array(self, counts: np.ndarray) -> np.ndarray: else: counts_int = counts.astype(np.int64) counts_int = np.maximum(counts_int, 0) - sampled = self.rng.binomial(counts_int, self.downsample) + if self.downsample < 1.0: + sampled = self.rng.binomial(counts_int, self.downsample) + else: + target = float(self.downsample) + if counts_int.ndim == 1: + total = counts_int.sum() + if total <= 0: + sampled = counts_int + else: + p = min(1.0, target / float(total)) + sampled = self.rng.binomial(counts_int, p) + else: + total = counts_int.sum(axis=1, keepdims=True).astype(np.float64) + p = np.divide( + target, + total, + out=np.ones_like(total, dtype=np.float64), + where=total > 0, + ) + p = np.minimum(p, 1.0) + sampled = self.rng.binomial(counts_int, p) if self.is_log1p: sampled = np.log1p(sampled) return sampled.astype(np.float32) @@ -804,7 +809,7 @@ def fetch_gene_expression(self, idx: int) -> torch.Tensor: if ( attrs.get("encoding-type") == "csr_matrix" and self.downsample is not None - and self.downsample < 1.0 + and self.downsample != 1.0 and self.output_space == "all" ): sub_indices, sub_data = self._fetch_gene_expression_csr_row(idx) @@ -816,9 +821,12 @@ def fetch_gene_expression(self, idx: int) -> torch.Tensor: else: counts_int = sub_data.astype(np.int64) counts_int = np.maximum(counts_int, 0) - sampled = self.rng.binomial(counts_int, self.downsample).astype( - np.float32 - ) + if self.downsample < 1.0: + p = self.downsample + else: + total = counts_int.sum() + p = 1.0 if total <= 0 else min(1.0, self.downsample / float(total)) + sampled = self.rng.binomial(counts_int, p).astype(np.float32) if self.is_log1p: sampled = np.log1p(sampled) dense[sub_indices] = sampled From 518434f23412785878389bc7d68f95cdd0195a92 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Thu, 5 Feb 2026 21:45:45 +0000 Subject: [PATCH 09/13] ruff formatting --- .../data_modules/perturbation_dataloader.py | 7 +++- src/cell_load/data_modules/samplers.py | 4 +- src/cell_load/dataset/_metadata.py | 4 +- src/cell_load/dataset/_perturbation.py | 37 ++++++++++++------- src/cell_load/mapping_strategies/random.py | 4 +- 5 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index a6cd337..68aa9a8 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -28,7 +28,12 @@ def _worker_init_fn(worker_id: int) -> None: - for var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS"): + for var in ( + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + "OPENBLAS_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + ): os.environ.setdefault(var, "1") try: torch.set_num_threads(1) diff --git a/src/cell_load/data_modules/samplers.py b/src/cell_load/data_modules/samplers.py index 5c50b3d..47c6eb1 100644 --- a/src/cell_load/data_modules/samplers.py +++ b/src/cell_load/data_modules/samplers.py @@ -183,9 +183,7 @@ def _validate_downsample_cells(self, downsample_cells: int | None) -> int | None raise ValueError("downsample_cells must be a positive integer or None.") return downsample_cells - def _apply_downsample_cells( - self, sentences: list[list[int]] - ) -> list[list[int]]: + def _apply_downsample_cells(self, sentences: list[list[int]]) -> list[list[int]]: if self.downsample_cells is None or not sentences: return sentences diff --git a/src/cell_load/dataset/_metadata.py b/src/cell_load/dataset/_metadata.py index 2038e3e..900f8dc 100644 --- a/src/cell_load/dataset/_metadata.py +++ b/src/cell_load/dataset/_metadata.py @@ -42,7 +42,9 @@ def __getitems__(self, indices): for out_pos, idx in enumerate(indices): dataset_idx = bisect_right(self.cumulative_sizes, idx) sample_idx = ( - idx if dataset_idx == 0 else idx - self.cumulative_sizes[dataset_idx - 1] + idx + if dataset_idx == 0 + else idx - self.cumulative_sizes[dataset_idx - 1] ) grouped.setdefault(dataset_idx, []).append((out_pos, sample_idx)) diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index 86813c5..9d630bd 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -155,10 +155,14 @@ def _init_split_index_cache(self) -> None: self._index_to_split_code = np.full(self.n_cells, -1, dtype=np.int8) for split, indices in self.split_perturbed_indices.items(): if indices: - self._index_to_split_code[list(indices)] = self._split_name_to_code[split] + self._index_to_split_code[list(indices)] = self._split_name_to_code[ + split + ] for split, indices in self.split_control_indices.items(): if indices: - self._index_to_split_code[list(indices)] = self._split_name_to_code[split] + self._index_to_split_code[list(indices)] = self._split_name_to_code[ + split + ] @staticmethod def _parse_env_int(name: str, default: int) -> int: @@ -183,9 +187,7 @@ def _parse_env_float(name: str, default: float) -> float: return default def _default_h5_open_kwargs(self) -> dict: - rdcc_nbytes = self._parse_env_int( - "CELL_LOAD_H5_RDCC_NBYTES", 64 * 1024 * 1024 - ) + rdcc_nbytes = self._parse_env_int("CELL_LOAD_H5_RDCC_NBYTES", 64 * 1024 * 1024) rdcc_nslots = self._parse_env_int("CELL_LOAD_H5_RDCC_NSLOTS", 1_000_003) rdcc_w0 = self._parse_env_float("CELL_LOAD_H5_RDCC_W0", 0.75) kwargs = { @@ -418,9 +420,8 @@ def __getitems__(self, indices): raise ValueError( f"No control cells found for cell type '{self.get_cell_type(missing_file_idx)}'" ) - if ( - (self.store_raw_basal and self.output_space != "embedding") - or (self.barcode and self.cell_barcodes is not None) + if (self.store_raw_basal and self.output_space != "embedding") or ( + self.barcode and self.cell_barcodes is not None ): raise ValueError( f"No control cells found for cell type '{self.get_cell_type(missing_file_idx)}'" @@ -759,7 +760,11 @@ def _fetch_gene_expression_csr_row(self, idx: int) -> tuple[np.ndarray, np.ndarr return sub_indices.astype(np.int64), sub_data.astype(np.float32) def _maybe_downsample_counts(self, counts: torch.Tensor) -> torch.Tensor: - if self.downsample is None or self.output_space != "all" or self.downsample == 1.0: + if ( + self.downsample is None + or self.output_space != "all" + or self.downsample == 1.0 + ): return counts counts_np = counts.detach().cpu().numpy() @@ -767,7 +772,11 @@ def _maybe_downsample_counts(self, counts: torch.Tensor) -> torch.Tensor: return torch.tensor(sampled, dtype=torch.float32) def _maybe_downsample_counts_array(self, counts: np.ndarray) -> np.ndarray: - if self.downsample is None or self.output_space != "all" or self.downsample == 1.0: + if ( + self.downsample is None + or self.output_space != "all" + or self.downsample == 1.0 + ): return counts if self.is_log1p: @@ -861,7 +870,7 @@ def _fetch_dense_matrix_batch( block = np.asarray(ds[row_start : row_end + 1], dtype=np.float32) if block.ndim == 1: block = block[:, None] - dense_sorted[run_start:len(sorted_rows)] = block + dense_sorted[run_start : len(sorted_rows)] = block inv_order = np.empty_like(order) inv_order[order] = np.arange(len(order)) @@ -961,9 +970,9 @@ def _fill_dense_run( ptr_end = int(indptr_slice[offset + 1] - base_ptr) if ptr_end <= ptr_start: continue - dense_sorted[start + offset, block_indices[ptr_start:ptr_end]] = ( - block_data[ptr_start:ptr_end] - ) + dense_sorted[start + offset, block_indices[ptr_start:ptr_end]] = block_data[ + ptr_start:ptr_end + ] @lru_cache(maxsize=10000) def fetch_obsm_expression(self, idx: int, key: str) -> torch.Tensor: diff --git a/src/cell_load/mapping_strategies/random.py b/src/cell_load/mapping_strategies/random.py index 45a04e6..ebea3a8 100644 --- a/src/cell_load/mapping_strategies/random.py +++ b/src/cell_load/mapping_strategies/random.py @@ -150,9 +150,7 @@ def register_split_indices( if pool_arr.size == 0: control_assignments = [] else: - repeats = int( - np.ceil(total_assignments_needed / pool_arr.size) - ) + repeats = int(np.ceil(total_assignments_needed / pool_arr.size)) control_assignments = np.tile(pool_arr, repeats)[ :total_assignments_needed ].tolist() From e6245a881b7336972c4ca21ae8fe331d8a37d7c6 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 10 Feb 2026 17:53:31 +0000 Subject: [PATCH 10/13] obsm now works with csr or dense data both --- src/cell_load/dataset/_perturbation.py | 172 +++++++++++++++++++++++-- src/cell_load/utils/data_utils.py | 7 +- 2 files changed, 169 insertions(+), 10 deletions(-) diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index 9d630bd..0fe4966 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -614,7 +614,9 @@ def get_dim_for_obsm(self, key: str) -> int: """ Get the feature dimensionality of obsm data with the specified key (e.g., 'X_uce'). """ - return self.h5_file[f"obsm/{key}"].shape[1] + matrix = self._get_obsm_matrix(key) + _, n_cols = self._get_matrix_shape(matrix) + return n_cols def get_cell_type(self, idx): """ @@ -876,6 +878,135 @@ def _fetch_dense_matrix_batch( inv_order[order] = np.arange(len(order)) return dense_sorted[inv_order] + @staticmethod + def _is_csr_group(obj) -> bool: + return isinstance(obj, h5py.Group) and ( + obj.attrs.get("encoding-type") == "csr_matrix" + or all(k in obj for k in ("data", "indices", "indptr")) + ) + + def _infer_n_cols_from_indices(self, indices_ds) -> int: + # Rare fallback for malformed files that omit CSR shape metadata. + total_nnz = int(indices_ds.shape[0]) + if total_nnz == 0: + return 0 + + max_col = -1 + chunk = 1_000_000 + for start in range(0, total_nnz, chunk): + stop = min(start + chunk, total_nnz) + block = np.asarray(indices_ds[start:stop], dtype=np.int64) + if block.size == 0: + continue + local_max = int(block.max()) + if local_max > max_col: + max_col = local_max + return max_col + 1 + + def _get_matrix_shape(self, matrix_obj) -> tuple[int, int]: + if isinstance(matrix_obj, h5py.Dataset): + if len(matrix_obj.shape) == 1: + return int(matrix_obj.shape[0]), 1 + if len(matrix_obj.shape) >= 2: + return int(matrix_obj.shape[0]), int(matrix_obj.shape[1]) + raise ValueError("Dataset has invalid rank for matrix-like data.") + + if self._is_csr_group(matrix_obj): + shape_attr = matrix_obj.attrs.get("shape") + if shape_attr is not None: + shape_arr = np.asarray(shape_attr, dtype=np.int64).reshape(-1) + if shape_arr.size >= 2: + return int(shape_arr[0]), int(shape_arr[1]) + + if "indptr" not in matrix_obj: + raise KeyError("CSR group is missing required 'indptr' dataset.") + n_rows = int(matrix_obj["indptr"].shape[0]) - 1 + + if "indices" not in matrix_obj: + raise KeyError("CSR group is missing required 'indices' dataset.") + n_cols = self._infer_n_cols_from_indices(matrix_obj["indices"]) + return n_rows, n_cols + + raise TypeError( + f"Unsupported matrix storage type: {type(matrix_obj).__name__}. " + "Expected h5py.Dataset or CSR-encoded h5py.Group." + ) + + def _fetch_csr_matrix_batch( + self, matrix_group: h5py.Group, indices: np.ndarray, n_cols: int + ) -> np.ndarray: + if indices.size == 0: + return np.empty((0, n_cols), dtype=np.float32) + + if not all(k in matrix_group for k in ("indptr", "data", "indices")): + raise KeyError( + "CSR group must contain 'indptr', 'data', and 'indices' datasets." + ) + + indptr_ds = matrix_group["indptr"] + data_ds = matrix_group["data"] + indices_ds = matrix_group["indices"] + + order = np.argsort(indices) + sorted_rows = indices[order] + dense_sorted = np.zeros((len(sorted_rows), n_cols), dtype=np.float32) + + run_start = 0 + for i in range(1, len(sorted_rows)): + if sorted_rows[i] != sorted_rows[i - 1] + 1: + self._fill_dense_run( + sorted_rows, + run_start, + i, + dense_sorted, + indptr_ds, + data_ds, + indices_ds, + ) + run_start = i + + self._fill_dense_run( + sorted_rows, + run_start, + len(sorted_rows), + dense_sorted, + indptr_ds, + data_ds, + indices_ds, + ) + + inv_order = np.empty_like(order) + inv_order[order] = np.arange(len(order)) + return dense_sorted[inv_order] + + def _fetch_csr_row(self, matrix_group: h5py.Group, idx: int, n_cols: int) -> np.ndarray: + if not all(k in matrix_group for k in ("indptr", "data", "indices")): + raise KeyError( + "CSR group must contain 'indptr', 'data', and 'indices' datasets." + ) + + indptr_ds = matrix_group["indptr"] + data_ds = matrix_group["data"] + indices_ds = matrix_group["indices"] + + start_ptr = int(indptr_ds[idx]) + end_ptr = int(indptr_ds[idx + 1]) + + dense = np.zeros(n_cols, dtype=np.float32) + if end_ptr <= start_ptr: + return dense + + row_data = np.asarray(data_ds[start_ptr:end_ptr], dtype=np.float32) + row_indices = np.asarray(indices_ds[start_ptr:end_ptr], dtype=np.int64) + dense[row_indices] = row_data + return dense + + def _get_obsm_matrix(self, key: str): + path = f"/obsm/{key}" + if path not in self.h5_file: + raise KeyError(f"obsm key '{key}' not found in {self.h5_path}") + return self.h5_file[path] + def _fetch_gene_expression_batch(self, indices: np.ndarray) -> torch.Tensor: """ Fetch raw gene counts for multiple indices at once (CSR fast path). @@ -933,11 +1064,20 @@ def _fetch_gene_expression_batch(self, indices: np.ndarray) -> torch.Tensor: def _fetch_obsm_expression_batch( self, indices: np.ndarray, key: str ) -> torch.Tensor: - ds = self.h5_file[f"/obsm/{key}"] - n_cols = int(ds.shape[1]) if ds.ndim > 1 else 1 + matrix = self._get_obsm_matrix(key) + _, n_cols = self._get_matrix_shape(matrix) if indices.size == 0: return torch.empty((0, n_cols), dtype=torch.float32) - dense = self._fetch_dense_matrix_batch(ds, indices, n_cols) + + if isinstance(matrix, h5py.Dataset): + dense = self._fetch_dense_matrix_batch(matrix, indices, n_cols) + elif self._is_csr_group(matrix): + dense = self._fetch_csr_matrix_batch(matrix, indices, n_cols) + else: + raise TypeError( + f"Unsupported obsm storage for key '{key}': {type(matrix).__name__}" + ) + return torch.from_numpy(dense) def _fill_dense_run( @@ -985,8 +1125,22 @@ def fetch_obsm_expression(self, idx: int, key: str) -> torch.Tensor: Returns: 1D FloatTensor of that embedding """ - row_data = self.h5_file[f"/obsm/{key}"][idx] - return torch.tensor(row_data, dtype=torch.float32) + matrix = self._get_obsm_matrix(key) + _, n_cols = self._get_matrix_shape(matrix) + + if isinstance(matrix, h5py.Dataset): + row_data = np.asarray(matrix[idx], dtype=np.float32) + if row_data.ndim == 0: + row_data = np.asarray([row_data], dtype=np.float32) + elif row_data.ndim > 1: + row_data = row_data.reshape(-1) + return torch.from_numpy(row_data) + + if self._is_csr_group(matrix): + row_data = self._fetch_csr_row(matrix, int(idx), n_cols) + return torch.from_numpy(row_data) + + raise TypeError(f"Unsupported obsm storage for key '{key}': {type(matrix).__name__}") def get_gene_names(self, output_space="all") -> list[str]: """ @@ -1282,13 +1436,13 @@ def _get_num_genes(self) -> int: indices = self.h5_file["X/indices"][:] n_cols = indices.max() + 1 except KeyError: - n_cols = self.h5_file["obsm/X_hvg"].shape[1] + n_cols = self._get_matrix_shape(self.h5_file["obsm/X_hvg"])[1] return n_cols def get_num_hvgs(self) -> int: """Return the number of highly variable genes in the obsm matrix.""" try: - return self.h5_file["obsm/X_hvg"].shape[1] + return self._get_matrix_shape(self.h5_file["obsm/X_hvg"])[1] except: return 0 @@ -1303,7 +1457,7 @@ def _get_num_cells(self) -> int: n_rows = len(indptr) - 1 except Exception: # if this also fails, fall back to obsm - n_rows = self.h5_file["obsm/X_hvg"].shape[0] + n_rows = self._get_matrix_shape(self.h5_file["obsm/X_hvg"])[0] return n_rows def get_pert_name(self, idx: int) -> str: diff --git a/src/cell_load/utils/data_utils.py b/src/cell_load/utils/data_utils.py index 25a3d83..ead8749 100644 --- a/src/cell_load/utils/data_utils.py +++ b/src/cell_load/utils/data_utils.py @@ -136,6 +136,11 @@ def generate_onehot_map(keys) -> dict: """ Build a map from each unique key to a fixed-length one-hot torch vector. + Note: + We clone each row from the identity matrix so every tensor owns compact + storage. This avoids pathological file sizes when maps are serialized + with pickle (shared-storage tensor views can serialize very poorly). + Args: keys: iterable of hashable items Returns: @@ -145,7 +150,7 @@ def generate_onehot_map(keys) -> dict: num_classes = len(unique_keys) # identity matrix rows are one-hot vectors onehots = torch.eye(num_classes) - return {k: onehots[i] for i, k in enumerate(unique_keys)} + return {k: onehots[i].clone() for i, k in enumerate(unique_keys)} def data_to_torch_X(X): From 8c179d9cb097568ad1c4adb3f135f49cfd9cb064 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 10 Feb 2026 17:53:47 +0000 Subject: [PATCH 11/13] ruff formatting --- src/cell_load/dataset/_perturbation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index 0fe4966..deb5646 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -979,7 +979,9 @@ def _fetch_csr_matrix_batch( inv_order[order] = np.arange(len(order)) return dense_sorted[inv_order] - def _fetch_csr_row(self, matrix_group: h5py.Group, idx: int, n_cols: int) -> np.ndarray: + def _fetch_csr_row( + self, matrix_group: h5py.Group, idx: int, n_cols: int + ) -> np.ndarray: if not all(k in matrix_group for k in ("indptr", "data", "indices")): raise KeyError( "CSR group must contain 'indptr', 'data', and 'indices' datasets." @@ -1140,7 +1142,9 @@ def fetch_obsm_expression(self, idx: int, key: str) -> torch.Tensor: row_data = self._fetch_csr_row(matrix, int(idx), n_cols) return torch.from_numpy(row_data) - raise TypeError(f"Unsupported obsm storage for key '{key}': {type(matrix).__name__}") + raise TypeError( + f"Unsupported obsm storage for key '{key}': {type(matrix).__name__}" + ) def get_gene_names(self, output_space="all") -> list[str]: """ From b55a2ab1af7a31d33988b87ea1ebd2005605187b Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sat, 21 Feb 2026 03:11:33 +0000 Subject: [PATCH 12/13] updated with a exp_counts parameter that expm1's counts and casts to int when set --- pyproject.toml | 2 +- .../data_modules/perturbation_dataloader.py | 10 ++-- src/cell_load/dataset/_perturbation.py | 46 +++++++------------ 3 files changed, 23 insertions(+), 35 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d332ddb..e3b55a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cell-load" -version = "0.10.0" +version = "0.10.1" description = "Dataloaders for training models on huge single-cell datasets" readme = "README.md" authors = [ diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index 68aa9a8..45605e9 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -178,7 +178,7 @@ def __init__( # Optional behaviors self.map_controls = kwargs.get("map_controls", True) self.perturbation_features_file = kwargs.get("perturbation_features_file") - self.int_counts = kwargs.get("int_counts", False) + self.exp_counts = kwargs.get("exp_counts", False) self.normalize_counts = kwargs.get("normalize_counts", False) self.store_raw_basal = kwargs.get("store_raw_basal", False) self.barcode = kwargs.get("barcode", False) @@ -286,7 +286,7 @@ def save_state(self, filepath: str): # Include the optional behaviors "map_controls": self.map_controls, "perturbation_features_file": self.perturbation_features_file, - "int_counts": self.int_counts, + "exp_counts": self.exp_counts, "normalize_counts": self.normalize_counts, "store_raw_basal": self.store_raw_basal, "barcode": self.barcode, @@ -332,7 +332,7 @@ def load_state(cls, filepath: str): "perturbation_features_file": save_dict.pop( "perturbation_features_file", None ), - "int_counts": save_dict.pop("int_counts", False), + "exp_counts": save_dict.pop("exp_counts", False), "normalize_counts": save_dict.pop("normalize_counts", False), "store_raw_basal": save_dict.pop("store_raw_basal", False), "barcode": save_dict.pop("barcode", True), @@ -460,8 +460,8 @@ def _create_dataloader( batch_size: int | None = None, ): """Create a DataLoader with appropriate configuration.""" - use_int_counts = "int_counts" in self.__dict__ and self.int_counts - collate_fn = partial(PerturbationDataset.collate_fn, int_counts=use_int_counts) + use_exp_counts = "exp_counts" in self.__dict__ and self.exp_counts + collate_fn = partial(PerturbationDataset.collate_fn, exp_counts=use_exp_counts) ds = MetadataConcatDataset(datasets) use_batch = self.basal_mapping_strategy == "batch" diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index deb5646..a53ad75 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -1214,7 +1214,7 @@ def _decode(x): # Static methods ############################## @staticmethod - def collate_fn(batch, int_counts=False): + def collate_fn(batch, exp_counts=False): """ Optimized collate function with preallocated lists. Safely handles normalization when vectors sum to zero. @@ -1287,42 +1287,30 @@ def collate_fn(batch, int_counts=False): is_discrete = suspected_discrete_torch(pert_cell_counts) is_log = suspected_log_torch(pert_cell_counts) already_logged = (not is_discrete) and is_log + if exp_counts: + if already_logged: + pert_cell_counts = torch.expm1(pert_cell_counts) + pert_cell_counts = torch.nan_to_num( + pert_cell_counts, nan=0.0, posinf=0.0, neginf=0.0 + ) + pert_cell_counts = pert_cell_counts.clamp_min(0).round().to(torch.int32) batch_dict["pert_cell_counts"] = pert_cell_counts - # if already_logged: # counts are already log transformed - # if ( - # int_counts - # ): # if the user wants to model with raw counts, don't log transform - # batch_dict["pert_cell_counts"] = torch.expm1(pert_cell_counts) - # else: - # batch_dict["pert_cell_counts"] = pert_cell_counts - # else: - # if int_counts: - # batch_dict["pert_cell_counts"] = pert_cell_counts - # else: - # batch_dict["pert_cell_counts"] = torch.log1p(pert_cell_counts) - if has_ctrl_cell_counts: ctrl_cell_counts = torch.stack(ctrl_cell_counts_list) - is_discrete = suspected_discrete_torch(pert_cell_counts) - is_log = suspected_log_torch(pert_cell_counts) + is_discrete = suspected_discrete_torch(ctrl_cell_counts) + is_log = suspected_log_torch(ctrl_cell_counts) already_logged = (not is_discrete) and is_log + if exp_counts: + if already_logged: + ctrl_cell_counts = torch.expm1(ctrl_cell_counts) + ctrl_cell_counts = torch.nan_to_num( + ctrl_cell_counts, nan=0.0, posinf=0.0, neginf=0.0 + ) + ctrl_cell_counts = ctrl_cell_counts.clamp_min(0).round().to(torch.int32) batch_dict["ctrl_cell_counts"] = ctrl_cell_counts - # if already_logged: # counts are already log transformed - # if ( - # int_counts - # ): # if the user wants to model with raw counts, don't log transform - # batch_dict["ctrl_cell_counts"] = torch.expm1(ctrl_cell_counts) - # else: - # batch_dict["ctrl_cell_counts"] = ctrl_cell_counts - # else: - # if int_counts: - # batch_dict["ctrl_cell_counts"] = ctrl_cell_counts - # else: - # batch_dict["ctrl_cell_counts"] = torch.log1p(ctrl_cell_counts) - if has_barcodes: batch_dict["pert_cell_barcode"] = pert_cell_barcode_list batch_dict["ctrl_cell_barcode"] = ctrl_cell_barcode_list From 10e34ee7c6cd4554641276a5d296aee219220a1e Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sat, 21 Feb 2026 15:24:35 +0000 Subject: [PATCH 13/13] changes defualt for downsampling is log1p to true --- pyproject.toml | 2 +- .../data_modules/perturbation_dataloader.py | 19 ++++++++++++++++--- src/cell_load/dataset/_perturbation.py | 4 ++-- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3b55a1..5317f42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cell-load" -version = "0.10.1" +version = "0.10.2" description = "Dataloaders for training models on huge single-cell datasets" readme = "README.md" authors = [ diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index 45605e9..e6934b9 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -70,7 +70,7 @@ def __init__( output_space: Literal["gene", "all", "embedding"] = "gene", downsample: float | None = None, downsample_cells: int | None = None, - is_log1p: bool = False, + is_log1p: bool = True, basal_mapping_strategy: Literal["batch", "random"] = "random", n_basal_samples: int = 1, should_yield_control_cells: bool = True, @@ -100,7 +100,7 @@ def __init__( read depth per cell (only for output_space="all") downsample_cells: Max cells per (cell_type, perturbation[, batch]) group; if a group has fewer cells it is unchanged - is_log1p: Whether raw counts in X are log1p-transformed (auto-set if uns/log1p is present) + is_log1p: Whether raw counts in X are log1p-transformed (default True; auto-set if uns/log1p is present) basal_mapping_strategy: One of {"batch","random","nearest","ot"} n_basal_samples: Number of control cells to sample per perturbed cell cache_perturbation_control_pairs: If True cache perturbation-control pairs at the start of training and reuse them. @@ -167,7 +167,7 @@ def __init__( if downsample_cells <= 0: raise ValueError("downsample_cells must be a positive integer or None.") self.downsample_cells = downsample_cells - self.is_log1p = is_log1p + self.is_log1p = bool(is_log1p) # Sampling and mapping self.n_basal_samples = n_basal_samples @@ -549,6 +549,19 @@ def _setup_global_maps(self): "Detected uns/log1p in at least one dataset; setting is_log1p=True." ) + if self.is_log1p: + if seen_log1p: + logger.warning( + "is_log1p mode is ENABLED. Detected uns/log1p metadata (example: %s).", + log1p_example, + ) + else: + logger.warning( + "is_log1p mode is ENABLED by configuration/default, but no uns/log1p metadata was detected." + ) + else: + logger.warning("is_log1p mode is DISABLED.") + # Create one-hot maps if self.perturbation_features_file: # Load the custom featurizations from a torch file diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index a53ad75..1e6a8c5 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -46,7 +46,7 @@ def __init__( barcode: bool = False, additional_obs: list[str] | None = None, downsample: float | None = None, - is_log1p: bool = False, + is_log1p: bool = True, cell_sentence_len: int | None = None, h5_open_kwargs: dict | None = None, **kwargs, @@ -73,7 +73,7 @@ def __init__( additional_obs: Optional list of obs column names to include in each sample downsample: If <=1, fraction of counts to retain via binomial downsampling; if >1, target read depth per cell (only for output_space="all") - is_log1p: Whether raw counts in X are log1p-transformed (affects downsampling) + is_log1p: Whether raw counts in X are log1p-transformed (default True; affects downsampling) cell_sentence_len: Optional sentence length for consecutive loading batches h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes) **kwargs: Additional options (e.g. output_space)