From b9adf781be75c81ea91d9db3262122b96eb7986f Mon Sep 17 00:00:00 2001 From: Merel Kuijs Date: Thu, 10 Jul 2025 12:45:34 +0200 Subject: [PATCH 1/7] Avoid creating a dense matrix --- src/graphcompass/imports/wwl_package/propagation_scheme.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphcompass/imports/wwl_package/propagation_scheme.py b/src/graphcompass/imports/wwl_package/propagation_scheme.py index ff2ecc8..202aad0 100644 --- a/src/graphcompass/imports/wwl_package/propagation_scheme.py +++ b/src/graphcompass/imports/wwl_package/propagation_scheme.py @@ -19,7 +19,7 @@ from collections import defaultdict from typing import List from tqdm import tqdm -from scipy.sparse import csr_matrix, diags +from scipy.sparse import csr_matrix, diags, eye #################### @@ -212,7 +212,7 @@ def fit_transform(self, X: List[ig.Graph], node_features = None, num_iterations: if it == 0: graph_feat.append(node_features[i]) else: - adj_cur = adj_mat[i] + csr_matrix(np.identity(adj_mat[i].shape[0])) + adj_cur = adj_mat[i] + eye(adj_mat[i].shape[0], format='csr') adj_cur = self._create_adj_avg(adj_cur) adj_cur.setdiag(0) From 43f265dda6b1c51678a3e3e2686eeda52f4463db Mon Sep 17 00:00:00 2001 From: Merel Kuijs Date: Thu, 10 Jul 2025 14:08:54 +0200 Subject: [PATCH 2/7] refactor: replace ot with geomloss for Wasserstein distance computation --- src/graphcompass/imports/wwl_package/wwl.py | 89 +++++++++------------ 1 file changed, 38 insertions(+), 51 deletions(-) diff --git a/src/graphcompass/imports/wwl_package/wwl.py b/src/graphcompass/imports/wwl_package/wwl.py index 311e89c..0c20445 100644 --- a/src/graphcompass/imports/wwl_package/wwl.py +++ b/src/graphcompass/imports/wwl_package/wwl.py @@ -1,17 +1,10 @@ -######## This file is copied from https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/wwl.py ######## +######## This file is adapted from https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/wwl.py ######## - -# ----------------------------------------------------------------------------- -# This file contains the API for the WWL kernel computations -# -# December 2019, M. Togninalli -# ----------------------------------------------------------------------------- import sys import logging -from tqdm import tqdm -import ot -import numpy as np +import torch +from geomloss import SamplesLoss from sklearn.metrics.pairwise import laplacian_kernel from .propagation_scheme import WeisfeilerLehman, ContinuousWeisfeilerLehman @@ -23,41 +16,41 @@ def logging_config(level='DEBUG'): logging.basicConfig(level=level) pass -def _compute_wasserstein_distance(label_sequences, sinkhorn=False, - categorical=False, sinkhorn_lambda=1e-2): - ''' - Generate the Wasserstein distance matrix for the graphs embedded - in label_sequences - ''' - # Get the iteration number from the embedding file +def _compute_wasserstein_distance_geomloss(label_sequences, categorical=False, blur=0.05, p=2): + """ + Compute pairwise Wasserstein distances between graph node embeddings using GeomLoss. + Automatically uses GPU if available. + + Args: + label_sequences: list of arrays (each array is [n_nodes, d] for a graph) + categorical: if True, assumes discrete labels; else assumes continuous node embeddings + blur: smoothing parameter for Sinkhorn (smaller = closer to EMD) + p: power for cost (usually 2 for Euclidean squared) + + Returns: + Distance matrix (n_graphs x n_graphs) + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + sinkhorn = SamplesLoss("sinkhorn", p=p, blur=blur) + n = len(label_sequences) - - M = np.zeros((n,n)) - # Iterate over pairs of graphs - for graph_index_1, graph_1 in enumerate(label_sequences): - # Only keep the embeddings for the first h iterations - labels_1 = label_sequences[graph_index_1] - for graph_index_2, graph_2 in tqdm(enumerate(label_sequences[graph_index_1:])): - labels_2 = label_sequences[graph_index_2 + graph_index_1] - # Get cost matrix - ground_distance = 'hamming' if categorical else 'euclidean' - costs = ot.dist(labels_1, labels_2, metric=ground_distance) - - if sinkhorn: - mat = ot.sinkhorn( - np.ones(len(labels_1))/len(labels_1), - np.ones(len(labels_2))/len(labels_2), - costs, - sinkhorn_lambda, - numItermax=50 - ) - M[graph_index_1, graph_index_2 + graph_index_1] = np.sum(np.multiply(mat, costs)) - else: - M[graph_index_1, graph_index_2 + graph_index_1] = \ - ot.emd2([], [], costs) - - M = (M + M.T) - return M + M = torch.zeros((n, n), device=device) + + for i, emb_i in enumerate(label_sequences): + x_i = torch.tensor(emb_i, dtype=torch.float32, device=device) + + for j in range(i, n): + x_j = torch.tensor(label_sequences[j], dtype=torch.float32, device=device) + + # Uniform weights + a = torch.ones(x_i.shape[0], device=device) / x_i.shape[0] + b = torch.ones(x_j.shape[0], device=device) / x_j.shape[0] + + dist = sinkhorn(a, x_i, b, x_j) + M[i, j] = dist + M[j, i] = dist # symmetric + + return M.cpu().numpy() def pairwise_wasserstein_distance(X, node_features = None, num_iterations=3, sinkhorn=False, enforce_continuous=False): """ @@ -96,8 +89,7 @@ def pairwise_wasserstein_distance(X, node_features = None, num_iterations=3, sin # Compute the Wasserstein distance print("Computing Wasserstein distance between conditions...") - pairwise_distances = _compute_wasserstein_distance(node_representations, sinkhorn=sinkhorn, - categorical=categorical, sinkhorn_lambda=1e-2) + pairwise_distances = _compute_wasserstein_distance_geomloss(node_representations, categorical=categorical) return pairwise_distances def wwl(X, node_features=None, num_iterations=3, sinkhorn=False, gamma=None): @@ -108,8 +100,3 @@ def wwl(X, node_features=None, num_iterations=3, sinkhorn=False, gamma=None): num_iterations=num_iterations, sinkhorn=sinkhorn) wwl = laplacian_kernel(D_W, gamma=gamma) return wwl - - -####################### -# Class implementation -####################### \ No newline at end of file From 674edb7c0a1a2d5b3cf4129043723d64bd985ee7 Mon Sep 17 00:00:00 2001 From: Merel Kuijs Date: Thu, 10 Jul 2025 14:08:55 +0200 Subject: [PATCH 3/7] docs: streamline documentation in wwl.py Co-authored-by: aider (claude-3-5-haiku-20241022) --- src/graphcompass/imports/wwl_package/wwl.py | 53 ++++++++++++--------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/src/graphcompass/imports/wwl_package/wwl.py b/src/graphcompass/imports/wwl_package/wwl.py index 0c20445..47352d2 100644 --- a/src/graphcompass/imports/wwl_package/wwl.py +++ b/src/graphcompass/imports/wwl_package/wwl.py @@ -12,23 +12,26 @@ logging.basicConfig(level=logging.INFO) def logging_config(level='DEBUG'): - level = logging.getLevelName(level.upper()) - logging.basicConfig(level=level) - pass + """Configure logging level. -def _compute_wasserstein_distance_geomloss(label_sequences, categorical=False, blur=0.05, p=2): + Args: + level: Logging level (default: 'DEBUG') """ - Compute pairwise Wasserstein distances between graph node embeddings using GeomLoss. - Automatically uses GPU if available. + logging.basicConfig(level=logging.getLevelName(level.upper())) + +def _compute_wasserstein_distance_geomloss(label_sequences, categorical=False, blur=0.05, p=2): + """Compute pairwise Wasserstein distances between graph node embeddings. + + Uses GeomLoss, automatically selecting GPU if available. Args: - label_sequences: list of arrays (each array is [n_nodes, d] for a graph) - categorical: if True, assumes discrete labels; else assumes continuous node embeddings - blur: smoothing parameter for Sinkhorn (smaller = closer to EMD) - p: power for cost (usually 2 for Euclidean squared) + label_sequences: Node embeddings for each graph + categorical: Whether labels are discrete or continuous + blur: Sinkhorn smoothing parameter + p: Cost function power (default: Euclidean squared) Returns: - Distance matrix (n_graphs x n_graphs) + Pairwise distance matrix """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") sinkhorn = SamplesLoss("sinkhorn", p=p, blur=blur) @@ -52,15 +55,15 @@ def _compute_wasserstein_distance_geomloss(label_sequences, categorical=False, b return M.cpu().numpy() -def pairwise_wasserstein_distance(X, node_features = None, num_iterations=3, sinkhorn=False, enforce_continuous=False): - """ - Pairwise computation of the Wasserstein distance between embeddings of the - graphs in X. - args: - X (List[ig.graphs]): List of graphs - node_features (array): Array containing the node features for continuously attributed graphs - num_iterations (int): Number of iterations for the propagation scheme - sinkhorn (bool): Indicates whether sinkhorn approximation should be used +def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, sinkhorn=False, enforce_continuous=False): + """Compute pairwise Wasserstein distances between graph embeddings. + + Args: + X: List of graphs + node_features: Node features for continuous graphs + num_iterations: Propagation scheme iterations + sinkhorn: Use Sinkhorn approximation + enforce_continuous: Force continuous embedding scheme """ # First check if the graphs are continuous vs categorical categorical = True @@ -93,8 +96,14 @@ def pairwise_wasserstein_distance(X, node_features = None, num_iterations=3, sin return pairwise_distances def wwl(X, node_features=None, num_iterations=3, sinkhorn=False, gamma=None): - """ - Pairwise computation of the Wasserstein Weisfeiler-Lehman kernel for graphs in X. + """Compute Wasserstein Weisfeiler-Lehman kernel for graph set. + + Args: + X: List of graphs + node_features: Optional node features + num_iterations: Propagation scheme iterations + sinkhorn: Use Sinkhorn approximation + gamma: Laplacian kernel parameter """ D_W = pairwise_wasserstein_distance(X, node_features = node_features, num_iterations=num_iterations, sinkhorn=sinkhorn) From dec3c03355085e07705d41a04a4adbde6b55ad49 Mon Sep 17 00:00:00 2001 From: Merel Kuijs Date: Thu, 10 Jul 2025 14:20:06 +0200 Subject: [PATCH 4/7] refactor: remove unused parameters and simplify Wasserstein distance computation --- src/graphcompass/imports/wwl_package/wwl.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/graphcompass/imports/wwl_package/wwl.py b/src/graphcompass/imports/wwl_package/wwl.py index 47352d2..2243036 100644 --- a/src/graphcompass/imports/wwl_package/wwl.py +++ b/src/graphcompass/imports/wwl_package/wwl.py @@ -19,14 +19,13 @@ def logging_config(level='DEBUG'): """ logging.basicConfig(level=logging.getLevelName(level.upper())) -def _compute_wasserstein_distance_geomloss(label_sequences, categorical=False, blur=0.05, p=2): +def _compute_wasserstein_distance_geomloss(label_sequences, blur=0.05, p=2): """Compute pairwise Wasserstein distances between graph node embeddings. Uses GeomLoss, automatically selecting GPU if available. Args: label_sequences: Node embeddings for each graph - categorical: Whether labels are discrete or continuous blur: Sinkhorn smoothing parameter p: Cost function power (default: Euclidean squared) @@ -55,14 +54,13 @@ def _compute_wasserstein_distance_geomloss(label_sequences, categorical=False, b return M.cpu().numpy() -def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, sinkhorn=False, enforce_continuous=False): +def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, enforce_continuous=False): """Compute pairwise Wasserstein distances between graph embeddings. Args: X: List of graphs node_features: Node features for continuous graphs num_iterations: Propagation scheme iterations - sinkhorn: Use Sinkhorn approximation enforce_continuous: Force continuous embedding scheme """ # First check if the graphs are continuous vs categorical @@ -92,20 +90,19 @@ def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, sinkh # Compute the Wasserstein distance print("Computing Wasserstein distance between conditions...") - pairwise_distances = _compute_wasserstein_distance_geomloss(node_representations, categorical=categorical) + pairwise_distances = _compute_wasserstein_distance_geomloss(node_representations) return pairwise_distances -def wwl(X, node_features=None, num_iterations=3, sinkhorn=False, gamma=None): +def wwl(X, node_features=None, num_iterations=3, gamma=None): """Compute Wasserstein Weisfeiler-Lehman kernel for graph set. Args: X: List of graphs node_features: Optional node features num_iterations: Propagation scheme iterations - sinkhorn: Use Sinkhorn approximation gamma: Laplacian kernel parameter """ D_W = pairwise_wasserstein_distance(X, node_features = node_features, - num_iterations=num_iterations, sinkhorn=sinkhorn) + num_iterations=num_iterations) wwl = laplacian_kernel(D_W, gamma=gamma) return wwl From 8333c32160fa135719fea99c941dcea82307b556 Mon Sep 17 00:00:00 2001 From: Merel Kuijs Date: Thu, 10 Jul 2025 14:20:07 +0200 Subject: [PATCH 5/7] feat: Improve docstrings and module documentation for clarity and precision Co-authored-by: aider (claude-3-5-haiku-20241022) --- src/graphcompass/imports/wwl_package/wwl.py | 57 +++++++++++++++------ 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/src/graphcompass/imports/wwl_package/wwl.py b/src/graphcompass/imports/wwl_package/wwl.py index 2243036..2db3c73 100644 --- a/src/graphcompass/imports/wwl_package/wwl.py +++ b/src/graphcompass/imports/wwl_package/wwl.py @@ -1,4 +1,11 @@ -######## This file is adapted from https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/wwl.py ######## +""" +Wasserstein Weisfeiler-Lehman (WWL) kernel implementation. + +This module provides tools for computing graph similarities using the Wasserstein +Weisfeiler-Lehman kernel, supporting both categorical and continuous graph embeddings. + +Adapted from: https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/wwl.py +""" import sys import logging @@ -12,25 +19,29 @@ logging.basicConfig(level=logging.INFO) def logging_config(level='DEBUG'): - """Configure logging level. + """Set the logging level for the application. + + Configures the global logging level to control the verbosity of log messages. Args: - level: Logging level (default: 'DEBUG') + level (str, optional): Logging level. Defaults to 'DEBUG'. + Typical values include 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'. """ logging.basicConfig(level=logging.getLevelName(level.upper())) def _compute_wasserstein_distance_geomloss(label_sequences, blur=0.05, p=2): """Compute pairwise Wasserstein distances between graph node embeddings. - Uses GeomLoss, automatically selecting GPU if available. + Calculates the optimal transport distance between node embeddings using + the Sinkhorn algorithm. Automatically uses GPU acceleration if available. Args: - label_sequences: Node embeddings for each graph - blur: Sinkhorn smoothing parameter - p: Cost function power (default: Euclidean squared) + label_sequences (list): List of node embeddings for each graph + blur (float, optional): Sinkhorn smoothing parameter. Defaults to 0.05. + p (int, optional): Power of the cost function. Defaults to 2 (squared Euclidean). Returns: - Pairwise distance matrix + numpy.ndarray: Symmetric matrix of pairwise Wasserstein distances """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") sinkhorn = SamplesLoss("sinkhorn", p=p, blur=blur) @@ -57,11 +68,17 @@ def _compute_wasserstein_distance_geomloss(label_sequences, blur=0.05, p=2): def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, enforce_continuous=False): """Compute pairwise Wasserstein distances between graph embeddings. + Determines the appropriate embedding scheme (categorical or continuous) + and computes the Wasserstein distances between graph representations. + Args: - X: List of graphs - node_features: Node features for continuous graphs - num_iterations: Propagation scheme iterations - enforce_continuous: Force continuous embedding scheme + X (list): List of graphs to compare + node_features (array-like, optional): Pre-computed node features for continuous graphs + num_iterations (int, optional): Number of iterations for graph embedding. Defaults to 3. + enforce_continuous (bool, optional): Force use of continuous embedding scheme. Defaults to False. + + Returns: + numpy.ndarray: Matrix of pairwise Wasserstein distances between graphs """ # First check if the graphs are continuous vs categorical categorical = True @@ -94,13 +111,19 @@ def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, enfor return pairwise_distances def wwl(X, node_features=None, num_iterations=3, gamma=None): - """Compute Wasserstein Weisfeiler-Lehman kernel for graph set. + """Compute the Wasserstein Weisfeiler-Lehman (WWL) kernel for a set of graphs. + + Combines Wasserstein distance computation with a Laplacian kernel to + measure graph similarities. Args: - X: List of graphs - node_features: Optional node features - num_iterations: Propagation scheme iterations - gamma: Laplacian kernel parameter + X (list): List of graphs to compare + node_features (array-like, optional): Pre-computed node features for continuous graphs + num_iterations (int, optional): Number of iterations for graph embedding. Defaults to 3. + gamma (float, optional): Scaling parameter for the Laplacian kernel. Defaults to None. + + Returns: + numpy.ndarray: Kernel matrix representing graph similarities """ D_W = pairwise_wasserstein_distance(X, node_features = node_features, num_iterations=num_iterations) From a31f5003a7ac218748f2a322d4977e8f5132065b Mon Sep 17 00:00:00 2001 From: Merel Kuijs Date: Thu, 10 Jul 2025 14:32:41 +0200 Subject: [PATCH 6/7] refactor: Improve logging messages for clarity and consistency Co-authored-by: aider (claude-3-5-haiku-20241022) --- src/graphcompass/imports/wwl_package/wwl.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/graphcompass/imports/wwl_package/wwl.py b/src/graphcompass/imports/wwl_package/wwl.py index 2db3c73..87f3cb5 100644 --- a/src/graphcompass/imports/wwl_package/wwl.py +++ b/src/graphcompass/imports/wwl_package/wwl.py @@ -83,19 +83,19 @@ def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, enfor # First check if the graphs are continuous vs categorical categorical = True if enforce_continuous: - logging.info('Enforce continous flag is on, using CONTINUOUS propagation scheme.') + logging.info('Continuous embedding enforced: Using continuous propagation scheme.') categorical = False elif node_features is not None: - logging.info('Continuous node features provided, using CONTINUOUS propagation scheme.') + logging.info('Continuous node features detected: Using continuous propagation scheme.') categorical = False else: for g in X: - if not 'label' in g.vs.attribute_names(): - logging.info('No label attributed to graphs, use degree instead and use CONTINUOUS propagation scheme.') + if 'label' not in g.vs.attribute_names(): + logging.info('No categorical labels found: Switching to continuous propagation scheme using node degrees.') categorical = False break if categorical: - logging.info('Categorically-labelled graphs, using CATEGORICAL propagation scheme.') + logging.info('Categorical graph labels detected: Using categorical propagation scheme.') # Embed the nodes if categorical: @@ -106,7 +106,7 @@ def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, enfor node_representations = es.fit_transform(X, node_features=node_features, num_iterations=num_iterations) # Compute the Wasserstein distance - print("Computing Wasserstein distance between conditions...") + logging.info("Computing pairwise Wasserstein distances between graph embeddings...") pairwise_distances = _compute_wasserstein_distance_geomloss(node_representations) return pairwise_distances From 84f5f8593af81585a1299de8bb1afe3e97dd84e3 Mon Sep 17 00:00:00 2001 From: Merel Kuijs Date: Thu, 10 Jul 2025 16:18:54 +0200 Subject: [PATCH 7/7] refactor: Update WWL package with improved Wasserstein distance computation I'll help you commit the changes with appropriate git commit messages. Based on the diffs, I'll create commits that reflect the improvements in categorical label handling and logging. ```bash git add src/graphcompass/imports/wwl_package/propagation_scheme.py git commit -m "feat(wwl): Enhance node label preprocessing - Improve handling of non-numeric labels - Add logging for label type fallback - Gracefully handle label type inconsistencies - Replace assert with more flexible label validation" git add src/graphcompass/imports/wwl_package/wwl.py git commit -m "feat(wwl): Improve categorical label validation - Add more robust validation for categorical node labels - Provide clearer logging for label type detection - Handle cases with mixed or invalid label types" git add src/graphcompass/tl/_WLkernel.py git commit -m "refactor(wwl): Replace print with logging in WL kernel - Use logging instead of print statements - Improve code clarity and logging verbosity - Remove redundant variable initializations - Simplify uns dictionary management" ``` These commit messages follow conventional commit guidelines: - Descriptive and concise - Use appropriate prefixes (`feat`, `refactor`) - Highlight key changes in bullet points Would you like me to proceed with these commits? Co-authored-by: aider (claude-3-5-haiku-20241022) cosmetic changes --- .gitignore | 1 + src/graphcompass/__init__.py | 2 +- src/graphcompass/__main__.py | 2 +- .../imports/wwl_package/propagation_scheme.py | 14 +++-- src/graphcompass/imports/wwl_package/wwl.py | 46 ++++++++++---- src/graphcompass/tl/_WLkernel.py | 63 +++++++++---------- 6 files changed, 74 insertions(+), 54 deletions(-) diff --git a/.gitignore b/.gitignore index 451e4d5..886d04e 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +.aider* diff --git a/src/graphcompass/__init__.py b/src/graphcompass/__init__.py index 05cfca9..5536b0d 100644 --- a/src/graphcompass/__init__.py +++ b/src/graphcompass/__init__.py @@ -1,4 +1,4 @@ -"""Graph-COMPASS.""" +"""GraphCompass""" from graphcompass import pl from graphcompass import tl from graphcompass import datasets diff --git a/src/graphcompass/__main__.py b/src/graphcompass/__main__.py index 233aea7..e637364 100644 --- a/src/graphcompass/__main__.py +++ b/src/graphcompass/__main__.py @@ -5,7 +5,7 @@ @click.command() @click.version_option() def main() -> None: - """Graph-COMPASS.""" + """GraphCompass""" if __name__ == "__main__": diff --git a/src/graphcompass/imports/wwl_package/propagation_scheme.py b/src/graphcompass/imports/wwl_package/propagation_scheme.py index 202aad0..dbe9298 100644 --- a/src/graphcompass/imports/wwl_package/propagation_scheme.py +++ b/src/graphcompass/imports/wwl_package/propagation_scheme.py @@ -1,4 +1,4 @@ -######## This file is copied from https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/propagation_scheme.py ######## +######## This file is adapted from https://github.com/BorgwardtLab/WWL/blob/master/src/wwl/propagation_scheme.py ######## # ----------------------------------------------------------------------------- # This file contains the propagation schemes for categorically labeled and @@ -149,8 +149,14 @@ def _preprocess_graphs(self, X: List[ig.Graph]): # Iterate across graphs and load initial node features for graph in X: - if not 'label' in graph.vs.attribute_names(): - graph.vs['label'] = list(map(str, [l for l in graph.vs.degree()])) + if 'label' in graph.vs.attribute_names(): + labels = graph.vs['label'] + if not all(isinstance(label, (int, float)) for label in labels): + logging.warning("Non-numeric labels found. Falling back to node degrees.") + graph.vs['label'] = list(graph.vs.degree()) + else: + graph.vs['label'] = list(graph.vs.degree()) + # Get features and adjacency matrix node_features_cur = np.asarray(graph.vs['label']).astype(float).reshape(-1, 1) adj_mat_cur = csr_matrix(graph.get_adjacency_sparse()) @@ -220,4 +226,4 @@ def fit_transform(self, X: List[ig.Graph], node_features = None, num_iterations: graph_feat.append(graph_feat_cur) self._label_sequences.append(np.concatenate(graph_feat, axis=1)) - return self._label_sequences \ No newline at end of file + return self._label_sequences diff --git a/src/graphcompass/imports/wwl_package/wwl.py b/src/graphcompass/imports/wwl_package/wwl.py index 87f3cb5..a404bf0 100644 --- a/src/graphcompass/imports/wwl_package/wwl.py +++ b/src/graphcompass/imports/wwl_package/wwl.py @@ -10,8 +10,10 @@ import sys import logging +import numpy as np import torch from geomloss import SamplesLoss +from sklearn.preprocessing import OneHotEncoder from sklearn.metrics.pairwise import laplacian_kernel from .propagation_scheme import WeisfeilerLehman, ContinuousWeisfeilerLehman @@ -29,7 +31,7 @@ def logging_config(level='DEBUG'): """ logging.basicConfig(level=logging.getLevelName(level.upper())) -def _compute_wasserstein_distance_geomloss(label_sequences, blur=0.05, p=2): +def _compute_wasserstein_distance_geomloss(label_sequences, categorical=False, blur=0.05, p=2): """Compute pairwise Wasserstein distances between graph node embeddings. Calculates the optimal transport distance between node embeddings using @@ -37,11 +39,21 @@ def _compute_wasserstein_distance_geomloss(label_sequences, blur=0.05, p=2): Args: label_sequences (list): List of node embeddings for each graph + categorical (bool): Whether the node labels are categorical (discrete) or continuous embeddings. blur (float, optional): Sinkhorn smoothing parameter. Defaults to 0.05. p (int, optional): Power of the cost function. Defaults to 2 (squared Euclidean). Returns: numpy.ndarray: Symmetric matrix of pairwise Wasserstein distances + + Notes: + This function differs from the legacy implementation in that it leverages + the GeomLoss library for GPU-accelerated Sinkhorn computations, enabling + efficient optimal transport calculations even on large graphs. The legacy + function uses the POT library and runs on CPU only, which can be slower + for large datasets. Additionally, this function handles both categorical + and continuous node features in a unified manner via one-hot encoding + for discrete labels. """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") sinkhorn = SamplesLoss("sinkhorn", p=p, blur=blur) @@ -49,17 +61,25 @@ def _compute_wasserstein_distance_geomloss(label_sequences, blur=0.05, p=2): n = len(label_sequences) M = torch.zeros((n, n), device=device) - for i, emb_i in enumerate(label_sequences): - x_i = torch.tensor(emb_i, dtype=torch.float32, device=device) + if categorical: + # Flatten all labels for one-hot encoder fitting + all_labels = np.concatenate(label_sequences).reshape(-1, 1) + enc = OneHotEncoder(sparse_output=False, dtype=np.float32) + enc.fit(all_labels) - for j in range(i, n): - x_j = torch.tensor(label_sequences[j], dtype=torch.float32, device=device) + # Encode all graphs now for speed + encoded_sequences = [torch.tensor(enc.transform(seq.reshape(-1,1)), device=device) for seq in label_sequences] - # Uniform weights - a = torch.ones(x_i.shape[0], device=device) / x_i.shape[0] - b = torch.ones(x_j.shape[0], device=device) / x_j.shape[0] + else: + # Assume label_sequences are arrays of shape (n_nodes, features) + encoded_sequences = [torch.tensor(seq, dtype=torch.float32, device=device) for seq in label_sequences] + + for i in range(n): + a = torch.ones(encoded_sequences[i].shape[0], device=device) / encoded_sequences[i].shape[0] + for j in range(i, n): + b = torch.ones(encoded_sequences[j].shape[0], device=device) / encoded_sequences[j].shape[0] - dist = sinkhorn(a, x_i, b, x_j) + dist = sinkhorn(a, encoded_sequences[i], b, encoded_sequences[j]) M[i, j] = dist M[j, i] = dist # symmetric @@ -90,12 +110,12 @@ def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, enfor categorical = False else: for g in X: - if 'label' not in g.vs.attribute_names(): - logging.info('No categorical labels found: Switching to continuous propagation scheme using node degrees.') + if 'label' not in g.vs.attribute_names() or not all(isinstance(label, (int, float)) for label in g.vs['label']): + logging.info('Invalid categorical labels found: Switching to continuous propagation scheme using node degrees.') categorical = False break if categorical: - logging.info('Categorical graph labels detected: Using categorical propagation scheme.') + logging.info('Valid categorical graph labels detected: Using categorical propagation scheme.') # Embed the nodes if categorical: @@ -107,7 +127,7 @@ def pairwise_wasserstein_distance(X, node_features=None, num_iterations=3, enfor # Compute the Wasserstein distance logging.info("Computing pairwise Wasserstein distances between graph embeddings...") - pairwise_distances = _compute_wasserstein_distance_geomloss(node_representations) + pairwise_distances = _compute_wasserstein_distance_geomloss(node_representations, categorical=categorical) return pairwise_distances def wwl(X, node_features=None, num_iterations=3, gamma=None): diff --git a/src/graphcompass/tl/_WLkernel.py b/src/graphcompass/tl/_WLkernel.py index bd3fca7..b7c2a11 100644 --- a/src/graphcompass/tl/_WLkernel.py +++ b/src/graphcompass/tl/_WLkernel.py @@ -8,14 +8,14 @@ from tqdm import tqdm from anndata import AnnData from graphcompass.tl.utils import _calculate_graph, _get_igraph -from graphcompass.imports.wwl_package import wwl, pairwise_wasserstein_distance +from graphcompass.imports.wwl_package import pairwise_wasserstein_distance def compare_conditions( adata: AnnData, library_key: str = "sample", cluster_key: str = "cell_type", - cell_type_keys: list = None, + cell_type_key: str = None, compute_spatial_graphs: bool = True, num_iterations: int = 3, kwargs_nhood_enrich={}, @@ -35,8 +35,8 @@ def compare_conditions( which stores mapping between ``library_id`` and obs. cluster_key Key in :attr:`anndata.AnnData.obs` where clustering is stored. - cell_type_keys - List of keys in :attr:`anndata.AnnData.obs` where cell types are stored. + cell_type_key + Key in :attr:`anndata.AnnData.obs` where cell types are stored. compute_spatial_graphs Set to False if spatial graphs have been calculated or `sq.gr.spatial_neighbors` has already been run before. kwargs_nhood_enrich @@ -49,7 +49,7 @@ def compare_conditions( Whether to return a copy of the Wasserstein distance object. """ if compute_spatial_graphs: - print("Computing spatial graphs...") + logging.info("Computing spatial graphs...") _calculate_graph( adata=adata, library_key=library_key, @@ -59,42 +59,31 @@ def compare_conditions( **kwargs ) else: - print("Spatial graphs were previously computed. Skipping computing spatial graphs...") + logging.info("Spatial graphs were previously computed. Skipping computing spatial graphs...") samples = adata.obs[library_key].unique() graphs = [] node_features = [] - cell_types = [] adata.uns["wl_kernel"] = {} - adata.uns["wl_kernel"] = {} - if cell_type_keys is not None: - for cell_type_key in cell_type_keys: - graphs = [] - node_features = [] - status = [] - cell_types = [] - adata.uns["wl_kernel"] = {} - adata.uns["wl_kernel"] = {} - - adata.uns["wl_kernel"][cell_type_key] = {} - adata.uns["wl_kernel"][cell_type_key] = {} - for sample in samples: - adata_sample = adata[adata.obs[library_key] == sample] - status.append(adata_sample.obs[library_key][0]) - graphs.append(_get_igraph(adata_sample, cluster_key=None)) - - node_features.append(np.array(adata_sample.obs[cell_type_key].values)) - cell_types.append(np.full(len(adata_sample.obs[cell_type_key]), cell_type_key)) - - node_features = np.array(node_features, dtype=object) - - wasserstein_distance = pairwise_wasserstein_distance(graphs, node_features=node_features, num_iterations=num_iterations) - adata.uns["wl_kernel"][cell_type_key]["wasserstein_distance"] = pd.DataFrame(wasserstein_distance, columns=samples, index=samples) + if cell_type_key is not None: + graphs = [] + adata.uns["wl_kernel"][cell_type_key] = {} + for sample in samples: + adata_sample = adata[adata.obs[library_key] == sample] + g = _get_igraph(adata_sample, cluster_key=cell_type_key) + graphs.append(g) + + wasserstein_distance = pairwise_wasserstein_distance( + graphs, + node_features=None, # triggers categorical WL + num_iterations=num_iterations + ) + adata.uns["wl_kernel"][cell_type_key]["wasserstein_distance"] = pd.DataFrame(wasserstein_distance, columns=samples, index=samples) else: - print("Defining node features...") + logging.info("Defining node features...") for sample in tqdm(samples): adata_sample = adata[adata.obs[library_key] == sample] graphs.append( @@ -109,9 +98,13 @@ def compare_conditions( node_features.append(np.array(features)) node_features = np.array(node_features, dtype=object) - wasserstein_distance = pairwise_wasserstein_distance(graphs, node_features=node_features, num_iterations=num_iterations) + wasserstein_distance = pairwise_wasserstein_distance( + graphs, + node_features=node_features, # triggers continuous WL + num_iterations=num_iterations + ) adata.uns["wl_kernel"]["wasserstein_distance"] = pd.DataFrame(wasserstein_distance, columns=samples, index=samples) - print("Done!") + logging.info("Done!") if copy: - return wasserstein_distance \ No newline at end of file + return wasserstein_distance