@@ -54,11 +54,13 @@ static void compressBytes(byte[] raw, byte[] compressed) {
54
54
private byte [] bytesA ;
55
55
private byte [] bytesB ;
56
56
private byte [] halfBytesA ;
57
+ private byte [] halfBytesAPacked ;
57
58
private byte [] halfBytesB ;
58
59
private byte [] halfBytesBPacked ;
59
60
private float [] floatsA ;
60
61
private float [] floatsB ;
61
- private int expectedhalfByteDotProduct ;
62
+ private int expectedHalfByteDotProduct ;
63
+ private int expectedHalfByteSquareDistance ;
62
64
63
65
@ Param ({"1" , "128" , "207" , "256" , "300" , "512" , "702" , "1024" })
64
66
int size ;
@@ -74,16 +76,23 @@ public void init() {
74
76
random .nextBytes (bytesB );
75
77
// random half byte arrays for binary methods
76
78
// this means that all values must be between 0 and 15
77
- expectedhalfByteDotProduct = 0 ;
79
+ expectedHalfByteDotProduct = 0 ;
80
+ expectedHalfByteSquareDistance = 0 ;
78
81
halfBytesA = new byte [size ];
79
82
halfBytesB = new byte [size ];
80
83
for (int i = 0 ; i < size ; ++i ) {
81
84
halfBytesA [i ] = (byte ) random .nextInt (16 );
82
85
halfBytesB [i ] = (byte ) random .nextInt (16 );
83
- expectedhalfByteDotProduct += halfBytesA [i ] * halfBytesB [i ];
86
+ expectedHalfByteDotProduct += halfBytesA [i ] * halfBytesB [i ];
87
+
88
+ int diff = halfBytesA [i ] - halfBytesB [i ];
89
+ expectedHalfByteSquareDistance += diff * diff ;
84
90
}
85
91
// pack the half byte arrays
86
92
if (size % 2 == 0 ) {
93
+ halfBytesAPacked = new byte [(size + 1 ) >> 1 ];
94
+ compressBytes (halfBytesA , halfBytesAPacked );
95
+
87
96
halfBytesBPacked = new byte [(size + 1 ) >> 1 ];
88
97
compressBytes (halfBytesB , halfBytesBPacked );
89
98
}
@@ -108,6 +117,74 @@ public float binaryCosineVector() {
108
117
return VectorUtil .cosine (bytesA , bytesB );
109
118
}
110
119
120
+ @ Benchmark
121
+ public int binarySquareScalar () {
122
+ return VectorUtil .squareDistance (bytesA , bytesB );
123
+ }
124
+
125
+ @ Benchmark
126
+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
127
+ public int binarySquareVector () {
128
+ return VectorUtil .squareDistance (bytesA , bytesB );
129
+ }
130
+
131
+ @ Benchmark
132
+ public int binaryHalfByteSquareScalar () {
133
+ int v = VectorUtil .int4SquareDistance (halfBytesA , halfBytesB );
134
+ if (v != expectedHalfByteSquareDistance ) {
135
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
136
+ }
137
+ return v ;
138
+ }
139
+
140
+ @ Benchmark
141
+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
142
+ public int binaryHalfByteSquareVector () {
143
+ int v = VectorUtil .int4SquareDistance (halfBytesA , halfBytesB );
144
+ if (v != expectedHalfByteSquareDistance ) {
145
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
146
+ }
147
+ return v ;
148
+ }
149
+
150
+ @ Benchmark
151
+ public int binaryHalfByteSquareSinglePackedScalar () {
152
+ int v = VectorUtil .int4SquareDistanceSinglePacked (halfBytesA , halfBytesBPacked );
153
+ if (v != expectedHalfByteSquareDistance ) {
154
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
155
+ }
156
+ return v ;
157
+ }
158
+
159
+ @ Benchmark
160
+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
161
+ public int binaryHalfByteSquareSinglePackedVector () {
162
+ int v = VectorUtil .int4SquareDistanceSinglePacked (halfBytesA , halfBytesBPacked );
163
+ if (v != expectedHalfByteSquareDistance ) {
164
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
165
+ }
166
+ return v ;
167
+ }
168
+
169
+ @ Benchmark
170
+ public int binaryHalfByteSquareBothPackedScalar () {
171
+ int v = VectorUtil .int4SquareDistanceBothPacked (halfBytesAPacked , halfBytesBPacked );
172
+ if (v != expectedHalfByteSquareDistance ) {
173
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
174
+ }
175
+ return v ;
176
+ }
177
+
178
+ @ Benchmark
179
+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
180
+ public int binaryHalfByteSquareBothPackedVector () {
181
+ int v = VectorUtil .int4SquareDistanceBothPacked (halfBytesAPacked , halfBytesBPacked );
182
+ if (v != expectedHalfByteSquareDistance ) {
183
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
184
+ }
185
+ return v ;
186
+ }
187
+
111
188
@ Benchmark
112
189
public int binaryDotProductScalar () {
113
190
return VectorUtil .dotProduct (bytesA , bytesB );
@@ -131,14 +208,22 @@ public int binaryDotProductUint8Vector() {
131
208
}
132
209
133
210
@ Benchmark
134
- public int binarySquareScalar () {
135
- return VectorUtil .squareDistance (bytesA , bytesB );
211
+ public int binaryHalfByteDotProductScalar () {
212
+ int v = VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
213
+ if (v != expectedHalfByteDotProduct ) {
214
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
215
+ }
216
+ return v ;
136
217
}
137
218
138
219
@ Benchmark
139
220
@ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
140
- public int binarySquareVector () {
141
- return VectorUtil .squareDistance (bytesA , bytesB );
221
+ public int binaryHalfByteDotProductVector () {
222
+ int v = VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
223
+ if (v != expectedHalfByteDotProduct ) {
224
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
225
+ }
226
+ return v ;
142
227
}
143
228
144
229
@ Benchmark
@@ -153,37 +238,39 @@ public int binarySquareUint8Vector() {
153
238
}
154
239
155
240
@ Benchmark
156
- public int binaryHalfByteScalar () {
157
- return VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
241
+ public int binaryHalfByteDotProductSinglePackedScalar () {
242
+ int v = VectorUtil .int4DotProductSinglePacked (halfBytesA , halfBytesBPacked );
243
+ if (v != expectedHalfByteDotProduct ) {
244
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
245
+ }
246
+ return v ;
158
247
}
159
248
160
249
@ Benchmark
161
250
@ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
162
- public int binaryHalfByteVector () {
163
- return VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
251
+ public int binaryHalfByteDotProductSinglePackedVector () {
252
+ int v = VectorUtil .int4DotProductSinglePacked (halfBytesA , halfBytesBPacked );
253
+ if (v != expectedHalfByteDotProduct ) {
254
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
255
+ }
256
+ return v ;
164
257
}
165
258
166
259
@ Benchmark
167
- public int binaryHalfByteScalarPacked () {
168
- if (size % 2 != 0 ) {
169
- throw new RuntimeException ("Size must be even for this benchmark" );
170
- }
171
- int v = VectorUtil .int4DotProductPacked (halfBytesA , halfBytesBPacked );
172
- if (v != expectedhalfByteDotProduct ) {
173
- throw new RuntimeException ("Expected " + expectedhalfByteDotProduct + " but got " + v );
260
+ public int binaryHalfByteDotProductBothPackedScalar () {
261
+ int v = VectorUtil .int4DotProductBothPacked (halfBytesAPacked , halfBytesBPacked );
262
+ if (v != expectedHalfByteDotProduct ) {
263
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
174
264
}
175
265
return v ;
176
266
}
177
267
178
268
@ Benchmark
179
269
@ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
180
- public int binaryHalfByteVectorPacked () {
181
- if (size % 2 != 0 ) {
182
- throw new RuntimeException ("Size must be even for this benchmark" );
183
- }
184
- int v = VectorUtil .int4DotProductPacked (halfBytesA , halfBytesBPacked );
185
- if (v != expectedhalfByteDotProduct ) {
186
- throw new RuntimeException ("Expected " + expectedhalfByteDotProduct + " but got " + v );
270
+ public int binaryHalfByteDotProductBothPackedVector () {
271
+ int v = VectorUtil .int4DotProductBothPacked (halfBytesAPacked , halfBytesBPacked );
272
+ if (v != expectedHalfByteDotProduct ) {
273
+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
187
274
}
188
275
return v ;
189
276
}
0 commit comments