Skip to content
Closed
20 changes: 20 additions & 0 deletions _nx_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def get_info():
'get_chunks : str, function (default = "chunks")': "A function that takes in a list of all the nodes as input and returns an iterable `node_chunks`. The default chunking is done by slicing the `nodes` into `n_jobs` number of chunks."
},
},
"harmonic_centrality": {
"url": "https://github.com/networkx/nx-parallel/blob/main/nx_parallel/algorithms/centrality/harmonic.py#L10",
"additional_docs": "Compute harmonic centrality in parallel.",
"additional_parameters": {
"G : NetworkX graph": "A graph (directed or undirected). u : node or iterable, optional (default: all nodes in G) Compute harmonic centrality for the specified node(s). distance : edge attribute key, optional (default: None) Use the specified edge attribute as the edge weight. wf_improved : bool, optional (default: True) This parameter is included for API compatibility but not used in harmonic centrality. backend : str, optional (default: None) The parallel backend to use (`'loky'`, `'threading'`, etc.). **backend_kwargs : additional backend parameters"
},
},
"is_reachable": {
"url": "https://github.com/networkx/nx-parallel/blob/main/nx_parallel/algorithms/tournament.py#L13",
"additional_docs": "The function parallelizes the calculation of two neighborhoods of vertices in `G` and checks closure conditions for each neighborhood subset in parallel.",
Expand Down Expand Up @@ -139,6 +146,14 @@ def get_info():
'get_chunks : str, function (default = "chunks")': "A function that takes in a list of all the isolated nodes as input and returns an iterable `isolate_chunks`. The default chunking is done by slicing the `isolates` into `n_jobs` number of chunks."
},
},
"parallel_bfs": {
"url": "https://github.com/networkx/nx-parallel/blob/main/nx_parallel/algorithms/traversal/breadth_first_search.py#L10",
"additional_docs": "Perform a parallelized Breadth-First Search (BFS) on the graph.",
"additional_parameters": {
"G : graph": 'A NetworkX graph. source : node, optional Starting node for the BFS traversal. If None, BFS is performed for all nodes. get_chunks : str or function (default="chunks") A function to divide nodes into chunks for parallel processing. If "chunks", the nodes are split into `n_jobs` chunks automatically. n_jobs : int, optional Number of jobs to run in parallel. If None, defaults to the number of CPUs.',
"bfs_result : dict": "A dictionary where keys are nodes and values are their BFS traversal order.",
},
},
"square_clustering": {
"url": "https://github.com/networkx/nx-parallel/blob/main/nx_parallel/algorithms/cluster.py#L11",
"additional_docs": "The nodes are chunked into `node_chunks` and then the square clustering coefficient for all `node_chunks` are computed in parallel over `n_jobs` number of CPU cores.",
Expand All @@ -153,5 +168,10 @@ def get_info():
'get_chunks : str, function (default = "chunks")': "A function that takes in a list of all the nodes as input and returns an iterable `node_chunks`. The default chunking is done by slicing the `nodes` into `n_jobs` number of chunks."
},
},
"voterank": {
"url": "https://github.com/networkx/nx-parallel/blob/main/nx_parallel/algorithms/centrality/voterank.py#L27",
"additional_docs": "Parallelized VoteRank centrality using joblib with chunking.",
"additional_parameters": None,
},
},
}
11 changes: 7 additions & 4 deletions benchmarks/benchmarks/bench_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
get_cached_gnp_random_graph,
Benchmark,
)
import networkx as nx
import nx_parallel as nxp


class Cluster(Benchmark):
class VoteRank(Benchmark):
"""Benchmark for the parallelized VoteRank centrality."""

params = [(backends), (num_nodes), (edge_prob)]
param_names = ["backend", "num_nodes", "edge_prob"]

def time_square_clustering(self, backend, num_nodes, edge_prob):
def time_voterank(self, backend, num_nodes, edge_prob):
"""Benchmark VoteRank on different graph sizes and backends."""
G = get_cached_gnp_random_graph(num_nodes, edge_prob)
_ = nx.square_clustering(G, backend=backend)
_ = nxp.voterank(G, number_of_nodes=min(100, num_nodes), backend=backend)
20 changes: 20 additions & 0 deletions benchmarks/benchmarks/bench_harmonic_centrality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .common import (
backends,
num_nodes,
edge_prob,
get_cached_gnp_random_graph,
Benchmark,
)
import nx_parallel as nxp


class HarmonicCentrality(Benchmark):
"""Benchmark for the parallelized Harmonic Centrality computation."""

params = [(backends), (num_nodes), (edge_prob)]
param_names = ["backend", "num_nodes", "edge_prob"]

def time_harmonic_centrality(self, backend, num_nodes, edge_prob):
"""Benchmark Harmonic Centrality on different graph sizes and backends."""
G = get_cached_gnp_random_graph(num_nodes, edge_prob)
_ = nxp.harmonic_centrality(G, backend=backend)
28 changes: 28 additions & 0 deletions benchmarks/benchmarks/bench_voterank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import networkx as nx
import nx_parallel as nxp
from asv_bench.benchmarks.utils import benchmark


class BenchmarkVoteRank:
"""Benchmark for the voterank algorithm in nx_parallel."""

def setup(self):
"""Set up test graphs before running the benchmarks."""
self.G_small = nx.erdos_renyi_graph(100, 0.1, seed=42)
self.G_medium = nx.erdos_renyi_graph(1000, 0.05, seed=42)
self.G_large = nx.erdos_renyi_graph(5000, 0.01, seed=42)

@benchmark.benchmark
def time_voterank_small(self):
"""Benchmark VoteRank on a small graph."""
nxp.voterank(self.G_small, number_of_nodes=10)

@benchmark.benchmark
def time_voterank_medium(self):
"""Benchmark VoteRank on a medium graph."""
nxp.voterank(self.G_medium, number_of_nodes=50)

@benchmark.benchmark
def time_voterank_large(self):
"""Benchmark VoteRank on a large graph."""
nxp.voterank(self.G_large, number_of_nodes=100)
2 changes: 2 additions & 0 deletions nx_parallel/algorithms/centrality/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .betweenness import *
from .harmonic import *
from .voterank import *
82 changes: 82 additions & 0 deletions nx_parallel/algorithms/centrality/harmonic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from functools import partial
from joblib import Parallel, delayed
import networkx as nx
import nx_parallel as nxp

__all__ = ["harmonic_centrality"]


@nxp._configure_if_nx_active()
def harmonic_centrality(
G, u=None, distance=None, wf_improved=True, *, backend=None, **backend_kwargs
):
"""Compute harmonic centrality in parallel.

Parameters
----------
G : NetworkX graph
A graph (directed or undirected).
u : node or iterable, optional (default: all nodes in G)
Compute harmonic centrality for the specified node(s).
distance : edge attribute key, optional (default: None)
Use the specified edge attribute as the edge weight.
wf_improved : bool, optional (default: True)
This parameter is included for API compatibility but not used in harmonic centrality.
backend : str, optional (default: None)
The parallel backend to use (`'loky'`, `'threading'`, etc.).
**backend_kwargs : additional backend parameters

Returns
-------
dict
Dictionary of nodes with harmonic centrality values.
"""

if hasattr(G, "graph_object"):
G = G.graph_object

u = set(G.nbunch_iter(u) if u is not None else G.nodes)
sources = set(G.nodes) # Always use all nodes as sources

centrality = {v: 0 for v in u}

transposed = False
if len(u) < len(sources):
transposed = True
u, sources = sources, u
if nx.is_directed(G):
G = nx.reverse(G, copy=False)

# Get number of parallel jobs
n_jobs = nxp.get_n_jobs()

# Chunking nodes for parallel processing
nodes = list(sources)
node_chunks = nxp.create_iterables(G, "node", n_jobs, nodes)

def process_chunk(chunk):
"""Process a chunk of nodes and compute harmonic centrality."""
local_centrality = {v: 0 for v in chunk}
spl = partial(nx.shortest_path_length, G, weight=distance)

for v in chunk:
dist = spl(v)
for node in u.intersection(dist):
d = dist[node]
if d == 0:
continue
local_centrality[v if transposed else node] += 1 / d

return local_centrality

# Run parallel processing on node chunks
results = Parallel(n_jobs=n_jobs, backend=backend, **backend_kwargs)(
delayed(process_chunk)(chunk) for chunk in node_chunks
)

# Merge results
for result in results:
for node, value in result.items():
centrality[node] += value

return centrality
1 change: 1 addition & 0 deletions nx_parallel/algorithms/centrality/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .test_betweenness_centrality import *
71 changes: 71 additions & 0 deletions nx_parallel/algorithms/centrality/voterank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from joblib import Parallel, delayed
import nx_parallel as nxp

__all__ = ["voterank"]


def _compute_votes(G, vote_rank, nodes):
"""Compute votes for a chunk of nodes in parallel."""
votes = {n: 0 for n in nodes}

for n in nodes:
for nbr in G[n]:
votes[n] += vote_rank[nbr][1] # Node receives votes from neighbors

return votes


def _update_voting_ability(G, vote_rank, selected_node, avgDegree):
"""Update the voting ability of the selected node and its out-neighbors."""
for nbr in G[selected_node]:
vote_rank[nbr][1] = max(
vote_rank[nbr][1] - (1 / avgDegree), 0
) # Ensure non-negative


@nxp._configure_if_nx_active()
def voterank(G, number_of_nodes=None, *, backend=None, **backend_kwargs):
"""Parallelized VoteRank centrality using joblib with chunking."""
influential_nodes = []
vote_rank = {n: [0, 1] for n in G.nodes()} # (score, voting ability)

if len(G) == 0:
return influential_nodes
if number_of_nodes is None or number_of_nodes > len(G):
number_of_nodes = len(G)

avgDegree = sum(
deg for _, deg in (G.out_degree() if G.is_directed() else G.degree())
) / len(G)
nodes = list(G.nodes())
chunk_size = backend_kwargs.get("chunk_size", 100) # Support chunk size override
node_chunks = [nodes[i : i + chunk_size] for i in range(0, len(nodes), chunk_size)]

for _ in range(number_of_nodes):
# Step 1: Compute votes in parallel using chunks
vote_chunks = Parallel(n_jobs=-1)(
delayed(_compute_votes)(G, vote_rank, chunk) for chunk in node_chunks
)

# Merge chunk results
votes = {n: 0 for n in G.nodes()}
for chunk_votes in vote_chunks:
for node, score in chunk_votes.items():
votes[node] += score

# Step 2: Reset votes for already selected nodes
for n in influential_nodes:
votes[n] = 0

# Step 3: Select the most influential node
n = max(sorted(G.nodes()), key=lambda x: votes[x]) # Deterministic tie-breaking
if votes[n] == 0:
return influential_nodes # Stop if no influential node found

influential_nodes.append(n)
vote_rank[n] = [0, 0] # Weaken selected node

# Step 4: Update voting ability
_update_voting_ability(G, vote_rank, n, avgDegree)

return influential_nodes
2 changes: 2 additions & 0 deletions nx_parallel/algorithms/traversal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .depth_first_search import *
from .breadth_first_search import *
96 changes: 96 additions & 0 deletions nx_parallel/algorithms/traversal/breadth_first_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from joblib import Parallel, delayed
from networkx.utils import py_random_state
import nx_parallel as nxp

__all__ = ["parallel_bfs"]


@nxp._configure_if_nx_active()
@py_random_state(3)
def parallel_bfs(G, source=None, get_chunks="chunks", n_jobs=None):
"""
Perform a parallelized Breadth-First Search (BFS) on the graph.

Parameters
----------
G : graph
A NetworkX graph.
source : node, optional
Starting node for the BFS traversal. If None, BFS is performed for all nodes.
get_chunks : str or function (default="chunks")
A function to divide nodes into chunks for parallel processing.
If "chunks", the nodes are split into `n_jobs` chunks automatically.
n_jobs : int, optional
Number of jobs to run in parallel. If None, defaults to the number of CPUs.

Returns
-------
bfs_result : dict
A dictionary where keys are nodes and values are their BFS traversal order.
"""
if hasattr(G, "graph_object"):
G = G.graph_object

if source is None:
nodes = G.nodes
else:
nodes = [source]

if n_jobs is None:
n_jobs = nxp.get_n_jobs()

# Create node chunks
if get_chunks == "chunks":
node_chunks = nxp.create_iterables(G, "node", n_jobs, nodes)
else:
node_chunks = get_chunks(nodes)

# Run BFS on each chunk in parallel
bfs_results = Parallel(n_jobs=n_jobs)(
delayed(_bfs_chunk)(G, chunk) for chunk in node_chunks
)

# Combine results from all chunks
combined_result = {}
for result in bfs_results:
combined_result.update(result)

return combined_result


def _bfs_chunk(G, nodes):
"""
Perform BFS for a subset of nodes.

Parameters
----------
G : graph
A NetworkX graph.
nodes : list
A list of nodes to start BFS from.

Returns
-------
bfs_result : dict
BFS traversal order for the given subset of nodes.
"""
bfs_result = {}
for node in nodes:
if node not in bfs_result:
visited = set()
queue = [node]
order = 0

while queue:
current = queue.pop(0)
if current not in visited:
visited.add(current)
bfs_result[current] = order
order += 1
queue.extend(
neighbor
for neighbor in G.neighbors(current)
if neighbor not in visited
)

return bfs_result
2 changes: 2 additions & 0 deletions nx_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# Centrality
"betweenness_centrality",
"edge_betweenness_centrality",
"harmonic_centrality",
"voterank",
# Efficiency
"local_efficiency",
# Shortest Paths : generic
Expand Down
Binary file added timing/heatmap_harmonic_centrality_timing.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion timing/timing_individual_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
number_of_nodes_list = [200, 400, 800, 1600]
weighted = False
pList = [1, 0.8, 0.6, 0.4, 0.2]
currFun = nx.tournament.is_reachable
currFun = nx.harmonic_centrality
"""
for p in pList:
for num in range(len(number_of_nodes_list)):
Expand Down
Loading