diff --git a/.gitignore b/.gitignore index 451e4d5..886d04e 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +.aider* diff --git a/src/graphcompass/__init__.py b/src/graphcompass/__init__.py index 05cfca9..5536b0d 100644 --- a/src/graphcompass/__init__.py +++ b/src/graphcompass/__init__.py @@ -1,4 +1,4 @@ -"""Graph-COMPASS.""" +"""GraphCompass""" from graphcompass import pl from graphcompass import tl from graphcompass import datasets diff --git a/src/graphcompass/__main__.py b/src/graphcompass/__main__.py index 233aea7..e637364 100644 --- a/src/graphcompass/__main__.py +++ b/src/graphcompass/__main__.py @@ -5,7 +5,7 @@ @click.command() @click.version_option() def main() -> None: - """Graph-COMPASS.""" + """GraphCompass""" if __name__ == "__main__": diff --git a/src/graphcompass/imports/wwl_package/propagation_scheme.py b/src/graphcompass/imports/wwl_package/propagation_scheme.py index ff2ecc8..dbe9298 100644 --- a/src/graphcompass/imports/wwl_package/propagation_scheme.py +++ b/src/graphcompass/imports/wwl_package/propagation_scheme.py @@ -1,4 +1,4 @@ -######## This file is copied from https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/propagation_scheme.py ######## +######## This file is adapted from https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/propagation_scheme.py ######## # ----------------------------------------------------------------------------- # This file contains the propagation schemes for categorically labeled and @@ -19,7 +19,7 @@ from collections import defaultdict from typing import List from tqdm import tqdm -from scipy.sparse import csr_matrix, diags +from scipy.sparse import csr_matrix, diags, eye #################### @@ -149,8 +149,14 @@ def _preprocess_graphs(self, X: List[ig.Graph]): # Iterate across graphs and load initial node features for graph in X: - if not 'label' in graph.vs.attribute_names(): - graph.vs['label'] = list(map(str, [l for l in graph.vs.degree()])) + if 'label' in graph.vs.attribute_names(): + labels = graph.vs['label'] + if not all(isinstance(label, (int, float)) for label in labels): + logging.warning("Non-numeric labels found. Falling back to node degrees.") + graph.vs['label'] = list(graph.vs.degree()) + else: + graph.vs['label'] = list(graph.vs.degree()) + # Get features and adjacency matrix node_features_cur = np.asarray(graph.vs['label']).astype(float).reshape(-1, 1) adj_mat_cur = csr_matrix(graph.get_adjacency_sparse()) @@ -212,7 +218,7 @@ def fit_transform(self, X: List[ig.Graph], node_features = None, num_iterations: if it == 0: graph_feat.append(node_features[i]) else: - adj_cur = adj_mat[i] + csr_matrix(np.identity(adj_mat[i].shape[0])) + adj_cur = adj_mat[i] + eye(adj_mat[i].shape[0], format='csr') adj_cur = self._create_adj_avg(adj_cur) adj_cur.setdiag(0) @@ -220,4 +226,4 @@ def fit_transform(self, X: List[ig.Graph], node_features = None, num_iterations: graph_feat.append(graph_feat_cur) self._label_sequences.append(np.concatenate(graph_feat, axis=1)) - return self._label_sequences \ No newline at end of file + return self._label_sequences diff --git a/src/graphcompass/imports/wwl_package/wwl.py b/src/graphcompass/imports/wwl_package/wwl.py index 311e89c..a404bf0 100644 --- a/src/graphcompass/imports/wwl_package/wwl.py +++ b/src/graphcompass/imports/wwl_package/wwl.py @@ -1,17 +1,19 @@ -######## This file is copied from https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/wwl.py ######## +""" +Wasserstein Weisfeiler-Lehman (WWL) kernel implementation. +This module provides tools for computing graph similarities using the Wasserstein +Weisfeiler-Lehman kernel, supporting both categorical and continuous graph embeddings. + +Adapted from: https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/wwl.py +""" -# ----------------------------------------------------------------------------- -# This file contains the API for the WWL kernel computations -# -# December 2019, M. Togninalli -# ----------------------------------------------------------------------------- import sys import logging -from tqdm import tqdm -import ot import numpy as np +import torch +from geomloss import SamplesLoss +from sklearn.preprocessing import OneHotEncoder from sklearn.metrics.pairwise import laplacian_kernel from .propagation_scheme import WeisfeilerLehman, ContinuousWeisfeilerLehman @@ -19,72 +21,101 @@ logging.basicConfig(level=logging.INFO) def logging_config(level='DEBUG'): - level = logging.getLevelName(level.upper()) - logging.basicConfig(level=level) - pass - -def _compute_wasserstein_distance(label_sequences, sinkhorn=False, - categorical=False, sinkhorn_lambda=1e-2): - ''' - Generate the Wasserstein distance matrix for the graphs embedded - in label_sequences - ''' - # Get the iteration number from the embedding file - n = len(label_sequences) - - M = np.zeros((n,n)) - # Iterate over pairs of graphs - for graph_index_1, graph_1 in enumerate(label_sequences): - # Only keep the embeddings for the first h iterations - labels_1 = label_sequences[graph_index_1] - for graph_index_2, graph_2 in tqdm(enumerate(label_sequences[graph_index_1:])): - labels_2 = label_sequences[graph_index_2 + graph_index_1] - # Get cost matrix - ground_distance = 'hamming' if categorical else 'euclidean' - costs = ot.dist(labels_1, labels_2, metric=ground_distance) - - if sinkhorn: - mat = ot.sinkhorn( - np.ones(len(labels_1))/len(labels_1), - np.ones(len(labels_2))/len(labels_2), - costs, - sinkhorn_lambda, - numItermax=50 - ) - M[graph_index_1, graph_index_2 + graph_index_1] = np.sum(np.multiply(mat, costs)) - else: - M[graph_index_1, graph_index_2 + graph_index_1] = \ - ot.emd2([], [], costs) - - M = (M + M.T) - return M - -def pairwise_wasserstein_distance(X, node_features = None, num_iterations=3, sinkhorn=False, enforce_continuous=False): + """Set the logging level for the application. + + Configures the global logging level to control the verbosity of log messages. + + Args: + level (str, optional): Logging level. Defaults to 'DEBUG'. + Typical values include 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'. """ - Pairwise computation of the Wasserstein distance between embeddings of the - graphs in X. - args: - X (List[ig.graphs]): List of graphs - node_features (array): Array containing the node features for continuously attributed graphs - num_iterations (int): Number of iterations for the propagation scheme - sinkhorn (bool): Indicates whether sinkhorn approximation should be used + logging.basicConfig(level=logging.getLevelName(level.upper())) + +def _compute_wasserstein_distance_geomloss(label_sequences, categorical=False, blur=0.05, p=2): + """Compute pairwise Wasserstein distances between graph node embeddings. + + Calculates the optimal transport distance between node embeddings using + the Sinkhorn algorithm. Automatically uses GPU acceleration if available. + + Args: + label_sequences (list): List of node embeddings for each graph + categorical (bool): Whether the node labels are categorical (discrete) or continuous embeddings. + blur (float, optional): Sinkhorn smoothing parameter. Defaults to 0.05. + p (int, optional): Power of the cost function. Defaults to 2 (squared Euclidean). + + Returns: + numpy.ndarray: Symmetric matrix of pairwise Wasserstein distances + + Notes: + This function differs from the legacy implementation in that it leverages + the GeomLoss library for GPU-accelerated Sinkhorn computations, enabling + efficient optimal transport calculations even on large graphs. The legacy + function uses the POT library and runs on CPU only, which can be slower + for large datasets. Additionally, this function handles both categorical + and continuous node features in a unified manner via one-hot encoding + for discrete labels. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + sinkhorn = SamplesLoss("sinkhorn", p=p, blur=blur) + + n = len(label_sequences) + M = torch.zeros((n, n), device=device) + + if categorical: + # Flatten all labels for one-hot encoder fitting + all_labels = np.concatenate(label_sequences).reshape(-1, 1) + enc = OneHotEncoder(sparse_output=False, dtype=np.float32) + enc.fit(all_labels) + + # Encode all graphs now for speed + encoded_sequences = [torch.tensor(enc.transform(seq.reshape(-1,1)), device=device) for seq in label_sequences] + + else: + # Assume label_sequences are arrays of shape (n_nodes, features) + encoded_sequences = [torch.tensor(seq, dtype=torch.float32, device=device) for seq in label_sequences] + + for i in range(n): + a = torch.ones(encoded_sequences[i].shape[0], device=device) / encoded_sequences[i].shape[0] + for j in range(i, n): + b = torch.ones(encoded_sequences[j].shape[0], device=device) / encoded_sequences[j].shape[0] + + dist = sinkhorn(a, encoded_sequences[i], b, encoded_sequences[j]) + M[i, j] = dist + M[j, i] = dist # symmetric + + return M.cpu().numpy() + +def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, enforce_continuous=False): + """Compute pairwise Wasserstein distances between graph embeddings. + + Determines the appropriate embedding scheme (categorical or continuous) + and computes the Wasserstein distances between graph representations. + + Args: + X (list): List of graphs to compare + node_features (array-like, optional): Pre-computed node features for continuous graphs + num_iterations (int, optional): Number of iterations for graph embedding. Defaults to 3. + enforce_continuous (bool, optional): Force use of continuous embedding scheme. Defaults to False. + + Returns: + numpy.ndarray: Matrix of pairwise Wasserstein distances between graphs """ # First check if the graphs are continuous vs categorical categorical = True if enforce_continuous: - logging.info('Enforce continous flag is on, using CONTINUOUS propagation scheme.') + logging.info('Continuous embedding enforced: Using continuous propagation scheme.') categorical = False elif node_features is not None: - logging.info('Continuous node features provided, using CONTINUOUS propagation scheme.') + logging.info('Continuous node features detected: Using continuous propagation scheme.') categorical = False else: for g in X: - if not 'label' in g.vs.attribute_names(): - logging.info('No label attributed to graphs, use degree instead and use CONTINUOUS propagation scheme.') + if 'label' not in g.vs.attribute_names() or not all(isinstance(label, (int, float)) for label in g.vs['label']): + logging.info('Invalid categorical labels found: Switching to continuous propagation scheme using node degrees.') categorical = False break if categorical: - logging.info('Categorically-labelled graphs, using CATEGORICAL propagation scheme.') + logging.info('Valid categorical graph labels detected: Using categorical propagation scheme.') # Embed the nodes if categorical: @@ -95,21 +126,26 @@ def pairwise_wasserstein_distance(X, node_features = None, num_iterations=3, sin node_representations = es.fit_transform(X, node_features=node_features, num_iterations=num_iterations) # Compute the Wasserstein distance - print("Computing Wasserstein distance between conditions...") - pairwise_distances = _compute_wasserstein_distance(node_representations, sinkhorn=sinkhorn, - categorical=categorical, sinkhorn_lambda=1e-2) + logging.info("Computing pairwise Wasserstein distances between graph embeddings...") + pairwise_distances = _compute_wasserstein_distance_geomloss(node_representations, categorical=categorical) return pairwise_distances -def wwl(X, node_features=None, num_iterations=3, sinkhorn=False, gamma=None): - """ - Pairwise computation of the Wasserstein Weisfeiler-Lehman kernel for graphs in X. +def wwl(X, node_features=None, num_iterations=3, gamma=None): + """Compute the Wasserstein Weisfeiler-Lehman (WWL) kernel for a set of graphs. + + Combines Wasserstein distance computation with a Laplacian kernel to + measure graph similarities. + + Args: + X (list): List of graphs to compare + node_features (array-like, optional): Pre-computed node features for continuous graphs + num_iterations (int, optional): Number of iterations for graph embedding. Defaults to 3. + gamma (float, optional): Scaling parameter for the Laplacian kernel. Defaults to None. + + Returns: + numpy.ndarray: Kernel matrix representing graph similarities """ D_W = pairwise_wasserstein_distance(X, node_features = node_features, - num_iterations=num_iterations, sinkhorn=sinkhorn) + num_iterations=num_iterations) wwl = laplacian_kernel(D_W, gamma=gamma) return wwl - - -####################### -# Class implementation -####################### \ No newline at end of file diff --git a/src/graphcompass/tl/_WLkernel.py b/src/graphcompass/tl/_WLkernel.py index bd3fca7..b7c2a11 100644 --- a/src/graphcompass/tl/_WLkernel.py +++ b/src/graphcompass/tl/_WLkernel.py @@ -8,14 +8,14 @@ from tqdm import tqdm from anndata import AnnData from graphcompass.tl.utils import _calculate_graph, _get_igraph -from graphcompass.imports.wwl_package import wwl, pairwise_wasserstein_distance +from graphcompass.imports.wwl_package import pairwise_wasserstein_distance def compare_conditions( adata: AnnData, library_key: str = "sample", cluster_key: str = "cell_type", - cell_type_keys: list = None, + cell_type_key: str = None, compute_spatial_graphs: bool = True, num_iterations: int = 3, kwargs_nhood_enrich={}, @@ -35,8 +35,8 @@ def compare_conditions( which stores mapping between ``library_id`` and obs. cluster_key Key in :attr:`anndata.AnnData.obs` where clustering is stored. - cell_type_keys - List of keys in :attr:`anndata.AnnData.obs` where cell types are stored. + cell_type_key + Key in :attr:`anndata.AnnData.obs` where cell types are stored. compute_spatial_graphs Set to False if spatial graphs have been calculated or `sq.gr.spatial_neighbors` has already been run before. kwargs_nhood_enrich @@ -49,7 +49,7 @@ def compare_conditions( Whether to return a copy of the Wasserstein distance object. """ if compute_spatial_graphs: - print("Computing spatial graphs...") + logging.info("Computing spatial graphs...") _calculate_graph( adata=adata, library_key=library_key, @@ -59,42 +59,31 @@ def compare_conditions( **kwargs ) else: - print("Spatial graphs were previously computed. Skipping computing spatial graphs...") + logging.info("Spatial graphs were previously computed. Skipping computing spatial graphs...") samples = adata.obs[library_key].unique() graphs = [] node_features = [] - cell_types = [] adata.uns["wl_kernel"] = {} - adata.uns["wl_kernel"] = {} - if cell_type_keys is not None: - for cell_type_key in cell_type_keys: - graphs = [] - node_features = [] - status = [] - cell_types = [] - adata.uns["wl_kernel"] = {} - adata.uns["wl_kernel"] = {} - - adata.uns["wl_kernel"][cell_type_key] = {} - adata.uns["wl_kernel"][cell_type_key] = {} - for sample in samples: - adata_sample = adata[adata.obs[library_key] == sample] - status.append(adata_sample.obs[library_key][0]) - graphs.append(_get_igraph(adata_sample, cluster_key=None)) - - node_features.append(np.array(adata_sample.obs[cell_type_key].values)) - cell_types.append(np.full(len(adata_sample.obs[cell_type_key]), cell_type_key)) - - node_features = np.array(node_features, dtype=object) - - wasserstein_distance = pairwise_wasserstein_distance(graphs, node_features=node_features, num_iterations=num_iterations) - adata.uns["wl_kernel"][cell_type_key]["wasserstein_distance"] = pd.DataFrame(wasserstein_distance, columns=samples, index=samples) + if cell_type_key is not None: + graphs = [] + adata.uns["wl_kernel"][cell_type_key] = {} + for sample in samples: + adata_sample = adata[adata.obs[library_key] == sample] + g = _get_igraph(adata_sample, cluster_key=cell_type_key) + graphs.append(g) + + wasserstein_distance = pairwise_wasserstein_distance( + graphs, + node_features=None, # triggers categorical WL + num_iterations=num_iterations + ) + adata.uns["wl_kernel"][cell_type_key]["wasserstein_distance"] = pd.DataFrame(wasserstein_distance, columns=samples, index=samples) else: - print("Defining node features...") + logging.info("Defining node features...") for sample in tqdm(samples): adata_sample = adata[adata.obs[library_key] == sample] graphs.append( @@ -109,9 +98,13 @@ def compare_conditions( node_features.append(np.array(features)) node_features = np.array(node_features, dtype=object) - wasserstein_distance = pairwise_wasserstein_distance(graphs, node_features=node_features, num_iterations=num_iterations) + wasserstein_distance = pairwise_wasserstein_distance( + graphs, + node_features=node_features, # triggers continuous WL + num_iterations=num_iterations + ) adata.uns["wl_kernel"]["wasserstein_distance"] = pd.DataFrame(wasserstein_distance, columns=samples, index=samples) - print("Done!") + logging.info("Done!") if copy: - return wasserstein_distance \ No newline at end of file + return wasserstein_distance