Skip to content

Commit 613ade4

Browse files
committed
Use thread-safe graph copy in label extraction
Resetting the count of days without a `graph.concurrentCopy` bug.
1 parent 4dcf091 commit 613ade4

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkFeaturesAndLabelsExtractor.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
import org.apache.commons.lang3.mutable.MutableLong;
2323
import org.neo4j.gds.RelationshipType;
2424
import org.neo4j.gds.api.Graph;
25-
import org.neo4j.gds.core.utils.TerminationFlag;
2625
import org.neo4j.gds.core.concurrency.ParallelUtil;
2726
import org.neo4j.gds.core.concurrency.Pools;
27+
import org.neo4j.gds.core.utils.TerminationFlag;
2828
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2929
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
3030
import org.neo4j.gds.core.utils.mem.MemoryRange;
@@ -113,20 +113,18 @@ private static HugeLongArray extractLabels(
113113
var startRelationshipOffset = relationshipOffset.getValue();
114114
tasks.add(() -> {
115115
var currentRelationshipOffset = new MutableLong(startRelationshipOffset);
116-
partition.consume(nodeId -> {
117-
graph.forEachRelationship(nodeId, -10, (src, trg, weight) -> {
118-
if (weight == EdgeSplitter.NEGATIVE || weight == EdgeSplitter.POSITIVE) {
119-
globalLabels.set(currentRelationshipOffset.getAndIncrement(), (long) weight);
120-
} else {
121-
throw new IllegalArgumentException(formatWithLocale("Label should be either `1` or `0`. But got %f for relationship (%d, %d)",
122-
weight,
123-
src,
124-
trg
125-
));
126-
}
127-
return true;
128-
});
129-
});
116+
partition.consume(nodeId -> graph.concurrentCopy().forEachRelationship(nodeId, -10, (src, trg, weight) -> {
117+
if (weight == EdgeSplitter.NEGATIVE || weight == EdgeSplitter.POSITIVE) {
118+
globalLabels.set(currentRelationshipOffset.getAndIncrement(), (long) weight);
119+
} else {
120+
throw new IllegalArgumentException(formatWithLocale("Label should be either `1` or `0`. But got %f for relationship (%d, %d)",
121+
weight,
122+
src,
123+
trg
124+
));
125+
}
126+
return true;
127+
}));
130128
progressTracker.logProgress(partition.totalDegree());
131129
}
132130
);

0 commit comments

Comments
 (0)