-
Notifications
You must be signed in to change notification settings - Fork 26
Aadduri/refactor emb #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
132c50a
64f54b4
9452eb9
9dec3ea
f3c1895
28b7bb3
194d5ff
d458908
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| import os | ||
| import glob | ||
| import re | ||
| import sys | ||
|
|
||
| from functools import partial | ||
| from pathlib import Path | ||
|
|
@@ -17,6 +18,8 @@ | |
| from ..config import ExperimentConfig | ||
| from ..dataset import MetadataConcatDataset, PerturbationDataset | ||
| from ..mapping_strategies import BatchMappingStrategy, RandomMappingStrategy | ||
|
|
||
| _OUTPUT_SPACE_ALIASES: dict[str, str] = {"hvg": "gene", "transcriptome": "all"} | ||
| from ..utils.data_utils import ( | ||
| GlobalH5MetadataCache, | ||
| generate_onehot_map, | ||
|
|
@@ -61,15 +64,19 @@ def __init__( | |
| toml_config_path: str, | ||
| batch_size: int = 128, | ||
| num_workers: int = 8, | ||
| pin_memory: bool = False, | ||
| random_seed: int = 42, # this should be removed by seed everything | ||
| pert_col: str = "gene", | ||
| batch_col: str = "gem_group", | ||
| cell_type_key: str = "cell_type", | ||
| control_pert: str = "non-targeting", | ||
| embed_key: Literal["X_hvg", "X_state"] | None = None, | ||
| output_space: Literal["gene", "all", "embedding"] = "gene", | ||
| output_space: Literal[ | ||
| "gene", "all", "embedding", "hvg", "transcriptome" | ||
| ] = "gene", | ||
| downsample: float | None = None, | ||
| downsample_cells: int | None = None, | ||
| balance_outliers: bool = False, | ||
| is_log1p: bool = True, | ||
| basal_mapping_strategy: Literal["batch", "random"] = "random", | ||
| n_basal_samples: int = 1, | ||
|
|
@@ -81,6 +88,8 @@ def __init__( | |
| additional_obs: list[str] | None = None, | ||
| use_consecutive_loading: bool = False, | ||
| h5_open_kwargs: dict | None = None, | ||
| show_progress: bool = True, | ||
| collate_dtype: str = "float16", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default value for |
||
| **kwargs, # missing perturbation_features_file and store_raw_basal for backwards compatibility | ||
| ): | ||
| """ | ||
|
|
@@ -109,6 +118,7 @@ def __init__( | |
| val_subsample_fraction: Fraction of validation subsets to keep (subsamples self.val_datasets) | ||
| use_consecutive_loading: Whether to form cell sets from consecutive indices for faster IO | ||
| h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes) | ||
| show_progress: Whether to display tqdm progress during dataset setup | ||
| """ | ||
| super().__init__() | ||
|
|
||
|
|
@@ -120,6 +130,7 @@ def __init__( | |
| # Experiment level params | ||
| self.batch_size = batch_size | ||
| self.num_workers = num_workers | ||
| self.pin_memory = bool(pin_memory) | ||
| self.random_seed = random_seed | ||
| self.rng = np.random.default_rng(random_seed) | ||
| self.drop_last = drop_last | ||
|
|
@@ -145,7 +156,7 @@ def __init__( | |
| self.cell_type_key = cell_type_key | ||
| self.control_pert = control_pert | ||
| self.embed_key = embed_key | ||
| self.output_space = output_space | ||
| self.output_space = _OUTPUT_SPACE_ALIASES.get(output_space, output_space) | ||
| if self.output_space not in {"gene", "all", "embedding"}: | ||
| raise ValueError( | ||
| f"output_space must be one of 'gene', 'all', or 'embedding'; got {self.output_space!r}" | ||
|
|
@@ -168,6 +179,7 @@ def __init__( | |
| if downsample_cells <= 0: | ||
| raise ValueError("downsample_cells must be a positive integer or None.") | ||
| self.downsample_cells = downsample_cells | ||
| self.balance_outliers = bool(balance_outliers) | ||
| self.is_log1p = bool(is_log1p) | ||
|
|
||
| # Sampling and mapping | ||
|
|
@@ -185,6 +197,8 @@ def __init__( | |
| self.barcode = kwargs.get("barcode", False) | ||
| self.additional_obs = additional_obs | ||
| self.h5_open_kwargs = h5_open_kwargs | ||
| self.show_progress = bool(show_progress) | ||
| self.collate_dtype = collate_dtype | ||
| if self.use_consecutive_loading: | ||
| self._set_h5_cache_env_defaults() | ||
|
|
||
|
|
@@ -295,6 +309,7 @@ def save_state(self, filepath: str): | |
| "additional_obs": self.additional_obs, | ||
| "use_consecutive_loading": self.use_consecutive_loading, | ||
| "h5_open_kwargs": self.h5_open_kwargs, | ||
| "collate_dtype": self.collate_dtype, | ||
| } | ||
|
|
||
| torch.save(save_dict, filepath) | ||
|
|
@@ -339,6 +354,7 @@ def load_state(cls, filepath: str): | |
| "barcode": save_dict.pop("barcode", True), | ||
| "use_consecutive_loading": save_dict.pop("use_consecutive_loading", False), | ||
| "h5_open_kwargs": save_dict.pop("h5_open_kwargs", None), | ||
| "collate_dtype": save_dict.pop("collate_dtype", "float16"), | ||
| } | ||
|
|
||
| # Create new instance with all the saved parameters | ||
|
|
@@ -469,6 +485,7 @@ def _create_dataloader( | |
|
|
||
| batch_size = batch_size or (1 if test else self.batch_size) | ||
|
|
||
| is_training = datasets is self.train_datasets | ||
| sampler = PerturbationBatchSampler( | ||
| dataset=ds, | ||
| batch_size=batch_size, | ||
|
|
@@ -478,14 +495,15 @@ def _create_dataloader( | |
| use_batch=use_batch, | ||
| use_consecutive_loading=self.use_consecutive_loading, | ||
| downsample_cells=self.downsample_cells, | ||
| balance_outliers=self.balance_outliers if is_training else False, | ||
| ) | ||
|
|
||
| return DataLoader( | ||
| ds, | ||
| batch_sampler=sampler, | ||
| num_workers=self.num_workers, | ||
| collate_fn=collate_fn, | ||
| pin_memory=True, | ||
| pin_memory=getattr(self, "pin_memory", False), | ||
| prefetch_factor=4 if not test and self.num_workers > 0 else None, | ||
| persistent_workers=bool(self.num_workers > 0 and not test), | ||
| worker_init_fn=_worker_init_fn if self.num_workers > 0 else None, | ||
|
|
@@ -629,6 +647,7 @@ def _create_base_dataset( | |
| is_log1p=self.is_log1p, | ||
| cell_sentence_len=self.cell_sentence_len, | ||
| h5_open_kwargs=self.h5_open_kwargs, | ||
| collate_dtype=self.collate_dtype, | ||
| ) | ||
|
|
||
| def _setup_datasets(self): | ||
|
|
@@ -647,8 +666,13 @@ def _setup_datasets(self): | |
| total_files += len(files) | ||
|
|
||
| pbar = ( | ||
| tqdm(total=total_files, desc="Processing datasets", leave=False) | ||
| if total_files > 0 | ||
| tqdm( | ||
| total=total_files, | ||
| desc="Processing datasets", | ||
| leave=False, | ||
| file=sys.stderr, | ||
| ) | ||
| if (self.show_progress and total_files > 0) | ||
| else None | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default for
pin_memoryhas been changed fromTrue(hardcoded in the previous version) toFalse. Memory pinning is generally recommended when training on GPUs as it speeds up data transfer from CPU to GPU. If the primary use case is GPU training, consider keeping the default asTrue.