Skip to content

Commit 655361f

Browse files
Implement Leiden stats proc
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent 08b020c commit 655361f

File tree

11 files changed

+465
-12
lines changed

11 files changed

+465
-12
lines changed

algo/src/main/java/org/neo4j/gds/leiden/Leiden.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
//TODO: take care of potential issues w. self-loops
3838

39-
public class Leiden extends Algorithm<HugeLongArray> {
39+
public class Leiden extends Algorithm<LeidenResult> {
4040

4141
private final Graph rootGraph;
4242

@@ -74,7 +74,7 @@ public Leiden(
7474
}
7575

7676
@Override
77-
public HugeLongArray compute() {
77+
public LeidenResult compute() {
7878
var workingGraph = rootGraph;
7979
var orientation = rootGraph.isUndirected() ? Orientation.UNDIRECTED : Orientation.NATURAL;
8080

@@ -85,6 +85,7 @@ public HugeLongArray compute() {
8585
HugeLongArray partition = HugeLongArray.newArray(workingGraph.nodeCount());
8686
partition.setAll(nodeId -> nodeId);
8787

88+
boolean didConverge = false;
8889
int iteration;
8990
// move on with refinement -> aggregation -> local move again
9091
for (iteration = 0; iteration < maxIterations; iteration++) {
@@ -97,7 +98,8 @@ public HugeLongArray compute() {
9798
communityVolumes = localMovePhasePartition.communityVolumes();
9899
var communitiesCount = Arrays.stream(partition.toArray()).distinct().count();
99100

100-
if (communitiesCount == workingGraph.nodeCount()) {
101+
didConverge = communitiesCount == workingGraph.nodeCount();
102+
if (didConverge) {
101103
break;
102104
}
103105

@@ -140,7 +142,7 @@ public HugeLongArray compute() {
140142
nodeVolumes = communityData.aggregatedNodeSeedVolume;
141143
}
142144

143-
return dendrograms[iteration - 1];
145+
return LeidenResult.of(dendrograms[iteration - 1], iteration, didConverge);
144146
}
145147

146148
private void initVolumes(HugeDoubleArray nodeVolumes, HugeDoubleArray communityVolumes) {
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
import org.immutables.value.Value;
23+
import org.neo4j.gds.annotation.ValueClass;
24+
import org.neo4j.gds.core.utils.paged.HugeLongArray;
25+
26+
import java.util.function.LongUnaryOperator;
27+
28+
@ValueClass
29+
@SuppressWarnings("immutables:subtype")
30+
public interface LeidenResult {
31+
32+
HugeLongArray communities();
33+
34+
long ranLevels();
35+
36+
boolean didConverge();
37+
38+
@Value.Derived
39+
default LongUnaryOperator communitiesFunction() {
40+
return communities()::get;
41+
}
42+
43+
static LeidenResult of(HugeLongArray communities, long ranLevels, boolean didConverge) {
44+
return ImmutableLeidenResult.of(communities, ranLevels, didConverge);
45+
}
46+
47+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
23+
import org.neo4j.gds.annotation.Configuration;
24+
import org.neo4j.gds.annotation.ValueClass;
25+
import org.neo4j.gds.core.CypherMapWrapper;
26+
27+
@ValueClass
28+
@Configuration
29+
@SuppressWarnings("immutables:subtype")
30+
public interface LeidenStatsConfig extends LeidenBaseConfig {
31+
static LeidenStatsConfig of(CypherMapWrapper userInput) {
32+
return new LeidenStatsConfigImpl(userInput);
33+
}
34+
}

algo/src/test/java/org/neo4j/gds/leiden/FootballTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class FootballTest {
5353
void leiden(long seed) {
5454
var gamma = 1.0 / graph.relationshipCount();
5555
Leiden leiden = new Leiden(graph, 5, gamma, 0.01, seed, 1, ProgressTracker.NULL_TRACKER);
56-
var communities = leiden.compute();
56+
var communities = leiden.compute().communities();
5757
var communitiesMap = LongStream
5858
.range(0, graph.nodeCount())
5959
.mapToObj(v -> "a" + v)

algo/src/test/java/org/neo4j/gds/leiden/KarateTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class KarateTest {
5050
void leiden(long seed) {
5151
var gamma = 1.0 / graph.relationshipCount();
5252
Leiden leiden = new Leiden(graph, 5, gamma, 0.01, seed, 1, ProgressTracker.NULL_TRACKER);
53-
var communities = leiden.compute();
53+
var communities = leiden.compute().communities();
5454
var communitiesMap = LongStream
5555
.range(0, graph.nodeCount())
5656
.mapToObj(v -> "a" + v)

algo/src/test/java/org/neo4j/gds/leiden/LeidenTest.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,16 @@ class LeidenTest {
7474

7575
@Test
7676
void leiden() {
77-
Leiden leiden = new Leiden(graph, 2, 1.0 / graph.relationshipCount(), 0.01, 19L, 1, ProgressTracker.NULL_TRACKER
77+
int maxLevels = 3;
78+
Leiden leiden = new Leiden(graph, maxLevels, 1.0 / graph.relationshipCount(), 0.01, 19L, 1, ProgressTracker.NULL_TRACKER
7879
);
79-
var communities = leiden.compute();
80+
81+
var leidenResult = leiden.compute();
82+
83+
assertThat(leidenResult.ranLevels()).isLessThanOrEqualTo(maxLevels);
84+
assertThat(leidenResult.didConverge()).isTrue();
85+
86+
var communities = leidenResult.communities();
8087
var communitiesMap = LongStream
8188
.range(0, graph.nodeCount())
8289
.mapToObj(v -> "a" + v)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
import org.neo4j.gds.AlgoBaseProc;
23+
import org.neo4j.gds.AlgorithmFactory;
24+
import org.neo4j.gds.core.CypherMapWrapper;
25+
import org.neo4j.gds.executor.ComputationResultConsumer;
26+
import org.neo4j.gds.executor.ProcedureExecutor;
27+
import org.neo4j.procedure.Description;
28+
import org.neo4j.procedure.Name;
29+
import org.neo4j.procedure.Procedure;
30+
31+
import java.util.Map;
32+
import java.util.stream.Stream;
33+
34+
import static org.neo4j.procedure.Mode.READ;
35+
36+
public class LeidenStatsProc extends AlgoBaseProc<Leiden, LeidenResult, LeidenStatsConfig, StatsResult> {
37+
38+
@Procedure(value = "gds.alpha.leiden.stats", mode = READ)
39+
@Description(STATS_DESCRIPTION)
40+
public Stream<StatsResult> stats(
41+
@Name(value = "graphName") String graphName,
42+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
43+
) {
44+
var statsSpec = new LeidenStatsSpec();
45+
return new ProcedureExecutor<>(
46+
statsSpec,
47+
executionContext()
48+
).compute(graphName, configuration, true, true);
49+
}
50+
51+
52+
53+
@Override
54+
@Deprecated
55+
public AlgorithmFactory<?, Leiden, LeidenStatsConfig> algorithmFactory() {
56+
return null;
57+
}
58+
59+
@Override
60+
@Deprecated
61+
public <T extends ComputationResultConsumer<Leiden, LeidenResult, LeidenStatsConfig, Stream<StatsResult>>> T computationResultConsumer() {
62+
return null;
63+
}
64+
65+
@Override
66+
@Deprecated
67+
protected LeidenStatsConfig newConfig(String username, CypherMapWrapper config) {
68+
return null;
69+
}
70+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
import org.neo4j.gds.executor.AlgorithmSpec;
23+
import org.neo4j.gds.executor.ComputationResultConsumer;
24+
import org.neo4j.gds.executor.GdsCallable;
25+
import org.neo4j.gds.executor.NewConfigFunction;
26+
27+
import java.util.stream.Stream;
28+
29+
import static org.neo4j.gds.executor.ExecutionMode.STREAM;
30+
import static org.neo4j.gds.leiden.LeidenStreamProc.DESCRIPTION;
31+
32+
@GdsCallable(name = "gds.alpha.leiden.stats", description = DESCRIPTION, executionMode = STREAM)
33+
public class LeidenStatsSpec implements AlgorithmSpec<Leiden, LeidenResult, LeidenStatsConfig, Stream<StatsResult>, LeidenAlgorithmFactory<LeidenStatsConfig>> {
34+
@Override
35+
public String name() {
36+
return "LeidenStats";
37+
}
38+
39+
@Override
40+
public LeidenAlgorithmFactory<LeidenStatsConfig> algorithmFactory() {
41+
return new LeidenAlgorithmFactory<>();
42+
}
43+
44+
@Override
45+
public NewConfigFunction<LeidenStatsConfig> newConfigFunction() {
46+
return (__, config) -> LeidenStatsConfig.of(config);
47+
}
48+
49+
@Override
50+
public ComputationResultConsumer<Leiden, LeidenResult, LeidenStatsConfig, Stream<StatsResult>> computationResultConsumer() {
51+
return (computationResult, executionContext) -> {
52+
var leidenResult = computationResult.result();
53+
if (leidenResult == null) {
54+
return Stream.empty();
55+
}
56+
57+
var statsBuilder = new StatsResult.StatsBuilder(
58+
executionContext.callContext(),
59+
computationResult.config().concurrency()
60+
);
61+
62+
var statsResult = statsBuilder
63+
.withLevels(leidenResult.ranLevels())
64+
.withDidConverge(leidenResult.didConverge())
65+
.withCommunityFunction(leidenResult.communitiesFunction())
66+
.withPreProcessingMillis(computationResult.preProcessingMillis())
67+
.withComputeMillis(computationResult.computeMillis())
68+
.withNodeCount(computationResult.graph().nodeCount())
69+
.withConfig(computationResult.config())
70+
.build();
71+
72+
return Stream.of(statsResult);
73+
};
74+
}
75+
}

proc/community/src/main/java/org/neo4j/gds/leiden/LeidenStreamSpec.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
*/
2020
package org.neo4j.gds.leiden;
2121

22-
import org.neo4j.gds.core.utils.paged.HugeLongArray;
2322
import org.neo4j.gds.executor.AlgorithmSpec;
2423
import org.neo4j.gds.executor.ComputationResultConsumer;
2524
import org.neo4j.gds.executor.GdsCallable;
@@ -32,7 +31,7 @@
3231
import static org.neo4j.gds.leiden.LeidenStreamProc.DESCRIPTION;
3332

3433
@GdsCallable(name = "gds.alpha.leiden.stream", description = DESCRIPTION, executionMode = STREAM)
35-
public class LeidenStreamSpec implements AlgorithmSpec<Leiden, HugeLongArray, LeidenStreamConfig, Stream<StreamResult>, LeidenAlgorithmFactory<LeidenStreamConfig>> {
34+
public class LeidenStreamSpec implements AlgorithmSpec<Leiden, LeidenResult, LeidenStreamConfig, Stream<StreamResult>, LeidenAlgorithmFactory<LeidenStreamConfig>> {
3635
@Override
3736
public String name() {
3837
return "LeidenStream";
@@ -49,10 +48,14 @@ public NewConfigFunction<LeidenStreamConfig> newConfigFunction() {
4948
}
5049

5150
@Override
52-
public ComputationResultConsumer<Leiden, HugeLongArray, LeidenStreamConfig, Stream<StreamResult>> computationResultConsumer() {
51+
public ComputationResultConsumer<Leiden, LeidenResult, LeidenStreamConfig, Stream<StreamResult>> computationResultConsumer() {
5352
return (computationResult, executionContext) -> {
53+
var leidenResult = computationResult.result();
54+
if (leidenResult == null) {
55+
return Stream.empty();
56+
}
5457
var graph = computationResult.graph();
55-
var communities = computationResult.result();
58+
var communities = leidenResult.communities();
5659
return LongStream.range(0, graph.nodeCount())
5760
.mapToObj(nodeId -> new StreamResult(graph.toOriginalNodeId(nodeId), communities.get(nodeId)));
5861
};

0 commit comments

Comments
 (0)