diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 3ea1326b4608..885f44491a6c 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -130,6 +130,8 @@ New Features * GITHUB#15176: Add `[Float|Byte]VectorValues#rescorer(element[])` interface to allow optimized rescoring of vectors. (Ben Trent) +* GITHUB#15169: Add codecs for 4 and 8 bit Optimized Scalar Quantization vectors (Trevor McCulloch) + Improvements --------------------- # GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch) diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 17876735f1f6..0952f2696868 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -87,7 +87,9 @@ org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat, org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat, org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat, - org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat; + org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat; provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.codecs.lucene104.Lucene104PostingsFormat; provides org.apache.lucene.index.SortFieldProvider with diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..7d49183249e2 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +/** + * A vectors format that uses HNSW graph to store and search for vectors. But vectors are binary + * quantized using {@link Lucene104ScalarQuantizedVectorsFormat} before being stored in the graph. + */ +public class Lucene104HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "Lucene104HnswBinaryQuantizedVectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private final Lucene104ScalarQuantizedVectorsFormat flatVectorsFormat; + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public Lucene104HnswScalarQuantizedVectorsFormat() { + this( + ScalarEncoding.UNSIGNED_BYTE, + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + DEFAULT_NUM_MERGE_WORKER, + null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) { + this(ScalarEncoding.UNSIGNED_BYTE, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public Lucene104HnswScalarQuantizedVectorsFormat( + ScalarEncoding encoding, + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService mergeExec) { + super(NAME); + flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + + MAXIMUM_MAX_CONN + + "; maxConn=" + + maxConn); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + + MAXIMUM_BEAM_WIDTH + + "; beamWidth=" + + beamWidth); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException( + "No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..b5a8db5190eb --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Vector scorer over OptimizedScalarQuantized vectors */ +public class Lucene104ScalarQuantizedVectorScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + + public Lucene104ScalarQuantizedVectorScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues qv) { + return new ScalarQuantizedVectorScorerSupplier(qv, similarityFunction); + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues qv) { + OptimizedScalarQuantizer quantizer = qv.getQuantizer(); + byte[] targetQuantized = + new byte + [OptimizedScalarQuantizer.discretize( + target.length, qv.getScalarEncoding().getDimensionsPerByte())]; + // We make a copy as the quantization process mutates the input + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + if (similarityFunction == COSINE) { + VectorUtil.l2normalize(copy); + } + target = copy; + var targetCorrectiveTerms = + quantizer.scalarQuantize( + target, targetQuantized, qv.getScalarEncoding().getBits(), qv.getCentroid()); + return new RandomVectorScorer.AbstractRandomVectorScorer(qv) { + @Override + public float score(int node) throws IOException { + return quantizedScore( + targetQuantized, targetCorrectiveTerms, qv, node, similarityFunction); + } + }; + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) + throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public String toString() { + return "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + + nonQuantizedDelegate + + ")"; + } + + private static final class ScalarQuantizedVectorScorerSupplier + implements RandomVectorScorerSupplier { + private final QuantizedByteVectorValues targetValues; + private final QuantizedByteVectorValues values; + private final VectorSimilarityFunction similarity; + + public ScalarQuantizedVectorScorerSupplier( + QuantizedByteVectorValues values, VectorSimilarityFunction similarity) throws IOException { + this.targetValues = values.copy(); + this.values = values; + this.similarity = similarity; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private byte[] targetVector; + private OptimizedScalarQuantizer.QuantizationResult targetCorrectiveTerms; + + @Override + public float score(int node) throws IOException { + return quantizedScore(targetVector, targetCorrectiveTerms, values, node, similarity); + } + + @Override + public void setScoringOrdinal(int node) throws IOException { + var rawTargetVector = targetValues.vectorValue(node); + switch (values.getScalarEncoding()) { + case UNSIGNED_BYTE -> targetVector = rawTargetVector; + case SEVEN_BIT -> targetVector = rawTargetVector; + case PACKED_NIBBLE -> { + if (targetVector == null) { + targetVector = new byte[OptimizedScalarQuantizer.discretize(values.dimension(), 2)]; + } + OffHeapScalarQuantizedVectorValues.unpackNibbles(rawTargetVector, targetVector); + } + } + targetCorrectiveTerms = targetValues.getCorrectiveTerms(node); + } + }; + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new ScalarQuantizedVectorScorerSupplier(values.copy(), similarity); + } + } + + private static final float[] SCALE_LUT = + new float[] { + 1f, + 1f / ((1 << 2) - 1), + 1f / ((1 << 3) - 1), + 1f / ((1 << 4) - 1), + 1f / ((1 << 5) - 1), + 1f / ((1 << 6) - 1), + 1f / ((1 << 7) - 1), + 1f / ((1 << 8) - 1), + }; + + private static float quantizedScore( + byte[] quantizedQuery, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + QuantizedByteVectorValues targetVectors, + int targetOrd, + VectorSimilarityFunction similarityFunction) + throws IOException { + var scalarEncoding = targetVectors.getScalarEncoding(); + byte[] quantizedDoc = targetVectors.vectorValue(targetOrd); + float qcDist = + switch (scalarEncoding) { + case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc); + case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc); + case PACKED_NIBBLE -> VectorUtil.int4DotProductPacked(quantizedQuery, quantizedDoc); + }; + OptimizedScalarQuantizer.QuantizationResult indexCorrections = + targetVectors.getCorrectiveTerms(targetOrd); + float scale = SCALE_LUT[scalarEncoding.getBits() - 1]; + float x1 = indexCorrections.quantizedComponentSum(); + float ax = indexCorrections.lowerInterval(); + // Here we must scale according to the bits + float lx = (indexCorrections.upperInterval() - ax) * scale; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * scale; + float y1 = queryCorrections.quantizedComponentSum(); + float score = + ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = + queryCorrections.additionalCorrection() + + indexCorrections.additionalCorrection() + - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += + queryCorrections.additionalCorrection() + + indexCorrections.additionalCorrection() + - targetVectors.getCentroidDP(); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..9acb43c64cf5 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import java.io.IOException; +import java.util.Optional; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * The quantization format used here is a per-vector optimized scalar quantization. These ideas are + * evolutions of LVQ proposed in Similarity search in the + * blink of an eye with compressed indices by Cecilia Aguerrebere et al., the previous work on + * globally optimized scalar quantization in Apache Lucene, and Accelerating Large-Scale Inference with Anisotropic + * Vector Quantization by Ruiqi Guo et. al. Also see {@link + * org.apache.lucene.util.quantization.OptimizedScalarQuantizer}. Some of key features are: + * + * + * + * A previous work related to improvements over regular LVQ is Practical and Asymptotically Optimal Quantization of + * High-Dimensional Vectors in Euclidean Space for Approximate Nearest Neighbor Search by + * Jianyang Gao, et. al. + * + *

The format is stored within two files: + * + *

.veq (vector data) file

+ * + *

Stores the quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

+ * + *

.vemq (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

+ */ +public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat { + public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC"; + public static final String NAME = "Lucene104ScalarQuantizedVectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "Lucene104ScalarQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene104ScalarQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemq"; + static final String VECTOR_DATA_EXTENSION = "veq"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + private static final Lucene104ScalarQuantizedVectorScorer scorer = + new Lucene104ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + private final ScalarEncoding encoding; + + /** + * Allowed encodings for scalar quantization. + * + *

This specifies how many bits are used per dimension and also dictates packing of dimensions + * into a byte stream. + */ + public enum ScalarEncoding { + /** Each dimension is quantized to 8 bits and treated as an unsigned value. */ + UNSIGNED_BYTE(0, (byte) 8, 1), + /** Each dimension is quantized to 4 bits two values are packed into each output byte. */ + PACKED_NIBBLE(1, (byte) 4, 2), + /** + * Each dimension is quantized to 7 bits and treated as a signed value. + * + *

This is intended for backwards compatibility with older iterations of scalar quantization. + * This setting will produce an index the same size as {@link #UNSIGNED_BYTE} but will produce + * less accurate vector comparisons. + */ + SEVEN_BIT(2, (byte) 7, 1); + + /** The number used to identify this encoding on the wire, rather than relying on ordinal. */ + private final int wireNumber; + + private final byte bits; + private final int dimsPerByte; + + ScalarEncoding(int wireNumber, byte bits, int dimsPerByte) { + this.wireNumber = wireNumber; + this.bits = bits; + this.dimsPerByte = dimsPerByte; + } + + int getWireNumber() { + return wireNumber; + } + + /** Return the number of bits used per dimension. */ + public byte getBits() { + return bits; + } + + /** Return the number of dimensions that can be packed into a single byte. */ + public int getDimensionsPerByte() { + return this.dimsPerByte; + } + + /** Return the number of bytes required to store a packed vector of the given dimensions. */ + public int getPackedLength(int dimensions) { + return (dimensions + this.dimsPerByte - 1) / this.dimsPerByte; + } + + /** Returns the encoding for the given wire number, or empty if unknown. */ + public static Optional fromWireNumber(int wireNumber) { + for (ScalarEncoding encoding : values()) { + if (encoding.wireNumber == wireNumber) { + return Optional.of(encoding); + } + } + return Optional.empty(); + } + } + + /** Creates a new instance with UNSIGNED_BYTE encoding. */ + public Lucene104ScalarQuantizedVectorsFormat() { + this(ScalarEncoding.UNSIGNED_BYTE); + } + + /** Creates a new instance with the chosen encoding. */ + public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding) { + super(NAME); + this.encoding = encoding; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene104ScalarQuantizedVectorsWriter( + state, encoding, rawVectorFormat.fieldsWriter(state), scorer); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene104ScalarQuantizedVectorsReader( + state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene104ScalarQuantizedVectorsFormat(name=" + + NAME + + ", encoding=" + + encoding + + ", flatVectorScorer=" + + scorer + + ", rawVectorFormat=" + + rawVectorFormat + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java new file mode 100644 index 000000000000..0251224998b4 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -0,0 +1,451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Reader for scalar quantized vectors in the Lucene 10.4 format. */ +class Lucene104ScalarQuantizedVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(Lucene104ScalarQuantizedVectorsReader.class); + + private final Map fields = new HashMap<>(); + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + private final Lucene104ScalarQuantizedVectorScorer vectorScorer; + + Lucene104ScalarQuantizedVectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + Lucene104ScalarQuantizedVectorScorer vectorsScorer) + throws IOException { + super(vectorsScorer); + this.vectorScorer = vectorsScorer; + this.rawVectorsReader = rawVectorsReader; + int versionMeta = -1; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene104ScalarQuantizedVectorsFormat.META_EXTENSION); + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + Lucene104ScalarQuantizedVectorsFormat.META_CODEC_NAME, + Lucene104ScalarQuantizedVectorsFormat.VERSION_START, + Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + quantizedVectorData = + openDataInput( + state, + versionMeta, + VECTOR_DATA_EXTENSION, + Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Quantized vectors are accessed randomly from their node ID stored in the HNSW + // graph. + state.context.withHints( + FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM)); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta, info); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + + info.name + + "\"; " + + dimension + + " != " + + fieldEntry.dimension); + } + + long numQuantizedVectorBytes = + Math.multiplyExact( + (fieldEntry.scalarEncoding.getPackedLength(dimension) + + (Float.BYTES * 3) + + Integer.BYTES), + (long) fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "vector data length " + + fieldEntry.vectorDataLength + + " not matching size = " + + fieldEntry.size + + " * (dims=" + + dimension + + " + 16" + + ") = " + + numQuantizedVectorBytes); + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + return vectorScorer.getRandomVectorScorer( + fi.similarityFunction, + OffHeapScalarQuantizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.scalarEncoding, + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData), + target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fi.vectorEncoding + + " expected: " + + VectorEncoding.FLOAT32); + } + OffHeapScalarQuantizedVectorValues sqvv = + OffHeapScalarQuantizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.scalarEncoding, + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData); + return new ScalarQuantizedVectorValues(rawVectorsReader.getFloatVectorValues(field), sqvv); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { + if (knnCollector.k() == 0) return; + final RandomVectorScorer scorer = getRandomVectorScorer(field, target); + if (scorer == null) return; + OrdinalTranslatedKnnCollector collector = + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits()); + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + collector.collect(i, scorer.score(i)); + collector.incVisitedCount(1); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += + RamUsageEstimator.sizeOfMap( + fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + Objects.requireNonNull(fieldInfo); + var raw = rawVectorsReader.getOffHeapByteSize(fieldInfo); + var fieldEntry = fields.get(fieldInfo.name); + if (fieldEntry == null) { + assert fieldInfo.getVectorEncoding() == VectorEncoding.BYTE; + return raw; + } + var quant = Map.of(VECTOR_DATA_EXTENSION, fieldEntry.vectorDataLength()); + return KnnVectorsReader.mergeOffHeapByteSizeMaps(raw, quant); + } + + public float[] getCentroid(String field) { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry != null) { + return fieldEntry.centroid; + } + return null; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context) + throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + try { + int versionVectorData = + CodecUtil.checkIndexHeader( + in, + codecName, + Lucene104ScalarQuantizedVectorsFormat.VERSION_START, + Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + + versionMeta + + ", " + + codecName + + "=" + + versionVectorData, + in); + } + CodecUtil.retrieveChecksum(in); + return in; + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, in); + throw t; + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction()); + } + return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + int dimension, + long vectorDataOffset, + long vectorDataLength, + int size, + ScalarEncoding scalarEncoding, + float[] centroid, + float centroidDP, + OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) { + + static FieldEntry create( + IndexInput input, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction) + throws IOException { + int dimension = input.readVInt(); + long vectorDataOffset = input.readVLong(); + long vectorDataLength = input.readVLong(); + int size = input.readVInt(); + final float[] centroid; + float centroidDP = 0; + ScalarEncoding scalarEncoding = ScalarEncoding.UNSIGNED_BYTE; + if (size > 0) { + int wireNumber = input.readVInt(); + scalarEncoding = + ScalarEncoding.fromWireNumber(wireNumber) + .orElseThrow( + () -> + new IllegalStateException( + "Could not get ScalarEncoding from wire number: " + wireNumber)); + centroid = new float[dimension]; + input.readFloats(centroid, 0, dimension); + centroidDP = Float.intBitsToFloat(input.readInt()); + } else { + centroid = null; + } + OrdToDocDISIReaderConfiguration conf = + OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + dimension, + vectorDataOffset, + vectorDataLength, + size, + scalarEncoding, + centroid, + centroidDP, + conf); + } + } + + /** Vector values holding row and quantized vector values */ + protected static final class ScalarQuantizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final QuantizedByteVectorValues quantizedVectorValues; + + ScalarQuantizedVectorValues( + FloatVectorValues rawVectorValues, QuantizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); + } + + @Override + public ScalarQuantizedVectorValues copy() throws IOException { + return new ScalarQuantizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return rawVectorValues.getAcceptOrds(acceptDocs); + } + + @Override + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.scorer(query); + } + + QuantizedByteVectorValues getQuantizedVectorValues() throws IOException { + return quantizedVectorValues; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java new file mode 100644 index 000000000000..12de1f83bd36 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -0,0 +1,861 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ +public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = + shallowSizeOfInstance(Lucene104ScalarQuantizedVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, vectorData; + private final ScalarEncoding encoding; + private final FlatVectorsWriter rawVectorDelegate; + private final Lucene104ScalarQuantizedVectorScorer vectorsScorer; + private boolean finished; + + /** + * Sole constructor + * + * @param vectorsScorer the scorer to use for scoring vectors + */ + protected Lucene104ScalarQuantizedVectorsWriter( + SegmentWriteState state, + ScalarEncoding encoding, + FlatVectorsWriter rawVectorDelegate, + Lucene104ScalarQuantizedVectorScorer vectorsScorer) + throws IOException { + super(vectorsScorer); + this.encoding = encoding; + this.vectorsScorer = vectorsScorer; + this.segmentWriteState = state; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene104ScalarQuantizedVectorsFormat.META_EXTENSION); + + String vectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION); + this.rawVectorDelegate = rawVectorDelegate; + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + Lucene104ScalarQuantizedVectorsFormat.META_CODEC_NAME, + Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + CodecUtil.writeIndexHeader( + vectorData, + Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FieldWriter fieldWriter = + new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); + fields.add(fieldWriter); + return fieldWriter; + } + return rawVectorDelegate; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + // after raw vectors are written, normalize vectors for clustering and quantization + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + field.normalizeVectors(); + } + final float[] clusterCenter; + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } + } + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction()); + if (sortMap == null) { + writeField(field, clusterCenter, maxDoc, quantizer); + } else { + writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + private void writeField( + FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer) + throws IOException { + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + writeVectors(fieldData, clusterCenter, quantizer); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + !fieldData.getVectors().isEmpty() ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + clusterCenter, + centroidDp, + fieldData.getDocsWithFieldSet()); + } + + private void writeVectors( + FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + byte[] scratch = + new byte + [OptimizedScalarQuantizer.discretize( + fieldData.fieldInfo.getVectorDimension(), encoding.getDimensionsPerByte())]; + byte[] vector = + switch (encoding) { + case UNSIGNED_BYTE -> scratch; + case SEVEN_BIT -> scratch; + case PACKED_NIBBLE -> + new byte[encoding.getPackedLength(fieldData.fieldInfo.getVectorDimension())]; + }; + for (int i = 0; i < fieldData.getVectors().size(); i++) { + float[] v = fieldData.getVectors().get(i); + OptimizedScalarQuantizer.QuantizationResult corrections = + scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter); + if (encoding == ScalarEncoding.PACKED_NIBBLE) { + OffHeapScalarQuantizedVectorValues.packNibbles(scratch, vector); + } + vectorData.writeBytes(vector, vector.length); + vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + vectorData.writeInt(corrections.quantizedComponentSum()); + } + } + + private void writeSortingField( + FieldWriter fieldData, + float[] clusterCenter, + int maxDoc, + Sorter.DocMap sortMap, + OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + final int[] ordMap = + new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + writeSortedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); + long quantizedVectorLength = vectorData.getFilePointer() - vectorDataOffset; + + float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + quantizedVectorLength, + clusterCenter, + centroidDp, + newDocsWithField); + } + + private void writeSortedVectors( + FieldWriter fieldData, + float[] clusterCenter, + int[] ordMap, + OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + byte[] scratch = + new byte + [OptimizedScalarQuantizer.discretize( + fieldData.fieldInfo.getVectorDimension(), encoding.getDimensionsPerByte())]; + byte[] vector = + switch (encoding) { + case UNSIGNED_BYTE -> scratch; + case SEVEN_BIT -> scratch; + case PACKED_NIBBLE -> + new byte[encoding.getPackedLength(fieldData.fieldInfo.getVectorDimension())]; + }; + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + OptimizedScalarQuantizer.QuantizationResult corrections = + scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter); + if (encoding == ScalarEncoding.PACKED_NIBBLE) { + OffHeapScalarQuantizedVectorValues.packNibbles(scratch, vector); + } + vectorData.writeBytes(vector, vector.length); + vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + vectorData.writeInt(corrections.quantizedComponentSum()); + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + float[] clusterCenter, + float centroidDp, + DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + int count = docsWithField.cardinality(); + meta.writeVInt(count); + if (count > 0) { + meta.writeVInt(encoding.getWireNumber()); + final ByteBuffer buffer = + ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + meta.writeInt(Float.floatToIntBits(centroidDp)); + } + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + return; + } + + final float[] centroid; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + FloatVectorValues floatVectorValues = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + QuantizedFloatVectorValues quantizedVectorValues = + new QuantizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + encoding, + centroid); + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = writeVectorData(vectorData, quantizedVectorValues); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + centroidDp, + docsWithField); + } + + static DocsWithFieldSet writeVectorData( + IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iterator = quantizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write vector + byte[] binaryValue = quantizedByteVectorValues.vectorValue(iterator.index()); + output.writeBytes(binaryValue, binaryValue.length); + OptimizedScalarQuantizer.QuantizationResult corrections = + quantizedByteVectorValues.getCorrectiveTerms(iterator.index()); + output.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + output.writeInt(Float.floatToIntBits(corrections.upperInterval())); + output.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + output.writeInt(corrections.quantizedComponentSum()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + final float[] centroid; + final float cDotC; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC); + } + + private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + float[] centroid, + float cDotC) + throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + IndexOutput tempQuantizedVectorData = null; + IndexInput quantizedDataInput = null; + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + try { + tempQuantizedVectorData = + segmentWriteState.directory.createTempOutput( + vectorData.getName(), "temp", segmentWriteState.context); + final String tempQuantizedVectorName = tempQuantizedVectorData.getName(); + FloatVectorValues floatVectorValues = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + DocsWithFieldSet docsWithField = + writeVectorData( + tempQuantizedVectorData, + new QuantizedFloatVectorValues(floatVectorValues, quantizer, encoding, centroid)); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + quantizedDataInput = + segmentWriteState.directory.openInput(tempQuantizedVectorName, segmentWriteState.context); + vectorData.copyBytes( + quantizedDataInput, quantizedDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + CodecUtil.retrieveChecksum(quantizedDataInput); + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + cDotC, + docsWithField); + + final IndexInput finalQuantizedDataInput = quantizedDataInput; + tempQuantizedVectorData = null; + quantizedDataInput = null; + + OffHeapScalarQuantizedVectorValues vectorValues = + new OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + centroid, + cDotC, + quantizer, + encoding, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, + finalQuantizedDataInput); + RandomVectorScorerSupplier scorerSupplier = + vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), vectorValues); + return new QuantizedCloseableRandomVectorScorerSupplier( + scorerSupplier, + vectorValues, + () -> { + IOUtils.close(finalQuantizedDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, tempQuantizedVectorName); + }); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, tempQuantizedVectorData, quantizedDataInput); + if (tempQuantizedVectorData != null) { + IOUtils.deleteFilesSuppressingExceptions( + t, segmentWriteState.directory, tempQuantizedVectorData.getName()); + } + throw t; + } + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData, rawVectorDelegate); + } + + static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof Lucene104ScalarQuantizedVectorsReader reader) { + return reader.getCentroid(fieldName); + } + return null; + } + + static int mergeAndRecalculateCentroids( + MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { + boolean recalculate = false; + int totalVectorCount = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null + || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); + int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); + if (vectorCount == 0) { + continue; + } + totalVectorCount += vectorCount; + // If there aren't centroids, or previously clustered with more than one cluster + // or if there are deleted docs, we must recalculate the centroid + if (centroid == null || mergeState.liveDocs[i] != null) { + recalculate = true; + break; + } + for (int j = 0; j < centroid.length; j++) { + mergedCentroid[j] += centroid[j] * vectorCount; + } + } + if (recalculate) { + return calculateCentroid(mergeState, fieldInfo, mergedCentroid); + } else { + for (int j = 0; j < mergedCentroid.length; j++) { + mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(mergedCentroid); + } + return totalVectorCount; + } + } + + static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) + throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + // clear out the centroid + Arrays.fill(centroid, 0); + int count = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null) continue; + FloatVectorValues vectorValues = + mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); + doc != DocIdSetIterator.NO_MORE_DOCS; + doc = iterator.nextDoc()) { + ++count; + float[] vector = vectorValues.vectorValue(iterator.index()); + for (int j = 0; j < vector.length; j++) { + centroid[j] += vector[j]; + } + } + } + if (count == 0) { + return count; + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= count; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(centroid); + } + return count; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private final float[] dimensionSums; + private final FloatArrayList magnitudes = new FloatArrayList(); + + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = flatFieldVectorsWriter; + this.dimensionSums = new float[fieldInfo.getVectorDimension()]; + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + public void normalizeVectors() { + for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { + float[] vector = flatFieldVectorsWriter.getVectors().get(i); + float magnitude = magnitudes.get(i); + for (int j = 0; j < vector.length; j++) { + vector[j] /= magnitude; + } + } + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float divisor = (float) Math.sqrt(dp); + magnitudes.add(divisor); + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } + } else { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += vectorValue[i]; + } + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + size += magnitudes.ramBytesUsed(); + return size; + } + } + + static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] quantized; + private final byte[] packed; + private final float[] centroid; + private final float centroidDP; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; + private final ScalarEncoding encoding; + + private int lastOrd = -1; + + QuantizedFloatVectorValues( + FloatVectorValues delegate, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + float[] centroid) { + this.values = delegate; + this.quantizer = quantizer; + this.encoding = encoding; + this.quantized = + new byte + [OptimizedScalarQuantizer.discretize( + delegate.dimension(), encoding.getDimensionsPerByte())]; + this.packed = + switch (encoding) { + case UNSIGNED_BYTE -> this.quantized; + case SEVEN_BIT -> this.quantized; + case PACKED_NIBBLE -> new byte[encoding.getPackedLength(delegate.dimension())]; + }; + this.centroid = centroid; + this.centroidDP = VectorUtil.dotProduct(centroid, centroid); + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + + ord + + " than the quantization was done for: " + + lastOrd); + } + return corrections; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + quantize(ord); + lastOrd = ord; + } + return packed; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public ScalarEncoding getScalarEncoding() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public float getCentroidDP() { + return centroidDP; + } + + @Override + public int size() { + return values.size(); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public QuantizedByteVectorValues copy() throws IOException { + return new QuantizedFloatVectorValues(values.copy(), quantizer, encoding, centroid); + } + + private void quantize(int ord) throws IOException { + corrections = + quantizer.scalarQuantize( + values.vectorValue(ord), quantized, encoding.getBits(), centroid); + if (encoding == ScalarEncoding.PACKED_NIBBLE) { + OffHeapScalarQuantizedVectorValues.packNibbles(quantized, packed); + } + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + } + + static class QuantizedCloseableRandomVectorScorerSupplier + implements CloseableRandomVectorScorerSupplier { + private final RandomVectorScorerSupplier supplier; + private final KnnVectorValues vectorValues; + private final Closeable onClose; + + QuantizedCloseableRandomVectorScorerSupplier( + RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { + this.supplier = supplier; + this.onClose = onClose; + this.vectorValues = vectorValues; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return vectorValues.size(); + } + } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + + NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java new file mode 100644 index 000000000000..d2c678d8f8ba --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import java.io.IOException; +import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** + * Scalar quantized vector values loaded from off-heap + * + * @lucene.internal + */ +public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVectorValues { + final int dimension; + final int size; + final VectorSimilarityFunction similarityFunction; + final FlatVectorsScorer vectorsScorer; + + final IndexInput slice; + final byte[] vectorValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final OptimizedScalarQuantizer quantizer; + final ScalarEncoding encoding; + final float[] centroid; + final float centroidDp; + + OffHeapScalarQuantizedVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.correctiveValues = new float[3]; + this.byteSize = encoding.getPackedLength(dimension) + (Float.BYTES * 3) + Integer.BYTES; + this.byteBuffer = ByteBuffer.allocate(encoding.getPackedLength(dimension)); + this.vectorValue = byteBuffer.array(); + this.quantizer = quantizer; + this.encoding = encoding; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return vectorValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + lastOrd = targetOrd; + return vectorValue; + } + + @Override + public float getCentroidDP() { + return centroidDp; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) + throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + slice.seek(((long) targetOrd * byteSize) + dimension); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + return quantizer; + } + + @Override + public ScalarEncoding getScalarEncoding() { + return encoding; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public int getVectorByteLength() { + return dimension; + } + + static void packNibbles(byte[] unpacked, byte[] packed) { + assert unpacked.length == packed.length * 2; + for (int i = 0; i < packed.length; i++) { + int x = unpacked[i] << 4 | unpacked[packed.length + i]; + packed[i] = (byte) x; + } + } + + static void unpackNibbles(byte[] packed, byte[] unpacked) { + assert unpacked.length == packed.length * 2; + for (int i = 0; i < packed.length; i++) { + unpacked[i] = (byte) ((packed[i] >> 4) & 0x0F); + unpacked[packed.length + i] = (byte) (packed[i] & 0x0F); + } + } + + static OffHeapScalarQuantizedVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + float[] centroid, + float centroidDp, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData) + throws IOException { + if (configuration.isEmpty()) { + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); + } + assert centroid != null; + IndexInput bytesSlice = + vectorData.slice( + "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + if (configuration.isDense()) { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + bytesSlice); + } else { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice); + } + } + + /** Dense off-heap scalar quantized vector values */ + static class DenseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { + DenseOffHeapVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + super( + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + slice); + } + + @Override + public OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy() throws IOException { + return new OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + /** Sparse off-heap scalar quantized vector values */ + private static class SparseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) + throws IOException { + super( + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + slice); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { + EmptyOffHeapVectorValues( + int dimension, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer) { + super( + dimension, + 0, + null, + Float.NaN, + null, + ScalarEncoding.UNSIGNED_BYTE, + similarityFunction, + vectorsScorer, + null); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java new file mode 100644 index 000000000000..48d0c4e665f1 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import java.io.IOException; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Scalar quantized byte vector values */ +abstract class QuantizedByteVectorValues extends ByteVectorValues { + + /** + * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of + * distances, the corrective terms are, in order + * + *

+ * + * For euclidean: + * + * + * + * @param vectorOrd the vector ordinal + * @return the corrective terms + * @throws IOException if an I/O error occurs + */ + public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) + throws IOException; + + /** + * @return the quantizer used to quantize the vectors + */ + public abstract OptimizedScalarQuantizer getQuantizer(); + + /** + * @return the scalar encoding used to pack the vectors. + */ + public abstract ScalarEncoding getScalarEncoding(); + + /** + * @return the centroid used to center the vectors prior to quantization + */ + public abstract float[] getCentroid() throws IOException; + + /** + * @return the dot product of the centroid. + */ + public abstract float getCentroidDP() throws IOException; + + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance or null + */ + public abstract VectorScorer scorer(float[] query) throws IOException; + + @Override + public abstract QuantizedByteVectorValues copy() throws IOException; +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java index 1966ed21d654..42e9381aa11d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java @@ -42,6 +42,7 @@ * * @lucene.experimental */ +@Deprecated public class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat { public static final String NAME = "Lucene99HnswScalarQuantizedVectorsFormat"; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 0f339ecbe0a8..da18fb8d2e57 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -31,6 +31,7 @@ * * @lucene.experimental */ +@Deprecated public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { // The bits that are allowed for scalar quantization diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 38b9cf6d67a5..c27a7b6f5fdb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -214,7 +214,7 @@ public static int int4DotProduct(byte[] a, byte[] b) { public static int int4DotProductPacked(byte[] unpacked, byte[] packed) { if (packed.length != ((unpacked.length + 1) >> 1)) { throw new IllegalArgumentException( - "vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length); + "vector dimensions differ: " + unpacked.length + " != 2 * " + packed.length); } return IMPL.int4DotProduct(unpacked, false, packed, true); } diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 0558fc8fef05..fde541c2ac08 100644 --- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -18,3 +18,5 @@ org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat +org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat +org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..7f01f1bbe852 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static java.lang.String.format; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.SameThreadExecutorService; + +public class TestLucene104HnswScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + private static final KnnVectorsFormat FORMAT = new Lucene104HnswScalarQuantizedVectorsFormat(); + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(FORMAT); + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene104HnswScalarQuantizedVectorsFormat( + ScalarEncoding.UNSIGNED_BYTE, 10, 20, 1, null); + } + }; + String expectedPattern = + "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=Lucene104ScalarQuantizedVectorsFormat(name=Lucene104ScalarQuantizedVectorsFormat," + + " encoding=UNSIGNED_BYTE," + + " flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s())," + + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))"; + + var defaultScorer = + format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer"); + var memSegScorer = + format( + Locale.ROOT, + expectedPattern, + "Lucene99MemorySegmentFlatVectorsScorer", + "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); + } + float[] randomVector = randomVector(vector.length); + float trueScore = similarityFunction.compare(vector, randomVector); + TopDocs td = + r.searchNearestVectors( + "f", + randomVector, + 1, + AcceptDocs.fromLiveDocs(null, r.maxDoc()), + Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + // When it's the only vector in a segment, the score should be very close to the true + // score + assertEquals(trueScore, td.scoreDocs[0].score, 0.01f); + } + } + } + } + + public void testLimits() { + expectThrows( + IllegalArgumentException.class, + () -> new Lucene104HnswScalarQuantizedVectorsFormat(-1, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene104HnswScalarQuantizedVectorsFormat(0, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene104HnswScalarQuantizedVectorsFormat(20, 0)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene104HnswScalarQuantizedVectorsFormat(20, -1)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene104HnswScalarQuantizedVectorsFormat(512 + 1, 20)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene104HnswScalarQuantizedVectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> + new Lucene104HnswScalarQuantizedVectorsFormat( + ScalarEncoding.UNSIGNED_BYTE, 20, 100, 1, new SameThreadExecutorService())); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); + assertEquals(1L, (long) offHeap.get("vex")); + long corrections = Float.BYTES + Float.BYTES + Float.BYTES + Integer.BYTES; + long expected = fieldInfo.getVectorDimension() + corrections; + assertEquals(expected, (long) offHeap.get("veq")); + assertEquals(3, offHeap.size()); + } + } + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..29041b5b07f0 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene104; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.junit.Before; + +public class TestLucene104ScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + private ScalarEncoding encoding; + private KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + var encodingValues = ScalarEncoding.values(); + encoding = encodingValues[random().nextInt(encodingValues.length)]; + format = new Lucene104ScalarQuantizedVectorsFormat(encoding); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene104ScalarQuantizedVectorsFormat(); + } + }; + String expectedPattern = + "Lucene104ScalarQuantizedVectorsFormat(" + + "name=Lucene104ScalarQuantizedVectorsFormat, " + + "encoding=UNSIGNED_BYTE, " + + "flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()), " + + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))"; + var defaultScorer = + format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer"); + var memSegScorer = + format( + Locale.ROOT, + expectedPattern, + "Lucene99MemorySegmentFlatVectorsScorer", + "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + public void testQuantizedVectorsWriteAndRead() throws IOException { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + if (i % 101 == 0) { + w.commit(); + } + } + w.commit(); + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); + assertEquals(vectorValues.size(), numVectors); + QuantizedByteVectorValues qvectorValues = + ((Lucene104ScalarQuantizedVectorsReader.ScalarQuantizedVectorValues) vectorValues) + .getQuantizedVectorValues(); + float[] centroid = qvectorValues.getCentroid(); + assertEquals(centroid.length, dims); + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + byte[] scratch = + new byte[OptimizedScalarQuantizer.discretize(dims, encoding.getDimensionsPerByte())]; + byte[] expectedVector = new byte[encoding.getPackedLength(dims)]; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + vectorValues = + new Lucene104ScalarQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); + } + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + OptimizedScalarQuantizer.QuantizationResult corrections = + quantizer.scalarQuantize( + vectorValues.vectorValue(docIndexIterator.index()), + scratch, + encoding.getBits(), + centroid); + switch (encoding) { + case UNSIGNED_BYTE -> System.arraycopy(scratch, 0, expectedVector, 0, dims); + case SEVEN_BIT -> System.arraycopy(scratch, 0, expectedVector, 0, dims); + case PACKED_NIBBLE -> + OffHeapScalarQuantizedVectorValues.packNibbles(scratch, expectedVector); + } + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + var actualCorrections = qvectorValues.getCorrectiveTerms(docIndexIterator.index()); + assertEquals(corrections.lowerInterval(), actualCorrections.lowerInterval(), 0.00001f); + assertEquals(corrections.upperInterval(), actualCorrections.upperInterval(), 0.00001f); + assertEquals( + corrections.additionalCorrection(), + actualCorrections.additionalCorrection(), + 0.00001f); + assertEquals( + corrections.quantizedComponentSum(), actualCorrections.quantizedComponentSum()); + } + } + } + } + } +}