Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ Optimizations
* GITHUB#15160: Increased the size used for blocks of postings from 128 to 256.
This gives a noticeable speedup to many queries. (Adrien Grand)

* GITHUB#14863: Perform scoring for 4, 7, 8 bit quantized vectors off-heap. (Kaival Parikh)

Bug Fixes
---------------------
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ static void compressBytes(byte[] raw, byte[] compressed) {
private byte[] bytesA;
private byte[] bytesB;
private byte[] halfBytesA;
private byte[] halfBytesAPacked;
private byte[] halfBytesB;
private byte[] halfBytesBPacked;
private float[] floatsA;
private float[] floatsB;
private int expectedhalfByteDotProduct;
private int expectedHalfByteDotProduct;
private int expectedHalfByteSquareDistance;

@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
int size;
Expand All @@ -74,16 +76,23 @@ public void init() {
random.nextBytes(bytesB);
// random half byte arrays for binary methods
// this means that all values must be between 0 and 15
expectedhalfByteDotProduct = 0;
expectedHalfByteDotProduct = 0;
expectedHalfByteSquareDistance = 0;
halfBytesA = new byte[size];
halfBytesB = new byte[size];
for (int i = 0; i < size; ++i) {
halfBytesA[i] = (byte) random.nextInt(16);
halfBytesB[i] = (byte) random.nextInt(16);
expectedhalfByteDotProduct += halfBytesA[i] * halfBytesB[i];
expectedHalfByteDotProduct += halfBytesA[i] * halfBytesB[i];

int diff = halfBytesA[i] - halfBytesB[i];
expectedHalfByteSquareDistance += diff * diff;
}
// pack the half byte arrays
if (size % 2 == 0) {
halfBytesAPacked = new byte[(size + 1) >> 1];
compressBytes(halfBytesA, halfBytesAPacked);

halfBytesBPacked = new byte[(size + 1) >> 1];
compressBytes(halfBytesB, halfBytesBPacked);
}
Expand All @@ -108,6 +117,74 @@ public float binaryCosineVector() {
return VectorUtil.cosine(bytesA, bytesB);
}

@Benchmark
public int binarySquareScalar() {
return VectorUtil.squareDistance(bytesA, bytesB);
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binarySquareVector() {
return VectorUtil.squareDistance(bytesA, bytesB);
}

@Benchmark
public int binaryHalfByteSquareScalar() {
int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteSquareVector() {
int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public int binaryHalfByteSquareSinglePackedScalar() {
int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteSquareSinglePackedVector() {
int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public int binaryHalfByteSquareBothPackedScalar() {
int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteSquareBothPackedVector() {
int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public int binaryDotProductScalar() {
return VectorUtil.dotProduct(bytesA, bytesB);
Expand All @@ -131,14 +208,22 @@ public int binaryDotProductUint8Vector() {
}

@Benchmark
public int binarySquareScalar() {
return VectorUtil.squareDistance(bytesA, bytesB);
public int binaryHalfByteDotProductScalar() {
int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binarySquareVector() {
return VectorUtil.squareDistance(bytesA, bytesB);
public int binaryHalfByteDotProductVector() {
int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
Expand All @@ -153,37 +238,39 @@ public int binarySquareUint8Vector() {
}

@Benchmark
public int binaryHalfByteScalar() {
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
public int binaryHalfByteDotProductSinglePackedScalar() {
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteVector() {
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
public int binaryHalfByteDotProductSinglePackedVector() {
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public int binaryHalfByteScalarPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
if (v != expectedhalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
public int binaryHalfByteDotProductBothPackedScalar() {
int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteVectorPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
if (v != expectedhalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
public int binaryHalfByteDotProductBothPackedVector() {
int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ private FlatVectorScorerUtil() {}
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
return IMPL.getLucene99FlatVectorsScorer();
}

public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return IMPL.getLucene99ScalarQuantizedVectorsScorer();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FloatToFloatFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
Expand Down Expand Up @@ -245,7 +246,7 @@ public float score(int vectorOrdinal) throws IOException {
values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES));
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
int dotProduct = VectorUtil.int4DotProductSinglePacked(targetBytes, compressedVector);
// For the current implementation of scalar quantization, all dotproducts should
// be >= 0;
assert dotProduct >= 0;
Expand Down Expand Up @@ -301,11 +302,6 @@ public void setScoringOrdinal(int node) throws IOException {
}
}

@FunctionalInterface
private interface FloatToFloatFunction {
float apply(float f);
}

private static final class ScalarQuantizedRandomVectorScorerSupplier
implements RandomVectorScorerSupplier {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
Expand Down Expand Up @@ -68,7 +68,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {

final byte bits;
final boolean compress;
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
final FlatVectorsScorer flatVectorScorer;

/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
Expand Down Expand Up @@ -115,8 +115,7 @@ public Lucene99ScalarQuantizedVectorsFormat(
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer();
}

public static float calculateDefaultConfidenceInterval(int vectorDimension) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,24 +164,35 @@ public int uint8DotProduct(byte[] a, byte[] b) {
}

@Override
public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) {
assert (apacked && bpacked) == false;
if (apacked || bpacked) {
byte[] packed = apacked ? a : b;
byte[] unpacked = apacked ? b : a;
int total = 0;
for (int i = 0; i < packed.length; i++) {
byte packedByte = packed[i];
byte unpacked1 = unpacked[i];
byte unpacked2 = unpacked[i + packed.length];
total += (packedByte & 0x0F) * unpacked2;
total += ((packedByte & 0xFF) >> 4) * unpacked1;
}
return total;
}
public int int4DotProduct(byte[] a, byte[] b) {
return dotProduct(a, b);
}

@Override
public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) {
int total = 0;
for (int i = 0; i < packed.length; i++) {
byte packedByte = packed[i];
byte unpacked1 = unpacked[i];
byte unpacked2 = unpacked[i + packed.length];
total += (packedByte & 0x0F) * unpacked2;
total += ((packedByte & 0xFF) >> 4) * unpacked1;
}
return total;
}

@Override
public int int4DotProductBothPacked(byte[] a, byte[] b) {
int total = 0;
for (int i = 0; i < a.length; i++) {
byte aByte = a[i];
byte bByte = b[i];
total += (aByte & 0x0F) * (bByte & 0x0F);
total += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4);
}
return total;
}

@Override
public float cosine(byte[] a, byte[] b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
Expand Down Expand Up @@ -210,6 +221,42 @@ public int squareDistance(byte[] a, byte[] b) {
return squareSum;
}

@Override
public int int4SquareDistance(byte[] a, byte[] b) {
return squareDistance(a, b);
}

@Override
public int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) {
int total = 0;
for (int i = 0; i < packed.length; i++) {
byte packedByte = packed[i];
byte unpacked1 = unpacked[i];
byte unpacked2 = unpacked[i + packed.length];

int diff1 = (packedByte & 0x0F) - unpacked2;
int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1;

total += diff1 * diff1 + diff2 * diff2;
}
return total;
}

@Override
public int int4SquareDistanceBothPacked(byte[] a, byte[] b) {
int total = 0;
for (int i = 0; i < a.length; i++) {
byte aByte = a[i];
byte bByte = b[i];

int diff1 = (aByte & 0x0F) - (bByte & 0x0F);
int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4);

total += diff1 * diff1 + diff2 * diff2;
}
return total;
}

@Override
public int uint8SquareDistance(byte[] a, byte[] b) {
// Note: this will not overflow if dim < 2^16, since max(ubyte * ubyte) = 2^16.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
import org.apache.lucene.store.IndexInput;

/** Default provider returning scalar implementations. */
Expand All @@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return DefaultFlatVectorScorer.INSTANCE;
}

@Override
public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
}

@Override
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) {
return new PostingDecodingUtil(input);
Expand Down
Loading
Loading