Skip to content

Commit c7dc9d9

Browse files
IoannisPanagiotasFlorentinDvnickolov
committed
Make euclidean distance be double for float as well and use that intel to reduce code.
Co-authored-by: Florentin Dörre <florentin.dorre@neotechnology.com> Co-authored-by: Veselin Nikolov <veselin.nikolov@neotechnology.com>
1 parent 261b9a6 commit c7dc9d9

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed

algo/src/main/java/org/neo4j/gds/kmeans/ClusterManager.java

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,28 @@ void initializeCenters(List<Long> initialCenterIds) {
5555
}
5656
}
5757

58-
abstract int findClosestCenter(long nodeId);
5958

6059
static ClusterManager createClusterManager(NodePropertyValues values, int dimensions, int k) {
6160
if (values.valueType() == ValueType.FLOAT_ARRAY) {
6261
return new FloatClusterManager(values, dimensions, k);
6362
}
6463
return new DoubleClusterManager(values, dimensions, k);
6564
}
65+
66+
public abstract double euclidean(long nodeId, int centerId);
67+
68+
public int findClosestCenter(long nodeId) {
69+
int community = 0;
70+
double smallestDistance = Double.MAX_VALUE;
71+
for (int centerId = 0; centerId < k; ++centerId) {
72+
double distance = euclidean(nodeId, centerId);
73+
if (Double.compare(distance, smallestDistance) < 0) {
74+
smallestDistance = distance;
75+
community = centerId;
76+
}
77+
}
78+
return community;
79+
}
6680
}
6781

6882
class FloatClusterManager extends ClusterManager {
@@ -107,6 +121,13 @@ public void updateFromTask(KmeansTask task) {
107121
}
108122
}
109123

124+
@Override
125+
public double euclidean(long nodeId, int centerId) {
126+
float[] left = nodePropertyValues.floatArrayValue(nodeId);
127+
float[] right = clusterCenters[centerId];
128+
return Math.sqrt(Intersections.sumSquareDelta(left, right, right.length));
129+
}
130+
110131
private float floatEuclidean(float[] left, float[] right) {
111132
return (float) Math.sqrt(Intersections.sumSquareDelta(left, right, right.length));
112133
}
@@ -164,28 +185,18 @@ public void updateFromTask(KmeansTask task) {
164185
}
165186

166187
@Override
167-
public void initialAssignCluster(int i, long id) {
168-
double[] cluster = nodePropertyValues.doubleArrayValue(id);
169-
System.arraycopy(cluster, 0, clusterCenters[i], 0, cluster.length);
170-
}
171-
172-
private double doubleEuclidean(double[] left, double[] right) {
188+
public double euclidean(long nodeId, int centerId) {
189+
double[] left = nodePropertyValues.doubleArrayValue(nodeId);
190+
double[] right = clusterCenters[centerId];
173191
return Math.sqrt(Intersections.sumSquareDelta(left, right, right.length));
192+
174193
}
175194

176195
@Override
177-
public int findClosestCenter(long nodeId) {
178-
var property = nodePropertyValues.doubleArrayValue(nodeId);
179-
int community = 0;
180-
double smallestDistance = Double.MAX_VALUE;
181-
for (int centerId = 0; centerId < k; ++centerId) {
182-
double distance = doubleEuclidean(property, clusterCenters[centerId]);
183-
if (Double.compare(distance, smallestDistance) < 0) {
184-
smallestDistance = distance;
185-
community = centerId;
186-
}
187-
}
188-
return community;
196+
public void initialAssignCluster(int i, long id) {
197+
double[] cluster = nodePropertyValues.doubleArrayValue(id);
198+
System.arraycopy(cluster, 0, clusterCenters[i], 0, cluster.length);
189199
}
190200

201+
191202
}

0 commit comments

Comments
 (0)