From 877d0f5b4649b1775c5442badca13f60512288aa Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 9 Sep 2025 09:47:03 +0000 Subject: [PATCH 01/23] init working version with pytest. --- pyproject.toml | 1 + .../pertpy_gpu/_distances.py | 1187 +++++++++++++++++ tests/pertpy/test_distances.py | 223 ++++ 3 files changed, 1411 insertions(+) create mode 100644 src/rapids_singlecell/pertpy_gpu/_distances.py create mode 100644 tests/pertpy/test_distances.py diff --git a/pyproject.toml b/pyproject.toml index 9c2d5cef..14f7ad14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ test-minimal = [ "scanpy>=1.10.0", "bbknn", "decoupler", + "pertpy", "fast-array-utils", ] test = [ diff --git a/src/rapids_singlecell/pertpy_gpu/_distances.py b/src/rapids_singlecell/pertpy_gpu/_distances.py new file mode 100644 index 00000000..562471c8 --- /dev/null +++ b/src/rapids_singlecell/pertpy_gpu/_distances.py @@ -0,0 +1,1187 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Literal, NamedTuple + +import numpy as np +import pandas as pd +from pandas import Series +# TODO(selmanozleyen): adapt which progress bar to use, probably use rapidsinglecell's progress bar (if it exists) +# from rich.progress import track +# from scipy.sparse import issparse +# from scipy.spatial.distance import cosine, mahalanobis +# from scipy.special import gammaln +# from scipy.stats import kendalltau, kstest, pearsonr, spearmanr +# from sklearn.linear_model import LogisticRegression +# from sklearn.metrics import pairwise_distances, r2_score +# from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel +# from sklearn.neighbors import KernelDensity +# from statsmodels.discrete.discrete_model import NegativeBinomialP + +if TYPE_CHECKING: + from collections.abc import Callable + + from anndata import AnnData + + +class MeanVar(NamedTuple): + mean: float + variance: float + + +Metric = Literal[ + "edistance", + # "euclidean", + # "root_mean_squared_error", + # "mse", + # "mean_absolute_error", + # "pearson_distance", + # "spearman_distance", + # "kendalltau_distance", + # "cosine_distance", + # "r2_distance", + # "mean_pairwise", + # "mmd", + # "sym_kldiv", + # "t_test", + # "ks_test", + # "nb_ll", + # "classifier_proba", + # "classifier_cp", + # "mean_var_distribution", + # "mahalanobis", +] + + +class Distance: + """Distance class, used to compute distances between groups of cells. + + The distance metric can be specified by the user. This class also provides a + method to compute the pairwise distances between all groups of cells. + Currently available metrics: + + - "edistance": Energy distance (Default metric). + In essence, it is twice the mean pairwise distance between cells of two + groups minus the mean pairwise distance between cells within each group + respectively. More information can be found in + `Peidli et al. (2023) `__. + - "euclidean": euclidean distance. + Euclidean distance between the means of cells from two groups. + - "root_mean_squared_error": euclidean distance. + Euclidean distance between the means of cells from two groups. + - "mse": Pseudobulk mean squared error. + mean squared distance between the means of cells from two groups. + - "mean_absolute_error": Pseudobulk mean absolute distance. + Mean absolute distance between the means of cells from two groups. + - "pearson_distance": Pearson distance. + Pearson distance between the means of cells from two groups. + - "spearman_distance": Spearman distance. + Spearman distance between the means of cells from two groups. + - "kendalltau_distance": Kendall tau distance. + Kendall tau distance between the means of cells from two groups. + - "cosine_distance": Cosine distance. + Cosine distance between the means of cells from two groups. + - "r2_distance": coefficient of determination distance. + Coefficient of determination distance between the means of cells from two groups. + - "mean_pairwise": Mean pairwise distance. + Mean of the pairwise euclidean distances between cells of two groups. + - "mmd": Maximum mean discrepancy + Maximum mean discrepancy between the cells of two groups. + Here, uses linear, rbf, and quadratic polynomial MMD. For theory on MMD in single-cell applications, see + `Lotfollahi et al. (2019) `__. + - "wasserstein": Wasserstein distance (Earth Mover's Distance) + Wasserstein distance between the cells of two groups. Uses an + OTT-JAX implementation of the Sinkhorn algorithm to compute the distance. + For more information on the optimal transport solver, see + `Cuturi et al. (2013) `__. + - "sym_kldiv": symmetrized Kullback–Leibler divergence distance. + Kullback–Leibler divergence of the gaussian distributions between cells of two groups. + Here we fit a gaussian distribution over one group of cells and then calculate the KL divergence on the other, and vice versa. + - "t_test": t-test statistic. + T-test statistic measure between cells of two groups. + - "ks_test": Kolmogorov-Smirnov test statistic. + Kolmogorov-Smirnov test statistic measure between cells of two groups. + - "nb_ll": log-likelihood over negative binomial + Average of log-likelihoods of samples of the secondary group after fitting a negative binomial distribution + over the samples of the first group. + - "classifier_proba": probability of a binary classifier + Average of the classification probability of the perturbation for a binary classifier. + - "classifier_cp": classifier class projection + Average of the class + - "mean_var_distribution": Distance between mean-variance distributions between cells of 2 groups. + Mean square distance between the mean-variance distributions of cells from 2 groups using Kernel Density Estimation (KDE). + - "mahalanobis": Mahalanobis distance between the means of cells from two groups. + It is originally used to measure distance between a point and a distribution. + in this context, it quantifies the difference between the mean profiles of a target group and a reference group. + + Attributes: + metric: Name of distance metric. + layer_key: Name of the counts to use in adata.layers. + obsm_key: Name of embedding in adata.obsm to use. + cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example() + >>> Distance = pt.tools.Distance(metric="edistance") + >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"] + >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"] + >>> D = Distance(X, Y) + """ + + def __init__( + self, + metric: Metric = "edistance", + agg_fct: Callable = np.mean, + layer_key: str = None, + obsm_key: str = None, + cell_wise_metric: str = "euclidean", + ): + """Initialize Distance class. + + Args: + metric: Distance metric to use. + agg_fct: Aggregation function to generate pseudobulk vectors. + layer_key: Name of the counts layer containing raw counts to calculate distances for. + Mutually exclusive with 'obsm_key'. + Is not used if `None`. + obsm_key: Name of embedding in adata.obsm to use. + Mutually exclusive with 'layer_key'. + Defaults to None, but is set to "X_pca" if not explicitly set internally. + cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells. + """ + metric_fct: AbstractDistance = None + self.aggregation_func = agg_fct + # elif metric in ("euclidean", "root_mean_squared_error"): + # metric_fct = EuclideanDistance(self.aggregation_func) + # elif metric == "mse": + # metric_fct = MeanSquaredDistance(self.aggregation_func) + # elif metric == "mean_absolute_error": + # metric_fct = MeanAbsoluteDistance(self.aggregation_func) + # elif metric == "pearson_distance": + # metric_fct = PearsonDistance(self.aggregation_func) + # elif metric == "spearman_distance": + # metric_fct = SpearmanDistance(self.aggregation_func) + # elif metric == "kendalltau_distance": + # metric_fct = KendallTauDistance(self.aggregation_func) + # elif metric == "cosine_distance": + # metric_fct = CosineDistance(self.aggregation_func) + # elif metric == "r2_distance": + # metric_fct = R2ScoreDistance(self.aggregation_func) + # elif metric == "mean_pairwise": + # metric_fct = MeanPairwiseDistance() + # elif metric == "mmd": + # metric_fct = MMD() + # elif metric == "sym_kldiv": + # metric_fct = SymmetricKLDivergence() + # elif metric == "t_test": + # metric_fct = TTestDistance() + # elif metric == "ks_test": + # metric_fct = KSTestDistance() + # elif metric == "nb_ll": + # metric_fct = NBLL() + # elif metric == "classifier_proba": + # metric_fct = ClassifierProbaDistance() + # elif metric == "classifier_cp": + # metric_fct = ClassifierClassProjection() + # elif metric == "mean_var_distribution": + # metric_fct = MeanVarDistributionDistance() + # elif metric == "mahalanobis": + # metric_fct = MahalanobisDistance(self.aggregation_func) + if metric == "edistance": + metric_fct = Edistance() + else: + raise ValueError(f"Metric {metric} not recognized.") + self.metric_fct = metric_fct + + if layer_key and obsm_key: + raise ValueError( + "Cannot use 'layer_key' and 'obsm_key' at the same time.\nPlease provide only one of the two keys." + ) + if not layer_key and not obsm_key: + obsm_key = "X_pca" + self.layer_key = layer_key + self.obsm_key = obsm_key + self.metric = metric + self.cell_wise_metric = cell_wise_metric + + def __call__( + self, + X: np.ndarray, + Y: np.ndarray, + **kwargs, + ) -> float: + """Compute distance between vectors X and Y. + + Args: + X: First vector of shape (n_samples, n_features). + Y: Second vector of shape (n_samples, n_features). + kwargs: Passed to the metric function. + + Returns: + float: Distance between X and Y. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example() + >>> Distance = pt.tools.Distance(metric="edistance") + >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"] + >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"] + >>> D = Distance(X, Y) + """ + if issparse(X): + X = X.toarray() + if issparse(Y): + Y = Y.toarray() + + if len(X) == 0 or len(Y) == 0: + raise ValueError("Neither X nor Y can be empty.") + + return self.metric_fct(X, Y, **kwargs) + + def bootstrap( + self, + X: np.ndarray, + Y: np.ndarray, + *, + n_bootstrap: int = 100, + random_state: int = 0, + **kwargs, + ) -> MeanVar: + """Bootstrap computation of mean and variance of the distance between vectors X and Y. + + Args: + X: First vector of shape (n_samples, n_features). + Y: Second vector of shape (n_samples, n_features). + n_bootstrap: Number of bootstrap samples. + random_state: Random state for bootstrapping. + **kwargs: Passed to the metric function. + + Returns: + Mean and variance of distance between X and Y. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example() + >>> Distance = pt.tools.Distance(metric="edistance") + >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"] + >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"] + >>> D = Distance.bootstrap(X, Y) + """ + return self._bootstrap_mode( + X, + Y, + n_bootstraps=n_bootstrap, + random_state=random_state, + **kwargs, + ) + + def pairwise( + self, + adata: AnnData, + groupby: str, + groups: list[str] | None = None, + bootstrap: bool = False, + n_bootstrap: int = 100, + random_state: int = 0, + show_progressbar: bool = True, + n_jobs: int = -1, + **kwargs, + ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]: + """Get pairwise distances between groups of cells. + + Args: + adata: Annotated data matrix. + groupby: Column name in adata.obs. + groups: List of groups to compute pairwise distances for. + If None, uses all groups. + bootstrap: Whether to bootstrap the distance. + n_bootstrap: Number of bootstrap samples. + random_state: Random state for bootstrapping. + show_progressbar: Whether to show progress bar. + n_jobs: Number of cores to use. Defaults to -1 (all). + kwargs: Additional keyword arguments passed to the metric function. + + Returns: + :class:`pandas.DataFrame`: Dataframe with pairwise distances. + tuple[:class:`pandas.DataFrame`, :class:`pandas.DataFrame`]: Two Dataframes, one for the mean and one for the variance of pairwise distances. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example() + >>> Distance = pt.tools.Distance(metric="edistance") + >>> pairwise_df = Distance.pairwise(adata, groupby="perturbation") + """ + groups = adata.obs[groupby].unique() if groups is None else groups + grouping = adata.obs[groupby].copy() + df = pd.DataFrame(index=groups, columns=groups, dtype=float) + if bootstrap: + df_var = pd.DataFrame(index=groups, columns=groups, dtype=float) + # fct = track if show_progressbar else lambda iterable: iterable + fct = lambda iterable: iterable # see TODO above about progress bar + + # Some metrics are able to handle precomputed distances. This means that + # the pairwise distances between all cells are computed once and then + # passed to the metric function. This is much faster than computing the + # pairwise distances for each group separately. Other metrics are not + # able to handle precomputed distances such as the PseudobulkDistance. + if self.metric_fct.accepts_precomputed: + # Precompute the pairwise distances if needed + if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp: + self.precompute_distances(adata, n_jobs=n_jobs, **kwargs) + pwd = adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] + for index_x, group_x in enumerate(fct(groups)): + idx_x = grouping == group_x + for group_y in groups[index_x:]: # type: ignore + # subset the pairwise distance matrix to the two groups + idx_y = grouping == group_y + sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y] + sub_idx = grouping[idx_x | idx_y] == group_x + if not bootstrap: + if group_x == group_y: + dist = 0.0 + else: + dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs) + df.loc[group_x, group_y] = dist + df.loc[group_y, group_x] = dist + + else: + bootstrap_output = self._bootstrap_mode_precomputed( + sub_pwd, + sub_idx, + n_bootstraps=n_bootstrap, + random_state=random_state, + **kwargs, + ) + # In the bootstrap case, distance of group to itself is a mean and can be non-zero + df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean + df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance + else: + embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() + for index_x, group_x in enumerate(fct(groups)): + cells_x = embedding[np.asarray(grouping == group_x)].copy() + for group_y in groups[index_x:]: # type: ignore + cells_y = embedding[np.asarray(grouping == group_y)].copy() + if not bootstrap: + # By distance axiom, the distance between a group and itself is 0 + dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs) + + df.loc[group_x, group_y] = dist + df.loc[group_y, group_x] = dist + else: + bootstrap_output = self.bootstrap( + cells_x, + cells_y, + n_bootstrap=n_bootstrap, + random_state=random_state, + **kwargs, + ) + # In the bootstrap case, distance of group to itself is a mean and can be non-zero + df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean + df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance + + df.index.name = groupby + df.columns.name = groupby + df.name = f"pairwise {self.metric}" + + if not bootstrap: + return df + else: + df = df.fillna(0) + df_var.index.name = groupby + df_var.columns.name = groupby + df_var = df_var.fillna(0) + df_var.name = f"pairwise {self.metric} variance" + + return df, df_var + + def onesided_distances( + self, + adata: AnnData, + groupby: str, + selected_group: str | None = None, + groups: list[str] | None = None, + bootstrap: bool = False, + n_bootstrap: int = 100, + random_state: int = 0, + show_progressbar: bool = True, + n_jobs: int = -1, + **kwargs, + ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]: + """Get distances between one selected cell group and the remaining other cell groups. + + Args: + adata: Annotated data matrix. + groupby: Column name in adata.obs. + selected_group: Group to compute pairwise distances to all other. + groups: List of groups to compute distances to selected_group for. + If None, uses all groups. + bootstrap: Whether to bootstrap the distance. + n_bootstrap: Number of bootstrap samples. + random_state: Random state for bootstrapping. + show_progressbar: Whether to show progress bar. + n_jobs: Number of cores to use. Defaults to -1 (all). + kwargs: Additional keyword arguments passed to the metric function. + + Returns: + :class:`pandas.DataFrame`: Dataframe with distances of groups to selected_group. + tuple[:class:`pandas.DataFrame`, :class:`pandas.DataFrame`]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group. + + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example() + >>> Distance = pt.tools.Distance(metric="edistance") + >>> pairwise_df = Distance.onesided_distances(adata, groupby="perturbation", selected_group="control") + """ + if self.metric == "classifier_cp": + if bootstrap: + raise NotImplementedError("Currently, ClassifierClassProjection does not support bootstrapping.") + return self.metric_fct.onesided_distances( # type: ignore + adata, + groupby, + selected_group, + groups, + show_progressbar, + n_jobs, + **kwargs, + ) + + groups = adata.obs[groupby].unique() if groups is None else groups + grouping = adata.obs[groupby].copy() + df = pd.Series(index=groups, dtype=float) + if bootstrap: + df_var = pd.Series(index=groups, dtype=float) + # fct = track if show_progressbar else lambda iterable: iterable + fct = lambda iterable: iterable # see TODO at the top of the file about progress bar + + # Some metrics are able to handle precomputed distances. This means that + # the pairwise distances between all cells are computed once and then + # passed to the metric function. This is much faster than computing the + # pairwise distances for each group separately. Other metrics are not + # able to handle precomputed distances such as the PseudobulkDistance. + if self.metric_fct.accepts_precomputed: + # Precompute the pairwise distances if needed + if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp: + self.precompute_distances(adata, n_jobs=n_jobs, **kwargs) + pwd = adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] + for group_x in fct(groups): + idx_x = grouping == group_x + group_y = selected_group + if group_x == group_y: + df.loc[group_x] = 0.0 # by distance axiom + else: + idx_y = grouping == group_y + # subset the pairwise distance matrix to the two groups + sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y] + sub_idx = grouping[idx_x | idx_y] == group_x + if not bootstrap: + dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs) + df.loc[group_x] = dist + else: + bootstrap_output = self._bootstrap_mode_precomputed( + sub_pwd, + sub_idx, + n_bootstraps=n_bootstrap, + random_state=random_state, + **kwargs, + ) + df.loc[group_x] = bootstrap_output.mean + df_var.loc[group_x] = bootstrap_output.variance + else: + embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() + for group_x in fct(groups): + cells_x = embedding[np.asarray(grouping == group_x)].copy() + group_y = selected_group + cells_y = embedding[np.asarray(grouping == group_y)].copy() + if not bootstrap: + # By distance axiom, the distance between a group and itself is 0 + dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs) + df.loc[group_x] = dist + else: + bootstrap_output = self.bootstrap( + cells_x, + cells_y, + n_bootstrap=n_bootstrap, + random_state=random_state, + **kwargs, + ) + # In the bootstrap case, distance of group to itself is a mean and can be non-zero + df.loc[group_x] = bootstrap_output.mean + df_var.loc[group_x] = bootstrap_output.variance + df.index.name = groupby + df.name = f"{self.metric} to {selected_group}" + if not bootstrap: + return df + else: + df_var.index.name = groupby + df_var = df_var.fillna(0) + df_var.name = f"pairwise {self.metric} variance to {selected_group}" + + return df, df_var + + def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None: + """Precompute pairwise distances between all cells, writes to adata.obsp. + + The precomputed distances are stored in adata.obsp under the key + '{self.obsm_key}_{cell_wise_metric}_predistances', as they depend on + both the cell-wise metric and the embedding used. + + Args: + adata: Annotated data matrix. + n_jobs: Number of cores to use. Defaults to -1 (all). + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example() + >>> distance = pt.tools.Distance(metric="edistance") + >>> distance.precompute_distances(adata) + """ + cells = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() + pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs) + adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd + + def compare_distance( + self, + pert: np.ndarray, + pred: np.ndarray, + ctrl: np.ndarray, + mode: Literal["simple", "scaled"] = "simple", + fit_to_pert_and_ctrl: bool = False, + **kwargs, + ) -> float: + """Compute the score of simulating a perturbation. + + Args: + pert: Real perturbed data. + pred: Simulated perturbed data. + ctrl: Control data + mode: Mode to use. + fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`. + kwargs: Additional keyword arguments passed to the metric function. + """ + if mode == "simple": + pass # nothing to be done + elif mode == "scaled": + from sklearn.preprocessing import MinMaxScaler + + scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl) + pred = scaler.transform(pred) + pert = scaler.transform(pert) + else: + raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.") + + d1 = self.metric_fct(pert, pred, **kwargs) + d2 = self.metric_fct(ctrl, pred, **kwargs) + return d1 / d2 + + def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + rng = np.random.default_rng(random_state) + + distances = [] + for _ in range(n_bootstraps): + X_bootstrapped = X[rng.choice(a=X.shape[0], size=X.shape[0], replace=True)] + Y_bootstrapped = Y[rng.choice(a=Y.shape[0], size=X.shape[0], replace=True)] + + distance = self(X_bootstrapped, Y_bootstrapped, **kwargs) + distances.append(distance) + + mean = np.mean(distances) + variance = np.var(distances) + return MeanVar(mean=mean, variance=variance) + + def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + rng = np.random.default_rng(random_state) + + distances = [] + for _ in range(n_bootstraps): + # To maintain the number of cells for both groups (whatever balancing they may have), + # we sample the positive and negative indices separately + bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True) + bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True) + bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx]) + bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx) + + bootstrap_sub_idx = sub_idx[bootstrap_idx] + bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs] + + distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs) + distances.append(distance) + + mean = np.mean(distances) + variance = np.var(distances) + return MeanVar(mean=mean, variance=variance) + + +class AbstractDistance(ABC): + """Abstract class of distance metrics between two sets of vectors.""" + + @abstractmethod + def __init__(self) -> None: + super().__init__() + self.accepts_precomputed: bool = None + + @abstractmethod + def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: + """Compute distance between vectors X and Y. + + Args: + X: First vector of shape (n_samples, n_features). + Y: Second vector of shape (n_samples, n_features). + kwargs: Passed to the metrics function. + + Returns: + float: Distance between X and Y. + """ + raise NotImplementedError("Metric class is abstract.") + + @abstractmethod + def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: + """Compute a distance between vectors X and Y with precomputed distances. + + Args: + P: Pairwise distance matrix of shape (n_samples, n_samples). + idx: Boolean array of shape (n_samples,) indicating which samples belong to X (or Y, since each metric is symmetric). + kwargs: Passed to the metrics function. + + Returns: + float: Distance between X and Y. + """ + raise NotImplementedError("Metric class is abstract.") + + +class Edistance(AbstractDistance): + """Edistance metric.""" + + def __init__(self) -> None: + super().__init__() + self.accepts_precomputed = True + self.cell_wise_metric = "euclidean" + + def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: + sigma_X = pairwise_distances(X, X, metric=self.cell_wise_metric, **kwargs).mean() + sigma_Y = pairwise_distances(Y, Y, metric=self.cell_wise_metric, **kwargs).mean() + delta = pairwise_distances(X, Y, metric=self.cell_wise_metric, **kwargs).mean() + return 2 * delta - sigma_X - sigma_Y + + def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: + sigma_X = P[idx, :][:, idx].mean() + sigma_Y = P[~idx, :][:, ~idx].mean() + delta = P[idx, :][:, ~idx].mean() + return 2 * delta - sigma_X - sigma_Y + + +# class MMD(AbstractDistance): +# """Linear Maximum Mean Discrepancy.""" + +# # Taken in parts from https://github.com/jindongwang/transferlearning/blob/master/code/distance/mmd_numpy_sklearn.py +# def __init__(self) -> None: +# super().__init__() +# self.accepts_precomputed = False + +# def __call__(self, X: np.ndarray, Y: np.ndarray, *, kernel="linear", gamma=1.0, degree=2, **kwargs) -> float: +# if kernel == "linear": +# XX = np.dot(X, X.T) +# YY = np.dot(Y, Y.T) +# XY = np.dot(X, Y.T) +# elif kernel == "rbf": +# XX = rbf_kernel(X, X, gamma=gamma) +# YY = rbf_kernel(Y, Y, gamma=gamma) +# XY = rbf_kernel(X, Y, gamma=gamma) +# elif kernel == "poly": +# XX = polynomial_kernel(X, X, degree=degree, gamma=gamma, coef0=0) +# YY = polynomial_kernel(Y, Y, degree=degree, gamma=gamma, coef0=0) +# XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=0) +# else: +# raise ValueError(f"Kernel {kernel} not recognized.") + +# return XX.mean() + YY.mean() - 2 * XY.mean() + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("MMD cannot be called on a pairwise distance matrix.") + + + +# class EuclideanDistance(AbstractDistance): +# """Euclidean distance between pseudobulk vectors.""" + +# def __init__(self, aggregation_func: Callable = np.mean) -> None: +# super().__init__() +# self.accepts_precomputed = False +# self.aggregation_func = aggregation_func + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# return np.linalg.norm( +# self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0), +# ord=2, +# **kwargs, +# ) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("EuclideanDistance cannot be called on a pairwise distance matrix.") + + +# class MeanSquaredDistance(AbstractDistance): +# """Mean squared distance between pseudobulk vectors.""" + +# def __init__(self, aggregation_func: Callable = np.mean) -> None: +# super().__init__() +# self.accepts_precomputed = False +# self.aggregation_func = aggregation_func + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# return ( +# np.linalg.norm( +# self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0), +# ord=2, +# **kwargs, +# ) +# ** 2 +# / X.shape[1] +# ) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("MeanSquaredDistance cannot be called on a pairwise distance matrix.") + + +# class MeanAbsoluteDistance(AbstractDistance): +# """Absolute (Norm-1) distance between pseudobulk vectors.""" + +# def __init__(self, aggregation_func: Callable = np.mean) -> None: +# super().__init__() +# self.accepts_precomputed = False +# self.aggregation_func = aggregation_func + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# return ( +# np.linalg.norm( +# self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0), +# ord=1, +# **kwargs, +# ) +# / X.shape[1] +# ) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("MeanAbsoluteDistance cannot be called on a pairwise distance matrix.") + + +# class MeanPairwiseDistance(AbstractDistance): +# """Mean of the pairwise euclidean distance between two groups of cells.""" + +# # NOTE: This is not a metric in the mathematical sense. + +# def __init__(self) -> None: +# super().__init__() +# self.accepts_precomputed = True + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# return pairwise_distances(X, Y, **kwargs).mean() + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# return P[idx, :][:, ~idx].mean() + + +# class PearsonDistance(AbstractDistance): +# """Pearson distance between pseudobulk vectors.""" + +# def __init__(self, aggregation_func: Callable = np.mean) -> None: +# super().__init__() +# self.accepts_precomputed = False +# self.aggregation_func = aggregation_func + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# return 1 - pearsonr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0] + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("PearsonDistance cannot be called on a pairwise distance matrix.") + + +# class SpearmanDistance(AbstractDistance): +# """Spearman distance between pseudobulk vectors.""" + +# def __init__(self, aggregation_func: Callable = np.mean) -> None: +# super().__init__() +# self.accepts_precomputed = False +# self.aggregation_func = aggregation_func + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# return 1 - spearmanr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0] + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("SpearmanDistance cannot be called on a pairwise distance matrix.") + + +# class KendallTauDistance(AbstractDistance): +# """Kendall-tau distance between pseudobulk vectors.""" + +# def __init__(self, aggregation_func: Callable = np.mean) -> None: +# super().__init__() +# self.accepts_precomputed = False +# self.aggregation_func = aggregation_func + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# x, y = self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0) +# n = len(x) +# tau_corr = kendalltau(x, y).statistic +# tau_dist = (1 - tau_corr) * n * (n - 1) / 4 +# return tau_dist + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("KendallTauDistance cannot be called on a pairwise distance matrix.") + + +# class CosineDistance(AbstractDistance): +# """Cosine distance between pseudobulk vectors.""" + +# def __init__(self, aggregation_func: Callable = np.mean) -> None: +# super().__init__() +# self.accepts_precomputed = False +# self.aggregation_func = aggregation_func + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# return cosine(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("CosineDistance cannot be called on a pairwise distance matrix.") + + +# class R2ScoreDistance(AbstractDistance): +# """Coefficient of determination across genes between pseudobulk vectors.""" + +# # NOTE: This is not a distance metric but a similarity metric. + +# def __init__(self, aggregation_func: Callable = np.mean) -> None: +# super().__init__() +# self.accepts_precomputed = False +# self.aggregation_func = aggregation_func + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# return 1 - r2_score(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("R2ScoreDistance cannot be called on a pairwise distance matrix.") + + +# class SymmetricKLDivergence(AbstractDistance): +# """Average of symmetric KL divergence between gene distributions of two groups. + +# Assuming a Gaussian distribution for each gene in each group, calculates +# the KL divergence between them and averages over all genes. Repeats this ABBA to get a symmetrized distance. +# See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Symmetrised_divergence. + +# """ + +# def __init__(self) -> None: +# super().__init__() +# self.accepts_precomputed = False + +# def __call__(self, X: np.ndarray, Y: np.ndarray, epsilon=1e-8, **kwargs) -> float: +# kl_all = [] +# for i in range(X.shape[1]): +# x_mean, x_std = X[:, i].mean(), X[:, i].std() + epsilon +# y_mean, y_std = Y[:, i].mean(), Y[:, i].std() + epsilon +# kl = np.log(y_std / x_std) + (x_std**2 + (x_mean - y_mean) ** 2) / (2 * y_std**2) - 1 / 2 +# klr = np.log(x_std / y_std) + (y_std**2 + (y_mean - x_mean) ** 2) / (2 * x_std**2) - 1 / 2 +# kl_all.append(kl + klr) +# return sum(kl_all) / len(kl_all) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("SymmetricKLDivergence cannot be called on a pairwise distance matrix.") + + +# class TTestDistance(AbstractDistance): +# """Average of T test statistic between two groups assuming unequal variances.""" + +# def __init__(self) -> None: +# super().__init__() +# self.accepts_precomputed = False + +# def __call__(self, X: np.ndarray, Y: np.ndarray, epsilon=1e-8, **kwargs) -> float: +# t_test_all = [] +# n1 = X.shape[0] +# n2 = Y.shape[0] +# for i in range(X.shape[1]): +# m1, v1 = X[:, i].mean(), X[:, i].std() ** 2 * n1 / (n1 - 1) + epsilon +# m2, v2 = Y[:, i].mean(), Y[:, i].std() ** 2 * n2 / (n2 - 1) + epsilon +# vn1 = v1 / n1 +# vn2 = v2 / n2 +# t = (m1 - m2) / np.sqrt(vn1 + vn2) +# t_test_all.append(abs(t)) +# return sum(t_test_all) / len(t_test_all) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("TTestDistance cannot be called on a pairwise distance matrix.") + + +# class KSTestDistance(AbstractDistance): +# """Average of two-sided KS test statistic between two groups.""" + +# def __init__(self) -> None: +# super().__init__() +# self.accepts_precomputed = False + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# stats = [abs(kstest(X[:, i], Y[:, i])[0]) for i in range(X.shape[1])] +# return sum(stats) / len(stats) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("KSTestDistance cannot be called on a pairwise distance matrix.") + + +# class NBLL(AbstractDistance): +# """Average of Log likelihood (scalar) of group B cells according to a NB distribution fitted over group A.""" + +# def __init__(self) -> None: +# super().__init__() +# self.accepts_precomputed = False + +# def __call__(self, X: np.ndarray, Y: np.ndarray, epsilon=1e-8, **kwargs) -> float: +# def _is_count_matrix(matrix, tolerance=1e-6): +# return bool(matrix.dtype.kind == "i" or np.all(np.abs(matrix - np.round(matrix)) < tolerance)) + +# if not _is_count_matrix(matrix=X) or not _is_count_matrix(matrix=Y): +# raise ValueError("NBLL distance only works for raw counts.") + +# @jit(forceobj=True) +# def _compute_nll(y: np.ndarray, nb_params: tuple[float, float], epsilon: float) -> float: +# mu = np.exp(nb_params[0]) +# theta = 1 / nb_params[1] +# eps = epsilon + +# log_theta_mu_eps = np.log(theta + mu + eps) +# nll = ( +# theta * (np.log(theta + eps) - log_theta_mu_eps) +# + y * (np.log(mu + eps) - log_theta_mu_eps) +# + gammaln(y + theta) +# - gammaln(theta) +# - gammaln(y + 1) +# ) +# return nll.mean() + +# def _process_gene(x: np.ndarray, y: np.ndarray, epsilon: float) -> float: +# try: +# nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params +# return _compute_nll(y, nb_params, epsilon) +# except np.linalg.LinAlgError: +# if x.mean() < 10 and y.mean() < 10: +# return 0.0 +# else: +# return np.nan # Use NaN to indicate skipped genes + +# nlls = [] +# genes_skipped = 0 + +# for i in range(X.shape[1]): +# nll = _process_gene(X[:, i], Y[:, i], epsilon) +# if np.isnan(nll): +# genes_skipped += 1 +# else: +# nlls.append(nll) + +# if genes_skipped > X.shape[1] / 2: +# raise AttributeError(f"{genes_skipped} genes could not be fit, which is over half.") + +# return -np.sum(nlls) / len(nlls) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("NBLL cannot be called on a pairwise distance matrix.") + + +# def _sample(X, frac=None, n=None): +# """Returns subsample of cells in format (train, test).""" +# if frac and n: +# raise ValueError("Cannot pass both frac and n.") +# if frac: +# n_cells = max(1, int(X.shape[0] * frac)) +# elif n: +# n_cells = n +# else: +# raise ValueError("Must pass either `frac` or `n`.") + +# rng = np.random.default_rng() +# sampled_indices = rng.choice(X.shape[0], n_cells, replace=False) +# remaining_indices = np.setdiff1d(np.arange(X.shape[0]), sampled_indices) +# return X[remaining_indices, :], X[sampled_indices, :] + + +# class ClassifierProbaDistance(AbstractDistance): +# """Average of classification probabilites of a binary classifier. + +# Assumes the first condition is control and the second is perturbed. +# Always holds out 20% of the perturbed condition. +# """ + +# def __init__(self) -> None: +# super().__init__() +# self.accepts_precomputed = False + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# Y_train, Y_test = _sample(Y, frac=0.2) +# label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0] +# train = np.concatenate([X, Y_train]) + +# reg = LogisticRegression() +# reg.fit(train, label) +# test_labels = reg.predict_proba(Y_test) +# return np.mean(test_labels[:, 1]) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("ClassifierProbaDistance cannot be called on a pairwise distance matrix.") + + +# class ClassifierClassProjection(AbstractDistance): +# """Average of 1-(classification probability of control). + +# Warning: unlike all other distances, this must also take a list of categorical labels the same length as X. +# """ + +# def __init__(self) -> None: +# super().__init__() +# self.accepts_precomputed = False + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("ClassifierClassProjection can currently only be called with onesided.") + +# def onesided_distances( +# self, +# adata: AnnData, +# groupby: str, +# selected_group: str | None = None, +# groups: list[str] | None = None, +# show_progressbar: bool = True, +# n_jobs: int = -1, +# **kwargs, +# ) -> Series: +# """Unlike the parent function, all groups except the selected group are factored into the classifier. + +# Similar to the parent function, the returned dataframe contains only the specified groups. +# """ +# groups = adata.obs[groupby].unique() if groups is None else groups +# fct = track if show_progressbar else lambda iterable: iterable + +# X = adata[adata.obs[groupby] != selected_group].X +# labels = adata[adata.obs[groupby] != selected_group].obs[groupby].values +# Y = adata[adata.obs[groupby] == selected_group].X + +# reg = LogisticRegression() +# reg.fit(X, labels) +# test_probas = reg.predict_proba(Y) + +# df = pd.Series(index=groups, dtype=float) + +# for group in fct(groups): +# if group == selected_group: +# df.loc[group] = 0 +# else: +# class_idx = list(reg.classes_).index(group) +# df.loc[group] = 1 - np.mean(test_probas[:, class_idx]) +# df.index.name = groupby +# df.name = f"classifier_cp to {selected_group}" + +# return df + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("ClassifierClassProjection cannot be called on a pairwise distance matrix.") + + +# class MeanVarDistributionDistance(AbstractDistance): +# """Distance between mean-var distributions of gene expression.""" + +# def __init__(self) -> None: +# super().__init__() +# self.accepts_precomputed = False + +# @staticmethod +# def _mean_var(x, log: bool = False): +# mean = np.mean(x, axis=0) +# var = np.var(x, axis=0) +# positive = mean > 0 +# mean = mean[positive] +# var = var[positive] +# if log: +# mean = np.log(mean) +# var = np.log(var) +# return mean, var + +# @staticmethod +# def _prep_kde_data(x, y): +# return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1) + +# @staticmethod +# def _grid_points(d, n_points=100): +# # Make grid, add 1 bin on lower/upper end to get final n_points +# d_min = d.min() +# d_max = d.max() +# # Compute bin size +# d_bin = (d_max - d_min) / (n_points - 2) +# d_min = d_min - d_bin +# d_max = d_max + d_bin +# return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin) + +# @staticmethod +# def _kde_eval_both(x_kde, y_kde, grid): +# n_points = len(grid) +# chunk_size = 10000 + +# result_x = np.zeros(n_points) +# result_y = np.zeros(n_points) + +# # Process same chunks for both KDEs +# for start in range(0, n_points, chunk_size): +# end = min(start + chunk_size, n_points) +# chunk = grid[start:end] +# result_x[start:end] = x_kde.score_samples(chunk) +# result_y[start:end] = y_kde.score_samples(chunk) + +# return result_x, result_y + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# """Difference of mean-var distributions in 2 matrices. + +# Args: +# X: Normalized and log transformed cells x genes count matrix. +# Y: Normalized and log transformed cells x genes count matrix. +# kwargs: Passed to the metrics function. +# """ +# mean_x, var_x = self._mean_var(X, log=True) +# mean_y, var_y = self._mean_var(Y, log=True) + +# x = self._prep_kde_data(mean_x, var_x) +# y = self._prep_kde_data(mean_y, var_y) + +# # Gridpoints to eval KDE on +# mean_grid = self._grid_points(np.concatenate([mean_x, mean_y])) +# var_grid = self._grid_points(np.concatenate([var_x, var_y])) +# grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2) + +# # Fit both KDEs first +# x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x) +# y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y) + +# # Evaluate both KDEs on same grid chunks +# kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid) + +# return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean() + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.") + + +# class MahalanobisDistance(AbstractDistance): +# """Mahalanobis distance between pseudobulk vectors.""" + +# def __init__(self, aggregation_func: Callable = np.mean) -> None: +# super().__init__() +# self.accepts_precomputed = False +# self.aggregation_func = aggregation_func + +# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: +# return mahalanobis( +# self.aggregation_func(X, axis=0), +# self.aggregation_func(Y, axis=0), +# np.linalg.inv(np.cov(X.T)), +# ) + +# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: +# raise NotImplementedError("Mahalanobis cannot be called on a pairwise distance matrix.") diff --git a/tests/pertpy/test_distances.py b/tests/pertpy/test_distances.py new file mode 100644 index 00000000..664fe3df --- /dev/null +++ b/tests/pertpy/test_distances.py @@ -0,0 +1,223 @@ + +import numpy as np +import pertpy as pt +import pytest +import scanpy as sc +from pandas import DataFrame, Series +from pytest import fixture, mark + +@pytest.fixture +def rng(): #TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture + rng = np.random.default_rng(seed=42) + return rng + + +actual_distances = [ + # Euclidean distances and related + # "euclidean", + # "mean_absolute_error", + # "mean_pairwise", + # "mse", + "edistance", + # Other + # "cosine_distance", + # "kendalltau_distance", + # "mmd", + # "pearson_distance", + # "spearman_distance", + # "t_test", + # "mahalanobis", +] +# semi_distances = ["r2_distance", "sym_kldiv", "ks_test"] +# non_distances = ["classifier_proba"] +# onesided_only = ["classifier_cp"] +# pseudo_counts_distances = ["nb_ll"] +# lognorm_counts_distances = ["mean_var_distribution"] +all_distances = ( + actual_distances #+ semi_distances + non_distances + lognorm_counts_distances + pseudo_counts_distances +) # + onesided_only +semi_distances = [] +non_distances = [] +onesided_only = [] +pseudo_counts_distances = [] +lognorm_counts_distances = [] + +@fixture +def adata(request): + low_subsample_distances = [ + "sym_kldiv", + "t_test", + "ks_test", + "classifier_proba", + "classifier_cp", + "mahalanobis", + "mean_var_distribution", + ] + no_subsample_distances = ["mahalanobis"] # mahalanobis only works on the full data without subsampling + + distance = request.node.callspec.params["distance"] + + adata = pt.dt.distance_example() + if distance not in no_subsample_distances: + if distance in low_subsample_distances: + adata = sc.pp.subsample(adata, 0.1, copy=True) + else: + adata = sc.pp.subsample(adata, 0.001, copy=True) + + adata = adata[:, np.random.default_rng().choice(adata.n_vars, 100, replace=False)].copy() + + adata.layers["lognorm"] = adata.X.copy() + adata.layers["counts"] = np.round(adata.X.toarray()).astype(int) + if "X_pca" not in adata.obsm: + sc.pp.pca(adata, n_comps=5) + if distance in lognorm_counts_distances: + groups = np.unique(adata.obs["perturbation"]) + # KDE is slow, subset to 3 groups for speed up + adata = adata[adata.obs["perturbation"].isin(groups[0:3])].copy() + + return adata + + +@fixture +def distance_obj(request): + distance = request.node.callspec.params["distance"] + if distance in lognorm_counts_distances: + Distance = pt.tl.Distance(distance, layer_key="lognorm") + elif distance in pseudo_counts_distances: + Distance = pt.tl.Distance(distance, layer_key="counts") + else: + Distance = pt.tl.Distance(distance, obsm_key="X_pca") + return Distance + + +@fixture +@mark.parametrize("distance", all_distances) +def pairwise_distance(adata, distance_obj, distance): + return distance_obj.pairwise(adata, groupby="perturbation", show_progressbar=True) + + +@mark.parametrize("distance", actual_distances + semi_distances) +def test_distance_axioms(pairwise_distance, distance): + # This is equivalent to testing for a semimetric, defined as fulfilling all axioms except triangle inequality. + # (M1) Definiteness + assert all(np.diag(pairwise_distance.values) == 0) # distance to self is 0 + + # (M2) Positivity + assert len(pairwise_distance) == np.sum(pairwise_distance.values == 0) # distance to other is not 0 + assert all(pairwise_distance.values.flatten() >= 0) # distance is non-negative + + # (M3) Symmetry + assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0 + + +@mark.parametrize("distance", actual_distances) +def test_triangle_inequality(pairwise_distance, distance, rng): + # Test if distances are well-defined in accordance with metric axioms + # (M4) Triangle inequality (we just probe this for a few random triplets) + # Some tests are not well defined for the triangle inequality. We skip those. + if distance in {"mahalanobis"}: + return + + for _ in range(5): + triplet = rng.choice(pairwise_distance.index, size=3, replace=False) + assert ( + pairwise_distance.loc[triplet[0], triplet[1]] + pairwise_distance.loc[triplet[1], triplet[2]] + >= pairwise_distance.loc[triplet[0], triplet[2]] + ) + + +@mark.parametrize("distance", all_distances) +def test_distance_layers(pairwise_distance, distance): + assert isinstance(pairwise_distance, DataFrame) + assert pairwise_distance.columns.equals(pairwise_distance.index) + assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0 # symmetry + + +@mark.parametrize("distance", actual_distances + pseudo_counts_distances) +def test_distance_counts(adata, distance): + if distance != "mahalanobis": # skip, doesn't work because covariance matrix is a singular matrix, not invertible + distance = pt.tl.Distance(distance, layer_key="counts") + df = distance.pairwise(adata, groupby="perturbation") + assert isinstance(df, DataFrame) + assert df.columns.equals(df.index) + assert np.sum(df.values - df.values.T) == 0 + + +@mark.parametrize("distance", all_distances) +def test_mutually_exclusive_keys(distance): + with pytest.raises(ValueError): + _ = pt.tl.Distance(distance, layer_key="counts", obsm_key="X_pca") + + +@mark.parametrize("distance", actual_distances + semi_distances + non_distances) +def test_distance_output_type(distance, rng): + # Test if distances are outputting floats + Distance = pt.tl.Distance(distance) + X = rng.normal(size=(50, 10)) + Y = rng.normal(size=(50, 10)) + d = Distance(X, Y) + assert isinstance(d, float) + + +@mark.parametrize("distance", all_distances + onesided_only) +def test_distance_onesided(adata, distance_obj, distance): + # Test consistency of one-sided distance results + selected_group = adata.obs.perturbation.unique()[0] + df = distance_obj.onesided_distances(adata, groupby="perturbation", selected_group=selected_group) + assert isinstance(df, Series) + assert df.loc[selected_group] == 0 # distance to self is 0 + + +def test_bootstrap_distance_output_type(rng): + # Test if distances are outputting floats + Distance = pt.tl.Distance(metric="edistance") + X = rng.normal(size=(50, 10)) + Y = rng.normal(size=(50, 10)) + d = Distance.bootstrap(X, Y, n_bootstrap=3) + assert hasattr(d, "mean") + assert hasattr(d, "variance") + + +@mark.parametrize("distance", ["edistance"]) +def test_bootstrap_distance_pairwise(adata, distance): + # Test consistency of pairwise distance results + Distance = pt.tl.Distance(distance, obsm_key="X_pca") + bootstrap_output = Distance.pairwise(adata, groupby="perturbation", bootstrap=True, n_bootstrap=3) + + assert isinstance(bootstrap_output, tuple) + + mean = bootstrap_output[0] + var = bootstrap_output[1] + + assert mean.columns.equals(mean.index) + assert np.sum(mean.values - mean.values.T) == 0 # symmetry + assert np.sum(var.values - var.values.T) == 0 # symmetry + + +@mark.parametrize("distance", ["edistance"]) +def test_bootstrap_distance_onesided(adata, distance): + # Test consistency of one-sided distance results + selected_group = adata.obs.perturbation.unique()[0] + Distance = pt.tl.Distance(distance, obsm_key="X_pca") + bootstrap_output = Distance.onesided_distances( + adata, + groupby="perturbation", + selected_group=selected_group, + bootstrap=True, + n_bootstrap=3, + ) + + assert isinstance(bootstrap_output, tuple) + + +def test_compare_distance(rng): + X = rng.normal(size=(50, 10)) + Y = rng.normal(size=(50, 10)) + C = rng.normal(size=(50, 10)) + Distance = pt.tl.Distance() + res_simple = Distance.compare_distance(X, Y, C, mode="simple") + assert isinstance(res_simple, float) + res_scaled = Distance.compare_distance(X, Y, C, mode="scaled") + assert isinstance(res_scaled, float) + with pytest.raises(ValueError): + Distance.compare_distance(X, Y, C, mode="new_mode") From 5bd9f52726b32d51a31ca78089d012abe8c14519 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Sep 2025 09:49:19 +0000 Subject: [PATCH 02/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pertpy_gpu/_distances.py | 109 +++++++++++++----- tests/pertpy/test_distances.py | 42 ++++--- 2 files changed, 112 insertions(+), 39 deletions(-) diff --git a/src/rapids_singlecell/pertpy_gpu/_distances.py b/src/rapids_singlecell/pertpy_gpu/_distances.py index 562471c8..2f6e80e2 100644 --- a/src/rapids_singlecell/pertpy_gpu/_distances.py +++ b/src/rapids_singlecell/pertpy_gpu/_distances.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd -from pandas import Series + # TODO(selmanozleyen): adapt which progress bar to use, probably use rapidsinglecell's progress bar (if it exists) # from rich.progress import track # from scipy.sparse import issparse @@ -327,7 +327,10 @@ def pairwise( # able to handle precomputed distances such as the PseudobulkDistance. if self.metric_fct.accepts_precomputed: # Precompute the pairwise distances if needed - if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp: + if ( + f"{self.obsm_key}_{self.cell_wise_metric}_predistances" + not in adata.obsp + ): self.precompute_distances(adata, n_jobs=n_jobs, **kwargs) pwd = adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] for index_x, group_x in enumerate(fct(groups)): @@ -341,7 +344,9 @@ def pairwise( if group_x == group_y: dist = 0.0 else: - dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs) + dist = self.metric_fct.from_precomputed( + sub_pwd, sub_idx, **kwargs + ) df.loc[group_x, group_y] = dist df.loc[group_y, group_x] = dist @@ -354,17 +359,29 @@ def pairwise( **kwargs, ) # In the bootstrap case, distance of group to itself is a mean and can be non-zero - df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean - df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance + df.loc[group_x, group_y] = df.loc[group_y, group_x] = ( + bootstrap_output.mean + ) + df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = ( + bootstrap_output.variance + ) else: - embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() + embedding = ( + adata.layers[self.layer_key] + if self.layer_key + else adata.obsm[self.obsm_key].copy() + ) for index_x, group_x in enumerate(fct(groups)): cells_x = embedding[np.asarray(grouping == group_x)].copy() for group_y in groups[index_x:]: # type: ignore cells_y = embedding[np.asarray(grouping == group_y)].copy() if not bootstrap: # By distance axiom, the distance between a group and itself is 0 - dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs) + dist = ( + 0.0 + if group_x == group_y + else self(cells_x, cells_y, **kwargs) + ) df.loc[group_x, group_y] = dist df.loc[group_y, group_x] = dist @@ -377,8 +394,12 @@ def pairwise( **kwargs, ) # In the bootstrap case, distance of group to itself is a mean and can be non-zero - df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean - df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance + df.loc[group_x, group_y] = df.loc[group_y, group_x] = ( + bootstrap_output.mean + ) + df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = ( + bootstrap_output.variance + ) df.index.name = groupby df.columns.name = groupby @@ -436,7 +457,9 @@ def onesided_distances( """ if self.metric == "classifier_cp": if bootstrap: - raise NotImplementedError("Currently, ClassifierClassProjection does not support bootstrapping.") + raise NotImplementedError( + "Currently, ClassifierClassProjection does not support bootstrapping." + ) return self.metric_fct.onesided_distances( # type: ignore adata, groupby, @@ -453,7 +476,9 @@ def onesided_distances( if bootstrap: df_var = pd.Series(index=groups, dtype=float) # fct = track if show_progressbar else lambda iterable: iterable - fct = lambda iterable: iterable # see TODO at the top of the file about progress bar + fct = ( + lambda iterable: iterable + ) # see TODO at the top of the file about progress bar # Some metrics are able to handle precomputed distances. This means that # the pairwise distances between all cells are computed once and then @@ -462,7 +487,10 @@ def onesided_distances( # able to handle precomputed distances such as the PseudobulkDistance. if self.metric_fct.accepts_precomputed: # Precompute the pairwise distances if needed - if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp: + if ( + f"{self.obsm_key}_{self.cell_wise_metric}_predistances" + not in adata.obsp + ): self.precompute_distances(adata, n_jobs=n_jobs, **kwargs) pwd = adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] for group_x in fct(groups): @@ -476,7 +504,9 @@ def onesided_distances( sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y] sub_idx = grouping[idx_x | idx_y] == group_x if not bootstrap: - dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs) + dist = self.metric_fct.from_precomputed( + sub_pwd, sub_idx, **kwargs + ) df.loc[group_x] = dist else: bootstrap_output = self._bootstrap_mode_precomputed( @@ -489,14 +519,20 @@ def onesided_distances( df.loc[group_x] = bootstrap_output.mean df_var.loc[group_x] = bootstrap_output.variance else: - embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() + embedding = ( + adata.layers[self.layer_key] + if self.layer_key + else adata.obsm[self.obsm_key].copy() + ) for group_x in fct(groups): cells_x = embedding[np.asarray(grouping == group_x)].copy() group_y = selected_group cells_y = embedding[np.asarray(grouping == group_y)].copy() if not bootstrap: # By distance axiom, the distance between a group and itself is 0 - dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs) + dist = ( + 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs) + ) df.loc[group_x] = dist else: bootstrap_output = self.bootstrap( @@ -537,8 +573,14 @@ def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None: >>> distance = pt.tools.Distance(metric="edistance") >>> distance.precompute_distances(adata) """ - cells = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() - pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs) + cells = ( + adata.layers[self.layer_key] + if self.layer_key + else adata.obsm[self.obsm_key].copy() + ) + pwd = pairwise_distances( + cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs + ) adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd def compare_distance( @@ -565,7 +607,9 @@ def compare_distance( elif mode == "scaled": from sklearn.preprocessing import MinMaxScaler - scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl) + scaler = MinMaxScaler().fit( + np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl + ) pred = scaler.transform(pred) pert = scaler.transform(pert) else: @@ -575,7 +619,9 @@ def compare_distance( d2 = self.metric_fct(ctrl, pred, **kwargs) return d1 / d2 - def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + def _bootstrap_mode( + self, X, Y, n_bootstraps=100, random_state=0, **kwargs + ) -> MeanVar: rng = np.random.default_rng(random_state) distances = [] @@ -590,22 +636,30 @@ def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> M variance = np.var(distances) return MeanVar(mean=mean, variance=variance) - def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + def _bootstrap_mode_precomputed( + self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs + ) -> MeanVar: rng = np.random.default_rng(random_state) distances = [] for _ in range(n_bootstraps): # To maintain the number of cells for both groups (whatever balancing they may have), # we sample the positive and negative indices separately - bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True) - bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True) + bootstrap_pos_idx = rng.choice( + a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True + ) + bootstrap_neg_idx = rng.choice( + a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True + ) bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx]) bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx) bootstrap_sub_idx = sub_idx[bootstrap_idx] bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs] - distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs) + distance = self.metric_fct.from_precomputed( + bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs + ) distances.append(distance) mean = np.mean(distances) @@ -659,8 +713,12 @@ def __init__(self) -> None: self.cell_wise_metric = "euclidean" def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - sigma_X = pairwise_distances(X, X, metric=self.cell_wise_metric, **kwargs).mean() - sigma_Y = pairwise_distances(Y, Y, metric=self.cell_wise_metric, **kwargs).mean() + sigma_X = pairwise_distances( + X, X, metric=self.cell_wise_metric, **kwargs + ).mean() + sigma_Y = pairwise_distances( + Y, Y, metric=self.cell_wise_metric, **kwargs + ).mean() delta = pairwise_distances(X, Y, metric=self.cell_wise_metric, **kwargs).mean() return 2 * delta - sigma_X - sigma_Y @@ -701,7 +759,6 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: # raise NotImplementedError("MMD cannot be called on a pairwise distance matrix.") - # class EuclideanDistance(AbstractDistance): # """Euclidean distance between pseudobulk vectors.""" diff --git a/tests/pertpy/test_distances.py b/tests/pertpy/test_distances.py index 664fe3df..fdd51d07 100644 --- a/tests/pertpy/test_distances.py +++ b/tests/pertpy/test_distances.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np import pertpy as pt @@ -6,8 +7,9 @@ from pandas import DataFrame, Series from pytest import fixture, mark -@pytest.fixture -def rng(): #TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture + +@pytest.fixture +def rng(): # TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture rng = np.random.default_rng(seed=42) return rng @@ -33,15 +35,14 @@ def rng(): #TODO(selmanozleyen): Think of a way to integrate this with decoupler # onesided_only = ["classifier_cp"] # pseudo_counts_distances = ["nb_ll"] # lognorm_counts_distances = ["mean_var_distribution"] -all_distances = ( - actual_distances #+ semi_distances + non_distances + lognorm_counts_distances + pseudo_counts_distances -) # + onesided_only +all_distances = actual_distances # + semi_distances + non_distances + lognorm_counts_distances + pseudo_counts_distances # + onesided_only semi_distances = [] non_distances = [] onesided_only = [] pseudo_counts_distances = [] lognorm_counts_distances = [] + @fixture def adata(request): low_subsample_distances = [ @@ -53,7 +54,9 @@ def adata(request): "mahalanobis", "mean_var_distribution", ] - no_subsample_distances = ["mahalanobis"] # mahalanobis only works on the full data without subsampling + no_subsample_distances = [ + "mahalanobis" + ] # mahalanobis only works on the full data without subsampling distance = request.node.callspec.params["distance"] @@ -64,7 +67,9 @@ def adata(request): else: adata = sc.pp.subsample(adata, 0.001, copy=True) - adata = adata[:, np.random.default_rng().choice(adata.n_vars, 100, replace=False)].copy() + adata = adata[ + :, np.random.default_rng().choice(adata.n_vars, 100, replace=False) + ].copy() adata.layers["lognorm"] = adata.X.copy() adata.layers["counts"] = np.round(adata.X.toarray()).astype(int) @@ -103,7 +108,9 @@ def test_distance_axioms(pairwise_distance, distance): assert all(np.diag(pairwise_distance.values) == 0) # distance to self is 0 # (M2) Positivity - assert len(pairwise_distance) == np.sum(pairwise_distance.values == 0) # distance to other is not 0 + assert len(pairwise_distance) == np.sum( + pairwise_distance.values == 0 + ) # distance to other is not 0 assert all(pairwise_distance.values.flatten() >= 0) # distance is non-negative # (M3) Symmetry @@ -121,7 +128,8 @@ def test_triangle_inequality(pairwise_distance, distance, rng): for _ in range(5): triplet = rng.choice(pairwise_distance.index, size=3, replace=False) assert ( - pairwise_distance.loc[triplet[0], triplet[1]] + pairwise_distance.loc[triplet[1], triplet[2]] + pairwise_distance.loc[triplet[0], triplet[1]] + + pairwise_distance.loc[triplet[1], triplet[2]] >= pairwise_distance.loc[triplet[0], triplet[2]] ) @@ -130,12 +138,16 @@ def test_triangle_inequality(pairwise_distance, distance, rng): def test_distance_layers(pairwise_distance, distance): assert isinstance(pairwise_distance, DataFrame) assert pairwise_distance.columns.equals(pairwise_distance.index) - assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0 # symmetry + assert ( + np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0 + ) # symmetry @mark.parametrize("distance", actual_distances + pseudo_counts_distances) def test_distance_counts(adata, distance): - if distance != "mahalanobis": # skip, doesn't work because covariance matrix is a singular matrix, not invertible + if ( + distance != "mahalanobis" + ): # skip, doesn't work because covariance matrix is a singular matrix, not invertible distance = pt.tl.Distance(distance, layer_key="counts") df = distance.pairwise(adata, groupby="perturbation") assert isinstance(df, DataFrame) @@ -163,7 +175,9 @@ def test_distance_output_type(distance, rng): def test_distance_onesided(adata, distance_obj, distance): # Test consistency of one-sided distance results selected_group = adata.obs.perturbation.unique()[0] - df = distance_obj.onesided_distances(adata, groupby="perturbation", selected_group=selected_group) + df = distance_obj.onesided_distances( + adata, groupby="perturbation", selected_group=selected_group + ) assert isinstance(df, Series) assert df.loc[selected_group] == 0 # distance to self is 0 @@ -182,7 +196,9 @@ def test_bootstrap_distance_output_type(rng): def test_bootstrap_distance_pairwise(adata, distance): # Test consistency of pairwise distance results Distance = pt.tl.Distance(distance, obsm_key="X_pca") - bootstrap_output = Distance.pairwise(adata, groupby="perturbation", bootstrap=True, n_bootstrap=3) + bootstrap_output = Distance.pairwise( + adata, groupby="perturbation", bootstrap=True, n_bootstrap=3 + ) assert isinstance(bootstrap_output, tuple) From 25da71a3b2ef7cdaa4a7203b2582ac804e37d446 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 9 Sep 2025 11:49:59 +0000 Subject: [PATCH 03/23] naive implementation --- .../pertpy_gpu/_distances.py | 760 +++--------------- tests/pertpy/test_distances.py | 132 +-- 2 files changed, 191 insertions(+), 701 deletions(-) diff --git a/src/rapids_singlecell/pertpy_gpu/_distances.py b/src/rapids_singlecell/pertpy_gpu/_distances.py index 2f6e80e2..c2db2d20 100644 --- a/src/rapids_singlecell/pertpy_gpu/_distances.py +++ b/src/rapids_singlecell/pertpy_gpu/_distances.py @@ -5,7 +5,12 @@ import numpy as np import pandas as pd - +from pandas import Series +import cupy as cp +from cuml.metrics import pairwise_distances +from cupy.random import choice +from cupyx.scipy.sparse import issparse as cp_issparse +from cuml.preprocessing import MinMaxScaler # TODO(selmanozleyen): adapt which progress bar to use, probably use rapidsinglecell's progress bar (if it exists) # from rich.progress import track # from scipy.sparse import issparse @@ -132,7 +137,7 @@ class Distance: def __init__( self, metric: Metric = "edistance", - agg_fct: Callable = np.mean, + agg_fct: Callable = cp.mean, layer_key: str = None, obsm_key: str = None, cell_wise_metric: str = "euclidean", @@ -207,8 +212,8 @@ def __init__( def __call__( self, - X: np.ndarray, - Y: np.ndarray, + X: cp.ndarray | cp.sparse.spmatrix, + Y: cp.ndarray | cp.sparse.spmatrix, **kwargs, ) -> float: """Compute distance between vectors X and Y. @@ -229,9 +234,9 @@ def __call__( >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"] >>> D = Distance(X, Y) """ - if issparse(X): + if cp_issparse(X): X = X.toarray() - if issparse(Y): + if cp_issparse(Y): Y = Y.toarray() if len(X) == 0 or len(Y) == 0: @@ -241,8 +246,8 @@ def __call__( def bootstrap( self, - X: np.ndarray, - Y: np.ndarray, + X: cp.ndarray, + Y: cp.ndarray, *, n_bootstrap: int = 100, random_state: int = 0, @@ -372,9 +377,9 @@ def pairwise( else adata.obsm[self.obsm_key].copy() ) for index_x, group_x in enumerate(fct(groups)): - cells_x = embedding[np.asarray(grouping == group_x)].copy() + cells_x = embedding[cp.asarray(grouping == group_x)].copy() for group_y in groups[index_x:]: # type: ignore - cells_y = embedding[np.asarray(grouping == group_y)].copy() + cells_y = embedding[cp.asarray(grouping == group_y)].copy() if not bootstrap: # By distance axiom, the distance between a group and itself is 0 dist = ( @@ -416,6 +421,99 @@ def pairwise( return df, df_var + + def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None: + """Precompute pairwise distances between all cells, writes to adata.obsp. + + The precomputed distances are stored in adata.obsp under the key + '{self.obsm_key}_{cell_wise_metric}_predistances', as they depend on + both the cell-wise metric and the embedding used. + + Args: + adata: Annotated data matrix. + n_jobs: Number of cores to use. Defaults to -1 (all). + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example() + >>> distance = pt.tools.Distance(metric="edistance") + >>> distance.precompute_distances(adata) + """ + cells = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() + pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs) + adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd + + def compare_distance( + self, + pert: cp.ndarray, + pred: cp.ndarray, + ctrl: cp.ndarray, + mode: Literal["simple", "scaled"] = "simple", + fit_to_pert_and_ctrl: bool = False, + **kwargs, + ) -> float: + """Compute the score of simulating a perturbation. + + Args: + pert: Real perturbed data. + pred: Simulated perturbed data. + ctrl: Control data + mode: Mode to use. + fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`. + kwargs: Additional keyword arguments passed to the metric function. + """ + if mode == "simple": + pass # nothing to be done + elif mode == "scaled": + + scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl) + pred = scaler.transform(pred) + pert = scaler.transform(pert) + else: + raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.") + + d1 = self.metric_fct(pert, pred, **kwargs) + d2 = self.metric_fct(ctrl, pred, **kwargs) + return d1 / d2 + + def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + # rng = np.random.default_rng(random_state) + + distances = [] + for _ in range(n_bootstraps): + X_bootstrapped = X[choice(a=X.shape[0], size=X.shape[0], replace=True)] + Y_bootstrapped = Y[choice(a=Y.shape[0], size=X.shape[0], replace=True)] + + distance = self(X_bootstrapped, Y_bootstrapped, **kwargs) + distances.append(distance.get()) + + mean = np.mean(distances) + variance = np.var(distances) + return MeanVar(mean=mean, variance=variance) + + def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + rng = np.random.default_rng(random_state) + + distances = [] + for _ in range(n_bootstraps): + # To maintain the number of cells for both groups (whatever balancing they may have), + # we sample the positive and negative indices separately + bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True) + bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True) + bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx]) + bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx) + + bootstrap_sub_idx = sub_idx[bootstrap_idx] + bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs] + + distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs) + distances.append(distance.get()) + + mean = np.mean(distances) + variance = np.var(distances) + return MeanVar(mean=mean, variance=variance) + + def onesided_distances( self, adata: AnnData, @@ -556,117 +654,6 @@ def onesided_distances( return df, df_var - def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None: - """Precompute pairwise distances between all cells, writes to adata.obsp. - - The precomputed distances are stored in adata.obsp under the key - '{self.obsm_key}_{cell_wise_metric}_predistances', as they depend on - both the cell-wise metric and the embedding used. - - Args: - adata: Annotated data matrix. - n_jobs: Number of cores to use. Defaults to -1 (all). - - Examples: - >>> import pertpy as pt - >>> adata = pt.dt.distance_example() - >>> distance = pt.tools.Distance(metric="edistance") - >>> distance.precompute_distances(adata) - """ - cells = ( - adata.layers[self.layer_key] - if self.layer_key - else adata.obsm[self.obsm_key].copy() - ) - pwd = pairwise_distances( - cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs - ) - adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd - - def compare_distance( - self, - pert: np.ndarray, - pred: np.ndarray, - ctrl: np.ndarray, - mode: Literal["simple", "scaled"] = "simple", - fit_to_pert_and_ctrl: bool = False, - **kwargs, - ) -> float: - """Compute the score of simulating a perturbation. - - Args: - pert: Real perturbed data. - pred: Simulated perturbed data. - ctrl: Control data - mode: Mode to use. - fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`. - kwargs: Additional keyword arguments passed to the metric function. - """ - if mode == "simple": - pass # nothing to be done - elif mode == "scaled": - from sklearn.preprocessing import MinMaxScaler - - scaler = MinMaxScaler().fit( - np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl - ) - pred = scaler.transform(pred) - pert = scaler.transform(pert) - else: - raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.") - - d1 = self.metric_fct(pert, pred, **kwargs) - d2 = self.metric_fct(ctrl, pred, **kwargs) - return d1 / d2 - - def _bootstrap_mode( - self, X, Y, n_bootstraps=100, random_state=0, **kwargs - ) -> MeanVar: - rng = np.random.default_rng(random_state) - - distances = [] - for _ in range(n_bootstraps): - X_bootstrapped = X[rng.choice(a=X.shape[0], size=X.shape[0], replace=True)] - Y_bootstrapped = Y[rng.choice(a=Y.shape[0], size=X.shape[0], replace=True)] - - distance = self(X_bootstrapped, Y_bootstrapped, **kwargs) - distances.append(distance) - - mean = np.mean(distances) - variance = np.var(distances) - return MeanVar(mean=mean, variance=variance) - - def _bootstrap_mode_precomputed( - self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs - ) -> MeanVar: - rng = np.random.default_rng(random_state) - - distances = [] - for _ in range(n_bootstraps): - # To maintain the number of cells for both groups (whatever balancing they may have), - # we sample the positive and negative indices separately - bootstrap_pos_idx = rng.choice( - a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True - ) - bootstrap_neg_idx = rng.choice( - a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True - ) - bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx]) - bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx) - - bootstrap_sub_idx = sub_idx[bootstrap_idx] - bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs] - - distance = self.metric_fct.from_precomputed( - bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs - ) - distances.append(distance) - - mean = np.mean(distances) - variance = np.var(distances) - return MeanVar(mean=mean, variance=variance) - - class AbstractDistance(ABC): """Abstract class of distance metrics between two sets of vectors.""" @@ -676,7 +663,7 @@ def __init__(self) -> None: self.accepts_precomputed: bool = None @abstractmethod - def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: + def __call__(self, X: cp.ndarray, Y: cp.ndarray, **kwargs) -> float: """Compute distance between vectors X and Y. Args: @@ -690,7 +677,7 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: raise NotImplementedError("Metric class is abstract.") @abstractmethod - def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: + def from_precomputed(self, P: cp.ndarray, idx: cp.ndarray, **kwargs) -> float: """Compute a distance between vectors X and Y with precomputed distances. Args: @@ -712,533 +699,14 @@ def __init__(self) -> None: self.accepts_precomputed = True self.cell_wise_metric = "euclidean" - def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - sigma_X = pairwise_distances( - X, X, metric=self.cell_wise_metric, **kwargs - ).mean() - sigma_Y = pairwise_distances( - Y, Y, metric=self.cell_wise_metric, **kwargs - ).mean() + def __call__(self, X: cp.ndarray, Y: cp.ndarray, **kwargs) -> float: + sigma_X = pairwise_distances(X, X, metric=self.cell_wise_metric, **kwargs).mean() + sigma_Y = pairwise_distances(Y, Y, metric=self.cell_wise_metric, **kwargs).mean() delta = pairwise_distances(X, Y, metric=self.cell_wise_metric, **kwargs).mean() return 2 * delta - sigma_X - sigma_Y - def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: + def from_precomputed(self, P: cp.ndarray, idx: cp.ndarray, **kwargs) -> float: sigma_X = P[idx, :][:, idx].mean() sigma_Y = P[~idx, :][:, ~idx].mean() delta = P[idx, :][:, ~idx].mean() return 2 * delta - sigma_X - sigma_Y - - -# class MMD(AbstractDistance): -# """Linear Maximum Mean Discrepancy.""" - -# # Taken in parts from https://github.com/jindongwang/transferlearning/blob/master/code/distance/mmd_numpy_sklearn.py -# def __init__(self) -> None: -# super().__init__() -# self.accepts_precomputed = False - -# def __call__(self, X: np.ndarray, Y: np.ndarray, *, kernel="linear", gamma=1.0, degree=2, **kwargs) -> float: -# if kernel == "linear": -# XX = np.dot(X, X.T) -# YY = np.dot(Y, Y.T) -# XY = np.dot(X, Y.T) -# elif kernel == "rbf": -# XX = rbf_kernel(X, X, gamma=gamma) -# YY = rbf_kernel(Y, Y, gamma=gamma) -# XY = rbf_kernel(X, Y, gamma=gamma) -# elif kernel == "poly": -# XX = polynomial_kernel(X, X, degree=degree, gamma=gamma, coef0=0) -# YY = polynomial_kernel(Y, Y, degree=degree, gamma=gamma, coef0=0) -# XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=0) -# else: -# raise ValueError(f"Kernel {kernel} not recognized.") - -# return XX.mean() + YY.mean() - 2 * XY.mean() - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("MMD cannot be called on a pairwise distance matrix.") - - -# class EuclideanDistance(AbstractDistance): -# """Euclidean distance between pseudobulk vectors.""" - -# def __init__(self, aggregation_func: Callable = np.mean) -> None: -# super().__init__() -# self.accepts_precomputed = False -# self.aggregation_func = aggregation_func - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# return np.linalg.norm( -# self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0), -# ord=2, -# **kwargs, -# ) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("EuclideanDistance cannot be called on a pairwise distance matrix.") - - -# class MeanSquaredDistance(AbstractDistance): -# """Mean squared distance between pseudobulk vectors.""" - -# def __init__(self, aggregation_func: Callable = np.mean) -> None: -# super().__init__() -# self.accepts_precomputed = False -# self.aggregation_func = aggregation_func - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# return ( -# np.linalg.norm( -# self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0), -# ord=2, -# **kwargs, -# ) -# ** 2 -# / X.shape[1] -# ) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("MeanSquaredDistance cannot be called on a pairwise distance matrix.") - - -# class MeanAbsoluteDistance(AbstractDistance): -# """Absolute (Norm-1) distance between pseudobulk vectors.""" - -# def __init__(self, aggregation_func: Callable = np.mean) -> None: -# super().__init__() -# self.accepts_precomputed = False -# self.aggregation_func = aggregation_func - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# return ( -# np.linalg.norm( -# self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0), -# ord=1, -# **kwargs, -# ) -# / X.shape[1] -# ) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("MeanAbsoluteDistance cannot be called on a pairwise distance matrix.") - - -# class MeanPairwiseDistance(AbstractDistance): -# """Mean of the pairwise euclidean distance between two groups of cells.""" - -# # NOTE: This is not a metric in the mathematical sense. - -# def __init__(self) -> None: -# super().__init__() -# self.accepts_precomputed = True - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# return pairwise_distances(X, Y, **kwargs).mean() - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# return P[idx, :][:, ~idx].mean() - - -# class PearsonDistance(AbstractDistance): -# """Pearson distance between pseudobulk vectors.""" - -# def __init__(self, aggregation_func: Callable = np.mean) -> None: -# super().__init__() -# self.accepts_precomputed = False -# self.aggregation_func = aggregation_func - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# return 1 - pearsonr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0] - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("PearsonDistance cannot be called on a pairwise distance matrix.") - - -# class SpearmanDistance(AbstractDistance): -# """Spearman distance between pseudobulk vectors.""" - -# def __init__(self, aggregation_func: Callable = np.mean) -> None: -# super().__init__() -# self.accepts_precomputed = False -# self.aggregation_func = aggregation_func - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# return 1 - spearmanr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0] - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("SpearmanDistance cannot be called on a pairwise distance matrix.") - - -# class KendallTauDistance(AbstractDistance): -# """Kendall-tau distance between pseudobulk vectors.""" - -# def __init__(self, aggregation_func: Callable = np.mean) -> None: -# super().__init__() -# self.accepts_precomputed = False -# self.aggregation_func = aggregation_func - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# x, y = self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0) -# n = len(x) -# tau_corr = kendalltau(x, y).statistic -# tau_dist = (1 - tau_corr) * n * (n - 1) / 4 -# return tau_dist - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("KendallTauDistance cannot be called on a pairwise distance matrix.") - - -# class CosineDistance(AbstractDistance): -# """Cosine distance between pseudobulk vectors.""" - -# def __init__(self, aggregation_func: Callable = np.mean) -> None: -# super().__init__() -# self.accepts_precomputed = False -# self.aggregation_func = aggregation_func - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# return cosine(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("CosineDistance cannot be called on a pairwise distance matrix.") - - -# class R2ScoreDistance(AbstractDistance): -# """Coefficient of determination across genes between pseudobulk vectors.""" - -# # NOTE: This is not a distance metric but a similarity metric. - -# def __init__(self, aggregation_func: Callable = np.mean) -> None: -# super().__init__() -# self.accepts_precomputed = False -# self.aggregation_func = aggregation_func - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# return 1 - r2_score(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("R2ScoreDistance cannot be called on a pairwise distance matrix.") - - -# class SymmetricKLDivergence(AbstractDistance): -# """Average of symmetric KL divergence between gene distributions of two groups. - -# Assuming a Gaussian distribution for each gene in each group, calculates -# the KL divergence between them and averages over all genes. Repeats this ABBA to get a symmetrized distance. -# See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Symmetrised_divergence. - -# """ - -# def __init__(self) -> None: -# super().__init__() -# self.accepts_precomputed = False - -# def __call__(self, X: np.ndarray, Y: np.ndarray, epsilon=1e-8, **kwargs) -> float: -# kl_all = [] -# for i in range(X.shape[1]): -# x_mean, x_std = X[:, i].mean(), X[:, i].std() + epsilon -# y_mean, y_std = Y[:, i].mean(), Y[:, i].std() + epsilon -# kl = np.log(y_std / x_std) + (x_std**2 + (x_mean - y_mean) ** 2) / (2 * y_std**2) - 1 / 2 -# klr = np.log(x_std / y_std) + (y_std**2 + (y_mean - x_mean) ** 2) / (2 * x_std**2) - 1 / 2 -# kl_all.append(kl + klr) -# return sum(kl_all) / len(kl_all) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("SymmetricKLDivergence cannot be called on a pairwise distance matrix.") - - -# class TTestDistance(AbstractDistance): -# """Average of T test statistic between two groups assuming unequal variances.""" - -# def __init__(self) -> None: -# super().__init__() -# self.accepts_precomputed = False - -# def __call__(self, X: np.ndarray, Y: np.ndarray, epsilon=1e-8, **kwargs) -> float: -# t_test_all = [] -# n1 = X.shape[0] -# n2 = Y.shape[0] -# for i in range(X.shape[1]): -# m1, v1 = X[:, i].mean(), X[:, i].std() ** 2 * n1 / (n1 - 1) + epsilon -# m2, v2 = Y[:, i].mean(), Y[:, i].std() ** 2 * n2 / (n2 - 1) + epsilon -# vn1 = v1 / n1 -# vn2 = v2 / n2 -# t = (m1 - m2) / np.sqrt(vn1 + vn2) -# t_test_all.append(abs(t)) -# return sum(t_test_all) / len(t_test_all) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("TTestDistance cannot be called on a pairwise distance matrix.") - - -# class KSTestDistance(AbstractDistance): -# """Average of two-sided KS test statistic between two groups.""" - -# def __init__(self) -> None: -# super().__init__() -# self.accepts_precomputed = False - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# stats = [abs(kstest(X[:, i], Y[:, i])[0]) for i in range(X.shape[1])] -# return sum(stats) / len(stats) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("KSTestDistance cannot be called on a pairwise distance matrix.") - - -# class NBLL(AbstractDistance): -# """Average of Log likelihood (scalar) of group B cells according to a NB distribution fitted over group A.""" - -# def __init__(self) -> None: -# super().__init__() -# self.accepts_precomputed = False - -# def __call__(self, X: np.ndarray, Y: np.ndarray, epsilon=1e-8, **kwargs) -> float: -# def _is_count_matrix(matrix, tolerance=1e-6): -# return bool(matrix.dtype.kind == "i" or np.all(np.abs(matrix - np.round(matrix)) < tolerance)) - -# if not _is_count_matrix(matrix=X) or not _is_count_matrix(matrix=Y): -# raise ValueError("NBLL distance only works for raw counts.") - -# @jit(forceobj=True) -# def _compute_nll(y: np.ndarray, nb_params: tuple[float, float], epsilon: float) -> float: -# mu = np.exp(nb_params[0]) -# theta = 1 / nb_params[1] -# eps = epsilon - -# log_theta_mu_eps = np.log(theta + mu + eps) -# nll = ( -# theta * (np.log(theta + eps) - log_theta_mu_eps) -# + y * (np.log(mu + eps) - log_theta_mu_eps) -# + gammaln(y + theta) -# - gammaln(theta) -# - gammaln(y + 1) -# ) -# return nll.mean() - -# def _process_gene(x: np.ndarray, y: np.ndarray, epsilon: float) -> float: -# try: -# nb_params = NegativeBinomialP(x, np.ones_like(x)).fit(disp=False).params -# return _compute_nll(y, nb_params, epsilon) -# except np.linalg.LinAlgError: -# if x.mean() < 10 and y.mean() < 10: -# return 0.0 -# else: -# return np.nan # Use NaN to indicate skipped genes - -# nlls = [] -# genes_skipped = 0 - -# for i in range(X.shape[1]): -# nll = _process_gene(X[:, i], Y[:, i], epsilon) -# if np.isnan(nll): -# genes_skipped += 1 -# else: -# nlls.append(nll) - -# if genes_skipped > X.shape[1] / 2: -# raise AttributeError(f"{genes_skipped} genes could not be fit, which is over half.") - -# return -np.sum(nlls) / len(nlls) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("NBLL cannot be called on a pairwise distance matrix.") - - -# def _sample(X, frac=None, n=None): -# """Returns subsample of cells in format (train, test).""" -# if frac and n: -# raise ValueError("Cannot pass both frac and n.") -# if frac: -# n_cells = max(1, int(X.shape[0] * frac)) -# elif n: -# n_cells = n -# else: -# raise ValueError("Must pass either `frac` or `n`.") - -# rng = np.random.default_rng() -# sampled_indices = rng.choice(X.shape[0], n_cells, replace=False) -# remaining_indices = np.setdiff1d(np.arange(X.shape[0]), sampled_indices) -# return X[remaining_indices, :], X[sampled_indices, :] - - -# class ClassifierProbaDistance(AbstractDistance): -# """Average of classification probabilites of a binary classifier. - -# Assumes the first condition is control and the second is perturbed. -# Always holds out 20% of the perturbed condition. -# """ - -# def __init__(self) -> None: -# super().__init__() -# self.accepts_precomputed = False - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# Y_train, Y_test = _sample(Y, frac=0.2) -# label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0] -# train = np.concatenate([X, Y_train]) - -# reg = LogisticRegression() -# reg.fit(train, label) -# test_labels = reg.predict_proba(Y_test) -# return np.mean(test_labels[:, 1]) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("ClassifierProbaDistance cannot be called on a pairwise distance matrix.") - - -# class ClassifierClassProjection(AbstractDistance): -# """Average of 1-(classification probability of control). - -# Warning: unlike all other distances, this must also take a list of categorical labels the same length as X. -# """ - -# def __init__(self) -> None: -# super().__init__() -# self.accepts_precomputed = False - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("ClassifierClassProjection can currently only be called with onesided.") - -# def onesided_distances( -# self, -# adata: AnnData, -# groupby: str, -# selected_group: str | None = None, -# groups: list[str] | None = None, -# show_progressbar: bool = True, -# n_jobs: int = -1, -# **kwargs, -# ) -> Series: -# """Unlike the parent function, all groups except the selected group are factored into the classifier. - -# Similar to the parent function, the returned dataframe contains only the specified groups. -# """ -# groups = adata.obs[groupby].unique() if groups is None else groups -# fct = track if show_progressbar else lambda iterable: iterable - -# X = adata[adata.obs[groupby] != selected_group].X -# labels = adata[adata.obs[groupby] != selected_group].obs[groupby].values -# Y = adata[adata.obs[groupby] == selected_group].X - -# reg = LogisticRegression() -# reg.fit(X, labels) -# test_probas = reg.predict_proba(Y) - -# df = pd.Series(index=groups, dtype=float) - -# for group in fct(groups): -# if group == selected_group: -# df.loc[group] = 0 -# else: -# class_idx = list(reg.classes_).index(group) -# df.loc[group] = 1 - np.mean(test_probas[:, class_idx]) -# df.index.name = groupby -# df.name = f"classifier_cp to {selected_group}" - -# return df - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("ClassifierClassProjection cannot be called on a pairwise distance matrix.") - - -# class MeanVarDistributionDistance(AbstractDistance): -# """Distance between mean-var distributions of gene expression.""" - -# def __init__(self) -> None: -# super().__init__() -# self.accepts_precomputed = False - -# @staticmethod -# def _mean_var(x, log: bool = False): -# mean = np.mean(x, axis=0) -# var = np.var(x, axis=0) -# positive = mean > 0 -# mean = mean[positive] -# var = var[positive] -# if log: -# mean = np.log(mean) -# var = np.log(var) -# return mean, var - -# @staticmethod -# def _prep_kde_data(x, y): -# return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1) - -# @staticmethod -# def _grid_points(d, n_points=100): -# # Make grid, add 1 bin on lower/upper end to get final n_points -# d_min = d.min() -# d_max = d.max() -# # Compute bin size -# d_bin = (d_max - d_min) / (n_points - 2) -# d_min = d_min - d_bin -# d_max = d_max + d_bin -# return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin) - -# @staticmethod -# def _kde_eval_both(x_kde, y_kde, grid): -# n_points = len(grid) -# chunk_size = 10000 - -# result_x = np.zeros(n_points) -# result_y = np.zeros(n_points) - -# # Process same chunks for both KDEs -# for start in range(0, n_points, chunk_size): -# end = min(start + chunk_size, n_points) -# chunk = grid[start:end] -# result_x[start:end] = x_kde.score_samples(chunk) -# result_y[start:end] = y_kde.score_samples(chunk) - -# return result_x, result_y - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# """Difference of mean-var distributions in 2 matrices. - -# Args: -# X: Normalized and log transformed cells x genes count matrix. -# Y: Normalized and log transformed cells x genes count matrix. -# kwargs: Passed to the metrics function. -# """ -# mean_x, var_x = self._mean_var(X, log=True) -# mean_y, var_y = self._mean_var(Y, log=True) - -# x = self._prep_kde_data(mean_x, var_x) -# y = self._prep_kde_data(mean_y, var_y) - -# # Gridpoints to eval KDE on -# mean_grid = self._grid_points(np.concatenate([mean_x, mean_y])) -# var_grid = self._grid_points(np.concatenate([var_x, var_y])) -# grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2) - -# # Fit both KDEs first -# x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x) -# y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y) - -# # Evaluate both KDEs on same grid chunks -# kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid) - -# return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean() - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.") - - -# class MahalanobisDistance(AbstractDistance): -# """Mahalanobis distance between pseudobulk vectors.""" - -# def __init__(self, aggregation_func: Callable = np.mean) -> None: -# super().__init__() -# self.accepts_precomputed = False -# self.aggregation_func = aggregation_func - -# def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: -# return mahalanobis( -# self.aggregation_func(X, axis=0), -# self.aggregation_func(Y, axis=0), -# np.linalg.inv(np.cov(X.T)), -# ) - -# def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: -# raise NotImplementedError("Mahalanobis cannot be called on a pairwise distance matrix.") diff --git a/tests/pertpy/test_distances.py b/tests/pertpy/test_distances.py index fdd51d07..1f31e38e 100644 --- a/tests/pertpy/test_distances.py +++ b/tests/pertpy/test_distances.py @@ -1,19 +1,25 @@ from __future__ import annotations import numpy as np -import pertpy as pt +from pertpy import data as dt import pytest import scanpy as sc from pandas import DataFrame, Series from pytest import fixture, mark +import cupy as cp +from rapids_singlecell.pertpy_gpu._distances import Distance +from scipy import sparse as sp +@pytest.fixture +def cp_rng(): #TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture + rng = cp.random.default_rng(seed=42) + return rng -@pytest.fixture -def rng(): # TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture +@pytest.fixture +def np_rng(): #TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture rng = np.random.default_rng(seed=42) return rng - actual_distances = [ # Euclidean distances and related # "euclidean", @@ -60,7 +66,7 @@ def adata(request): distance = request.node.callspec.params["distance"] - adata = pt.dt.distance_example() + adata = dt.distance_example() if distance not in no_subsample_distances: if distance in low_subsample_distances: adata = sc.pp.subsample(adata, 0.1, copy=True) @@ -72,13 +78,29 @@ def adata(request): ].copy() adata.layers["lognorm"] = adata.X.copy() - adata.layers["counts"] = np.round(adata.X.toarray()).astype(int) + adata.layers["counts"] = cp.round(adata.X.toarray()).astype(int) if "X_pca" not in adata.obsm: sc.pp.pca(adata, n_comps=5) if distance in lognorm_counts_distances: groups = np.unique(adata.obs["perturbation"]) # KDE is slow, subset to 3 groups for speed up adata = adata[adata.obs["perturbation"].isin(groups[0:3])].copy() + + adata.X = cp.asarray(adata.X.toarray()) + for l_key in adata.layers.keys(): + if sp.issparse(adata.layers[l_key]): + from cupyx.scipy.sparse import csr_matrix, csc_matrix , coo_matrix + if sp.isspmatrix_csr(adata.layers[l_key]): + adata.layers[l_key] = csr_matrix(adata.layers[l_key]) + elif sp.isspmatrix_csc(adata.layers[l_key]): + adata.layers[l_key] = csc_matrix(adata.layers[l_key]) + elif sp.isspmatrix_coo(adata.layers[l_key]): + adata.layers[l_key] = coo_matrix(adata.layers[l_key]) + else: + adata.layers[l_key] = cp.asarray(adata.layers[l_key]) + adata.layers["lognorm"] = cp.asarray(adata.layers["lognorm"].toarray()) + adata.layers["counts"] = cp.asarray(adata.layers["counts"]) + adata.obsm["X_pca"] = cp.asarray(adata.obsm["X_pca"]) return adata @@ -87,18 +109,18 @@ def adata(request): def distance_obj(request): distance = request.node.callspec.params["distance"] if distance in lognorm_counts_distances: - Distance = pt.tl.Distance(distance, layer_key="lognorm") + d = Distance(distance, layer_key="lognorm") elif distance in pseudo_counts_distances: - Distance = pt.tl.Distance(distance, layer_key="counts") + d = Distance(distance, layer_key="counts") else: - Distance = pt.tl.Distance(distance, obsm_key="X_pca") - return Distance + d = Distance(distance, obsm_key="X_pca") + return d @fixture @mark.parametrize("distance", all_distances) def pairwise_distance(adata, distance_obj, distance): - return distance_obj.pairwise(adata, groupby="perturbation", show_progressbar=True) + return distance_obj.pairwise(adata, groupby="perturbation", show_progressbar=False) @mark.parametrize("distance", actual_distances + semi_distances) @@ -118,7 +140,7 @@ def test_distance_axioms(pairwise_distance, distance): @mark.parametrize("distance", actual_distances) -def test_triangle_inequality(pairwise_distance, distance, rng): +def test_triangle_inequality(pairwise_distance, distance, np_rng): # Test if distances are well-defined in accordance with metric axioms # (M4) Triangle inequality (we just probe this for a few random triplets) # Some tests are not well defined for the triangle inequality. We skip those. @@ -126,7 +148,8 @@ def test_triangle_inequality(pairwise_distance, distance, rng): return for _ in range(5): - triplet = rng.choice(pairwise_distance.index, size=3, replace=False) + from cupy.random import choice + triplet = np_rng.choice(pairwise_distance.index, size=3, replace=False) assert ( pairwise_distance.loc[triplet[0], triplet[1]] + pairwise_distance.loc[triplet[1], triplet[2]] @@ -158,16 +181,17 @@ def test_distance_counts(adata, distance): @mark.parametrize("distance", all_distances) def test_mutually_exclusive_keys(distance): with pytest.raises(ValueError): - _ = pt.tl.Distance(distance, layer_key="counts", obsm_key="X_pca") + _ = Distance(distance, layer_key="counts", obsm_key="X_pca") @mark.parametrize("distance", actual_distances + semi_distances + non_distances) -def test_distance_output_type(distance, rng): +def test_distance_output_type(distance, cp_rng): # Test if distances are outputting floats - Distance = pt.tl.Distance(distance) - X = rng.normal(size=(50, 10)) - Y = rng.normal(size=(50, 10)) - d = Distance(X, Y) + dist = Distance(distance) + X = cp_rng.standard_normal(size=(50, 10)) + Y = cp_rng.standard_normal(size=(50, 10)) + d = dist(X, Y) + d = float(d.get()) assert isinstance(d, float) @@ -182,12 +206,12 @@ def test_distance_onesided(adata, distance_obj, distance): assert df.loc[selected_group] == 0 # distance to self is 0 -def test_bootstrap_distance_output_type(rng): +def test_bootstrap_distance_output_type(cp_rng): # Test if distances are outputting floats - Distance = pt.tl.Distance(metric="edistance") - X = rng.normal(size=(50, 10)) - Y = rng.normal(size=(50, 10)) - d = Distance.bootstrap(X, Y, n_bootstrap=3) + d = Distance(metric="edistance") + X = cp_rng.standard_normal(size=(50, 10)) + Y = cp_rng.standard_normal(size=(50, 10)) + d = d.bootstrap(X, Y, n_bootstrap=3) assert hasattr(d, "mean") assert hasattr(d, "variance") @@ -195,10 +219,8 @@ def test_bootstrap_distance_output_type(rng): @mark.parametrize("distance", ["edistance"]) def test_bootstrap_distance_pairwise(adata, distance): # Test consistency of pairwise distance results - Distance = pt.tl.Distance(distance, obsm_key="X_pca") - bootstrap_output = Distance.pairwise( - adata, groupby="perturbation", bootstrap=True, n_bootstrap=3 - ) + dist = Distance(distance, obsm_key="X_pca") + bootstrap_output = dist.pairwise(adata, groupby="perturbation", bootstrap=True, n_bootstrap=3) assert isinstance(bootstrap_output, tuple) @@ -210,30 +232,30 @@ def test_bootstrap_distance_pairwise(adata, distance): assert np.sum(var.values - var.values.T) == 0 # symmetry -@mark.parametrize("distance", ["edistance"]) -def test_bootstrap_distance_onesided(adata, distance): - # Test consistency of one-sided distance results - selected_group = adata.obs.perturbation.unique()[0] - Distance = pt.tl.Distance(distance, obsm_key="X_pca") - bootstrap_output = Distance.onesided_distances( - adata, - groupby="perturbation", - selected_group=selected_group, - bootstrap=True, - n_bootstrap=3, - ) - - assert isinstance(bootstrap_output, tuple) - - -def test_compare_distance(rng): - X = rng.normal(size=(50, 10)) - Y = rng.normal(size=(50, 10)) - C = rng.normal(size=(50, 10)) - Distance = pt.tl.Distance() - res_simple = Distance.compare_distance(X, Y, C, mode="simple") - assert isinstance(res_simple, float) - res_scaled = Distance.compare_distance(X, Y, C, mode="scaled") - assert isinstance(res_scaled, float) - with pytest.raises(ValueError): - Distance.compare_distance(X, Y, C, mode="new_mode") +# @mark.parametrize("distance", ["edistance"]) +# def test_bootstrap_distance_onesided(adata, distance): +# # Test consistency of one-sided distance results +# selected_group = adata.obs.perturbation.unique()[0] +# d = Distance(distance, obsm_key="X_pca") +# bootstrap_output = d.onesided_distances( +# adata, +# groupby="perturbation", +# selected_group=selected_group, +# bootstrap=True, +# n_bootstrap=3, +# ) + +# assert isinstance(bootstrap_output, tuple) + + +# def test_compare_distance(rng): +# X = rng.standard_normal(size=(50, 10)) +# Y = rng.standard_normal(size=(50, 10)) +# C = rng.standard_normal(size=(50, 10)) +# d = Distance() +# res_simple = d.compare_distance(X, Y, C, mode="simple") +# assert isinstance(res_simple.get(), float) +# res_scaled = d.compare_distance(X, Y, C, mode="scaled") +# assert isinstance(res_scaled.get(), float) +# with pytest.raises(ValueError): +# d.compare_distance(X, Y, C, mode="new_mode") From 5115d65038cc054d9a7a9446026dfe96ed1b7886 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:55:23 +0000 Subject: [PATCH 04/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pertpy_gpu/_distances.py | 52 +++++++++++++------ tests/pertpy/test_distances.py | 28 ++++++---- 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/src/rapids_singlecell/pertpy_gpu/_distances.py b/src/rapids_singlecell/pertpy_gpu/_distances.py index c2db2d20..b1b6375b 100644 --- a/src/rapids_singlecell/pertpy_gpu/_distances.py +++ b/src/rapids_singlecell/pertpy_gpu/_distances.py @@ -3,14 +3,14 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Literal, NamedTuple +import cupy as cp import numpy as np import pandas as pd -from pandas import Series -import cupy as cp from cuml.metrics import pairwise_distances +from cuml.preprocessing import MinMaxScaler from cupy.random import choice from cupyx.scipy.sparse import issparse as cp_issparse -from cuml.preprocessing import MinMaxScaler + # TODO(selmanozleyen): adapt which progress bar to use, probably use rapidsinglecell's progress bar (if it exists) # from rich.progress import track # from scipy.sparse import issparse @@ -421,7 +421,6 @@ def pairwise( return df, df_var - def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None: """Precompute pairwise distances between all cells, writes to adata.obsp. @@ -439,8 +438,14 @@ def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None: >>> distance = pt.tools.Distance(metric="edistance") >>> distance.precompute_distances(adata) """ - cells = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() - pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs) + cells = ( + adata.layers[self.layer_key] + if self.layer_key + else adata.obsm[self.obsm_key].copy() + ) + pwd = pairwise_distances( + cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs + ) adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd def compare_distance( @@ -465,8 +470,9 @@ def compare_distance( if mode == "simple": pass # nothing to be done elif mode == "scaled": - - scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl) + scaler = MinMaxScaler().fit( + np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl + ) pred = scaler.transform(pred) pert = scaler.transform(pert) else: @@ -476,7 +482,9 @@ def compare_distance( d2 = self.metric_fct(ctrl, pred, **kwargs) return d1 / d2 - def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + def _bootstrap_mode( + self, X, Y, n_bootstraps=100, random_state=0, **kwargs + ) -> MeanVar: # rng = np.random.default_rng(random_state) distances = [] @@ -491,29 +499,36 @@ def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> M variance = np.var(distances) return MeanVar(mean=mean, variance=variance) - def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + def _bootstrap_mode_precomputed( + self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs + ) -> MeanVar: rng = np.random.default_rng(random_state) distances = [] for _ in range(n_bootstraps): # To maintain the number of cells for both groups (whatever balancing they may have), # we sample the positive and negative indices separately - bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True) - bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True) + bootstrap_pos_idx = rng.choice( + a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True + ) + bootstrap_neg_idx = rng.choice( + a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True + ) bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx]) bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx) bootstrap_sub_idx = sub_idx[bootstrap_idx] bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs] - distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs) + distance = self.metric_fct.from_precomputed( + bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs + ) distances.append(distance.get()) mean = np.mean(distances) variance = np.var(distances) return MeanVar(mean=mean, variance=variance) - def onesided_distances( self, adata: AnnData, @@ -654,6 +669,7 @@ def onesided_distances( return df, df_var + class AbstractDistance(ABC): """Abstract class of distance metrics between two sets of vectors.""" @@ -700,8 +716,12 @@ def __init__(self) -> None: self.cell_wise_metric = "euclidean" def __call__(self, X: cp.ndarray, Y: cp.ndarray, **kwargs) -> float: - sigma_X = pairwise_distances(X, X, metric=self.cell_wise_metric, **kwargs).mean() - sigma_Y = pairwise_distances(Y, Y, metric=self.cell_wise_metric, **kwargs).mean() + sigma_X = pairwise_distances( + X, X, metric=self.cell_wise_metric, **kwargs + ).mean() + sigma_Y = pairwise_distances( + Y, Y, metric=self.cell_wise_metric, **kwargs + ).mean() delta = pairwise_distances(X, Y, metric=self.cell_wise_metric, **kwargs).mean() return 2 * delta - sigma_X - sigma_Y diff --git a/tests/pertpy/test_distances.py b/tests/pertpy/test_distances.py index 1f31e38e..70111989 100644 --- a/tests/pertpy/test_distances.py +++ b/tests/pertpy/test_distances.py @@ -1,25 +1,29 @@ from __future__ import annotations +import cupy as cp import numpy as np -from pertpy import data as dt import pytest import scanpy as sc from pandas import DataFrame, Series +from pertpy import data as dt from pytest import fixture, mark -import cupy as cp -from rapids_singlecell.pertpy_gpu._distances import Distance from scipy import sparse as sp -@pytest.fixture -def cp_rng(): #TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture + +from rapids_singlecell.pertpy_gpu._distances import Distance + + +@pytest.fixture +def cp_rng(): # TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture rng = cp.random.default_rng(seed=42) return rng -@pytest.fixture -def np_rng(): #TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture +@pytest.fixture +def np_rng(): # TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture rng = np.random.default_rng(seed=42) return rng + actual_distances = [ # Euclidean distances and related # "euclidean", @@ -85,11 +89,12 @@ def adata(request): groups = np.unique(adata.obs["perturbation"]) # KDE is slow, subset to 3 groups for speed up adata = adata[adata.obs["perturbation"].isin(groups[0:3])].copy() - + adata.X = cp.asarray(adata.X.toarray()) for l_key in adata.layers.keys(): if sp.issparse(adata.layers[l_key]): - from cupyx.scipy.sparse import csr_matrix, csc_matrix , coo_matrix + from cupyx.scipy.sparse import coo_matrix, csc_matrix, csr_matrix + if sp.isspmatrix_csr(adata.layers[l_key]): adata.layers[l_key] = csr_matrix(adata.layers[l_key]) elif sp.isspmatrix_csc(adata.layers[l_key]): @@ -148,7 +153,6 @@ def test_triangle_inequality(pairwise_distance, distance, np_rng): return for _ in range(5): - from cupy.random import choice triplet = np_rng.choice(pairwise_distance.index, size=3, replace=False) assert ( pairwise_distance.loc[triplet[0], triplet[1]] @@ -220,7 +224,9 @@ def test_bootstrap_distance_output_type(cp_rng): def test_bootstrap_distance_pairwise(adata, distance): # Test consistency of pairwise distance results dist = Distance(distance, obsm_key="X_pca") - bootstrap_output = dist.pairwise(adata, groupby="perturbation", bootstrap=True, n_bootstrap=3) + bootstrap_output = dist.pairwise( + adata, groupby="perturbation", bootstrap=True, n_bootstrap=3 + ) assert isinstance(bootstrap_output, tuple) From 64c166db71087beef3b31cf17aedab747caa49a9 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 9 Sep 2025 11:56:03 +0000 Subject: [PATCH 05/23] error from merge --- tests/pertpy/test_distances.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pertpy/test_distances.py b/tests/pertpy/test_distances.py index 70111989..35b4a9fb 100644 --- a/tests/pertpy/test_distances.py +++ b/tests/pertpy/test_distances.py @@ -175,7 +175,7 @@ def test_distance_counts(adata, distance): if ( distance != "mahalanobis" ): # skip, doesn't work because covariance matrix is a singular matrix, not invertible - distance = pt.tl.Distance(distance, layer_key="counts") + distance = Distance(distance, layer_key="counts") df = distance.pairwise(adata, groupby="perturbation") assert isinstance(df, DataFrame) assert df.columns.equals(df.index) From 20cdf589c2e6fcba13012df177478b66088e09e8 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 9 Sep 2025 12:04:47 +0000 Subject: [PATCH 06/23] integrate more tests --- tests/pertpy/test_distances.py | 56 ++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/tests/pertpy/test_distances.py b/tests/pertpy/test_distances.py index 35b4a9fb..e7f00c62 100644 --- a/tests/pertpy/test_distances.py +++ b/tests/pertpy/test_distances.py @@ -238,30 +238,32 @@ def test_bootstrap_distance_pairwise(adata, distance): assert np.sum(var.values - var.values.T) == 0 # symmetry -# @mark.parametrize("distance", ["edistance"]) -# def test_bootstrap_distance_onesided(adata, distance): -# # Test consistency of one-sided distance results -# selected_group = adata.obs.perturbation.unique()[0] -# d = Distance(distance, obsm_key="X_pca") -# bootstrap_output = d.onesided_distances( -# adata, -# groupby="perturbation", -# selected_group=selected_group, -# bootstrap=True, -# n_bootstrap=3, -# ) - -# assert isinstance(bootstrap_output, tuple) - - -# def test_compare_distance(rng): -# X = rng.standard_normal(size=(50, 10)) -# Y = rng.standard_normal(size=(50, 10)) -# C = rng.standard_normal(size=(50, 10)) -# d = Distance() -# res_simple = d.compare_distance(X, Y, C, mode="simple") -# assert isinstance(res_simple.get(), float) -# res_scaled = d.compare_distance(X, Y, C, mode="scaled") -# assert isinstance(res_scaled.get(), float) -# with pytest.raises(ValueError): -# d.compare_distance(X, Y, C, mode="new_mode") +@mark.parametrize("distance", ["edistance"]) +def test_bootstrap_distance_onesided(adata, distance): + # Test consistency of one-sided distance results + selected_group = adata.obs.perturbation.unique()[0] + d = Distance(distance, obsm_key="X_pca") + bootstrap_output = d.onesided_distances( + adata, + groupby="perturbation", + selected_group=selected_group, + bootstrap=True, + n_bootstrap=3, + ) + + assert isinstance(bootstrap_output, tuple) + + +def test_compare_distance(cp_rng): + X = cp_rng.standard_normal(size=(50, 10)) + Y = cp_rng.standard_normal(size=(50, 10)) + C = cp_rng.standard_normal(size=(50, 10)) + d = Distance() + res_simple = d.compare_distance(X, Y, C, mode="simple") + res_simple = float(res_simple.get()) + assert isinstance(res_simple, float) + res_scaled = d.compare_distance(X, Y, C, mode="scaled") + res_scaled = float(res_scaled.get()) + assert isinstance(res_scaled, float) + with pytest.raises(ValueError): + d.compare_distance(X, Y, C, mode="new_mode") From ee1ecfc945ec3f8dcf0c4e4faa483ea0b7aebc2a Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 9 Sep 2025 14:56:19 +0000 Subject: [PATCH 07/23] save state with temporary scripts. now there is a reference code to improve --- src/rapids_singlecell/pertpy_gpu/__init__.py | 3 + src/rapids_singlecell/ptg.py | 3 + tmp_scripts/a.ipynb | 1259 ++++++++++++++++++ tmp_scripts/compute_edistance.py | 35 + tmp_scripts/compute_edistance_cpu.py | 21 + tmp_scripts/prepare_data.py | 26 + 6 files changed, 1347 insertions(+) create mode 100644 src/rapids_singlecell/pertpy_gpu/__init__.py create mode 100644 src/rapids_singlecell/ptg.py create mode 100644 tmp_scripts/a.ipynb create mode 100644 tmp_scripts/compute_edistance.py create mode 100644 tmp_scripts/compute_edistance_cpu.py create mode 100644 tmp_scripts/prepare_data.py diff --git a/src/rapids_singlecell/pertpy_gpu/__init__.py b/src/rapids_singlecell/pertpy_gpu/__init__.py new file mode 100644 index 00000000..6324ee3c --- /dev/null +++ b/src/rapids_singlecell/pertpy_gpu/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from ._distances import Distance diff --git a/src/rapids_singlecell/ptg.py b/src/rapids_singlecell/ptg.py new file mode 100644 index 00000000..2358f441 --- /dev/null +++ b/src/rapids_singlecell/ptg.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .pertpy_gpu import * diff --git a/tmp_scripts/a.ipynb b/tmp_scripts/a.ipynb new file mode 100644 index 00000000..a1cb90b9 --- /dev/null +++ b/tmp_scripts/a.ipynb @@ -0,0 +1,1259 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 15, + "id": "e1d4de19", + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "import os\n", + "\n", + "import anndata as ad\n", + "import cupy as cp\n", + "import pertpy as pt\n", + "import rmm\n", + "from rmm.allocators.cupy import rmm_cupy_allocator\n", + "\n", + "import rapids_singlecell as rsc\n", + "from rapids_singlecell.ptg import Distance\n", + "\n", + "rmm.reinitialize(\n", + " managed_memory=False, # Allows oversubscription\n", + " pool_allocator=True, # default is False\n", + " devices=0, # GPU device IDs to register. By default registers only GPU 0.\n", + ")\n", + "cp.cuda.set_allocator(rmm_cupy_allocator)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e9c83bff", + "metadata": {}, + "outputs": [], + "source": [ + "save_dir = os.path.join(\n", + " os.path.expanduser(\"~\"), \"data\", \"adamson_2016_upr_epistasis_pca.h5ad\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "ea369599", + "metadata": {}, + "outputs": [], + "source": [ + "adata = ad.read_h5ad(save_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df4921a8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " '*',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " nan,\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " 'IRE1_only_pMJ148',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " nan,\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " nan,\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " nan,\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " nan,\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " nan,\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " nan,\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " 'IRE1_only_pMJ148',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " nan,\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " nan,\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_pMJ150',\n", + " '*',\n", + " 'XBP1_pBA578',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " nan,\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " nan,\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'IRE1_only_pMJ148',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_pMJ150',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_only_pMJ146',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " nan,\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_only_pMJ146',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " nan,\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_only_pMJ145',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_only_pMJ145',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_only_pMJ145',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'IRE1_only_pMJ148',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_PERK_pMJ150',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'PERK_IRE1_pMJ154',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'IRE1_only_pMJ148',\n", + " 'ATF6_IRE1_pMJ152',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'PERK_only_pMJ146',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_IRE1_pMJ154',\n", + " 'IRE1_only_pMJ148',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " '3x_neg_ctrl_pMJ144-2',\n", + " '3x_neg_ctrl_pMJ144-1',\n", + " 'PERK_only_pMJ146',\n", + " 'ATF6_PERK_IRE1_pMJ158',\n", + " 'ATF6_IRE1_pMJ152',\n", + " ...]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "48beeb95", + "metadata": {}, + "outputs": [], + "source": [ + "rsc.get.anndata_to_GPU(adata, convert_all=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "553326ea", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/shadeform/rapids_singlecell/src/rapids_singlecell/pertpy_gpu/_distances.py:730: RuntimeWarning: Mean of empty slice.\n", + " sigma_Y = P[~idx, :][:, ~idx].mean()\n", + "/home/shadeform/miniforge3/envs/pertpy/lib/python3.13/site-packages/numpy/_core/_methods.py:145: RuntimeWarning: invalid value encountered in divide\n", + " ret = ret.dtype.type(ret / rcount)\n", + "/home/shadeform/rapids_singlecell/src/rapids_singlecell/pertpy_gpu/_distances.py:731: RuntimeWarning: Mean of empty slice.\n", + " delta = P[idx, :][:, ~idx].mean()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 6.32 s, sys: 1.88 s, total: 8.2 s\n", + "Wall time: 8.19 s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/shadeform/rapids_singlecell/src/rapids_singlecell/pertpy_gpu/_distances.py:729: RuntimeWarning: Mean of empty slice.\n", + " sigma_X = P[idx, :][:, idx].mean()\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "obs_key = \"perturbation\"\n", + "dist = Distance(obsm_key=\"X_pca\", metric=\"edistance\")\n", + "df = dist.pairwise(adata, groupby=obs_key)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "bd845994", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9423ba54408542468fc5e8eec76894de", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/home/shadeform/miniforge3/envs/pertpy/lib/python3.13/site-packages/pertpy/tools/_distances/_distances.py:675: \n",
+       "RuntimeWarning: Mean of empty slice.\n",
+       "  sigma_Y = P[~idx, :][:, ~idx].mean()\n",
+       "
\n" + ], + "text/plain": [ + "/home/shadeform/miniforge3/envs/pertpy/lib/python3.13/site-packages/pertpy/tools/_distances/_distances.py:675: \n", + "RuntimeWarning: Mean of empty slice.\n", + " sigma_Y = P[~idx, :][:, ~idx].mean()\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/home/shadeform/miniforge3/envs/pertpy/lib/python3.13/site-packages/numpy/_core/_methods.py:145: RuntimeWarning: \n",
+       "invalid value encountered in divide\n",
+       "  ret = ret.dtype.type(ret / rcount)\n",
+       "
\n" + ], + "text/plain": [ + "/home/shadeform/miniforge3/envs/pertpy/lib/python3.13/site-packages/numpy/_core/_methods.py:145: RuntimeWarning: \n", + "invalid value encountered in divide\n", + " ret = ret.dtype.type(ret / rcount)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/home/shadeform/miniforge3/envs/pertpy/lib/python3.13/site-packages/pertpy/tools/_distances/_distances.py:676: \n",
+       "RuntimeWarning: Mean of empty slice.\n",
+       "  delta = P[idx, :][:, ~idx].mean()\n",
+       "
\n" + ], + "text/plain": [ + "/home/shadeform/miniforge3/envs/pertpy/lib/python3.13/site-packages/pertpy/tools/_distances/_distances.py:676: \n", + "RuntimeWarning: Mean of empty slice.\n", + " delta = P[idx, :][:, ~idx].mean()\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/home/shadeform/miniforge3/envs/pertpy/lib/python3.13/site-packages/pertpy/tools/_distances/_distances.py:674: \n",
+       "RuntimeWarning: Mean of empty slice.\n",
+       "  sigma_X = P[idx, :][:, idx].mean()\n",
+       "
\n" + ], + "text/plain": [ + "/home/shadeform/miniforge3/envs/pertpy/lib/python3.13/site-packages/pertpy/tools/_distances/_distances.py:674: \n", + "RuntimeWarning: Mean of empty slice.\n", + " sigma_X = P[idx, :][:, idx].mean()\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "dist = pt.tl.Distance(obsm_key=\"X_pca\", metric=\"edistance\")\n",
+    "df = dist.pairwise(adata, groupby=obs_key)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e92bee29",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "pertpy",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.13.7"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tmp_scripts/compute_edistance.py b/tmp_scripts/compute_edistance.py
new file mode 100644
index 00000000..b795b5b6
--- /dev/null
+++ b/tmp_scripts/compute_edistance.py
@@ -0,0 +1,35 @@
+from __future__ import annotations
+
+import os
+import time
+
+import anndata as ad
+import cupy as cp
+import rmm
+from rmm.allocators.cupy import rmm_cupy_allocator
+
+import rapids_singlecell as rsc
+from rapids_singlecell.ptg import Distance
+
+rmm.reinitialize(
+    managed_memory=False,  # Allows oversubscription
+    pool_allocator=True,  # default is False
+    devices=0,  # GPU device IDs to register. By default registers only GPU 0.
+)
+cp.cuda.set_allocator(rmm_cupy_allocator)
+
+
+if __name__ == "__main__":
+    obs_key = "perturbation"
+
+    # homedir/data/adamson_2016_upr_epistasis
+    save_dir = os.path.join(
+        os.path.expanduser("~"), "data", "adamson_2016_upr_epistasis_pca.h5ad"
+    )
+    adata = ad.read_h5ad(save_dir)
+    rsc.get.anndata_to_GPU(adata, convert_all=True)
+    dist = Distance(obsm_key="X_pca", metric="edistance")
+    start_time = time.time()
+    df = dist.pairwise(adata, groupby=obs_key)
+    end_time = time.time()
+    print(f"Time taken: {end_time - start_time} seconds")
diff --git a/tmp_scripts/compute_edistance_cpu.py b/tmp_scripts/compute_edistance_cpu.py
new file mode 100644
index 00000000..dc0ad05d
--- /dev/null
+++ b/tmp_scripts/compute_edistance_cpu.py
@@ -0,0 +1,21 @@
+from __future__ import annotations
+
+import os
+import time
+
+import anndata as ad
+from pertpy.tools import Distance
+
+if __name__ == "__main__":
+    obs_key = "perturbation"
+
+    # homedir/data/adamson_2016_upr_epistasis
+    save_dir = os.path.join(
+        os.path.expanduser("~"), "data", "adamson_2016_upr_epistasis_pca.h5ad"
+    )
+    adata = ad.read_h5ad(save_dir)
+    dist = Distance(obsm_key="X_pca", metric="edistance")
+    start_time = time.time()
+    df = dist.pairwise(adata, groupby=obs_key)
+    end_time = time.time()
+    print(f"Time taken: {end_time - start_time} seconds")
diff --git a/tmp_scripts/prepare_data.py b/tmp_scripts/prepare_data.py
new file mode 100644
index 00000000..2557eeb8
--- /dev/null
+++ b/tmp_scripts/prepare_data.py
@@ -0,0 +1,26 @@
+from __future__ import annotations
+
+import os
+
+import pertpy as pt
+import scanpy as sc
+
+import rapids_singlecell as rsc
+
+if __name__ == "__main__":
+    adata = pt.data.adamson_2016_upr_epistasis()
+    obs_key = "perturbation"
+
+    # remove genes with 0 expression
+    sc.pp.filter_genes(adata, min_counts=1)
+    sc.pp.filter_cells(adata, min_counts=1)
+    # fill na obskeys
+    # set categories first
+    adata.obs[obs_key] = adata.obs[obs_key].cat.add_categories("control")
+    adata.obs[obs_key] = adata.obs[obs_key].fillna("control")
+    rsc.pp.pca(adata, n_comps=50)
+    # save dir as
+    # homedir/data/adamson_2016_upr_epistasis
+    save_dir = os.path.join(os.path.expanduser("~"), "data")
+    os.makedirs(save_dir, exist_ok=True)
+    adata.write(os.path.join(save_dir, "adamson_2016_upr_epistasis_pca.h5ad"))

From 28cc06fe89fab229b212211527f8d1cc492e5fe3 Mon Sep 17 00:00:00 2001
From: selmanozleyen 
Date: Tue, 9 Sep 2025 23:38:34 +0000
Subject: [PATCH 08/23] vibe codded version + the version I simplied. it works
 roughly

---
 .../pertpy_gpu/_distances_standalone.py       | 332 ++++++++++++++++++
 tmp_scripts/compute_edistance_standalone.py   |  59 ++++
 2 files changed, 391 insertions(+)
 create mode 100644 src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
 create mode 100644 tmp_scripts/compute_edistance_standalone.py

diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
new file mode 100644
index 00000000..203e5364
--- /dev/null
+++ b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
@@ -0,0 +1,332 @@
+from __future__ import annotations
+
+import cupy as cp
+import numpy as np
+import pandas as pd
+from anndata import AnnData
+
+from ..preprocessing._harmony._helper import _create_category_index_mapping
+from ..squidpy_gpu._utils import _assert_categorical_obs
+
+# CUDA kernel for edistance computation - simplified direct approach
+edistance_kernel_code = r"""
+extern "C" __global__
+void edistance_pairwise_kernel(
+    const float* __restrict__ embedding,
+    const int* __restrict__ cat_offsets,
+    const int* __restrict__ cell_indices,
+    const int* __restrict__ pair_left,
+    const int* __restrict__ pair_right,
+    float* __restrict__ edistances,
+    int k,
+    int n_features)
+{
+    extern __shared__ float shared_sums[];
+
+    const int thread_id = threadIdx.x;
+    const int block_id = blockIdx.x;
+    const int block_size = blockDim.x;
+
+    // Each thread accumulates partial sums for [within_A, within_B, between_AB]
+    float local_within_A = 0.0f;
+    float local_within_B = 0.0f;
+    float local_between = 0.0f;
+
+    const int a = pair_left[block_id];
+    const int b = pair_right[block_id];
+
+    const int start_a = cat_offsets[a];
+    const int end_a = cat_offsets[a + 1];
+    const int start_b = cat_offsets[b];
+    const int end_b = cat_offsets[b + 1];
+
+    const int n_a = end_a - start_a;
+    const int n_b = end_b - start_b;
+
+    if (a == b) {
+        // Same group: edistance = 0 by definition, but we still need to compute
+        // within-group sum to match cuml behavior
+        for (int ia = start_a + thread_id; ia < end_a; ia += block_size) {
+            const int idx_i = cell_indices[ia];
+
+            // Compute ALL pairwise distances within group (including symmetric pairs)
+            for (int ja = start_a; ja < end_a; ++ja) {
+                const int idx_j = cell_indices[ja];
+
+                if (idx_i != idx_j) {  // Skip diagonal (distance = 0)
+                    float dist_sq = 0.0f;
+                    for (int feat = 0; feat < n_features; ++feat) {
+                        float diff = embedding[idx_i * n_features + feat] -
+                                    embedding[idx_j * n_features + feat];
+                        dist_sq += diff * diff;
+                    }
+                    local_within_A += sqrtf(dist_sq);
+                }
+            }
+        }
+
+        // Store in shared memory for reduction
+        shared_sums[thread_id] = local_within_A;
+        __syncthreads();
+
+        // Parallel reduction
+        for (int stride = block_size / 2; stride > 0; stride >>= 1) {
+            if (thread_id < stride) {
+                shared_sums[thread_id] += shared_sums[thread_id + stride];
+            }
+            __syncthreads();
+        }
+
+        // For same groups, edistance is always 0
+        if (thread_id == 0) {
+            edistances[block_id] = 0.0f;
+        }
+
+    } else {
+        // Different groups: compute all three components and final edistance
+
+        // 1. Compute within-group A distances (ALL pairs including symmetric)
+        for (int ia = start_a + thread_id; ia < end_a; ia += block_size) {
+            const int idx_i = cell_indices[ia];
+
+            for (int ja = start_a; ja < end_a; ++ja) {
+                const int idx_j = cell_indices[ja];
+
+                if (idx_i != idx_j) {
+                    float dist_sq = 0.0f;
+                    for (int feat = 0; feat < n_features; ++feat) {
+                        float diff = embedding[idx_i * n_features + feat] -
+                                    embedding[idx_j * n_features + feat];
+                        dist_sq += diff * diff;
+                    }
+                    local_within_A += sqrtf(dist_sq);
+                }
+            }
+        }
+
+        // 2. Compute within-group B distances (ALL pairs including symmetric)
+        for (int ib = start_b + thread_id; ib < end_b; ib += block_size) {
+            const int idx_i = cell_indices[ib];
+
+            for (int jb = start_b; jb < end_b; ++jb) {
+                const int idx_j = cell_indices[jb];
+
+                if (idx_i != idx_j) {
+                    float dist_sq = 0.0f;
+                    for (int feat = 0; feat < n_features; ++feat) {
+                        float diff = embedding[idx_i * n_features + feat] -
+                                    embedding[idx_j * n_features + feat];
+                        dist_sq += diff * diff;
+                    }
+                    local_within_B += sqrtf(dist_sq);
+                }
+            }
+        }
+
+        // 3. Compute between-group distances (ALL cross-pairs)
+        for (int ia = start_a + thread_id; ia < end_a; ia += block_size) {
+            const int idx_i = cell_indices[ia];
+
+            for (int jb = start_b; jb < end_b; ++jb) {
+                const int idx_j = cell_indices[jb];
+
+                float dist_sq = 0.0f;
+                for (int feat = 0; feat < n_features; ++feat) {
+                    float diff = embedding[idx_i * n_features + feat] -
+                                embedding[idx_j * n_features + feat];
+                    dist_sq += diff * diff;
+                }
+                local_between += sqrtf(dist_sq);
+            }
+        }
+
+        // Reduce within_A
+        shared_sums[thread_id] = local_within_A;
+        __syncthreads();
+        for (int stride = block_size / 2; stride > 0; stride >>= 1) {
+            if (thread_id < stride) {
+                shared_sums[thread_id] += shared_sums[thread_id + stride];
+            }
+            __syncthreads();
+        }
+        float total_within_A = shared_sums[0];
+        __syncthreads();
+
+        // Reduce within_B
+        shared_sums[thread_id] = local_within_B;
+        __syncthreads();
+        for (int stride = block_size / 2; stride > 0; stride >>= 1) {
+            if (thread_id < stride) {
+                shared_sums[thread_id] += shared_sums[thread_id + stride];
+            }
+            __syncthreads();
+        }
+        float total_within_B = shared_sums[0];
+        __syncthreads();
+
+        // Reduce between
+        shared_sums[thread_id] = local_between;
+        __syncthreads();
+        for (int stride = block_size / 2; stride > 0; stride >>= 1) {
+            if (thread_id < stride) {
+                shared_sums[thread_id] += shared_sums[thread_id + stride];
+            }
+            __syncthreads();
+        }
+        float total_between = shared_sums[0];
+
+        // Compute final edistance directly
+        if (thread_id == 0) {
+            // Normalize by total matrix elements (matching cuml behavior)
+            float mean_within_A = total_within_A / (n_a * n_a);
+            float mean_within_B = total_within_B / (n_b * n_b);
+            float mean_between = total_between / (n_a * n_b);
+
+            // Edistance formula: 2*δ - σ_A - σ_B
+            edistances[block_id] = 2.0f * mean_between - mean_within_A - mean_within_B;
+        }
+    }
+}
+"""
+
+edistance_pairwise_kernel = cp.RawKernel(
+    edistance_kernel_code, "edistance_pairwise_kernel"
+)
+
+
+def _fill_distance_matrix(
+    edistances: cp.ndarray,
+    pairwise_distances: cp.ndarray,
+    pair_left: cp.ndarray,
+    pair_right: cp.ndarray,
+):
+    """Fill the symmetric distance matrix from kernel output"""
+
+    for i in range(pair_left.size):
+        a, b = int(pair_left[i]), int(pair_right[i])
+        edist = float(edistances[i])
+
+        # Fill symmetric matrix
+        pairwise_distances[a, b] = edist
+        pairwise_distances[b, a] = edist
+
+
+def _edistance_pairwise_helper(
+    embedding: cp.ndarray, cat_offsets: cp.ndarray, cell_indices: cp.ndarray, k: int
+) -> cp.ndarray:
+    """
+    Fast pairwise edistance computation using CUDA kernels.
+
+    Parameters
+    ----------
+    embedding : cp.ndarray
+        Cell embeddings [n_cells, n_features]
+    cat_offsets : cp.ndarray
+        Group start/end indices (from harmony helper)
+    cell_indices : cp.ndarray
+        Sorted cell indices by group (from harmony helper)
+    k : int
+        Number of groups
+
+    Returns
+    -------
+    pairwise_distances : cp.ndarray
+        Pairwise edistance matrix [k, k]
+    """
+
+    n_cells, n_features = embedding.shape
+
+    # Build group pairs (same pattern as co_occurrence)
+    pair_left = []
+    pair_right = []
+    for a in range(k):
+        for b in range(a, k):  # Upper triangle
+            pair_left.append(a)
+            pair_right.append(b)
+    pair_left = cp.asarray(pair_left, dtype=cp.int32)
+    pair_right = cp.asarray(pair_right, dtype=cp.int32)
+
+    # Allocate output for final edistances (one per pair)
+    edistances = cp.zeros(pair_left.size, dtype=np.float32)
+
+    # Choose optimal block size (same logic as co_occurrence)
+    props = cp.cuda.runtime.getDeviceProperties(0)
+    max_smem = int(props.get("sharedMemPerBlock", 48 * 1024))
+
+    chosen_threads = None
+    for tpb in (1024, 512, 256, 128, 64, 32):
+        # Each thread needs one float for shared memory reduction
+        required = tpb * cp.dtype(cp.float32).itemsize
+        if required <= max_smem:
+            chosen_threads = tpb
+            shared_mem_size = required
+            break
+
+    # Launch kernel (similar pattern to co_occurrence)
+    grid = (pair_left.size,)  # One block per group pair
+    block = (chosen_threads,)
+    edistance_pairwise_kernel(
+        grid,
+        block,
+        (
+            embedding,
+            cat_offsets,
+            cell_indices,
+            pair_left,
+            pair_right,
+            edistances,
+            k,
+            n_features,
+        ),
+        shared_mem=shared_mem_size,
+    )
+
+    # Fill symmetric distance matrix
+    pairwise_distances = cp.zeros((k, k), dtype=np.float32)
+    _fill_distance_matrix(edistances, pairwise_distances, pair_left, pair_right)
+
+    return pairwise_distances
+
+
+def pairwise_edistance_gpu(
+    adata: AnnData,
+    groupby: str,
+    *,
+    obsm_key: str = "X_pca",
+    groups: list[str] | None = None,
+    copy: bool = False,
+) -> pd.DataFrame | None:
+    """GPU-accelerated pairwise edistance computation"""
+
+    # 1. Prepare data (exactly like co_occurrence)
+    _assert_categorical_obs(adata, key=groupby)  # Reuse validation
+
+    embedding = cp.array(adata.obsm[obsm_key]).astype(np.float32)
+    original_groups = adata.obs[groupby]
+    group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)}
+    group_labels = cp.array([group_map[c] for c in original_groups], dtype=np.int32)
+
+    # 2. Use harmony's category mapping (same as co_occurrence)
+    k = len(group_map)  # number of groups
+    cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)
+
+    # 3. Compute pairwise edistances using GPU kernels
+    pairwise_distances = _edistance_pairwise_helper(
+        embedding, cat_offsets, cell_indices, k
+    )
+
+    # 4. Create output DataFrame (same pattern as pertpy)
+    groups_list = (
+        list(original_groups.cat.categories.values) if groups is None else groups
+    )
+    df = pd.DataFrame(pairwise_distances.get(), index=groups_list, columns=groups_list)
+    df.index.name = groupby
+    df.columns.name = groupby
+    df.name = "pairwise edistance"
+
+    if copy:
+        return df
+
+    # Store in adata like co_occurrence does
+    adata.uns[f"{groupby}_pairwise_edistance"] = {"distances": df}
+    return df
diff --git a/tmp_scripts/compute_edistance_standalone.py b/tmp_scripts/compute_edistance_standalone.py
new file mode 100644
index 00000000..ffe49760
--- /dev/null
+++ b/tmp_scripts/compute_edistance_standalone.py
@@ -0,0 +1,59 @@
+from __future__ import annotations
+
+import os
+import time
+
+import anndata as ad
+import cupy as cp
+import numpy as np
+import pandas as pd
+import rmm
+from rmm.allocators.cupy import rmm_cupy_allocator
+
+import rapids_singlecell as rsc
+from rapids_singlecell.pertpy_gpu._distances_standalone import pairwise_edistance_gpu
+
+rmm.reinitialize(
+    managed_memory=False,  # Allows oversubscription
+    pool_allocator=True,  # default is False
+    devices=0,  # GPU device IDs to register. By default registers only GPU 0.
+)
+cp.cuda.set_allocator(rmm_cupy_allocator)
+
+
+if __name__ == "__main__":
+    obs_key = "perturbation"
+
+    # homedir/data/adamson_2016_upr_epistasis
+    save_dir = os.path.join(
+        os.path.expanduser("~"), "data", "adamson_2016_upr_epistasis_pca.h5ad"
+    )
+    save_dir_df = os.path.join(os.path.expanduser("~"), "data", "df_cpu.csv")
+    adata = ad.read_h5ad(save_dir)
+    rsc.get.anndata_to_GPU(adata, convert_all=True)
+    start_time = time.time()
+    df_gpu = pairwise_edistance_gpu(adata, groupby=obs_key, obsm_key="X_pca")
+    end_time = time.time()
+    print(f"Time taken: {end_time - start_time} seconds")
+    # print("CPU time")
+    # dist = Distance(obsm_key="X_pca", metric="edistance")
+    # start_time = time.time()
+    # df_cpu = dist.pairwise(adata, groupby=obs_key)
+    # end_time = time.time()
+    # print(f"Time taken: {end_time - start_time} seconds")
+    df = pd.read_csv(save_dir_df, index_col=0)
+
+    groups = adata.obs[obs_key].unique()
+    for group_x in groups:
+        for group_y in groups:
+            if group_x == group_y:
+                assert df_gpu.loc[group_x, group_y] == 0
+            else:
+                assert np.isclose(
+                    df_gpu.loc[group_x, group_y], df.loc[group_x, group_y], atol=1e-3
+                ), (
+                    f"Group df_gpu: {df_gpu.loc[group_x, group_y]}, Group df: {df.loc[group_x, group_y]}"
+                )
+    # print(df.equals(df_gpu))
+    # print(df)
+    # print(df_gpu)

From 4df0982a841414ab5cbeb763c1b2415bed46cc73 Mon Sep 17 00:00:00 2001
From: selmanozleyen 
Date: Wed, 10 Sep 2025 00:04:54 +0000
Subject: [PATCH 09/23] changes

---
 .../pertpy_gpu/_distances_standalone.py       | 41 +++----------------
 1 file changed, 5 insertions(+), 36 deletions(-)

diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
index 203e5364..d36fe4a2 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
+++ b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
@@ -46,38 +46,6 @@
     if (a == b) {
         // Same group: edistance = 0 by definition, but we still need to compute
         // within-group sum to match cuml behavior
-        for (int ia = start_a + thread_id; ia < end_a; ia += block_size) {
-            const int idx_i = cell_indices[ia];
-
-            // Compute ALL pairwise distances within group (including symmetric pairs)
-            for (int ja = start_a; ja < end_a; ++ja) {
-                const int idx_j = cell_indices[ja];
-
-                if (idx_i != idx_j) {  // Skip diagonal (distance = 0)
-                    float dist_sq = 0.0f;
-                    for (int feat = 0; feat < n_features; ++feat) {
-                        float diff = embedding[idx_i * n_features + feat] -
-                                    embedding[idx_j * n_features + feat];
-                        dist_sq += diff * diff;
-                    }
-                    local_within_A += sqrtf(dist_sq);
-                }
-            }
-        }
-
-        // Store in shared memory for reduction
-        shared_sums[thread_id] = local_within_A;
-        __syncthreads();
-
-        // Parallel reduction
-        for (int stride = block_size / 2; stride > 0; stride >>= 1) {
-            if (thread_id < stride) {
-                shared_sums[thread_id] += shared_sums[thread_id + stride];
-            }
-            __syncthreads();
-        }
-
-        // For same groups, edistance is always 0
         if (thread_id == 0) {
             edistances[block_id] = 0.0f;
         }
@@ -178,12 +146,13 @@
         // Compute final edistance directly
         if (thread_id == 0) {
             // Normalize by total matrix elements (matching cuml behavior)
-            float mean_within_A = total_within_A / (n_a * n_a);
-            float mean_within_B = total_within_B / (n_b * n_b);
-            float mean_between = total_between / (n_a * n_b);
+            // cast to float
+            float mean_within_A = total_within_A / ((float)(n_a * n_a));
+            float mean_within_B = total_within_B / ((float)(n_b * n_b));
+            float mean_between = total_between / ((float)(n_a * n_b) * 0.5f);
 
             // Edistance formula: 2*δ - σ_A - σ_B
-            edistances[block_id] = 2.0f * mean_between - mean_within_A - mean_within_B;
+            edistances[block_id] = mean_between - mean_within_A - mean_within_B;
         }
     }
 }

From 33d3fc1624a523155fa47b9bccc5a219fe7a0c9a Mon Sep 17 00:00:00 2001
From: selmanozleyen 
Date: Wed, 10 Sep 2025 09:18:29 +0000
Subject: [PATCH 10/23] simplify kernels and create a separate file

---
 .../pertpy_gpu/_distances_standalone.py       | 344 ++++++------------
 .../pertpy_gpu/kernels/edistance_kernels.cu   |  77 ++++
 2 files changed, 192 insertions(+), 229 deletions(-)
 create mode 100644 src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu

diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
index d36fe4a2..661512dc 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
+++ b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
@@ -1,5 +1,7 @@
 from __future__ import annotations
 
+import os
+from pathlib import Path
 import cupy as cp
 import numpy as np
 import pandas as pd
@@ -8,253 +10,105 @@
 from ..preprocessing._harmony._helper import _create_category_index_mapping
 from ..squidpy_gpu._utils import _assert_categorical_obs
 
-# CUDA kernel for edistance computation - simplified direct approach
-edistance_kernel_code = r"""
-extern "C" __global__
-void edistance_pairwise_kernel(
-    const float* __restrict__ embedding,
-    const int* __restrict__ cat_offsets,
-    const int* __restrict__ cell_indices,
-    const int* __restrict__ pair_left,
-    const int* __restrict__ pair_right,
-    float* __restrict__ edistances,
-    int k,
-    int n_features)
-{
-    extern __shared__ float shared_sums[];
-
-    const int thread_id = threadIdx.x;
-    const int block_id = blockIdx.x;
-    const int block_size = blockDim.x;
-
-    // Each thread accumulates partial sums for [within_A, within_B, between_AB]
-    float local_within_A = 0.0f;
-    float local_within_B = 0.0f;
-    float local_between = 0.0f;
-
-    const int a = pair_left[block_id];
-    const int b = pair_right[block_id];
-
-    const int start_a = cat_offsets[a];
-    const int end_a = cat_offsets[a + 1];
-    const int start_b = cat_offsets[b];
-    const int end_b = cat_offsets[b + 1];
-
-    const int n_a = end_a - start_a;
-    const int n_b = end_b - start_b;
-
-    if (a == b) {
-        // Same group: edistance = 0 by definition, but we still need to compute
-        // within-group sum to match cuml behavior
-        if (thread_id == 0) {
-            edistances[block_id] = 0.0f;
-        }
-
-    } else {
-        // Different groups: compute all three components and final edistance
-
-        // 1. Compute within-group A distances (ALL pairs including symmetric)
-        for (int ia = start_a + thread_id; ia < end_a; ia += block_size) {
-            const int idx_i = cell_indices[ia];
-
-            for (int ja = start_a; ja < end_a; ++ja) {
-                const int idx_j = cell_indices[ja];
-
-                if (idx_i != idx_j) {
-                    float dist_sq = 0.0f;
-                    for (int feat = 0; feat < n_features; ++feat) {
-                        float diff = embedding[idx_i * n_features + feat] -
-                                    embedding[idx_j * n_features + feat];
-                        dist_sq += diff * diff;
-                    }
-                    local_within_A += sqrtf(dist_sq);
-                }
-            }
-        }
-
-        // 2. Compute within-group B distances (ALL pairs including symmetric)
-        for (int ib = start_b + thread_id; ib < end_b; ib += block_size) {
-            const int idx_i = cell_indices[ib];
-
-            for (int jb = start_b; jb < end_b; ++jb) {
-                const int idx_j = cell_indices[jb];
-
-                if (idx_i != idx_j) {
-                    float dist_sq = 0.0f;
-                    for (int feat = 0; feat < n_features; ++feat) {
-                        float diff = embedding[idx_i * n_features + feat] -
-                                    embedding[idx_j * n_features + feat];
-                        dist_sq += diff * diff;
-                    }
-                    local_within_B += sqrtf(dist_sq);
-                }
-            }
-        }
-
-        // 3. Compute between-group distances (ALL cross-pairs)
-        for (int ia = start_a + thread_id; ia < end_a; ia += block_size) {
-            const int idx_i = cell_indices[ia];
-
-            for (int jb = start_b; jb < end_b; ++jb) {
-                const int idx_j = cell_indices[jb];
-
-                float dist_sq = 0.0f;
-                for (int feat = 0; feat < n_features; ++feat) {
-                    float diff = embedding[idx_i * n_features + feat] -
-                                embedding[idx_j * n_features + feat];
-                    dist_sq += diff * diff;
-                }
-                local_between += sqrtf(dist_sq);
-            }
-        }
-
-        // Reduce within_A
-        shared_sums[thread_id] = local_within_A;
-        __syncthreads();
-        for (int stride = block_size / 2; stride > 0; stride >>= 1) {
-            if (thread_id < stride) {
-                shared_sums[thread_id] += shared_sums[thread_id + stride];
-            }
-            __syncthreads();
-        }
-        float total_within_A = shared_sums[0];
-        __syncthreads();
-
-        // Reduce within_B
-        shared_sums[thread_id] = local_within_B;
-        __syncthreads();
-        for (int stride = block_size / 2; stride > 0; stride >>= 1) {
-            if (thread_id < stride) {
-                shared_sums[thread_id] += shared_sums[thread_id + stride];
-            }
-            __syncthreads();
-        }
-        float total_within_B = shared_sums[0];
-        __syncthreads();
-
-        // Reduce between
-        shared_sums[thread_id] = local_between;
-        __syncthreads();
-        for (int stride = block_size / 2; stride > 0; stride >>= 1) {
-            if (thread_id < stride) {
-                shared_sums[thread_id] += shared_sums[thread_id + stride];
-            }
-            __syncthreads();
-        }
-        float total_between = shared_sums[0];
-
-        // Compute final edistance directly
-        if (thread_id == 0) {
-            // Normalize by total matrix elements (matching cuml behavior)
-            // cast to float
-            float mean_within_A = total_within_A / ((float)(n_a * n_a));
-            float mean_within_B = total_within_B / ((float)(n_b * n_b));
-            float mean_between = total_between / ((float)(n_a * n_b) * 0.5f);
-
-            // Edistance formula: 2*δ - σ_A - σ_B
-            edistances[block_id] = mean_between - mean_within_A - mean_within_B;
-        }
-    }
-}
-"""
-
-edistance_pairwise_kernel = cp.RawKernel(
-    edistance_kernel_code, "edistance_pairwise_kernel"
-)
-
-
-def _fill_distance_matrix(
-    edistances: cp.ndarray,
-    pairwise_distances: cp.ndarray,
-    pair_left: cp.ndarray,
-    pair_right: cp.ndarray,
-):
-    """Fill the symmetric distance matrix from kernel output"""
-
-    for i in range(pair_left.size):
-        a, b = int(pair_left[i]), int(pair_right[i])
-        edist = float(edistances[i])
-
-        # Fill symmetric matrix
-        pairwise_distances[a, b] = edist
-        pairwise_distances[b, a] = edist
-
-
-def _edistance_pairwise_helper(
+# Load CUDA kernels from separate file
+def _load_edistance_kernels():
+    """Load CUDA kernels from separate .cu file"""
+    kernel_dir = Path(__file__).parent / "kernels"
+    kernel_file = kernel_dir / "edistance_kernels.cu"
+    
+    if not kernel_file.exists():
+        raise FileNotFoundError(f"CUDA kernel file not found: {kernel_file}")
+    
+    with open(kernel_file, 'r') as f:
+        kernel_code = f.read()
+    
+    # Compile kernels
+    compute_group_distances_kernel = cp.RawKernel(kernel_code, "compute_group_distances")
+    
+    return compute_group_distances_kernel
+
+# Load kernels at module import time
+compute_group_distances_kernel = _load_edistance_kernels()
+
+
+
+def compute_d_other_gpu(
     embedding: cp.ndarray, cat_offsets: cp.ndarray, cell_indices: cp.ndarray, k: int
 ) -> cp.ndarray:
     """
-    Fast pairwise edistance computation using CUDA kernels.
-
+    Compute between-group mean distances for all group pairs.
+    
     Parameters
     ----------
     embedding : cp.ndarray
         Cell embeddings [n_cells, n_features]
     cat_offsets : cp.ndarray
-        Group start/end indices (from harmony helper)
+        Group start/end indices
     cell_indices : cp.ndarray
-        Sorted cell indices by group (from harmony helper)
+        Sorted cell indices by group
     k : int
         Number of groups
-
+        
     Returns
     -------
-    pairwise_distances : cp.ndarray
-        Pairwise edistance matrix [k, k]
+    d_other : cp.ndarray
+        Between-group mean distances [k, k]
     """
-
-    n_cells, n_features = embedding.shape
-
-    # Build group pairs (same pattern as co_occurrence)
+    _, n_features = embedding.shape
+    
     pair_left = []
     pair_right = []
+    pair_indices = []     
+    # only upper triangle
     for a in range(k):
-        for b in range(a, k):  # Upper triangle
+        for b in range(a, k):
             pair_left.append(a)
             pair_right.append(b)
+            pair_indices.append(a * k + b)  # Flatten matrix index
+    
     pair_left = cp.asarray(pair_left, dtype=cp.int32)
     pair_right = cp.asarray(pair_right, dtype=cp.int32)
-
-    # Allocate output for final edistances (one per pair)
-    edistances = cp.zeros(pair_left.size, dtype=np.float32)
-
-    # Choose optimal block size (same logic as co_occurrence)
+    pair_indices = cp.asarray(pair_indices, dtype=cp.int32)
+    
+    num_pairs = len(pair_left)  # k * (k-1) pairs instead of k²
+
+    # Allocate output for off-diagonal distances only
+    d_other_offdiag = cp.zeros(num_pairs, dtype=np.float32)
+    
+    # Choose optimal block size
     props = cp.cuda.runtime.getDeviceProperties(0)
     max_smem = int(props.get("sharedMemPerBlock", 48 * 1024))
 
     chosen_threads = None
+    shared_mem_size = 0 # TODO: think of a better way to do this
     for tpb in (1024, 512, 256, 128, 64, 32):
-        # Each thread needs one float for shared memory reduction
         required = tpb * cp.dtype(cp.float32).itemsize
         if required <= max_smem:
             chosen_threads = tpb
             shared_mem_size = required
             break
 
-    # Launch kernel (similar pattern to co_occurrence)
-    grid = (pair_left.size,)  # One block per group pair
+    # Launch kernel - one block per OFF-DIAGONAL group pair only
+    grid = (num_pairs,)
     block = (chosen_threads,)
-    edistance_pairwise_kernel(
+    compute_group_distances_kernel(
         grid,
         block,
-        (
-            embedding,
-            cat_offsets,
-            cell_indices,
-            pair_left,
-            pair_right,
-            edistances,
-            k,
-            n_features,
-        ),
+        (embedding, cat_offsets, cell_indices, pair_left, pair_right, d_other_offdiag, k, n_features),
         shared_mem=shared_mem_size,
     )
-
-    # Fill symmetric distance matrix
-    pairwise_distances = cp.zeros((k, k), dtype=np.float32)
-    _fill_distance_matrix(edistances, pairwise_distances, pair_left, pair_right)
-
-    return pairwise_distances
+    
+    # Build full k x k matrix
+    pairwise_means = cp.zeros((k, k), dtype=np.float32)
+    
+    # Fill the full matrix
+    for i, idx in enumerate(pair_indices.get()):
+        a, b = divmod(idx, k)
+        pairwise_means[a, b] = d_other_offdiag[i]
+        pairwise_means[b, a] = d_other_offdiag[i]
+    
+    
+    return pairwise_means
 
 
 def pairwise_edistance_gpu(
@@ -263,39 +117,71 @@ def pairwise_edistance_gpu(
     *,
     obsm_key: str = "X_pca",
     groups: list[str] | None = None,
-    copy: bool = False,
-) -> pd.DataFrame | None:
-    """GPU-accelerated pairwise edistance computation"""
-
-    # 1. Prepare data (exactly like co_occurrence)
-    _assert_categorical_obs(adata, key=groupby)  # Reuse validation
+) -> pd.DataFrame:
+    """
+    GPU-accelerated pairwise edistance computation with decomposed components.
+    
+    Returns d_itself, d_other arrays and final edistance DataFrame where:
+    df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
+    
+    Parameters
+    ----------
+    adata : AnnData
+        Annotated data matrix
+    groupby : str
+        Key in adata.obs for grouping
+    obsm_key : str
+        Key in adata.obsm for embeddings
+    groups : list[str] | None
+        Specific groups to compute (if None, use all)
+    copy : bool
+        Whether to return a copy
+        
+    Returns
+    -------
+    d_itself : cp.ndarray
+        Within-group mean distances [k]
+    d_other : cp.ndarray  
+        Between-group mean distances [k, k]
+    df : pd.DataFrame
+        Final edistance matrix
+    """
+    # 1. Prepare data (same as original)
+    _assert_categorical_obs(adata, key=groupby)
 
     embedding = cp.array(adata.obsm[obsm_key]).astype(np.float32)
     original_groups = adata.obs[groupby]
     group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)}
-    group_labels = cp.array([group_map[c] for c in original_groups], dtype=np.int32)
+    group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32)
 
-    # 2. Use harmony's category mapping (same as co_occurrence)
-    k = len(group_map)  # number of groups
+    # 2. Use harmony's category mapping
+    k = len(group_map)
     cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)
 
-    # 3. Compute pairwise edistances using GPU kernels
-    pairwise_distances = _edistance_pairwise_helper(
-        embedding, cat_offsets, cell_indices, k
-    )
+    # 3. Compute decomposed components
+    # d_itself = compute_d_itself_gpu(embedding, cat_offsets, cell_indices, k)
+    pairwise_means = compute_d_other_gpu(embedding, cat_offsets, cell_indices, k)
+    
+    # 4. Compute final edistance: df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
+    edistance_matrix = cp.zeros((k, k), dtype=np.float32)
+    for a in range(k):
+        for b in range(a+1, k):
+            edistance_matrix[a, b] = 2 * pairwise_means[a, b] - pairwise_means[a, a] - pairwise_means[b, b]
+            edistance_matrix[b, a] = edistance_matrix[a, b]
 
-    # 4. Create output DataFrame (same pattern as pertpy)
+    # 5. Create output DataFrame
     groups_list = (
         list(original_groups.cat.categories.values) if groups is None else groups
     )
-    df = pd.DataFrame(pairwise_distances.get(), index=groups_list, columns=groups_list)
+    df = pd.DataFrame(edistance_matrix.get(), index=groups_list, columns=groups_list)
     df.index.name = groupby
     df.columns.name = groupby
     df.name = "pairwise edistance"
 
-    if copy:
-        return df
 
-    # Store in adata like co_occurrence does
-    adata.uns[f"{groupby}_pairwise_edistance"] = {"distances": df}
+    # Store in adata
+    adata.uns[f"{groupby}_pairwise_edistance"] = {
+        "distances": df,
+    }
+    
     return df
diff --git a/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu b/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
new file mode 100644
index 00000000..668915ea
--- /dev/null
+++ b/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
@@ -0,0 +1,77 @@
+// kernels/edistance_kernels.cu
+#include 
+#include 
+
+extern "C" {
+
+/**
+ * Compute between-group mean distances for off-diagonal pairs only
+ * Each block processes one group pair, threads collaborate within the block
+ */
+__global__ void compute_group_distances(
+    const float* __restrict__ embedding,
+    const int* __restrict__ cat_offsets,
+    const int* __restrict__ cell_indices,
+    const int* __restrict__ pair_left,
+    const int* __restrict__ pair_right,
+    float* __restrict__ d_other,
+    int k,
+    int n_features)
+{
+    extern __shared__ float shared_sums[];
+
+    const int thread_id = threadIdx.x;
+    const int block_id = blockIdx.x;
+    const int block_size = blockDim.x;
+
+    float local_sum = 0.0f;
+
+    const int a = pair_left[block_id];
+    const int b = pair_right[block_id];
+
+    // No need to check a == b since we only pass off-diagonal pairs
+
+    const int start_a = cat_offsets[a];
+    const int end_a = cat_offsets[a + 1];
+    const int start_b = cat_offsets[b];
+    const int end_b = cat_offsets[b + 1];
+
+    const int n_a = end_a - start_a;
+    const int n_b = end_b - start_b;
+
+    // Compute between-group distances (ALL cross-pairs)
+    for (int ia = start_a + thread_id; ia < end_a; ia += block_size) {
+        const int idx_i = cell_indices[ia];
+
+        for (int jb = start_b; jb < end_b; ++jb) {
+            const int idx_j = cell_indices[jb];
+
+            float dist_sq = 0.0f;
+            #pragma unroll
+            for (int feat = 0; feat < n_features; ++feat) {
+                float diff = embedding[idx_i * n_features + feat] -
+                            embedding[idx_j * n_features + feat];
+                dist_sq += diff * diff;
+            }
+            local_sum += sqrtf(dist_sq);
+        }
+    }
+
+    // Reduce across threads using shared memory
+    shared_sums[thread_id] = local_sum;
+    __syncthreads();
+    
+    for (int stride = block_size / 2; stride > 0; stride >>= 1) {
+        if (thread_id < stride) {
+            shared_sums[thread_id] += shared_sums[thread_id + stride];
+        }
+        __syncthreads();
+    }
+
+    if (thread_id == 0) {
+        // Store mean between-group distance
+        d_other[block_id] = shared_sums[0] / (float)(n_a * n_b);
+    }
+}
+
+} // extern "C"
\ No newline at end of file

From a8b016be427d73cf6de66b80b38261688a121024 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 10 Sep 2025 09:18:39 +0000
Subject: [PATCH 11/23] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 .../pertpy_gpu/_distances_standalone.py       | 72 +++++++++++--------
 .../pertpy_gpu/kernels/edistance_kernels.cu   |  4 +-
 2 files changed, 44 insertions(+), 32 deletions(-)

diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
index 661512dc..9434330e 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
+++ b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 
-import os
 from pathlib import Path
+
 import cupy as cp
 import numpy as np
 import pandas as pd
@@ -10,34 +10,37 @@
 from ..preprocessing._harmony._helper import _create_category_index_mapping
 from ..squidpy_gpu._utils import _assert_categorical_obs
 
+
 # Load CUDA kernels from separate file
 def _load_edistance_kernels():
     """Load CUDA kernels from separate .cu file"""
     kernel_dir = Path(__file__).parent / "kernels"
     kernel_file = kernel_dir / "edistance_kernels.cu"
-    
+
     if not kernel_file.exists():
         raise FileNotFoundError(f"CUDA kernel file not found: {kernel_file}")
-    
-    with open(kernel_file, 'r') as f:
+
+    with open(kernel_file) as f:
         kernel_code = f.read()
-    
+
     # Compile kernels
-    compute_group_distances_kernel = cp.RawKernel(kernel_code, "compute_group_distances")
-    
+    compute_group_distances_kernel = cp.RawKernel(
+        kernel_code, "compute_group_distances"
+    )
+
     return compute_group_distances_kernel
 
+
 # Load kernels at module import time
 compute_group_distances_kernel = _load_edistance_kernels()
 
 
-
 def compute_d_other_gpu(
     embedding: cp.ndarray, cat_offsets: cp.ndarray, cell_indices: cp.ndarray, k: int
 ) -> cp.ndarray:
     """
     Compute between-group mean distances for all group pairs.
-    
+
     Parameters
     ----------
     embedding : cp.ndarray
@@ -48,39 +51,39 @@ def compute_d_other_gpu(
         Sorted cell indices by group
     k : int
         Number of groups
-        
+
     Returns
     -------
     d_other : cp.ndarray
         Between-group mean distances [k, k]
     """
     _, n_features = embedding.shape
-    
+
     pair_left = []
     pair_right = []
-    pair_indices = []     
+    pair_indices = []
     # only upper triangle
     for a in range(k):
         for b in range(a, k):
             pair_left.append(a)
             pair_right.append(b)
             pair_indices.append(a * k + b)  # Flatten matrix index
-    
+
     pair_left = cp.asarray(pair_left, dtype=cp.int32)
     pair_right = cp.asarray(pair_right, dtype=cp.int32)
     pair_indices = cp.asarray(pair_indices, dtype=cp.int32)
-    
+
     num_pairs = len(pair_left)  # k * (k-1) pairs instead of k²
 
     # Allocate output for off-diagonal distances only
     d_other_offdiag = cp.zeros(num_pairs, dtype=np.float32)
-    
+
     # Choose optimal block size
     props = cp.cuda.runtime.getDeviceProperties(0)
     max_smem = int(props.get("sharedMemPerBlock", 48 * 1024))
 
     chosen_threads = None
-    shared_mem_size = 0 # TODO: think of a better way to do this
+    shared_mem_size = 0  # TODO: think of a better way to do this
     for tpb in (1024, 512, 256, 128, 64, 32):
         required = tpb * cp.dtype(cp.float32).itemsize
         if required <= max_smem:
@@ -94,20 +97,28 @@ def compute_d_other_gpu(
     compute_group_distances_kernel(
         grid,
         block,
-        (embedding, cat_offsets, cell_indices, pair_left, pair_right, d_other_offdiag, k, n_features),
+        (
+            embedding,
+            cat_offsets,
+            cell_indices,
+            pair_left,
+            pair_right,
+            d_other_offdiag,
+            k,
+            n_features,
+        ),
         shared_mem=shared_mem_size,
     )
-    
+
     # Build full k x k matrix
     pairwise_means = cp.zeros((k, k), dtype=np.float32)
-    
+
     # Fill the full matrix
     for i, idx in enumerate(pair_indices.get()):
         a, b = divmod(idx, k)
         pairwise_means[a, b] = d_other_offdiag[i]
         pairwise_means[b, a] = d_other_offdiag[i]
-    
-    
+
     return pairwise_means
 
 
@@ -120,10 +131,10 @@ def pairwise_edistance_gpu(
 ) -> pd.DataFrame:
     """
     GPU-accelerated pairwise edistance computation with decomposed components.
-    
+
     Returns d_itself, d_other arrays and final edistance DataFrame where:
     df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
-    
+
     Parameters
     ----------
     adata : AnnData
@@ -136,12 +147,12 @@ def pairwise_edistance_gpu(
         Specific groups to compute (if None, use all)
     copy : bool
         Whether to return a copy
-        
+
     Returns
     -------
     d_itself : cp.ndarray
         Within-group mean distances [k]
-    d_other : cp.ndarray  
+    d_other : cp.ndarray
         Between-group mean distances [k, k]
     df : pd.DataFrame
         Final edistance matrix
@@ -161,12 +172,14 @@ def pairwise_edistance_gpu(
     # 3. Compute decomposed components
     # d_itself = compute_d_itself_gpu(embedding, cat_offsets, cell_indices, k)
     pairwise_means = compute_d_other_gpu(embedding, cat_offsets, cell_indices, k)
-    
+
     # 4. Compute final edistance: df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
     edistance_matrix = cp.zeros((k, k), dtype=np.float32)
     for a in range(k):
-        for b in range(a+1, k):
-            edistance_matrix[a, b] = 2 * pairwise_means[a, b] - pairwise_means[a, a] - pairwise_means[b, b]
+        for b in range(a + 1, k):
+            edistance_matrix[a, b] = (
+                2 * pairwise_means[a, b] - pairwise_means[a, a] - pairwise_means[b, b]
+            )
             edistance_matrix[b, a] = edistance_matrix[a, b]
 
     # 5. Create output DataFrame
@@ -178,10 +191,9 @@ def pairwise_edistance_gpu(
     df.columns.name = groupby
     df.name = "pairwise edistance"
 
-
     # Store in adata
     adata.uns[f"{groupby}_pairwise_edistance"] = {
         "distances": df,
     }
-    
+
     return df
diff --git a/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu b/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
index 668915ea..9aa2f9f0 100644
--- a/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
+++ b/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
@@ -60,7 +60,7 @@ __global__ void compute_group_distances(
     // Reduce across threads using shared memory
     shared_sums[thread_id] = local_sum;
     __syncthreads();
-    
+
     for (int stride = block_size / 2; stride > 0; stride >>= 1) {
         if (thread_id < stride) {
             shared_sums[thread_id] += shared_sums[thread_id + stride];
@@ -74,4 +74,4 @@ __global__ void compute_group_distances(
     }
 }
 
-} // extern "C"
\ No newline at end of file
+} // extern "C"

From 9323bfeadca80e7dd49d011229d1bddcd5209e2b Mon Sep 17 00:00:00 2001
From: selmanozleyen 
Date: Wed, 10 Sep 2025 11:32:37 +0000
Subject: [PATCH 12/23] move to float64 for precision.

---
 .../pertpy_gpu/_distances_standalone.py       |  8 +-
 .../pertpy_gpu/kernels/edistance_kernels.cu   | 16 ++--
 tmp_scripts/compute_edistance_standalone.py   | 25 ++++--
 tmp_scripts/generate_float64_reference.py     | 85 +++++++++++++++++++
 4 files changed, 114 insertions(+), 20 deletions(-)
 create mode 100644 tmp_scripts/generate_float64_reference.py

diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
index 9434330e..944fe383 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
+++ b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
@@ -76,7 +76,7 @@ def compute_d_other_gpu(
     num_pairs = len(pair_left)  # k * (k-1) pairs instead of k²
 
     # Allocate output for off-diagonal distances only
-    d_other_offdiag = cp.zeros(num_pairs, dtype=np.float32)
+    d_other_offdiag = cp.zeros(num_pairs, dtype=np.float64)
 
     # Choose optimal block size
     props = cp.cuda.runtime.getDeviceProperties(0)
@@ -85,7 +85,7 @@ def compute_d_other_gpu(
     chosen_threads = None
     shared_mem_size = 0  # TODO: think of a better way to do this
     for tpb in (1024, 512, 256, 128, 64, 32):
-        required = tpb * cp.dtype(cp.float32).itemsize
+        required = tpb * cp.dtype(cp.float64).itemsize
         if required <= max_smem:
             chosen_threads = tpb
             shared_mem_size = required
@@ -111,7 +111,7 @@ def compute_d_other_gpu(
     )
 
     # Build full k x k matrix
-    pairwise_means = cp.zeros((k, k), dtype=np.float32)
+    pairwise_means = cp.zeros((k, k), dtype=np.float64)
 
     # Fill the full matrix
     for i, idx in enumerate(pair_indices.get()):
@@ -160,7 +160,7 @@ def pairwise_edistance_gpu(
     # 1. Prepare data (same as original)
     _assert_categorical_obs(adata, key=groupby)
 
-    embedding = cp.array(adata.obsm[obsm_key]).astype(np.float32)
+    embedding = cp.array(adata.obsm[obsm_key]).astype(np.float64)
     original_groups = adata.obs[groupby]
     group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)}
     group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32)
diff --git a/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu b/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
index 9aa2f9f0..48754dd9 100644
--- a/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
+++ b/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
@@ -9,22 +9,22 @@ extern "C" {
  * Each block processes one group pair, threads collaborate within the block
  */
 __global__ void compute_group_distances(
-    const float* __restrict__ embedding,
+    const double* __restrict__ embedding,
     const int* __restrict__ cat_offsets,
     const int* __restrict__ cell_indices,
     const int* __restrict__ pair_left,
     const int* __restrict__ pair_right,
-    float* __restrict__ d_other,
+    double* __restrict__ d_other,
     int k,
     int n_features)
 {
-    extern __shared__ float shared_sums[];
+    extern __shared__ double shared_sums[];
 
     const int thread_id = threadIdx.x;
     const int block_id = blockIdx.x;
     const int block_size = blockDim.x;
 
-    float local_sum = 0.0f;
+    double local_sum = 0.0;
 
     const int a = pair_left[block_id];
     const int b = pair_right[block_id];
@@ -46,14 +46,14 @@ __global__ void compute_group_distances(
         for (int jb = start_b; jb < end_b; ++jb) {
             const int idx_j = cell_indices[jb];
 
-            float dist_sq = 0.0f;
+            double dist_sq = 0.0;
             #pragma unroll
             for (int feat = 0; feat < n_features; ++feat) {
-                float diff = embedding[idx_i * n_features + feat] -
+                double diff = embedding[idx_i * n_features + feat] -
                             embedding[idx_j * n_features + feat];
                 dist_sq += diff * diff;
             }
-            local_sum += sqrtf(dist_sq);
+            local_sum += sqrt(dist_sq);
         }
     }
 
@@ -70,7 +70,7 @@ __global__ void compute_group_distances(
 
     if (thread_id == 0) {
         // Store mean between-group distance
-        d_other[block_id] = shared_sums[0] / (float)(n_a * n_b);
+        d_other[block_id] = shared_sums[0] / (double)(n_a * n_b);
     }
 }
 
diff --git a/tmp_scripts/compute_edistance_standalone.py b/tmp_scripts/compute_edistance_standalone.py
index ffe49760..8049c29a 100644
--- a/tmp_scripts/compute_edistance_standalone.py
+++ b/tmp_scripts/compute_edistance_standalone.py
@@ -28,7 +28,7 @@
     save_dir = os.path.join(
         os.path.expanduser("~"), "data", "adamson_2016_upr_epistasis_pca.h5ad"
     )
-    save_dir_df = os.path.join(os.path.expanduser("~"), "data", "df_cpu.csv")
+    save_dir_df = os.path.join(os.path.expanduser("~"), "data", "df_cpu_float64.csv")
     adata = ad.read_h5ad(save_dir)
     rsc.get.anndata_to_GPU(adata, convert_all=True)
     start_time = time.time()
@@ -43,17 +43,26 @@
     # print(f"Time taken: {end_time - start_time} seconds")
     df = pd.read_csv(save_dir_df, index_col=0)
 
+    is_not_close = []
     groups = adata.obs[obs_key].unique()
-    for group_x in groups:
-        for group_y in groups:
+    k = len(groups)
+    atol = 1e-8
+    for idx1 in range(k):
+        for idx2 in range(idx1 + 1, k):
+            group_x = groups[idx1]
+            group_y = groups[idx2]
             if group_x == group_y:
                 assert df_gpu.loc[group_x, group_y] == 0
             else:
-                assert np.isclose(
-                    df_gpu.loc[group_x, group_y], df.loc[group_x, group_y], atol=1e-3
-                ), (
-                    f"Group df_gpu: {df_gpu.loc[group_x, group_y]}, Group df: {df.loc[group_x, group_y]}"
-                )
+                if not np.isclose(
+                    df_gpu.loc[group_x, group_y], df.loc[group_x, group_y], atol=atol
+                ):
+                    is_not_close.append(
+                        ((group_x, group_y), df_gpu.loc[group_x, group_y], df.loc[group_x, group_y], np.abs(df_gpu.loc[group_x, group_y] - df.loc[group_x, group_y]))
+                    )
+                    print(f"Group df_gpu: {df_gpu.loc[group_x, group_y]}, Group df: {df.loc[group_x, group_y]}, idx: ({idx1}, {idx2})")
+    
+    print("Out of", int(k * (k - 1) / 2), "pairs,", len(is_not_close), "pairs are not close with atol=", atol)
     # print(df.equals(df_gpu))
     # print(df)
     # print(df_gpu)
diff --git a/tmp_scripts/generate_float64_reference.py b/tmp_scripts/generate_float64_reference.py
new file mode 100644
index 00000000..09dfc605
--- /dev/null
+++ b/tmp_scripts/generate_float64_reference.py
@@ -0,0 +1,85 @@
+from __future__ import annotations
+
+import os
+import numpy as np
+import pandas as pd
+import anndata as ad
+from sklearn.metrics import pairwise_distances as sklearn_pairwise_distances
+
+def compute_edistance_sklearn(X, Y, dtype=np.float64):
+    """Compute edistance using sklearn's pairwise_distances with specified precision"""
+    X = np.array(X, dtype=dtype)
+    Y = np.array(Y, dtype=dtype)
+    
+    # Compute pairwise distances using sklearn
+    sigma_X = sklearn_pairwise_distances(X, X, metric='euclidean').mean()
+    sigma_Y = sklearn_pairwise_distances(Y, Y, metric='euclidean').mean()
+    delta = sklearn_pairwise_distances(X, Y, metric='euclidean').mean()
+    
+    return 2 * delta - sigma_X - sigma_Y
+
+def compute_edistance_pairwise_sklearn(adata, groupby, obsm_key="X_pca", dtype=np.float64):
+    """Compute pairwise edistance matrix using sklearn with specified precision"""
+    # Get data and convert to CPU numpy
+    embedding = np.array(adata.obsm[obsm_key], dtype=dtype)
+    
+    groups = adata.obs[groupby].cat.categories
+    k = len(groups)
+    
+    print(f"Computing edistance for {k} groups with dtype {dtype}...")
+    
+    # Build edistance matrix
+    edistance_matrix = np.zeros((k, k), dtype=dtype)
+    
+    for i, group_a in enumerate(groups):
+        mask_a = adata.obs[groupby] == group_a
+        X = embedding[mask_a]
+        
+        for j, group_b in enumerate(groups):
+            if i == j:
+                edistance_matrix[i, j] = 0.0
+            elif i < j:  # Only compute upper triangle
+                mask_b = adata.obs[groupby] == group_b
+                Y = embedding[mask_b]
+                
+                edist = compute_edistance_sklearn(X, Y, dtype=dtype)
+                edistance_matrix[i, j] = edist
+                edistance_matrix[j, i] = edist  # Symmetric
+                
+        if (i + 1) % 5 == 0:
+            print(f"  Processed {i + 1}/{k} groups")
+    
+    return pd.DataFrame(edistance_matrix, index=groups, columns=groups)
+
+if __name__ == "__main__":
+    obs_key = "perturbation"
+    
+    # Load the data
+    save_dir = os.path.join(os.path.expanduser("~"), "data")
+    adata_path = os.path.join(save_dir, "adamson_2016_upr_epistasis_pca.h5ad")
+    
+    print(f"Loading data from {adata_path}...")
+    adata = ad.read_h5ad(adata_path)
+    
+    print(f"Data shape: {adata.shape}")
+    print(f"Groups: {len(adata.obs[obs_key].cat.categories)}")
+    
+    # Generate the float64 reference CSV using sklearn
+    print("\nGenerating float64 reference using sklearn...")
+    df_reference = compute_edistance_pairwise_sklearn(adata, obs_key, obsm_key="X_pca", dtype=np.float64)
+    
+    # Save the float64 reference
+    output_path = os.path.join(save_dir, "df_cpu_float64.csv")
+    df_reference.to_csv(output_path)
+    print(f"Saved float64 reference to: {output_path}")
+    
+    # Show a sample of values
+    print("\nSample values:")
+    groups = df_reference.index[:3]
+    for i, group_a in enumerate(groups):
+        for j, group_b in enumerate(groups):
+            if i < j:
+                val = df_reference.loc[group_a, group_b]
+                print(f"  {group_a} vs {group_b}: {val:.10f}")
+    
+    print("Float64 reference generation complete!")

From 2150b0a8913fe38ef2e680981edd11e49df49a10 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 10 Sep 2025 11:34:23 +0000
Subject: [PATCH 13/23] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 tmp_scripts/compute_edistance_standalone.py | 24 +++++++--
 tmp_scripts/generate_float64_reference.py   | 54 ++++++++++++---------
 2 files changed, 51 insertions(+), 27 deletions(-)

diff --git a/tmp_scripts/compute_edistance_standalone.py b/tmp_scripts/compute_edistance_standalone.py
index 8049c29a..408214a3 100644
--- a/tmp_scripts/compute_edistance_standalone.py
+++ b/tmp_scripts/compute_edistance_standalone.py
@@ -58,11 +58,27 @@
                     df_gpu.loc[group_x, group_y], df.loc[group_x, group_y], atol=atol
                 ):
                     is_not_close.append(
-                        ((group_x, group_y), df_gpu.loc[group_x, group_y], df.loc[group_x, group_y], np.abs(df_gpu.loc[group_x, group_y] - df.loc[group_x, group_y]))
+                        (
+                            (group_x, group_y),
+                            df_gpu.loc[group_x, group_y],
+                            df.loc[group_x, group_y],
+                            np.abs(
+                                df_gpu.loc[group_x, group_y] - df.loc[group_x, group_y]
+                            ),
+                        )
                     )
-                    print(f"Group df_gpu: {df_gpu.loc[group_x, group_y]}, Group df: {df.loc[group_x, group_y]}, idx: ({idx1}, {idx2})")
-    
-    print("Out of", int(k * (k - 1) / 2), "pairs,", len(is_not_close), "pairs are not close with atol=", atol)
+                    print(
+                        f"Group df_gpu: {df_gpu.loc[group_x, group_y]}, Group df: {df.loc[group_x, group_y]}, idx: ({idx1}, {idx2})"
+                    )
+
+    print(
+        "Out of",
+        int(k * (k - 1) / 2),
+        "pairs,",
+        len(is_not_close),
+        "pairs are not close with atol=",
+        atol,
+    )
     # print(df.equals(df_gpu))
     # print(df)
     # print(df_gpu)
diff --git a/tmp_scripts/generate_float64_reference.py b/tmp_scripts/generate_float64_reference.py
index 09dfc605..e3751137 100644
--- a/tmp_scripts/generate_float64_reference.py
+++ b/tmp_scripts/generate_float64_reference.py
@@ -1,78 +1,86 @@
 from __future__ import annotations
 
 import os
+
+import anndata as ad
 import numpy as np
 import pandas as pd
-import anndata as ad
 from sklearn.metrics import pairwise_distances as sklearn_pairwise_distances
 
+
 def compute_edistance_sklearn(X, Y, dtype=np.float64):
     """Compute edistance using sklearn's pairwise_distances with specified precision"""
     X = np.array(X, dtype=dtype)
     Y = np.array(Y, dtype=dtype)
-    
+
     # Compute pairwise distances using sklearn
-    sigma_X = sklearn_pairwise_distances(X, X, metric='euclidean').mean()
-    sigma_Y = sklearn_pairwise_distances(Y, Y, metric='euclidean').mean()
-    delta = sklearn_pairwise_distances(X, Y, metric='euclidean').mean()
-    
+    sigma_X = sklearn_pairwise_distances(X, X, metric="euclidean").mean()
+    sigma_Y = sklearn_pairwise_distances(Y, Y, metric="euclidean").mean()
+    delta = sklearn_pairwise_distances(X, Y, metric="euclidean").mean()
+
     return 2 * delta - sigma_X - sigma_Y
 
-def compute_edistance_pairwise_sklearn(adata, groupby, obsm_key="X_pca", dtype=np.float64):
+
+def compute_edistance_pairwise_sklearn(
+    adata, groupby, obsm_key="X_pca", dtype=np.float64
+):
     """Compute pairwise edistance matrix using sklearn with specified precision"""
     # Get data and convert to CPU numpy
     embedding = np.array(adata.obsm[obsm_key], dtype=dtype)
-    
+
     groups = adata.obs[groupby].cat.categories
     k = len(groups)
-    
+
     print(f"Computing edistance for {k} groups with dtype {dtype}...")
-    
+
     # Build edistance matrix
     edistance_matrix = np.zeros((k, k), dtype=dtype)
-    
+
     for i, group_a in enumerate(groups):
         mask_a = adata.obs[groupby] == group_a
         X = embedding[mask_a]
-        
+
         for j, group_b in enumerate(groups):
             if i == j:
                 edistance_matrix[i, j] = 0.0
             elif i < j:  # Only compute upper triangle
                 mask_b = adata.obs[groupby] == group_b
                 Y = embedding[mask_b]
-                
+
                 edist = compute_edistance_sklearn(X, Y, dtype=dtype)
                 edistance_matrix[i, j] = edist
                 edistance_matrix[j, i] = edist  # Symmetric
-                
+
         if (i + 1) % 5 == 0:
             print(f"  Processed {i + 1}/{k} groups")
-    
+
     return pd.DataFrame(edistance_matrix, index=groups, columns=groups)
 
+
 if __name__ == "__main__":
     obs_key = "perturbation"
-    
+
     # Load the data
     save_dir = os.path.join(os.path.expanduser("~"), "data")
     adata_path = os.path.join(save_dir, "adamson_2016_upr_epistasis_pca.h5ad")
-    
+
     print(f"Loading data from {adata_path}...")
     adata = ad.read_h5ad(adata_path)
-    
+
     print(f"Data shape: {adata.shape}")
     print(f"Groups: {len(adata.obs[obs_key].cat.categories)}")
-    
+
     # Generate the float64 reference CSV using sklearn
     print("\nGenerating float64 reference using sklearn...")
-    df_reference = compute_edistance_pairwise_sklearn(adata, obs_key, obsm_key="X_pca", dtype=np.float64)
-    
+    df_reference = compute_edistance_pairwise_sklearn(
+        adata, obs_key, obsm_key="X_pca", dtype=np.float64
+    )
+
     # Save the float64 reference
     output_path = os.path.join(save_dir, "df_cpu_float64.csv")
     df_reference.to_csv(output_path)
     print(f"Saved float64 reference to: {output_path}")
-    
+
     # Show a sample of values
     print("\nSample values:")
     groups = df_reference.index[:3]
@@ -81,5 +89,5 @@ def compute_edistance_pairwise_sklearn(adata, groupby, obsm_key="X_pca", dtype=n
             if i < j:
                 val = df_reference.loc[group_a, group_b]
                 print(f"  {group_a} vs {group_b}: {val:.10f}")
-    
+
     print("Float64 reference generation complete!")

From 7faef570f7990c1361b3498b5bbbd738313018d2 Mon Sep 17 00:00:00 2001
From: selmanozleyen 
Date: Wed, 10 Sep 2025 13:18:51 +0000
Subject: [PATCH 14/23] bootstrapping

---
 .../pertpy_gpu/_distances.py                  |   2 +-
 .../pertpy_gpu/_distances_standalone.py       | 286 +++++++++++++++++-
 tmp_scripts/compute_edistance.py              |   6 +-
 tmp_scripts/compute_edistance_cpu.py          |  23 +-
 tmp_scripts/compute_edistance_standalone.py   |  50 +--
 5 files changed, 332 insertions(+), 35 deletions(-)

diff --git a/src/rapids_singlecell/pertpy_gpu/_distances.py b/src/rapids_singlecell/pertpy_gpu/_distances.py
index b1b6375b..04c1e5b1 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances.py
+++ b/src/rapids_singlecell/pertpy_gpu/_distances.py
@@ -523,7 +523,7 @@ def _bootstrap_mode_precomputed(
             distance = self.metric_fct.from_precomputed(
                 bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs
             )
-            distances.append(distance.get())
+            distances.append(distance)
 
         mean = np.mean(distances)
         variance = np.var(distances)
diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
index 944fe383..90abedc0 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
+++ b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
@@ -35,7 +35,7 @@ def _load_edistance_kernels():
 compute_group_distances_kernel = _load_edistance_kernels()
 
 
-def compute_d_other_gpu(
+def compute_pairwise_means_gpu(
     embedding: cp.ndarray, cat_offsets: cp.ndarray, cell_indices: cp.ndarray, k: int
 ) -> cp.ndarray:
     """
@@ -122,12 +122,177 @@ def compute_d_other_gpu(
     return pairwise_means
 
 
+def generate_bootstrap_indices(
+    cat_offsets: cp.ndarray,
+    k: int,
+    n_bootstrap: int = 100,
+    random_state: int = 0,
+) -> list[list[cp.ndarray]]:
+    """
+    Generate bootstrap indices for all groups and all bootstrap iterations.
+    This matches the CPU implementation's random sampling logic for reproducibility.
+    
+    Parameters
+    ----------
+    cat_offsets : cp.ndarray
+        Group start/end indices
+    k : int
+        Number of groups
+    n_bootstrap : int
+        Number of bootstrap samples
+    random_state : int
+        Random seed for reproducibility
+        
+    Returns
+    -------
+    bootstrap_indices : list[list[cp.ndarray]]
+        For each bootstrap iteration, list of indices arrays for each group
+        Shape: [n_bootstrap][k] where each element is cp.ndarray of group_size
+    """
+    import numpy as np
+    
+    # Use same RNG logic as CPU code
+    rng = np.random.default_rng(random_state)
+    
+    # Convert to numpy for CPU-based random generation
+    cat_offsets_np = cat_offsets.get()
+    
+    bootstrap_indices = []
+    
+    for bootstrap_iter in range(n_bootstrap):
+        group_indices = []
+        
+        for group_idx in range(k):
+            start_idx = cat_offsets_np[group_idx]
+            end_idx = cat_offsets_np[group_idx + 1]
+            group_size = end_idx - start_idx
+            
+            if group_size > 0:
+                # Generate bootstrap indices using same logic as CPU code
+                # rng.choice(a=X.shape[0], size=X.shape[0], replace=True)
+                bootstrap_group_indices = rng.choice(
+                    group_size, size=group_size, replace=True
+                )
+                # Convert to CuPy array
+                group_indices.append(cp.array(bootstrap_group_indices, dtype=cp.int32))
+            else:
+                # Empty group
+                group_indices.append(cp.array([], dtype=cp.int32))
+        
+        bootstrap_indices.append(group_indices)
+    
+    return bootstrap_indices
+
+
+def _bootstrap_sample_cells_from_indices(
+    *,
+    cat_offsets: cp.ndarray,
+    cell_indices: cp.ndarray,
+    k: int,
+    bootstrap_group_indices: list[cp.ndarray],
+) -> tuple[cp.ndarray, cp.ndarray]:
+    """
+    Bootstrap sample cells using pre-generated indices.
+    
+    Parameters
+    ----------
+    cat_offsets : cp.ndarray
+        Group start/end indices
+    cell_indices : cp.ndarray
+        Sorted cell indices by group
+    k : int
+        Number of groups
+    bootstrap_group_indices : list[cp.ndarray]
+        Pre-generated bootstrap indices for each group
+        
+    Returns
+    -------
+    new_cat_offsets, new_cell_indices : tuple[cp.ndarray, cp.ndarray]
+        New category structure with bootstrapped cells
+    """
+    new_cell_indices = []
+    new_cat_offsets = cp.zeros(k + 1, dtype=cp.int32)
+    
+    for group_idx in range(k):
+        start_idx = cat_offsets[group_idx]
+        end_idx = cat_offsets[group_idx + 1]
+        group_size = end_idx - start_idx
+        
+        if group_size > 0:
+            # Get original cell indices for this group
+            group_cells = cell_indices[start_idx:end_idx]
+            
+            # Use pre-generated bootstrap indices
+            bootstrap_indices = bootstrap_group_indices[group_idx]
+            bootstrap_cells = group_cells[bootstrap_indices]
+            
+            new_cell_indices.extend(bootstrap_cells.get().tolist())
+        
+        new_cat_offsets[group_idx + 1] = len(new_cell_indices)
+    
+    return new_cat_offsets, cp.array(new_cell_indices, dtype=cp.int32)
+
+
+def compute_pairwise_means_gpu_bootstrap(
+    embedding: cp.ndarray,
+    *,
+    cat_offsets: cp.ndarray,
+    cell_indices: cp.ndarray,
+    k: int,
+    n_bootstrap: int = 100,
+    random_state: int = 0,
+) -> tuple[cp.ndarray, cp.ndarray]:
+    """
+    Compute bootstrap statistics for between-group distances.
+    Uses CPU-compatible random generation for reproducibility.
+    
+    Returns:
+        means: [k, k] matrix of bootstrap means
+        variances: [k, k] matrix of bootstrap variances
+    """
+    # Generate all bootstrap indices upfront using CPU-compatible logic
+    bootstrap_indices = generate_bootstrap_indices(
+        cat_offsets, k, n_bootstrap, random_state
+    )
+    
+    bootstrap_results = []
+    
+    for bootstrap_iter in range(n_bootstrap):
+        # Use pre-generated indices for this bootstrap iteration
+        boot_cat_offsets, boot_cell_indices = _bootstrap_sample_cells_from_indices(
+            cat_offsets=cat_offsets,
+            cell_indices=cell_indices,
+            k=k,
+            bootstrap_group_indices=bootstrap_indices[bootstrap_iter],
+        )
+        
+        # Compute distances with bootstrapped samples
+        pairwise_means = compute_pairwise_means_gpu(
+            embedding=embedding,
+            cat_offsets=boot_cat_offsets,
+            cell_indices=boot_cell_indices,
+            k=k,
+        )
+        bootstrap_results.append(pairwise_means.get())
+    
+    # Compute statistics across bootstrap samples
+    bootstrap_stack = cp.array(bootstrap_results)  # [n_bootstrap, k, k]
+    means = cp.mean(bootstrap_stack, axis=0)
+    variances = cp.var(bootstrap_stack, axis=0)
+    
+    return means, variances
+
+
 def pairwise_edistance_gpu(
     adata: AnnData,
     groupby: str,
     *,
     obsm_key: str = "X_pca",
     groups: list[str] | None = None,
+    inplace: bool = False,
+    bootstrap: bool = False,
+    n_bootstrap: int = 100,
+    random_state: int = 0,
 ) -> pd.DataFrame:
     """
     GPU-accelerated pairwise edistance computation with decomposed components.
@@ -169,9 +334,117 @@ def pairwise_edistance_gpu(
     k = len(group_map)
     cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)
 
+    groups_list = (
+        list(original_groups.cat.categories.values) if groups is None else groups
+    )
+    if not bootstrap:
+        df = compute_pairwise_means_gpu_edistance(
+            embedding=embedding,
+            cat_offsets=cat_offsets,
+            cell_indices=cell_indices,
+            k=k,
+            groups_list=groups_list,
+            groupby=groupby,
+        )
+        if inplace:
+            adata.uns[f"{groupby}_pairwise_edistance"] = {
+                "distances": df,
+            }
+        return df
+
+    else:
+        df, df_var = compute_pairwise_means_gpu_edistance_bootstrap(
+            embedding=embedding,
+            cat_offsets=cat_offsets,
+            cell_indices=cell_indices,
+            k=k,
+            groups_list=groups_list,
+            groupby=groupby,
+            n_bootstrap=n_bootstrap,
+            random_state=random_state,
+        )
+
+        if inplace:
+            adata.uns[f"{groupby}_pairwise_edistance"] = {
+                "distances": df,
+                "distances_var": df_var,
+            }
+        return df, df_var
+
+
+def compute_pairwise_means_gpu_edistance_bootstrap(
+    embedding: cp.ndarray,
+    *,
+    cat_offsets: cp.ndarray,
+    cell_indices: cp.ndarray,
+    k: int,
+    groups_list: list[str],
+    groupby: str,
+    n_bootstrap: int = 100,
+    random_state: int = 0,
+) -> tuple[pd.DataFrame, pd.DataFrame]:
+    # Bootstrap computation
+    pairwise_means_boot, pairwise_vars_boot = compute_pairwise_means_gpu_bootstrap(
+        embedding=embedding,
+        cat_offsets=cat_offsets,
+        cell_indices=cell_indices,
+        k=k,
+        n_bootstrap=n_bootstrap,
+        random_state=random_state,
+    )
+
+    # 4. Compute final edistance for means and variances
+    edistance_means = cp.zeros((k, k), dtype=np.float32)
+    edistance_vars = cp.zeros((k, k), dtype=np.float32)
+
+    for a in range(k):
+        for b in range(a + 1, k):
+            # Bootstrap mean edistance
+            edistance_means[a, b] = (
+                2 * pairwise_means_boot[a, b]
+                - pairwise_means_boot[a, a]
+                - pairwise_means_boot[b, b]
+            )
+            edistance_means[b, a] = edistance_means[a, b]
+
+            # Bootstrap variance edistance (using delta method approximation)
+            # Var(2*X - Y - Z) = 4*Var(X) + Var(Y) + Var(Z) (assuming independence)
+            edistance_vars[a, b] = (
+                4 * pairwise_vars_boot[a, b]
+                + pairwise_vars_boot[a, a]
+                + pairwise_vars_boot[b, b]
+            )
+            edistance_vars[b, a] = edistance_vars[a, b]
+
+    # 5. Create output DataFrames
+
+    df_mean = pd.DataFrame(
+        edistance_means.get(), index=groups_list, columns=groups_list
+    )
+    df_mean.index.name = groupby
+    df_mean.columns.name = groupby
+    df_mean.name = "pairwise edistance"
+
+    df_var = pd.DataFrame(edistance_vars.get(), index=groups_list, columns=groups_list)
+    df_var.index.name = groupby
+    df_var.columns.name = groupby
+    df_var.name = "pairwise edistance variance"
+
+    return df_mean, df_var
+
+
+def compute_pairwise_means_gpu_edistance(
+    embedding: cp.ndarray,
+    *,
+    cat_offsets: cp.ndarray,
+    cell_indices: cp.ndarray,
+    k: int,
+    groups_list: list[str],
+    groupby: str,
+) -> pd.DataFrame:
     # 3. Compute decomposed components
     # d_itself = compute_d_itself_gpu(embedding, cat_offsets, cell_indices, k)
-    pairwise_means = compute_d_other_gpu(embedding, cat_offsets, cell_indices, k)
+    pairwise_means = compute_pairwise_means_gpu(embedding, cat_offsets, cell_indices, k)
 
     # 4. Compute final edistance: df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
     edistance_matrix = cp.zeros((k, k), dtype=np.float32)
@@ -183,17 +456,10 @@ def pairwise_edistance_gpu(
             edistance_matrix[b, a] = edistance_matrix[a, b]
 
     # 5. Create output DataFrame
-    groups_list = (
-        list(original_groups.cat.categories.values) if groups is None else groups
-    )
+
     df = pd.DataFrame(edistance_matrix.get(), index=groups_list, columns=groups_list)
     df.index.name = groupby
     df.columns.name = groupby
     df.name = "pairwise edistance"
 
-    # Store in adata
-    adata.uns[f"{groupby}_pairwise_edistance"] = {
-        "distances": df,
-    }
-
     return df
diff --git a/tmp_scripts/compute_edistance.py b/tmp_scripts/compute_edistance.py
index b795b5b6..e325e7a8 100644
--- a/tmp_scripts/compute_edistance.py
+++ b/tmp_scripts/compute_edistance.py
@@ -2,6 +2,7 @@
 
 import os
 import time
+from pathlib import Path
 
 import anndata as ad
 import cupy as cp
@@ -24,9 +25,10 @@
 
     # homedir/data/adamson_2016_upr_epistasis
     save_dir = os.path.join(
-        os.path.expanduser("~"), "data", "adamson_2016_upr_epistasis_pca.h5ad"
+        os.path.expanduser("~"),
+        "data",
     )
-    adata = ad.read_h5ad(save_dir)
+    adata = ad.read_h5ad(Path(save_dir) / "adamson_2016_upr_epistasis_pca.h5ad")
     rsc.get.anndata_to_GPU(adata, convert_all=True)
     dist = Distance(obsm_key="X_pca", metric="edistance")
     start_time = time.time()
diff --git a/tmp_scripts/compute_edistance_cpu.py b/tmp_scripts/compute_edistance_cpu.py
index dc0ad05d..694224d1 100644
--- a/tmp_scripts/compute_edistance_cpu.py
+++ b/tmp_scripts/compute_edistance_cpu.py
@@ -2,20 +2,37 @@
 
 import os
 import time
+from argparse import ArgumentParser
+from pathlib import Path
 
 import anndata as ad
 from pertpy.tools import Distance
 
 if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("--bootstrap", action="store_true")
+    args = parser.parse_args()
     obs_key = "perturbation"
+    bootstrap = args.bootstrap
 
     # homedir/data/adamson_2016_upr_epistasis
     save_dir = os.path.join(
-        os.path.expanduser("~"), "data", "adamson_2016_upr_epistasis_pca.h5ad"
+        os.path.expanduser("~"),
+        "data",
     )
-    adata = ad.read_h5ad(save_dir)
+    adata = ad.read_h5ad(Path(save_dir) / "adamson_2016_upr_epistasis_pca.h5ad")
     dist = Distance(obsm_key="X_pca", metric="edistance")
     start_time = time.time()
-    df = dist.pairwise(adata, groupby=obs_key)
+    if bootstrap:
+        df, df_var = dist.pairwise(
+            adata, groupby=obs_key, bootstrap=True, n_bootstrap=100
+        )
+    else:
+        df = dist.pairwise(adata, groupby=obs_key)
     end_time = time.time()
     print(f"Time taken: {end_time - start_time} seconds")
+    if bootstrap:
+        df_var.to_csv(Path(save_dir) / "df_cpu_bootstrap_var.csv")
+        df.to_csv(Path(save_dir) / "df_cpu_bootstrap.csv")
+    else:
+        df.to_csv(Path(save_dir) / "df_cpu.csv")
diff --git a/tmp_scripts/compute_edistance_standalone.py b/tmp_scripts/compute_edistance_standalone.py
index 408214a3..d618c23a 100644
--- a/tmp_scripts/compute_edistance_standalone.py
+++ b/tmp_scripts/compute_edistance_standalone.py
@@ -9,9 +9,11 @@
 import pandas as pd
 import rmm
 from rmm.allocators.cupy import rmm_cupy_allocator
-
+from argparse import ArgumentParser
+from pathlib import Path
 import rapids_singlecell as rsc
 from rapids_singlecell.pertpy_gpu._distances_standalone import pairwise_edistance_gpu
+from pathlib import Path
 
 rmm.reinitialize(
     managed_memory=False,  # Allows oversubscription
@@ -20,33 +22,43 @@
 )
 cp.cuda.set_allocator(rmm_cupy_allocator)
 
-
 if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("--bootstrap", action="store_true")
+    args = parser.parse_args()
+    bootstrap = args.bootstrap
     obs_key = "perturbation"
+    home_dir = Path(os.path.expanduser("~")) / "data"
 
     # homedir/data/adamson_2016_upr_epistasis
-    save_dir = os.path.join(
-        os.path.expanduser("~"), "data", "adamson_2016_upr_epistasis_pca.h5ad"
-    )
-    save_dir_df = os.path.join(os.path.expanduser("~"), "data", "df_cpu_float64.csv")
+    save_dir = home_dir / "adamson_2016_upr_epistasis_pca.h5ad"
     adata = ad.read_h5ad(save_dir)
     rsc.get.anndata_to_GPU(adata, convert_all=True)
+
+    df_expected = None
+    if bootstrap:
+        df_cpu_bootstrap_var = pd.read_csv(home_dir / "df_cpu_bootstrap_var.csv", index_col=0)
+        df_cpu_bootstrap = pd.read_csv(home_dir / "df_cpu_bootstrap.csv", index_col=0)
+        df_expected = df_cpu_bootstrap
+    else:
+        df_cpu_float64 = pd.read_csv(home_dir / "df_cpu_float64.csv", index_col=0)
+        df_expected = df_cpu_float64
+
+
     start_time = time.time()
-    df_gpu = pairwise_edistance_gpu(adata, groupby=obs_key, obsm_key="X_pca")
+    if not bootstrap:
+        df_gpu = pairwise_edistance_gpu(adata, groupby=obs_key, obsm_key="X_pca", bootstrap=bootstrap)
+    else:
+        df_gpu, df_gpu_var = pairwise_edistance_gpu(adata, groupby=obs_key, obsm_key="X_pca", bootstrap=bootstrap, n_bootstrap=100)
     end_time = time.time()
     print(f"Time taken: {end_time - start_time} seconds")
-    # print("CPU time")
-    # dist = Distance(obsm_key="X_pca", metric="edistance")
-    # start_time = time.time()
-    # df_cpu = dist.pairwise(adata, groupby=obs_key)
-    # end_time = time.time()
-    # print(f"Time taken: {end_time - start_time} seconds")
-    df = pd.read_csv(save_dir_df, index_col=0)
+
+  
 
     is_not_close = []
     groups = adata.obs[obs_key].unique()
     k = len(groups)
-    atol = 1e-8
+    atol = 1e-8 if not bootstrap else 1e-2
     for idx1 in range(k):
         for idx2 in range(idx1 + 1, k):
             group_x = groups[idx1]
@@ -55,20 +67,20 @@
                 assert df_gpu.loc[group_x, group_y] == 0
             else:
                 if not np.isclose(
-                    df_gpu.loc[group_x, group_y], df.loc[group_x, group_y], atol=atol
+                    df_gpu.loc[group_x, group_y], df_expected.loc[group_x, group_y], atol=atol
                 ):
                     is_not_close.append(
                         (
                             (group_x, group_y),
+                            df_expected.loc[group_x, group_y],
                             df_gpu.loc[group_x, group_y],
-                            df.loc[group_x, group_y],
                             np.abs(
-                                df_gpu.loc[group_x, group_y] - df.loc[group_x, group_y]
+                                df_expected.loc[group_x, group_y] - df_gpu.loc[group_x, group_y]
                             ),
                         )
                     )
                     print(
-                        f"Group df_gpu: {df_gpu.loc[group_x, group_y]}, Group df: {df.loc[group_x, group_y]}, idx: ({idx1}, {idx2})"
+                        f"Group df_gpu: {df_gpu.loc[group_x, group_y]}, Group df: {df_expected.loc[group_x, group_y]}, idx: ({idx1}, {idx2})"
                     )
 
     print(

From fb10db64b5fc6197d7d2f83211473ab6897f3dbd Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 10 Sep 2025 13:19:01 +0000
Subject: [PATCH 15/23] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 .../pertpy_gpu/_distances_standalone.py       | 48 +++++++++----------
 tmp_scripts/compute_edistance_standalone.py   | 32 ++++++++-----
 2 files changed, 45 insertions(+), 35 deletions(-)

diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
index 90abedc0..4e781b2d 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
+++ b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
@@ -131,7 +131,7 @@ def generate_bootstrap_indices(
     """
     Generate bootstrap indices for all groups and all bootstrap iterations.
     This matches the CPU implementation's random sampling logic for reproducibility.
-    
+
     Parameters
     ----------
     cat_offsets : cp.ndarray
@@ -142,7 +142,7 @@ def generate_bootstrap_indices(
         Number of bootstrap samples
     random_state : int
         Random seed for reproducibility
-        
+
     Returns
     -------
     bootstrap_indices : list[list[cp.ndarray]]
@@ -150,23 +150,23 @@ def generate_bootstrap_indices(
         Shape: [n_bootstrap][k] where each element is cp.ndarray of group_size
     """
     import numpy as np
-    
+
     # Use same RNG logic as CPU code
     rng = np.random.default_rng(random_state)
-    
+
     # Convert to numpy for CPU-based random generation
     cat_offsets_np = cat_offsets.get()
-    
+
     bootstrap_indices = []
-    
+
     for bootstrap_iter in range(n_bootstrap):
         group_indices = []
-        
+
         for group_idx in range(k):
             start_idx = cat_offsets_np[group_idx]
             end_idx = cat_offsets_np[group_idx + 1]
             group_size = end_idx - start_idx
-            
+
             if group_size > 0:
                 # Generate bootstrap indices using same logic as CPU code
                 # rng.choice(a=X.shape[0], size=X.shape[0], replace=True)
@@ -178,9 +178,9 @@ def generate_bootstrap_indices(
             else:
                 # Empty group
                 group_indices.append(cp.array([], dtype=cp.int32))
-        
+
         bootstrap_indices.append(group_indices)
-    
+
     return bootstrap_indices
 
 
@@ -193,7 +193,7 @@ def _bootstrap_sample_cells_from_indices(
 ) -> tuple[cp.ndarray, cp.ndarray]:
     """
     Bootstrap sample cells using pre-generated indices.
-    
+
     Parameters
     ----------
     cat_offsets : cp.ndarray
@@ -204,7 +204,7 @@ def _bootstrap_sample_cells_from_indices(
         Number of groups
     bootstrap_group_indices : list[cp.ndarray]
         Pre-generated bootstrap indices for each group
-        
+
     Returns
     -------
     new_cat_offsets, new_cell_indices : tuple[cp.ndarray, cp.ndarray]
@@ -212,24 +212,24 @@ def _bootstrap_sample_cells_from_indices(
     """
     new_cell_indices = []
     new_cat_offsets = cp.zeros(k + 1, dtype=cp.int32)
-    
+
     for group_idx in range(k):
         start_idx = cat_offsets[group_idx]
         end_idx = cat_offsets[group_idx + 1]
         group_size = end_idx - start_idx
-        
+
         if group_size > 0:
             # Get original cell indices for this group
             group_cells = cell_indices[start_idx:end_idx]
-            
+
             # Use pre-generated bootstrap indices
             bootstrap_indices = bootstrap_group_indices[group_idx]
             bootstrap_cells = group_cells[bootstrap_indices]
-            
+
             new_cell_indices.extend(bootstrap_cells.get().tolist())
-        
+
         new_cat_offsets[group_idx + 1] = len(new_cell_indices)
-    
+
     return new_cat_offsets, cp.array(new_cell_indices, dtype=cp.int32)
 
 
@@ -245,7 +245,7 @@ def compute_pairwise_means_gpu_bootstrap(
     """
     Compute bootstrap statistics for between-group distances.
     Uses CPU-compatible random generation for reproducibility.
-    
+
     Returns:
         means: [k, k] matrix of bootstrap means
         variances: [k, k] matrix of bootstrap variances
@@ -254,9 +254,9 @@ def compute_pairwise_means_gpu_bootstrap(
     bootstrap_indices = generate_bootstrap_indices(
         cat_offsets, k, n_bootstrap, random_state
     )
-    
+
     bootstrap_results = []
-    
+
     for bootstrap_iter in range(n_bootstrap):
         # Use pre-generated indices for this bootstrap iteration
         boot_cat_offsets, boot_cell_indices = _bootstrap_sample_cells_from_indices(
@@ -265,7 +265,7 @@ def compute_pairwise_means_gpu_bootstrap(
             k=k,
             bootstrap_group_indices=bootstrap_indices[bootstrap_iter],
         )
-        
+
         # Compute distances with bootstrapped samples
         pairwise_means = compute_pairwise_means_gpu(
             embedding=embedding,
@@ -274,12 +274,12 @@ def compute_pairwise_means_gpu_bootstrap(
             k=k,
         )
         bootstrap_results.append(pairwise_means.get())
-    
+
     # Compute statistics across bootstrap samples
     bootstrap_stack = cp.array(bootstrap_results)  # [n_bootstrap, k, k]
     means = cp.mean(bootstrap_stack, axis=0)
     variances = cp.var(bootstrap_stack, axis=0)
-    
+
     return means, variances
 
 
diff --git a/tmp_scripts/compute_edistance_standalone.py b/tmp_scripts/compute_edistance_standalone.py
index d618c23a..2f668a57 100644
--- a/tmp_scripts/compute_edistance_standalone.py
+++ b/tmp_scripts/compute_edistance_standalone.py
@@ -2,6 +2,8 @@
 
 import os
 import time
+from argparse import ArgumentParser
+from pathlib import Path
 
 import anndata as ad
 import cupy as cp
@@ -9,11 +11,9 @@
 import pandas as pd
 import rmm
 from rmm.allocators.cupy import rmm_cupy_allocator
-from argparse import ArgumentParser
-from pathlib import Path
+
 import rapids_singlecell as rsc
 from rapids_singlecell.pertpy_gpu._distances_standalone import pairwise_edistance_gpu
-from pathlib import Path
 
 rmm.reinitialize(
     managed_memory=False,  # Allows oversubscription
@@ -37,24 +37,31 @@
 
     df_expected = None
     if bootstrap:
-        df_cpu_bootstrap_var = pd.read_csv(home_dir / "df_cpu_bootstrap_var.csv", index_col=0)
+        df_cpu_bootstrap_var = pd.read_csv(
+            home_dir / "df_cpu_bootstrap_var.csv", index_col=0
+        )
         df_cpu_bootstrap = pd.read_csv(home_dir / "df_cpu_bootstrap.csv", index_col=0)
         df_expected = df_cpu_bootstrap
     else:
         df_cpu_float64 = pd.read_csv(home_dir / "df_cpu_float64.csv", index_col=0)
         df_expected = df_cpu_float64
 
-
     start_time = time.time()
     if not bootstrap:
-        df_gpu = pairwise_edistance_gpu(adata, groupby=obs_key, obsm_key="X_pca", bootstrap=bootstrap)
+        df_gpu = pairwise_edistance_gpu(
+            adata, groupby=obs_key, obsm_key="X_pca", bootstrap=bootstrap
+        )
     else:
-        df_gpu, df_gpu_var = pairwise_edistance_gpu(adata, groupby=obs_key, obsm_key="X_pca", bootstrap=bootstrap, n_bootstrap=100)
+        df_gpu, df_gpu_var = pairwise_edistance_gpu(
+            adata,
+            groupby=obs_key,
+            obsm_key="X_pca",
+            bootstrap=bootstrap,
+            n_bootstrap=100,
+        )
     end_time = time.time()
     print(f"Time taken: {end_time - start_time} seconds")
 
-  
-
     is_not_close = []
     groups = adata.obs[obs_key].unique()
     k = len(groups)
@@ -67,7 +74,9 @@
                 assert df_gpu.loc[group_x, group_y] == 0
             else:
                 if not np.isclose(
-                    df_gpu.loc[group_x, group_y], df_expected.loc[group_x, group_y], atol=atol
+                    df_gpu.loc[group_x, group_y],
+                    df_expected.loc[group_x, group_y],
+                    atol=atol,
                 ):
                     is_not_close.append(
                         (
@@ -75,7 +84,8 @@
                             df_expected.loc[group_x, group_y],
                             df_gpu.loc[group_x, group_y],
                             np.abs(
-                                df_expected.loc[group_x, group_y] - df_gpu.loc[group_x, group_y]
+                                df_expected.loc[group_x, group_y]
+                                - df_gpu.loc[group_x, group_y]
                             ),
                         )
                     )

From 6b7d7b270d6c435e2b7f537c40c0c9e4f99989db Mon Sep 17 00:00:00 2001
From: selmanozleyen 
Date: Wed, 10 Sep 2025 13:38:14 +0000
Subject: [PATCH 16/23] swich to f32

---
 .../pertpy_gpu/_distances_standalone.py          |  9 ++++-----
 .../pertpy_gpu/kernels/edistance_kernels.cu      | 16 ++++++++--------
 2 files changed, 12 insertions(+), 13 deletions(-)

diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
index 4e781b2d..a76f1374 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
+++ b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
@@ -76,7 +76,7 @@ def compute_pairwise_means_gpu(
     num_pairs = len(pair_left)  # k * (k-1) pairs instead of k²
 
     # Allocate output for off-diagonal distances only
-    d_other_offdiag = cp.zeros(num_pairs, dtype=np.float64)
+    d_other_offdiag = cp.zeros(num_pairs, dtype=np.float32)
 
     # Choose optimal block size
     props = cp.cuda.runtime.getDeviceProperties(0)
@@ -85,7 +85,7 @@ def compute_pairwise_means_gpu(
     chosen_threads = None
     shared_mem_size = 0  # TODO: think of a better way to do this
     for tpb in (1024, 512, 256, 128, 64, 32):
-        required = tpb * cp.dtype(cp.float64).itemsize
+        required = tpb * cp.dtype(cp.float32).itemsize
         if required <= max_smem:
             chosen_threads = tpb
             shared_mem_size = required
@@ -111,7 +111,7 @@ def compute_pairwise_means_gpu(
     )
 
     # Build full k x k matrix
-    pairwise_means = cp.zeros((k, k), dtype=np.float64)
+    pairwise_means = cp.zeros((k, k), dtype=np.float32)
 
     # Fill the full matrix
     for i, idx in enumerate(pair_indices.get()):
@@ -322,10 +322,9 @@ def pairwise_edistance_gpu(
     df : pd.DataFrame
         Final edistance matrix
     """
-    # 1. Prepare data (same as original)
     _assert_categorical_obs(adata, key=groupby)
 
-    embedding = cp.array(adata.obsm[obsm_key]).astype(np.float64)
+    embedding = cp.array(adata.obsm[obsm_key]).astype(np.float32)  # Changed from float64
     original_groups = adata.obs[groupby]
     group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)}
     group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32)
diff --git a/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu b/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
index 48754dd9..9aa2f9f0 100644
--- a/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
+++ b/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
@@ -9,22 +9,22 @@ extern "C" {
  * Each block processes one group pair, threads collaborate within the block
  */
 __global__ void compute_group_distances(
-    const double* __restrict__ embedding,
+    const float* __restrict__ embedding,
     const int* __restrict__ cat_offsets,
     const int* __restrict__ cell_indices,
     const int* __restrict__ pair_left,
     const int* __restrict__ pair_right,
-    double* __restrict__ d_other,
+    float* __restrict__ d_other,
     int k,
     int n_features)
 {
-    extern __shared__ double shared_sums[];
+    extern __shared__ float shared_sums[];
 
     const int thread_id = threadIdx.x;
     const int block_id = blockIdx.x;
     const int block_size = blockDim.x;
 
-    double local_sum = 0.0;
+    float local_sum = 0.0f;
 
     const int a = pair_left[block_id];
     const int b = pair_right[block_id];
@@ -46,14 +46,14 @@ __global__ void compute_group_distances(
         for (int jb = start_b; jb < end_b; ++jb) {
             const int idx_j = cell_indices[jb];
 
-            double dist_sq = 0.0;
+            float dist_sq = 0.0f;
             #pragma unroll
             for (int feat = 0; feat < n_features; ++feat) {
-                double diff = embedding[idx_i * n_features + feat] -
+                float diff = embedding[idx_i * n_features + feat] -
                             embedding[idx_j * n_features + feat];
                 dist_sq += diff * diff;
             }
-            local_sum += sqrt(dist_sq);
+            local_sum += sqrtf(dist_sq);
         }
     }
 
@@ -70,7 +70,7 @@ __global__ void compute_group_distances(
 
     if (thread_id == 0) {
         // Store mean between-group distance
-        d_other[block_id] = shared_sums[0] / (double)(n_a * n_b);
+        d_other[block_id] = shared_sums[0] / (float)(n_a * n_b);
     }
 }
 

From 3bbf06a3140b39c5e7e50d7c4ae6d9c542a1b9fb Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 10 Sep 2025 13:38:24 +0000
Subject: [PATCH 17/23] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 src/rapids_singlecell/pertpy_gpu/_distances_standalone.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
index a76f1374..d532314f 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
+++ b/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
@@ -324,7 +324,9 @@ def pairwise_edistance_gpu(
     """
     _assert_categorical_obs(adata, key=groupby)
 
-    embedding = cp.array(adata.obsm[obsm_key]).astype(np.float32)  # Changed from float64
+    embedding = cp.array(adata.obsm[obsm_key]).astype(
+        np.float32
+    )  # Changed from float64
     original_groups = adata.obs[groupby]
     group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)}
     group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32)

From 7a3c003a0805169103c5afce911ce1983bb31d35 Mon Sep 17 00:00:00 2001
From: selmanozleyen 
Date: Sun, 28 Sep 2025 03:20:28 +0000
Subject: [PATCH 18/23] refactor

---
 src/rapids_singlecell/pertpy_gpu/__init__.py  |   2 +-
 .../pertpy_gpu/_distances.py                  | 732 ------------------
 ..._distances_standalone.py => _edistance.py} | 214 +++--
 tmp_scripts/compute_edistance_standalone.py   |   9 +-
 4 files changed, 113 insertions(+), 844 deletions(-)
 delete mode 100644 src/rapids_singlecell/pertpy_gpu/_distances.py
 rename src/rapids_singlecell/pertpy_gpu/{_distances_standalone.py => _edistance.py} (87%)

diff --git a/src/rapids_singlecell/pertpy_gpu/__init__.py b/src/rapids_singlecell/pertpy_gpu/__init__.py
index 6324ee3c..90b92ca7 100644
--- a/src/rapids_singlecell/pertpy_gpu/__init__.py
+++ b/src/rapids_singlecell/pertpy_gpu/__init__.py
@@ -1,3 +1,3 @@
 from __future__ import annotations
 
-from ._distances import Distance
+from ._edistance import pertpy_edistance
diff --git a/src/rapids_singlecell/pertpy_gpu/_distances.py b/src/rapids_singlecell/pertpy_gpu/_distances.py
deleted file mode 100644
index 04c1e5b1..00000000
--- a/src/rapids_singlecell/pertpy_gpu/_distances.py
+++ /dev/null
@@ -1,732 +0,0 @@
-from __future__ import annotations
-
-from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Literal, NamedTuple
-
-import cupy as cp
-import numpy as np
-import pandas as pd
-from cuml.metrics import pairwise_distances
-from cuml.preprocessing import MinMaxScaler
-from cupy.random import choice
-from cupyx.scipy.sparse import issparse as cp_issparse
-
-# TODO(selmanozleyen): adapt which progress bar to use, probably use rapidsinglecell's progress bar (if it exists)
-# from rich.progress import track
-# from scipy.sparse import issparse
-# from scipy.spatial.distance import cosine, mahalanobis
-# from scipy.special import gammaln
-# from scipy.stats import kendalltau, kstest, pearsonr, spearmanr
-# from sklearn.linear_model import LogisticRegression
-# from sklearn.metrics import pairwise_distances, r2_score
-# from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel
-# from sklearn.neighbors import KernelDensity
-# from statsmodels.discrete.discrete_model import NegativeBinomialP
-
-if TYPE_CHECKING:
-    from collections.abc import Callable
-
-    from anndata import AnnData
-
-
-class MeanVar(NamedTuple):
-    mean: float
-    variance: float
-
-
-Metric = Literal[
-    "edistance",
-    # "euclidean",
-    # "root_mean_squared_error",
-    # "mse",
-    # "mean_absolute_error",
-    # "pearson_distance",
-    # "spearman_distance",
-    # "kendalltau_distance",
-    # "cosine_distance",
-    # "r2_distance",
-    # "mean_pairwise",
-    # "mmd",
-    # "sym_kldiv",
-    # "t_test",
-    # "ks_test",
-    # "nb_ll",
-    # "classifier_proba",
-    # "classifier_cp",
-    # "mean_var_distribution",
-    # "mahalanobis",
-]
-
-
-class Distance:
-    """Distance class, used to compute distances between groups of cells.
-
-    The distance metric can be specified by the user. This class also provides a
-    method to compute the pairwise distances between all groups of cells.
-    Currently available metrics:
-
-    - "edistance": Energy distance (Default metric).
-        In essence, it is twice the mean pairwise distance between cells of two
-        groups minus the mean pairwise distance between cells within each group
-        respectively. More information can be found in
-        `Peidli et al. (2023) `__.
-    - "euclidean": euclidean distance.
-        Euclidean distance between the means of cells from two groups.
-    - "root_mean_squared_error": euclidean distance.
-        Euclidean distance between the means of cells from two groups.
-    - "mse": Pseudobulk mean squared error.
-        mean squared distance between the means of cells from two groups.
-    - "mean_absolute_error": Pseudobulk mean absolute distance.
-        Mean absolute distance between the means of cells from two groups.
-    - "pearson_distance": Pearson distance.
-        Pearson distance between the means of cells from two groups.
-    - "spearman_distance": Spearman distance.
-        Spearman distance between the means of cells from two groups.
-    - "kendalltau_distance": Kendall tau distance.
-        Kendall tau distance between the means of cells from two groups.
-    - "cosine_distance": Cosine distance.
-        Cosine distance between the means of cells from two groups.
-    - "r2_distance": coefficient of determination distance.
-        Coefficient of determination distance between the means of cells from two groups.
-    - "mean_pairwise": Mean pairwise distance.
-        Mean of the pairwise euclidean distances between cells of two groups.
-    - "mmd": Maximum mean discrepancy
-        Maximum mean discrepancy between the cells of two groups.
-        Here, uses linear, rbf, and quadratic polynomial MMD. For theory on MMD in single-cell applications, see
-        `Lotfollahi et al. (2019) `__.
-    - "wasserstein": Wasserstein distance (Earth Mover's Distance)
-        Wasserstein distance between the cells of two groups. Uses an
-        OTT-JAX implementation of the Sinkhorn algorithm to compute the distance.
-        For more information on the optimal transport solver, see
-        `Cuturi et al. (2013) `__.
-    - "sym_kldiv": symmetrized Kullback–Leibler divergence distance.
-        Kullback–Leibler divergence of the gaussian distributions between cells of two groups.
-        Here we fit a gaussian distribution over one group of cells and then calculate the KL divergence on the other, and vice versa.
-    - "t_test": t-test statistic.
-        T-test statistic measure between cells of two groups.
-    - "ks_test": Kolmogorov-Smirnov test statistic.
-        Kolmogorov-Smirnov test statistic measure between cells of two groups.
-    - "nb_ll": log-likelihood over negative binomial
-        Average of log-likelihoods of samples of the secondary group after fitting a negative binomial distribution
-        over the samples of the first group.
-    - "classifier_proba": probability of a binary classifier
-        Average of the classification probability of the perturbation for a binary classifier.
-    - "classifier_cp": classifier class projection
-        Average of the class
-    - "mean_var_distribution": Distance between mean-variance distributions between cells of 2 groups.
-       Mean square distance between the mean-variance distributions of cells from 2 groups using Kernel Density Estimation (KDE).
-    - "mahalanobis": Mahalanobis distance between the means of cells from two groups.
-        It is originally used to measure distance between a point and a distribution.
-        in this context, it quantifies the difference between the mean profiles of a target group and a reference group.
-
-    Attributes:
-        metric: Name of distance metric.
-        layer_key: Name of the counts to use in adata.layers.
-        obsm_key: Name of embedding in adata.obsm to use.
-        cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells.
-
-    Examples:
-        >>> import pertpy as pt
-        >>> adata = pt.dt.distance_example()
-        >>> Distance = pt.tools.Distance(metric="edistance")
-        >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
-        >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
-        >>> D = Distance(X, Y)
-    """
-
-    def __init__(
-        self,
-        metric: Metric = "edistance",
-        agg_fct: Callable = cp.mean,
-        layer_key: str = None,
-        obsm_key: str = None,
-        cell_wise_metric: str = "euclidean",
-    ):
-        """Initialize Distance class.
-
-        Args:
-            metric: Distance metric to use.
-            agg_fct: Aggregation function to generate pseudobulk vectors.
-            layer_key: Name of the counts layer containing raw counts to calculate distances for.
-                              Mutually exclusive with 'obsm_key'.
-                              Is not used if `None`.
-            obsm_key: Name of embedding in adata.obsm to use.
-                      Mutually exclusive with 'layer_key'.
-                      Defaults to None, but is set to "X_pca" if not explicitly set internally.
-            cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells.
-        """
-        metric_fct: AbstractDistance = None
-        self.aggregation_func = agg_fct
-        # elif metric in ("euclidean", "root_mean_squared_error"):
-        #     metric_fct = EuclideanDistance(self.aggregation_func)
-        # elif metric == "mse":
-        #     metric_fct = MeanSquaredDistance(self.aggregation_func)
-        # elif metric == "mean_absolute_error":
-        #     metric_fct = MeanAbsoluteDistance(self.aggregation_func)
-        # elif metric == "pearson_distance":
-        #     metric_fct = PearsonDistance(self.aggregation_func)
-        # elif metric == "spearman_distance":
-        #     metric_fct = SpearmanDistance(self.aggregation_func)
-        # elif metric == "kendalltau_distance":
-        #     metric_fct = KendallTauDistance(self.aggregation_func)
-        # elif metric == "cosine_distance":
-        #     metric_fct = CosineDistance(self.aggregation_func)
-        # elif metric == "r2_distance":
-        #     metric_fct = R2ScoreDistance(self.aggregation_func)
-        # elif metric == "mean_pairwise":
-        #     metric_fct = MeanPairwiseDistance()
-        # elif metric == "mmd":
-        #     metric_fct = MMD()
-        # elif metric == "sym_kldiv":
-        #     metric_fct = SymmetricKLDivergence()
-        # elif metric == "t_test":
-        #     metric_fct = TTestDistance()
-        # elif metric == "ks_test":
-        #     metric_fct = KSTestDistance()
-        # elif metric == "nb_ll":
-        #     metric_fct = NBLL()
-        # elif metric == "classifier_proba":
-        #     metric_fct = ClassifierProbaDistance()
-        # elif metric == "classifier_cp":
-        #     metric_fct = ClassifierClassProjection()
-        # elif metric == "mean_var_distribution":
-        #     metric_fct = MeanVarDistributionDistance()
-        # elif metric == "mahalanobis":
-        #     metric_fct = MahalanobisDistance(self.aggregation_func)
-        if metric == "edistance":
-            metric_fct = Edistance()
-        else:
-            raise ValueError(f"Metric {metric} not recognized.")
-        self.metric_fct = metric_fct
-
-        if layer_key and obsm_key:
-            raise ValueError(
-                "Cannot use 'layer_key' and 'obsm_key' at the same time.\nPlease provide only one of the two keys."
-            )
-        if not layer_key and not obsm_key:
-            obsm_key = "X_pca"
-        self.layer_key = layer_key
-        self.obsm_key = obsm_key
-        self.metric = metric
-        self.cell_wise_metric = cell_wise_metric
-
-    def __call__(
-        self,
-        X: cp.ndarray | cp.sparse.spmatrix,
-        Y: cp.ndarray | cp.sparse.spmatrix,
-        **kwargs,
-    ) -> float:
-        """Compute distance between vectors X and Y.
-
-        Args:
-            X: First vector of shape (n_samples, n_features).
-            Y: Second vector of shape (n_samples, n_features).
-            kwargs: Passed to the metric function.
-
-        Returns:
-            float: Distance between X and Y.
-
-        Examples:
-            >>> import pertpy as pt
-            >>> adata = pt.dt.distance_example()
-            >>> Distance = pt.tools.Distance(metric="edistance")
-            >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
-            >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
-            >>> D = Distance(X, Y)
-        """
-        if cp_issparse(X):
-            X = X.toarray()
-        if cp_issparse(Y):
-            Y = Y.toarray()
-
-        if len(X) == 0 or len(Y) == 0:
-            raise ValueError("Neither X nor Y can be empty.")
-
-        return self.metric_fct(X, Y, **kwargs)
-
-    def bootstrap(
-        self,
-        X: cp.ndarray,
-        Y: cp.ndarray,
-        *,
-        n_bootstrap: int = 100,
-        random_state: int = 0,
-        **kwargs,
-    ) -> MeanVar:
-        """Bootstrap computation of mean and variance of the distance between vectors X and Y.
-
-        Args:
-            X: First vector of shape (n_samples, n_features).
-            Y: Second vector of shape (n_samples, n_features).
-            n_bootstrap: Number of bootstrap samples.
-            random_state: Random state for bootstrapping.
-            **kwargs: Passed to the metric function.
-
-        Returns:
-            Mean and variance of distance between X and Y.
-
-        Examples:
-            >>> import pertpy as pt
-            >>> adata = pt.dt.distance_example()
-            >>> Distance = pt.tools.Distance(metric="edistance")
-            >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
-            >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
-            >>> D = Distance.bootstrap(X, Y)
-        """
-        return self._bootstrap_mode(
-            X,
-            Y,
-            n_bootstraps=n_bootstrap,
-            random_state=random_state,
-            **kwargs,
-        )
-
-    def pairwise(
-        self,
-        adata: AnnData,
-        groupby: str,
-        groups: list[str] | None = None,
-        bootstrap: bool = False,
-        n_bootstrap: int = 100,
-        random_state: int = 0,
-        show_progressbar: bool = True,
-        n_jobs: int = -1,
-        **kwargs,
-    ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
-        """Get pairwise distances between groups of cells.
-
-        Args:
-            adata: Annotated data matrix.
-            groupby: Column name in adata.obs.
-            groups: List of groups to compute pairwise distances for.
-                    If None, uses all groups.
-            bootstrap: Whether to bootstrap the distance.
-            n_bootstrap: Number of bootstrap samples.
-            random_state: Random state for bootstrapping.
-            show_progressbar: Whether to show progress bar.
-            n_jobs: Number of cores to use. Defaults to -1 (all).
-            kwargs: Additional keyword arguments passed to the metric function.
-
-        Returns:
-            :class:`pandas.DataFrame`: Dataframe with pairwise distances.
-            tuple[:class:`pandas.DataFrame`, :class:`pandas.DataFrame`]: Two Dataframes, one for the mean and one for the variance of pairwise distances.
-
-        Examples:
-            >>> import pertpy as pt
-            >>> adata = pt.dt.distance_example()
-            >>> Distance = pt.tools.Distance(metric="edistance")
-            >>> pairwise_df = Distance.pairwise(adata, groupby="perturbation")
-        """
-        groups = adata.obs[groupby].unique() if groups is None else groups
-        grouping = adata.obs[groupby].copy()
-        df = pd.DataFrame(index=groups, columns=groups, dtype=float)
-        if bootstrap:
-            df_var = pd.DataFrame(index=groups, columns=groups, dtype=float)
-        # fct = track if show_progressbar else lambda iterable: iterable
-        fct = lambda iterable: iterable  # see TODO above about progress bar
-
-        # Some metrics are able to handle precomputed distances. This means that
-        # the pairwise distances between all cells are computed once and then
-        # passed to the metric function. This is much faster than computing the
-        # pairwise distances for each group separately. Other metrics are not
-        # able to handle precomputed distances such as the PseudobulkDistance.
-        if self.metric_fct.accepts_precomputed:
-            # Precompute the pairwise distances if needed
-            if (
-                f"{self.obsm_key}_{self.cell_wise_metric}_predistances"
-                not in adata.obsp
-            ):
-                self.precompute_distances(adata, n_jobs=n_jobs, **kwargs)
-            pwd = adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"]
-            for index_x, group_x in enumerate(fct(groups)):
-                idx_x = grouping == group_x
-                for group_y in groups[index_x:]:  # type: ignore
-                    # subset the pairwise distance matrix to the two groups
-                    idx_y = grouping == group_y
-                    sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
-                    sub_idx = grouping[idx_x | idx_y] == group_x
-                    if not bootstrap:
-                        if group_x == group_y:
-                            dist = 0.0
-                        else:
-                            dist = self.metric_fct.from_precomputed(
-                                sub_pwd, sub_idx, **kwargs
-                            )
-                        df.loc[group_x, group_y] = dist
-                        df.loc[group_y, group_x] = dist
-
-                    else:
-                        bootstrap_output = self._bootstrap_mode_precomputed(
-                            sub_pwd,
-                            sub_idx,
-                            n_bootstraps=n_bootstrap,
-                            random_state=random_state,
-                            **kwargs,
-                        )
-                        # In the bootstrap case, distance of group to itself is a mean and can be non-zero
-                        df.loc[group_x, group_y] = df.loc[group_y, group_x] = (
-                            bootstrap_output.mean
-                        )
-                        df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = (
-                            bootstrap_output.variance
-                        )
-        else:
-            embedding = (
-                adata.layers[self.layer_key]
-                if self.layer_key
-                else adata.obsm[self.obsm_key].copy()
-            )
-            for index_x, group_x in enumerate(fct(groups)):
-                cells_x = embedding[cp.asarray(grouping == group_x)].copy()
-                for group_y in groups[index_x:]:  # type: ignore
-                    cells_y = embedding[cp.asarray(grouping == group_y)].copy()
-                    if not bootstrap:
-                        # By distance axiom, the distance between a group and itself is 0
-                        dist = (
-                            0.0
-                            if group_x == group_y
-                            else self(cells_x, cells_y, **kwargs)
-                        )
-
-                        df.loc[group_x, group_y] = dist
-                        df.loc[group_y, group_x] = dist
-                    else:
-                        bootstrap_output = self.bootstrap(
-                            cells_x,
-                            cells_y,
-                            n_bootstrap=n_bootstrap,
-                            random_state=random_state,
-                            **kwargs,
-                        )
-                        # In the bootstrap case, distance of group to itself is a mean and can be non-zero
-                        df.loc[group_x, group_y] = df.loc[group_y, group_x] = (
-                            bootstrap_output.mean
-                        )
-                        df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = (
-                            bootstrap_output.variance
-                        )
-
-        df.index.name = groupby
-        df.columns.name = groupby
-        df.name = f"pairwise {self.metric}"
-
-        if not bootstrap:
-            return df
-        else:
-            df = df.fillna(0)
-            df_var.index.name = groupby
-            df_var.columns.name = groupby
-            df_var = df_var.fillna(0)
-            df_var.name = f"pairwise {self.metric} variance"
-
-            return df, df_var
-
-    def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None:
-        """Precompute pairwise distances between all cells, writes to adata.obsp.
-
-        The precomputed distances are stored in adata.obsp under the key
-        '{self.obsm_key}_{cell_wise_metric}_predistances', as they depend on
-        both the cell-wise metric and the embedding used.
-
-        Args:
-            adata: Annotated data matrix.
-            n_jobs: Number of cores to use. Defaults to -1 (all).
-
-        Examples:
-            >>> import pertpy as pt
-            >>> adata = pt.dt.distance_example()
-            >>> distance = pt.tools.Distance(metric="edistance")
-            >>> distance.precompute_distances(adata)
-        """
-        cells = (
-            adata.layers[self.layer_key]
-            if self.layer_key
-            else adata.obsm[self.obsm_key].copy()
-        )
-        pwd = pairwise_distances(
-            cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs
-        )
-        adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd
-
-    def compare_distance(
-        self,
-        pert: cp.ndarray,
-        pred: cp.ndarray,
-        ctrl: cp.ndarray,
-        mode: Literal["simple", "scaled"] = "simple",
-        fit_to_pert_and_ctrl: bool = False,
-        **kwargs,
-    ) -> float:
-        """Compute the score of simulating a perturbation.
-
-        Args:
-            pert: Real perturbed data.
-            pred: Simulated perturbed data.
-            ctrl: Control data
-            mode: Mode to use.
-            fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`.
-            kwargs: Additional keyword arguments passed to the metric function.
-        """
-        if mode == "simple":
-            pass  # nothing to be done
-        elif mode == "scaled":
-            scaler = MinMaxScaler().fit(
-                np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl
-            )
-            pred = scaler.transform(pred)
-            pert = scaler.transform(pert)
-        else:
-            raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.")
-
-        d1 = self.metric_fct(pert, pred, **kwargs)
-        d2 = self.metric_fct(ctrl, pred, **kwargs)
-        return d1 / d2
-
-    def _bootstrap_mode(
-        self, X, Y, n_bootstraps=100, random_state=0, **kwargs
-    ) -> MeanVar:
-        # rng = np.random.default_rng(random_state)
-
-        distances = []
-        for _ in range(n_bootstraps):
-            X_bootstrapped = X[choice(a=X.shape[0], size=X.shape[0], replace=True)]
-            Y_bootstrapped = Y[choice(a=Y.shape[0], size=X.shape[0], replace=True)]
-
-            distance = self(X_bootstrapped, Y_bootstrapped, **kwargs)
-            distances.append(distance.get())
-
-        mean = np.mean(distances)
-        variance = np.var(distances)
-        return MeanVar(mean=mean, variance=variance)
-
-    def _bootstrap_mode_precomputed(
-        self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs
-    ) -> MeanVar:
-        rng = np.random.default_rng(random_state)
-
-        distances = []
-        for _ in range(n_bootstraps):
-            # To maintain the number of cells for both groups (whatever balancing they may have),
-            # we sample the positive and negative indices separately
-            bootstrap_pos_idx = rng.choice(
-                a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True
-            )
-            bootstrap_neg_idx = rng.choice(
-                a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True
-            )
-            bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx])
-            bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx)
-
-            bootstrap_sub_idx = sub_idx[bootstrap_idx]
-            bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs]
-
-            distance = self.metric_fct.from_precomputed(
-                bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs
-            )
-            distances.append(distance)
-
-        mean = np.mean(distances)
-        variance = np.var(distances)
-        return MeanVar(mean=mean, variance=variance)
-
-    def onesided_distances(
-        self,
-        adata: AnnData,
-        groupby: str,
-        selected_group: str | None = None,
-        groups: list[str] | None = None,
-        bootstrap: bool = False,
-        n_bootstrap: int = 100,
-        random_state: int = 0,
-        show_progressbar: bool = True,
-        n_jobs: int = -1,
-        **kwargs,
-    ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
-        """Get distances between one selected cell group and the remaining other cell groups.
-
-        Args:
-            adata: Annotated data matrix.
-            groupby: Column name in adata.obs.
-            selected_group: Group to compute pairwise distances to all other.
-            groups: List of groups to compute distances to selected_group for.
-                    If None, uses all groups.
-            bootstrap: Whether to bootstrap the distance.
-            n_bootstrap: Number of bootstrap samples.
-            random_state: Random state for bootstrapping.
-            show_progressbar: Whether to show progress bar.
-            n_jobs: Number of cores to use. Defaults to -1 (all).
-            kwargs: Additional keyword arguments passed to the metric function.
-
-        Returns:
-            :class:`pandas.DataFrame`: Dataframe with distances of groups to selected_group.
-            tuple[:class:`pandas.DataFrame`, :class:`pandas.DataFrame`]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group.
-
-
-        Examples:
-            >>> import pertpy as pt
-            >>> adata = pt.dt.distance_example()
-            >>> Distance = pt.tools.Distance(metric="edistance")
-            >>> pairwise_df = Distance.onesided_distances(adata, groupby="perturbation", selected_group="control")
-        """
-        if self.metric == "classifier_cp":
-            if bootstrap:
-                raise NotImplementedError(
-                    "Currently, ClassifierClassProjection does not support bootstrapping."
-                )
-            return self.metric_fct.onesided_distances(  # type: ignore
-                adata,
-                groupby,
-                selected_group,
-                groups,
-                show_progressbar,
-                n_jobs,
-                **kwargs,
-            )
-
-        groups = adata.obs[groupby].unique() if groups is None else groups
-        grouping = adata.obs[groupby].copy()
-        df = pd.Series(index=groups, dtype=float)
-        if bootstrap:
-            df_var = pd.Series(index=groups, dtype=float)
-        # fct = track if show_progressbar else lambda iterable: iterable
-        fct = (
-            lambda iterable: iterable
-        )  # see TODO at the top of the file about progress bar
-
-        # Some metrics are able to handle precomputed distances. This means that
-        # the pairwise distances between all cells are computed once and then
-        # passed to the metric function. This is much faster than computing the
-        # pairwise distances for each group separately. Other metrics are not
-        # able to handle precomputed distances such as the PseudobulkDistance.
-        if self.metric_fct.accepts_precomputed:
-            # Precompute the pairwise distances if needed
-            if (
-                f"{self.obsm_key}_{self.cell_wise_metric}_predistances"
-                not in adata.obsp
-            ):
-                self.precompute_distances(adata, n_jobs=n_jobs, **kwargs)
-            pwd = adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"]
-            for group_x in fct(groups):
-                idx_x = grouping == group_x
-                group_y = selected_group
-                if group_x == group_y:
-                    df.loc[group_x] = 0.0  # by distance axiom
-                else:
-                    idx_y = grouping == group_y
-                    # subset the pairwise distance matrix to the two groups
-                    sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
-                    sub_idx = grouping[idx_x | idx_y] == group_x
-                    if not bootstrap:
-                        dist = self.metric_fct.from_precomputed(
-                            sub_pwd, sub_idx, **kwargs
-                        )
-                        df.loc[group_x] = dist
-                    else:
-                        bootstrap_output = self._bootstrap_mode_precomputed(
-                            sub_pwd,
-                            sub_idx,
-                            n_bootstraps=n_bootstrap,
-                            random_state=random_state,
-                            **kwargs,
-                        )
-                        df.loc[group_x] = bootstrap_output.mean
-                        df_var.loc[group_x] = bootstrap_output.variance
-        else:
-            embedding = (
-                adata.layers[self.layer_key]
-                if self.layer_key
-                else adata.obsm[self.obsm_key].copy()
-            )
-            for group_x in fct(groups):
-                cells_x = embedding[np.asarray(grouping == group_x)].copy()
-                group_y = selected_group
-                cells_y = embedding[np.asarray(grouping == group_y)].copy()
-                if not bootstrap:
-                    # By distance axiom, the distance between a group and itself is 0
-                    dist = (
-                        0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs)
-                    )
-                    df.loc[group_x] = dist
-                else:
-                    bootstrap_output = self.bootstrap(
-                        cells_x,
-                        cells_y,
-                        n_bootstrap=n_bootstrap,
-                        random_state=random_state,
-                        **kwargs,
-                    )
-                    # In the bootstrap case, distance of group to itself is a mean and can be non-zero
-                    df.loc[group_x] = bootstrap_output.mean
-                    df_var.loc[group_x] = bootstrap_output.variance
-        df.index.name = groupby
-        df.name = f"{self.metric} to {selected_group}"
-        if not bootstrap:
-            return df
-        else:
-            df_var.index.name = groupby
-            df_var = df_var.fillna(0)
-            df_var.name = f"pairwise {self.metric} variance to {selected_group}"
-
-            return df, df_var
-
-
-class AbstractDistance(ABC):
-    """Abstract class of distance metrics between two sets of vectors."""
-
-    @abstractmethod
-    def __init__(self) -> None:
-        super().__init__()
-        self.accepts_precomputed: bool = None
-
-    @abstractmethod
-    def __call__(self, X: cp.ndarray, Y: cp.ndarray, **kwargs) -> float:
-        """Compute distance between vectors X and Y.
-
-        Args:
-            X: First vector of shape (n_samples, n_features).
-            Y: Second vector of shape (n_samples, n_features).
-            kwargs: Passed to the metrics function.
-
-        Returns:
-            float: Distance between X and Y.
-        """
-        raise NotImplementedError("Metric class is abstract.")
-
-    @abstractmethod
-    def from_precomputed(self, P: cp.ndarray, idx: cp.ndarray, **kwargs) -> float:
-        """Compute a distance between vectors X and Y with precomputed distances.
-
-        Args:
-            P: Pairwise distance matrix of shape (n_samples, n_samples).
-            idx: Boolean array of shape (n_samples,) indicating which samples belong to X (or Y, since each metric is symmetric).
-            kwargs: Passed to the metrics function.
-
-        Returns:
-            float: Distance between X and Y.
-        """
-        raise NotImplementedError("Metric class is abstract.")
-
-
-class Edistance(AbstractDistance):
-    """Edistance metric."""
-
-    def __init__(self) -> None:
-        super().__init__()
-        self.accepts_precomputed = True
-        self.cell_wise_metric = "euclidean"
-
-    def __call__(self, X: cp.ndarray, Y: cp.ndarray, **kwargs) -> float:
-        sigma_X = pairwise_distances(
-            X, X, metric=self.cell_wise_metric, **kwargs
-        ).mean()
-        sigma_Y = pairwise_distances(
-            Y, Y, metric=self.cell_wise_metric, **kwargs
-        ).mean()
-        delta = pairwise_distances(X, Y, metric=self.cell_wise_metric, **kwargs).mean()
-        return 2 * delta - sigma_X - sigma_Y
-
-    def from_precomputed(self, P: cp.ndarray, idx: cp.ndarray, **kwargs) -> float:
-        sigma_X = P[idx, :][:, idx].mean()
-        sigma_Y = P[~idx, :][:, ~idx].mean()
-        delta = P[idx, :][:, ~idx].mean()
-        return 2 * delta - sigma_X - sigma_Y
diff --git a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py b/src/rapids_singlecell/pertpy_gpu/_edistance.py
similarity index 87%
rename from src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
rename to src/rapids_singlecell/pertpy_gpu/_edistance.py
index d532314f..0f76d15d 100644
--- a/src/rapids_singlecell/pertpy_gpu/_distances_standalone.py
+++ b/src/rapids_singlecell/pertpy_gpu/_edistance.py
@@ -1,14 +1,24 @@
 from __future__ import annotations
 
 from pathlib import Path
+from typing import TYPE_CHECKING, NamedTuple
 
 import cupy as cp
 import numpy as np
 import pandas as pd
-from anndata import AnnData
 
-from ..preprocessing._harmony._helper import _create_category_index_mapping
-from ..squidpy_gpu._utils import _assert_categorical_obs
+from rapids_singlecell.preprocessing._harmony._helper import (
+    _create_category_index_mapping,
+)
+from rapids_singlecell.squidpy_gpu._utils import _assert_categorical_obs
+
+if TYPE_CHECKING:
+    from anndata import AnnData
+
+
+class EDistanceResult(NamedTuple):
+    distances: pd.DataFrame
+    distances_var: pd.DataFrame | None
 
 
 # Load CUDA kernels from separate file
@@ -35,7 +45,87 @@ def _load_edistance_kernels():
 compute_group_distances_kernel = _load_edistance_kernels()
 
 
-def compute_pairwise_means_gpu(
+def pertpy_edistance(
+    adata: AnnData,
+    groupby: str,
+    *,
+    obsm_key: str = "X_pca",
+    groups: list[str] | None = None,
+    inplace: bool = False,
+    bootstrap: bool = False,
+    n_bootstrap: int = 100,
+    random_state: int = 0,
+) -> pd.DataFrame:
+    """
+    GPU-accelerated pairwise edistance computation with decomposed components.
+
+    Returns d_itself, d_other arrays and final edistance DataFrame where:
+    df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
+
+    Parameters
+    ----------
+    adata : AnnData
+        Annotated data matrix
+    groupby : str
+        Key in adata.obs for grouping
+    obsm_key : str
+        Key in adata.obsm for embeddings
+    groups : list[str] | None
+        Specific groups to compute (if None, use all)
+    copy : bool
+        Whether to return a copy
+
+    Returns
+    -------
+    df : pd.DataFrame
+        Final edistance matrix
+    """
+    _assert_categorical_obs(adata, key=groupby)
+
+    embedding = cp.array(adata.obsm[obsm_key]).astype(np.float32)
+    original_groups = adata.obs[groupby]
+    group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)}
+    group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32)
+
+    # Use harmony's category mapping
+    k = len(group_map)
+    cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)
+
+    groups_list = (
+        list(original_groups.cat.categories.values) if groups is None else groups
+    )
+    result = None
+    if not bootstrap:
+        df = _prepare_edistance_df(
+            embedding=embedding,
+            cat_offsets=cat_offsets,
+            cell_indices=cell_indices,
+            k=k,
+            groups_list=groups_list,
+            groupby=groupby,
+        )
+        result = EDistanceResult(distances=df, distances_var=None)
+
+    else:
+        df, df_var = _prepare_edistance_df_bootstrap(
+            embedding=embedding,
+            cat_offsets=cat_offsets,
+            cell_indices=cell_indices,
+            k=k,
+            groups_list=groups_list,
+            groupby=groupby,
+            n_bootstrap=n_bootstrap,
+            random_state=random_state,
+        )
+        result = EDistanceResult(distances=df, distances_var=df_var)
+
+    if inplace:
+        adata.uns[f"{groupby}_pairwise_edistance"] = dict(result)
+
+    return result
+
+
+def _pairwise_means(
     embedding: cp.ndarray, cat_offsets: cp.ndarray, cell_indices: cp.ndarray, k: int
 ) -> cp.ndarray:
     """
@@ -122,7 +212,7 @@ def compute_pairwise_means_gpu(
     return pairwise_means
 
 
-def generate_bootstrap_indices(
+def _generate_bootstrap_indices(
     cat_offsets: cp.ndarray,
     k: int,
     n_bootstrap: int = 100,
@@ -233,7 +323,7 @@ def _bootstrap_sample_cells_from_indices(
     return new_cat_offsets, cp.array(new_cell_indices, dtype=cp.int32)
 
 
-def compute_pairwise_means_gpu_bootstrap(
+def _pairwise_means_bootstrap(
     embedding: cp.ndarray,
     *,
     cat_offsets: cp.ndarray,
@@ -251,7 +341,7 @@ def compute_pairwise_means_gpu_bootstrap(
         variances: [k, k] matrix of bootstrap variances
     """
     # Generate all bootstrap indices upfront using CPU-compatible logic
-    bootstrap_indices = generate_bootstrap_indices(
+    bootstrap_indices = _generate_bootstrap_indices(
         cat_offsets, k, n_bootstrap, random_state
     )
 
@@ -267,7 +357,7 @@ def compute_pairwise_means_gpu_bootstrap(
         )
 
         # Compute distances with bootstrapped samples
-        pairwise_means = compute_pairwise_means_gpu(
+        pairwise_means = _pairwise_means(
             embedding=embedding,
             cat_offsets=boot_cat_offsets,
             cell_indices=boot_cell_indices,
@@ -283,97 +373,7 @@ def compute_pairwise_means_gpu_bootstrap(
     return means, variances
 
 
-def pairwise_edistance_gpu(
-    adata: AnnData,
-    groupby: str,
-    *,
-    obsm_key: str = "X_pca",
-    groups: list[str] | None = None,
-    inplace: bool = False,
-    bootstrap: bool = False,
-    n_bootstrap: int = 100,
-    random_state: int = 0,
-) -> pd.DataFrame:
-    """
-    GPU-accelerated pairwise edistance computation with decomposed components.
-
-    Returns d_itself, d_other arrays and final edistance DataFrame where:
-    df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
-
-    Parameters
-    ----------
-    adata : AnnData
-        Annotated data matrix
-    groupby : str
-        Key in adata.obs for grouping
-    obsm_key : str
-        Key in adata.obsm for embeddings
-    groups : list[str] | None
-        Specific groups to compute (if None, use all)
-    copy : bool
-        Whether to return a copy
-
-    Returns
-    -------
-    d_itself : cp.ndarray
-        Within-group mean distances [k]
-    d_other : cp.ndarray
-        Between-group mean distances [k, k]
-    df : pd.DataFrame
-        Final edistance matrix
-    """
-    _assert_categorical_obs(adata, key=groupby)
-
-    embedding = cp.array(adata.obsm[obsm_key]).astype(
-        np.float32
-    )  # Changed from float64
-    original_groups = adata.obs[groupby]
-    group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)}
-    group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32)
-
-    # 2. Use harmony's category mapping
-    k = len(group_map)
-    cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)
-
-    groups_list = (
-        list(original_groups.cat.categories.values) if groups is None else groups
-    )
-    if not bootstrap:
-        df = compute_pairwise_means_gpu_edistance(
-            embedding=embedding,
-            cat_offsets=cat_offsets,
-            cell_indices=cell_indices,
-            k=k,
-            groups_list=groups_list,
-            groupby=groupby,
-        )
-        if inplace:
-            adata.uns[f"{groupby}_pairwise_edistance"] = {
-                "distances": df,
-            }
-        return df
-
-    else:
-        df, df_var = compute_pairwise_means_gpu_edistance_bootstrap(
-            embedding=embedding,
-            cat_offsets=cat_offsets,
-            cell_indices=cell_indices,
-            k=k,
-            groups_list=groups_list,
-            groupby=groupby,
-            n_bootstrap=n_bootstrap,
-            random_state=random_state,
-        )
-
-        if inplace:
-            adata.uns[f"{groupby}_pairwise_edistance"] = {
-                "distances": df,
-                "distances_var": df_var,
-            }
-        return df, df_var
-
-
-def compute_pairwise_means_gpu_edistance_bootstrap(
+def _prepare_edistance_df_bootstrap(
     embedding: cp.ndarray,
     *,
     cat_offsets: cp.ndarray,
@@ -385,7 +385,7 @@ def compute_pairwise_means_gpu_edistance_bootstrap(
     random_state: int = 0,
 ) -> tuple[pd.DataFrame, pd.DataFrame]:
     # Bootstrap computation
-    pairwise_means_boot, pairwise_vars_boot = compute_pairwise_means_gpu_bootstrap(
+    pairwise_means_boot, pairwise_vars_boot = _pairwise_means_bootstrap(
         embedding=embedding,
         cat_offsets=cat_offsets,
         cell_indices=cell_indices,
@@ -394,7 +394,7 @@ def compute_pairwise_means_gpu_edistance_bootstrap(
         random_state=random_state,
     )
 
-    # 4. Compute final edistance for means and variances
+    # Compute final edistance for means and variances
     edistance_means = cp.zeros((k, k), dtype=np.float32)
     edistance_vars = cp.zeros((k, k), dtype=np.float32)
 
@@ -434,7 +434,7 @@ def compute_pairwise_means_gpu_edistance_bootstrap(
     return df_mean, df_var
 
 
-def compute_pairwise_means_gpu_edistance(
+def _prepare_edistance_df(
     embedding: cp.ndarray,
     *,
     cat_offsets: cp.ndarray,
@@ -443,11 +443,10 @@ def compute_pairwise_means_gpu_edistance(
     groups_list: list[str],
     groupby: str,
 ) -> pd.DataFrame:
-    # 3. Compute decomposed components
-    # d_itself = compute_d_itself_gpu(embedding, cat_offsets, cell_indices, k)
-    pairwise_means = compute_pairwise_means_gpu(embedding, cat_offsets, cell_indices, k)
+    # Compute means
+    pairwise_means = _pairwise_means(embedding, cat_offsets, cell_indices, k)
 
-    # 4. Compute final edistance: df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
+    # Compute final edistance: df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
     edistance_matrix = cp.zeros((k, k), dtype=np.float32)
     for a in range(k):
         for b in range(a + 1, k):
@@ -456,8 +455,7 @@ def compute_pairwise_means_gpu_edistance(
             )
             edistance_matrix[b, a] = edistance_matrix[a, b]
 
-    # 5. Create output DataFrame
-
+    # Create output DataFrame
     df = pd.DataFrame(edistance_matrix.get(), index=groups_list, columns=groups_list)
     df.index.name = groupby
     df.columns.name = groupby
diff --git a/tmp_scripts/compute_edistance_standalone.py b/tmp_scripts/compute_edistance_standalone.py
index 2f668a57..ef1d12dd 100644
--- a/tmp_scripts/compute_edistance_standalone.py
+++ b/tmp_scripts/compute_edistance_standalone.py
@@ -13,7 +13,7 @@
 from rmm.allocators.cupy import rmm_cupy_allocator
 
 import rapids_singlecell as rsc
-from rapids_singlecell.pertpy_gpu._distances_standalone import pairwise_edistance_gpu
+from rapids_singlecell.pertpy_gpu._edistance import pertpy_edistance
 
 rmm.reinitialize(
     managed_memory=False,  # Allows oversubscription
@@ -48,17 +48,20 @@
 
     start_time = time.time()
     if not bootstrap:
-        df_gpu = pairwise_edistance_gpu(
+        res = pertpy_edistance(
             adata, groupby=obs_key, obsm_key="X_pca", bootstrap=bootstrap
         )
+        df_gpu = res.distances
     else:
-        df_gpu, df_gpu_var = pairwise_edistance_gpu(
+        res = pertpy_edistance(
             adata,
             groupby=obs_key,
             obsm_key="X_pca",
             bootstrap=bootstrap,
             n_bootstrap=100,
         )
+        df_gpu = res.distances
+        df_gpu_var = res.distances_var
     end_time = time.time()
     print(f"Time taken: {end_time - start_time} seconds")
 

From ecfc5cf55d09f3a2634bb906b696ec9dbcb75eab Mon Sep 17 00:00:00 2001
From: selmanozleyen 
Date: Wed, 1 Oct 2025 20:12:56 +0000
Subject: [PATCH 19/23] push working state

---
 pyproject.toml                                |   1 -
 .../pertpy_gpu/_edistance.py                  |  30 +-
 .../_edistance.py}                            |  57 ++--
 tests/pertpy/test_distances.py                | 321 +++++-------------
 4 files changed, 117 insertions(+), 292 deletions(-)
 rename src/rapids_singlecell/pertpy_gpu/{kernels/edistance_kernels.cu => _kernels/_edistance.py} (54%)

diff --git a/pyproject.toml b/pyproject.toml
index 14f7ad14..9c2d5cef 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,7 +47,6 @@ test-minimal = [
     "scanpy>=1.10.0",
     "bbknn",
     "decoupler",
-    "pertpy",
     "fast-array-utils",
 ]
 test = [
diff --git a/src/rapids_singlecell/pertpy_gpu/_edistance.py b/src/rapids_singlecell/pertpy_gpu/_edistance.py
index 0f76d15d..c6569d0f 100644
--- a/src/rapids_singlecell/pertpy_gpu/_edistance.py
+++ b/src/rapids_singlecell/pertpy_gpu/_edistance.py
@@ -1,12 +1,14 @@
 from __future__ import annotations
 
-from pathlib import Path
 from typing import TYPE_CHECKING, NamedTuple
 
 import cupy as cp
 import numpy as np
 import pandas as pd
 
+from rapids_singlecell.pertpy_gpu._kernels._edistance import (
+    get_compute_group_distances_kernel,
+)
 from rapids_singlecell.preprocessing._harmony._helper import (
     _create_category_index_mapping,
 )
@@ -21,29 +23,7 @@ class EDistanceResult(NamedTuple):
     distances_var: pd.DataFrame | None
 
 
-# Load CUDA kernels from separate file
-def _load_edistance_kernels():
-    """Load CUDA kernels from separate .cu file"""
-    kernel_dir = Path(__file__).parent / "kernels"
-    kernel_file = kernel_dir / "edistance_kernels.cu"
-
-    if not kernel_file.exists():
-        raise FileNotFoundError(f"CUDA kernel file not found: {kernel_file}")
-
-    with open(kernel_file) as f:
-        kernel_code = f.read()
-
-    # Compile kernels
-    compute_group_distances_kernel = cp.RawKernel(
-        kernel_code, "compute_group_distances"
-    )
-
-    return compute_group_distances_kernel
-
-
-# Load kernels at module import time
-compute_group_distances_kernel = _load_edistance_kernels()
-
+compute_group_distances_kernel = get_compute_group_distances_kernel()
 
 def pertpy_edistance(
     adata: AnnData,
@@ -166,7 +146,7 @@ def _pairwise_means(
     num_pairs = len(pair_left)  # k * (k-1) pairs instead of k²
 
     # Allocate output for off-diagonal distances only
-    d_other_offdiag = cp.zeros(num_pairs, dtype=np.float32)
+    d_other_offdiag = cp.zeros(num_pairs, dtype=embedding.dtype)
 
     # Choose optimal block size
     props = cp.cuda.runtime.getDeviceProperties(0)
diff --git a/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu b/src/rapids_singlecell/pertpy_gpu/_kernels/_edistance.py
similarity index 54%
rename from src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
rename to src/rapids_singlecell/pertpy_gpu/_kernels/_edistance.py
index 9aa2f9f0..d669c2da 100644
--- a/src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu
+++ b/src/rapids_singlecell/pertpy_gpu/_kernels/_edistance.py
@@ -1,23 +1,16 @@
-// kernels/edistance_kernels.cu
-#include 
-#include 
-
-extern "C" {
-
-/**
- * Compute between-group mean distances for off-diagonal pairs only
- * Each block processes one group pair, threads collaborate within the block
- */
-__global__ void compute_group_distances(
-    const float* __restrict__ embedding,
-    const int* __restrict__ cat_offsets,
-    const int* __restrict__ cell_indices,
-    const int* __restrict__ pair_left,
-    const int* __restrict__ pair_right,
-    float* __restrict__ d_other,
-    int k,
-    int n_features)
-{
+from __future__ import annotations
+
+from cuml.common.kernel_utils import cuda_kernel_factory
+
+_compute_group_distances_kernel = r"""
+(const float* __restrict__ embedding,
+ const int* __restrict__ cat_offsets,
+ const int* __restrict__ cell_indices,
+ const int* __restrict__ pair_left,
+ const int* __restrict__ pair_right,
+ float* __restrict__ d_other,
+ int k,
+ int n_features) {
     extern __shared__ float shared_sums[];
 
     const int thread_id = threadIdx.x;
@@ -29,8 +22,6 @@
     const int a = pair_left[block_id];
     const int b = pair_right[block_id];
 
-    // No need to check a == b since we only pass off-diagonal pairs
-
     const int start_a = cat_offsets[a];
     const int end_a = cat_offsets[a + 1];
     const int start_b = cat_offsets[b];
@@ -39,25 +30,23 @@
     const int n_a = end_a - start_a;
     const int n_b = end_b - start_b;
 
-    // Compute between-group distances (ALL cross-pairs)
     for (int ia = start_a + thread_id; ia < end_a; ia += block_size) {
         const int idx_i = cell_indices[ia];
 
         for (int jb = start_b; jb < end_b; ++jb) {
             const int idx_j = cell_indices[jb];
 
-            float dist_sq = 0.0f;
-            #pragma unroll
+            double dist_sq = 0.0;
             for (int feat = 0; feat < n_features; ++feat) {
-                float diff = embedding[idx_i * n_features + feat] -
-                            embedding[idx_j * n_features + feat];
+                double diff = (double)embedding[idx_i * n_features + feat] -
+                              (double)embedding[idx_j * n_features + feat];
                 dist_sq += diff * diff;
             }
-            local_sum += sqrtf(dist_sq);
+
+            local_sum += (float)sqrt(dist_sq);
         }
     }
 
-    // Reduce across threads using shared memory
     shared_sums[thread_id] = local_sum;
     __syncthreads();
 
@@ -69,9 +58,15 @@
     }
 
     if (thread_id == 0) {
-        // Store mean between-group distance
         d_other[block_id] = shared_sums[0] / (float)(n_a * n_b);
     }
 }
+"""
+
 
-} // extern "C"
+def get_compute_group_distances_kernel():
+    return cuda_kernel_factory(
+        _compute_group_distances_kernel,
+        (),
+        "compute_group_distances",
+    )
diff --git a/tests/pertpy/test_distances.py b/tests/pertpy/test_distances.py
index e7f00c62..7a7bb954 100644
--- a/tests/pertpy/test_distances.py
+++ b/tests/pertpy/test_distances.py
@@ -2,268 +2,119 @@
 
 import cupy as cp
 import numpy as np
+import pandas as pd
 import pytest
-import scanpy as sc
-from pandas import DataFrame, Series
-from pertpy import data as dt
-from pytest import fixture, mark
-from scipy import sparse as sp
+from anndata import AnnData
 
-from rapids_singlecell.pertpy_gpu._distances import Distance
+from rapids_singlecell.pertpy_gpu._edistance import EDistanceResult, pertpy_edistance
 
 
 @pytest.fixture
-def cp_rng():  # TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture
-    rng = cp.random.default_rng(seed=42)
-    return rng
-
-
-@pytest.fixture
-def np_rng():  # TODO(selmanozleyen): Think of a way to integrate this with decoupler's rng fixture
-    rng = np.random.default_rng(seed=42)
-    return rng
-
-
-actual_distances = [
-    # Euclidean distances and related
-    # "euclidean",
-    # "mean_absolute_error",
-    # "mean_pairwise",
-    # "mse",
-    "edistance",
-    # Other
-    # "cosine_distance",
-    # "kendalltau_distance",
-    # "mmd",
-    # "pearson_distance",
-    # "spearman_distance",
-    # "t_test",
-    # "mahalanobis",
-]
-# semi_distances = ["r2_distance", "sym_kldiv", "ks_test"]
-# non_distances = ["classifier_proba"]
-# onesided_only = ["classifier_cp"]
-# pseudo_counts_distances = ["nb_ll"]
-# lognorm_counts_distances = ["mean_var_distribution"]
-all_distances = actual_distances  # + semi_distances + non_distances + lognorm_counts_distances + pseudo_counts_distances  # + onesided_only
-semi_distances = []
-non_distances = []
-onesided_only = []
-pseudo_counts_distances = []
-lognorm_counts_distances = []
-
-
-@fixture
-def adata(request):
-    low_subsample_distances = [
-        "sym_kldiv",
-        "t_test",
-        "ks_test",
-        "classifier_proba",
-        "classifier_cp",
-        "mahalanobis",
-        "mean_var_distribution",
-    ]
-    no_subsample_distances = [
-        "mahalanobis"
-    ]  # mahalanobis only works on the full data without subsampling
-
-    distance = request.node.callspec.params["distance"]
-
-    adata = dt.distance_example()
-    if distance not in no_subsample_distances:
-        if distance in low_subsample_distances:
-            adata = sc.pp.subsample(adata, 0.1, copy=True)
-        else:
-            adata = sc.pp.subsample(adata, 0.001, copy=True)
-
-    adata = adata[
-        :, np.random.default_rng().choice(adata.n_vars, 100, replace=False)
-    ].copy()
-
-    adata.layers["lognorm"] = adata.X.copy()
-    adata.layers["counts"] = cp.round(adata.X.toarray()).astype(int)
-    if "X_pca" not in adata.obsm:
-        sc.pp.pca(adata, n_comps=5)
-    if distance in lognorm_counts_distances:
-        groups = np.unique(adata.obs["perturbation"])
-        # KDE is slow, subset to 3 groups for speed up
-        adata = adata[adata.obs["perturbation"].isin(groups[0:3])].copy()
-
-    adata.X = cp.asarray(adata.X.toarray())
-    for l_key in adata.layers.keys():
-        if sp.issparse(adata.layers[l_key]):
-            from cupyx.scipy.sparse import coo_matrix, csc_matrix, csr_matrix
-
-            if sp.isspmatrix_csr(adata.layers[l_key]):
-                adata.layers[l_key] = csr_matrix(adata.layers[l_key])
-            elif sp.isspmatrix_csc(adata.layers[l_key]):
-                adata.layers[l_key] = csc_matrix(adata.layers[l_key])
-            elif sp.isspmatrix_coo(adata.layers[l_key]):
-                adata.layers[l_key] = coo_matrix(adata.layers[l_key])
-        else:
-            adata.layers[l_key] = cp.asarray(adata.layers[l_key])
-    adata.layers["lognorm"] = cp.asarray(adata.layers["lognorm"].toarray())
-    adata.layers["counts"] = cp.asarray(adata.layers["counts"])
-    adata.obsm["X_pca"] = cp.asarray(adata.obsm["X_pca"])
-
+def small_adata() -> AnnData:
+    rng = np.random.default_rng(0)
+    n_groups = 3
+    cells_per_group = 4
+    n_features = 5
+    total_cells = n_groups * cells_per_group
+
+    cpu_embedding = rng.normal(size=(total_cells, n_features)).astype(np.float32)
+    groups = [f"g{idx}" for idx in range(n_groups) for _ in range(cells_per_group)]
+    obs = pd.DataFrame({"group": pd.Categorical(groups, categories=[f"g{i}" for i in range(n_groups)])})
+
+    adata = AnnData(cpu_embedding.copy(), obs=obs)
+    adata.obsm["X_pca"] = cp.asarray(cpu_embedding, dtype=cp.float32)
     return adata
 
 
-@fixture
-def distance_obj(request):
-    distance = request.node.callspec.params["distance"]
-    if distance in lognorm_counts_distances:
-        d = Distance(distance, layer_key="lognorm")
-    elif distance in pseudo_counts_distances:
-        d = Distance(distance, layer_key="counts")
-    else:
-        d = Distance(distance, obsm_key="X_pca")
-    return d
+def _compute_cpu_reference(adata: AnnData, obsm_key: str, group_key: str) -> tuple[np.ndarray, np.ndarray]:
+    embedding = adata.obsm[obsm_key].get()
+    group_series = adata.obs[group_key]
+    categories = list(group_series.cat.categories)
+    k = len(categories)
 
+    pair_means = np.zeros((k, k), dtype=np.float32)
 
-@fixture
-@mark.parametrize("distance", all_distances)
-def pairwise_distance(adata, distance_obj, distance):
-    return distance_obj.pairwise(adata, groupby="perturbation", show_progressbar=False)
+    for i, gi in enumerate(categories):
+        idx_i = np.where(group_series == gi)[0]
+        for j, gj in enumerate(categories[i:], start=i):
+            idx_j = np.where(group_series == gj)[0]
+            if len(idx_i) == 0 or len(idx_j) == 0:
+                mean_distance = 0.0
+            else:
+                distances = []
+                for idx in idx_i:
+                    diffs = embedding[idx] - embedding[idx_j]
+                    distances.append(np.sqrt(np.sum(diffs ** 2, axis=1)))
+                stacked = np.concatenate(distances)
+                mean_distance = stacked.mean(dtype=np.float64)
+            pair_means[i, j] = pair_means[j, i] = np.float32(mean_distance)
 
+    edistance = np.zeros((k, k), dtype=np.float32)
+    for i in range(k):
+        for j in range(i + 1, k):
+            value = 2 * pair_means[i, j] - pair_means[i, i] - pair_means[j, j]
+            edistance[i, j] = edistance[j, i] = np.float32(value)
 
-@mark.parametrize("distance", actual_distances + semi_distances)
-def test_distance_axioms(pairwise_distance, distance):
-    # This is equivalent to testing for a semimetric, defined as fulfilling all axioms except triangle inequality.
-    # (M1) Definiteness
-    assert all(np.diag(pairwise_distance.values) == 0)  # distance to self is 0
+    return pair_means, edistance
 
-    # (M2) Positivity
-    assert len(pairwise_distance) == np.sum(
-        pairwise_distance.values == 0
-    )  # distance to other is not 0
-    assert all(pairwise_distance.values.flatten() >= 0)  # distance is non-negative
 
-    # (M3) Symmetry
-    assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0
+def test_pertpy_edistance_matches_cpu_reference(small_adata: AnnData) -> None:
+    result = pertpy_edistance(small_adata, groupby="group", obsm_key="X_pca")
 
+    assert isinstance(result, EDistanceResult)
+    assert result.distances_var is None
 
-@mark.parametrize("distance", actual_distances)
-def test_triangle_inequality(pairwise_distance, distance, np_rng):
-    # Test if distances are well-defined in accordance with metric axioms
-    # (M4) Triangle inequality (we just probe this for a few random triplets)
-    # Some tests are not well defined for the triangle inequality. We skip those.
-    if distance in {"mahalanobis"}:
-        return
+    _, cpu_edistance = _compute_cpu_reference(small_adata, "X_pca", "group")
 
-    for _ in range(5):
-        triplet = np_rng.choice(pairwise_distance.index, size=3, replace=False)
-        assert (
-            pairwise_distance.loc[triplet[0], triplet[1]]
-            + pairwise_distance.loc[triplet[1], triplet[2]]
-            >= pairwise_distance.loc[triplet[0], triplet[2]]
-        )
+    assert result.distances.shape == cpu_edistance.shape
+    np.testing.assert_allclose(result.distances.values, cpu_edistance, atol=1e-5)
+    assert np.allclose(result.distances.values, result.distances.values.T)
 
 
-@mark.parametrize("distance", all_distances)
-def test_distance_layers(pairwise_distance, distance):
-    assert isinstance(pairwise_distance, DataFrame)
-    assert pairwise_distance.columns.equals(pairwise_distance.index)
-    assert (
-        np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0
-    )  # symmetry
-
-
-@mark.parametrize("distance", actual_distances + pseudo_counts_distances)
-def test_distance_counts(adata, distance):
-    if (
-        distance != "mahalanobis"
-    ):  # skip, doesn't work because covariance matrix is a singular matrix, not invertible
-        distance = Distance(distance, layer_key="counts")
-        df = distance.pairwise(adata, groupby="perturbation")
-        assert isinstance(df, DataFrame)
-        assert df.columns.equals(df.index)
-        assert np.sum(df.values - df.values.T) == 0
-
-
-@mark.parametrize("distance", all_distances)
-def test_mutually_exclusive_keys(distance):
-    with pytest.raises(ValueError):
-        _ = Distance(distance, layer_key="counts", obsm_key="X_pca")
-
-
-@mark.parametrize("distance", actual_distances + semi_distances + non_distances)
-def test_distance_output_type(distance, cp_rng):
-    # Test if distances are outputting floats
-    dist = Distance(distance)
-    X = cp_rng.standard_normal(size=(50, 10))
-    Y = cp_rng.standard_normal(size=(50, 10))
-    d = dist(X, Y)
-    d = float(d.get())
-    assert isinstance(d, float)
-
-
-@mark.parametrize("distance", all_distances + onesided_only)
-def test_distance_onesided(adata, distance_obj, distance):
-    # Test consistency of one-sided distance results
-    selected_group = adata.obs.perturbation.unique()[0]
-    df = distance_obj.onesided_distances(
-        adata, groupby="perturbation", selected_group=selected_group
+def test_pertpy_edistance_inplace_populates_uns(small_adata: AnnData) -> None:
+    key = "group_pairwise_edistance"
+    result = pertpy_edistance(
+        small_adata,
+        groupby="group",
+        obsm_key="X_pca",
+        inplace=True,
     )
-    assert isinstance(df, Series)
-    assert df.loc[selected_group] == 0  # distance to self is 0
-
 
-def test_bootstrap_distance_output_type(cp_rng):
-    # Test if distances are outputting floats
-    d = Distance(metric="edistance")
-    X = cp_rng.standard_normal(size=(50, 10))
-    Y = cp_rng.standard_normal(size=(50, 10))
-    d = d.bootstrap(X, Y, n_bootstrap=3)
-    assert hasattr(d, "mean")
-    assert hasattr(d, "variance")
+    assert isinstance(result, EDistanceResult)
+    assert key in small_adata.uns
+    stored = small_adata.uns[key]
+    assert set(stored.keys()) == {"distances", "distances_var"}
+    np.testing.assert_allclose(stored["distances"].values, result.distances.values)
+    assert stored["distances_var"] is None
 
 
-@mark.parametrize("distance", ["edistance"])
-def test_bootstrap_distance_pairwise(adata, distance):
-    # Test consistency of pairwise distance results
-    dist = Distance(distance, obsm_key="X_pca")
-    bootstrap_output = dist.pairwise(
-        adata, groupby="perturbation", bootstrap=True, n_bootstrap=3
+def test_pertpy_edistance_bootstrap_returns_variance(small_adata: AnnData) -> None:
+    result = pertpy_edistance(
+        small_adata,
+        groupby="group",
+        obsm_key="X_pca",
+        bootstrap=True,
+        n_bootstrap=8,
+        random_state=11,
     )
 
-    assert isinstance(bootstrap_output, tuple)
-
-    mean = bootstrap_output[0]
-    var = bootstrap_output[1]
+    assert isinstance(result, EDistanceResult)
+    assert result.distances_var is not None
+    assert result.distances.shape == result.distances_var.shape
+    assert np.allclose(result.distances.values, result.distances.values.T)
+    assert np.allclose(result.distances_var.values, result.distances_var.values.T)
+    assert np.all(result.distances_var.values >= 0)
 
-    assert mean.columns.equals(mean.index)
-    assert np.sum(mean.values - mean.values.T) == 0  # symmetry
-    assert np.sum(var.values - var.values.T) == 0  # symmetry
 
+def test_pertpy_edistance_requires_categorical_obs(small_adata: AnnData) -> None:
+    bad = small_adata.copy()
+    bad.obs["group"] = bad.obs["group"].astype(str)
 
-@mark.parametrize("distance", ["edistance"])
-def test_bootstrap_distance_onesided(adata, distance):
-    # Test consistency of one-sided distance results
-    selected_group = adata.obs.perturbation.unique()[0]
-    d = Distance(distance, obsm_key="X_pca")
-    bootstrap_output = d.onesided_distances(
-        adata,
-        groupby="perturbation",
-        selected_group=selected_group,
-        bootstrap=True,
-        n_bootstrap=3,
-    )
-
-    assert isinstance(bootstrap_output, tuple)
+    with pytest.raises(TypeError):
+        pertpy_edistance(bad, groupby="group", obsm_key="X_pca")
 
 
-def test_compare_distance(cp_rng):
-    X = cp_rng.standard_normal(size=(50, 10))
-    Y = cp_rng.standard_normal(size=(50, 10))
-    C = cp_rng.standard_normal(size=(50, 10))
-    d = Distance()
-    res_simple = d.compare_distance(X, Y, C, mode="simple")
-    res_simple = float(res_simple.get())
-    assert isinstance(res_simple, float)
-    res_scaled = d.compare_distance(X, Y, C, mode="scaled")
-    res_scaled = float(res_scaled.get())
-    assert isinstance(res_scaled, float)
-    with pytest.raises(ValueError):
-        d.compare_distance(X, Y, C, mode="new_mode")
+@pytest.mark.parametrize("missing_key", ["missing", "other"])
+def test_pertpy_edistance_missing_group_raises(small_adata: AnnData, missing_key: str) -> None:
+    with pytest.raises(KeyError):
+        pertpy_edistance(small_adata, groupby=missing_key, obsm_key="X_pca")

From 578326149152cbbc8bfb3488a89b6d869b9c8b0c Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 1 Oct 2025 20:13:07 +0000
Subject: [PATCH 20/23] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 src/rapids_singlecell/pertpy_gpu/_edistance.py |  1 +
 tests/pertpy/test_distances.py                 | 14 ++++++++++----
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/src/rapids_singlecell/pertpy_gpu/_edistance.py b/src/rapids_singlecell/pertpy_gpu/_edistance.py
index c6569d0f..613c2384 100644
--- a/src/rapids_singlecell/pertpy_gpu/_edistance.py
+++ b/src/rapids_singlecell/pertpy_gpu/_edistance.py
@@ -25,6 +25,7 @@ class EDistanceResult(NamedTuple):
 
 compute_group_distances_kernel = get_compute_group_distances_kernel()
 
+
 def pertpy_edistance(
     adata: AnnData,
     groupby: str,
diff --git a/tests/pertpy/test_distances.py b/tests/pertpy/test_distances.py
index 7a7bb954..23a7a557 100644
--- a/tests/pertpy/test_distances.py
+++ b/tests/pertpy/test_distances.py
@@ -19,14 +19,18 @@ def small_adata() -> AnnData:
 
     cpu_embedding = rng.normal(size=(total_cells, n_features)).astype(np.float32)
     groups = [f"g{idx}" for idx in range(n_groups) for _ in range(cells_per_group)]
-    obs = pd.DataFrame({"group": pd.Categorical(groups, categories=[f"g{i}" for i in range(n_groups)])})
+    obs = pd.DataFrame(
+        {"group": pd.Categorical(groups, categories=[f"g{i}" for i in range(n_groups)])}
+    )
 
     adata = AnnData(cpu_embedding.copy(), obs=obs)
     adata.obsm["X_pca"] = cp.asarray(cpu_embedding, dtype=cp.float32)
     return adata
 
 
-def _compute_cpu_reference(adata: AnnData, obsm_key: str, group_key: str) -> tuple[np.ndarray, np.ndarray]:
+def _compute_cpu_reference(
+    adata: AnnData, obsm_key: str, group_key: str
+) -> tuple[np.ndarray, np.ndarray]:
     embedding = adata.obsm[obsm_key].get()
     group_series = adata.obs[group_key]
     categories = list(group_series.cat.categories)
@@ -44,7 +48,7 @@ def _compute_cpu_reference(adata: AnnData, obsm_key: str, group_key: str) -> tup
                 distances = []
                 for idx in idx_i:
                     diffs = embedding[idx] - embedding[idx_j]
-                    distances.append(np.sqrt(np.sum(diffs ** 2, axis=1)))
+                    distances.append(np.sqrt(np.sum(diffs**2, axis=1)))
                 stacked = np.concatenate(distances)
                 mean_distance = stacked.mean(dtype=np.float64)
             pair_means[i, j] = pair_means[j, i] = np.float32(mean_distance)
@@ -115,6 +119,8 @@ def test_pertpy_edistance_requires_categorical_obs(small_adata: AnnData) -> None
 
 
 @pytest.mark.parametrize("missing_key", ["missing", "other"])
-def test_pertpy_edistance_missing_group_raises(small_adata: AnnData, missing_key: str) -> None:
+def test_pertpy_edistance_missing_group_raises(
+    small_adata: AnnData, missing_key: str
+) -> None:
     with pytest.raises(KeyError):
         pertpy_edistance(small_adata, groupby=missing_key, obsm_key="X_pca")

From 0496af0e1b7d1b705b662efa47619e4a2d55b466 Mon Sep 17 00:00:00 2001
From: selmanozleyen 
Date: Thu, 16 Oct 2025 09:30:01 +0000
Subject: [PATCH 21/23] add documentation and write a comparison script

---
 .../pertpy_gpu/_edistance.py                  | 15 ++++---
 tmp_scripts/compute_edistance.py              | 37 ---------------
 tmp_scripts/run.py                            | 45 +++++++++++++++++++
 3 files changed, 54 insertions(+), 43 deletions(-)
 delete mode 100644 tmp_scripts/compute_edistance.py
 create mode 100644 tmp_scripts/run.py

diff --git a/src/rapids_singlecell/pertpy_gpu/_edistance.py b/src/rapids_singlecell/pertpy_gpu/_edistance.py
index 613c2384..65fe12d2 100644
--- a/src/rapids_singlecell/pertpy_gpu/_edistance.py
+++ b/src/rapids_singlecell/pertpy_gpu/_edistance.py
@@ -36,12 +36,15 @@ def pertpy_edistance(
     bootstrap: bool = False,
     n_bootstrap: int = 100,
     random_state: int = 0,
-) -> pd.DataFrame:
+) -> EDistanceResult:
     """
     GPU-accelerated pairwise edistance computation with decomposed components.
 
-    Returns d_itself, d_other arrays and final edistance DataFrame where:
-    df[a,b] = 2*d_other[a,b] - d_itself[a] - d_itself[b]
+    Returns EDistanceResult containing the distances and distances_var.
+    The distances DataFrame is where:
+    distances[a,b] = 2*d[a,b] - d[a] - d[b]
+    The distances_var DataFrame is where:
+    distances_var[a,b] = 4*d_var[a,b] + d_var[a] + d_var[b]
 
     Parameters
     ----------
@@ -58,8 +61,8 @@ def pertpy_edistance(
 
     Returns
     -------
-    df : pd.DataFrame
-        Final edistance matrix
+    result : EDistanceResult
+        EDistanceResult containing the distances and if bootstrap is True, the distances_var.
     """
     _assert_categorical_obs(adata, key=groupby)
 
@@ -101,7 +104,7 @@ def pertpy_edistance(
         result = EDistanceResult(distances=df, distances_var=df_var)
 
     if inplace:
-        adata.uns[f"{groupby}_pairwise_edistance"] = dict(result)
+        adata.uns[f"{groupby}_pairwise_edistance"] = result._asdict()
 
     return result
 
diff --git a/tmp_scripts/compute_edistance.py b/tmp_scripts/compute_edistance.py
deleted file mode 100644
index e325e7a8..00000000
--- a/tmp_scripts/compute_edistance.py
+++ /dev/null
@@ -1,37 +0,0 @@
-from __future__ import annotations
-
-import os
-import time
-from pathlib import Path
-
-import anndata as ad
-import cupy as cp
-import rmm
-from rmm.allocators.cupy import rmm_cupy_allocator
-
-import rapids_singlecell as rsc
-from rapids_singlecell.ptg import Distance
-
-rmm.reinitialize(
-    managed_memory=False,  # Allows oversubscription
-    pool_allocator=True,  # default is False
-    devices=0,  # GPU device IDs to register. By default registers only GPU 0.
-)
-cp.cuda.set_allocator(rmm_cupy_allocator)
-
-
-if __name__ == "__main__":
-    obs_key = "perturbation"
-
-    # homedir/data/adamson_2016_upr_epistasis
-    save_dir = os.path.join(
-        os.path.expanduser("~"),
-        "data",
-    )
-    adata = ad.read_h5ad(Path(save_dir) / "adamson_2016_upr_epistasis_pca.h5ad")
-    rsc.get.anndata_to_GPU(adata, convert_all=True)
-    dist = Distance(obsm_key="X_pca", metric="edistance")
-    start_time = time.time()
-    df = dist.pairwise(adata, groupby=obs_key)
-    end_time = time.time()
-    print(f"Time taken: {end_time - start_time} seconds")
diff --git a/tmp_scripts/run.py b/tmp_scripts/run.py
new file mode 100644
index 00000000..6ecfec46
--- /dev/null
+++ b/tmp_scripts/run.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+
+import os
+import time
+from argparse import ArgumentParser
+from pathlib import Path
+
+import rapids_singlecell as rsc
+from rapids_singlecell.pertpy_gpu._edistance import pertpy_edistance
+import anndata as ad
+from pertpy.tools import Distance
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("--bootstrap", action="store_true")
+    parser.add_argument("--gpu", action="store_true")
+    args = parser.parse_args()
+    obs_key = "perturbation"
+    bootstrap = args.bootstrap
+    gpu = args.gpu
+    # homedir/data/adamson_2016_upr_epistasis
+    save_dir = os.path.join(
+        os.path.expanduser("~"),
+        "data",
+    )
+    adata = ad.read_h5ad(Path(save_dir) / "adamson_2016_upr_epistasis_pca.h5ad")
+    if gpu:
+        rsc.get.anndata_to_GPU(adata, convert_all=True)
+    else:
+        dist = Distance(obsm_key="X_pca", metric="edistance")
+    start_time = time.time()
+    if gpu:
+        res = pertpy_edistance(
+            adata, groupby=obs_key, obsm_key="X_pca", bootstrap=bootstrap, n_bootstrap=100
+        )
+        df_gpu = res.distances
+        df_gpu_var = res.distances_var
+    else:
+        if bootstrap:
+            df, df_var = dist.pairwise(adata, groupby=obs_key, bootstrap=bootstrap, n_bootstrap=100)
+        else:
+            df = dist.pairwise(adata, groupby=obs_key)
+            df_var = None
+    end_time = time.time()
+    print(f"Time taken: {end_time - start_time} seconds")

From a7cf6c4ac06881ad421058bca31132ae886377d6 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Thu, 16 Oct 2025 09:30:11 +0000
Subject: [PATCH 22/23] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 tmp_scripts/run.py | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/tmp_scripts/run.py b/tmp_scripts/run.py
index 6ecfec46..bac0de01 100644
--- a/tmp_scripts/run.py
+++ b/tmp_scripts/run.py
@@ -5,11 +5,12 @@
 from argparse import ArgumentParser
 from pathlib import Path
 
-import rapids_singlecell as rsc
-from rapids_singlecell.pertpy_gpu._edistance import pertpy_edistance
 import anndata as ad
 from pertpy.tools import Distance
 
+import rapids_singlecell as rsc
+from rapids_singlecell.pertpy_gpu._edistance import pertpy_edistance
+
 if __name__ == "__main__":
     parser = ArgumentParser()
     parser.add_argument("--bootstrap", action="store_true")
@@ -31,13 +32,19 @@
     start_time = time.time()
     if gpu:
         res = pertpy_edistance(
-            adata, groupby=obs_key, obsm_key="X_pca", bootstrap=bootstrap, n_bootstrap=100
+            adata,
+            groupby=obs_key,
+            obsm_key="X_pca",
+            bootstrap=bootstrap,
+            n_bootstrap=100,
         )
         df_gpu = res.distances
         df_gpu_var = res.distances_var
     else:
         if bootstrap:
-            df, df_var = dist.pairwise(adata, groupby=obs_key, bootstrap=bootstrap, n_bootstrap=100)
+            df, df_var = dist.pairwise(
+                adata, groupby=obs_key, bootstrap=bootstrap, n_bootstrap=100
+            )
         else:
             df = dist.pairwise(adata, groupby=obs_key)
             df_var = None

From 63e6e83d37ccbeb65d1c6c4ce8eae4415a49d8b0 Mon Sep 17 00:00:00 2001
From: Intron7 
Date: Fri, 24 Oct 2025 17:05:49 +0200
Subject: [PATCH 23/23] speed up kernel

---
 src/rapids_singlecell/__init__.py             |  2 +-
 .../pertpy_gpu/_edistance.py                  | 33 ++------------
 .../pertpy_gpu/_kernels/_edistance.py         | 44 +++++++++++--------
 3 files changed, 31 insertions(+), 48 deletions(-)

diff --git a/src/rapids_singlecell/__init__.py b/src/rapids_singlecell/__init__.py
index 4ac1caff..01aa831c 100644
--- a/src/rapids_singlecell/__init__.py
+++ b/src/rapids_singlecell/__init__.py
@@ -2,7 +2,7 @@
 
 import cuml.internals.logger as logger
 
-from . import dcg, get, gr, pp, tl
+from . import dcg, get, gr, pp, ptg, tl
 from ._version import __version__
 
 logger.set_level(2)
diff --git a/src/rapids_singlecell/pertpy_gpu/_edistance.py b/src/rapids_singlecell/pertpy_gpu/_edistance.py
index 65fe12d2..cc17911b 100644
--- a/src/rapids_singlecell/pertpy_gpu/_edistance.py
+++ b/src/rapids_singlecell/pertpy_gpu/_edistance.py
@@ -23,9 +23,6 @@ class EDistanceResult(NamedTuple):
     distances_var: pd.DataFrame | None
 
 
-compute_group_distances_kernel = get_compute_group_distances_kernel()
-
-
 def pertpy_edistance(
     adata: AnnData,
     groupby: str,
@@ -150,24 +147,12 @@ def _pairwise_means(
     num_pairs = len(pair_left)  # k * (k-1) pairs instead of k²
 
     # Allocate output for off-diagonal distances only
-    d_other_offdiag = cp.zeros(num_pairs, dtype=embedding.dtype)
-
-    # Choose optimal block size
-    props = cp.cuda.runtime.getDeviceProperties(0)
-    max_smem = int(props.get("sharedMemPerBlock", 48 * 1024))
-
-    chosen_threads = None
-    shared_mem_size = 0  # TODO: think of a better way to do this
-    for tpb in (1024, 512, 256, 128, 64, 32):
-        required = tpb * cp.dtype(cp.float32).itemsize
-        if required <= max_smem:
-            chosen_threads = tpb
-            shared_mem_size = required
-            break
+    pairwise_means = cp.zeros((k, k), dtype=embedding.dtype)
 
+    compute_group_distances_kernel = get_compute_group_distances_kernel(embedding.dtype)
     # Launch kernel - one block per OFF-DIAGONAL group pair only
     grid = (num_pairs,)
-    block = (chosen_threads,)
+    block = (1024,)
     compute_group_distances_kernel(
         grid,
         block,
@@ -177,22 +162,12 @@ def _pairwise_means(
             cell_indices,
             pair_left,
             pair_right,
-            d_other_offdiag,
+            pairwise_means,
             k,
             n_features,
         ),
-        shared_mem=shared_mem_size,
     )
 
-    # Build full k x k matrix
-    pairwise_means = cp.zeros((k, k), dtype=np.float32)
-
-    # Fill the full matrix
-    for i, idx in enumerate(pair_indices.get()):
-        a, b = divmod(idx, k)
-        pairwise_means[a, b] = d_other_offdiag[i]
-        pairwise_means[b, a] = d_other_offdiag[i]
-
     return pairwise_means
 
 
diff --git a/src/rapids_singlecell/pertpy_gpu/_kernels/_edistance.py b/src/rapids_singlecell/pertpy_gpu/_kernels/_edistance.py
index d669c2da..70dafc05 100644
--- a/src/rapids_singlecell/pertpy_gpu/_kernels/_edistance.py
+++ b/src/rapids_singlecell/pertpy_gpu/_kernels/_edistance.py
@@ -3,21 +3,20 @@
 from cuml.common.kernel_utils import cuda_kernel_factory
 
 _compute_group_distances_kernel = r"""
-(const float* __restrict__ embedding,
+(const {0}* __restrict__ embedding,
  const int* __restrict__ cat_offsets,
  const int* __restrict__ cell_indices,
  const int* __restrict__ pair_left,
  const int* __restrict__ pair_right,
- float* __restrict__ d_other,
+ {0}* __restrict__ pairwise_means,
  int k,
  int n_features) {
-    extern __shared__ float shared_sums[];
 
     const int thread_id = threadIdx.x;
     const int block_id = blockIdx.x;
     const int block_size = blockDim.x;
 
-    float local_sum = 0.0f;
+    {0} local_sum = ({0})(0.0);
 
     const int a = pair_left[block_id];
     const int b = pair_right[block_id];
@@ -36,37 +35,46 @@
         for (int jb = start_b; jb < end_b; ++jb) {
             const int idx_j = cell_indices[jb];
 
-            double dist_sq = 0.0;
+            {0} dist_sq = ({0})(0.0);
             for (int feat = 0; feat < n_features; ++feat) {
-                double diff = (double)embedding[idx_i * n_features + feat] -
-                              (double)embedding[idx_j * n_features + feat];
+                {0} diff = embedding[idx_i * n_features + feat] -
+                              embedding[idx_j * n_features + feat];
                 dist_sq += diff * diff;
             }
 
-            local_sum += (float)sqrt(dist_sq);
+            local_sum += sqrt(dist_sq);
         }
     }
 
-    shared_sums[thread_id] = local_sum;
+    // --- warp-shuffle reduction -------------
+    #pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1)
+        local_sum += __shfl_down_sync(0xffffffff, local_sum, offset);
+
+    // --- block reduce -----------------------
+    static __shared__ {0} s[32];              // one per warp
+    if ((threadIdx.x & 31) == 0) s[threadIdx.x>>5] = local_sum;
     __syncthreads();
 
-    for (int stride = block_size / 2; stride > 0; stride >>= 1) {
-        if (thread_id < stride) {
-            shared_sums[thread_id] += shared_sums[thread_id + stride];
+    if (threadIdx.x < 32) {
+        {0} val = (threadIdx.x < (blockDim.x>>5)) ? s[threadIdx.x] : ({0})(0.0);
+        #pragma unroll
+        for (int offset = 16; offset > 0; offset >>= 1)
+            val += __shfl_down_sync(0xffffffff, val, offset);
+        if (threadIdx.x == 0) {
+            {0} mean = val/(({0})(n_a) * ({0})(n_b));
+            pairwise_means[a * k + b] = mean;
+            pairwise_means[b * k + a] = mean;
         }
-        __syncthreads();
-    }
 
-    if (thread_id == 0) {
-        d_other[block_id] = shared_sums[0] / (float)(n_a * n_b);
     }
 }
 """
 
 
-def get_compute_group_distances_kernel():
+def get_compute_group_distances_kernel(dtype):
     return cuda_kernel_factory(
         _compute_group_distances_kernel,
-        (),
+        (dtype,),
         "compute_group_distances",
     )