diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 3ea1326b4608..3bf3d34f4826 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -150,6 +150,8 @@ Optimizations * GITHUB#15160: Increased the size used for blocks of postings from 128 to 256. This gives a noticeable speedup to many queries. (Adrien Grand) +* GITHUB#14863: Perform scoring for 4, 7, 8 bit quantized vectors off-heap. (Kaival Parikh) + Bug Fixes --------------------- * GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java index a8eb1b945cee..4c8253fdab9f 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java @@ -54,11 +54,13 @@ static void compressBytes(byte[] raw, byte[] compressed) { private byte[] bytesA; private byte[] bytesB; private byte[] halfBytesA; + private byte[] halfBytesAPacked; private byte[] halfBytesB; private byte[] halfBytesBPacked; private float[] floatsA; private float[] floatsB; - private int expectedhalfByteDotProduct; + private int expectedHalfByteDotProduct; + private int expectedHalfByteSquareDistance; @Param({"1", "128", "207", "256", "300", "512", "702", "1024"}) int size; @@ -74,16 +76,23 @@ public void init() { random.nextBytes(bytesB); // random half byte arrays for binary methods // this means that all values must be between 0 and 15 - expectedhalfByteDotProduct = 0; + expectedHalfByteDotProduct = 0; + expectedHalfByteSquareDistance = 0; halfBytesA = new byte[size]; halfBytesB = new byte[size]; for (int i = 0; i < size; ++i) { halfBytesA[i] = (byte) random.nextInt(16); halfBytesB[i] = (byte) random.nextInt(16); - expectedhalfByteDotProduct += halfBytesA[i] * halfBytesB[i]; + expectedHalfByteDotProduct += halfBytesA[i] * halfBytesB[i]; + + int diff = halfBytesA[i] - halfBytesB[i]; + expectedHalfByteSquareDistance += diff * diff; } // pack the half byte arrays if (size % 2 == 0) { + halfBytesAPacked = new byte[(size + 1) >> 1]; + compressBytes(halfBytesA, halfBytesAPacked); + halfBytesBPacked = new byte[(size + 1) >> 1]; compressBytes(halfBytesB, halfBytesBPacked); } @@ -108,6 +117,74 @@ public float binaryCosineVector() { return VectorUtil.cosine(bytesA, bytesB); } + @Benchmark + public int binarySquareScalar() { + return VectorUtil.squareDistance(bytesA, bytesB); + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int binarySquareVector() { + return VectorUtil.squareDistance(bytesA, bytesB); + } + + @Benchmark + public int binaryHalfByteSquareScalar() { + int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int binaryHalfByteSquareVector() { + int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + public int binaryHalfByteSquareSinglePackedScalar() { + int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int binaryHalfByteSquareSinglePackedVector() { + int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + public int binaryHalfByteSquareBothPackedScalar() { + int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int binaryHalfByteSquareBothPackedVector() { + int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + @Benchmark public int binaryDotProductScalar() { return VectorUtil.dotProduct(bytesA, bytesB); @@ -131,14 +208,22 @@ public int binaryDotProductUint8Vector() { } @Benchmark - public int binarySquareScalar() { - return VectorUtil.squareDistance(bytesA, bytesB); + public int binaryHalfByteDotProductScalar() { + int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; } @Benchmark @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) - public int binarySquareVector() { - return VectorUtil.squareDistance(bytesA, bytesB); + public int binaryHalfByteDotProductVector() { + int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; } @Benchmark @@ -153,37 +238,39 @@ public int binarySquareUint8Vector() { } @Benchmark - public int binaryHalfByteScalar() { - return VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + public int binaryHalfByteDotProductSinglePackedScalar() { + int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; } @Benchmark @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) - public int binaryHalfByteVector() { - return VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + public int binaryHalfByteDotProductSinglePackedVector() { + int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; } @Benchmark - public int binaryHalfByteScalarPacked() { - if (size % 2 != 0) { - throw new RuntimeException("Size must be even for this benchmark"); - } - int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked); - if (v != expectedhalfByteDotProduct) { - throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v); + public int binaryHalfByteDotProductBothPackedScalar() { + int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); } return v; } @Benchmark @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) - public int binaryHalfByteVectorPacked() { - if (size % 2 != 0) { - throw new RuntimeException("Size must be even for this benchmark"); - } - int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked); - if (v != expectedhalfByteDotProduct) { - throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v); + public int binaryHalfByteDotProductBothPackedVector() { + int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); } return v; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java index 808d7b3cc882..123c18e00c08 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java @@ -37,4 +37,8 @@ private FlatVectorScorerUtil() {} public static FlatVectorsScorer getLucene99FlatVectorsScorer() { return IMPL.getLucene99FlatVectorsScorer(); } + + public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return IMPL.getLucene99ScalarQuantizedVectorsScorer(); + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java index 117521ddcc2a..80afaf5c685a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java @@ -23,6 +23,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.FloatToFloatFunction; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -245,7 +246,7 @@ public float score(int vectorOrdinal) throws IOException { values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES)); values.getSlice().readBytes(compressedVector, 0, compressedVector.length); float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal); - int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector); + int dotProduct = VectorUtil.int4DotProductSinglePacked(targetBytes, compressedVector); // For the current implementation of scalar quantization, all dotproducts should // be >= 0; assert dotProduct >= 0; @@ -301,11 +302,6 @@ public void setScoringOrdinal(int node) throws IOException { } } - @FunctionalInterface - private interface FloatToFloatFunction { - float apply(float f); - } - private static final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 0f339ecbe0a8..76c73980aef8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -18,10 +18,10 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -68,7 +68,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { final byte bits; final boolean compress; - final Lucene99ScalarQuantizedVectorScorer flatVectorScorer; + final FlatVectorsScorer flatVectorScorer; /** Constructs a format using default graph construction parameters */ public Lucene99ScalarQuantizedVectorsFormat() { @@ -115,8 +115,7 @@ public Lucene99ScalarQuantizedVectorsFormat( this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; this.compress = compress; - this.flatVectorScorer = - new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer(); } public static float calculateDefaultConfidenceInterval(int vectorDimension) { diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java index 89c813a4b93b..7f08c673a7f1 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java @@ -164,24 +164,35 @@ public int uint8DotProduct(byte[] a, byte[] b) { } @Override - public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { - assert (apacked && bpacked) == false; - if (apacked || bpacked) { - byte[] packed = apacked ? a : b; - byte[] unpacked = apacked ? b : a; - int total = 0; - for (int i = 0; i < packed.length; i++) { - byte packedByte = packed[i]; - byte unpacked1 = unpacked[i]; - byte unpacked2 = unpacked[i + packed.length]; - total += (packedByte & 0x0F) * unpacked2; - total += ((packedByte & 0xFF) >> 4) * unpacked1; - } - return total; - } + public int int4DotProduct(byte[] a, byte[] b) { return dotProduct(a, b); } + @Override + public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) { + int total = 0; + for (int i = 0; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + total += (packedByte & 0x0F) * unpacked2; + total += ((packedByte & 0xFF) >> 4) * unpacked1; + } + return total; + } + + @Override + public int int4DotProductBothPacked(byte[] a, byte[] b) { + int total = 0; + for (int i = 0; i < a.length; i++) { + byte aByte = a[i]; + byte bByte = b[i]; + total += (aByte & 0x0F) * (bByte & 0x0F); + total += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4); + } + return total; + } + @Override public float cosine(byte[] a, byte[] b) { // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. @@ -210,6 +221,42 @@ public int squareDistance(byte[] a, byte[] b) { return squareSum; } + @Override + public int int4SquareDistance(byte[] a, byte[] b) { + return squareDistance(a, b); + } + + @Override + public int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) { + int total = 0; + for (int i = 0; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + + int diff1 = (packedByte & 0x0F) - unpacked2; + int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1; + + total += diff1 * diff1 + diff2 * diff2; + } + return total; + } + + @Override + public int int4SquareDistanceBothPacked(byte[] a, byte[] b) { + int total = 0; + for (int i = 0; i < a.length; i++) { + byte aByte = a[i]; + byte bByte = b[i]; + + int diff1 = (aByte & 0x0F) - (bByte & 0x0F); + int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4); + + total += diff1 * diff1 + diff2 * diff2; + } + return total; + } + @Override public int uint8SquareDistance(byte[] a, byte[] b) { // Note: this will not overflow if dim < 2^16, since max(ubyte * ubyte) = 2^16. diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index c5e9301e9bc4..21977fa3dc77 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -19,6 +19,7 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; import org.apache.lucene.store.IndexInput; /** Default provider returning scalar implementations. */ @@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() { return DefaultFlatVectorScorer.INSTANCE; } + @Override + public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + } + @Override public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) { return new PostingDecodingUtil(input); diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java index 7190f983b4ce..7242a2501a19 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java @@ -36,18 +36,40 @@ public interface VectorUtilSupport { /** Returns the dot product computed over signed bytes. */ int dotProduct(byte[] a, byte[] b); + /** Returns the dot product computed over unsigned half-bytes, both uncompressed. */ + int int4DotProduct(byte[] a, byte[] b); + + /** Returns the dot product computed over unsigned half-bytes, one compressed. */ + int int4DotProductSinglePacked(byte[] unpacked, byte[] packed); + + /** Returns the dot product computed over unsigned half-bytes, both compressed. */ + int int4DotProductBothPacked(byte[] a, byte[] b); + /** Returns the dot product computed as though the bytes were unsigned. */ int uint8DotProduct(byte[] a, byte[] b); - /** Returns the dot product over the computed bytes, assuming the values are int4 encoded. */ - int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked); - /** Returns the cosine similarity between the two byte vectors. */ float cosine(byte[] a, byte[] b); /** Returns the sum of squared differences of the two byte vectors. */ int squareDistance(byte[] a, byte[] b); + /** + * Returns the sum of squared differences between two unsigned half-byte vectors, both + * uncompressed. + */ + int int4SquareDistance(byte[] a, byte[] b); + + /** + * Returns the sum of squared differences between two unsigned half-byte vectors, one compressed. + */ + int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed); + + /** + * Returns the sum of squared differences between two unsigned half-byte vectors, both compressed. + */ + int int4SquareDistanceBothPacked(byte[] a, byte[] b); + /** Returns the sum of squared differences of the two unsigned byte vectors. */ int uint8SquareDistance(byte[] a, byte[] b); diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java index 24864318af5a..cf9c56c59774 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java @@ -109,6 +109,9 @@ public static VectorizationProvider getInstance() { /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ public abstract FlatVectorsScorer getLucene99FlatVectorsScorer(); + /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ + public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer(); + /** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */ public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException; diff --git a/lucene/core/src/java/org/apache/lucene/util/FloatToFloatFunction.java b/lucene/core/src/java/org/apache/lucene/util/FloatToFloatFunction.java new file mode 100644 index 000000000000..9068a5438361 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/FloatToFloatFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +/** + * Simple interface to map one float to another (useful in scaling scores). + * + * @lucene.internal + */ +@FunctionalInterface +public interface FloatToFloatFunction { + float apply(float f); +} diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 38b9cf6d67a5..db1f6fee083b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -113,6 +113,37 @@ public static int squareDistance(byte[] a, byte[] b) { return IMPL.squareDistance(a, b); } + /** Returns the sum of squared differences between two uint4 (values between [0,15]) vectors. */ + public static int int4SquareDistance(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + return IMPL.int4SquareDistance(a, b); + } + + /** + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors. The + * second vector is considered "packed" (i.e. every byte representing two values). + */ + public static int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) { + if (packed.length != ((unpacked.length + 1) >> 1)) { + throw new IllegalArgumentException( + "vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length); + } + return IMPL.int4SquareDistanceSinglePacked(unpacked, packed); + } + + /** + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors. Both + * vectors are considered "packed" (i.e. every byte representing two values). + */ + public static int int4SquareDistanceBothPacked(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + return IMPL.int4SquareDistanceBothPacked(a, b); + } + /** Returns the sum of squared differences of the two vectors where each byte is unsigned */ public static int uint8SquareDistance(byte[] a, byte[] b) { if (a.length != b.length) { @@ -189,15 +220,22 @@ public static int uint8DotProduct(byte[] a, byte[] b) { return IMPL.uint8DotProduct(a, b); } + /** + * Dot product computed over uint4 (values between [0,15]) bytes. + * + * @param a bytes containing a vector + * @param b bytes containing another vector, of the same dimension + * @return the value of the dot product of the two vectors + */ public static int int4DotProduct(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return IMPL.int4DotProduct(a, false, b, false); + return IMPL.int4DotProduct(a, b); } /** - * Dot product computed over int4 (values between [0,15]) bytes. The second vector is considered + * Dot product computed over uint4 (values between [0,15]) bytes. The second vector is considered * "packed" (i.e. every byte representing two values). The following packing is assumed: * *
@@ -211,12 +249,28 @@ public static int int4DotProduct(byte[] a, byte[] b) { * @param packed the packed vector, of length {@code (unpacked.length + 1) / 2} * @return the value of the dot product of the two vectors */ - public static int int4DotProductPacked(byte[] unpacked, byte[] packed) { + public static int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) { if (packed.length != ((unpacked.length + 1) >> 1)) { throw new IllegalArgumentException( "vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length); } - return IMPL.int4DotProduct(unpacked, false, packed, true); + return IMPL.int4DotProductSinglePacked(unpacked, packed); + } + + /** + * Dot product computed over uint4 (values between [0,15]) bytes. Both vectors are considered + * "packed" (i.e. every byte representing two values). + * + * @param a bytes containing a packed vector + * @param b bytes containing another packed vector, of the same dimension + * @return the value of the dot product of the two vectors + */ + public static int int4DotProductBothPacked(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.length + " != " + b.length); + } + return IMPL.int4DotProductBothPacked(a, b); } /** diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..12b95f6c2ff2 --- /dev/null +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED; +import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.FloatToFloatFunction; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.apache.lucene.util.quantization.ScalarQuantizer; + +class Lucene99MemorySegmentScalarQuantizedVectorScorer implements FlatVectorsScorer { + static final Lucene99MemorySegmentScalarQuantizedVectorScorer INSTANCE = + new Lucene99MemorySegmentScalarQuantizedVectorScorer(); + + private static final FlatVectorsScorer DELEGATE = + new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + + private Lucene99MemorySegmentScalarQuantizedVectorScorer() {} + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues quantized + && quantized.getSlice() instanceof MemorySegmentAccessInput input) { + return new RandomVectorScorerSupplierImpl(similarityFunction, quantized, input); + } + return DELEGATE.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues quantized + && quantized.getSlice() instanceof MemorySegmentAccessInput input) { + return new RandomVectorScorerImpl(similarityFunction, quantized, input, target); + } + return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) + throws IOException { + return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public String toString() { + return "Lucene99MemorySegmentScalarQuantizedVectorScorer()"; + } + + private abstract static class RandomVectorScorerBase + extends RandomVectorScorer.AbstractRandomVectorScorer { + + private final ScalarQuantizer quantizer; + private final float constMultiplier; + private final MemorySegmentAccessInput input; + private final int vectorByteSize; + private final int nodeSize; + private final Scorer scorer; + private final FloatToFloatFunction scaler; + private byte[] scratch; + + RandomVectorScorerBase( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input) { + super(values); + + this.quantizer = values.getScalarQuantizer(); + this.constMultiplier = this.quantizer.getConstantMultiplier(); + this.input = input; + this.vectorByteSize = values.getVectorByteLength(); + this.nodeSize = this.vectorByteSize + Float.BYTES; + + this.scorer = + switch (similarityFunction) { + case EUCLIDEAN -> { + if (this.quantizer.getBits() <= 4) { + if (this.vectorByteSize != values.dimension()) { + yield this::compressedInt4Euclidean; + } + yield this::int4Euclidean; + } + yield this::euclidean; + } + case DOT_PRODUCT, COSINE, MAXIMUM_INNER_PRODUCT -> { + if (this.quantizer.getBits() <= 4) { + if (this.vectorByteSize != values.dimension()) { + yield this::compressedInt4DotProduct; + } + yield this::int4DotProduct; + } + yield this::dotProduct; + } + }; + + this.scaler = + switch (similarityFunction) { + case EUCLIDEAN -> VectorUtil::normalizeDistanceToUnitInterval; + case DOT_PRODUCT, COSINE -> VectorUtil::normalizeToUnitInterval; + case MAXIMUM_INNER_PRODUCT -> VectorUtil::scaleMaxInnerProductScore; + }; + + checkInvariants(); + } + + final void checkInvariants() { + if (input.length() < (long) nodeSize * maxOrd()) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + ScalarQuantizer getQuantizer() { + return quantizer; + } + + private static final ValueLayout.OfInt INT_UNALIGNED_LE = + JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + @SuppressWarnings("restricted") + Node getNode(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * nodeSize; + MemorySegment node = input.segmentSliceOrNull(byteOffset, nodeSize); + if (node == null) { + if (scratch == null) { + scratch = new byte[nodeSize]; + } + input.readBytes(byteOffset, scratch, 0, nodeSize); + node = MemorySegment.ofArray(scratch); + } + return new Node( + node.reinterpret(vectorByteSize), + Float.intBitsToFloat(node.get(INT_UNALIGNED_LE, vectorByteSize))); + } + + float scoreBody(int ord, float queryOffset) throws IOException { + checkOrdinal(ord); + Node node = getNode(ord); + return scaler.apply(scorer.score(node.vector) * constMultiplier + node.offset + queryOffset); + } + + abstract int euclidean(MemorySegment doc); + + abstract int int4Euclidean(MemorySegment doc); + + abstract int compressedInt4Euclidean(MemorySegment doc); + + abstract int dotProduct(MemorySegment doc); + + abstract int int4DotProduct(MemorySegment doc); + + abstract int compressedInt4DotProduct(MemorySegment doc); + + record Node(MemorySegment vector, float offset) {} + + @FunctionalInterface + private interface Scorer { + int score(MemorySegment doc) throws IOException; + } + } + + private static class RandomVectorScorerImpl extends RandomVectorScorerBase { + private final byte[] targetBytes; + private final float queryOffset; + + RandomVectorScorerImpl( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input, + float[] target) { + super(similarityFunction, values, input); + this.targetBytes = new byte[target.length]; + this.queryOffset = quantizeQuery(target, targetBytes, similarityFunction, getQuantizer()); + } + + @Override + public float score(int node) throws IOException { + return scoreBody(node, queryOffset); + } + + @Override + int euclidean(MemorySegment doc) { + return PanamaVectorUtilSupport.uint8SquareDistance(targetBytes, doc); + } + + @Override + int int4Euclidean(MemorySegment doc) { + return PanamaVectorUtilSupport.int4SquareDistance(targetBytes, doc); + } + + @Override + int compressedInt4Euclidean(MemorySegment doc) { + return PanamaVectorUtilSupport.int4SquareDistanceSinglePacked(targetBytes, doc); + } + + @Override + int dotProduct(MemorySegment doc) { + return PanamaVectorUtilSupport.uint8DotProduct(targetBytes, doc); + } + + @Override + int int4DotProduct(MemorySegment doc) { + return PanamaVectorUtilSupport.int4DotProduct(targetBytes, doc); + } + + @Override + int compressedInt4DotProduct(MemorySegment doc) { + return PanamaVectorUtilSupport.int4DotProductSinglePacked(targetBytes, doc); + } + } + + private record RandomVectorScorerSupplierImpl( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input) + implements RandomVectorScorerSupplier { + + @Override + public UpdateableRandomVectorScorer scorer() { + return new UpdateableRandomVectorScorerImpl(similarityFunction, values, input); + } + + @Override + public RandomVectorScorerSupplier copy() { + return new RandomVectorScorerSupplierImpl(similarityFunction, values, input); + } + } + + private static class UpdateableRandomVectorScorerImpl extends RandomVectorScorerBase + implements UpdateableRandomVectorScorer { + private MemorySegment query; + private float queryOffset; + + UpdateableRandomVectorScorerImpl( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input) { + super(similarityFunction, values, input); + } + + @Override + public void setScoringOrdinal(int ord) throws IOException { + checkOrdinal(ord); + Node node = getNode(ord); + query = node.vector; + queryOffset = node.offset; + } + + @Override + public float score(int node) throws IOException { + return scoreBody(node, queryOffset); + } + + @Override + int euclidean(MemorySegment doc) { + return PanamaVectorUtilSupport.uint8SquareDistance(query, doc); + } + + @Override + int int4Euclidean(MemorySegment doc) { + return PanamaVectorUtilSupport.int4SquareDistance(query, doc); + } + + @Override + int compressedInt4Euclidean(MemorySegment doc) { + return PanamaVectorUtilSupport.int4SquareDistanceBothPacked(query, doc); + } + + @Override + int dotProduct(MemorySegment doc) { + return PanamaVectorUtilSupport.uint8DotProduct(query, doc); + } + + @Override + int int4DotProduct(MemorySegment doc) { + return PanamaVectorUtilSupport.int4DotProduct(query, doc); + } + + @Override + int compressedInt4DotProduct(MemorySegment doc) { + return PanamaVectorUtilSupport.int4DotProductBothPacked(query, doc); + } + } +} diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index a77c4846ca2a..ba612f750040 100644 --- a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -360,7 +360,7 @@ public byte tail(int index) { @Override public int dotProduct(byte[] a, byte[] b) { - return dotProductBody(new ArrayLoader(a), new ArrayLoader(b)); + return dotProductBody(new ArrayLoader(a), new ArrayLoader(b), true); } @Override @@ -369,15 +369,19 @@ public int uint8DotProduct(byte[] a, byte[] b) { } public static int dotProduct(byte[] a, MemorySegment b) { - return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b), true); } public static int dotProduct(MemorySegment a, MemorySegment b) { - return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); + return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), true); } - private static int dotProductBody(ByteVectorLoader a, ByteVectorLoader b) { - return dotProductBody(a, b, true); + public static int uint8DotProduct(byte[] a, MemorySegment b) { + return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b), false); + } + + public static int uint8DotProduct(MemorySegment a, MemorySegment b) { + return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), false); } private static int dotProductBody(ByteVectorLoader a, ByteVectorLoader b, boolean signed) { @@ -479,178 +483,198 @@ private static int dotProductBody128( return acc.reduceLanes(ADD); } + private static class Int4Constants { + static final VectorSpeciesBYTE_SPECIES; + static final VectorSpecies SHORT_SPECIES; + static final int CHUNK; + + static { + if (VECTOR_BITSIZE >= 512) { + BYTE_SPECIES = ByteVector.SPECIES_256; + SHORT_SPECIES = ShortVector.SPECIES_512; + CHUNK = 4096; + } else if (VECTOR_BITSIZE == 256) { + BYTE_SPECIES = ByteVector.SPECIES_128; + SHORT_SPECIES = ShortVector.SPECIES_256; + CHUNK = 2048; + } else { + BYTE_SPECIES = ByteVector.SPECIES_64; + SHORT_SPECIES = ShortVector.SPECIES_128; + CHUNK = 1024; + } + } + } + @Override - public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { - assert (apacked && bpacked) == false; + public int int4DotProduct(byte[] a, byte[] b) { + return int4DotProductBody(new ArrayLoader(a), new ArrayLoader(b)); + } + + public static int int4DotProduct(byte[] a, MemorySegment b) { + return int4DotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + } + + public static int int4DotProduct(MemorySegment a, MemorySegment b) { + return int4DotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); + } + + private static int int4DotProductBody(ByteVectorLoader a, ByteVectorLoader b) { int i = 0; int res = 0; - if (apacked || bpacked) { - byte[] packed = apacked ? a : b; - byte[] unpacked = apacked ? b : a; - if (packed.length >= 32) { - if (VECTOR_BITSIZE >= 512) { - i += ByteVector.SPECIES_256.loopBound(packed.length); - res += dotProductBody512Int4Packed(unpacked, packed, i); - } else if (VECTOR_BITSIZE == 256) { - i += ByteVector.SPECIES_128.loopBound(packed.length); - res += dotProductBody256Int4Packed(unpacked, packed, i); - } else { - i += ByteVector.SPECIES_64.loopBound(packed.length); - res += dotProductBody128Int4Packed(unpacked, packed, i); - } - } - // scalar tail - for (; i < packed.length; i++) { - byte packedByte = packed[i]; - byte unpacked1 = unpacked[i]; - byte unpacked2 = unpacked[i + packed.length]; - res += (packedByte & 0x0F) * unpacked2; - res += ((packedByte & 0xFF) >> 4) * unpacked1; - } - } else { - if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { - return dotProduct(a, b); - } else if (a.length >= 32) { - i += ByteVector.SPECIES_128.loopBound(a.length); - res += int4DotProductBody128(a, b, i); - } - // scalar tail - for (; i < a.length; i++) { - res += b[i] * a[i]; - } + if (a.length() >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(a.length()); + res += int4DotProductBody(a, b, i); + } + // scalar tail + for (; i < a.length(); i++) { + res += a.tail(i) * b.tail(i); } - return res; } - private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int int4DotProductBody(ByteVectorLoader a, ByteVectorLoader b, int limit) { int sum = 0; - // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator - for (int i = 0; i < limit; i += 4096) { - ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_512); - ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_512); - int innerLimit = Math.min(limit - i, 4096); - for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { - // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); + // iterate in chunks to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += Int4Constants.CHUNK) { + ShortVector acc = ShortVector.zero(Int4Constants.SHORT_SPECIES); + int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); + for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + ByteVector vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j); + Vector vb16 = vb8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); - // upper - ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); - Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); - acc0 = acc0.add(prod16); + // unpacked + ByteVector va8 = a.load(Int4Constants.BYTE_SPECIES, i + j); + Vector va16 = va8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); - // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); - ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); - Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); - acc1 = acc1.add(prod16a); + acc = acc.add(vb16.mul(va16)); } - IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); - IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); - IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); - IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); - sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + Vector intAcc0 = acc.convert(S2I, 0); + Vector intAcc1 = acc.convert(S2I, 1); + sum += intAcc0.add(intAcc1).reinterpretAsInts().reduceLanes(ADD); } return sum; } - private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) { + @Override + public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) { + return int4DotProductSinglePackedBody(new ArrayLoader(unpacked), new ArrayLoader(packed)); + } + + public static int int4DotProductSinglePacked(byte[] unpacked, MemorySegment packed) { + return int4DotProductSinglePackedBody( + new ArrayLoader(unpacked), new MemorySegmentLoader(packed)); + } + + private static int int4DotProductSinglePackedBody( + ByteVectorLoader unpacked, ByteVectorLoader packed) { + int i = 0; + int res = 0; + if (packed.length() >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(packed.length()); + res += int4DotProductSinglePackedBody(unpacked, packed, i); + } + // scalar tail + for (; i < packed.length(); i++) { + byte packedByte = packed.tail(i); + byte unpacked1 = unpacked.tail(i); + byte unpacked2 = unpacked.tail(i + packed.length()); + res += (packedByte & 0x0F) * unpacked2; + res += ((packedByte & 0xFF) >> 4) * unpacked1; + } + return res; + } + + private static int int4DotProductSinglePackedBody( + ByteVectorLoader unpacked, ByteVectorLoader packed, int limit) { int sum = 0; - // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator - for (int i = 0; i < limit; i += 2048) { - ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256); - ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256); - int innerLimit = Math.min(limit - i, 2048); - for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { + // iterate in chunks to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += Int4Constants.CHUNK) { + ShortVector acc0 = ShortVector.zero(Int4Constants.SHORT_SPECIES); + ShortVector acc1 = ShortVector.zero(Int4Constants.SHORT_SPECIES); + int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); + for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); - // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + ByteVector vb8 = packed.load(Int4Constants.BYTE_SPECIES, i + j); // upper + ByteVector va8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j + packed.length()); ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); - Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0); acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector vc8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); - Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + Vector prod16a = + prod8a.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0); acc1 = acc1.add(prod16a); } - IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); - IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); - IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); - IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); - sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + Vector intAcc0 = acc0.convert(S2I, 0); + Vector intAcc1 = acc0.convert(S2I, 1); + Vector intAcc2 = acc1.convert(S2I, 0); + Vector intAcc3 = acc1.convert(S2I, 1); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reinterpretAsInts().reduceLanes(ADD); } return sum; } - /** vectorized dot product body (128 bit vectors) */ - private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) { - int sum = 0; - // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator - for (int i = 0; i < limit; i += 1024) { - ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128); - ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); - int innerLimit = Math.min(limit - i, 1024); - for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { - // packed - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); - // unpacked - ByteVector va8 = - ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); + @Override + public int int4DotProductBothPacked(byte[] a, byte[] b) { + return int4DotProductBothPackedBody(new ArrayLoader(a), new ArrayLoader(b)); + } - // upper - ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); - ShortVector prod16 = - prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); - acc0 = acc0.add(prod16.and((short) 0xFF)); + public static int int4DotProductBothPacked(MemorySegment a, MemorySegment b) { + return int4DotProductBothPackedBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); + } - // lower - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); - prod8 = vb8.lanewise(LSHR, 4).mul(va8); - prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); - acc1 = acc1.add(prod16.and((short) 0xFF)); - } - IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); - IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); - IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); - IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); - sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + private static int int4DotProductBothPackedBody(ByteVectorLoader a, ByteVectorLoader b) { + int i = 0; + int res = 0; + if (a.length() >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(a.length()); + res += int4DotProductBothPackedBody(a, b, i); } - return sum; + // scalar tail + for (; i < a.length(); i++) { + byte aByte = a.tail(i); + byte bByte = b.tail(i); + res += (aByte & 0x0F) * (bByte & 0x0F); + res += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4); + } + return res; } - private int int4DotProductBody128(byte[] a, byte[] b, int limit) { + private static int int4DotProductBothPackedBody( + ByteVectorLoader a, ByteVectorLoader b, int limit) { int sum = 0; - // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator - for (int i = 0; i < limit; i += 1024) { - ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128); - ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); - int innerLimit = Math.min(limit - i, 1024); - for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j); - ByteVector prod8 = va8.mul(vb8); - ShortVector prod16 = - prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); - acc0 = acc0.add(prod16.and((short) 0xFF)); - - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8); - vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8); - prod8 = va8.mul(vb8); - prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); - acc1 = acc1.add(prod16.and((short) 0xFF)); + // iterate in chunks to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += Int4Constants.CHUNK) { + ShortVector acc0 = ShortVector.zero(Int4Constants.SHORT_SPECIES); + ShortVector acc1 = ShortVector.zero(Int4Constants.SHORT_SPECIES); + int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); + for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { + // packed + var vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j); + // packed + var va8 = a.load(Int4Constants.BYTE_SPECIES, i + j); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8.and((byte) 0x0F)); + Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(va8.lanewise(LSHR, 4)); + Vector prod16a = + prod8a.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0); + acc1 = acc1.add(prod16a); } - IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); - IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); - IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); - IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); - sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + Vector intAcc0 = acc0.convert(S2I, 0); + Vector intAcc1 = acc0.convert(S2I, 1); + Vector intAcc2 = acc1.convert(S2I, 0); + Vector intAcc3 = acc1.convert(S2I, 1); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reinterpretAsInts().reduceLanes(ADD); } return sum; } @@ -788,7 +812,7 @@ private static float[] cosineBody128(ByteVectorLoader a, ByteVectorLoader b, int @Override public int squareDistance(byte[] a, byte[] b) { - return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b)); + return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b), true); } @Override @@ -797,15 +821,19 @@ public int uint8SquareDistance(byte[] a, byte[] b) { } public static int squareDistance(MemorySegment a, MemorySegment b) { - return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); + return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), true); } public static int squareDistance(byte[] a, MemorySegment b) { - return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b), true); } - private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b) { - return squareDistanceBody(a, b, true); + public static int uint8SquareDistance(MemorySegment a, MemorySegment b) { + return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), false); + } + + public static int uint8SquareDistance(byte[] a, MemorySegment b) { + return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b), false); } private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b, boolean signed) { @@ -886,6 +914,183 @@ private static int squareDistanceBody128( return acc1.add(acc2).reduceLanes(ADD); } + @Override + public int int4SquareDistance(byte[] a, byte[] b) { + return int4SquareDistanceBody(new ArrayLoader(a), new ArrayLoader(b)); + } + + public static int int4SquareDistance(byte[] a, MemorySegment b) { + return int4SquareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + } + + public static int int4SquareDistance(MemorySegment a, MemorySegment b) { + return int4SquareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); + } + + private static int int4SquareDistanceBody(ByteVectorLoader a, ByteVectorLoader b) { + int i = 0; + int res = 0; + if (a.length() >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(a.length()); + res += int4SquareDistanceBody(a, b, i); + } + // scalar tail + for (; i < a.length(); i++) { + int diff = a.tail(i) - b.tail(i); + res += diff * diff; + } + return res; + } + + private static int int4SquareDistanceBody(ByteVectorLoader a, ByteVectorLoader b, int limit) { + int sum = 0; + // iterate in chunks to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += Int4Constants.CHUNK) { + ShortVector acc = ShortVector.zero(Int4Constants.SHORT_SPECIES); + int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); + for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { + // unpacked + var vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j); + // unpacked + var va8 = a.load(Int4Constants.BYTE_SPECIES, i + j); + + ByteVector diff8 = vb8.sub(va8); + Vector diff16 = diff8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); + acc = acc.add(diff16.mul(diff16)); + } + Vector intAcc0 = acc.convert(S2I, 0); + Vector intAcc1 = acc.convert(S2I, 1); + sum += intAcc0.add(intAcc1).reinterpretAsInts().reduceLanes(ADD); + } + return sum; + } + + @Override + public int int4SquareDistanceSinglePacked(byte[] a, byte[] b) { + return int4SquareDistanceSinglePackedBody(new ArrayLoader(a), new ArrayLoader(b)); + } + + public static int int4SquareDistanceSinglePacked(byte[] a, MemorySegment b) { + return int4SquareDistanceSinglePackedBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + } + + private static int int4SquareDistanceSinglePackedBody( + ByteVectorLoader unpacked, ByteVectorLoader packed) { + int i = 0; + int res = 0; + if (packed.length() >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(packed.length()); + res += int4SquareDistanceSinglePackedBody(unpacked, packed, i); + } + // scalar tail + for (; i < packed.length(); i++) { + byte packedByte = packed.tail(i); + byte unpacked1 = unpacked.tail(i); + byte unpacked2 = unpacked.tail(i + packed.length()); + + int diff1 = (packedByte & 0x0F) - unpacked2; + int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1; + + res += diff1 * diff1 + diff2 * diff2; + } + return res; + } + + private static int int4SquareDistanceSinglePackedBody( + ByteVectorLoader unpacked, ByteVectorLoader packed, int limit) { + int sum = 0; + // iterate in chunks to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += Int4Constants.CHUNK) { + ShortVector acc0 = ShortVector.zero(Int4Constants.SHORT_SPECIES); + ShortVector acc1 = ShortVector.zero(Int4Constants.SHORT_SPECIES); + int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); + for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { + // packed + ByteVector vb8 = packed.load(Int4Constants.BYTE_SPECIES, i + j); + + // upper + ByteVector va8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j + packed.length()); + ByteVector diff8 = vb8.and((byte) 0x0F).sub(va8); + Vector diff16 = diff8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); + acc0 = acc0.add(diff16.mul(diff16)); + + // lower + ByteVector vc8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j); + ByteVector diff8a = vb8.lanewise(LSHR, 4).sub(vc8); + Vector diff16a = diff8a.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); + acc1 = acc1.add(diff16a.mul(diff16a)); + } + Vector intAcc0 = acc0.convert(S2I, 0); + Vector intAcc1 = acc0.convert(S2I, 1); + Vector intAcc2 = acc1.convert(S2I, 0); + Vector intAcc3 = acc1.convert(S2I, 1); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reinterpretAsInts().reduceLanes(ADD); + } + return sum; + } + + @Override + public int int4SquareDistanceBothPacked(byte[] a, byte[] b) { + return int4SquareDistanceBothPackedBody(new ArrayLoader(a), new ArrayLoader(b)); + } + + public static int int4SquareDistanceBothPacked(MemorySegment a, MemorySegment b) { + return int4SquareDistanceBothPackedBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); + } + + private static int int4SquareDistanceBothPackedBody(ByteVectorLoader a, ByteVectorLoader b) { + int i = 0; + int res = 0; + if (a.length() >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(a.length()); + res += int4SquareDistanceBothPackedBody(a, b, i); + } + // scalar tail + for (; i < a.length(); i++) { + byte aByte = a.tail(i); + byte bByte = b.tail(i); + + int diff1 = (aByte & 0x0F) - (bByte & 0x0F); + int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4); + + res += diff1 * diff1 + diff2 * diff2; + } + return res; + } + + private static int int4SquareDistanceBothPackedBody( + ByteVectorLoader a, ByteVectorLoader b, int limit) { + int sum = 0; + // iterate in chunks to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += Int4Constants.CHUNK) { + ShortVector acc0 = ShortVector.zero(Int4Constants.SHORT_SPECIES); + ShortVector acc1 = ShortVector.zero(Int4Constants.SHORT_SPECIES); + int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); + for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { + // packed + var vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j); + // packed + var va8 = a.load(Int4Constants.BYTE_SPECIES, i + j); + + // upper + ByteVector diff8 = vb8.and((byte) 0x0F).sub(va8.and((byte) 0x0F)); + Vector diff16 = diff8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); + acc0 = acc0.add(diff16.mul(diff16)); + + // lower + ByteVector diff8a = vb8.lanewise(LSHR, 4).sub(va8.lanewise(LSHR, 4)); + Vector diff16a = diff8a.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); + acc1 = acc1.add(diff16a.mul(diff16a)); + } + Vector intAcc0 = acc0.convert(S2I, 0); + Vector intAcc1 = acc0.convert(S2I, 1); + Vector intAcc2 = acc1.convert(S2I, 0); + Vector intAcc3 = acc1.convert(S2I, 1); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reinterpretAsInts().reduceLanes(ADD); + } + return sum; + } + // Experiments suggest that we need at least 8 lanes so that the overhead of going with the vector // approach and counting trues on vector masks pays off. private static final boolean ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO = INT_SPECIES.length() >= 8; diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index 54b3be67afcb..cf3ab94f417c 100644 --- a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -78,6 +78,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() { return Lucene99MemorySegmentFlatVectorsScorer.INSTANCE; } + @Override + public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return Lucene99MemorySegmentScalarQuantizedVectorScorer.INSTANCE; + } + @Override public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException { if (input instanceof MemorySegmentAccessInput msai) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 2c6c54cece73..3ad2cab88690 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -308,10 +308,19 @@ public KnnVectorsFormat knnVectorsFormat() { } }; String expectedPattern = - "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))"; - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=%s, rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s)))"; + var defaultScorer = + format( + Locale.ROOT, + expectedPattern, + "ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())", + "DefaultFlatVectorScorer()"); var memSegScorer = - format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + format( + Locale.ROOT, + expectedPattern, + "Lucene99MemorySegmentScalarQuantizedVectorScorer()", + "Lucene99MemorySegmentFlatVectorsScorer()"); assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index e04054c27e37..7156afd9cc3c 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -372,10 +372,19 @@ public KnnVectorsFormat knnVectorsFormat() { } }; String expectedPattern = - "Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))"; - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + "Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=%s, rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s))"; + var defaultScorer = + format( + Locale.ROOT, + expectedPattern, + "ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())", + "DefaultFlatVectorScorer()"); var memSegScorer = - format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + format( + Locale.ROOT, + expectedPattern, + "Lucene99MemorySegmentScalarQuantizedVectorScorer()", + "Lucene99MemorySegmentFlatVectorsScorer()"); assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); } diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java index 7ec661b3659f..78280e7e4c36 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java @@ -107,11 +107,23 @@ public void testInt4DotProduct() { b[i] = (byte) random().nextInt(16); } - assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true)); - assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false)); + assertIntReturningProviders(p -> p.int4DotProduct(a, b)); + assertIntReturningProviders(p -> p.int4DotProductSinglePacked(a, pack(b))); + assertIntReturningProviders(p -> p.int4DotProductSinglePacked(b, pack(a))); + assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b))); + + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b))); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a))); assertEquals( LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), - PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b))); } public void testInt4DotProductBoundaries() { @@ -122,20 +134,106 @@ public void testInt4DotProductBoundaries() { Arrays.fill(a, MAX_VALUE); Arrays.fill(b, MAX_VALUE); - assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true)); - assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false)); + + assertIntReturningProviders(p -> p.int4DotProduct(a, b)); + assertIntReturningProviders(p -> p.int4DotProductSinglePacked(a, pack(b))); + assertIntReturningProviders(p -> p.int4DotProductSinglePacked(b, pack(a))); + assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b))); + + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b))); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a))); assertEquals( LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), - PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b))); byte MIN_VALUE = 0; Arrays.fill(a, MIN_VALUE); Arrays.fill(b, MIN_VALUE); - assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true)); - assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false)); + + assertIntReturningProviders(p -> p.int4DotProduct(a, b)); + assertIntReturningProviders(p -> p.int4DotProductSinglePacked(a, pack(b))); + assertIntReturningProviders(p -> p.int4DotProductSinglePacked(b, pack(a))); + assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b))); + + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b))); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a))); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b))); + } + + public void testInt4SquareDistance() { + assumeTrue("even sizes only", size % 2 == 0); + var a = new byte[size]; + var b = new byte[size]; + for (int i = 0; i < size; ++i) { + a[i] = (byte) random().nextInt(16); + b[i] = (byte) random().nextInt(16); + } + + assertIntReturningProviders(p -> p.int4SquareDistance(a, b)); + assertIntReturningProviders(p -> p.int4SquareDistanceSinglePacked(a, pack(b))); + assertIntReturningProviders(p -> p.int4SquareDistanceSinglePacked(b, pack(a))); + assertIntReturningProviders(p -> p.int4SquareDistanceBothPacked(pack(a), pack(b))); + + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistance(a, b)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistanceSinglePacked(a, pack(b))); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistanceSinglePacked(b, pack(a))); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistanceBothPacked(pack(a), pack(b))); + } + + public void testInt4SquareDistanceBoundaries() { + assumeTrue("even sizes only", size % 2 == 0); + + // squareDistance is maximized when the points are farther away + + byte MAX_VALUE = 15; + var a = new byte[size]; + Arrays.fill(a, MAX_VALUE); + + byte MIN_VALUE = 0; + var b = new byte[size]; + Arrays.fill(b, MIN_VALUE); + + assertIntReturningProviders(p -> p.int4DotProduct(a, b)); + assertIntReturningProviders(p -> p.int4DotProductSinglePacked(a, pack(b))); + assertIntReturningProviders(p -> p.int4DotProductSinglePacked(b, pack(a))); + assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b))); + + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b))); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a))); assertEquals( LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), - PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b))); } public void testInt4BitDotProduct() {