Skip to content

Commit a6d3715

Browse files
Implement param resolution by passing a map
1 parent 62c99fb commit a6d3715

File tree

11 files changed

+88
-18
lines changed

11 files changed

+88
-18
lines changed

core/src/main/java/org/neo4j/gds/config/GraphProjectFromGraphConfig.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.neo4j.gds.concurrency.ConcurrencyValidatorService;
2727
import org.neo4j.gds.core.CypherMapWrapper;
2828

29+
import java.util.Map;
30+
2931

3032
@ValueClass
3133
@Configuration
@@ -53,6 +55,12 @@ default int concurrency() {
5355
return ConcurrencyConfig.DEFAULT_CONCURRENCY;
5456
}
5557

58+
@Value.Default
59+
@Value.Parameter(false)
60+
default Map<String, Object> parameterMap() {
61+
return Map.of();
62+
}
63+
5664
@Value.Check
5765
default void validateReadConcurrency() {
5866
ConcurrencyValidatorService.validator().validate(concurrency(), "concurrency", ConcurrencyConfig.CONCURRENCY_LIMITATION);

proc/catalog/src/test/java/org/neo4j/gds/catalog/GraphProjectSubgraphProcTest.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
class GraphProjectSubgraphProcTest extends BaseProcTest {
4242

4343
@Neo4jGraph
44-
public static final String DB = "CREATE (a:A)-[:REL]->(b:B)";
44+
public static final String DB = "CREATE (a:A)-[:REL { weight: 42.0 }]->(b:B)";
4545

4646
@BeforeEach
4747
void setup() throws Exception {
@@ -51,7 +51,8 @@ void setup() throws Exception {
5151
.graphProject()
5252
.withNodeLabel("A")
5353
.withNodeLabel("B")
54-
.withAnyRelationshipType()
54+
.withRelationshipType("REL")
55+
.withRelationshipProperty("weight")
5556
.yields()
5657
);
5758
}
@@ -155,12 +156,19 @@ void throwsOnSemanticNodeError() {
155156

156157
@Test
157158
void throwsOnSemanticRelationshipError() {
158-
var subGraphQuery = "CALL gds.beta.graph.project.subgraph('subgraph', 'graph', 'true', 'r:BAR AND r.weight > 42')";
159+
var subGraphQuery = "CALL gds.beta.graph.project.subgraph('subgraph', 'graph', 'true', 'r:BAR AND r.prop > 42')";
159160

160161
assertThatThrownBy(() -> runQuery(subGraphQuery))
161162
.getRootCause()
162163
.isInstanceOf(SemanticErrors.class)
163-
.hasMessageContaining("Unknown property `weight`.")
164+
.hasMessageContaining("Unknown property `prop`.")
164165
.hasMessageContaining("Unknown relationship type `BAR`.");
165166
}
167+
168+
@Test
169+
void shouldResolveParameters() {
170+
var subGraphQuery = "CALL gds.beta.graph.project.subgraph('subgraph', 'graph', 'true', 'r:REL AND r.weight > $weight', { parameterMap: { weight: $weight } })";
171+
172+
runQuery(subGraphQuery, Map.of("weight", 42));
173+
}
166174
}

subgraph-filtering/src/main/java/org/neo4j/gds/beta/filter/GraphStoreFilter.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ public static GraphStore filter(
9393
graphStore,
9494
expressions.nodeExpression(),
9595
config.concurrency(),
96+
config.parameterMap(),
9697
executorService,
9798
progressTracker
9899
);
@@ -103,6 +104,7 @@ public static GraphStore filter(
103104
inputNodes,
104105
filteredNodes.idMap(),
105106
config.concurrency(),
107+
config.parameterMap(),
106108
executorService,
107109
progressTracker
108110
);

subgraph-filtering/src/main/java/org/neo4j/gds/beta/filter/NodesFilter.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
4545

4646
import java.util.Iterator;
47+
import java.util.Map;
4748
import java.util.Optional;
4849
import java.util.concurrent.ExecutorService;
4950
import java.util.function.Function;
@@ -61,6 +62,7 @@ static FilteredNodes filterNodes(
6162
GraphStore inputGraphStore,
6263
Expression expression,
6364
int concurrency,
65+
Map<String, Object> parameterMap,
6466
ExecutorService executorService,
6567
ProgressTracker progressTracker
6668
) {
@@ -79,6 +81,7 @@ static FilteredNodes filterNodes(
7981
var tasks = NodeFilterTask.of(
8082
inputGraphStore,
8183
expression,
84+
parameterMap,
8285
partitions,
8386
nodesBuilder,
8487
progressTracker
@@ -247,6 +250,7 @@ private static final class NodeFilterTask implements Runnable {
247250
static Iterator<NodeFilterTask> of(
248251
GraphStore inputGraphStore,
249252
Expression expression,
253+
Map<String, Object> parameterMap,
250254
Iterator<Partition> partitions,
251255
NodesBuilder nodesBuilder,
252256
ProgressTracker progressTracker
@@ -261,6 +265,7 @@ protected NodeFilterTask fetch() {
261265
return new NodeFilterTask(
262266
partitions.next(),
263267
expression,
268+
parameterMap,
264269
inputGraphStore,
265270
nodesBuilder,
266271
progressTracker
@@ -272,6 +277,7 @@ protected NodeFilterTask fetch() {
272277
private NodeFilterTask(
273278
Partition partition,
274279
Expression expression,
280+
Map<String, Object> parameterMap,
275281
GraphStore inputGraphStore,
276282
NodesBuilder nodesBuilder,
277283
ProgressTracker progressTracker
@@ -280,7 +286,7 @@ private NodeFilterTask(
280286
this.expression = expression;
281287
this.inputGraphStore = inputGraphStore;
282288
this.nodesBuilder = nodesBuilder;
283-
this.nodeContext = new EvaluationContext.NodeEvaluationContext(inputGraphStore);
289+
this.nodeContext = new EvaluationContext.NodeEvaluationContext(inputGraphStore, parameterMap);
284290
this.progressTracker = progressTracker;
285291
}
286292

subgraph-filtering/src/main/java/org/neo4j/gds/beta/filter/RelationshipsFilter.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ static FilteredRelationships filterRelationships(
6767
IdMap inputNodes,
6868
IdMap outputNodes,
6969
int concurrency,
70+
Map<String, Object> parameterMap,
7071
ExecutorService executorService,
7172
ProgressTracker progressTracker
7273
) {
@@ -85,6 +86,7 @@ static FilteredRelationships filterRelationships(
8586
outputNodes,
8687
relType,
8788
concurrency,
89+
parameterMap,
8890
executorService,
8991
progressTracker
9092
);
@@ -140,6 +142,7 @@ private static FilteredRelationship filterRelationshipType(
140142
IdMap outputNodes,
141143
RelationshipType relType,
142144
int concurrency,
145+
Map<String, Object> parameterMap,
143146
ExecutorService executorService,
144147
ProgressTracker progressTracker
145148
) {
@@ -172,6 +175,7 @@ private static FilteredRelationship filterRelationshipType(
172175
outputNodes,
173176
relationshipsBuilder,
174177
relType,
178+
parameterMap,
175179
propertyIndices,
176180
progressTracker
177181
),
@@ -217,6 +221,7 @@ private RelationshipFilterTask(
217221
IdMap outputNodes,
218222
RelationshipsBuilder relationshipsBuilder,
219223
RelationshipType relType,
224+
Map<String, Object> parameterMap,
220225
Map<String, Integer> propertyIndices,
221226
ProgressTracker progressTracker
222227
) {
@@ -227,7 +232,7 @@ private RelationshipFilterTask(
227232
this.outputNodes = outputNodes;
228233
this.relationshipsBuilder = relationshipsBuilder;
229234
this.relType = relType;
230-
this.evaluationContext = new EvaluationContext.RelationshipEvaluationContext(propertyIndices);
235+
this.evaluationContext = new EvaluationContext.RelationshipEvaluationContext(propertyIndices, parameterMap);
231236
this.progressTracker = progressTracker;
232237
}
233238

subgraph-filtering/src/main/java/org/neo4j/gds/beta/filter/expression/AstFactoryAdapter.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,6 @@ public Expression oldParameter(InputPosition p, Expression.LeafExpression.Variab
323323
throw new UnsupportedOperationException();
324324
}
325325

326-
@Override
327-
public Expression newParameter(InputPosition p, Expression.LeafExpression.Variable v) {
328-
throw new UnsupportedOperationException();
329-
}
330-
331326
@Override
332327
public NULL newSingleQuery(List<NULL> nulls) {
333328
throw new UnsupportedOperationException();

subgraph-filtering/src/main/java/org/neo4j/gds/beta/filter/expression/EvaluationContext.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@
3232

3333
public abstract class EvaluationContext {
3434

35+
private final Map<String, Object> parameterMap;
36+
37+
protected EvaluationContext(Map<String, Object> parameterMap) {
38+
this.parameterMap = parameterMap;
39+
}
40+
41+
double resolveParameter(String parameterName) {
42+
return ((Number) this.parameterMap.get(parameterName)).doubleValue();
43+
}
44+
3545
abstract double getProperty(String propertyKey, ValueType propertyType);
3646

3747
abstract boolean hasLabelsOrTypes(List<String> labelsOrTypes);
@@ -41,7 +51,11 @@ public static class NodeEvaluationContext extends EvaluationContext {
4151
private final GraphStore graphStore;
4252
private long nodeId;
4353

44-
public NodeEvaluationContext(GraphStore graphStore) {
54+
public NodeEvaluationContext(
55+
GraphStore graphStore,
56+
Map<String, Object> parameterMap
57+
) {
58+
super(parameterMap);
4559
this.graphStore = graphStore;
4660
}
4761

@@ -80,7 +94,11 @@ public static class RelationshipEvaluationContext extends EvaluationContext {
8094

8195
private final ObjectIntMap<String> propertyIndices;
8296

83-
public RelationshipEvaluationContext(Map<String, Integer> propertyIndices) {
97+
public RelationshipEvaluationContext(
98+
Map<String, Integer> propertyIndices,
99+
Map<String, Object> parameterMap
100+
) {
101+
super(parameterMap);
84102
this.propertyIndices = new ObjectIntScatterMap<>();
85103
propertyIndices.forEach(this.propertyIndices::put);
86104
}

subgraph-filtering/src/main/java/org/neo4j/gds/beta/filter/expression/Expression.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,24 @@ default double evaluate(EvaluationContext context) {
198198
}
199199

200200
}
201+
202+
@ValueClass
203+
interface NewParameter extends UnaryExpression {
204+
205+
@Value.Derived
206+
@Override
207+
default double evaluate(EvaluationContext context) {
208+
return context.resolveParameter(((LeafExpression.Variable) in()).name());
209+
}
210+
211+
@Override
212+
default ValidationContext validate(ValidationContext context) {
213+
if (!(in() instanceof LeafExpression.Variable)) {
214+
throw new IllegalStateException();
215+
}
216+
return context;
217+
}
218+
}
201219
}
202220

203221
interface BinaryExpression extends Expression {

subgraph-filtering/src/main/java/org/neo4j/gds/beta/filter/expression/GdsAstFactory.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ public Expression.LeafExpression.Variable newVariable(InputPosition p, String na
4141
return ImmutableVariable.builder().name(name).build();
4242
}
4343

44+
45+
@Override
46+
public Expression newParameter(
47+
InputPosition p, Expression.LeafExpression.Variable v
48+
) {
49+
return ImmutableNewParameter.of(v.valueType(), v);
50+
}
51+
4452
@Override
4553
public Expression.Literal.DoubleLiteral newDouble(InputPosition p, String image) {
4654
return ImmutableDoubleLiteral.builder().value(Double.parseDouble(image)).build();

subgraph-filtering/src/test/java/org/neo4j/gds/beta/filter/expression/EvaluationContextTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ void nodeEvaluationContextPositive(
7777
ValueType valueType,
7878
List<String> expectedLabels
7979
) {
80-
var context = new EvaluationContext.NodeEvaluationContext(graphStore);
80+
var context = new EvaluationContext.NodeEvaluationContext(graphStore, Map.of());
8181
context.init(idFunction.of(variable));
8282
assertThat(context.getProperty(propertyKey, valueType)).isEqualTo(expectedValue);
8383
assertThat(context.hasLabelsOrTypes(expectedLabels)).isTrue();
@@ -93,22 +93,22 @@ private static Stream<Arguments> nodesNegative() {
9393
@ParameterizedTest
9494
@MethodSource("nodesNegative")
9595
void nodeEvaluationContextNegative(String variable, List<String> unExpectedLabels) {
96-
var context = new EvaluationContext.NodeEvaluationContext(graphStore);
96+
var context = new EvaluationContext.NodeEvaluationContext(graphStore, Map.of());
9797
context.init(idFunction.of(variable));
9898
assertThat(context.hasLabelsOrTypes(unExpectedLabels)).isFalse();
9999
}
100100

101101
@Test
102102
void relationshipEvaluationContextPositive() {
103-
var context = new EvaluationContext.RelationshipEvaluationContext(Map.of("baz", 0));
103+
var context = new EvaluationContext.RelationshipEvaluationContext(Map.of("baz", 0), Map.of());
104104
context.init("REL", new double[]{84});
105105
assertThat(context.getProperty("baz", ValueType.DOUBLE)).isEqualTo(84);
106106
assertThat(context.hasLabelsOrTypes(List.of("REL"))).isTrue();
107107
}
108108

109109
@Test
110110
void relationshipEvaluationContextNegative() {
111-
var context = new EvaluationContext.RelationshipEvaluationContext(Map.of());
111+
var context = new EvaluationContext.RelationshipEvaluationContext(Map.of(), Map.of());
112112
context.init("REL");
113113
assertThat(context.hasLabelsOrTypes(List.of("BAR"))).isFalse();
114114
}

0 commit comments

Comments
 (0)