Skip to content

Commit 045d248

Browse files
Implement Leiden stream procedure
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent 34056f6 commit 045d248

File tree

8 files changed

+366
-1
lines changed

8 files changed

+366
-1
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
import org.neo4j.gds.GraphAlgorithmFactory;
23+
import org.neo4j.gds.api.Graph;
24+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
25+
26+
public class LeidenAlgorithmFactory<CONFIG extends LeidenBaseConfig> extends GraphAlgorithmFactory<Leiden, CONFIG> {
27+
@Override
28+
public Leiden build(Graph graph, CONFIG configuration, ProgressTracker progressTracker) {
29+
30+
double gamma = configuration.gamma() / graph.relationshipCount();
31+
32+
return new Leiden(
33+
graph,
34+
configuration.maxLevels(),
35+
gamma,
36+
configuration.theta(),
37+
configuration.randomSeed().orElse(0L),
38+
progressTracker
39+
);
40+
}
41+
42+
@Override
43+
public String taskName() {
44+
return "Leiden";
45+
}
46+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
import org.immutables.value.Value;
23+
import org.neo4j.gds.config.AlgoBaseConfig;
24+
import org.neo4j.gds.config.RandomSeedConfig;
25+
import org.neo4j.gds.config.RelationshipWeightConfig;
26+
27+
public interface LeidenBaseConfig extends
28+
AlgoBaseConfig,
29+
RelationshipWeightConfig,
30+
RandomSeedConfig {
31+
32+
@Value.Default
33+
default double gamma() {
34+
return 1.0;
35+
}
36+
37+
@Value.Default
38+
default double theta() {
39+
return 0.01;
40+
}
41+
42+
@Value.Default
43+
default int maxLevels() {
44+
return 10;
45+
}
46+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
23+
import org.neo4j.gds.annotation.Configuration;
24+
import org.neo4j.gds.annotation.ValueClass;
25+
import org.neo4j.gds.core.CypherMapWrapper;
26+
27+
@ValueClass
28+
@Configuration
29+
@SuppressWarnings("immutables:subtype")
30+
public interface LeidenStreamConfig extends LeidenBaseConfig {
31+
static LeidenStreamConfig of(CypherMapWrapper userInput) {
32+
return new LeidenStreamConfigImpl(userInput);
33+
}
34+
}

doc/asciidoc/operations-reference/appendix-a-graph-algos.adoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,4 +309,6 @@ include::../algorithms/algorithm-tiers.adoc[]
309309
| `gds.alpha.kmeans.mutate`
310310
| `gds.alpha.kmeans.stats`
311311
| `gds.alpha.kmeans.stream`
312+
.1+<.^| Leiden
313+
| `gds.alpha.leiden.stream`
312314
|===

open-packaging/src/test/java/org/neo4j/gds/OpenGdsProcedureSmokeTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class OpenGdsProcedureSmokeTest extends BaseProcTest {
6262
"gds.alpha.closeness.harmonic.write",
6363
"gds.alpha.closeness.harmonic.stream",
6464

65+
"gds.alpha.leiden.stream",
66+
6567
"gds.alpha.kmeans.mutate",
6668
"gds.alpha.kmeans.stats",
6769
"gds.alpha.kmeans.stream",
@@ -455,7 +457,7 @@ void countShouldMatch() {
455457
);
456458

457459
// If you find yourself updating this count, please also update the count in SmokeTest.kt
458-
int expectedCount = 314;
460+
int expectedCount = 315;
459461
assertEquals(
460462
expectedCount,
461463
registeredProcedures.size(),
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
import org.neo4j.gds.AlgoBaseProc;
23+
import org.neo4j.gds.AlgorithmFactory;
24+
import org.neo4j.gds.core.CypherMapWrapper;
25+
import org.neo4j.gds.core.utils.paged.HugeLongArray;
26+
import org.neo4j.gds.executor.ComputationResultConsumer;
27+
import org.neo4j.gds.executor.ProcedureExecutor;
28+
import org.neo4j.procedure.Description;
29+
import org.neo4j.procedure.Name;
30+
import org.neo4j.procedure.Procedure;
31+
32+
import java.util.Map;
33+
import java.util.stream.Stream;
34+
35+
import static org.neo4j.procedure.Mode.READ;
36+
37+
public class LeidenStreamProc extends AlgoBaseProc<Leiden, HugeLongArray, LeidenStreamConfig, LeidenStreamProc.StreamResult> {
38+
// Config
39+
// relationshipWeightProperty
40+
// maxIterations
41+
42+
// Output
43+
// nodeId
44+
// communityId
45+
46+
static final String DESCRIPTION =
47+
"Leiden is a community detection algorithm, which guarantees that communities are well connected";
48+
49+
@Procedure(name = "gds.alpha.leiden.stream", mode = READ)
50+
@Description(DESCRIPTION)
51+
public Stream<StreamResult> stream(
52+
@Name(value = "graphName") String graphName,
53+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
54+
) {
55+
var streamSpec = new LeidenStreamSpec();
56+
57+
return new ProcedureExecutor<>(
58+
streamSpec,
59+
executionContext()
60+
).compute(graphName, configuration, true, true);
61+
}
62+
63+
@Override
64+
public AlgorithmFactory<?, Leiden, LeidenStreamConfig> algorithmFactory() {
65+
return new LeidenStreamSpec().algorithmFactory();
66+
}
67+
68+
@Override
69+
public ComputationResultConsumer<Leiden, HugeLongArray, LeidenStreamConfig, Stream<StreamResult>> computationResultConsumer() {
70+
return new LeidenStreamSpec().computationResultConsumer();
71+
}
72+
73+
@Override
74+
protected LeidenStreamConfig newConfig(String username, CypherMapWrapper config) {
75+
return new LeidenStreamSpec().newConfigFunction().apply(username, config);
76+
}
77+
78+
public static class StreamResult {
79+
80+
public final long nodeId;
81+
public final long communityId;
82+
83+
StreamResult(long nodeId, long communityId) {
84+
this.nodeId = nodeId;
85+
this.communityId = communityId;
86+
}
87+
}
88+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
import org.neo4j.gds.core.utils.paged.HugeLongArray;
23+
import org.neo4j.gds.executor.AlgorithmSpec;
24+
import org.neo4j.gds.executor.ComputationResultConsumer;
25+
import org.neo4j.gds.executor.GdsCallable;
26+
import org.neo4j.gds.executor.NewConfigFunction;
27+
28+
import java.util.stream.LongStream;
29+
import java.util.stream.Stream;
30+
31+
import static org.neo4j.gds.executor.ExecutionMode.STREAM;
32+
import static org.neo4j.gds.leiden.LeidenStreamProc.DESCRIPTION;
33+
34+
@GdsCallable(name = "gds.alpha.leiden.stream", description = DESCRIPTION, executionMode = STREAM)
35+
public class LeidenStreamSpec implements AlgorithmSpec<Leiden, HugeLongArray, LeidenStreamConfig, Stream<LeidenStreamProc.StreamResult>, LeidenAlgorithmFactory<LeidenStreamConfig>> {
36+
@Override
37+
public String name() {
38+
return "LeidenStream";
39+
}
40+
41+
@Override
42+
public LeidenAlgorithmFactory<LeidenStreamConfig> algorithmFactory() {
43+
return new LeidenAlgorithmFactory<>();
44+
}
45+
46+
@Override
47+
public NewConfigFunction<LeidenStreamConfig> newConfigFunction() {
48+
return (__, config) -> LeidenStreamConfig.of(config);
49+
}
50+
51+
@Override
52+
public ComputationResultConsumer<Leiden, HugeLongArray, LeidenStreamConfig, Stream<LeidenStreamProc.StreamResult>> computationResultConsumer() {
53+
return (computationResult, executionContext) -> {
54+
var graph = computationResult.graph();
55+
var communities = computationResult.result();
56+
return LongStream.range(0, graph.nodeCount())
57+
.mapToObj(nodeId -> new LeidenStreamProc.StreamResult(graph.toOriginalNodeId(nodeId), communities.get(nodeId)));
58+
};
59+
}
60+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.leiden;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.BaseProcTest;
25+
import org.neo4j.gds.catalog.GraphProjectProc;
26+
import org.neo4j.gds.extension.Neo4jGraph;
27+
28+
import java.util.HashSet;
29+
30+
import static org.assertj.core.api.Assertions.assertThat;
31+
32+
class LeidenStreamProcTest extends BaseProcTest {
33+
34+
@Neo4jGraph
35+
private static final String DB_CYPHER =
36+
"CREATE " +
37+
" (a0:Node)," +
38+
" (a1:Node)," +
39+
" (a2:Node)," +
40+
" (a3:Node)," +
41+
" (a4:Node)," +
42+
" (a5:Node)," +
43+
" (a6:Node)," +
44+
" (a7:Node)," +
45+
" (a0)-[:R {weight: 1.0}]->(a1)," +
46+
" (a0)-[:R {weight: 1.0}]->(a2)," +
47+
" (a0)-[:R {weight: 1.0}]->(a3)," +
48+
" (a0)-[:R {weight: 1.0}]->(a4)," +
49+
" (a2)-[:R {weight: 1.0}]->(a3)," +
50+
" (a2)-[:R {weight: 1.0}]->(a4)," +
51+
" (a3)-[:R {weight: 1.0}]->(a4)," +
52+
" (a1)-[:R {weight: 1.0}]->(a5)," +
53+
" (a1)-[:R {weight: 1.0}]->(a6)," +
54+
" (a1)-[:R {weight: 1.0}]->(a7)," +
55+
" (a5)-[:R {weight: 1.0}]->(a6)," +
56+
" (a5)-[:R {weight: 1.0}]->(a7)," +
57+
" (a6)-[:R {weight: 1.0}]->(a7)";
58+
59+
@BeforeEach
60+
void setUp() throws Exception {
61+
registerProcedures(
62+
GraphProjectProc.class,
63+
LeidenStreamProc.class
64+
);
65+
66+
runQuery("CALL gds.graph.project('leiden', '*', '*')");
67+
}
68+
69+
@Test
70+
void stream() {
71+
runQuery("CALL gds.alpha.leiden.stream('leiden')", result -> {
72+
assertThat(result.columns()).containsExactlyInAnyOrder("nodeId", "communityId");
73+
long resultRowCount = 0;
74+
var communities = new HashSet<Long>();
75+
while(result.hasNext()) {
76+
var next = result.next();
77+
assertThat(next.get("nodeId")).isInstanceOf(Long.class);
78+
assertThat(next.get("communityId")).isInstanceOf(Long.class);
79+
communities.add((Long) next.get("communityId"));
80+
resultRowCount++;
81+
}
82+
assertThat(resultRowCount).isEqualTo(8L);
83+
assertThat(communities).hasSize(2);
84+
return true;
85+
});
86+
}
87+
}

0 commit comments

Comments
 (0)