Aadduri/refactor emb#72
Conversation
…DataModule - Add collate_dtype param to PerturbationDataset for float16/float32 tensor casting - Wire collate_dtype through PerturbationDataModule to all dataset constructors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e_controls When n_samples > pool_size (e.g., observational data with rare cell types having only 2 control cells but sentence_len=64), the old tail+head wrap only wrapped once, returning fewer elements than requested. This caused IndexError in __getitems__ during multi-worker DataLoader training. Use modular arithmetic to wrap around the pool as many times as needed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…n conditions Caps any condition exceeding the median sentence count using a rolling window that advances each epoch, so all cells are eventually seen. Applied only to training dataloaders; val/test remain unbalanced. Bumps version to 0.11.0. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request updates the version to 0.11.0 and transitions the project from print statements to a structured logging framework. It introduces several performance and memory optimizations, including a collate_dtype parameter to cast tensors to lower precision and a configurable pin_memory option for the dataloader. Additionally, a balance_outliers feature was added to the sampler to downsample over-represented perturbations using a rolling window. Feedback highlights potential issues with the new default values for collate_dtype and pin_memory, and suggests improvements for logging efficiency and code redundancy.
| use_consecutive_loading: bool = False, | ||
| h5_open_kwargs: dict | None = None, | ||
| show_progress: bool = True, | ||
| collate_dtype: str = "float16", |
There was a problem hiding this comment.
The default value for collate_dtype is set to "float16". This is a significant change from the previous implicit default of float32 (via torch.FloatTensor). While this reduces memory usage, it may lead to precision issues or compatibility problems with models expecting float32 inputs. Consider if "float32" would be a safer default for a general-purpose dataloader, or ensure this change is clearly documented.
| toml_config_path: str, | ||
| batch_size: int = 128, | ||
| num_workers: int = 8, | ||
| pin_memory: bool = False, |
There was a problem hiding this comment.
| logger.info( | ||
| f"balance_outliers: cap={cap} (median). " | ||
| f"{n_capped}/{len(unique_codes)} conditions capped. " | ||
| f"Sentences: {total_before} -> {total_after} (epoch {epoch})." | ||
| ) |
There was a problem hiding this comment.
It is recommended to use lazy interpolation in logging calls (passing arguments to the logger) rather than f-strings. This avoids the overhead of string formatting if the log level is disabled and maintains consistency with the logging style used in other parts of the PR (e.g., in filter_on_target_knockdown.py).
logger.info(
"balance_outliers: cap=%d (median). %d/%d conditions capped. Sentences: %d -> %d (epoch %d).",
cap,
n_capped,
len(unique_codes),
total_before,
total_after,
epoch,
)| self.output_space = _OUTPUT_SPACE_ALIASES.get( | ||
| kwargs.get("output_space", "gene"), kwargs.get("output_space", "gene") | ||
| ) |
There was a problem hiding this comment.
adds an argument balance_outliers for limiting outlier condition groups like control cells