Skip to content

Commit aa9cd02

Browse files
authored
Merge pull request #5206 from breakanalysis/lp-termination-flag
Better termination flag handling in LP
2 parents 4197ba3 + 178d593 commit aa9cd02

File tree

13 files changed

+147
-55
lines changed

13 files changed

+147
-55
lines changed

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/ClassifierTrainerFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ public static ClassifierTrainer create(
6363
(RandomForestTrainerConfig) config,
6464
false,
6565
randomSeed,
66-
progressTracker
66+
progressTracker,
67+
terminationFlag
6768
);
6869
}
6970
default:

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/randomforest/RandomForestClassifierTrainer.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.neo4j.gds.annotation.ValueClass;
2424
import org.neo4j.gds.core.concurrency.ParallelUtil;
2525
import org.neo4j.gds.core.concurrency.Pools;
26+
import org.neo4j.gds.core.utils.TerminationFlag;
2627
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2728
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
2829
import org.neo4j.gds.core.utils.mem.MemoryRange;
@@ -60,6 +61,7 @@ public class RandomForestClassifierTrainer implements ClassifierTrainer {
6061
private final boolean computeOutOfBagError;
6162
private final SplittableRandom random;
6263
private final ProgressTracker progressTracker;
64+
private final TerminationFlag terminationFlag;
6365
private Optional<Double> outOfBagError = Optional.empty();
6466

6567
public RandomForestClassifierTrainer(
@@ -68,14 +70,16 @@ public RandomForestClassifierTrainer(
6870
RandomForestTrainerConfig config,
6971
boolean computeOutOfBagError,
7072
Optional<Long> randomSeed,
71-
ProgressTracker progressTracker
73+
ProgressTracker progressTracker,
74+
TerminationFlag terminationFlag
7275
) {
7376
this.classIdMap = classIdMap;
7477
this.config = config;
7578
this.concurrency = concurrency;
7679
this.computeOutOfBagError = computeOutOfBagError;
7780
this.random = new SplittableRandom(randomSeed.orElseGet(() -> new SplittableRandom().nextLong()));
7881
this.progressTracker = progressTracker;
82+
this.terminationFlag = terminationFlag;
7983
}
8084

8185
public static MemoryEstimation memoryEstimation(
@@ -153,7 +157,7 @@ public RandomForestClassifier train(
153157
numberOfTreesTrained
154158
)
155159
).collect(Collectors.toList());
156-
ParallelUtil.runWithConcurrency(concurrency, tasks, Pools.DEFAULT);
160+
ParallelUtil.runWithConcurrency(concurrency, tasks, terminationFlag, Pools.DEFAULT);
157161

158162
outOfBagError = maybePredictions.map(predictions -> OutOfBagErrorMetric.evaluate(
159163
trainSet,

ml/ml-algo/src/test/java/org/neo4j/gds/ml/models/randomforest/RandomForestClassifierTest.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.junit.jupiter.params.provider.CsvSource;
2626
import org.junit.jupiter.params.provider.ValueSource;
2727
import org.neo4j.gds.core.GraphDimensions;
28+
import org.neo4j.gds.core.utils.TerminationFlag;
2829
import org.neo4j.gds.core.utils.mem.MemoryRange;
2930
import org.neo4j.gds.core.utils.paged.HugeLongArray;
3031
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
@@ -110,7 +111,8 @@ void usingOneTree(int concurrency) {
110111
.build(),
111112
false,
112113
Optional.of(42L),
113-
ProgressTracker.NULL_TRACKER
114+
ProgressTracker.NULL_TRACKER,
115+
TerminationFlag.RUNNING_TRUE
114116
);
115117

116118
var randomForestPredictor = randomForestTrainer.train(allFeatureVectors, allLabels, trainSet);
@@ -141,7 +143,8 @@ void usingTwentyTrees(int concurrency) {
141143
.build(),
142144
false,
143145
Optional.of(1337L),
144-
ProgressTracker.NULL_TRACKER
146+
ProgressTracker.NULL_TRACKER,
147+
TerminationFlag.RUNNING_TRUE
145148
);
146149

147150
var randomForestPredictor = randomForestTrainer.train(allFeatureVectors, allLabels, trainSet);
@@ -171,7 +174,8 @@ void shouldMakeSaneErrorEstimation(int concurrency) {
171174
.build(),
172175
true,
173176
Optional.of(1337L),
174-
ProgressTracker.NULL_TRACKER
177+
ProgressTracker.NULL_TRACKER,
178+
TerminationFlag.RUNNING_TRUE
175179
);
176180

177181
randomForestTrainer.train(allFeatureVectors, allLabels, trainSet);
@@ -195,7 +199,8 @@ void considerTrainSet(int concurrency) {
195199
.build(),
196200
false,
197201
Optional.of(1337L),
198-
ProgressTracker.NULL_TRACKER
202+
ProgressTracker.NULL_TRACKER,
203+
TerminationFlag.RUNNING_TRUE
199204
);
200205

201206
HugeLongArray mutableTrainSet = HugeLongArray.newArray(NUM_SAMPLES / 2);
@@ -232,19 +237,19 @@ void predictOverheadMemoryEstimation(
232237

233238
@ParameterizedTest
234239
@CsvSource(value = {
235-
" 6, 100_000, 10, 10, 1, 1, 0.1, 1.0, 4413594, 5226418",
240+
" 6, 100_000, 10, 10, 1, 1, 0.1, 1.0, 4413602, 5226426",
236241
// Should increase fairly little with more trees if training set big.
237-
" 10, 100_000, 10, 10, 1, 10, 0.1, 1.0, 4414242, 6295802",
242+
" 10, 100_000, 10, 10, 1, 10, 0.1, 1.0, 4414250, 6295810",
238243
// Should be capped by number of training examples, despite high max depth.
239-
" 8_000, 500, 10, 10, 1, 1, 0.1, 1.0, 23154, 182954",
244+
" 8_000, 500, 10, 10, 1, 1, 0.1, 1.0, 23162, 182962",
240245
// Should increase very little when having more classes.
241-
" 10, 100_000, 100, 10, 1, 10, 0.1, 1.0, 4414962, 6296522",
246+
" 10, 100_000, 100, 10, 1, 10, 0.1, 1.0, 4414970, 6296530",
242247
// Should increase very little when using more features for splits.
243-
" 10, 100_000, 100, 10, 1, 10, 0.9, 1.0, 4415034, 6296686",
248+
" 10, 100_000, 100, 10, 1, 10, 0.9, 1.0, 4415042, 6296694",
244249
// Should decrease a lot when sampling fewer training examples per tree.
245-
" 10, 100_000, 100, 10, 1, 10, 0.1, 0.2, 1204962, 2446522",
250+
" 10, 100_000, 100, 10, 1, 10, 0.1, 0.2, 1204970, 2446530",
246251
// Should almost be x4 when concurrency * 4.
247-
" 10, 100_000, 100, 10, 4, 10, 0.1, 1.0, 16457256, 21037256",
252+
" 10, 100_000, 100, 10, 4, 10, 0.1, 1.0, 16457264, 21037264",
248253
})
249254
void trainMemoryEstimation(
250255
int maxDepth,

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkFeatureExtractor.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.neo4j.gds.api.Graph;
2323
import org.neo4j.gds.core.concurrency.ParallelUtil;
2424
import org.neo4j.gds.core.concurrency.Pools;
25+
import org.neo4j.gds.core.utils.TerminationFlag;
2526
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
2627
import org.neo4j.gds.core.utils.partition.DegreePartition;
2728
import org.neo4j.gds.core.utils.partition.PartitionUtils;
@@ -71,7 +72,8 @@ public static Features extractFeatures(
7172
Graph graph,
7273
List<LinkFeatureStep> linkFeatureSteps,
7374
int concurrency,
74-
ProgressTracker progressTracker
75+
ProgressTracker progressTracker,
76+
TerminationFlag terminationFlag
7577
) {
7678
var extractor = of(graph, linkFeatureSteps);
7779

@@ -101,7 +103,7 @@ public static Features extractFeatures(
101103
relationshipOffset += partition.totalDegree();
102104
}
103105

104-
ParallelUtil.runWithConcurrency(concurrency, linkFeatureWriters, Pools.DEFAULT);
106+
ParallelUtil.runWithConcurrency(concurrency, linkFeatureWriters, terminationFlag, Pools.DEFAULT);
105107

106108
return FeaturesFactory.wrap(linkFeatures);
107109
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkFeaturesAndLabelsExtractor.java

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.commons.lang3.mutable.MutableLong;
2323
import org.neo4j.gds.RelationshipType;
2424
import org.neo4j.gds.api.Graph;
25+
import org.neo4j.gds.core.utils.TerminationFlag;
2526
import org.neo4j.gds.core.concurrency.ParallelUtil;
2627
import org.neo4j.gds.core.concurrency.Pools;
2728
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
@@ -75,18 +76,29 @@ static FeaturesAndLabels extractFeaturesAndLabels(
7576
Graph graph,
7677
List<LinkFeatureStep> featureSteps,
7778
int concurrency,
78-
ProgressTracker progressTracker
79+
ProgressTracker progressTracker,
80+
TerminationFlag terminationFlag
7981
) {
8082
progressTracker.setVolume(graph.relationshipCount() * 2);
81-
var features = LinkFeatureExtractor.extractFeatures(graph, featureSteps, concurrency, progressTracker);
83+
var features = LinkFeatureExtractor.extractFeatures(
84+
graph,
85+
featureSteps,
86+
concurrency,
87+
progressTracker,
88+
terminationFlag
89+
);
8290

83-
var labels = extractLabels(graph, features.size(), concurrency, progressTracker);
91+
var labels = extractLabels(graph, features.size(), concurrency, progressTracker, terminationFlag);
8492

8593
return ImmutableFeaturesAndLabels.of(features, labels);
8694
}
8795

8896
private static HugeLongArray extractLabels(
89-
Graph graph, long numberOfTargets, int concurrency, ProgressTracker progressTracker
97+
Graph graph,
98+
long numberOfTargets,
99+
int concurrency,
100+
ProgressTracker progressTracker,
101+
TerminationFlag terminationFlag
90102
) {
91103
var globalLabels = HugeLongArray.newArray(numberOfTargets);
92104
var partitions = PartitionUtils.degreePartition(
@@ -121,7 +133,7 @@ private static HugeLongArray extractLabels(
121133
relationshipOffset.add(partition.totalDegree());
122134
}
123135

124-
ParallelUtil.runWithConcurrency(concurrency, tasks, Pools.DEFAULT);
136+
ParallelUtil.runWithConcurrency(concurrency, tasks, terminationFlag, Pools.DEFAULT);
125137
return globalLabels;
126138
}
127139
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
package org.neo4j.gds.ml.pipeline.linkPipeline.train;
2121

2222
import org.jetbrains.annotations.NotNull;
23-
import org.neo4j.gds.Algorithm;
2423
import org.neo4j.gds.RelationshipType;
2524
import org.neo4j.gds.api.Graph;
2625
import org.neo4j.gds.core.model.Model;
26+
import org.neo4j.gds.core.utils.TerminationFlag;
2727
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2828
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
2929
import org.neo4j.gds.core.utils.mem.MemoryRange;
@@ -66,7 +66,7 @@
6666
import static org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkFeaturesAndLabelsExtractor.extractFeaturesAndLabels;
6767
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
6868

69-
public class LinkPredictionTrain extends Algorithm<LinkPredictionTrainResult> {
69+
public final class LinkPredictionTrain {
7070

7171
public static final String MODEL_TYPE = "LinkPrediction";
7272

@@ -75,6 +75,9 @@ public class LinkPredictionTrain extends Algorithm<LinkPredictionTrainResult> {
7575
private final LinkPredictionTrainingPipeline pipeline;
7676
private final LinkPredictionTrainConfig config;
7777
private final LocalIdMap classIdMap;
78+
private final ProgressTracker progressTracker;
79+
private final TerminationFlag terminationFlag;
80+
7881

7982
public static LocalIdMap makeClassIdMap() {
8083
return LocalIdMap.of((long) EdgeSplitter.NEGATIVE, (long) EdgeSplitter.POSITIVE);
@@ -85,13 +88,15 @@ public LinkPredictionTrain(
8588
Graph validationGraph,
8689
LinkPredictionTrainingPipeline pipeline,
8790
LinkPredictionTrainConfig config,
88-
ProgressTracker progressTracker
91+
ProgressTracker progressTracker,
92+
TerminationFlag terminationFlag
8993
) {
90-
super(progressTracker);
9194
this.trainGraph = trainGraph;
9295
this.validationGraph = validationGraph;
9396
this.pipeline = pipeline;
9497
this.config = config;
98+
this.terminationFlag = terminationFlag;
99+
this.progressTracker = progressTracker;
95100
this.classIdMap = makeClassIdMap();
96101
}
97102

@@ -113,15 +118,15 @@ public static List<Task> progressTasks(int validationFolds, int numberOfModelSel
113118
);
114119
}
115120

116-
@Override
117121
public LinkPredictionTrainResult compute() {
118122

119123
progressTracker.beginSubTask("Extract train features");
120124
var trainData = extractFeaturesAndLabels(
121125
trainGraph,
122126
pipeline.featureSteps(),
123127
config.concurrency(),
124-
progressTracker
128+
progressTracker,
129+
terminationFlag
125130
);
126131
var trainRelationshipIds = new ReadOnlyHugeLongIdentityArray(trainData.size());
127132
progressTracker.endSubTask("Extract train features");
@@ -281,7 +286,8 @@ private void computeTestMetric(Classifier classifier, TrainingStatistics trainin
281286
validationGraph,
282287
pipeline.featureSteps(),
283288
config.concurrency(),
284-
progressTracker
289+
progressTracker,
290+
terminationFlag
285291
);
286292
progressTracker.endSubTask("Extract test features");
287293

@@ -362,11 +368,6 @@ private Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredicti
362368
);
363369
}
364370

365-
@Override
366-
public void release() {
367-
368-
}
369-
370371
public static MemoryEstimation estimate(
371372
LinkPredictionTrainingPipeline pipeline,
372373
LinkPredictionTrainConfig trainConfig

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkFeatureExtractorTest.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.junit.jupiter.api.Test;
2323
import org.neo4j.gds.Orientation;
2424
import org.neo4j.gds.api.Graph;
25+
import org.neo4j.gds.core.utils.TerminationFlag;
2526
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
2627
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2728
import org.neo4j.gds.extension.GdlExtension;
@@ -58,7 +59,8 @@ void singleLinkFeatureStep() {
5859
graph,
5960
List.of(new HadamardFeatureStep(List.of("array"))),
6061
1,
61-
ProgressTracker.NULL_TRACKER
62+
ProgressTracker.NULL_TRACKER,
63+
TerminationFlag.RUNNING_TRUE
6264
);
6365

6466
var expected = HugeObjectArray.of(
@@ -86,7 +88,8 @@ void multipleLinkFeatureStep() {
8688
new CosineFeatureStep(List.of("noise", "z"))
8789
),
8890
1,
89-
ProgressTracker.NULL_TRACKER
91+
ProgressTracker.NULL_TRACKER,
92+
TerminationFlag.RUNNING_TRUE
9093
);
9194

9295
var normA = Math.sqrt(42 * 42 + 13 * 13);

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/linkfunctions/CosineLinkFeatureStepTest.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions;
2121

2222
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.core.utils.TerminationFlag;
2324
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2425
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
2526
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStepFactory;
@@ -40,7 +41,13 @@ public void runCosineLinkFeatureStep() {
4041
ImmutableLinkFeatureStepConfiguration.builder().nodeProperties(List.of("noise", "z", "array")).build()
4142
);
4243

43-
var linkFeatures = LinkFeatureExtractor.extractFeatures(graph, List.of(step), 4, ProgressTracker.NULL_TRACKER);
44+
var linkFeatures = LinkFeatureExtractor.extractFeatures(
45+
graph,
46+
List.of(step),
47+
4,
48+
ProgressTracker.NULL_TRACKER,
49+
TerminationFlag.RUNNING_TRUE
50+
);
4451

4552
var delta = 0.0001D;
4653

@@ -64,7 +71,13 @@ public void handlesZeroVectors() {
6471
ImmutableLinkFeatureStepConfiguration.builder().nodeProperties(List.of("zeros")).build()
6572
);
6673

67-
var linkFeatures = LinkFeatureExtractor.extractFeatures(graph, List.of(step), 4, ProgressTracker.NULL_TRACKER);
74+
var linkFeatures = LinkFeatureExtractor.extractFeatures(
75+
graph,
76+
List.of(step),
77+
4,
78+
ProgressTracker.NULL_TRACKER,
79+
TerminationFlag.RUNNING_TRUE
80+
);
6881

6982
for (long i = 0; i < linkFeatures.size(); i++) {
7083
assertThat(linkFeatures.get(i)).hasSize(1).containsExactly(0.0);
@@ -78,7 +91,13 @@ public void failsOnNaNValues() {
7891
ImmutableLinkFeatureStepConfiguration.builder().nodeProperties(List.of("invalidValue", "z")).build()
7992
);
8093

81-
assertThatThrownBy(() -> LinkFeatureExtractor.extractFeatures(graph, List.of(step), 4, ProgressTracker.NULL_TRACKER))
94+
assertThatThrownBy(() -> LinkFeatureExtractor.extractFeatures(
95+
graph,
96+
List.of(step),
97+
4,
98+
ProgressTracker.NULL_TRACKER,
99+
TerminationFlag.RUNNING_TRUE
100+
))
82101
.hasMessage("Encountered NaN in the nodeProperty `invalidValue` for nodes ['1'] when computing the cosine feature vector. " +
83102
"Either define a default value if its a stored property or check the nodePropertyStep");
84103
}

0 commit comments

Comments
 (0)