From 18f7291a8d110d897adf2b6ddd32778eeb5d5d06 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Thu, 4 Sep 2025 14:59:29 -0700 Subject: [PATCH 01/22] vector scorer --- .../Lucene104ScalarQuantizedVectorScorer.java | 151 ++++++++++++++++++ .../lucene104/QuantizedByteVectorValues.java | 78 +++++++++ 2 files changed, 229 insertions(+) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..61c4f3bebc36 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -0,0 +1,151 @@ +package org.apache.lucene.codecs.lucene104; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +public class Lucene104ScalarQuantizedVectorScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + + public Lucene104ScalarQuantizedVectorScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues qv) { + return new ScalarQuantizedVectorScorerSupplier(qv, similarityFunction); + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues qv) { + OptimizedScalarQuantizer quantizer = qv.getQuantizer(); + byte[] targetQuantized = new byte[target.length]; + // XXX parameterize the number of bits + var targetCorrectiveTerms = + quantizer.scalarQuantize(target, targetQuantized, (byte) 8, qv.getCentroid()); + return new RandomVectorScorer.AbstractRandomVectorScorer(qv) { + @Override + public float score(int node) throws IOException { + return quantizedScore( + targetQuantized, targetCorrectiveTerms, qv, node, similarityFunction); + } + }; + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) + throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public String toString() { + return "ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + private static final class ScalarQuantizedVectorScorerSupplier + implements RandomVectorScorerSupplier { + private final QuantizedByteVectorValues targetValues; + private final QuantizedByteVectorValues values; + private final VectorSimilarityFunction similarity; + + public ScalarQuantizedVectorScorerSupplier( + QuantizedByteVectorValues values, VectorSimilarityFunction similarity) throws IOException { + this.targetValues = values.copy(); + this.values = values; + this.similarity = similarity; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private byte[] targetVector; + private OptimizedScalarQuantizer.QuantizationResult targetCorrectiveTerms; + + @Override + public float score(int node) throws IOException { + return quantizedScore(targetVector, targetCorrectiveTerms, values, node, similarity); + } + + @Override + public void setScoringOrdinal(int node) throws IOException { + targetVector = targetValues.vectorValue(node); + targetCorrectiveTerms = targetValues.getCorrectiveTerms(node); + } + }; + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new ScalarQuantizedVectorScorerSupplier(values.copy(), similarity); + } + } + + public static final float EIGHT_BIT_SCALE = 1f / ((1 << 8) - 1); + + // XXX factor this out to share with Lucene102BinaryFlatVectorsScorer + // we need to know how many bits were used for both the query and index vector for scaling. + static float quantizedScore( + byte[] quantizedQuery, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + QuantizedByteVectorValues targetVectors, + int targetOrd, + VectorSimilarityFunction similarityFunction) + throws IOException { + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float qcDist = VectorUtil.uint8DotProduct(quantizedQuery, binaryCode); + OptimizedScalarQuantizer.QuantizationResult indexCorrections = + targetVectors.getCorrectiveTerms(targetOrd); + float x1 = indexCorrections.quantizedComponentSum(); + float ax = indexCorrections.lowerInterval(); + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = (indexCorrections.upperInterval() - ax) * EIGHT_BIT_SCALE; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * EIGHT_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = + ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = + queryCorrections.additionalCorrection() + + indexCorrections.additionalCorrection() + - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += + queryCorrections.additionalCorrection() + + indexCorrections.additionalCorrection() + - targetVectors.getCentroidDP(); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java new file mode 100644 index 000000000000..e72960c78bdd --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import java.io.IOException; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Scalar quantized byte vector values */ +abstract class QuantizedByteVectorValues extends ByteVectorValues { + + /** + * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of + * distances, the corrective terms are, in order + * + * + * + * For euclidean: + * + * + * + * @param vectorOrd the vector ordinal + * @return the corrective terms + * @throws IOException if an I/O error occurs + */ + public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) + throws IOException; + + /** + * @return the quantizer used to quantize the vectors + */ + public abstract OptimizedScalarQuantizer getQuantizer(); + + public abstract float[] getCentroid() throws IOException; + + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance or null + */ + public abstract VectorScorer scorer(float[] query) throws IOException; + + @Override + public abstract QuantizedByteVectorValues copy() throws IOException; + + float getCentroidDP() throws IOException { + // this only gets executed on-merge + float[] centroid = getCentroid(); + return VectorUtil.dotProduct(centroid, centroid); + } +} From 149965abe736e7abaf64f716459a244dbc75acf7 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Thu, 4 Sep 2025 16:05:11 -0700 Subject: [PATCH 02/22] offheap vv --- .../Lucene104ScalarQuantizedVectorScorer.java | 1 + ...Lucene104ScalarQuantizedVectorsFormat.java | 144 ++++++ .../OffHeapScalarQuantizedVectorValues.java | 372 +++++++++++++++ .../lucene/codecs/lucene104/package-info.java | 436 ++++++++++++++++++ 4 files changed, 953 insertions(+) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene104/package-info.java diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index 61c4f3bebc36..5bdea5d95c2c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -13,6 +13,7 @@ import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +/** Vector scorer over OptimizedScalarQuantized vectors */ public class Lucene104ScalarQuantizedVectorScorer implements FlatVectorsScorer { private final FlatVectorsScorer nonQuantizedDelegate; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..1f08dd68bf9a --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * The binary quantization format used here is a per-vector optimized scalar quantization. These + * ideas are evolutions of LVQ proposed in Similarity + * search in the blink of an eye with compressed indices by Cecilia Aguerrebere et al., the + * previous work on globally optimized scalar quantization in Apache Lucene, and Accelerating Large-Scale Inference with Anisotropic + * Vector Quantization by Ruiqi Guo et. al. Also see {@link + * org.apache.lucene.util.quantization.OptimizedScalarQuantizer}. Some of key features are: + * + * + * + * A previous work related to improvements over regular LVQ is Practical and Asymptotically Optimal Quantization of + * High-Dimensional Vectors in Euclidean Space for Approximate Nearest Neighbor Search by + * Jianyang Gao, et. al. + * + *

The format is stored within two files: + * + *

.veb (vector data) file

+ * + *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

+ * + *

.vemb (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

+ */ +public class Lucene102BinaryQuantizedVectorsFormat extends FlatVectorsFormat { + + public static final byte QUERY_BITS = 4; + public static final byte INDEX_BITS = 1; + + public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; + public static final String NAME = "Lucene102BinaryQuantizedVectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "Lucene102BinaryQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene102BinaryQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemb"; + static final String VECTOR_DATA_EXTENSION = "veb"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + private static final Lucene102BinaryFlatVectorsScorer scorer = + new Lucene102BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + /** Creates a new instance with the default number of vectors per cluster. */ + public Lucene102BinaryQuantizedVectorsFormat() { + super(NAME); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene102BinaryQuantizedVectorsWriter( + scorer, rawVectorFormat.fieldsWriter(state), state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene102BinaryQuantizedVectorsReader( + state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene102BinaryQuantizedVectorsFormat(name=" + + NAME + + ", flatVectorScorer=" + + scorer + + ", rawVectorFormat=" + + rawVectorFormat + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java new file mode 100644 index 000000000000..57453e09c8ab --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import java.io.IOException; +import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** + * Scalar quantized vector values loaded from off-heap + * + * @lucene.internal + */ +public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVectorValues { + final int dimension; + final int size; + final VectorSimilarityFunction similarityFunction; + final FlatVectorsScorer vectorsScorer; + + final IndexInput slice; + final byte[] vectorValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final OptimizedScalarQuantizer quantizer; + final float[] centroid; + final float centroidDp; + + // XXX this needs bits, probably??? + OffHeapScalarQuantizedVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.correctiveValues = new float[3]; + this.byteSize = dimension + (Float.BYTES * 3) + Integer.BYTES; + this.byteBuffer = ByteBuffer.allocate(dimension); + this.vectorValue = byteBuffer.array(); + this.quantizer = quantizer; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return vectorValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), dimension); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + lastOrd = targetOrd; + return vectorValue; + } + + @Override + public float getCentroidDP() { + return centroidDp; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) + throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + slice.seek(((long) targetOrd * byteSize) + dimension); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + return quantizer; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public int getVectorByteLength() { + return dimension; + } + + static OffHeapScalarQuantizedVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + float[] centroid, + float centroidDp, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData) + throws IOException { + if (configuration.isEmpty()) { + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); + } + assert centroid != null; + IndexInput bytesSlice = + vectorData.slice( + "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + if (configuration.isDense()) { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + bytesSlice); + } else { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice); + } + } + + /** Dense off-heap binarized vector values */ + static class DenseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { + DenseOffHeapVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + super( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice); + } + + @Override + public OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy() throws IOException { + return new OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + quantizer, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + /** Sparse off-heap binarized vector values */ + private static class SparseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) + throws IOException { + super( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + quantizer, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { + EmptyOffHeapVectorValues( + int dimension, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer) { + super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/package-info.java new file mode 100644 index 000000000000..a406f096a0e3 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/package-info.java @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Lucene 10.4 file format. + * + *

Apache Lucene - Index File Formats

+ * + *
+ * + * + * + *
+ * + *

Introduction

+ * + *
+ * + *

This document defines the index file formats used in this version of Lucene. If you are using + * a different version of Lucene, please consult the copy of docs/ that was distributed + * with the version you are using. + * + *

This document attempts to provide a high-level definition of the Apache Lucene file formats. + *

+ * + *

Definitions

+ * + *
+ * + *

The fundamental concepts in Lucene are index, document, field and term. + * + *

An index contains a sequence of documents. + * + *

    + *
  • A document is a sequence of fields. + *
  • A field is a named sequence of terms. + *
  • A term is a sequence of bytes. + *
+ * + *

The same sequence of bytes in two different fields is considered a different term. Thus terms + * are represented as a pair: the string naming the field, and the bytes within the field. + * + *

Inverted Indexing

+ * + *

Lucene's index stores terms and statistics about those terms in order to make term-based + * search more efficient. Lucene's terms index falls into the family of indexes known as an + * inverted index. This is because it can list, for a term, the documents that contain it. + * This is the inverse of the natural relationship, in which documents list terms. + * + *

Types of Fields

+ * + *

In Lucene, fields may be stored, in which case their text is stored in the index + * literally, in a non-inverted manner. Fields that are inverted are called indexed. A field + * may be both stored and indexed. + * + *

The text of a field may be tokenized into terms to be indexed, or the text of a field + * may be used literally as a term to be indexed. Most fields are tokenized, but sometimes it is + * useful for certain identifier fields to be indexed literally. + * + *

See the {@link org.apache.lucene.document.Field Field} java docs for more information on + * Fields. + * + *

Segments

+ * + *

Lucene indexes may be composed of multiple sub-indexes, or segments. Each segment is a + * fully independent index, which could be searched separately. Indexes evolve by: + * + *

    + *
  1. Creating new segments for newly added documents. + *
  2. Merging existing segments. + *
+ * + *

Searches may involve multiple segments and/or multiple indexes, each index potentially + * composed of a set of segments. + * + *

Document Numbers

+ * + *

Internally, Lucene refers to documents by an integer document number. The first + * document added to an index is numbered zero, and each subsequent document added gets a number one + * greater than the previous. + * + *

Note that a document's number may change, so caution should be taken when storing these + * numbers outside of Lucene. In particular, numbers may change in the following situations: + * + *

    + *
  • + *

    The numbers stored in each segment are unique only within the segment, and must be + * converted before they can be used in a larger context. The standard technique is to + * allocate each segment a range of values, based on the range of numbers used in that + * segment. To convert a document number from a segment to an external value, the segment's + * base document number is added. To convert an external value back to a + * segment-specific value, the segment is identified by the range that the external value is + * in, and the segment's base value is subtracted. For example two five document segments + * might be combined, so that the first segment has a base value of zero, and the second of + * five. Document three from the second segment would have an external value of eight. + *

  • + *

    When documents are deleted, gaps are created in the numbering. These are eventually + * removed as the index evolves through merging. Deleted documents are dropped when segments + * are merged. A freshly-merged segment thus has no gaps in its numbering. + *

+ * + *
+ * + *

Index Structure Overview

+ * + *
+ * + *

Each segment index maintains the following: + * + *

    + *
  • {@link org.apache.lucene.codecs.lucene99.Lucene99SegmentInfoFormat Segment info}. This + * contains metadata about a segment, such as the number of documents, what files it uses, and + * information about how the segment is sorted + *
  • {@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Field names}. This + * contains metadata about the set of named fields used in the index. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}. + * This contains, for each document, a list of attribute-value pairs, where the attributes are + * field names. These are used to store auxiliary information about the document, such as its + * title, url, or an identifier to access a database. The set of stored fields are what is + * returned for each hit when searching. This is keyed by document number. + *
  • {@link org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat Term dictionary}. A + * dictionary containing all of the terms used in all of the indexed fields of all of the + * documents. The dictionary also contains the number of documents which contain the term, and + * pointers to the term's frequency and proximity data. + *
  • {@link org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat Term Frequency data}. For + * each term in the dictionary, the numbers of all the documents that contain that term, and + * the frequency of the term in that document, unless frequencies are omitted ({@link + * org.apache.lucene.index.IndexOptions#DOCS IndexOptions.DOCS}) + *
  • {@link org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat Term Proximity data}. For + * each term in the dictionary, the positions that the term occurs in each document. Note that + * this will not exist if all fields in all documents omit position data. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90NormsFormat Normalization factors}. For + * each field in each document, a value is stored that is multiplied into the score for hits + * on that field. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vectors}. For each + * field in each document, the term vector (sometimes called document vector) may be stored. A + * term vector consists of term text and term frequency. To add Term Vectors to your index see + * the {@link org.apache.lucene.document.Field Field} constructors + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat Per-document values}. Like + * stored values, these are also keyed by document number, but are generally intended to be + * loaded into main memory for fast access. Whereas stored values are generally intended for + * summary results from searches, per-document values are useful for things like scoring + * factors. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat Live documents}. An + * optional file indicating which documents are live. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90PointsFormat Point values}. Optional pair + * of files, recording dimensionally indexed fields, to enable fast numeric range filtering + * and large numeric values like BigInteger and BigDecimal (1D) and geographic shape + * intersection (2D, 3D). + *
  • {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}. The + * vector format stores numeric vectors in a format optimized for random access and + * computation, supporting high-dimensional nearest-neighbor search. + *
+ * + *

Details on each of these are provided in their linked pages.

+ * + *

File Naming

+ * + *
+ * + *

All files belonging to a segment have the same name with varying extensions. The extensions + * correspond to the different file formats described below. When using the Compound File format + * (default for small segments) these files (except for the Segment info file, the Lock file, and + * Deleted documents file) are collapsed into a single .cfs file (see below for details) + * + *

Typically, all segments in an index are stored in a single directory, although this is not + * required. + * + *

File names are never re-used. That is, when any file is saved to the Directory it is given a + * never before used filename. This is achieved using a simple generations approach. For example, + * the first segments file is segments_1, then segments_2, etc. The generation is a sequential long + * integer represented in alpha-numeric (base 36) form.

+ * + *

Summary of File Extensions

+ * + *
+ * + *

The following table summarizes the names and extensions of the files in Lucene: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
lucene filenames by extension
NameExtensionBrief Description
{@link org.apache.lucene.index.SegmentInfos Segments File}segments_NStores information about a commit point
Lock Filewrite.lockThe Write lock prevents multiple IndexWriters from writing to the same + * file.
{@link org.apache.lucene.codecs.lucene99.Lucene99SegmentInfoFormat Segment Info}.siStores metadata about a segment
{@link org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat Compound File}.cfs, .cfeAn optional "virtual" file consisting of all the other index files for + * systems that frequently run out of file handles.
{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Fields}.fnmStores information about the fields
{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Index}.fdxContains pointers to field data
{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Data}.fdtThe stored fields for documents
{@link org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat Term Dictionary}.timThe term dictionary, stores term info
{@link org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat Term Index}.tipThe index into the Term Dictionary
{@link org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat Frequencies}.docContains the list of docs which contain each term along with frequency
{@link org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat Positions}.posStores position information about where a term occurs in the index
{@link org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat Payloads}.payStores additional per-position metadata information such as character offsets and user payloads
{@link org.apache.lucene.codecs.lucene90.Lucene90NormsFormat Norms}.nvd, .nvmEncodes length and boost factors for docs and fields
{@link org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat Per-Document Values}.dvd, .dvmEncodes additional scoring factors or other per-document information.
{@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Index}.tvxStores offset into the document data file
{@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Data}.tvdContains term vector data.
{@link org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat Live Documents}.livInfo about what documents are live
{@link org.apache.lucene.codecs.lucene90.Lucene90PointsFormat Point values}.kdd, .kdi, .kdmHolds indexed points
{@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}.vec, .vem, .veq, vexHolds indexed vectors; .vec files contain the raw vector data, + * .vem the vector metadata, .veq the quantized vector data, and .vex the + * hnsw graph data.
+ * + *

+ * + *

Lock File

+ * + * The write lock, which is stored in the index directory by default, is named "write.lock". If the + * lock directory is different from the index directory then the write lock will be named + * "XXXX-write.lock" where XXXX is a unique prefix derived from the full path to the index + * directory. When this file is present, a writer is currently modifying the index (adding or + * removing documents). This lock file ensures that only one writer is modifying the index at a + * time. + * + *

History

+ * + *

Compatibility notes are provided in this document, describing how file formats have changed + * from prior versions: + * + *

+ * + * + * + *

Limitations

+ * + *
+ * + *

Lucene uses a Java int to refer to document numbers, and the index file format + * uses an Int32 on-disk to store document numbers. This is a limitation of both the + * index file format and the current implementation. Eventually these should be replaced with either + * UInt64 values, or better yet, {@link org.apache.lucene.store.DataOutput#writeVInt + * VInt} values which have no limit.

+ */ +package org.apache.lucene.codecs.lucene104; From aedce9c720ff8fa8d2ea7de14408431063a56c64 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Fri, 5 Sep 2025 15:52:14 -0700 Subject: [PATCH 03/22] writer --- ...Lucene104ScalarQuantizedVectorsFormat.java | 56 +- ...Lucene104ScalarQuantizedVectorsWriter.java | 808 ++++++++++++++++++ .../lucene104/QuantizedByteVectorValues.java | 3 + 3 files changed, 838 insertions(+), 29 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index 1f08dd68bf9a..dfd140ac9f83 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.codecs.lucene102; +package org.apache.lucene.codecs.lucene104; import java.io.IOException; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; @@ -26,10 +26,10 @@ import org.apache.lucene.index.SegmentWriteState; /** - * The binary quantization format used here is a per-vector optimized scalar quantization. These - * ideas are evolutions of LVQ proposed in Similarity - * search in the blink of an eye with compressed indices by Cecilia Aguerrebere et al., the - * previous work on globally optimized scalar quantization in Apache Lucene, and Similarity search in the + * blink of an eye with compressed indices by Cecilia Aguerrebere et al., the previous work on + * globally optimized scalar quantization in Apache Lucene, and Accelerating Large-Scale Inference with Anisotropic * Vector Quantization by Ruiqi Guo et. al. Also see {@link * org.apache.lucene.util.quantization.OptimizedScalarQuantizer}. Some of key features are: @@ -52,31 +52,32 @@ * *

The format is stored within two files: * - *

.veb (vector data) file

+ *

.veq (vector data) file

* - *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's + *

Stores the quantized vectors in a flat format. Additionally, it stores each vector's * corrective factors. At the end of the file, additional information is stored for vector ordinal * to centroid ordinal mapping and sparse vector information. * *

    *
  • For each vector: *
      - *
    • [byte] the binary quantized values, each byte holds 8 bits. + *
    • [byte] the quantized values. Each dimension may be up to 8 bits, and multiple + * dimensions may be packed into a single byte. *
    • [float] the optimized quantiles and an additional similarity dependent * corrective factor. - *
    • short the sum of the quantized components + *
    • [int] the sum of the quantized components *
    *
  • After the vectors, sparse vector information keeping track of monotonic blocks. *
* - *

.vemb (vector metadata) file

+ *

.vemq (vector metadata) file

* *

Stores the metadata for the vectors. This includes the number of vectors, the number of * dimensions, and file offset information. * *

    *
  • int the field number - *
  • int the vector encoding ordinal + *
  • int the vector encoding ordinal XXX wut is this? *
  • int the vector similarity ordinal *
  • vint the vector dimensions *
  • vlong the offset to the vector data in the .veb file @@ -87,43 +88,40 @@ *
  • The sparse vector information, if required, mapping vector ordinal to doc ID *
*/ -public class Lucene102BinaryQuantizedVectorsFormat extends FlatVectorsFormat { - - public static final byte QUERY_BITS = 4; - public static final byte INDEX_BITS = 1; - - public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; - public static final String NAME = "Lucene102BinaryQuantizedVectorsFormat"; +public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat { + public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC"; + public static final String NAME = "Lucene104ScalarQuantizedVectorsFormat"; static final int VERSION_START = 0; static final int VERSION_CURRENT = VERSION_START; - static final String META_CODEC_NAME = "Lucene102BinaryQuantizedVectorsFormatMeta"; - static final String VECTOR_DATA_CODEC_NAME = "Lucene102BinaryQuantizedVectorsFormatData"; - static final String META_EXTENSION = "vemb"; - static final String VECTOR_DATA_EXTENSION = "veb"; + static final String META_CODEC_NAME = "Lucene104ScalarQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene104ScalarQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemq"; + static final String VECTOR_DATA_EXTENSION = "veq"; static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); - private static final Lucene102BinaryFlatVectorsScorer scorer = - new Lucene102BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + private static final Lucene104ScalarQuantizedVectorScorer scorer = + new Lucene104ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); /** Creates a new instance with the default number of vectors per cluster. */ - public Lucene102BinaryQuantizedVectorsFormat() { + public Lucene104ScalarQuantizedVectorsFormat() { super(NAME); } @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene102BinaryQuantizedVectorsWriter( + return new Lucene104ScalarQuantizedVectorsWriter( scorer, rawVectorFormat.fieldsWriter(state), state); } @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new Lucene102BinaryQuantizedVectorsReader( - state, rawVectorFormat.fieldsReader(state), scorer); + throw new UnsupportedOperationException("XXX TODO"); + // return new Lucene102BinaryQuantizedVectorsReader( + // state, rawVectorFormat.fieldsReader(state), scorer); } @Override @@ -133,7 +131,7 @@ public int getMaxDimensions(String fieldName) { @Override public String toString() { - return "Lucene102BinaryQuantizedVectorsFormat(name=" + return "Lucene104ScalarQuantizedVectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java new file mode 100644 index 000000000000..055831571ecc --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -0,0 +1,808 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ +public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = + shallowSizeOfInstance(Lucene104ScalarQuantizedVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, vectorData; + private final FlatVectorsWriter rawVectorDelegate; + private final Lucene104ScalarQuantizedVectorScorer vectorsScorer; + private boolean finished; + + /** + * Sole constructor + * + * @param vectorsScorer the scorer to use for scoring vectors + */ + protected Lucene104ScalarQuantizedVectorsWriter( + Lucene104ScalarQuantizedVectorScorer vectorsScorer, + FlatVectorsWriter rawVectorDelegate, + SegmentWriteState state) + throws IOException { + super(vectorsScorer); + this.vectorsScorer = vectorsScorer; + this.segmentWriteState = state; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene104ScalarQuantizedVectorsFormat.META_EXTENSION); + + String binarizedVectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION); + this.rawVectorDelegate = rawVectorDelegate; + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(binarizedVectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + Lucene104ScalarQuantizedVectorsFormat.META_CODEC_NAME, + Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + CodecUtil.writeIndexHeader( + vectorData, + Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FieldWriter fieldWriter = + new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); + fields.add(fieldWriter); + return fieldWriter; + } + return rawVectorDelegate; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + // after raw vectors are written, normalize vectors for clustering and quantization + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + field.normalizeVectors(); + } + final float[] clusterCenter; + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } + } + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction()); + if (sortMap == null) { + writeField(field, clusterCenter, maxDoc, quantizer); + } else { + writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + private void writeField( + FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer) + throws IOException { + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + writeVectors(fieldData, clusterCenter, quantizer); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + !fieldData.getVectors().isEmpty() ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + + // XXX should probably include bits; check existing sq format + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + clusterCenter, + centroidDp, + fieldData.getDocsWithFieldSet()); + } + + private void writeVectors( + FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; + for (int i = 0; i < fieldData.getVectors().size(); i++) { + float[] v = fieldData.getVectors().get(i); + // XXX properly parameterize bits + OptimizedScalarQuantizer.QuantizationResult corrections = + scalarQuantizer.scalarQuantize(v, vector, (byte) 8, clusterCenter); + vectorData.writeBytes(vector, vector.length); + vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + vectorData.writeInt(corrections.quantizedComponentSum()); + } + } + + private void writeSortingField( + FieldWriter fieldData, + float[] clusterCenter, + int maxDoc, + Sorter.DocMap sortMap, + OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + final int[] ordMap = + new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + writeSortedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); + long quantizedVectorLength = vectorData.getFilePointer() - vectorDataOffset; + + float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + quantizedVectorLength, + clusterCenter, + centroidDp, + newDocsWithField); + } + + private void writeSortedVectors( + FieldWriter fieldData, + float[] clusterCenter, + int[] ordMap, + OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + int discreteDims = discretize(fieldData.fieldInfo.getVectorDimension(), 64); + // XXX properly parameterize bits + byte[] quantizationScratch = new byte[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + OptimizedScalarQuantizer.QuantizationResult corrections = + scalarQuantizer.scalarQuantize(v, quantizationScratch, (byte) 8, clusterCenter); + packAsBinary(quantizationScratch, vector); + vectorData.writeBytes(vector, vector.length); + vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + vectorData.writeInt(corrections.quantizedComponentSum()); + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + float[] clusterCenter, + float centroidDp, + DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + int count = docsWithField.cardinality(); + meta.writeVInt(count); + if (count > 0) { + final ByteBuffer buffer = + ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + meta.writeInt(Float.floatToIntBits(centroidDp)); + } + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + return; + } + + final float[] centroid; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + FloatVectorValues floatVectorValues = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + QuantizedFloatVectorValues quantizedVectorValues = + new QuantizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + centroid); + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = writeVectorData(vectorData, quantizedVectorValues); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + centroidDp, + docsWithField); + } + + static DocsWithFieldSet writeVectorData( + IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iterator = quantizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write vector + byte[] binaryValue = quantizedByteVectorValues.vectorValue(iterator.index()); + output.writeBytes(binaryValue, binaryValue.length); + OptimizedScalarQuantizer.QuantizationResult corrections = + quantizedByteVectorValues.getCorrectiveTerms(iterator.index()); + output.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + output.writeInt(Float.floatToIntBits(corrections.upperInterval())); + output.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + output.writeInt(corrections.quantizedComponentSum()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + final float[] centroid; + final float cDotC; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC); + } + + private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + float[] centroid, + float cDotC) + throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + IndexOutput tempQuantizedVectorData = null; + IndexInput quantizedDataInput = null; + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + try { + // XXX I guess we write to a temp file so that we can actually read it? that sucks. + tempQuantizedVectorData = + segmentWriteState.directory.createTempOutput( + vectorData.getName(), "temp", segmentWriteState.context); + final String tempQuantizedVectorName = tempQuantizedVectorData.getName(); + FloatVectorValues floatVectorValues = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + DocsWithFieldSet docsWithField = + writeVectorData( + tempQuantizedVectorData, + new QuantizedFloatVectorValues(floatVectorValues, quantizer, centroid)); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + quantizedDataInput = + segmentWriteState.directory.openInput(tempQuantizedVectorName, segmentWriteState.context); + vectorData.copyBytes( + quantizedDataInput, quantizedDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + CodecUtil.retrieveChecksum(quantizedDataInput); + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + cDotC, + docsWithField); + + final IndexInput finalQuantizedDataInput = quantizedDataInput; + tempQuantizedVectorData = null; + quantizedDataInput = null; + + OffHeapScalarQuantizedVectorValues vectorValues = + new OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + centroid, + cDotC, + quantizer, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, + finalQuantizedDataInput); + RandomVectorScorerSupplier scorerSupplier = + vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), vectorValues); + return new QuantizedCloseableRandomVectorScorerSupplier( + scorerSupplier, + vectorValues, + () -> { + IOUtils.close(finalQuantizedDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, tempQuantizedVectorName); + }); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, tempQuantizedVectorData, quantizedDataInput); + if (tempQuantizedVectorData != null) { + IOUtils.deleteFilesSuppressingExceptions( + t, segmentWriteState.directory, tempQuantizedVectorData.getName()); + } + throw t; + } + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData, rawVectorDelegate); + } + + static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + // XXX not fixable yet + // if (vectorsReader instanceof Lucene102BinaryQuantizedVectorsReader reader) { + // return reader.getCentroid(fieldName); + // } + return null; + } + + static int mergeAndRecalculateCentroids( + MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { + boolean recalculate = false; + int totalVectorCount = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null + || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); + int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); + if (vectorCount == 0) { + continue; + } + totalVectorCount += vectorCount; + // If there aren't centroids, or previously clustered with more than one cluster + // or if there are deleted docs, we must recalculate the centroid + if (centroid == null || mergeState.liveDocs[i] != null) { + recalculate = true; + break; + } + for (int j = 0; j < centroid.length; j++) { + mergedCentroid[j] += centroid[j] * vectorCount; + } + } + if (recalculate) { + return calculateCentroid(mergeState, fieldInfo, mergedCentroid); + } else { + for (int j = 0; j < mergedCentroid.length; j++) { + mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(mergedCentroid); + } + return totalVectorCount; + } + } + + static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) + throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + // clear out the centroid + Arrays.fill(centroid, 0); + int count = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null) continue; + FloatVectorValues vectorValues = + mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); + doc != DocIdSetIterator.NO_MORE_DOCS; + doc = iterator.nextDoc()) { + ++count; + float[] vector = vectorValues.vectorValue(iterator.index()); + for (int j = 0; j < vector.length; j++) { + centroid[j] += vector[j]; + } + } + } + if (count == 0) { + return count; + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= count; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(centroid); + } + return count; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private final float[] dimensionSums; + private final FloatArrayList magnitudes = new FloatArrayList(); + + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = flatFieldVectorsWriter; + this.dimensionSums = new float[fieldInfo.getVectorDimension()]; + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + public void normalizeVectors() { + for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { + float[] vector = flatFieldVectorsWriter.getVectors().get(i); + float magnitude = magnitudes.get(i); + for (int j = 0; j < vector.length; j++) { + vector[j] /= magnitude; + } + } + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float divisor = (float) Math.sqrt(dp); + magnitudes.add(divisor); + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } + } else { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += vectorValue[i]; + } + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + size += magnitudes.ramBytesUsed(); + return size; + } + } + + static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] quantized; + private final float[] centroid; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; + + private int lastOrd = -1; + + QuantizedFloatVectorValues( + FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { + this.values = delegate; + this.quantizer = quantizer; + this.quantized = new byte[delegate.dimension()]; + this.centroid = centroid; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + + ord + + " than the quantization was done for: " + + lastOrd); + } + return corrections; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + quantize(ord); + lastOrd = ord; + } + return quantized; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public int size() { + return values.size(); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public QuantizedByteVectorValues copy() throws IOException { + return new QuantizedFloatVectorValues(values.copy(), quantizer, centroid); + } + + private void quantize(int ord) throws IOException { + // XXX properly parameterize bits + corrections = + quantizer.scalarQuantize(values.vectorValue(ord), quantized, (byte) 8, centroid); + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + } + + static class QuantizedCloseableRandomVectorScorerSupplier + implements CloseableRandomVectorScorerSupplier { + private final RandomVectorScorerSupplier supplier; + private final KnnVectorValues vectorValues; + private final Closeable onClose; + + QuantizedCloseableRandomVectorScorerSupplier( + RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { + this.supplier = supplier; + this.onClose = onClose; + this.vectorValues = vectorValues; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return vectorValues.size(); + } + } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + + NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java index e72960c78bdd..f497a277d50d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java @@ -22,6 +22,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +// XXX do I want to make this public? this also overlaps heavily with the binarized version. /** Scalar quantized byte vector values */ abstract class QuantizedByteVectorValues extends ByteVectorValues { @@ -70,6 +71,8 @@ public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(i @Override public abstract QuantizedByteVectorValues copy() throws IOException; + // XXX off heap overrides this. this is probably only used in one other spot so it should be + // abstract. float getCentroidDP() throws IOException { // this only gets executed on-merge float[] centroid = getCentroid(); From 89a4228833b46ede3731649f2c1e35ef2f7e25ab Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 8 Sep 2025 10:20:59 -0700 Subject: [PATCH 04/22] reader --- lucene/core/src/java/module-info.java | 1 + ...Lucene104ScalarQuantizedVectorsFormat.java | 5 +- ...Lucene104ScalarQuantizedVectorsReader.java | 437 ++++++++++++++++++ ...Lucene104ScalarQuantizedVectorsWriter.java | 7 +- 4 files changed, 443 insertions(+), 7 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index b9030b52a470..3956cd829121 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -33,6 +33,7 @@ exports org.apache.lucene.codecs.lucene102; exports org.apache.lucene.codecs.lucene103.blocktree; exports org.apache.lucene.codecs.lucene103; + exports org.apache.lucene.codecs.lucene104; exports org.apache.lucene.codecs.perfield; exports org.apache.lucene.codecs; exports org.apache.lucene.document; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index dfd140ac9f83..c62a56c33a21 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -119,9 +119,8 @@ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOExceptio @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - throw new UnsupportedOperationException("XXX TODO"); - // return new Lucene102BinaryQuantizedVectorsReader( - // state, rawVectorFormat.fieldsReader(state), scorer); + return new Lucene104ScalarQuantizedVectorsReader( + state, rawVectorFormat.fieldsReader(state), scorer); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java new file mode 100644 index 000000000000..9d96b8a963fb --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Reader for scalar quantized vectors in the Lucene 10.4 format. */ +class Lucene104ScalarQuantizedVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(Lucene104ScalarQuantizedVectorsReader.class); + + private final Map fields = new HashMap<>(); + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + private final Lucene104ScalarQuantizedVectorScorer vectorScorer; + + Lucene104ScalarQuantizedVectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + Lucene104ScalarQuantizedVectorScorer vectorsScorer) + throws IOException { + super(vectorsScorer); + this.vectorScorer = vectorsScorer; + this.rawVectorsReader = rawVectorsReader; + int versionMeta = -1; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene104ScalarQuantizedVectorsFormat.META_EXTENSION); + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + Lucene104ScalarQuantizedVectorsFormat.META_CODEC_NAME, + Lucene104ScalarQuantizedVectorsFormat.VERSION_START, + Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + quantizedVectorData = + openDataInput( + state, + versionMeta, + VECTOR_DATA_EXTENSION, + Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Quantized vectors are accessed randomly from their node ID stored in the HNSW + // graph. + state.context.withHints( + FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM)); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta, info); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + + info.name + + "\"; " + + dimension + + " != " + + fieldEntry.dimension); + } + + long numQuantizedVectorBytes = + Math.multiplyExact((dimension + (Float.BYTES * 3) + Integer.BYTES), (long) fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "vector data length " + + fieldEntry.vectorDataLength + + " not matching size = " + + fieldEntry.size + + " * (dims=" + + dimension + + " + 16" + + ") = " + + numQuantizedVectorBytes); + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + return vectorScorer.getRandomVectorScorer( + fi.similarityFunction, + OffHeapScalarQuantizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData), + target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fi.vectorEncoding + + " expected: " + + VectorEncoding.FLOAT32); + } + OffHeapScalarQuantizedVectorValues sqvv = + OffHeapScalarQuantizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData); + return new ScalarQuantizedVectorValues(rawVectorsReader.getFloatVectorValues(field), sqvv); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { + if (knnCollector.k() == 0) return; + final RandomVectorScorer scorer = getRandomVectorScorer(field, target); + if (scorer == null) return; + OrdinalTranslatedKnnCollector collector = + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits()); + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + collector.collect(i, scorer.score(i)); + collector.incVisitedCount(1); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += + RamUsageEstimator.sizeOfMap( + fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + Objects.requireNonNull(fieldInfo); + var raw = rawVectorsReader.getOffHeapByteSize(fieldInfo); + var fieldEntry = fields.get(fieldInfo.name); + if (fieldEntry == null) { + assert fieldInfo.getVectorEncoding() == VectorEncoding.BYTE; + return raw; + } + var quant = Map.of(VECTOR_DATA_EXTENSION, fieldEntry.vectorDataLength()); + return KnnVectorsReader.mergeOffHeapByteSizeMaps(raw, quant); + } + + public float[] getCentroid(String field) { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry != null) { + return fieldEntry.centroid; + } + return null; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context) + throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + try { + int versionVectorData = + CodecUtil.checkIndexHeader( + in, + codecName, + Lucene104ScalarQuantizedVectorsFormat.VERSION_START, + Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + + versionMeta + + ", " + + codecName + + "=" + + versionVectorData, + in); + } + CodecUtil.retrieveChecksum(in); + return in; + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, in); + throw t; + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction()); + } + return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + int dimension, + int descritizedDimension, + long vectorDataOffset, + long vectorDataLength, + int size, + float[] centroid, + float centroidDP, + OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) { + + static FieldEntry create( + IndexInput input, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction) + throws IOException { + int dimension = input.readVInt(); + long vectorDataOffset = input.readVLong(); + long vectorDataLength = input.readVLong(); + int size = input.readVInt(); + final float[] centroid; + float centroidDP = 0; + if (size > 0) { + centroid = new float[dimension]; + input.readFloats(centroid, 0, dimension); + centroidDP = Float.intBitsToFloat(input.readInt()); + } else { + centroid = null; + } + OrdToDocDISIReaderConfiguration conf = + OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + dimension, + discretize(dimension, 64), + vectorDataOffset, + vectorDataLength, + size, + centroid, + centroidDP, + conf); + } + } + + /** Vector values holding row and quantized vector values */ + protected static final class ScalarQuantizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final QuantizedByteVectorValues quantizedVectorValues; + + ScalarQuantizedVectorValues( + FloatVectorValues rawVectorValues, QuantizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); + } + + @Override + public ScalarQuantizedVectorValues copy() throws IOException { + return new ScalarQuantizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return rawVectorValues.getAcceptOrds(acceptDocs); + } + + @Override + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.scorer(query); + } + + QuantizedByteVectorValues getQuantizedVectorValues() throws IOException { + return quantizedVectorValues; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 055831571ecc..4a432e4fa9bb 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -473,10 +473,9 @@ static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { vectorsReader = candidateReader.getFieldReader(fieldName); } - // XXX not fixable yet - // if (vectorsReader instanceof Lucene102BinaryQuantizedVectorsReader reader) { - // return reader.getCentroid(fieldName); - // } + if (vectorsReader instanceof Lucene104ScalarQuantizedVectorsReader reader) { + return reader.getCentroid(fieldName); + } return null; } From a55fee90d11ac595c9e451e767289040add16c7b Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 8 Sep 2025 11:21:16 -0700 Subject: [PATCH 05/22] wrap up and test flat vector format for sq8 --- lucene/core/src/java/module-info.java | 3 +- .../Lucene104ScalarQuantizedVectorScorer.java | 10 +- ...Lucene104ScalarQuantizedVectorsWriter.java | 9 +- .../org.apache.lucene.codecs.KnnVectorsFormat | 1 + ...Lucene104ScalarQuantizedVectorsFormat.java | 178 ++++++++++++++++++ 5 files changed, 192 insertions(+), 9 deletions(-) create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 3956cd829121..0d931587b1a3 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -88,7 +88,8 @@ org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat, org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat, org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat, - org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat; + org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat; provides org.apache.lucene.index.SortFieldProvider with diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index 5bdea5d95c2c..042aaf77abc5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -1,5 +1,6 @@ package org.apache.lucene.codecs.lucene104; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; @@ -7,6 +8,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -39,6 +41,12 @@ public RandomVectorScorer getRandomVectorScorer( if (vectorValues instanceof QuantizedByteVectorValues qv) { OptimizedScalarQuantizer quantizer = qv.getQuantizer(); byte[] targetQuantized = new byte[target.length]; + // We make a copy as the quantization process mutates the input + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + if (similarityFunction == COSINE) { + VectorUtil.l2normalize(copy); + } + target = copy; // XXX parameterize the number of bits var targetCorrectiveTerms = quantizer.scalarQuantize(target, targetQuantized, (byte) 8, qv.getCentroid()); @@ -63,7 +71,7 @@ public RandomVectorScorer getRandomVectorScorer( @Override public String toString() { - return "ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + return "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; } private static final class ScalarQuantizedVectorScorerSupplier diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 4a432e4fa9bb..240583b810c2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -21,8 +21,6 @@ import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary; import java.io.Closeable; import java.io.IOException; @@ -238,15 +236,12 @@ private void writeSortedVectors( int[] ordMap, OptimizedScalarQuantizer scalarQuantizer) throws IOException { - int discreteDims = discretize(fieldData.fieldInfo.getVectorDimension(), 64); // XXX properly parameterize bits - byte[] quantizationScratch = new byte[discreteDims]; - byte[] vector = new byte[discreteDims / 8]; + byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; for (int ordinal : ordMap) { float[] v = fieldData.getVectors().get(ordinal); OptimizedScalarQuantizer.QuantizationResult corrections = - scalarQuantizer.scalarQuantize(v, quantizationScratch, (byte) 8, clusterCenter); - packAsBinary(quantizationScratch, vector); + scalarQuantizer.scalarQuantize(v, vector, (byte) 8, clusterCenter); vectorData.writeBytes(vector, vector.length); vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 0558fc8fef05..b4ea8f29e730 100644 --- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -18,3 +18,4 @@ org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat +org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..957314444f3d --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +public class TestLucene104ScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + private static final KnnVectorsFormat FORMAT = new Lucene104ScalarQuantizedVectorsFormat(); + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(FORMAT); + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene104ScalarQuantizedVectorsFormat(); + } + }; + String expectedPattern = + "Lucene104ScalarQuantizedVectorsFormat(" + + "name=Lucene104ScalarQuantizedVectorsFormat, " + + "flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()), " + + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))"; + var defaultScorer = + format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer"); + var memSegScorer = + format( + Locale.ROOT, + expectedPattern, + "Lucene99MemorySegmentFlatVectorsScorer", + "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + public void testQuantizedVectorsWriteAndRead() throws IOException { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + if (i % 101 == 0) { + w.commit(); + } + } + w.commit(); + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); + assertEquals(vectorValues.size(), numVectors); + QuantizedByteVectorValues qvectorValues = + ((Lucene104ScalarQuantizedVectorsReader.ScalarQuantizedVectorValues) vectorValues) + .getQuantizedVectorValues(); + float[] centroid = qvectorValues.getCentroid(); + assertEquals(centroid.length, dims); + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + byte[] expectedVector = new byte[dims]; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + vectorValues = + new Lucene104ScalarQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); + } + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + OptimizedScalarQuantizer.QuantizationResult corrections = + quantizer.scalarQuantize( + vectorValues.vectorValue(docIndexIterator.index()), + expectedVector, + // XXX FIXME + (byte)8, + centroid); + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); + } + } + } + } + } +} From b7abe2d8d851104c80d177c62a4db10fc43015ec Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 8 Sep 2025 13:09:22 -0700 Subject: [PATCH 06/22] hnsw codec --- lucene/core/src/java/module-info.java | 3 +- ...ne104HnswScalarQuantizedVectorsFormat.java | 155 +++++++++++++++ .../org.apache.lucene.codecs.KnnVectorsFormat | 1 + ...ne104HnswScalarQuantizedVectorsFormat.java | 181 ++++++++++++++++++ 4 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 0d931587b1a3..b54746cf3409 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -89,7 +89,8 @@ org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat, org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat, org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat, - org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; + org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat; provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.codecs.lucene103.Lucene103PostingsFormat; provides org.apache.lucene.index.SortFieldProvider with diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..b90c697e7846 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +/** + * A vectors format that uses HNSW graph to store and search for vectors. But vectors are binary + * quantized using {@link Lucene104ScalarQuantizedVectorsFormat} before being stored in the graph. + */ +public class Lucene104HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "Lucene104HnswBinaryQuantizedVectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private static final FlatVectorsFormat flatVectorsFormat = + new Lucene104ScalarQuantizedVectorsFormat(); + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public Lucene104HnswScalarQuantizedVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public Lucene104HnswScalarQuantizedVectorsFormat( + int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + + MAXIMUM_MAX_CONN + + "; maxConn=" + + maxConn); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + + MAXIMUM_BEAM_WIDTH + + "; beamWidth=" + + beamWidth); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException( + "No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index b4ea8f29e730..fde541c2ac08 100644 --- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -19,3 +19,4 @@ org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat +org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..bb624e5117c4 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static java.lang.String.format; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.SameThreadExecutorService; + +public class TestLucene104HnswScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + private static final KnnVectorsFormat FORMAT = new Lucene104HnswScalarQuantizedVectorsFormat(); + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(FORMAT); + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene104HnswScalarQuantizedVectorsFormat(10, 20, 1, null); + } + }; + String expectedPattern = + "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=Lucene104ScalarQuantizedVectorsFormat(name=Lucene104ScalarQuantizedVectorsFormat," + + " flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s())," + + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))"; + + var defaultScorer = + format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer"); + var memSegScorer = + format( + Locale.ROOT, + expectedPattern, + "Lucene99MemorySegmentFlatVectorsScorer", + "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); + } + float[] randomVector = randomVector(vector.length); + float trueScore = similarityFunction.compare(vector, randomVector); + TopDocs td = + r.searchNearestVectors( + "f", + randomVector, + 1, + AcceptDocs.fromLiveDocs(null, r.maxDoc()), + Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + // When it's the only vector in a segment, the score should be very close to the true + // score + assertEquals(trueScore, td.scoreDocs[0].score, 0.01f); + } + } + } + } + + public void testLimits() { + expectThrows( + IllegalArgumentException.class, + () -> new Lucene104HnswScalarQuantizedVectorsFormat(-1, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene104HnswScalarQuantizedVectorsFormat(0, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene104HnswScalarQuantizedVectorsFormat(20, 0)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene104HnswScalarQuantizedVectorsFormat(20, -1)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene104HnswScalarQuantizedVectorsFormat(512 + 1, 20)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene104HnswScalarQuantizedVectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> + new Lucene104HnswScalarQuantizedVectorsFormat( + 20, 100, 1, new SameThreadExecutorService())); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); + assertEquals(1L, (long) offHeap.get("vex")); + long corrections = Float.BYTES + Float.BYTES + Float.BYTES + Integer.BYTES; + long expected = fieldInfo.getVectorDimension() + corrections; + assertEquals(expected, (long) offHeap.get("veq")); + assertEquals(3, offHeap.size()); + } + } + } + } +} From 7cec9e802afe6cd282f243ceca53e22ea1277faa Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 8 Sep 2025 13:54:01 -0700 Subject: [PATCH 07/22] enum for scalar encoding --- ...ne104HnswScalarQuantizedVectorsFormat.java | 1 - .../Lucene104ScalarQuantizedVectorScorer.java | 4 +- ...Lucene104ScalarQuantizedVectorsFormat.java | 71 +++++++++++++++++-- ...ne104HnswScalarQuantizedVectorsFormat.java | 2 +- ...Lucene104ScalarQuantizedVectorsFormat.java | 3 +- 5 files changed, 72 insertions(+), 9 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java index b90c697e7846..b24dbd8df0b9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java @@ -28,7 +28,6 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; -import org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index 042aaf77abc5..5e57e80d4e26 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -71,7 +71,9 @@ public RandomVectorScorer getRandomVectorScorer( @Override public String toString() { - return "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + return "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + + nonQuantizedDelegate + + ")"; } private static final class ScalarQuantizedVectorScorerSupplier diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index c62a56c33a21..b67c062da3be 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -17,6 +17,7 @@ package org.apache.lucene.codecs.lucene104; import java.io.IOException; +import java.util.Optional; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; @@ -77,11 +78,12 @@ * *
    *
  • int the field number - *
  • int the vector encoding ordinal XXX wut is this? + *
  • int the vector encoding ordinal *
  • int the vector similarity ordinal - *
  • vint the vector dimensions - *
  • vlong the offset to the vector data in the .veb file - *
  • vlong the length of the vector data in the .veb file + *
  • vint the vector dimensions XXX encode a value indicating scalar encoding (8-bit or + * 4-bit packed). + *
  • vlong the offset to the vector data in the .veq file + *
  • vlong the length of the vector data in the .veq file *
  • vint the number of vectors *
  • [float] the centroid *
  • float the centroid square magnitude @@ -106,9 +108,66 @@ public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat { private static final Lucene104ScalarQuantizedVectorScorer scorer = new Lucene104ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); - /** Creates a new instance with the default number of vectors per cluster. */ + private final ScalarEncoding encoding; + + /** + * Allowed encodings for scalar quantization. + * + *

    This specifies how many bits are used per dimension and also dictates packing of dimensions + * into a byte stream. + */ + public enum ScalarEncoding { + /** Each dimension is quantized to 8 bits and treated as an unsigned value. */ + UNSIGNED_BYTE(0, (byte) 8), + /** Each dimension is quantized to 4 bits two values are packed into each output byte. */ + PACKED_NIBBLE(1, (byte) 4); + + /** The number used to identify this encoding on the wire, rather than relying on ordinal. */ + private int wireNumber; + + private byte bits; + private int dimensionsPerByte; + + ScalarEncoding(int wireNumber, byte bits) { + this.wireNumber = wireNumber; + this.bits = bits; + assert 8 % bits == 0; + this.dimensionsPerByte = 8 / bits; + } + + int getWireNumber() { + return wireNumber; + } + + int getDimensionsPerByte() { + return dimensionsPerByte; + } + + /** Return the number of bits used per dimension. */ + public byte getBits() { + return bits; + } + + /** Returns the encoding for the given wire number, or empty if unknown. */ + public static Optional fromWireNumber(int wireNumber) { + for (ScalarEncoding encoding : values()) { + if (encoding.wireNumber == wireNumber) { + return Optional.of(encoding); + } + } + return Optional.empty(); + } + } + + /** Creates a new instance with UNSIGNED_BYTE encoding. */ public Lucene104ScalarQuantizedVectorsFormat() { + this(ScalarEncoding.UNSIGNED_BYTE); + } + + /** Creates a new instance with the chosen encoding. */ + public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding) { super(NAME); + this.encoding = encoding; } @Override @@ -132,6 +191,8 @@ public int getMaxDimensions(String fieldName) { public String toString() { return "Lucene104ScalarQuantizedVectorsFormat(name=" + NAME + + ", encoding=" + + encoding + ", flatVectorScorer=" + scorer + ", rawVectorFormat=" diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java index bb624e5117c4..7870a6a02d8e 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java @@ -19,7 +19,6 @@ import static java.lang.String.format; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -69,6 +68,7 @@ public KnnVectorsFormat knnVectorsFormat() { String expectedPattern = "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20," + " flatVectorFormat=Lucene104ScalarQuantizedVectorsFormat(name=Lucene104ScalarQuantizedVectorsFormat," + + " encoding=UNSIGNED_BYTE," + " flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s())," + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))"; diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index 957314444f3d..4cdaa76600e1 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -97,6 +97,7 @@ public KnnVectorsFormat knnVectorsFormat() { String expectedPattern = "Lucene104ScalarQuantizedVectorsFormat(" + "name=Lucene104ScalarQuantizedVectorsFormat, " + + "encoding=UNSIGNED_BYTE, " + "flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()), " + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))"; var defaultScorer = @@ -166,7 +167,7 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { vectorValues.vectorValue(docIndexIterator.index()), expectedVector, // XXX FIXME - (byte)8, + (byte) 8, centroid); assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); From 7b76f3d1c9508f79ae36ea1568ca28f07c0785eb Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 8 Sep 2025 14:38:27 -0700 Subject: [PATCH 08/22] fix most of the write path --- ...ne104HnswScalarQuantizedVectorsFormat.java | 11 ++--- ...Lucene104ScalarQuantizedVectorsFormat.java | 6 +-- ...Lucene104ScalarQuantizedVectorsReader.java | 11 +++-- ...Lucene104ScalarQuantizedVectorsWriter.java | 42 ++++++++++++------- .../OffHeapScalarQuantizedVectorValues.java | 39 ++++++++++++----- .../lucene104/QuantizedByteVectorValues.java | 4 +- ...ne104HnswScalarQuantizedVectorsFormat.java | 6 ++- 7 files changed, 79 insertions(+), 40 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java index b24dbd8df0b9..88fca83d2923 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java @@ -28,6 +28,7 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; @@ -58,15 +59,14 @@ public class Lucene104HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat private final int beamWidth; /** The format for storing, reading, merging vectors on disk */ - private static final FlatVectorsFormat flatVectorsFormat = - new Lucene104ScalarQuantizedVectorsFormat(); + private final Lucene104ScalarQuantizedVectorsFormat flatVectorsFormat; private final int numMergeWorkers; private final TaskExecutor mergeExec; /** Constructs a format using default graph construction parameters */ public Lucene104HnswScalarQuantizedVectorsFormat() { - this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + this(ScalarEncoding.UNSIGNED_BYTE, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); } /** @@ -76,7 +76,7 @@ public Lucene104HnswScalarQuantizedVectorsFormat() { * @param beamWidth the size of the queue maintained during graph construction. */ public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) { - this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + this(ScalarEncoding.UNSIGNED_BYTE, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); } /** @@ -90,8 +90,9 @@ public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) { * generated by this format to do the merge */ public Lucene104HnswScalarQuantizedVectorsFormat( - int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + ScalarEncoding encoding, int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { super(NAME); + flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding); if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { throw new IllegalArgumentException( "maxConn must be positive and less than or equal to " diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index b67c062da3be..7daef6bcc78a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -80,11 +80,11 @@ *

  • int the field number *
  • int the vector encoding ordinal *
  • int the vector similarity ordinal - *
  • vint the vector dimensions XXX encode a value indicating scalar encoding (8-bit or - * 4-bit packed). + *
  • vint the vector dimensions *
  • vlong the offset to the vector data in the .veq file *
  • vlong the length of the vector data in the .veq file *
  • vint the number of vectors + *
  • vint the wire number for ScalarEncoding *
  • [float] the centroid *
  • float the centroid square magnitude *
  • The sparse vector information, if required, mapping vector ordinal to doc ID @@ -173,7 +173,7 @@ public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding) { @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new Lucene104ScalarQuantizedVectorsWriter( - scorer, rawVectorFormat.fieldsWriter(state), state); + state, encoding, rawVectorFormat.fieldsWriter(state), scorer); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java index 9d96b8a963fb..50aa07a76f8f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -19,7 +19,6 @@ import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; import java.io.IOException; import java.util.HashMap; @@ -28,6 +27,7 @@ import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.CorruptIndexException; @@ -165,6 +165,7 @@ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) th fi.dimension, fi.size, new OptimizedScalarQuantizer(fi.similarityFunction), + fi.scalarEncoding, fi.similarityFunction, vectorScorer, fi.centroid, @@ -207,6 +208,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { fi.dimension, fi.size, new OptimizedScalarQuantizer(fi.similarityFunction), + fi.scalarEncoding, fi.similarityFunction, vectorScorer, fi.centroid, @@ -337,10 +339,10 @@ private record FieldEntry( VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding, int dimension, - int descritizedDimension, long vectorDataOffset, long vectorDataLength, int size, + ScalarEncoding scalarEncoding, float[] centroid, float centroidDP, OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) { @@ -356,7 +358,10 @@ static FieldEntry create( int size = input.readVInt(); final float[] centroid; float centroidDP = 0; + ScalarEncoding scalarEncoding = ScalarEncoding.UNSIGNED_BYTE; if (size > 0) { + int wireNumber = input.readVInt(); + scalarEncoding = ScalarEncoding.fromWireNumber(wireNumber).orElseThrow(() -> new IllegalStateException("Could not get ScalarEncoding from wire number: " + wireNumber)); centroid = new float[dimension]; input.readFloats(centroid, 0, dimension); centroidDP = Float.intBitsToFloat(input.readInt()); @@ -369,10 +374,10 @@ static FieldEntry create( similarityFunction, vectorEncoding, dimension, - discretize(dimension, 64), vectorDataOffset, vectorDataLength, size, + scalarEncoding, centroid, centroidDP, conf); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 240583b810c2..c9d6921222cc 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -33,6 +33,7 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.DocsWithFieldSet; @@ -65,6 +66,7 @@ public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter { private final SegmentWriteState segmentWriteState; private final List fields = new ArrayList<>(); private final IndexOutput meta, vectorData; + private final ScalarEncoding encoding; private final FlatVectorsWriter rawVectorDelegate; private final Lucene104ScalarQuantizedVectorScorer vectorsScorer; private boolean finished; @@ -75,11 +77,13 @@ public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter { * @param vectorsScorer the scorer to use for scoring vectors */ protected Lucene104ScalarQuantizedVectorsWriter( - Lucene104ScalarQuantizedVectorScorer vectorsScorer, + SegmentWriteState state, + ScalarEncoding encoding, FlatVectorsWriter rawVectorDelegate, - SegmentWriteState state) + Lucene104ScalarQuantizedVectorScorer vectorsScorer) throws IOException { super(vectorsScorer); + this.encoding = encoding; this.vectorsScorer = vectorsScorer; this.segmentWriteState = state; String metaFileName = @@ -88,7 +92,7 @@ protected Lucene104ScalarQuantizedVectorsWriter( state.segmentSuffix, Lucene104ScalarQuantizedVectorsFormat.META_EXTENSION); - String binarizedVectorDataFileName = + String vectorDataFileName = IndexFileNames.segmentFileName( state.segmentInfo.name, state.segmentSuffix, @@ -96,7 +100,7 @@ protected Lucene104ScalarQuantizedVectorsWriter( this.rawVectorDelegate = rawVectorDelegate; try { meta = state.directory.createOutput(metaFileName, state.context); - vectorData = state.directory.createOutput(binarizedVectorDataFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); CodecUtil.writeIndexHeader( meta, @@ -173,7 +177,6 @@ private void writeField( float centroidDp = !fieldData.getVectors().isEmpty() ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; - // XXX should probably include bits; check existing sq format writeMeta( fieldData.fieldInfo, maxDoc, @@ -190,9 +193,9 @@ private void writeVectors( byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; for (int i = 0; i < fieldData.getVectors().size(); i++) { float[] v = fieldData.getVectors().get(i); - // XXX properly parameterize bits + // XXX must pack PACKED_NIBBLE OptimizedScalarQuantizer.QuantizationResult corrections = - scalarQuantizer.scalarQuantize(v, vector, (byte) 8, clusterCenter); + scalarQuantizer.scalarQuantize(v, vector, encoding.getBits(), clusterCenter); vectorData.writeBytes(vector, vector.length); vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); @@ -236,12 +239,12 @@ private void writeSortedVectors( int[] ordMap, OptimizedScalarQuantizer scalarQuantizer) throws IOException { - // XXX properly parameterize bits byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; for (int ordinal : ordMap) { float[] v = fieldData.getVectors().get(ordinal); + // XXX must pack PACKED_NIBBLE OptimizedScalarQuantizer.QuantizationResult corrections = - scalarQuantizer.scalarQuantize(v, vector, (byte) 8, clusterCenter); + scalarQuantizer.scalarQuantize(v, vector, encoding.getBits(), clusterCenter); vectorData.writeBytes(vector, vector.length); vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); @@ -268,6 +271,7 @@ private void writeMeta( int count = docsWithField.cardinality(); meta.writeVInt(count); if (count > 0) { + meta.writeVInt(encoding.getWireNumber()); final ByteBuffer buffer = ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES) .order(ByteOrder.LITTLE_ENDIAN); @@ -322,6 +326,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE new QuantizedFloatVectorValues( floatVectorValues, new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + encoding, centroid); long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); DocsWithFieldSet docsWithField = writeVectorData(vectorData, quantizedVectorValues); @@ -393,7 +398,6 @@ private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); try { - // XXX I guess we write to a temp file so that we can actually read it? that sucks. tempQuantizedVectorData = segmentWriteState.directory.createTempOutput( vectorData.getName(), "temp", segmentWriteState.context); @@ -406,7 +410,7 @@ private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( DocsWithFieldSet docsWithField = writeVectorData( tempQuantizedVectorData, - new QuantizedFloatVectorValues(floatVectorValues, quantizer, centroid)); + new QuantizedFloatVectorValues(floatVectorValues, quantizer, encoding, centroid)); CodecUtil.writeFooter(tempQuantizedVectorData); IOUtils.close(tempQuantizedVectorData); quantizedDataInput = @@ -435,6 +439,7 @@ private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( centroid, cDotC, quantizer, + encoding, fieldInfo.getVectorSimilarityFunction(), vectorsScorer, finalQuantizedDataInput); @@ -645,13 +650,15 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { private final float[] centroid; private final FloatVectorValues values; private final OptimizedScalarQuantizer quantizer; + private final ScalarEncoding encoding; private int lastOrd = -1; QuantizedFloatVectorValues( - FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { + FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, ScalarEncoding encoding, float[] centroid) { this.values = delegate; this.quantizer = quantizer; + this.encoding = encoding; this.quantized = new byte[delegate.dimension()]; this.centroid = centroid; } @@ -687,6 +694,11 @@ public OptimizedScalarQuantizer getQuantizer() { throw new UnsupportedOperationException(); } + @Override + public ScalarEncoding getScalarEncoding() { + throw new UnsupportedOperationException(); + } + @Override public float[] getCentroid() throws IOException { return centroid; @@ -704,13 +716,13 @@ public VectorScorer scorer(float[] target) throws IOException { @Override public QuantizedByteVectorValues copy() throws IOException { - return new QuantizedFloatVectorValues(values.copy(), quantizer, centroid); + return new QuantizedFloatVectorValues(values.copy(), quantizer, encoding, centroid); } private void quantize(int ord) throws IOException { - // XXX properly parameterize bits + // XXX pack PACKED_NIBBLE, maybe??? corrections = - quantizer.scalarQuantize(values.vectorValue(ord), quantized, (byte) 8, centroid); + quantizer.scalarQuantize(values.vectorValue(ord), quantized, encoding.getBits(), centroid); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java index 57453e09c8ab..1d89ebc7920e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.VectorSimilarityFunction; @@ -49,16 +50,17 @@ public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVe final float[] correctiveValues; int quantizedComponentSum; final OptimizedScalarQuantizer quantizer; + final ScalarEncoding encoding; final float[] centroid; final float centroidDp; - // XXX this needs bits, probably??? OffHeapScalarQuantizedVectorValues( int dimension, int size, float[] centroid, float centroidDp, OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer, IndexInput slice) { @@ -74,6 +76,7 @@ public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVe this.byteBuffer = ByteBuffer.allocate(dimension); this.vectorValue = byteBuffer.array(); this.quantizer = quantizer; + this.encoding = encoding; } @Override @@ -123,6 +126,11 @@ public OptimizedScalarQuantizer getQuantizer() { return quantizer; } + @Override + public ScalarEncoding getScalarEncoding() { + return encoding; + } + @Override public float[] getCentroid() { return centroid; @@ -137,7 +145,8 @@ static OffHeapScalarQuantizedVectorValues load( OrdToDocDISIReaderConfiguration configuration, int dimension, int size, - OptimizedScalarQuantizer binaryQuantizer, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer, float[] centroid, @@ -159,7 +168,8 @@ static OffHeapScalarQuantizedVectorValues load( size, centroid, centroidDp, - binaryQuantizer, + quantizer, + encoding, similarityFunction, vectorsScorer, bytesSlice); @@ -170,7 +180,8 @@ static OffHeapScalarQuantizedVectorValues load( size, centroid, centroidDp, - binaryQuantizer, + quantizer, + encoding, vectorData, similarityFunction, vectorsScorer, @@ -178,14 +189,15 @@ static OffHeapScalarQuantizedVectorValues load( } } - /** Dense off-heap binarized vector values */ + /** Dense off-heap scalar quantized vector values */ static class DenseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { DenseOffHeapVectorValues( int dimension, int size, float[] centroid, float centroidDp, - OptimizedScalarQuantizer binaryQuantizer, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer, IndexInput slice) { @@ -194,7 +206,8 @@ static class DenseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues size, centroid, centroidDp, - binaryQuantizer, + quantizer, + encoding, similarityFunction, vectorsScorer, slice); @@ -208,6 +221,7 @@ public OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy() throws centroid, centroidDp, quantizer, + encoding, similarityFunction, vectorsScorer, slice.clone()); @@ -243,7 +257,7 @@ public DocIndexIterator iterator() { } } - /** Sparse off-heap binarized vector values */ + /** Sparse off-heap scalar quantized vector values */ private static class SparseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { private final DirectMonotonicReader ordToDoc; private final IndexedDISI disi; @@ -257,7 +271,8 @@ private static class SparseOffHeapVectorValues extends OffHeapScalarQuantizedVec int size, float[] centroid, float centroidDp, - OptimizedScalarQuantizer binaryQuantizer, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, IndexInput dataIn, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer, @@ -268,7 +283,8 @@ private static class SparseOffHeapVectorValues extends OffHeapScalarQuantizedVec size, centroid, centroidDp, - binaryQuantizer, + quantizer, + encoding, similarityFunction, vectorsScorer, slice); @@ -287,6 +303,7 @@ public SparseOffHeapVectorValues copy() throws IOException { centroid, centroidDp, quantizer, + encoding, dataIn, similarityFunction, vectorsScorer, @@ -346,7 +363,7 @@ private static class EmptyOffHeapVectorValues extends OffHeapScalarQuantizedVect int dimension, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer) { - super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null); + super(dimension, 0, null, Float.NaN, null, null, similarityFunction, vectorsScorer, null); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java index f497a277d50d..87166a0cca8f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java @@ -17,12 +17,12 @@ package org.apache.lucene.codecs.lucene104; import java.io.IOException; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; -// XXX do I want to make this public? this also overlaps heavily with the binarized version. /** Scalar quantized byte vector values */ abstract class QuantizedByteVectorValues extends ByteVectorValues { @@ -58,6 +58,8 @@ public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(i */ public abstract OptimizedScalarQuantizer getQuantizer(); + public abstract ScalarEncoding getScalarEncoding(); + public abstract float[] getCentroid() throws IOException; /** diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java index 7870a6a02d8e..7f01f1bbe852 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java @@ -29,6 +29,7 @@ import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; @@ -62,7 +63,8 @@ public void testToString() { new FilterCodec("foo", Codec.getDefault()) { @Override public KnnVectorsFormat knnVectorsFormat() { - return new Lucene104HnswScalarQuantizedVectorsFormat(10, 20, 1, null); + return new Lucene104HnswScalarQuantizedVectorsFormat( + ScalarEncoding.UNSIGNED_BYTE, 10, 20, 1, null); } }; String expectedPattern = @@ -140,7 +142,7 @@ public void testLimits() { IllegalArgumentException.class, () -> new Lucene104HnswScalarQuantizedVectorsFormat( - 20, 100, 1, new SameThreadExecutorService())); + ScalarEncoding.UNSIGNED_BYTE, 20, 100, 1, new SameThreadExecutorService())); } // Ensures that all expected vector similarity functions are translatable in the format. From cf4fdef88a094497a22f2e18543ee4e7b55cc330 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 8 Sep 2025 15:27:45 -0700 Subject: [PATCH 09/22] packing without testing --- .../Lucene104ScalarQuantizedVectorScorer.java | 29 +++++++++---- ...Lucene104ScalarQuantizedVectorsFormat.java | 12 +++--- ...Lucene104ScalarQuantizedVectorsWriter.java | 42 +++++++++++++++---- 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index 5e57e80d4e26..07994697d31c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -47,9 +47,8 @@ public RandomVectorScorer getRandomVectorScorer( VectorUtil.l2normalize(copy); } target = copy; - // XXX parameterize the number of bits var targetCorrectiveTerms = - quantizer.scalarQuantize(target, targetQuantized, (byte) 8, qv.getCentroid()); + quantizer.scalarQuantize(target, targetQuantized, qv.getScalarEncoding().getBits(), qv.getCentroid()); return new RandomVectorScorer.AbstractRandomVectorScorer(qv) { @Override public float score(int node) throws IOException { @@ -114,10 +113,17 @@ public RandomVectorScorerSupplier copy() throws IOException { } } - public static final float EIGHT_BIT_SCALE = 1f / ((1 << 8) - 1); + private static final float[] SCALE_LUT = new float[]{ + 1f, + 1f / ((1 << 2) - 1), + 1f / ((1 << 3) - 1), + 1f / ((1 << 4) - 1), + 1f / ((1 << 5) - 1), + 1f / ((1 << 6) - 1), + 1f / ((1 << 7) - 1), + 1f / ((1 << 8) - 1), + }; - // XXX factor this out to share with Lucene102BinaryFlatVectorsScorer - // we need to know how many bits were used for both the query and index vector for scaling. static float quantizedScore( byte[] quantizedQuery, OptimizedScalarQuantizer.QuantizationResult queryCorrections, @@ -125,16 +131,21 @@ static float quantizedScore( int targetOrd, VectorSimilarityFunction similarityFunction) throws IOException { - byte[] binaryCode = targetVectors.vectorValue(targetOrd); - float qcDist = VectorUtil.uint8DotProduct(quantizedQuery, binaryCode); + var scalarEncoding = targetVectors.getScalarEncoding(); + byte[] quantizedDoc = targetVectors.vectorValue(targetOrd); + float qcDist = switch(scalarEncoding) { + case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc); + case PACKED_NIBBLE -> VectorUtil.int4DotProductPacked(quantizedQuery, quantizedDoc); + }; OptimizedScalarQuantizer.QuantizationResult indexCorrections = targetVectors.getCorrectiveTerms(targetOrd); + float scale = SCALE_LUT[scalarEncoding.getBits() - 1]; float x1 = indexCorrections.quantizedComponentSum(); float ax = indexCorrections.lowerInterval(); // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary - float lx = (indexCorrections.upperInterval() - ax) * EIGHT_BIT_SCALE; + float lx = (indexCorrections.upperInterval() - ax) * scale; float ay = queryCorrections.lowerInterval(); - float ly = (queryCorrections.upperInterval() - ay) * EIGHT_BIT_SCALE; + float ly = (queryCorrections.upperInterval() - ay) * scale; float y1 = queryCorrections.quantizedComponentSum(); float score = ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index 7daef6bcc78a..0b60509af5c9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -126,28 +126,26 @@ public enum ScalarEncoding { private int wireNumber; private byte bits; - private int dimensionsPerByte; ScalarEncoding(int wireNumber, byte bits) { + assert 8 % bits == 0; this.wireNumber = wireNumber; this.bits = bits; - assert 8 % bits == 0; - this.dimensionsPerByte = 8 / bits; } int getWireNumber() { return wireNumber; } - int getDimensionsPerByte() { - return dimensionsPerByte; - } - /** Return the number of bits used per dimension. */ public byte getBits() { return bits; } + public int packedLength(int dimensions) { + return (dimensions * bits + 7) / 8; + } + /** Returns the encoding for the given wire number, or empty if unknown. */ public static Optional fromWireNumber(int wireNumber) { for (ScalarEncoding encoding : values()) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index c9d6921222cc..51619f9f06ac 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -190,12 +190,18 @@ private void writeField( private void writeVectors( FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) throws IOException { - byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] scratch = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] vector = switch(encoding) { + case UNSIGNED_BYTE -> scratch; + case PACKED_NIBBLE -> new byte[encoding.packedLength(scratch.length)]; + }; for (int i = 0; i < fieldData.getVectors().size(); i++) { float[] v = fieldData.getVectors().get(i); - // XXX must pack PACKED_NIBBLE OptimizedScalarQuantizer.QuantizationResult corrections = - scalarQuantizer.scalarQuantize(v, vector, encoding.getBits(), clusterCenter); + scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter); + if (encoding == ScalarEncoding.PACKED_NIBBLE) { + packNibbles(scratch, vector); + } vectorData.writeBytes(vector, vector.length); vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); @@ -239,12 +245,18 @@ private void writeSortedVectors( int[] ordMap, OptimizedScalarQuantizer scalarQuantizer) throws IOException { - byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] scratch = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] vector = switch(encoding) { + case UNSIGNED_BYTE -> scratch; + case PACKED_NIBBLE -> new byte[encoding.packedLength(scratch.length)]; + }; for (int ordinal : ordMap) { float[] v = fieldData.getVectors().get(ordinal); - // XXX must pack PACKED_NIBBLE OptimizedScalarQuantizer.QuantizationResult corrections = - scalarQuantizer.scalarQuantize(v, vector, encoding.getBits(), clusterCenter); + scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter); + if (encoding == ScalarEncoding.PACKED_NIBBLE) { + packNibbles(scratch, vector); + } vectorData.writeBytes(vector, vector.length); vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); @@ -253,6 +265,13 @@ private void writeSortedVectors( } } + private static void packNibbles(byte[] unpacked, byte[] packed) { + for (int i = 0; i < packed.length; i++) { + int x = unpacked[i] << 4 | unpacked[packed.length + i]; + packed[i] = (byte) x; + } + } + private void writeMeta( FieldInfo field, int maxDoc, @@ -647,6 +666,7 @@ public long ramBytesUsed() { static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { private OptimizedScalarQuantizer.QuantizationResult corrections; private final byte[] quantized; + private final byte[] packed; private final float[] centroid; private final FloatVectorValues values; private final OptimizedScalarQuantizer quantizer; @@ -660,6 +680,10 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { this.quantizer = quantizer; this.encoding = encoding; this.quantized = new byte[delegate.dimension()]; + this.packed = switch (encoding) { + case UNSIGNED_BYTE -> this.quantized; + case PACKED_NIBBLE -> new byte[encoding.packedLength(this.quantized.length)]; + }; this.centroid = centroid; } @@ -681,7 +705,7 @@ public byte[] vectorValue(int ord) throws IOException { quantize(ord); lastOrd = ord; } - return quantized; + return packed; } @Override @@ -720,9 +744,11 @@ public QuantizedByteVectorValues copy() throws IOException { } private void quantize(int ord) throws IOException { - // XXX pack PACKED_NIBBLE, maybe??? corrections = quantizer.scalarQuantize(values.vectorValue(ord), quantized, encoding.getBits(), centroid); + if (encoding == ScalarEncoding.PACKED_NIBBLE) { + packNibbles(quantized, packed); + } } @Override From 132c8ee152227cf3cdbcae3ad06c83cae4693f87 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 8 Sep 2025 16:04:44 -0700 Subject: [PATCH 10/22] flat vectors test --- ...ne104HnswScalarQuantizedVectorsFormat.java | 14 ++++-- .../Lucene104ScalarQuantizedVectorScorer.java | 33 +++++++------ ...Lucene104ScalarQuantizedVectorsReader.java | 11 ++++- ...Lucene104ScalarQuantizedVectorsWriter.java | 48 +++++++++---------- .../OffHeapScalarQuantizedVectorValues.java | 28 +++++++++-- .../org/apache/lucene/util/VectorUtil.java | 2 +- ...Lucene104ScalarQuantizedVectorsFormat.java | 28 ++++++++--- 7 files changed, 109 insertions(+), 55 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java index 88fca83d2923..7d49183249e2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java @@ -27,7 +27,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; @@ -66,7 +65,12 @@ public class Lucene104HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat /** Constructs a format using default graph construction parameters */ public Lucene104HnswScalarQuantizedVectorsFormat() { - this(ScalarEncoding.UNSIGNED_BYTE, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + this( + ScalarEncoding.UNSIGNED_BYTE, + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + DEFAULT_NUM_MERGE_WORKER, + null); } /** @@ -90,7 +94,11 @@ public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) { * generated by this format to do the merge */ public Lucene104HnswScalarQuantizedVectorsFormat( - ScalarEncoding encoding, int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + ScalarEncoding encoding, + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService mergeExec) { super(NAME); flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding); if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index 07994697d31c..5aacee8b0af3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -48,7 +48,8 @@ public RandomVectorScorer getRandomVectorScorer( } target = copy; var targetCorrectiveTerms = - quantizer.scalarQuantize(target, targetQuantized, qv.getScalarEncoding().getBits(), qv.getCentroid()); + quantizer.scalarQuantize( + target, targetQuantized, qv.getScalarEncoding().getBits(), qv.getCentroid()); return new RandomVectorScorer.AbstractRandomVectorScorer(qv) { @Override public float score(int node) throws IOException { @@ -113,16 +114,17 @@ public RandomVectorScorerSupplier copy() throws IOException { } } - private static final float[] SCALE_LUT = new float[]{ - 1f, - 1f / ((1 << 2) - 1), - 1f / ((1 << 3) - 1), - 1f / ((1 << 4) - 1), - 1f / ((1 << 5) - 1), - 1f / ((1 << 6) - 1), - 1f / ((1 << 7) - 1), - 1f / ((1 << 8) - 1), - }; + private static final float[] SCALE_LUT = + new float[] { + 1f, + 1f / ((1 << 2) - 1), + 1f / ((1 << 3) - 1), + 1f / ((1 << 4) - 1), + 1f / ((1 << 5) - 1), + 1f / ((1 << 6) - 1), + 1f / ((1 << 7) - 1), + 1f / ((1 << 8) - 1), + }; static float quantizedScore( byte[] quantizedQuery, @@ -133,10 +135,11 @@ static float quantizedScore( throws IOException { var scalarEncoding = targetVectors.getScalarEncoding(); byte[] quantizedDoc = targetVectors.vectorValue(targetOrd); - float qcDist = switch(scalarEncoding) { - case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc); - case PACKED_NIBBLE -> VectorUtil.int4DotProductPacked(quantizedQuery, quantizedDoc); - }; + float qcDist = + switch (scalarEncoding) { + case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc); + case PACKED_NIBBLE -> VectorUtil.int4DotProductPacked(quantizedQuery, quantizedDoc); + }; OptimizedScalarQuantizer.QuantizationResult indexCorrections = targetVectors.getCorrectiveTerms(targetOrd); float scale = SCALE_LUT[scalarEncoding.getBits() - 1]; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java index 50aa07a76f8f..f0b27d5aaaf2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -137,7 +137,9 @@ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { } long numQuantizedVectorBytes = - Math.multiplyExact((dimension + (Float.BYTES * 3) + Integer.BYTES), (long) fieldEntry.size); + Math.multiplyExact( + (fieldEntry.scalarEncoding.packedLength(dimension) + (Float.BYTES * 3) + Integer.BYTES), + (long) fieldEntry.size); if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { throw new IllegalStateException( "vector data length " @@ -361,7 +363,12 @@ static FieldEntry create( ScalarEncoding scalarEncoding = ScalarEncoding.UNSIGNED_BYTE; if (size > 0) { int wireNumber = input.readVInt(); - scalarEncoding = ScalarEncoding.fromWireNumber(wireNumber).orElseThrow(() -> new IllegalStateException("Could not get ScalarEncoding from wire number: " + wireNumber)); + scalarEncoding = + ScalarEncoding.fromWireNumber(wireNumber) + .orElseThrow( + () -> + new IllegalStateException( + "Could not get ScalarEncoding from wire number: " + wireNumber)); centroid = new float[dimension]; input.readFloats(centroid, 0, dimension); centroidDP = Float.intBitsToFloat(input.readInt()); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 51619f9f06ac..65a0db11b14f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -191,16 +191,17 @@ private void writeVectors( FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) throws IOException { byte[] scratch = new byte[fieldData.fieldInfo.getVectorDimension()]; - byte[] vector = switch(encoding) { - case UNSIGNED_BYTE -> scratch; - case PACKED_NIBBLE -> new byte[encoding.packedLength(scratch.length)]; - }; + byte[] vector = + switch (encoding) { + case UNSIGNED_BYTE -> scratch; + case PACKED_NIBBLE -> new byte[encoding.packedLength(scratch.length)]; + }; for (int i = 0; i < fieldData.getVectors().size(); i++) { float[] v = fieldData.getVectors().get(i); OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter); if (encoding == ScalarEncoding.PACKED_NIBBLE) { - packNibbles(scratch, vector); + OffHeapScalarQuantizedVectorValues.packNibbles(scratch, vector); } vectorData.writeBytes(vector, vector.length); vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); @@ -246,16 +247,17 @@ private void writeSortedVectors( OptimizedScalarQuantizer scalarQuantizer) throws IOException { byte[] scratch = new byte[fieldData.fieldInfo.getVectorDimension()]; - byte[] vector = switch(encoding) { - case UNSIGNED_BYTE -> scratch; - case PACKED_NIBBLE -> new byte[encoding.packedLength(scratch.length)]; - }; + byte[] vector = + switch (encoding) { + case UNSIGNED_BYTE -> scratch; + case PACKED_NIBBLE -> new byte[encoding.packedLength(scratch.length)]; + }; for (int ordinal : ordMap) { float[] v = fieldData.getVectors().get(ordinal); OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter); if (encoding == ScalarEncoding.PACKED_NIBBLE) { - packNibbles(scratch, vector); + OffHeapScalarQuantizedVectorValues.packNibbles(scratch, vector); } vectorData.writeBytes(vector, vector.length); vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); @@ -265,13 +267,6 @@ private void writeSortedVectors( } } - private static void packNibbles(byte[] unpacked, byte[] packed) { - for (int i = 0; i < packed.length; i++) { - int x = unpacked[i] << 4 | unpacked[packed.length + i]; - packed[i] = (byte) x; - } - } - private void writeMeta( FieldInfo field, int maxDoc, @@ -675,15 +670,19 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { private int lastOrd = -1; QuantizedFloatVectorValues( - FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, ScalarEncoding encoding, float[] centroid) { + FloatVectorValues delegate, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + float[] centroid) { this.values = delegate; this.quantizer = quantizer; this.encoding = encoding; this.quantized = new byte[delegate.dimension()]; - this.packed = switch (encoding) { - case UNSIGNED_BYTE -> this.quantized; - case PACKED_NIBBLE -> new byte[encoding.packedLength(this.quantized.length)]; - }; + this.packed = + switch (encoding) { + case UNSIGNED_BYTE -> this.quantized; + case PACKED_NIBBLE -> new byte[encoding.packedLength(this.quantized.length)]; + }; this.centroid = centroid; } @@ -745,9 +744,10 @@ public QuantizedByteVectorValues copy() throws IOException { private void quantize(int ord) throws IOException { corrections = - quantizer.scalarQuantize(values.vectorValue(ord), quantized, encoding.getBits(), centroid); + quantizer.scalarQuantize( + values.vectorValue(ord), quantized, encoding.getBits(), centroid); if (encoding == ScalarEncoding.PACKED_NIBBLE) { - packNibbles(quantized, packed); + OffHeapScalarQuantizedVectorValues.packNibbles(quantized, packed); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java index 1d89ebc7920e..ba104bf08ef2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java @@ -72,8 +72,8 @@ public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVe this.centroid = centroid; this.centroidDp = centroidDp; this.correctiveValues = new float[3]; - this.byteSize = dimension + (Float.BYTES * 3) + Integer.BYTES; - this.byteBuffer = ByteBuffer.allocate(dimension); + this.byteSize = encoding.packedLength(dimension) + (Float.BYTES * 3) + Integer.BYTES; + this.byteBuffer = ByteBuffer.allocate(encoding.packedLength(dimension)); this.vectorValue = byteBuffer.array(); this.quantizer = quantizer; this.encoding = encoding; @@ -95,7 +95,7 @@ public byte[] vectorValue(int targetOrd) throws IOException { return vectorValue; } slice.seek((long) targetOrd * byteSize); - slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), dimension); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length); slice.readFloats(correctiveValues, 0, 3); quantizedComponentSum = slice.readInt(); lastOrd = targetOrd; @@ -141,6 +141,17 @@ public int getVectorByteLength() { return dimension; } + static void packNibbles(byte[] unpacked, byte[] packed) { + int limit = (unpacked.length & 1) == 0 ? packed.length : packed.length - 1; + for (int i = 0; i < limit; i++) { + int x = unpacked[i] << 4 | unpacked[packed.length + i]; + packed[i] = (byte) x; + } + if ((unpacked.length & 1) == 1) { + packed[packed.length - 1] = (byte) (unpacked[packed.length] << 4); + } + } + static OffHeapScalarQuantizedVectorValues load( OrdToDocDISIReaderConfiguration configuration, int dimension, @@ -363,7 +374,16 @@ private static class EmptyOffHeapVectorValues extends OffHeapScalarQuantizedVect int dimension, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer) { - super(dimension, 0, null, Float.NaN, null, null, similarityFunction, vectorsScorer, null); + super( + dimension, + 0, + null, + Float.NaN, + null, + ScalarEncoding.UNSIGNED_BYTE, + similarityFunction, + vectorsScorer, + null); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 38b9cf6d67a5..c27a7b6f5fdb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -214,7 +214,7 @@ public static int int4DotProduct(byte[] a, byte[] b) { public static int int4DotProductPacked(byte[] unpacked, byte[] packed) { if (packed.length != ((unpacked.length + 1) >> 1)) { throw new IllegalArgumentException( - "vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length); + "vector dimensions differ: " + unpacked.length + " != 2 * " + packed.length); } return IMPL.int4DotProduct(unpacked, false, packed, true); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index 4cdaa76600e1..5620bcc8ef28 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -26,6 +26,7 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; @@ -45,14 +46,24 @@ import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.junit.Before; public class TestLucene104ScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { - private static final KnnVectorsFormat FORMAT = new Lucene104ScalarQuantizedVectorsFormat(); + private ScalarEncoding encoding; + private KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + encoding = random().nextBoolean() ? ScalarEncoding.UNSIGNED_BYTE : ScalarEncoding.PACKED_NIBBLE; + format = new Lucene104ScalarQuantizedVectorsFormat(encoding); + super.setUp(); + } @Override protected Codec getCodec() { - return TestUtil.alwaysKnnVectorsFormat(FORMAT); + return TestUtil.alwaysKnnVectorsFormat(format); } public void testSearch() throws Exception { @@ -154,7 +165,8 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { assertEquals(centroid.length, dims); OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); - byte[] expectedVector = new byte[dims]; + byte[] scratch = new byte[dims]; + byte[] expectedVector = new byte[encoding.packedLength(dims)]; if (similarityFunction == VectorSimilarityFunction.COSINE) { vectorValues = new Lucene104ScalarQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); @@ -165,10 +177,14 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { OptimizedScalarQuantizer.QuantizationResult corrections = quantizer.scalarQuantize( vectorValues.vectorValue(docIndexIterator.index()), - expectedVector, - // XXX FIXME - (byte) 8, + scratch, + encoding.getBits(), centroid); + switch (encoding) { + case UNSIGNED_BYTE -> System.arraycopy(scratch, 0, expectedVector, 0, dims); + case PACKED_NIBBLE -> + OffHeapScalarQuantizedVectorValues.packNibbles(scratch, expectedVector); + } assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); } From e90b6d113876f2d6916a25e8acde34da50c784db Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Tue, 9 Sep 2025 09:31:50 -0700 Subject: [PATCH 11/22] fix license --- .../Lucene104ScalarQuantizedVectorScorer.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index 5aacee8b0af3..84173525cc9b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.lucene.codecs.lucene104; import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; From f9cc396d7d9dbb8fcd6a0117f86e0e7f60d929da Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Tue, 9 Sep 2025 10:08:57 -0700 Subject: [PATCH 12/22] CHANGES --- lucene/CHANGES.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 28aeee6b0e4f..5f5f7bcd6963 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -120,6 +120,8 @@ New Features * GITHUB#14565: Add ParentsChildrenBlockJoinQuery that supports parent and child filter in the same query along with limiting number of child documents to retrieve per parent. (Jinny Wang) +* GITHUB#15169: Add codecs for 4 and 8 bit Optimized Scalar Quantization vectors (Trevor McCulloch) + Improvements --------------------- From e6e6b6a7ddf6692c2beae20497829fa868e7ac97 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Tue, 9 Sep 2025 11:11:02 -0700 Subject: [PATCH 13/22] handle boundary cases with nibble encoding -- unpacked must always have an even length --- .../Lucene104ScalarQuantizedVectorScorer.java | 5 +++- ...Lucene104ScalarQuantizedVectorsFormat.java | 8 ++++++- ...Lucene104ScalarQuantizedVectorsReader.java | 4 +++- ...Lucene104ScalarQuantizedVectorsWriter.java | 23 ++++++++++++++----- .../OffHeapScalarQuantizedVectorValues.java | 11 ++++----- ...Lucene104ScalarQuantizedVectorsFormat.java | 5 ++-- 6 files changed, 38 insertions(+), 18 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index 84173525cc9b..ed0c4f41990a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -56,7 +56,10 @@ public RandomVectorScorer getRandomVectorScorer( throws IOException { if (vectorValues instanceof QuantizedByteVectorValues qv) { OptimizedScalarQuantizer quantizer = qv.getQuantizer(); - byte[] targetQuantized = new byte[target.length]; + byte[] targetQuantized = + new byte + [OptimizedScalarQuantizer.discretize( + target.length, qv.getScalarEncoding().getDimensionsPerByte())]; // We make a copy as the quantization process mutates the input float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); if (similarityFunction == COSINE) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index 0b60509af5c9..513e895b0a49 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -142,7 +142,13 @@ public byte getBits() { return bits; } - public int packedLength(int dimensions) { + /** Return the number of dimensions that can be packed into a single byte. */ + public int getDimensionsPerByte() { + return 8 / bits; + } + + /** Return the number of bytes required to store a packed vector of the given dimensions. */ + public int getPackedLength(int dimensions) { return (dimensions * bits + 7) / 8; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java index f0b27d5aaaf2..0251224998b4 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -138,7 +138,9 @@ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { long numQuantizedVectorBytes = Math.multiplyExact( - (fieldEntry.scalarEncoding.packedLength(dimension) + (Float.BYTES * 3) + Integer.BYTES), + (fieldEntry.scalarEncoding.getPackedLength(dimension) + + (Float.BYTES * 3) + + Integer.BYTES), (long) fieldEntry.size); if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { throw new IllegalStateException( diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 65a0db11b14f..b8fb50b43d7c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -190,11 +190,15 @@ private void writeField( private void writeVectors( FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) throws IOException { - byte[] scratch = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] scratch = + new byte + [OptimizedScalarQuantizer.discretize( + fieldData.fieldInfo.getVectorDimension(), encoding.getDimensionsPerByte())]; byte[] vector = switch (encoding) { case UNSIGNED_BYTE -> scratch; - case PACKED_NIBBLE -> new byte[encoding.packedLength(scratch.length)]; + case PACKED_NIBBLE -> + new byte[encoding.getPackedLength(fieldData.fieldInfo.getVectorDimension())]; }; for (int i = 0; i < fieldData.getVectors().size(); i++) { float[] v = fieldData.getVectors().get(i); @@ -246,11 +250,15 @@ private void writeSortedVectors( int[] ordMap, OptimizedScalarQuantizer scalarQuantizer) throws IOException { - byte[] scratch = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] scratch = + new byte + [OptimizedScalarQuantizer.discretize( + fieldData.fieldInfo.getVectorDimension(), encoding.getDimensionsPerByte())]; byte[] vector = switch (encoding) { case UNSIGNED_BYTE -> scratch; - case PACKED_NIBBLE -> new byte[encoding.packedLength(scratch.length)]; + case PACKED_NIBBLE -> + new byte[encoding.getPackedLength(fieldData.fieldInfo.getVectorDimension())]; }; for (int ordinal : ordMap) { float[] v = fieldData.getVectors().get(ordinal); @@ -677,11 +685,14 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { this.values = delegate; this.quantizer = quantizer; this.encoding = encoding; - this.quantized = new byte[delegate.dimension()]; + this.quantized = + new byte + [OptimizedScalarQuantizer.discretize( + delegate.dimension(), encoding.getDimensionsPerByte())]; this.packed = switch (encoding) { case UNSIGNED_BYTE -> this.quantized; - case PACKED_NIBBLE -> new byte[encoding.packedLength(this.quantized.length)]; + case PACKED_NIBBLE -> new byte[encoding.getPackedLength(delegate.dimension())]; }; this.centroid = centroid; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java index ba104bf08ef2..584f00ca0f64 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java @@ -72,8 +72,8 @@ public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVe this.centroid = centroid; this.centroidDp = centroidDp; this.correctiveValues = new float[3]; - this.byteSize = encoding.packedLength(dimension) + (Float.BYTES * 3) + Integer.BYTES; - this.byteBuffer = ByteBuffer.allocate(encoding.packedLength(dimension)); + this.byteSize = encoding.getPackedLength(dimension) + (Float.BYTES * 3) + Integer.BYTES; + this.byteBuffer = ByteBuffer.allocate(encoding.getPackedLength(dimension)); this.vectorValue = byteBuffer.array(); this.quantizer = quantizer; this.encoding = encoding; @@ -142,14 +142,11 @@ public int getVectorByteLength() { } static void packNibbles(byte[] unpacked, byte[] packed) { - int limit = (unpacked.length & 1) == 0 ? packed.length : packed.length - 1; - for (int i = 0; i < limit; i++) { + assert unpacked.length == packed.length * 2; + for (int i = 0; i < packed.length; i++) { int x = unpacked[i] << 4 | unpacked[packed.length + i]; packed[i] = (byte) x; } - if ((unpacked.length & 1) == 1) { - packed[packed.length - 1] = (byte) (unpacked[packed.length] << 4); - } } static OffHeapScalarQuantizedVectorValues load( diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index 5620bcc8ef28..04fd5ce13027 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -165,8 +165,9 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { assertEquals(centroid.length, dims); OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); - byte[] scratch = new byte[dims]; - byte[] expectedVector = new byte[encoding.packedLength(dims)]; + byte[] scratch = + new byte[OptimizedScalarQuantizer.discretize(dims, encoding.getDimensionsPerByte())]; + byte[] expectedVector = new byte[encoding.getPackedLength(dims)]; if (similarityFunction == VectorSimilarityFunction.COSINE) { vectorValues = new Lucene104ScalarQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); From 01a9748ffcb5fe3b42cb99306a3e417463aab39f Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Tue, 9 Sep 2025 11:36:59 -0700 Subject: [PATCH 14/22] resilience to small floating point errors --- .../TestLucene104ScalarQuantizedVectorsFormat.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index 04fd5ce13027..e84c6798f848 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -187,7 +187,12 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { OffHeapScalarQuantizedVectorValues.packNibbles(scratch, expectedVector); } assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); - assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); + var actualCorrections = + qvectorValues.getCorrectiveTerms(docIndexIterator.index()); + assertEquals(corrections.lowerInterval(), actualCorrections.lowerInterval(), 0.00001f); + assertEquals(corrections.upperInterval(), actualCorrections.upperInterval(), 0.00001f); + assertEquals(corrections.additionalCorrection(), actualCorrections.additionalCorrection(), 0.00001f); + assertEquals(corrections.quantizedComponentSum(), actualCorrections.quantizedComponentSum()); } } } From 2e5f89dc47e398d8d8060dc4c5b72a10d4581cdc Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Tue, 9 Sep 2025 11:42:39 -0700 Subject: [PATCH 15/22] tidy-- --- .../TestLucene104ScalarQuantizedVectorsFormat.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index e84c6798f848..25be0c78a51b 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -187,12 +187,15 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { OffHeapScalarQuantizedVectorValues.packNibbles(scratch, expectedVector); } assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); - var actualCorrections = - qvectorValues.getCorrectiveTerms(docIndexIterator.index()); + var actualCorrections = qvectorValues.getCorrectiveTerms(docIndexIterator.index()); assertEquals(corrections.lowerInterval(), actualCorrections.lowerInterval(), 0.00001f); assertEquals(corrections.upperInterval(), actualCorrections.upperInterval(), 0.00001f); - assertEquals(corrections.additionalCorrection(), actualCorrections.additionalCorrection(), 0.00001f); - assertEquals(corrections.quantizedComponentSum(), actualCorrections.quantizedComponentSum()); + assertEquals( + corrections.additionalCorrection(), + actualCorrections.additionalCorrection(), + 0.00001f); + assertEquals( + corrections.quantizedComponentSum(), actualCorrections.quantizedComponentSum()); } } } From b30731c3f3a0aa430c45355fca57237b9321650f Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Wed, 10 Sep 2025 10:05:37 -0700 Subject: [PATCH 16/22] remove unnecessary default --- ...Lucene104ScalarQuantizedVectorsWriter.java | 7 +++++++ .../lucene104/QuantizedByteVectorValues.java | 19 +++++++++++-------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index b8fb50b43d7c..cb84bf3e03b5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -671,6 +671,7 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { private final byte[] quantized; private final byte[] packed; private final float[] centroid; + private final float centroidDP; private final FloatVectorValues values; private final OptimizedScalarQuantizer quantizer; private final ScalarEncoding encoding; @@ -695,6 +696,7 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { case PACKED_NIBBLE -> new byte[encoding.getPackedLength(delegate.dimension())]; }; this.centroid = centroid; + this.centroidDP = VectorUtil.dotProduct(centroid, centroid); } @Override @@ -738,6 +740,11 @@ public float[] getCentroid() throws IOException { return centroid; } + @Override + public float getCentroidDP() { + return centroidDP; + } + @Override public int size() { return values.size(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java index 87166a0cca8f..8ccaafa445da 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java @@ -58,10 +58,21 @@ public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(i */ public abstract OptimizedScalarQuantizer getQuantizer(); + /** + * @return the scalar encoding used to pack the vectors. + */ public abstract ScalarEncoding getScalarEncoding(); + /** + * @return the centroid used to center the vectors prior to quantization + */ public abstract float[] getCentroid() throws IOException; + /** + * @return the dot product of the centroid. + */ + public abstract float getCentroidDP() throws IOException; + /** * Return a {@link VectorScorer} for the given query vector. * @@ -72,12 +83,4 @@ public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(i @Override public abstract QuantizedByteVectorValues copy() throws IOException; - - // XXX off heap overrides this. this is probably only used in one other spot so it should be - // abstract. - float getCentroidDP() throws IOException { - // this only gets executed on-merge - float[] centroid = getCentroid(); - return VectorUtil.dotProduct(centroid, centroid); - } } From bb89c01fa42505a1718ab62a06910f2502e134b9 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Wed, 10 Sep 2025 10:13:02 -0700 Subject: [PATCH 17/22] tidy --- .../lucene/codecs/lucene104/QuantizedByteVectorValues.java | 1 - 1 file changed, 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java index 8ccaafa445da..48d0c4e665f1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java @@ -20,7 +20,6 @@ import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.VectorScorer; -import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; /** Scalar quantized byte vector values */ From 11a978a926a2617ab1d232d991266f776e20210f Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Wed, 10 Sep 2025 12:01:16 -0700 Subject: [PATCH 18/22] unpack bytes during updateable scoring --- .../Lucene104ScalarQuantizedVectorScorer.java | 13 +++++++++++-- .../OffHeapScalarQuantizedVectorValues.java | 8 ++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index ed0c4f41990a..6ef80c3e107f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -121,7 +121,16 @@ public float score(int node) throws IOException { @Override public void setScoringOrdinal(int node) throws IOException { - targetVector = targetValues.vectorValue(node); + var rawTargetVector = targetValues.vectorValue(node); + switch (values.getScalarEncoding()) { + case UNSIGNED_BYTE -> targetVector = rawTargetVector; + case PACKED_NIBBLE -> { + if (targetVector == null) { + targetVector = new byte[OptimizedScalarQuantizer.discretize(values.dimension(), 2)]; + } + OffHeapScalarQuantizedVectorValues.unpackNibbles(rawTargetVector, targetVector); + } + } targetCorrectiveTerms = targetValues.getCorrectiveTerms(node); } }; @@ -145,7 +154,7 @@ public RandomVectorScorerSupplier copy() throws IOException { 1f / ((1 << 8) - 1), }; - static float quantizedScore( + private static float quantizedScore( byte[] quantizedQuery, OptimizedScalarQuantizer.QuantizationResult queryCorrections, QuantizedByteVectorValues targetVectors, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java index 584f00ca0f64..d2c678d8f8ba 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java @@ -149,6 +149,14 @@ static void packNibbles(byte[] unpacked, byte[] packed) { } } + static void unpackNibbles(byte[] packed, byte[] unpacked) { + assert unpacked.length == packed.length * 2; + for (int i = 0; i < packed.length; i++) { + unpacked[i] = (byte) ((packed[i] >> 4) & 0x0F); + unpacked[packed.length + i] = (byte) (packed[i] & 0x0F); + } + } + static OffHeapScalarQuantizedVectorValues load( OrdToDocDISIReaderConfiguration configuration, int dimension, From bc5d3852f3efa26e0b48717e960f885f2d06d05a Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 15 Sep 2025 12:48:44 -0700 Subject: [PATCH 19/22] Apply suggestion from @benwtrent Co-authored-by: Benjamin Trent --- .../codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index 6ef80c3e107f..ebce991f200e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -173,7 +173,7 @@ private static float quantizedScore( float scale = SCALE_LUT[scalarEncoding.getBits() - 1]; float x1 = indexCorrections.quantizedComponentSum(); float ax = indexCorrections.lowerInterval(); - // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + // Here we must scale according to the bits float lx = (indexCorrections.upperInterval() - ax) * scale; float ay = queryCorrections.lowerInterval(); float ly = (queryCorrections.upperInterval() - ay) * scale; From 30350e410815109c65148f1a70f97310706b41a1 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 15 Sep 2025 14:00:21 -0700 Subject: [PATCH 20/22] add 7 bit representation --- .../Lucene104ScalarQuantizedVectorScorer.java | 1 + ...Lucene104ScalarQuantizedVectorsFormat.java | 25 +++++++++++++------ ...Lucene104ScalarQuantizedVectorsWriter.java | 3 +++ ...Lucene104ScalarQuantizedVectorsFormat.java | 3 ++- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index 6ef80c3e107f..86a978d8839a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -166,6 +166,7 @@ private static float quantizedScore( float qcDist = switch (scalarEncoding) { case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc); + case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc); case PACKED_NIBBLE -> VectorUtil.int4DotProductPacked(quantizedQuery, quantizedDoc); }; OptimizedScalarQuantizer.QuantizationResult indexCorrections = diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index 513e895b0a49..9acb43c64cf5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -118,19 +118,28 @@ public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat { */ public enum ScalarEncoding { /** Each dimension is quantized to 8 bits and treated as an unsigned value. */ - UNSIGNED_BYTE(0, (byte) 8), + UNSIGNED_BYTE(0, (byte) 8, 1), /** Each dimension is quantized to 4 bits two values are packed into each output byte. */ - PACKED_NIBBLE(1, (byte) 4); + PACKED_NIBBLE(1, (byte) 4, 2), + /** + * Each dimension is quantized to 7 bits and treated as a signed value. + * + *

    This is intended for backwards compatibility with older iterations of scalar quantization. + * This setting will produce an index the same size as {@link #UNSIGNED_BYTE} but will produce + * less accurate vector comparisons. + */ + SEVEN_BIT(2, (byte) 7, 1); /** The number used to identify this encoding on the wire, rather than relying on ordinal. */ - private int wireNumber; + private final int wireNumber; - private byte bits; + private final byte bits; + private final int dimsPerByte; - ScalarEncoding(int wireNumber, byte bits) { - assert 8 % bits == 0; + ScalarEncoding(int wireNumber, byte bits, int dimsPerByte) { this.wireNumber = wireNumber; this.bits = bits; + this.dimsPerByte = dimsPerByte; } int getWireNumber() { @@ -144,12 +153,12 @@ public byte getBits() { /** Return the number of dimensions that can be packed into a single byte. */ public int getDimensionsPerByte() { - return 8 / bits; + return this.dimsPerByte; } /** Return the number of bytes required to store a packed vector of the given dimensions. */ public int getPackedLength(int dimensions) { - return (dimensions * bits + 7) / 8; + return (dimensions + this.dimsPerByte - 1) / this.dimsPerByte; } /** Returns the encoding for the given wire number, or empty if unknown. */ diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index cb84bf3e03b5..12de1f83bd36 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -197,6 +197,7 @@ private void writeVectors( byte[] vector = switch (encoding) { case UNSIGNED_BYTE -> scratch; + case SEVEN_BIT -> scratch; case PACKED_NIBBLE -> new byte[encoding.getPackedLength(fieldData.fieldInfo.getVectorDimension())]; }; @@ -257,6 +258,7 @@ private void writeSortedVectors( byte[] vector = switch (encoding) { case UNSIGNED_BYTE -> scratch; + case SEVEN_BIT -> scratch; case PACKED_NIBBLE -> new byte[encoding.getPackedLength(fieldData.fieldInfo.getVectorDimension())]; }; @@ -693,6 +695,7 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { this.packed = switch (encoding) { case UNSIGNED_BYTE -> this.quantized; + case SEVEN_BIT -> this.quantized; case PACKED_NIBBLE -> new byte[encoding.getPackedLength(delegate.dimension())]; }; this.centroid = centroid; diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index 25be0c78a51b..e6f9fbcd1051 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -56,7 +56,8 @@ public class TestLucene104ScalarQuantizedVectorsFormat extends BaseKnnVectorsFor @Before @Override public void setUp() throws Exception { - encoding = random().nextBoolean() ? ScalarEncoding.UNSIGNED_BYTE : ScalarEncoding.PACKED_NIBBLE; + var encodingValues = ScalarEncoding.values(); + encoding = encodingValues[random().nextInt(encodingValues.length)]; format = new Lucene104ScalarQuantizedVectorsFormat(encoding); super.setUp(); } From 93209dd680f9d49ea97e0eea46d8f81d3a742244 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 15 Sep 2025 14:01:22 -0700 Subject: [PATCH 21/22] mark existing 99 formats as deprecated --- .../lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java | 1 + .../codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java | 1 + 2 files changed, 2 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java index 1966ed21d654..42e9381aa11d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java @@ -42,6 +42,7 @@ * * @lucene.experimental */ +@Deprecated public class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat { public static final String NAME = "Lucene99HnswScalarQuantizedVectorsFormat"; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 0f339ecbe0a8..da18fb8d2e57 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -31,6 +31,7 @@ * * @lucene.experimental */ +@Deprecated public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { // The bits that are allowed for scalar quantization From 2de7d8180f284193ad4d7f27ded6b5c52a77bb9e Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 15 Sep 2025 14:23:26 -0700 Subject: [PATCH 22/22] fix some missing 7 bit checks --- .../codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java | 1 + .../lucene104/TestLucene104ScalarQuantizedVectorsFormat.java | 1 + 2 files changed, 2 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index b2ece366baab..b5a8db5190eb 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -124,6 +124,7 @@ public void setScoringOrdinal(int node) throws IOException { var rawTargetVector = targetValues.vectorValue(node); switch (values.getScalarEncoding()) { case UNSIGNED_BYTE -> targetVector = rawTargetVector; + case SEVEN_BIT -> targetVector = rawTargetVector; case PACKED_NIBBLE -> { if (targetVector == null) { targetVector = new byte[OptimizedScalarQuantizer.discretize(values.dimension(), 2)]; diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index e6f9fbcd1051..29041b5b07f0 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -184,6 +184,7 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { centroid); switch (encoding) { case UNSIGNED_BYTE -> System.arraycopy(scratch, 0, expectedVector, 0, dims); + case SEVEN_BIT -> System.arraycopy(scratch, 0, expectedVector, 0, dims); case PACKED_NIBBLE -> OffHeapScalarQuantizedVectorValues.packNibbles(scratch, expectedVector); }