Skip to content

Commit ba394f4

Browse files
Address review comments
Co-authored-by: Florentin Dörre <florentin.dorre@neotechnology.com>
1 parent c7dc9d9 commit ba394f4

File tree

2 files changed

+8
-23
lines changed

2 files changed

+8
-23
lines changed

algo/src/main/java/org/neo4j/gds/kmeans/ClusterManager.java

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ abstract class ClusterManager {
3333
final int dimensions;
3434
final int k;
3535

36-
public ClusterManager(NodePropertyValues values, int dimensions, int k) {
36+
ClusterManager(NodePropertyValues values, int dimensions, int k) {
3737
this.dimensions = dimensions;
3838
this.k = k;
3939
this.nodePropertyValues = values;
@@ -82,7 +82,7 @@ public int findClosestCenter(long nodeId) {
8282
class FloatClusterManager extends ClusterManager {
8383
private final float[][] clusterCenters;
8484

85-
public FloatClusterManager(NodePropertyValues values, int dimensions, int k) {
85+
FloatClusterManager(NodePropertyValues values, int dimensions, int k) {
8686
super(values, dimensions, k);
8787
this.clusterCenters = new float[k][dimensions];
8888
}
@@ -115,8 +115,9 @@ public void updateFromTask(KmeansTask task) {
115115
var floatKmeansTask = (FloatKmeansTask) task;
116116
for (int centerId = 0; centerId < k; ++centerId) {
117117
nodesInCluster[centerId] += task.getNumAssignedAtCenter(centerId);
118+
var taskContributionToCluster = floatKmeansTask.getCenterContribution(centerId);
118119
for (int dimension = 0; dimension < dimensions; ++dimension) {
119-
clusterCenters[centerId][dimension] += floatKmeansTask.getCenterContribution(centerId)[dimension];
120+
clusterCenters[centerId][dimension] += taskContributionToCluster[dimension];
120121
}
121122
}
122123
}
@@ -132,26 +133,12 @@ private float floatEuclidean(float[] left, float[] right) {
132133
return (float) Math.sqrt(Intersections.sumSquareDelta(left, right, right.length));
133134
}
134135

135-
@Override
136-
public int findClosestCenter(long nodeId) {
137-
var property = nodePropertyValues.floatArrayValue(nodeId);
138-
int community = 0;
139-
float smallestDistance = Float.MAX_VALUE;
140-
for (int centerId = 0; centerId < k; ++centerId) {
141-
float distance = floatEuclidean(property, clusterCenters[centerId]);
142-
if (Float.compare(distance, smallestDistance) < 0) {
143-
smallestDistance = distance;
144-
community = centerId;
145-
}
146-
}
147-
return community;
148-
}
149136
}
150137

151138
class DoubleClusterManager extends ClusterManager {
152139
private final double[][] clusterCenters;
153140

154-
public DoubleClusterManager(NodePropertyValues values, int dimensions, int k) {
141+
DoubleClusterManager(NodePropertyValues values, int dimensions, int k) {
155142
super(values, dimensions, k);
156143
this.clusterCenters = new double[k][dimensions];
157144
}
@@ -177,9 +164,9 @@ public void updateFromTask(KmeansTask task) {
177164
var doubleKmeansTask = (DoubleKmeansTask) task;
178165
for (int centerId = 0; centerId < k; ++centerId) {
179166
nodesInCluster[centerId] += task.getNumAssignedAtCenter(centerId);
180-
167+
var taskContributionToCluster = doubleKmeansTask.getCenterContribution(centerId);
181168
for (int dimension = 0; dimension < dimensions; ++dimension) {
182-
clusterCenters[centerId][dimension] += doubleKmeansTask.getCenterContribution(centerId)[dimension];
169+
clusterCenters[centerId][dimension] += taskContributionToCluster[dimension];
183170
}
184171
}
185172
}
@@ -198,5 +185,4 @@ public void initialAssignCluster(int i, long id) {
198185
System.arraycopy(cluster, 0, clusterCenters[i], 0, cluster.length);
199186
}
200187

201-
202188
}

algo/src/main/java/org/neo4j/gds/kmeans/KmeansTask.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
*/
2020
package org.neo4j.gds.kmeans;
2121

22-
import org.neo4j.gds.api.nodeproperties.ValueType;
2322
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
2423
import org.neo4j.gds.core.utils.paged.HugeIntArray;
2524
import org.neo4j.gds.core.utils.partition.Partition;
@@ -78,7 +77,7 @@ static KmeansTask createTask(
7877
Partition partition,
7978
ProgressTracker progressTracker
8079
) {
81-
if (nodePropertyValues.valueType() == ValueType.DOUBLE_ARRAY) {
80+
if (clusterManager instanceof DoubleClusterManager) {
8281
return new DoubleKmeansTask(
8382
clusterManager,
8483
nodePropertyValues,

0 commit comments

Comments
 (0)