From 2a4f61c4c1a0040ca13c9ddbe136323d4628ee70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 22:21:05 +0200 Subject: [PATCH 1/8] Add GraphSamplingEndpoints --- .../api/graph_sampling_endpoints.py | 164 ++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 graphdatascience/procedure_surface/api/graph_sampling_endpoints.py diff --git a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py new file mode 100644 index 000000000..f80cc21f7 --- /dev/null +++ b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from graphdatascience import Graph +from graphdatascience.procedure_surface.api.base_result import BaseResult + + +class GraphSamplingEndpoints(ABC): + """ + Abstract base class defining the API for graph sampling algorithms algorithm. + """ + + @abstractmethod + def rwr( + self, + G: Graph, + graph_name: str, + startNodes: Optional[List[int]] = None, + restartProbability: Optional[float] = None, + samplingRatio: Optional[float] = None, + nodeLabelStratification: Optional[bool] = None, + relationshipWeightProperty: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + """ + Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. + + This method performs a random walk, beginning from a set of nodes (if provided), + where at each step there is a probability to restart back at the original nodes. + The result is turned into a new graph induced by the random walks and stored in the catalog. + + Parameters + ---------- + G : Graph + The input graph on which the Random Walk with Restart (RWR) will be + performed. + graph_name : str + The name of the new graph in the catalog. + startNodes : list of int, optional + A list of node IDs to start the random walk from. If not provided, all + nodes are used as potential starting points. + restartProbability : float, optional + The probability of restarting back to the original node at each step. + Should be a value between 0 and 1. If not specified, a default value is used. + samplingRatio : float, optional + The ratio of nodes to sample during the computation. This value should + be between 0 and 1. If not specified, no sampling is performed. + nodeLabelStratification : bool, optional + If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. + relationshipWeightProperty : str, optional + The name of the property on relationships to use as weights during + the random walk. If not specified, the relationships are treated as + unweighted. + relationship_types : Optional[List[str]], default=None + The relationship types used to select relationships for this algorithm run. + node_labels : Optional[List[str]], default=None + The node labels used to select nodes for this algorithm run. + sudo : bool, optional + Override memory estimation limits. Use with caution as this can lead to + memory issues if the estimation is significantly wrong. + log_progress : bool, optional + If True, logs the progress of the computation. + username : str, optional + The username to attribute the procedure run to + concurrency : Any, optional + The number of concurrent threads used for the algorithm execution. + job_id : Any, optional + An identifier for the job that can be used for monitoring and cancellation + + Returns + ------- + GraphSamplingResult + The result of the Random Walk with Restart (RWR), including the sampled + nodes and their scores. + """ + pass + + @abstractmethod + def cnarw( + self, + G: Graph, + graph_name: str, + startNodes: Optional[List[int]] = None, + restartProbability: Optional[float] = None, + samplingRatio: Optional[float] = None, + nodeLabelStratification: Optional[bool] = None, + relationshipWeightProperty: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + """ + Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. + + This method performs a random walk, beginning from a set of nodes (if provided), + where at each step there is a probability to restart back at the original nodes. + The result is turned into a new graph induced by the random walks and stored in the catalog. + + Parameters + ---------- + G : Graph + The input graph on which the Random Walk with Restart (RWR) will be + performed. + graph_name : str + The name of the new graph in the catalog. + startNodes : list of int, optional + A list of node IDs to start the random walk from. If not provided, all + nodes are used as potential starting points. + restartProbability : float, optional + The probability of restarting back to the original node at each step. + Should be a value between 0 and 1. If not specified, a default value is used. + samplingRatio : float, optional + The ratio of nodes to sample during the computation. This value should + be between 0 and 1. If not specified, no sampling is performed. + nodeLabelStratification : bool, optional + If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. + relationshipWeightProperty : str, optional + The name of the property on relationships to use as weights during + the random walk. If not specified, the relationships are treated as + unweighted. + relationship_types : Optional[List[str]], default=None + The relationship types used to select relationships for this algorithm run. + node_labels : Optional[List[str]], default=None + The node labels used to select nodes for this algorithm run. + sudo : bool, optional + Override memory estimation limits. Use with caution as this can lead to + memory issues if the estimation is significantly wrong. + log_progress : bool, optional + If True, logs the progress of the computation. + username : str, optional + The username to attribute the procedure run to + concurrency : Any, optional + The number of concurrent threads used for the algorithm execution. + job_id : Any, optional + An identifier for the job that can be used for monitoring and cancellation + + Returns + ------- + GraphSamplingResult + The result of the Random Walk with Restart (RWR), including the sampled + nodes and their scores. + """ + pass + + +class GraphSamplingResult(BaseResult): + graph_name: str + from_graph_name: str + node_count: int + relationship_count: int + start_node_count: int + project_millis: int From f051ff62945c69797f533f57f98a97a656dc05a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 22:21:52 +0200 Subject: [PATCH 2/8] Implement GraphSamplingArrowEndpoints --- .../arrow/graph_sampling_arrow_endpoints.py | 93 ++++++++++++++++++ .../test_graph_sampling_arrow_endpoints.py | 98 +++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py diff --git a/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py new file mode 100644 index 000000000..444e6520c --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from typing import Any, List, Optional + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.job_client import JobClient +from graphdatascience.procedure_surface.api.graph_sampling_endpoints import ( + GraphSamplingEndpoints, + GraphSamplingResult, +) +from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter + + +class GraphSamplingArrowEndpoints(GraphSamplingEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + + def rwr( + self, + G: Graph, + graph_name: str, + startNodes: Optional[List[int]] = None, + restartProbability: Optional[float] = None, + samplingRatio: Optional[float] = None, + nodeLabelStratification: Optional[bool] = None, + relationshipWeightProperty: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + config = ConfigConverter.convert_to_gds_config( + from_graph_name=G.name(), + graph_name=graph_name, + startNodes=startNodes, + restartProbability=restartProbability, + samplingRatio=samplingRatio, + nodeLabelStratification=nodeLabelStratification, + relationshipWeightProperty=relationshipWeightProperty, + relationship_types=relationship_types, + node_labels=node_labels, + sudo=sudo, + log_progress=log_progress, + username=username, + concurrency=concurrency, + job_id=job_id, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.rwr", config) + + return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id)) + + def cnarw( + self, + G: Graph, + graph_name: str, + startNodes: Optional[List[int]] = None, + restartProbability: Optional[float] = None, + samplingRatio: Optional[float] = None, + nodeLabelStratification: Optional[bool] = None, + relationshipWeightProperty: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + config = ConfigConverter.convert_to_gds_config( + from_graph_name=G.name(), + graph_name=graph_name, + startNodes=startNodes, + restartProbability=restartProbability, + samplingRatio=samplingRatio, + nodeLabelStratification=nodeLabelStratification, + relationshipWeightProperty=relationshipWeightProperty, + relationship_types=relationship_types, + node_labels=node_labels, + sudo=sudo, + log_progress=log_progress, + username=username, + concurrency=concurrency, + job_id=job_id, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.cnarw", config) + + return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id)) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py new file mode 100644 index 000000000..0670936d5 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py @@ -0,0 +1,98 @@ +from typing import Generator + +import pytest + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.arrow.graph_sampling_arrow_endpoints import GraphSamplingArrowEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import ( + create_graph, +) + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]: + gdl = """ + (a :Node {id: 0}) + (b :Node {id: 1}) + (c :Node {id: 2}) + (d :Node {id: 3}) + (e :Node {id: 4}) + (a)-[:REL {weight: 1.0}]->(b) + (b)-[:REL {weight: 2.0}]->(c) + (c)-[:REL {weight: 1.5}]->(d) + (d)-[:REL {weight: 0.5}]->(e) + (e)-[:REL {weight: 1.2}]->(a) + """ + + yield create_graph(arrow_client, "sample_graph", gdl) + arrow_client.do_action("v2/graph.drop", {"graphName": "sample_graph"}) + arrow_client.do_action("v2/graph.drop", {"graphName": "sampled"}) + + +@pytest.fixture +def graph_sampling_endpoints( + arrow_client: AuthenticatedArrowClient, +) -> Generator[GraphSamplingArrowEndpoints, None, None]: + yield GraphSamplingArrowEndpoints(arrow_client) + + +def test_rwr_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.rwr( + G=sample_graph, graph_name="sampled", startNodes=[0, 1], restartProbability=0.15, samplingRatio=0.8 + ) + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.relationship_count >= 0 + assert result.start_node_count > 0 + assert result.project_millis >= 0 + + +def test_rwr_with_weights(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.rwr( + G=sample_graph, + graph_name="sampled", + startNodes=[0], + restartProbability=0.2, + samplingRatio=0.6, + relationshipWeightProperty="weight", + ) + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_rwr_minimal_config(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.rwr(G=sample_graph, graph_name="sampled") + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.project_millis >= 0 + + +def test_cnarw_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.cnarw( + G=sample_graph, graph_name="sampled", startNodes=[0, 1], restartProbability=0.15, samplingRatio=0.8 + ) + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.relationship_count >= 0 + assert result.start_node_count == 2 + assert result.project_millis >= 0 + + +def test_cnarw_minimal_config(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: + result = graph_sampling_endpoints.cnarw(G=sample_graph, graph_name="sampled") + + assert result.graph_name == "sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.project_millis >= 0 From 470bb31f314e6b657866ec85460ad2b0d05c9b6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 23:10:30 +0200 Subject: [PATCH 3/8] Use snake case names for sampling arguments --- .../api/graph_sampling_endpoints.py | 40 +++++++++---------- .../arrow/graph_sampling_arrow_endpoints.py | 40 +++++++++---------- .../test_graph_sampling_arrow_endpoints.py | 12 +++--- 3 files changed, 46 insertions(+), 46 deletions(-) diff --git a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py index f80cc21f7..d8524f5ff 100644 --- a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py +++ b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py @@ -17,11 +17,11 @@ def rwr( self, G: Graph, graph_name: str, - startNodes: Optional[List[int]] = None, - restartProbability: Optional[float] = None, - samplingRatio: Optional[float] = None, - nodeLabelStratification: Optional[bool] = None, - relationshipWeightProperty: Optional[str] = None, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, relationship_types: Optional[List[str]] = None, node_labels: Optional[List[str]] = None, sudo: Optional[bool] = None, @@ -44,18 +44,18 @@ def rwr( performed. graph_name : str The name of the new graph in the catalog. - startNodes : list of int, optional + start_nodes : list of int, optional A list of node IDs to start the random walk from. If not provided, all nodes are used as potential starting points. - restartProbability : float, optional + restart_probability : float, optional The probability of restarting back to the original node at each step. Should be a value between 0 and 1. If not specified, a default value is used. - samplingRatio : float, optional + sampling_ratio : float, optional The ratio of nodes to sample during the computation. This value should be between 0 and 1. If not specified, no sampling is performed. - nodeLabelStratification : bool, optional + node_label_stratification : bool, optional If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. - relationshipWeightProperty : str, optional + relationship_weight_property : str, optional The name of the property on relationships to use as weights during the random walk. If not specified, the relationships are treated as unweighted. @@ -88,11 +88,11 @@ def cnarw( self, G: Graph, graph_name: str, - startNodes: Optional[List[int]] = None, - restartProbability: Optional[float] = None, - samplingRatio: Optional[float] = None, - nodeLabelStratification: Optional[bool] = None, - relationshipWeightProperty: Optional[str] = None, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, relationship_types: Optional[List[str]] = None, node_labels: Optional[List[str]] = None, sudo: Optional[bool] = None, @@ -115,18 +115,18 @@ def cnarw( performed. graph_name : str The name of the new graph in the catalog. - startNodes : list of int, optional + start_nodes : list of int, optional A list of node IDs to start the random walk from. If not provided, all nodes are used as potential starting points. - restartProbability : float, optional + restart_probability : float, optional The probability of restarting back to the original node at each step. Should be a value between 0 and 1. If not specified, a default value is used. - samplingRatio : float, optional + sampling_ratio : float, optional The ratio of nodes to sample during the computation. This value should be between 0 and 1. If not specified, no sampling is performed. - nodeLabelStratification : bool, optional + node_label_stratification : bool, optional If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. - relationshipWeightProperty : str, optional + relationship_weight_property : str, optional The name of the property on relationships to use as weights during the random walk. If not specified, the relationships are treated as unweighted. diff --git a/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py index 444e6520c..1ef0be1c4 100644 --- a/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py @@ -20,11 +20,11 @@ def rwr( self, G: Graph, graph_name: str, - startNodes: Optional[List[int]] = None, - restartProbability: Optional[float] = None, - samplingRatio: Optional[float] = None, - nodeLabelStratification: Optional[bool] = None, - relationshipWeightProperty: Optional[str] = None, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, relationship_types: Optional[List[str]] = None, node_labels: Optional[List[str]] = None, sudo: Optional[bool] = None, @@ -36,11 +36,11 @@ def rwr( config = ConfigConverter.convert_to_gds_config( from_graph_name=G.name(), graph_name=graph_name, - startNodes=startNodes, - restartProbability=restartProbability, - samplingRatio=samplingRatio, - nodeLabelStratification=nodeLabelStratification, - relationshipWeightProperty=relationshipWeightProperty, + startNodes=start_nodes, + restartProbability=restart_probability, + samplingRatio=sampling_ratio, + nodeLabelStratification=node_label_stratification, + relationshipWeightProperty=relationship_weight_property, relationship_types=relationship_types, node_labels=node_labels, sudo=sudo, @@ -58,11 +58,11 @@ def cnarw( self, G: Graph, graph_name: str, - startNodes: Optional[List[int]] = None, - restartProbability: Optional[float] = None, - samplingRatio: Optional[float] = None, - nodeLabelStratification: Optional[bool] = None, - relationshipWeightProperty: Optional[str] = None, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, relationship_types: Optional[List[str]] = None, node_labels: Optional[List[str]] = None, sudo: Optional[bool] = None, @@ -74,11 +74,11 @@ def cnarw( config = ConfigConverter.convert_to_gds_config( from_graph_name=G.name(), graph_name=graph_name, - startNodes=startNodes, - restartProbability=restartProbability, - samplingRatio=samplingRatio, - nodeLabelStratification=nodeLabelStratification, - relationshipWeightProperty=relationshipWeightProperty, + startNodes=start_nodes, + restartProbability=restart_probability, + samplingRatio=sampling_ratio, + nodeLabelStratification=node_label_stratification, + relationshipWeightProperty=relationship_weight_property, relationship_types=relationship_types, node_labels=node_labels, sudo=sudo, diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py index 0670936d5..ff1111d77 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py @@ -39,7 +39,7 @@ def graph_sampling_endpoints( def test_rwr_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: result = graph_sampling_endpoints.rwr( - G=sample_graph, graph_name="sampled", startNodes=[0, 1], restartProbability=0.15, samplingRatio=0.8 + G=sample_graph, graph_name="sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 ) assert result.graph_name == "sampled" @@ -54,10 +54,10 @@ def test_rwr_with_weights(graph_sampling_endpoints: GraphSamplingArrowEndpoints, result = graph_sampling_endpoints.rwr( G=sample_graph, graph_name="sampled", - startNodes=[0], - restartProbability=0.2, - samplingRatio=0.6, - relationshipWeightProperty="weight", + start_nodes=[0], + restart_probability=0.2, + sampling_ratio=0.6, + relationship_weight_property="weight", ) assert result.graph_name == "sampled" @@ -78,7 +78,7 @@ def test_rwr_minimal_config(graph_sampling_endpoints: GraphSamplingArrowEndpoint def test_cnarw_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: result = graph_sampling_endpoints.cnarw( - G=sample_graph, graph_name="sampled", startNodes=[0, 1], restartProbability=0.15, samplingRatio=0.8 + G=sample_graph, graph_name="sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 ) assert result.graph_name == "sampled" From d647c45f5d4d1073d2a2e9550932a1302c9e2ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 23:10:51 +0200 Subject: [PATCH 4/8] Implement cypher sampling endpoints --- .../cypher/graph_sampling_cypher_endpoints.py | 98 ++++++++++++++ .../cypher/cypher_graph_helper.py | 7 + .../test_graph_sampling_cypher_endpoints.py | 125 ++++++++++++++++++ 3 files changed, 230 insertions(+) create mode 100644 graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/cypher/cypher_graph_helper.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py diff --git a/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py new file mode 100644 index 000000000..cce413e6e --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/graph_sampling_cypher_endpoints.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import Any, List, Optional + +from ...call_parameters import CallParameters +from ...graph.graph_object import Graph +from ...query_runner.query_runner import QueryRunner +from ..api.graph_sampling_endpoints import GraphSamplingEndpoints, GraphSamplingResult +from ..utils.config_converter import ConfigConverter + + +class GraphSamplingCypherEndpoints(GraphSamplingEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def rwr( + self, + G: Graph, + graph_name: str, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + config = ConfigConverter.convert_to_gds_config( + startNodes=start_nodes, + restartProbability=restart_probability, + samplingRatio=sampling_ratio, + nodeLabelStratification=node_label_stratification, + relationshipWeightProperty=relationship_weight_property, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + + params = CallParameters( + graph_name=graph_name, + from_graph_name=G.name(), + config=config, + ) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.graph.sample.rwr", params=params).squeeze() + return GraphSamplingResult(**result.to_dict()) + + def cnarw( + self, + G: Graph, + graph_name: str, + start_nodes: Optional[List[int]] = None, + restart_probability: Optional[float] = None, + sampling_ratio: Optional[float] = None, + node_label_stratification: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + ) -> GraphSamplingResult: + config = ConfigConverter.convert_to_gds_config( + startNodes=start_nodes, + restartProbability=restart_probability, + samplingRatio=sampling_ratio, + nodeLabelStratification=node_label_stratification, + relationshipWeightProperty=relationship_weight_property, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + + params = CallParameters( + graph_name=graph_name, + from_graph_name=G.name(), + config=config, + ) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.graph.sample.cnarw", params=params).squeeze() + return GraphSamplingResult(**result.to_dict()) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/cypher_graph_helper.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/cypher_graph_helper.py new file mode 100644 index 000000000..33a772740 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/cypher_graph_helper.py @@ -0,0 +1,7 @@ +from graphdatascience import QueryRunner + + +def delete_all_graphs(query_runner: QueryRunner) -> None: + query_runner.run_cypher( + "CALL gds.graph.list() YIELD graphName CALL gds.graph.drop(graphName) YIELD graphName as g RETURN g" + ) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py new file mode 100644 index 000000000..63de3e264 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graph_sampling_cypher_endpoints.py @@ -0,0 +1,125 @@ +from typing import Generator + +import pytest + +from graphdatascience import Graph, QueryRunner +from graphdatascience.procedure_surface.cypher.graph_sampling_cypher_endpoints import GraphSamplingCypherEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import delete_all_graphs + + +@pytest.fixture +def sample_graph(query_runner: QueryRunner) -> Generator[Graph, None, None]: + create_statement = """ + CREATE + (a: Node {id: 0}), + (b: Node {id: 1}), + (c: Node {id: 2}), + (d: Node {id: 3}), + (e: Node {id: 4}), + (a)-[:REL {weight: 1.0}]->(b), + (b)-[:REL {weight: 2.0}]->(c), + (c)-[:REL {weight: 1.5}]->(d), + (d)-[:REL {weight: 0.5}]->(e), + (e)-[:REL {weight: 1.2}]->(a) + """ + + query_runner.run_cypher(create_statement) + + query_runner.run_cypher(""" + MATCH (n) + OPTIONAL MATCH (n)-[r]->(m) + WITH gds.graph.project('g', n, m, {relationshipProperties: {weight: r.weight}}) AS G + RETURN G + """) + + yield Graph("g", query_runner) + + delete_all_graphs(query_runner) + query_runner.run_cypher("MATCH (n) DETACH DELETE n") + + +@pytest.fixture +def graph_sampling_endpoints(query_runner: QueryRunner) -> Generator[GraphSamplingCypherEndpoints, None, None]: + yield GraphSamplingCypherEndpoints(query_runner) + + +def test_rwr_basic(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test RWR sampling with basic configuration.""" + result = graph_sampling_endpoints.rwr( + G=sample_graph, graph_name="rwr_sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 + ) + + assert result.graph_name == "rwr_sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.relationship_count >= 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_rwr_with_weights(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test RWR sampling with weighted relationships.""" + result = graph_sampling_endpoints.rwr( + G=sample_graph, + graph_name="rwr_weighted", + restart_probability=0.2, + sampling_ratio=0.6, + relationship_weight_property="weight", + ) + + assert result.graph_name == "rwr_weighted" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_rwr_minimal_config(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test RWR sampling with minimal configuration.""" + result = graph_sampling_endpoints.rwr(G=sample_graph, graph_name="rwr_minimal") + + assert result.graph_name == "rwr_minimal" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.project_millis >= 0 + + +def test_cnarw_basic(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test CNARW sampling with basic configuration.""" + result = graph_sampling_endpoints.cnarw( + G=sample_graph, graph_name="cnarw_sampled", restart_probability=0.15, sampling_ratio=0.8 + ) + + assert result.graph_name == "cnarw_sampled" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.relationship_count >= 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_cnarw_with_stratification(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test CNARW sampling with node label stratification.""" + result = graph_sampling_endpoints.cnarw( + G=sample_graph, + graph_name="cnarw_stratified", + restart_probability=0.1, + sampling_ratio=0.7, + node_label_stratification=True, + ) + + assert result.graph_name == "cnarw_stratified" + assert result.from_graph_name == sample_graph.name() + assert result.node_count > 0 + assert result.start_node_count >= 1 + assert result.project_millis >= 0 + + +def test_cnarw_minimal_config(graph_sampling_endpoints: GraphSamplingCypherEndpoints, sample_graph: Graph) -> None: + """Test CNARW sampling with minimal configuration.""" + result = graph_sampling_endpoints.cnarw(G=sample_graph, graph_name="cnarw_minimal") + + assert result.graph_name == "cnarw_minimal" + assert result.from_graph_name == sample_graph.name() + assert result.start_node_count >= 1 + assert result.project_millis >= 0 From 3895d595a9c25b9e3a4cb0f2ef2b88df82d5c3e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Aug 2025 23:12:55 +0200 Subject: [PATCH 5/8] Expose sampling endpoints in catalog endpoints --- .../procedure_surface/api/catalog_endpoints.py | 6 ++++++ .../procedure_surface/arrow/catalog_arrow_endpoints.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/graphdatascience/procedure_surface/api/catalog_endpoints.py b/graphdatascience/procedure_surface/api/catalog_endpoints.py index fecf0b45e..552e58e80 100644 --- a/graphdatascience/procedure_surface/api/catalog_endpoints.py +++ b/graphdatascience/procedure_surface/api/catalog_endpoints.py @@ -9,6 +9,7 @@ from graphdatascience import Graph from graphdatascience.procedure_surface.api.base_result import BaseResult +from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints class CatalogEndpoints(ABC): @@ -65,6 +66,11 @@ def filter( """ pass + @property + @abstractmethod + def sample(self) -> GraphSamplingEndpoints: + pass + class GraphListResult(BaseResult): graph_name: str diff --git a/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py index cab59fc88..2e9e2aee6 100644 --- a/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py @@ -13,6 +13,8 @@ GraphFilterResult, GraphListResult, ) +from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints +from graphdatascience.procedure_surface.arrow.graph_sampling_arrow_endpoints import GraphSamplingArrowEndpoints from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol from graphdatascience.query_runner.termination_flag import TerminationFlag @@ -116,6 +118,10 @@ def filter( return GraphFilterResult(**JobClient.get_summary(self._arrow_client, job_id)) + @property + def sample(self) -> GraphSamplingEndpoints: + return GraphSamplingArrowEndpoints(self._arrow_client) + def _arrow_config(self) -> dict[str, Any]: connection_info = self._arrow_client.advertised_connection_info() @@ -131,6 +137,7 @@ def _arrow_config(self) -> dict[str, Any]: } + class ProjectionResult(BaseResult): graph_name: str node_count: int From e197129cecccc146c6433393d2d577a368213b14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Thu, 28 Aug 2025 12:17:06 +0200 Subject: [PATCH 6/8] Fix tests and code style --- .../procedure_surface/arrow/catalog_arrow_endpoints.py | 1 - .../arrow/test_graph_sampling_arrow_endpoints.py | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py index 2e9e2aee6..3bba9a16d 100644 --- a/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py @@ -137,7 +137,6 @@ def _arrow_config(self) -> dict[str, Any]: } - class ProjectionResult(BaseResult): graph_name: str node_count: int diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py index ff1111d77..677f49361 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graph_sampling_arrow_endpoints.py @@ -39,14 +39,14 @@ def graph_sampling_endpoints( def test_rwr_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: result = graph_sampling_endpoints.rwr( - G=sample_graph, graph_name="sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 + G=sample_graph, graph_name="sampled", restart_probability=0.15, sampling_ratio=0.8 ) assert result.graph_name == "sampled" assert result.from_graph_name == sample_graph.name() assert result.node_count > 0 assert result.relationship_count >= 0 - assert result.start_node_count > 0 + assert result.start_node_count >= 1 assert result.project_millis >= 0 @@ -54,7 +54,6 @@ def test_rwr_with_weights(graph_sampling_endpoints: GraphSamplingArrowEndpoints, result = graph_sampling_endpoints.rwr( G=sample_graph, graph_name="sampled", - start_nodes=[0], restart_probability=0.2, sampling_ratio=0.6, relationship_weight_property="weight", @@ -78,14 +77,14 @@ def test_rwr_minimal_config(graph_sampling_endpoints: GraphSamplingArrowEndpoint def test_cnarw_basic(graph_sampling_endpoints: GraphSamplingArrowEndpoints, sample_graph: Graph) -> None: result = graph_sampling_endpoints.cnarw( - G=sample_graph, graph_name="sampled", start_nodes=[0, 1], restart_probability=0.15, sampling_ratio=0.8 + G=sample_graph, graph_name="sampled", restart_probability=0.15, sampling_ratio=0.8 ) assert result.graph_name == "sampled" assert result.from_graph_name == sample_graph.name() assert result.node_count > 0 assert result.relationship_count >= 0 - assert result.start_node_count == 2 + assert result.start_node_count >= 1 assert result.project_millis >= 0 From 686463293424000eee20ca5ff22d55664bf31d70 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 28 Aug 2025 17:14:24 +0200 Subject: [PATCH 7/8] Update docstring for RWR - Content copied as much as possible from GDS Manual - Add default values --- .../api/graph_sampling_endpoints.py | 55 ++++++++++--------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py index d8524f5ff..36b68aa6b 100644 --- a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py +++ b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py @@ -9,7 +9,7 @@ class GraphSamplingEndpoints(ABC): """ - Abstract base class defining the API for graph sampling algorithms algorithm. + Abstract base class defining the API for graph sampling operations. """ @abstractmethod @@ -31,11 +31,11 @@ def rwr( job_id: Optional[Any] = None, ) -> GraphSamplingResult: """ - Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. + Random walk with restarts (RWR) samples the graph by taking random walks from a set of start nodes. - This method performs a random walk, beginning from a set of nodes (if provided), - where at each step there is a probability to restart back at the original nodes. - The result is turned into a new graph induced by the random walks and stored in the catalog. + On each step of a random walk, there is a probability that the walk stops, and a new walk from one of the start + nodes starts instead (i.e. the walk restarts). Each node visited on these walks will be part of the sampled + subgraph. The resulting subgraph is stored as a new graph in the Graph Catalog. Parameters ---------- @@ -43,43 +43,46 @@ def rwr( The input graph on which the Random Walk with Restart (RWR) will be performed. graph_name : str - The name of the new graph in the catalog. + The name of the new graph that is stored in the graph catalog. start_nodes : list of int, optional - A list of node IDs to start the random walk from. If not provided, all - nodes are used as potential starting points. + IDs of the initial set of nodes in the original graph from which the sampling random walks will start. + By default, a single node is chosen uniformly at random. restart_probability : float, optional - The probability of restarting back to the original node at each step. - Should be a value between 0 and 1. If not specified, a default value is used. + The probability that a sampling random walk restarts from one of the start nodes. + Default is 0.1. sampling_ratio : float, optional - The ratio of nodes to sample during the computation. This value should - be between 0 and 1. If not specified, no sampling is performed. + The fraction of nodes in the original graph to be sampled. + Default is 0.15. node_label_stratification : bool, optional - If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. + If true, preserves the node label distribution of the original graph. + Default is False. relationship_weight_property : str, optional - The name of the property on relationships to use as weights during - the random walk. If not specified, the relationships are treated as - unweighted. + Name of the relationship property to use as weights. If unspecified, the algorithm runs unweighted. relationship_types : Optional[List[str]], default=None - The relationship types used to select relationships for this algorithm run. + Filter the named graph using the given relationship types. Relationships with any of the given types will be + included. node_labels : Optional[List[str]], default=None - The node labels used to select nodes for this algorithm run. + Filter the named graph using the given node labels. Nodes with any of the given labels will be included. sudo : bool, optional - Override memory estimation limits. Use with caution as this can lead to - memory issues if the estimation is significantly wrong. + Bypass heap control. Use with caution. + Default is False. log_progress : bool, optional - If True, logs the progress of the computation. + Turn `on/off` percentage logging while running procedure. + Default is True. username : str, optional - The username to attribute the procedure run to + Use Administrator access to run an algorithm on a graph owned by another user. + Default is None. concurrency : Any, optional - The number of concurrent threads used for the algorithm execution. + The number of concurrent threads used for running the algorithm. + Default is 4. job_id : Any, optional - An identifier for the job that can be used for monitoring and cancellation + An ID that can be provided to more easily track the algorithm’s progress. + By default, a random job id is generated. Returns ------- GraphSamplingResult - The result of the Random Walk with Restart (RWR), including the sampled - nodes and their scores. + The result of the Random Walk with Restart (RWR), including the dimensions of the sampled graph. """ pass From 1c6b3d09f1723136e15bf64376bf236961eb76cf Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 28 Aug 2025 17:24:04 +0200 Subject: [PATCH 8/8] Update docstring for CNARW - Content copied as much as possible from GDS Manual - Add default values --- .../api/graph_sampling_endpoints.py | 61 ++++++++++--------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py index 36b68aa6b..ba1239edf 100644 --- a/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py +++ b/graphdatascience/procedure_surface/api/graph_sampling_endpoints.py @@ -40,8 +40,7 @@ def rwr( Parameters ---------- G : Graph - The input graph on which the Random Walk with Restart (RWR) will be - performed. + The input graph to be sampled. graph_name : str The name of the new graph that is stored in the graph catalog. start_nodes : list of int, optional @@ -105,55 +104,59 @@ def cnarw( job_id: Optional[Any] = None, ) -> GraphSamplingResult: """ - Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog. + Common Neighbour Aware Random Walk (CNARW) samples the graph by taking random walks from a set of start nodes - This method performs a random walk, beginning from a set of nodes (if provided), - where at each step there is a probability to restart back at the original nodes. - The result is turned into a new graph induced by the random walks and stored in the catalog. + CNARW is a graph sampling technique that involves optimizing the selection of the next-hop node. It takes into + account the number of common neighbours between the current node and the next-hop candidates. On each step of a + random walk, there is a probability that the walk stops, and a new walk from one of the start nodes starts + instead (i.e. the walk restarts). Each node visited on these walks will be part of the sampled subgraph. The + resulting subgraph is stored as a new graph in the Graph Catalog. Parameters ---------- G : Graph - The input graph on which the Random Walk with Restart (RWR) will be - performed. + The input graph to be sampled. graph_name : str - The name of the new graph in the catalog. + The name of the new graph that is stored in the graph catalog. start_nodes : list of int, optional - A list of node IDs to start the random walk from. If not provided, all - nodes are used as potential starting points. + IDs of the initial set of nodes in the original graph from which the sampling random walks will start. + By default, a single node is chosen uniformly at random. restart_probability : float, optional - The probability of restarting back to the original node at each step. - Should be a value between 0 and 1. If not specified, a default value is used. + The probability that a sampling random walk restarts from one of the start nodes. + Default is 0.1. sampling_ratio : float, optional - The ratio of nodes to sample during the computation. This value should - be between 0 and 1. If not specified, no sampling is performed. + The fraction of nodes in the original graph to be sampled. + Default is 0.15. node_label_stratification : bool, optional - If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph. + If true, preserves the node label distribution of the original graph. + Default is False. relationship_weight_property : str, optional - The name of the property on relationships to use as weights during - the random walk. If not specified, the relationships are treated as - unweighted. + Name of the relationship property to use as weights. If unspecified, the algorithm runs unweighted. relationship_types : Optional[List[str]], default=None - The relationship types used to select relationships for this algorithm run. + Filter the named graph using the given relationship types. Relationships with any of the given types will be + included. node_labels : Optional[List[str]], default=None - The node labels used to select nodes for this algorithm run. + Filter the named graph using the given node labels. Nodes with any of the given labels will be included. sudo : bool, optional - Override memory estimation limits. Use with caution as this can lead to - memory issues if the estimation is significantly wrong. + Bypass heap control. Use with caution. + Default is False. log_progress : bool, optional - If True, logs the progress of the computation. + Turn `on/off` percentage logging while running procedure. + Default is True. username : str, optional - The username to attribute the procedure run to + Use Administrator access to run an algorithm on a graph owned by another user. + Default is None. concurrency : Any, optional - The number of concurrent threads used for the algorithm execution. + The number of concurrent threads used for running the algorithm. + Default is 4. job_id : Any, optional - An identifier for the job that can be used for monitoring and cancellation + An ID that can be provided to more easily track the algorithm’s progress. + By default, a random job id is generated. Returns ------- GraphSamplingResult - The result of the Random Walk with Restart (RWR), including the sampled - nodes and their scores. + The result of the Common Neighbour Aware Random Walk (CNARW), including the dimensions of the sampled graph. """ pass