Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/cell_load/_cli/filter_on_target_knockdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def main():
help="Column in adata.var containing gene names (default: gene_name)",
)

parser.add_argument(
"--verbose",
action="store_true",
help="Print per-stage cell and perturbation counts during filtering",
)

parser.add_argument(
"--preprocess",
action="store_true",
Expand Down Expand Up @@ -137,6 +143,7 @@ def main():
min_cells=args.min_cells,
layer=args.layer,
var_gene_name=args.var_gene_name,
verbose=args.verbose,
)

print(
Expand Down
339 changes: 151 additions & 188 deletions src/cell_load/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import anndata
import h5py
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch

Expand Down Expand Up @@ -240,99 +241,16 @@ def suspected_log_torch(x: torch.Tensor) -> bool:
return global_max.item() < 15.0


def _mean(expr) -> float:
"""Return the mean of a dense or sparse 1-D/2-D slice."""
if sp.issparse(expr):
return float(expr.mean())
return float(np.asarray(expr).mean())


def is_on_target_knockdown(
adata: anndata.AnnData,
target_gene: str,
perturbation_column: str = "gene",
control_label: str = "non-targeting",
residual_expression: float = 0.30,
layer: str | None = None,
) -> bool:
"""
True ⇢ average expression of *target_gene* in perturbed cells is below
`residual_expression` × (average expression in control cells).

Parameters
----------
adata : AnnData
target_gene : str
Gene symbol to check.
perturbation_column : str, default "gene"
Column in ``adata.obs`` holding perturbation identities.
control_label : str, default "non-targeting"
Category in *perturbation_column* marking control cells.
residual_expression : float, default 0.30
Residual fraction (0‒1). 0.30 → 70 % knock-down.
layer : str | None, optional
Use this matrix in ``adata.layers`` instead of ``adata.X``.

Raises
------
KeyError
*target_gene* not present in ``adata.var_names``.
ValueError
No perturbed cells for *target_gene*, or control mean is zero.

Returns
-------
bool
"""
if target_gene == control_label:
# Never evaluate the control itself
return False

if target_gene not in adata.var_names:
print(f"Gene {target_gene!r} not found in `adata.var_names`.")
return 1

gene_idx = adata.var_names.get_loc(target_gene)
X = adata.layers[layer] if layer is not None else adata.X

control_cells = adata.obs[perturbation_column] == control_label
perturbed_cells = adata.obs[perturbation_column] == target_gene

if not perturbed_cells.any():
raise ValueError(f"No cells labelled with perturbation {target_gene!r}.")

try:
control_mean = _mean(X[control_cells, gene_idx])
except Exception:
control_cells = (adata.obs[perturbation_column] == control_label).values
control_mean = _mean(X[control_cells, gene_idx])

if np.isclose(control_mean, 0.0):
log.warning(
"Skipping perturbation %r: control mean expression is zero; cannot compute knock-down ratio.",
target_gene,
)
return False

try:
perturbed_mean = _mean(X[perturbed_cells, gene_idx])
except Exception:
perturbed_cells = (adata.obs[perturbation_column] == target_gene).values
perturbed_mean = _mean(X[perturbed_cells, gene_idx])

knockdown_ratio = perturbed_mean / control_mean
return knockdown_ratio < residual_expression


def filter_on_target_knockdown(
adata: anndata.AnnData,
perturbation_column: str = "gene",
control_label: str = "non-targeting",
residual_expression: float = 0.30, # perturbation-level threshold
cell_residual_expression: float = 0.50, # cell-level threshold
min_cells: int = 30, # **NEW**: minimum cells/perturbation
residual_expression: float = 0.30,
cell_residual_expression: float = 0.50,
min_cells: int = 30,
layer: str | None = None,
var_gene_name: str = "gene_name",
verbose: bool = False,
) -> anndata.AnnData:
"""
1. Keep perturbations whose *average* knock-down ≥ (1-residual_expression).
Expand All @@ -343,107 +261,152 @@ def filter_on_target_knockdown(
Returns
-------
AnnData
View of `adata` satisfying all three criteria.
Subset of `adata` satisfying all three criteria, with var index set to gene names.
"""
# --- prep ---
adata_ = set_var_index_to_col(adata.copy(), col=var_gene_name)
X = adata_.layers[layer] if layer is not None else adata_.X
perts = adata_.obs[perturbation_column]
control_cells = (perts == control_label).values

# ---------- stage 1: perturbation filter ----------
perts_to_keep = [control_label] # always keep controls
for pert in perts.unique():
if pert == control_label:
continue
if is_on_target_knockdown(
adata_,
target_gene=pert,
perturbation_column=perturbation_column,
control_label=control_label,
residual_expression=residual_expression,
layer=layer,
):
perts_to_keep.append(pert)

# ---------- stage 2: cell filter ----------
keep_mask = np.zeros(adata_.n_obs, dtype=bool)
keep_mask[control_cells] = True # retain all controls

# cache control means to avoid recomputation
control_mean_cache: dict[str, float] = {}

for pert in perts_to_keep:
if pert == control_label:
continue

if pert not in adata_.var_names:
continue

gene_idx = adata_.var_names.get_loc(pert)

# control mean for this gene
if pert not in control_mean_cache:
try:
ctrl_mean = _mean(X[control_cells, gene_idx])
except Exception:
print(control_cells.shape, control_cells)
print(gene_idx)
print(X[control_cells, gene_idx].shape)
control_mean_cache[pert] = ctrl_mean
else:
ctrl_mean = control_mean_cache[pert]
if var_gene_name not in adata.var.columns:
raise KeyError(f"Column {var_gene_name!r} not found in adata.var.")

pert_cells = (perts == pert).values
# pd.Index over the gene name column — used for membership testing and position lookup.
# get_indexer returns the first occurrence for duplicate gene names, matching
# the behaviour of var_names_make_unique (first occurrence keeps original name).
gene_index = pd.Index(adata.var[var_gene_name])

if np.isclose(ctrl_mean, 0.0):
log.warning(
"Skipping cell-level filtering for perturbation %r because control mean expression is zero.",
pert,
)
keep_mask[pert_cells] = False
continue
# FIX: Replace .A1 with .toarray().flatten() for scipy sparse matrices
expr_vals = (
X[pert_cells, gene_idx].toarray().flatten()
if sp.issparse(X)
else X[pert_cells, gene_idx]
)
ratios = expr_vals / ctrl_mean
keep_mask[pert_cells] = ratios < cell_residual_expression

# ---------- stage 3: minimum-cell filter ----------
for pert in perts.unique():
if pert == control_label:
continue
# cells of this perturbation *still* kept after stages 1-2
pert_mask = (perts == pert).values & keep_mask
if pert_mask.sum() < min_cells:
keep_mask[pert_mask] = False # drop them

# return view with all criteria satisfied
return adata_[keep_mask]


def set_var_index_to_col(adata: anndata.AnnData, col: str = "col", copy=True) -> None:
"""
Set `adata.var` index to the values in the specified column, allowing non-unique indices.

Parameters
----------
adata : AnnData
The AnnData object to modify.
col : str
Column in `adata.var` to use as the new index.

Raises
------
KeyError
If the specified column does not exist in `adata.var`.
"""
if col not in adata.var.columns:
raise KeyError(f"Column {col!r} not found in adata.var.")

adata.var.index = adata.var[col].astype("str")
adata.var_names_make_unique()
return adata
# ------------------------------------------------------------------
# 1. Shared setup
# ------------------------------------------------------------------
X = adata.layers[layer] if layer is not None else adata.X
perts = adata.obs[perturbation_column]
n_cells = adata.n_obs

control_mask = (perts == control_label).values # (n_cells,) bool

unique_perts = [p for p in perts.unique() if p != control_label]
matched_perts = [p for p in unique_perts if p in gene_index]
n_matched = len(matched_perts)

if verbose:
print(f"[input] {n_cells:,} cells | {len(unique_perts)} perturbations (excl. control)")

if n_matched == 0:
if verbose:
print("[no matched perturbations] returning only control cells")
return adata[control_mask].copy()

# ------------------------------------------------------------------
# 2. Extract dense submatrix for all matched genes at once
# Shape: (n_cells, n_matched)
# get_indexer translates gene names -> integer column positions, equivalent
# to what adata[:, matched_perts] does internally when var_names are gene names.
# ------------------------------------------------------------------
pert_positions = gene_index.get_indexer(matched_perts) # (n_matched,) int array

if sp.issparse(X):
X_sub = X[:, pert_positions].toarray() # (n_cells, n_matched)
else:
X_sub = np.asarray(X)[:, pert_positions]
Comment on lines +303 to +306
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation densifies the entire submatrix X_sub for all matched perturbations. For large datasets with many perturbations (e.g., whole-genome screens), this can lead to excessive memory consumption or Out-Of-Memory (OOM) errors. Since X is typically sparse in single-cell data, it is more efficient to avoid full densification and instead use sparse indexing or only densify the specific elements needed for each stage.

Consider keeping X_sub sparse if X is sparse, and then using np.asarray(...).flatten() when extracting specific values (like diag_expr or expr_vals) or computing means to maintain the speedup while significantly reducing the memory footprint.


# ------------------------------------------------------------------
# 3. Control means for all matched genes -- single matrix op
# ------------------------------------------------------------------
ctrl_means = X_sub[control_mask].mean(axis=0) # (n_matched,)

# ------------------------------------------------------------------
# 4. Map each cell to its perturbation's column in X_sub
# (-1 for control / unmatched)
# ------------------------------------------------------------------
pert_to_col = {p: i for i, p in enumerate(matched_perts)}
cell_col_mapped = perts.map(pert_to_col) # NaN for control / unmatched
cell_col = np.full(n_cells, -1, dtype=np.int32)
valid = ~cell_col_mapped.isna()
cell_col[valid.values] = cell_col_mapped[valid].values.astype(np.int32)

matched_cells = cell_col >= 0 # (n_cells,) bool
valid_row = np.where(matched_cells)[0] # (n_valid,) cell row indices
valid_col = cell_col[valid_row] # (n_valid,) gene column indices into X_sub

# ------------------------------------------------------------------
# 5. Stage 1: perturbation-level filter
#
# For each matched pert i:
# mean(X_sub[cells_of_pert_i, i]) / ctrl_means[i] < residual_expression
#
# Fancy indexing gives each cell's expression at its own gene in one shot.
# np.bincount accumulates per-perturbation sums without a Python loop.
# ------------------------------------------------------------------
diag_expr = X_sub[valid_row, valid_col] # (n_valid,) each cell's own-gene expr

pert_sums = np.bincount(valid_col, weights=diag_expr, minlength=n_matched)
pert_counts = np.bincount(valid_col, minlength=n_matched).astype(np.float64)
pert_means = np.where(pert_counts > 0, pert_sums / np.maximum(pert_counts, 1.0), 0.0)

valid_ctrl = ~np.isclose(ctrl_means, 0.0)
kd_ratio = np.where(valid_ctrl, pert_means / np.where(valid_ctrl, ctrl_means, 1.0), np.inf)
stage1_pass = valid_ctrl & (kd_ratio < residual_expression) # (n_matched,) bool

if verbose:
cells_s1 = int(control_mask.sum()) + int(stage1_pass[valid_col].sum())
perts_s1 = len(unique_perts) - int(stage1_pass.sum())
print(f"[stage 1 (pert avg filter)] removed {n_cells - cells_s1:,} cells | {perts_s1} perturbations")
prev_cells = cells_s1

# ------------------------------------------------------------------
# 6. Stage 2: cell-level filter
#
# For each cell whose pert passed stage 1:
# X_sub[cell, pert_col] / ctrl_means[pert_col] < cell_residual_expression
#
# Fancy indexing (rows=passed cell indices, cols=their pert column) handles
# the entire stage in three numpy operations.
# ------------------------------------------------------------------
passed_sel = stage1_pass[valid_col] # (n_valid,) bool
passed_row = valid_row[passed_sel] # (n_passed,) cell row indices
passed_col = valid_col[passed_sel] # (n_passed,) gene column indices into X_sub

expr_vals = X_sub[passed_row, passed_col] # (n_passed,)
ctrl_means_per_cell = ctrl_means[passed_col] # (n_passed,)

nonzero_ctrl = ~np.isclose(ctrl_means_per_cell, 0.0)
cell_keep = np.zeros(len(passed_row), dtype=bool)
cell_keep[nonzero_ctrl] = (
expr_vals[nonzero_ctrl] / ctrl_means_per_cell[nonzero_ctrl] < cell_residual_expression
)

keep_mask = control_mask.copy()
keep_mask[passed_row] = cell_keep

if verbose:
cells_s2 = int(keep_mask.sum())
perts_after_s1 = set(np.array(matched_perts)[stage1_pass])
perts_after_s2 = set(perts.values[keep_mask & matched_cells])
print(f"[stage 2 (cell filter) ] removed {prev_cells - cells_s2:,} cells | {len(perts_after_s1 - perts_after_s2)} perturbations")
prev_cells = cells_s2

# ------------------------------------------------------------------
# 7. Stage 3: minimum cells per perturbation
# ------------------------------------------------------------------
kept_mask = keep_mask & matched_cells
kept_row = np.where(kept_mask)[0]
kept_col = cell_col[kept_row]

if len(kept_col) > 0:
pert_kept_counts = np.bincount(kept_col, minlength=n_matched)
drop_pert = pert_kept_counts < min_cells
cell_drop = drop_pert[kept_col]
keep_mask[kept_row[cell_drop]] = False

if verbose:
cells_s3 = int(keep_mask.sum())
perts_removed = int((drop_pert & (pert_kept_counts > 0)).sum()) if len(kept_col) > 0 else 0
print(f"[stage 3 (min cells filter)] removed {prev_cells - cells_s3:,} cells | {perts_removed} perturbations")
out_perts = len(set(perts.values[keep_mask]) - {control_label})
print(f"[output] {cells_s3:,} cells | {out_perts} perturbations (excl. control)")

# ------------------------------------------------------------------
# 8. Build output
# Subset first, then copy -- avoids copying the full expression matrix.
# ------------------------------------------------------------------
result = adata[keep_mask].copy()
if var_gene_name in result.var.columns:
result.var.index = result.var[var_gene_name].astype(str)
result.var_names_make_unique()
return result