From e23d42689242bdd8492757744faa5d6a38bb3249 Mon Sep 17 00:00:00 2001 From: Leon Hafner Date: Mon, 30 Mar 2026 16:35:23 -0700 Subject: [PATCH 1/3] add fast filter_on_target_knockdown --- src/cell_load/utils/data_utils.py | 178 ++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) diff --git a/src/cell_load/utils/data_utils.py b/src/cell_load/utils/data_utils.py index 2944dfe..2bb9c36 100644 --- a/src/cell_load/utils/data_utils.py +++ b/src/cell_load/utils/data_utils.py @@ -4,6 +4,7 @@ import anndata import h5py import numpy as np +import pandas as pd import scipy.sparse as sp import torch @@ -447,3 +448,180 @@ def set_var_index_to_col(adata: anndata.AnnData, col: str = "col", copy=True) -> adata.var.index = adata.var[col].astype("str") adata.var_names_make_unique() return adata + + +def filter_on_target_knockdown_fast( + adata: anndata.AnnData, + perturbation_column: str = "gene", + control_label: str = "non-targeting", + residual_expression: float = 0.30, + cell_residual_expression: float = 0.50, + min_cells: int = 30, + layer: str | None = None, + var_gene_name: str = "gene_name", + verbose: bool = False, +) -> anndata.AnnData: + """ + Vectorized reimplementation of filter_on_target_knockdown. Same semantics, much faster. + + Key differences from the original: + - No upfront adata.copy(): computation runs on the original X, the copy is deferred + to the end and covers only the kept subset of cells. + - Single dense submatrix X_sub (n_cells x n_matched_perts) replaces per-perturbation + sparse column slices in the original loops. + - Control means computed once for all matched genes via X_sub[control_mask].mean(axis=0). + - Stage 1 pert means: fancy indexing + np.bincount instead of a Python loop. + - Stage 2 cell-level ratios: fully vectorized via fancy indexing + (X_sub[row_idx, col_idx] retrieves each cell's expression at its own gene in one call). + - Stage 3 min-cell check: np.bincount over the keep mask, no loop. + + Note on unmatched perturbations (pert name not in var): + The original adds them to perts_to_keep in stage 1 (is_on_target_knockdown returns 1), + but their cells never enter keep_mask in stage 2 (continue is hit). Net effect: their + cells are dropped, identical to this implementation which simply never marks them as kept. + """ + if var_gene_name not in adata.var.columns: + raise KeyError(f"Column {var_gene_name!r} not found in adata.var.") + + # pd.Index over the gene name column — used for membership testing and position lookup. + # get_indexer returns the first occurrence for duplicate gene names, matching + # the behaviour of var_names_make_unique (first occurrence keeps original name). + gene_index = pd.Index(adata.var[var_gene_name]) + + # ------------------------------------------------------------------ + # 1. Shared setup + # ------------------------------------------------------------------ + X = adata.layers[layer] if layer is not None else adata.X + perts = adata.obs[perturbation_column] + n_cells = adata.n_obs + + control_mask = (perts == control_label).values # (n_cells,) bool + + unique_perts = [p for p in perts.unique() if p != control_label] + matched_perts = [p for p in unique_perts if p in gene_index] + n_matched = len(matched_perts) + + if verbose: + print(f"[input] {n_cells:,} cells | {len(unique_perts)} perturbations (excl. control)") + + if n_matched == 0: + if verbose: + print("[no matched perturbations] returning only control cells") + return adata[control_mask].copy() + + # ------------------------------------------------------------------ + # 2. Extract dense submatrix for all matched genes at once + # Shape: (n_cells, n_matched) + # get_indexer translates gene names -> integer column positions, equivalent + # to what adata[:, matched_perts] does internally when var_names are gene names. + # ------------------------------------------------------------------ + pert_positions = gene_index.get_indexer(matched_perts) # (n_matched,) int array + + if sp.issparse(X): + X_sub = X[:, pert_positions].toarray() # (n_cells, n_matched) + else: + X_sub = np.asarray(X)[:, pert_positions] + + # ------------------------------------------------------------------ + # 3. Control means for all matched genes -- single matrix op + # ------------------------------------------------------------------ + ctrl_means = X_sub[control_mask].mean(axis=0) # (n_matched,) + + # ------------------------------------------------------------------ + # 4. Map each cell to its perturbation's column in X_sub + # (-1 for control / unmatched) + # ------------------------------------------------------------------ + pert_to_col = {p: i for i, p in enumerate(matched_perts)} + cell_col_mapped = perts.map(pert_to_col) # NaN for control / unmatched + cell_col = np.full(n_cells, -1, dtype=np.int32) + valid = ~cell_col_mapped.isna() + cell_col[valid.values] = cell_col_mapped[valid].values.astype(np.int32) + + matched_cells = cell_col >= 0 # (n_cells,) bool + valid_row = np.where(matched_cells)[0] # (n_valid,) cell row indices + valid_col = cell_col[valid_row] # (n_valid,) gene column indices into X_sub + + # ------------------------------------------------------------------ + # 5. Stage 1: perturbation-level filter + # + # For each matched pert i: + # mean(X_sub[cells_of_pert_i, i]) / ctrl_means[i] < residual_expression + # + # Fancy indexing gives each cell's expression at its own gene in one shot. + # np.bincount accumulates per-perturbation sums without a Python loop. + # ------------------------------------------------------------------ + diag_expr = X_sub[valid_row, valid_col] # (n_valid,) each cell's own-gene expr + + pert_sums = np.bincount(valid_col, weights=diag_expr, minlength=n_matched) + pert_counts = np.bincount(valid_col, minlength=n_matched).astype(np.float64) + pert_means = np.where(pert_counts > 0, pert_sums / np.maximum(pert_counts, 1.0), 0.0) + + valid_ctrl = ~np.isclose(ctrl_means, 0.0) + kd_ratio = np.where(valid_ctrl, pert_means / np.where(valid_ctrl, ctrl_means, 1.0), np.inf) + stage1_pass = valid_ctrl & (kd_ratio < residual_expression) # (n_matched,) bool + + if verbose: + cells_s1 = int(control_mask.sum()) + int(stage1_pass[valid_col].sum()) + perts_s1 = len(unique_perts) - int(stage1_pass.sum()) + print(f"[stage 1 (pert avg filter)] removed {n_cells - cells_s1:,} cells | {perts_s1} perturbations") + prev_cells = cells_s1 + + # ------------------------------------------------------------------ + # 6. Stage 2: cell-level filter + # + # For each cell whose pert passed stage 1: + # X_sub[cell, pert_col] / ctrl_means[pert_col] < cell_residual_expression + # + # Fancy indexing (rows=passed cell indices, cols=their pert column) handles + # the entire stage in three numpy operations. + # ------------------------------------------------------------------ + passed_sel = stage1_pass[valid_col] # (n_valid,) bool + passed_row = valid_row[passed_sel] # (n_passed,) cell row indices + passed_col = valid_col[passed_sel] # (n_passed,) gene column indices into X_sub + + expr_vals = X_sub[passed_row, passed_col] # (n_passed,) + ctrl_means_per_cell = ctrl_means[passed_col] # (n_passed,) + + nonzero_ctrl = ~np.isclose(ctrl_means_per_cell, 0.0) + cell_keep = np.zeros(len(passed_row), dtype=bool) + cell_keep[nonzero_ctrl] = ( + expr_vals[nonzero_ctrl] / ctrl_means_per_cell[nonzero_ctrl] < cell_residual_expression + ) + + keep_mask = control_mask.copy() + keep_mask[passed_row] = cell_keep + + if verbose: + cells_s2 = int(keep_mask.sum()) + perts_after_s1 = set(np.array(matched_perts)[stage1_pass]) + perts_after_s2 = set(perts.values[keep_mask & matched_cells]) + print(f"[stage 2 (cell filter) ] removed {prev_cells - cells_s2:,} cells | {len(perts_after_s1 - perts_after_s2)} perturbations") + prev_cells = cells_s2 + + # ------------------------------------------------------------------ + # 7. Stage 3: minimum cells per perturbation + # ------------------------------------------------------------------ + kept_mask = keep_mask & matched_cells + kept_row = np.where(kept_mask)[0] + kept_col = cell_col[kept_row] + + if len(kept_col) > 0: + pert_kept_counts = np.bincount(kept_col, minlength=n_matched) + drop_pert = pert_kept_counts < min_cells + cell_drop = drop_pert[kept_col] + keep_mask[kept_row[cell_drop]] = False + + if verbose: + cells_s3 = int(keep_mask.sum()) + perts_removed = int((drop_pert & (pert_kept_counts > 0)).sum()) if len(kept_col) > 0 else 0 + print(f"[stage 3 (min cells filter)] removed {prev_cells - cells_s3:,} cells | {perts_removed} perturbations") + + # ------------------------------------------------------------------ + # 8. Build output + # Subset first, then copy -- avoids copying the full expression matrix. + # ------------------------------------------------------------------ + result = adata[keep_mask].copy() + if var_gene_name in result.var.columns: + result.var.index = result.var[var_gene_name].astype(str) + result.var_names_make_unique() + return result \ No newline at end of file From 31deabd23218eb8a5f51bb5eb2566749351d6976 Mon Sep 17 00:00:00 2001 From: Leon Hafner Date: Tue, 31 Mar 2026 15:08:38 -0700 Subject: [PATCH 2/3] add output summary --- src/cell_load/utils/data_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/cell_load/utils/data_utils.py b/src/cell_load/utils/data_utils.py index 2bb9c36..fca53e7 100644 --- a/src/cell_load/utils/data_utils.py +++ b/src/cell_load/utils/data_utils.py @@ -615,6 +615,8 @@ def filter_on_target_knockdown_fast( cells_s3 = int(keep_mask.sum()) perts_removed = int((drop_pert & (pert_kept_counts > 0)).sum()) if len(kept_col) > 0 else 0 print(f"[stage 3 (min cells filter)] removed {prev_cells - cells_s3:,} cells | {perts_removed} perturbations") + out_perts = len(set(perts.values[keep_mask]) - {control_label}) + print(f"[output] {cells_s3:,} cells | {out_perts} perturbations (excl. control)") # ------------------------------------------------------------------ # 8. Build output From 904a7eedf742d55907c7d2eaa60205917210caa3 Mon Sep 17 00:00:00 2001 From: Leon Hafner Date: Tue, 31 Mar 2026 18:20:29 -0700 Subject: [PATCH 3/3] replace old function --- .../_cli/filter_on_target_knockdown.py | 7 + src/cell_load/utils/data_utils.py | 227 +----------------- 2 files changed, 12 insertions(+), 222 deletions(-) diff --git a/src/cell_load/_cli/filter_on_target_knockdown.py b/src/cell_load/_cli/filter_on_target_knockdown.py index eb1799b..61057d1 100644 --- a/src/cell_load/_cli/filter_on_target_knockdown.py +++ b/src/cell_load/_cli/filter_on_target_knockdown.py @@ -99,6 +99,12 @@ def main(): help="Column in adata.var containing gene names (default: gene_name)", ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print per-stage cell and perturbation counts during filtering", + ) + parser.add_argument( "--preprocess", action="store_true", @@ -137,6 +143,7 @@ def main(): min_cells=args.min_cells, layer=args.layer, var_gene_name=args.var_gene_name, + verbose=args.verbose, ) print( diff --git a/src/cell_load/utils/data_utils.py b/src/cell_load/utils/data_utils.py index fca53e7..44517f3 100644 --- a/src/cell_load/utils/data_utils.py +++ b/src/cell_load/utils/data_utils.py @@ -241,99 +241,16 @@ def suspected_log_torch(x: torch.Tensor) -> bool: return global_max.item() < 15.0 -def _mean(expr) -> float: - """Return the mean of a dense or sparse 1-D/2-D slice.""" - if sp.issparse(expr): - return float(expr.mean()) - return float(np.asarray(expr).mean()) - - -def is_on_target_knockdown( - adata: anndata.AnnData, - target_gene: str, - perturbation_column: str = "gene", - control_label: str = "non-targeting", - residual_expression: float = 0.30, - layer: str | None = None, -) -> bool: - """ - True ⇢ average expression of *target_gene* in perturbed cells is below - `residual_expression` × (average expression in control cells). - - Parameters - ---------- - adata : AnnData - target_gene : str - Gene symbol to check. - perturbation_column : str, default "gene" - Column in ``adata.obs`` holding perturbation identities. - control_label : str, default "non-targeting" - Category in *perturbation_column* marking control cells. - residual_expression : float, default 0.30 - Residual fraction (0‒1). 0.30 → 70 % knock-down. - layer : str | None, optional - Use this matrix in ``adata.layers`` instead of ``adata.X``. - - Raises - ------ - KeyError - *target_gene* not present in ``adata.var_names``. - ValueError - No perturbed cells for *target_gene*, or control mean is zero. - - Returns - ------- - bool - """ - if target_gene == control_label: - # Never evaluate the control itself - return False - - if target_gene not in adata.var_names: - print(f"Gene {target_gene!r} not found in `adata.var_names`.") - return 1 - - gene_idx = adata.var_names.get_loc(target_gene) - X = adata.layers[layer] if layer is not None else adata.X - - control_cells = adata.obs[perturbation_column] == control_label - perturbed_cells = adata.obs[perturbation_column] == target_gene - - if not perturbed_cells.any(): - raise ValueError(f"No cells labelled with perturbation {target_gene!r}.") - - try: - control_mean = _mean(X[control_cells, gene_idx]) - except Exception: - control_cells = (adata.obs[perturbation_column] == control_label).values - control_mean = _mean(X[control_cells, gene_idx]) - - if np.isclose(control_mean, 0.0): - log.warning( - "Skipping perturbation %r: control mean expression is zero; cannot compute knock-down ratio.", - target_gene, - ) - return False - - try: - perturbed_mean = _mean(X[perturbed_cells, gene_idx]) - except Exception: - perturbed_cells = (adata.obs[perturbation_column] == target_gene).values - perturbed_mean = _mean(X[perturbed_cells, gene_idx]) - - knockdown_ratio = perturbed_mean / control_mean - return knockdown_ratio < residual_expression - - def filter_on_target_knockdown( adata: anndata.AnnData, perturbation_column: str = "gene", control_label: str = "non-targeting", - residual_expression: float = 0.30, # perturbation-level threshold - cell_residual_expression: float = 0.50, # cell-level threshold - min_cells: int = 30, # **NEW**: minimum cells/perturbation + residual_expression: float = 0.30, + cell_residual_expression: float = 0.50, + min_cells: int = 30, layer: str | None = None, var_gene_name: str = "gene_name", + verbose: bool = False, ) -> anndata.AnnData: """ 1. Keep perturbations whose *average* knock-down ≥ (1-residual_expression). @@ -344,141 +261,7 @@ def filter_on_target_knockdown( Returns ------- AnnData - View of `adata` satisfying all three criteria. - """ - # --- prep --- - adata_ = set_var_index_to_col(adata.copy(), col=var_gene_name) - X = adata_.layers[layer] if layer is not None else adata_.X - perts = adata_.obs[perturbation_column] - control_cells = (perts == control_label).values - - # ---------- stage 1: perturbation filter ---------- - perts_to_keep = [control_label] # always keep controls - for pert in perts.unique(): - if pert == control_label: - continue - if is_on_target_knockdown( - adata_, - target_gene=pert, - perturbation_column=perturbation_column, - control_label=control_label, - residual_expression=residual_expression, - layer=layer, - ): - perts_to_keep.append(pert) - - # ---------- stage 2: cell filter ---------- - keep_mask = np.zeros(adata_.n_obs, dtype=bool) - keep_mask[control_cells] = True # retain all controls - - # cache control means to avoid recomputation - control_mean_cache: dict[str, float] = {} - - for pert in perts_to_keep: - if pert == control_label: - continue - - if pert not in adata_.var_names: - continue - - gene_idx = adata_.var_names.get_loc(pert) - - # control mean for this gene - if pert not in control_mean_cache: - try: - ctrl_mean = _mean(X[control_cells, gene_idx]) - except Exception: - print(control_cells.shape, control_cells) - print(gene_idx) - print(X[control_cells, gene_idx].shape) - control_mean_cache[pert] = ctrl_mean - else: - ctrl_mean = control_mean_cache[pert] - - pert_cells = (perts == pert).values - - if np.isclose(ctrl_mean, 0.0): - log.warning( - "Skipping cell-level filtering for perturbation %r because control mean expression is zero.", - pert, - ) - keep_mask[pert_cells] = False - continue - # FIX: Replace .A1 with .toarray().flatten() for scipy sparse matrices - expr_vals = ( - X[pert_cells, gene_idx].toarray().flatten() - if sp.issparse(X) - else X[pert_cells, gene_idx] - ) - ratios = expr_vals / ctrl_mean - keep_mask[pert_cells] = ratios < cell_residual_expression - - # ---------- stage 3: minimum-cell filter ---------- - for pert in perts.unique(): - if pert == control_label: - continue - # cells of this perturbation *still* kept after stages 1-2 - pert_mask = (perts == pert).values & keep_mask - if pert_mask.sum() < min_cells: - keep_mask[pert_mask] = False # drop them - - # return view with all criteria satisfied - return adata_[keep_mask] - - -def set_var_index_to_col(adata: anndata.AnnData, col: str = "col", copy=True) -> None: - """ - Set `adata.var` index to the values in the specified column, allowing non-unique indices. - - Parameters - ---------- - adata : AnnData - The AnnData object to modify. - col : str - Column in `adata.var` to use as the new index. - - Raises - ------ - KeyError - If the specified column does not exist in `adata.var`. - """ - if col not in adata.var.columns: - raise KeyError(f"Column {col!r} not found in adata.var.") - - adata.var.index = adata.var[col].astype("str") - adata.var_names_make_unique() - return adata - - -def filter_on_target_knockdown_fast( - adata: anndata.AnnData, - perturbation_column: str = "gene", - control_label: str = "non-targeting", - residual_expression: float = 0.30, - cell_residual_expression: float = 0.50, - min_cells: int = 30, - layer: str | None = None, - var_gene_name: str = "gene_name", - verbose: bool = False, -) -> anndata.AnnData: - """ - Vectorized reimplementation of filter_on_target_knockdown. Same semantics, much faster. - - Key differences from the original: - - No upfront adata.copy(): computation runs on the original X, the copy is deferred - to the end and covers only the kept subset of cells. - - Single dense submatrix X_sub (n_cells x n_matched_perts) replaces per-perturbation - sparse column slices in the original loops. - - Control means computed once for all matched genes via X_sub[control_mask].mean(axis=0). - - Stage 1 pert means: fancy indexing + np.bincount instead of a Python loop. - - Stage 2 cell-level ratios: fully vectorized via fancy indexing - (X_sub[row_idx, col_idx] retrieves each cell's expression at its own gene in one call). - - Stage 3 min-cell check: np.bincount over the keep mask, no loop. - - Note on unmatched perturbations (pert name not in var): - The original adds them to perts_to_keep in stage 1 (is_on_target_knockdown returns 1), - but their cells never enter keep_mask in stage 2 (continue is hit). Net effect: their - cells are dropped, identical to this implementation which simply never marks them as kept. + Subset of `adata` satisfying all three criteria, with var index set to gene names. """ if var_gene_name not in adata.var.columns: raise KeyError(f"Column {var_gene_name!r} not found in adata.var.")