Skip to content

Commit 0167f13

Browse files
author
Kaival Parikh
committed
Implement off-heap quantized scoring
1 parent b3f4011 commit 0167f13

File tree

7 files changed

+511
-38
lines changed

7 files changed

+511
-38
lines changed

lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,8 @@ private FlatVectorScorerUtil() {}
3737
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
3838
return IMPL.getLucene99FlatVectorsScorer();
3939
}
40+
41+
public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
42+
return IMPL.getLucene99ScalarQuantizedVectorsScorer();
43+
}
4044
}

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
package org.apache.lucene.codecs.lucene99;
1919

2020
import java.io.IOException;
21-
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
2221
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
2322
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
2423
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
24+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
2626
import org.apache.lucene.index.SegmentReadState;
2727
import org.apache.lucene.index.SegmentWriteState;
@@ -70,7 +70,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
7070

7171
final byte bits;
7272
final boolean compress;
73-
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
73+
final FlatVectorsScorer flatVectorScorer;
7474

7575
/** Constructs a format using default graph construction parameters */
7676
public Lucene99ScalarQuantizedVectorsFormat() {
@@ -117,8 +117,7 @@ public Lucene99ScalarQuantizedVectorsFormat(
117117
this.bits = (byte) bits;
118118
this.confidenceInterval = confidenceInterval;
119119
this.compress = compress;
120-
this.flatVectorScorer =
121-
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
120+
this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer();
122121
}
123122

124123
public static float calculateDefaultConfidenceInterval(int vectorDimension) {

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
2121
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
22+
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
2223
import org.apache.lucene.store.IndexInput;
2324

2425
/** Default provider returning scalar implementations. */
@@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() {
4041
return DefaultFlatVectorScorer.INSTANCE;
4142
}
4243

44+
@Override
45+
public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
46+
return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
47+
}
48+
4349
@Override
4450
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) {
4551
return new PostingDecodingUtil(input);

lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ public static VectorizationProvider getInstance() {
111111
/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
112112
public abstract FlatVectorsScorer getLucene99FlatVectorsScorer();
113113

114+
/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
115+
public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer();
116+
114117
/** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */
115118
public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException;
116119

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.internal.vectorization;
18+
19+
import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
20+
import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery;
21+
22+
import java.io.IOException;
23+
import java.lang.foreign.MemorySegment;
24+
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
25+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
26+
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
27+
import org.apache.lucene.index.KnnVectorValues;
28+
import org.apache.lucene.index.VectorSimilarityFunction;
29+
import org.apache.lucene.store.MemorySegmentAccessInput;
30+
import org.apache.lucene.util.VectorUtil;
31+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
32+
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
33+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
34+
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
35+
import org.apache.lucene.util.quantization.ScalarQuantizer;
36+
37+
class Lucene99MemorySegmentScalarQuantizedVectorScorer implements FlatVectorsScorer {
38+
static final Lucene99MemorySegmentScalarQuantizedVectorScorer INSTANCE =
39+
new Lucene99MemorySegmentScalarQuantizedVectorScorer();
40+
41+
private static final FlatVectorsScorer DELEGATE =
42+
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
43+
44+
private Lucene99MemorySegmentScalarQuantizedVectorScorer() {}
45+
46+
@Override
47+
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
48+
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
49+
throws IOException {
50+
if (vectorValues instanceof QuantizedByteVectorValues quantized
51+
&& quantized.getSlice() instanceof MemorySegmentAccessInput input) {
52+
return new RandomVectorScorerSupplierImpl(similarityFunction, quantized, input);
53+
}
54+
return DELEGATE.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
55+
}
56+
57+
@Override
58+
public RandomVectorScorer getRandomVectorScorer(
59+
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
60+
throws IOException {
61+
if (vectorValues instanceof QuantizedByteVectorValues quantized
62+
&& quantized.getSlice() instanceof MemorySegmentAccessInput input) {
63+
return new RandomVectorScorerImpl(similarityFunction, quantized, input, target);
64+
}
65+
return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
66+
}
67+
68+
@Override
69+
public RandomVectorScorer getRandomVectorScorer(
70+
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
71+
throws IOException {
72+
return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
73+
}
74+
75+
private abstract static class RandomVectorScorerBase
76+
extends RandomVectorScorer.AbstractRandomVectorScorer {
77+
78+
private final ScalarQuantizer quantizer;
79+
private final float constMultiplier;
80+
private final MemorySegmentAccessInput input;
81+
private final int vectorByteSize;
82+
private final int nodeSize;
83+
private final Scorer scorer;
84+
private final Scaler scaler;
85+
private byte[] scratch;
86+
87+
RandomVectorScorerBase(
88+
VectorSimilarityFunction similarityFunction,
89+
QuantizedByteVectorValues values,
90+
MemorySegmentAccessInput input) {
91+
super(values);
92+
93+
this.quantizer = values.getScalarQuantizer();
94+
this.constMultiplier = this.quantizer.getConstantMultiplier();
95+
this.input = input;
96+
this.vectorByteSize = values.getVectorByteLength();
97+
this.nodeSize = this.vectorByteSize + Float.BYTES;
98+
99+
this.scorer =
100+
switch (similarityFunction) {
101+
case EUCLIDEAN -> {
102+
if (this.quantizer.getBits() <= 4) {
103+
if (this.vectorByteSize != values.dimension()) {
104+
yield this::compressedInt4Euclidean;
105+
}
106+
yield this::int4Euclidean;
107+
}
108+
yield this::euclidean;
109+
}
110+
case DOT_PRODUCT, COSINE, MAXIMUM_INNER_PRODUCT -> {
111+
if (this.quantizer.getBits() <= 4) {
112+
if (this.vectorByteSize != values.dimension()) {
113+
yield this::compressedInt4DotProduct;
114+
}
115+
yield this::int4DotProduct;
116+
}
117+
yield this::dotProduct;
118+
}
119+
};
120+
121+
this.scaler =
122+
switch (similarityFunction) {
123+
case EUCLIDEAN -> distance -> 1 / (1 + distance);
124+
case DOT_PRODUCT, COSINE -> score -> Math.max((1 + score) / 2, 0);
125+
case MAXIMUM_INNER_PRODUCT -> VectorUtil::scaleMaxInnerProductScore;
126+
};
127+
128+
checkInvariants();
129+
}
130+
131+
final void checkInvariants() {
132+
if (input.length() < (long) nodeSize * maxOrd()) {
133+
throw new IllegalArgumentException("input length is less than expected vector data");
134+
}
135+
}
136+
137+
final void checkOrdinal(int ord) {
138+
if (ord < 0 || ord >= maxOrd()) {
139+
throw new IllegalArgumentException("illegal ordinal: " + ord);
140+
}
141+
}
142+
143+
ScalarQuantizer getQuantizer() {
144+
return quantizer;
145+
}
146+
147+
@SuppressWarnings("restricted")
148+
Node getNode(int ord) throws IOException {
149+
checkOrdinal(ord);
150+
long byteOffset = (long) ord * nodeSize;
151+
MemorySegment node = input.segmentSliceOrNull(byteOffset, nodeSize);
152+
if (node == null) {
153+
if (scratch == null) {
154+
scratch = new byte[nodeSize];
155+
}
156+
input.readBytes(byteOffset, scratch, 0, nodeSize);
157+
node = MemorySegment.ofArray(scratch);
158+
}
159+
return new Node(node.reinterpret(vectorByteSize), node.get(JAVA_FLOAT, vectorByteSize));
160+
}
161+
162+
float scoreBody(int ord, float queryOffset) throws IOException {
163+
checkOrdinal(ord);
164+
Node node = getNode(ord);
165+
return scaler.scale(scorer.score(node.vector) * constMultiplier + node.offset + queryOffset);
166+
}
167+
168+
abstract int euclidean(MemorySegment doc);
169+
170+
abstract int int4Euclidean(MemorySegment doc);
171+
172+
abstract int compressedInt4Euclidean(MemorySegment doc);
173+
174+
abstract int dotProduct(MemorySegment doc);
175+
176+
abstract int int4DotProduct(MemorySegment doc);
177+
178+
abstract int compressedInt4DotProduct(MemorySegment doc);
179+
180+
record Node(MemorySegment vector, float offset) {}
181+
182+
@FunctionalInterface
183+
private interface Scorer {
184+
int score(MemorySegment doc) throws IOException;
185+
}
186+
187+
@FunctionalInterface
188+
private interface Scaler {
189+
float scale(float score);
190+
}
191+
}
192+
193+
private static class RandomVectorScorerImpl extends RandomVectorScorerBase {
194+
private final byte[] targetBytes;
195+
private final float queryOffset;
196+
197+
RandomVectorScorerImpl(
198+
VectorSimilarityFunction similarityFunction,
199+
QuantizedByteVectorValues values,
200+
MemorySegmentAccessInput input,
201+
float[] target) {
202+
super(similarityFunction, values, input);
203+
this.targetBytes = new byte[target.length];
204+
this.queryOffset = quantizeQuery(target, targetBytes, similarityFunction, getQuantizer());
205+
}
206+
207+
@Override
208+
public float score(int node) throws IOException {
209+
return scoreBody(node, queryOffset);
210+
}
211+
212+
@Override
213+
int euclidean(MemorySegment doc) {
214+
return PanamaVectorUtilSupport.squareDistance(targetBytes, doc);
215+
}
216+
217+
@Override
218+
int int4Euclidean(MemorySegment doc) {
219+
// TODO
220+
throw new UnsupportedOperationException();
221+
}
222+
223+
@Override
224+
int compressedInt4Euclidean(MemorySegment doc) {
225+
// TODO
226+
throw new UnsupportedOperationException();
227+
}
228+
229+
@Override
230+
int dotProduct(MemorySegment doc) {
231+
return PanamaVectorUtilSupport.dotProduct(targetBytes, doc);
232+
}
233+
234+
@Override
235+
int int4DotProduct(MemorySegment doc) {
236+
return PanamaVectorUtilSupport.int4DotProduct(targetBytes, false, doc, false);
237+
}
238+
239+
@Override
240+
int compressedInt4DotProduct(MemorySegment doc) {
241+
return PanamaVectorUtilSupport.int4DotProduct(targetBytes, false, doc, true);
242+
}
243+
}
244+
245+
private record RandomVectorScorerSupplierImpl(
246+
VectorSimilarityFunction similarityFunction,
247+
QuantizedByteVectorValues values,
248+
MemorySegmentAccessInput input)
249+
implements RandomVectorScorerSupplier {
250+
251+
@Override
252+
public UpdateableRandomVectorScorer scorer() {
253+
return new UpdateableRandomVectorScorerImpl(similarityFunction, values, input);
254+
}
255+
256+
@Override
257+
public RandomVectorScorerSupplier copy() {
258+
return new RandomVectorScorerSupplierImpl(similarityFunction, values, input);
259+
}
260+
}
261+
262+
private static class UpdateableRandomVectorScorerImpl extends RandomVectorScorerBase
263+
implements UpdateableRandomVectorScorer {
264+
private MemorySegment query;
265+
private float queryOffset;
266+
267+
UpdateableRandomVectorScorerImpl(
268+
VectorSimilarityFunction similarityFunction,
269+
QuantizedByteVectorValues values,
270+
MemorySegmentAccessInput input) {
271+
super(similarityFunction, values, input);
272+
}
273+
274+
@Override
275+
public void setScoringOrdinal(int ord) throws IOException {
276+
checkOrdinal(ord);
277+
Node node = getNode(ord);
278+
query = node.vector;
279+
queryOffset = node.offset;
280+
}
281+
282+
@Override
283+
public float score(int node) throws IOException {
284+
return scoreBody(node, queryOffset);
285+
}
286+
287+
@Override
288+
int euclidean(MemorySegment doc) {
289+
return PanamaVectorUtilSupport.squareDistance(query, doc);
290+
}
291+
292+
@Override
293+
int int4Euclidean(MemorySegment doc) {
294+
// TODO
295+
throw new UnsupportedOperationException();
296+
}
297+
298+
@Override
299+
int compressedInt4Euclidean(MemorySegment doc) {
300+
// TODO
301+
throw new UnsupportedOperationException();
302+
}
303+
304+
@Override
305+
int dotProduct(MemorySegment doc) {
306+
return PanamaVectorUtilSupport.dotProduct(query, doc);
307+
}
308+
309+
@Override
310+
int int4DotProduct(MemorySegment doc) {
311+
return PanamaVectorUtilSupport.int4DotProduct(query, false, doc, false);
312+
}
313+
314+
@Override
315+
int compressedInt4DotProduct(MemorySegment doc) {
316+
return PanamaVectorUtilSupport.int4DotProduct(query, true, doc, true);
317+
}
318+
}
319+
}

0 commit comments

Comments
 (0)