Skip to content

Commit ac04f41

Browse files
author
Kaival Parikh
committed
Perform scoring for 4 and 7 bit quantized vectors off-heap
1 parent 50a4f18 commit ac04f41

17 files changed

+970
-202
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ Optimizations
136136
* GITHUB#15151: Use `SimScorer#score` bulk API to compute impact scores per
137137
block of postings. (Adrien Grand)
138138

139+
* GITHUB#14863: Perform scoring for 4 and 7 bit quantized vectors off-heap. (Kaival Parikh)
140+
139141
Bug Fixes
140142
---------------------
141143
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException

lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java

Lines changed: 112 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,13 @@ static void compressBytes(byte[] raw, byte[] compressed) {
5454
private byte[] bytesA;
5555
private byte[] bytesB;
5656
private byte[] halfBytesA;
57+
private byte[] halfBytesAPacked;
5758
private byte[] halfBytesB;
5859
private byte[] halfBytesBPacked;
5960
private float[] floatsA;
6061
private float[] floatsB;
61-
private int expectedhalfByteDotProduct;
62+
private int expectedHalfByteDotProduct;
63+
private int expectedHalfByteSquareDistance;
6264

6365
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
6466
int size;
@@ -74,16 +76,23 @@ public void init() {
7476
random.nextBytes(bytesB);
7577
// random half byte arrays for binary methods
7678
// this means that all values must be between 0 and 15
77-
expectedhalfByteDotProduct = 0;
79+
expectedHalfByteDotProduct = 0;
80+
expectedHalfByteSquareDistance = 0;
7881
halfBytesA = new byte[size];
7982
halfBytesB = new byte[size];
8083
for (int i = 0; i < size; ++i) {
8184
halfBytesA[i] = (byte) random.nextInt(16);
8285
halfBytesB[i] = (byte) random.nextInt(16);
83-
expectedhalfByteDotProduct += halfBytesA[i] * halfBytesB[i];
86+
expectedHalfByteDotProduct += halfBytesA[i] * halfBytesB[i];
87+
88+
int diff = halfBytesA[i] - halfBytesB[i];
89+
expectedHalfByteSquareDistance += diff * diff;
8490
}
8591
// pack the half byte arrays
8692
if (size % 2 == 0) {
93+
halfBytesAPacked = new byte[(size + 1) >> 1];
94+
compressBytes(halfBytesA, halfBytesAPacked);
95+
8796
halfBytesBPacked = new byte[(size + 1) >> 1];
8897
compressBytes(halfBytesB, halfBytesBPacked);
8998
}
@@ -108,6 +117,74 @@ public float binaryCosineVector() {
108117
return VectorUtil.cosine(bytesA, bytesB);
109118
}
110119

120+
@Benchmark
121+
public int binarySquareScalar() {
122+
return VectorUtil.squareDistance(bytesA, bytesB);
123+
}
124+
125+
@Benchmark
126+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
127+
public int binarySquareVector() {
128+
return VectorUtil.squareDistance(bytesA, bytesB);
129+
}
130+
131+
@Benchmark
132+
public int binaryHalfByteSquareScalar() {
133+
int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB);
134+
if (v != expectedHalfByteSquareDistance) {
135+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
136+
}
137+
return v;
138+
}
139+
140+
@Benchmark
141+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
142+
public int binaryHalfByteSquareVector() {
143+
int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB);
144+
if (v != expectedHalfByteSquareDistance) {
145+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
146+
}
147+
return v;
148+
}
149+
150+
@Benchmark
151+
public int binaryHalfByteSquareSinglePackedScalar() {
152+
int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked);
153+
if (v != expectedHalfByteSquareDistance) {
154+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
155+
}
156+
return v;
157+
}
158+
159+
@Benchmark
160+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
161+
public int binaryHalfByteSquareSinglePackedVector() {
162+
int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked);
163+
if (v != expectedHalfByteSquareDistance) {
164+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
165+
}
166+
return v;
167+
}
168+
169+
@Benchmark
170+
public int binaryHalfByteSquareBothPackedScalar() {
171+
int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked);
172+
if (v != expectedHalfByteSquareDistance) {
173+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
174+
}
175+
return v;
176+
}
177+
178+
@Benchmark
179+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
180+
public int binaryHalfByteSquareBothPackedVector() {
181+
int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked);
182+
if (v != expectedHalfByteSquareDistance) {
183+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
184+
}
185+
return v;
186+
}
187+
111188
@Benchmark
112189
public int binaryDotProductScalar() {
113190
return VectorUtil.dotProduct(bytesA, bytesB);
@@ -131,14 +208,22 @@ public int binaryDotProductUint8Vector() {
131208
}
132209

133210
@Benchmark
134-
public int binarySquareScalar() {
135-
return VectorUtil.squareDistance(bytesA, bytesB);
211+
public int binaryHalfByteDotProductScalar() {
212+
int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
213+
if (v != expectedHalfByteDotProduct) {
214+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
215+
}
216+
return v;
136217
}
137218

138219
@Benchmark
139220
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
140-
public int binarySquareVector() {
141-
return VectorUtil.squareDistance(bytesA, bytesB);
221+
public int binaryHalfByteDotProductVector() {
222+
int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
223+
if (v != expectedHalfByteDotProduct) {
224+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
225+
}
226+
return v;
142227
}
143228

144229
@Benchmark
@@ -153,37 +238,39 @@ public int binarySquareUint8Vector() {
153238
}
154239

155240
@Benchmark
156-
public int binaryHalfByteScalar() {
157-
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
241+
public int binaryHalfByteDotProductSinglePackedScalar() {
242+
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
243+
if (v != expectedHalfByteDotProduct) {
244+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
245+
}
246+
return v;
158247
}
159248

160249
@Benchmark
161250
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
162-
public int binaryHalfByteVector() {
163-
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
251+
public int binaryHalfByteDotProductSinglePackedVector() {
252+
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
253+
if (v != expectedHalfByteDotProduct) {
254+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
255+
}
256+
return v;
164257
}
165258

166259
@Benchmark
167-
public int binaryHalfByteScalarPacked() {
168-
if (size % 2 != 0) {
169-
throw new RuntimeException("Size must be even for this benchmark");
170-
}
171-
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
172-
if (v != expectedhalfByteDotProduct) {
173-
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
260+
public int binaryHalfByteDotProductBothPackedScalar() {
261+
int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
262+
if (v != expectedHalfByteDotProduct) {
263+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
174264
}
175265
return v;
176266
}
177267

178268
@Benchmark
179269
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
180-
public int binaryHalfByteVectorPacked() {
181-
if (size % 2 != 0) {
182-
throw new RuntimeException("Size must be even for this benchmark");
183-
}
184-
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
185-
if (v != expectedhalfByteDotProduct) {
186-
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
270+
public int binaryHalfByteDotProductBothPackedVector() {
271+
int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
272+
if (v != expectedHalfByteDotProduct) {
273+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
187274
}
188275
return v;
189276
}

lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,8 @@ private FlatVectorScorerUtil() {}
3737
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
3838
return IMPL.getLucene99FlatVectorsScorer();
3939
}
40+
41+
public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
42+
return IMPL.getLucene99ScalarQuantizedVectorsScorer();
43+
}
4044
}

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2424
import org.apache.lucene.index.KnnVectorValues;
2525
import org.apache.lucene.index.VectorSimilarityFunction;
26+
import org.apache.lucene.util.FloatToFloatFunction;
2627
import org.apache.lucene.util.VectorUtil;
2728
import org.apache.lucene.util.hnsw.RandomVectorScorer;
2829
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
@@ -245,7 +246,7 @@ public float score(int vectorOrdinal) throws IOException {
245246
values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES));
246247
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
247248
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
248-
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
249+
int dotProduct = VectorUtil.int4DotProductSinglePacked(targetBytes, compressedVector);
249250
// For the current implementation of scalar quantization, all dotproducts should
250251
// be >= 0;
251252
assert dotProduct >= 0;
@@ -301,11 +302,6 @@ public void setScoringOrdinal(int node) throws IOException {
301302
}
302303
}
303304

304-
@FunctionalInterface
305-
private interface FloatToFloatFunction {
306-
float apply(float f);
307-
}
308-
309305
private static final class ScalarQuantizedRandomVectorScorerSupplier
310306
implements RandomVectorScorerSupplier {
311307

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
package org.apache.lucene.codecs.lucene99;
1919

2020
import java.io.IOException;
21-
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
2221
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
2322
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
2423
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
24+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
2626
import org.apache.lucene.index.SegmentReadState;
2727
import org.apache.lucene.index.SegmentWriteState;
@@ -68,7 +68,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
6868

6969
final byte bits;
7070
final boolean compress;
71-
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
71+
final FlatVectorsScorer flatVectorScorer;
7272

7373
/** Constructs a format using default graph construction parameters */
7474
public Lucene99ScalarQuantizedVectorsFormat() {
@@ -115,8 +115,7 @@ public Lucene99ScalarQuantizedVectorsFormat(
115115
this.bits = (byte) bits;
116116
this.confidenceInterval = confidenceInterval;
117117
this.compress = compress;
118-
this.flatVectorScorer =
119-
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
118+
this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer();
120119
}
121120

122121
public static float calculateDefaultConfidenceInterval(int vectorDimension) {

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -164,24 +164,35 @@ public int uint8DotProduct(byte[] a, byte[] b) {
164164
}
165165

166166
@Override
167-
public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) {
168-
assert (apacked && bpacked) == false;
169-
if (apacked || bpacked) {
170-
byte[] packed = apacked ? a : b;
171-
byte[] unpacked = apacked ? b : a;
172-
int total = 0;
173-
for (int i = 0; i < packed.length; i++) {
174-
byte packedByte = packed[i];
175-
byte unpacked1 = unpacked[i];
176-
byte unpacked2 = unpacked[i + packed.length];
177-
total += (packedByte & 0x0F) * unpacked2;
178-
total += ((packedByte & 0xFF) >> 4) * unpacked1;
179-
}
180-
return total;
181-
}
167+
public int int4DotProduct(byte[] a, byte[] b) {
182168
return dotProduct(a, b);
183169
}
184170

171+
@Override
172+
public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) {
173+
int total = 0;
174+
for (int i = 0; i < packed.length; i++) {
175+
byte packedByte = packed[i];
176+
byte unpacked1 = unpacked[i];
177+
byte unpacked2 = unpacked[i + packed.length];
178+
total += (packedByte & 0x0F) * unpacked2;
179+
total += ((packedByte & 0xFF) >> 4) * unpacked1;
180+
}
181+
return total;
182+
}
183+
184+
@Override
185+
public int int4DotProductBothPacked(byte[] a, byte[] b) {
186+
int total = 0;
187+
for (int i = 0; i < a.length; i++) {
188+
byte aByte = a[i];
189+
byte bByte = b[i];
190+
total += (aByte & 0x0F) * (bByte & 0x0F);
191+
total += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4);
192+
}
193+
return total;
194+
}
195+
185196
@Override
186197
public float cosine(byte[] a, byte[] b) {
187198
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
@@ -210,6 +221,42 @@ public int squareDistance(byte[] a, byte[] b) {
210221
return squareSum;
211222
}
212223

224+
@Override
225+
public int int4SquareDistance(byte[] a, byte[] b) {
226+
return squareDistance(a, b);
227+
}
228+
229+
@Override
230+
public int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) {
231+
int total = 0;
232+
for (int i = 0; i < packed.length; i++) {
233+
byte packedByte = packed[i];
234+
byte unpacked1 = unpacked[i];
235+
byte unpacked2 = unpacked[i + packed.length];
236+
237+
int diff1 = (packedByte & 0x0F) - unpacked2;
238+
int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1;
239+
240+
total += diff1 * diff1 + diff2 * diff2;
241+
}
242+
return total;
243+
}
244+
245+
@Override
246+
public int int4SquareDistanceBothPacked(byte[] a, byte[] b) {
247+
int total = 0;
248+
for (int i = 0; i < a.length; i++) {
249+
byte aByte = a[i];
250+
byte bByte = b[i];
251+
252+
int diff1 = (aByte & 0x0F) - (bByte & 0x0F);
253+
int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4);
254+
255+
total += diff1 * diff1 + diff2 * diff2;
256+
}
257+
return total;
258+
}
259+
213260
@Override
214261
public int uint8SquareDistance(byte[] a, byte[] b) {
215262
// Note: this will not overflow if dim < 2^16, since max(ubyte * ubyte) = 2^16.

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
2121
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
22+
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
2223
import org.apache.lucene.store.IndexInput;
2324

2425
/** Default provider returning scalar implementations. */
@@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() {
4041
return DefaultFlatVectorScorer.INSTANCE;
4142
}
4243

44+
@Override
45+
public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
46+
return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
47+
}
48+
4349
@Override
4450
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) {
4551
return new PostingDecodingUtil(input);

0 commit comments

Comments
 (0)