From ac41d297b398f1ae68669da314438259347b6895 Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Sat, 19 Jul 2025 02:13:30 +0000 Subject: [PATCH 1/2] Add utility functions to score on-heap vectors with off-heap ones --- ...Lucene99MemorySegmentByteVectorScorer.java | 6 +- .../PanamaVectorUtilSupport.java | 164 +++++++++++++----- 2 files changed, 120 insertions(+), 50 deletions(-) diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java index b65f1e570921..a8799c25a30a 100644 --- a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java @@ -32,7 +32,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer final int vectorByteSize; final MemorySegmentAccessInput input; - final MemorySegment query; + final byte[] query; byte[] scratch; /** @@ -61,7 +61,7 @@ public static Optional create( super(values); this.input = input; this.vectorByteSize = values.getVectorByteLength(); - this.query = MemorySegment.ofArray(queryVector); + this.query = queryVector; } final MemorySegment getSegment(int ord) throws IOException { @@ -113,7 +113,7 @@ public float score(int node) throws IOException { checkOrdinal(node); // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); - return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); + return 0.5f + raw / (float) (query.length * (1 << 15)); } } diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index efd91aad4487..d0e3e4cda758 100644 --- a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -309,45 +309,99 @@ private float squareDistanceBody(float[] a, float[] b, int limit) { // We also support 128 bit vectors, going 32 bits at a time. // This is slower but still faster than not vectorizing at all. + private interface ByteVectorLoader { + int length(); + + ByteVector load(VectorSpecies species, int index); + + byte tail(int index); + } + + private record ArrayLoader(byte[] arr) implements ByteVectorLoader { + @Override + public int length() { + return arr.length; + } + + @Override + public ByteVector load(VectorSpecies species, int index) { + assert index + species.length() < length(); + return ByteVector.fromArray(species, arr, index); + } + + @Override + public byte tail(int index) { + assert index < length(); + return arr[index]; + } + } + + private record MemorySegmentLoader(MemorySegment segment) implements ByteVectorLoader { + @Override + public int length() { + return Math.toIntExact(segment.byteSize()); + } + + @Override + public ByteVector load(VectorSpecies species, int index) { + assert index + species.length() < length(); + return ByteVector.fromMemorySegment(species, segment, index, LITTLE_ENDIAN); + } + + @Override + public byte tail(int index) { + assert index < length(); + return segment.get(JAVA_BYTE, index); + } + } + @Override public int dotProduct(byte[] a, byte[] b) { - return dotProduct(MemorySegment.ofArray(a), MemorySegment.ofArray(b)); + return dotProductBody(new ArrayLoader(a), new ArrayLoader(b)); + } + + public static int dotProduct(byte[] a, MemorySegment b) { + return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b)); } public static int dotProduct(MemorySegment a, MemorySegment b) { - assert a.byteSize() == b.byteSize(); + return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); + } + + private static int dotProductBody(ByteVectorLoader a, ByteVectorLoader b) { + assert a.length() == b.length(); int i = 0; int res = 0; // only vectorize if we'll at least enter the loop a single time - if (a.byteSize() >= 16) { + if (a.length() >= 16) { // compute vectorized dot product consistent with VPDPBUSD instruction if (VECTOR_BITSIZE >= 512) { - i += BYTE_SPECIES.loopBound(a.byteSize()); + i += BYTE_SPECIES.loopBound(a.length()); res += dotProductBody512(a, b, i); } else if (VECTOR_BITSIZE == 256) { - i += BYTE_SPECIES.loopBound(a.byteSize()); + i += BYTE_SPECIES.loopBound(a.length()); res += dotProductBody256(a, b, i); } else { // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length()); + i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length()); res += dotProductBody128(a, b, i); } } // scalar tail - for (; i < a.byteSize(); i++) { - res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i); + for (; i < a.length(); i++) { + res += a.tail(i) * b.tail(i); } return res; } /** vectorized dot product body (512 bit vectors) */ - private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit) { + private static int dotProductBody512(ByteVectorLoader a, ByteVectorLoader b, int limit) { IntVector acc = IntVector.zero(INT_SPECIES); for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { - ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN); + ByteVector va8 = a.load(BYTE_SPECIES, i); + ByteVector vb8 = b.load(BYTE_SPECIES, i); // 16-bit multiply: avoid AVX-512 heavy multiply on zmm Vector va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); @@ -363,11 +417,11 @@ private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit } /** vectorized dot product body (256 bit vectors) */ - private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit) { + private static int dotProductBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) { IntVector acc = IntVector.zero(IntVector.SPECIES_256); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); - ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); + ByteVector va8 = a.load(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); // 32-bit multiply and add into accumulator Vector va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0); @@ -379,13 +433,13 @@ private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit } /** vectorized dot product body (128 bit vectors) */ - private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit) { + private static int dotProductBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) { IntVector acc = IntVector.zero(IntVector.SPECIES_128); // 4 bytes at a time (re-loading half the vector each time!) for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { // load 8 bytes - ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); - ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); + ByteVector va8 = a.load(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); // process first "half" only: 16-bit multiply Vector va16 = va8.convert(B2S, 0); @@ -577,27 +631,35 @@ private int int4DotProductBody128(byte[] a, byte[] b, int limit) { @Override public float cosine(byte[] a, byte[] b) { - return cosine(MemorySegment.ofArray(a), MemorySegment.ofArray(b)); + return cosineBody(new ArrayLoader(a), new ArrayLoader(b)); } public static float cosine(MemorySegment a, MemorySegment b) { + return cosineBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); + } + + public static float cosine(byte[] a, MemorySegment b) { + return cosineBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + } + + private static float cosineBody(ByteVectorLoader a, ByteVectorLoader b) { int i = 0; int sum = 0; int norm1 = 0; int norm2 = 0; // only vectorize if we'll at least enter the loop a single time - if (a.byteSize() >= 16) { + if (a.length() >= 16) { final float[] ret; if (VECTOR_BITSIZE >= 512) { - i += BYTE_SPECIES.loopBound((int) a.byteSize()); + i += BYTE_SPECIES.loopBound(a.length()); ret = cosineBody512(a, b, i); } else if (VECTOR_BITSIZE == 256) { - i += BYTE_SPECIES.loopBound((int) a.byteSize()); + i += BYTE_SPECIES.loopBound(a.length()); ret = cosineBody256(a, b, i); } else { // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length()); + i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length()); ret = cosineBody128(a, b, i); } sum += ret[0]; @@ -606,9 +668,9 @@ public static float cosine(MemorySegment a, MemorySegment b) { } // scalar tail - for (; i < a.byteSize(); i++) { - byte elem1 = a.get(JAVA_BYTE, i); - byte elem2 = b.get(JAVA_BYTE, i); + for (; i < a.length(); i++) { + byte elem1 = a.tail(i); + byte elem2 = b.tail(i); sum += elem1 * elem2; norm1 += elem1 * elem1; norm2 += elem2 * elem2; @@ -617,13 +679,13 @@ public static float cosine(MemorySegment a, MemorySegment b) { } /** vectorized cosine body (512 bit vectors) */ - private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit) { + private static float[] cosineBody512(ByteVectorLoader a, ByteVectorLoader b, int limit) { IntVector accSum = IntVector.zero(INT_SPECIES); IntVector accNorm1 = IntVector.zero(INT_SPECIES); IntVector accNorm2 = IntVector.zero(INT_SPECIES); for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { - ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN); + ByteVector va8 = a.load(BYTE_SPECIES, i); + ByteVector vb8 = b.load(BYTE_SPECIES, i); // 16-bit multiply: avoid AVX-512 heavy multiply on zmm Vector va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); @@ -647,13 +709,13 @@ private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit } /** vectorized cosine body (256 bit vectors) */ - private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit) { + private static float[] cosineBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) { IntVector accSum = IntVector.zero(IntVector.SPECIES_256); IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256); IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); - ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); + ByteVector va8 = a.load(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); // 16-bit multiply, and add into accumulators Vector va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0); @@ -672,13 +734,13 @@ private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit } /** vectorized cosine body (128 bit vectors) */ - private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit) { + private static float[] cosineBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) { IntVector accSum = IntVector.zero(IntVector.SPECIES_128); IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128); IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { - ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); - ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); + ByteVector va8 = a.load(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); // process first half only: 16-bit multiply Vector va16 = va8.convert(B2S, 0); @@ -700,39 +762,47 @@ private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit @Override public int squareDistance(byte[] a, byte[] b) { - return squareDistance(MemorySegment.ofArray(a), MemorySegment.ofArray(b)); + return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b)); } public static int squareDistance(MemorySegment a, MemorySegment b) { - assert a.byteSize() == b.byteSize(); + return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); + } + + public static int squareDistance(byte[] a, MemorySegment b) { + return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + } + + private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b) { + assert a.length() == b.length(); int i = 0; int res = 0; // only vectorize if we'll at least enter the loop a single time - if (a.byteSize() >= 16) { + if (a.length() >= 16) { if (VECTOR_BITSIZE >= 256) { - i += BYTE_SPECIES.loopBound((int) a.byteSize()); + i += BYTE_SPECIES.loopBound(a.length()); res += squareDistanceBody256(a, b, i); } else { - i += ByteVector.SPECIES_64.loopBound((int) a.byteSize()); + i += ByteVector.SPECIES_64.loopBound(a.length()); res += squareDistanceBody128(a, b, i); } } // scalar tail - for (; i < a.byteSize(); i++) { - int diff = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i); + for (; i < a.length(); i++) { + int diff = a.tail(i) - b.tail(i); res += diff * diff; } return res; } /** vectorized square distance body (256+ bit vectors) */ - private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int limit) { + private static int squareDistanceBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) { IntVector acc = IntVector.zero(INT_SPECIES); for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { - ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN); + ByteVector va8 = a.load(BYTE_SPECIES, i); + ByteVector vb8 = b.load(BYTE_SPECIES, i); // 32-bit sub, multiply, and add into accumulators // TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512? @@ -746,14 +816,14 @@ private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int l } /** vectorized square distance body (128 bit vectors) */ - private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int limit) { + private static int squareDistanceBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) { // 128-bit implementation, which must "split up" vectors due to widening conversions // it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula IntVector acc1 = IntVector.zero(IntVector.SPECIES_128); IntVector acc2 = IntVector.zero(IntVector.SPECIES_128); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); - ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); + ByteVector va8 = a.load(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); // 16-bit sub Vector va16 = va8.convertShape(B2S, ShortVector.SPECIES_128, 0); From 21656eeb13d280bb5d2ec61dab61aa0ad12b6378 Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Sat, 19 Jul 2025 02:45:07 +0000 Subject: [PATCH 2/2] Fix assertions, add CHANGES.txt entry --- lucene/CHANGES.txt | 2 ++ .../internal/vectorization/PanamaVectorUtilSupport.java | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index e625456a70c4..83c163b39202 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -53,6 +53,8 @@ Optimizations * GITHUB#14011: Reduce allocation rate in HNSW concurrent merge. (Viliam Durina) * GITHUB#14022: Optimize DFS marking of connected components in HNSW by reducing stack depth, improving performance and reducing allocations. (Viswanath Kuchibhotla) +* GITHUB#14874: Improve off-heap KNN byte vector query performance in cases where indexing and search are performed by the same process. (Kaival Parikh) + Bug Fixes --------------------- * GITHUB#14049: Randomize KNN codec params in RandomCodec. Fixes scalar quantization div-by-zero diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index d0e3e4cda758..7572560923b0 100644 --- a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -325,13 +325,13 @@ public int length() { @Override public ByteVector load(VectorSpecies species, int index) { - assert index + species.length() < length(); + assert index + species.length() <= length(); return ByteVector.fromArray(species, arr, index); } @Override public byte tail(int index) { - assert index < length(); + assert index <= length(); return arr[index]; } } @@ -344,13 +344,13 @@ public int length() { @Override public ByteVector load(VectorSpecies species, int index) { - assert index + species.length() < length(); + assert index + species.length() <= length(); return ByteVector.fromMemorySegment(species, segment, index, LITTLE_ENDIAN); } @Override public byte tail(int index) { - assert index < length(); + assert index <= length(); return segment.get(JAVA_BYTE, index); } }