diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 3ea1326b4608..885f44491a6c 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -130,6 +130,8 @@ New Features
* GITHUB#15176: Add `[Float|Byte]VectorValues#rescorer(element[])` interface to allow optimized rescoring of vectors.
(Ben Trent)
+* GITHUB#15169: Add codecs for 4 and 8 bit Optimized Scalar Quantization vectors (Trevor McCulloch)
+
Improvements
---------------------
# GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch)
diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java
index 17876735f1f6..0952f2696868 100644
--- a/lucene/core/src/java/module-info.java
+++ b/lucene/core/src/java/module-info.java
@@ -87,7 +87,9 @@
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat,
- org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat;
+ org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat,
+ org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat,
+ org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene104.Lucene104PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..7d49183249e2
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene104;
+
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.search.TaskExecutor;
+import org.apache.lucene.util.hnsw.HnswGraph;
+
+/**
+ * A vectors format that uses HNSW graph to store and search for vectors. But vectors are binary
+ * quantized using {@link Lucene104ScalarQuantizedVectorsFormat} before being stored in the graph.
+ */
+public class Lucene104HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
+
+ public static final String NAME = "Lucene104HnswBinaryQuantizedVectorsFormat";
+
+ /**
+ * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
+ * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
+ */
+ private final int maxConn;
+
+ /**
+ * The number of candidate neighbors to track while searching the graph for each newly inserted
+ * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
+ * for details.
+ */
+ private final int beamWidth;
+
+ /** The format for storing, reading, merging vectors on disk */
+ private final Lucene104ScalarQuantizedVectorsFormat flatVectorsFormat;
+
+ private final int numMergeWorkers;
+ private final TaskExecutor mergeExec;
+
+ /** Constructs a format using default graph construction parameters */
+ public Lucene104HnswScalarQuantizedVectorsFormat() {
+ this(
+ ScalarEncoding.UNSIGNED_BYTE,
+ DEFAULT_MAX_CONN,
+ DEFAULT_BEAM_WIDTH,
+ DEFAULT_NUM_MERGE_WORKER,
+ null);
+ }
+
+ /**
+ * Constructs a format using the given graph construction parameters.
+ *
+ * @param maxConn the maximum number of connections to a node in the HNSW graph
+ * @param beamWidth the size of the queue maintained during graph construction.
+ */
+ public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
+ this(ScalarEncoding.UNSIGNED_BYTE, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
+ }
+
+ /**
+ * Constructs a format using the given graph construction parameters and scalar quantization.
+ *
+ * @param maxConn the maximum number of connections to a node in the HNSW graph
+ * @param beamWidth the size of the queue maintained during graph construction.
+ * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
+ * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
+ * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
+ * generated by this format to do the merge
+ */
+ public Lucene104HnswScalarQuantizedVectorsFormat(
+ ScalarEncoding encoding,
+ int maxConn,
+ int beamWidth,
+ int numMergeWorkers,
+ ExecutorService mergeExec) {
+ super(NAME);
+ flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding);
+ if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
+ throw new IllegalArgumentException(
+ "maxConn must be positive and less than or equal to "
+ + MAXIMUM_MAX_CONN
+ + "; maxConn="
+ + maxConn);
+ }
+ if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
+ throw new IllegalArgumentException(
+ "beamWidth must be positive and less than or equal to "
+ + MAXIMUM_BEAM_WIDTH
+ + "; beamWidth="
+ + beamWidth);
+ }
+ this.maxConn = maxConn;
+ this.beamWidth = beamWidth;
+ if (numMergeWorkers == 1 && mergeExec != null) {
+ throw new IllegalArgumentException(
+ "No executor service is needed as we'll use single thread to merge");
+ }
+ this.numMergeWorkers = numMergeWorkers;
+ if (mergeExec != null) {
+ this.mergeExec = new TaskExecutor(mergeExec);
+ } else {
+ this.mergeExec = null;
+ }
+ }
+
+ @Override
+ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new Lucene99HnswVectorsWriter(
+ state,
+ maxConn,
+ beamWidth,
+ flatVectorsFormat.fieldsWriter(state),
+ numMergeWorkers,
+ mergeExec);
+ }
+
+ @Override
+ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
+ }
+
+ @Override
+ public int getMaxDimensions(String fieldName) {
+ return 1024;
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn="
+ + maxConn
+ + ", beamWidth="
+ + beamWidth
+ + ", flatVectorFormat="
+ + flatVectorsFormat
+ + ")";
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java
new file mode 100644
index 000000000000..b5a8db5190eb
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene104;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+import java.io.IOException;
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+/** Vector scorer over OptimizedScalarQuantized vectors */
+public class Lucene104ScalarQuantizedVectorScorer implements FlatVectorsScorer {
+ private final FlatVectorsScorer nonQuantizedDelegate;
+
+ public Lucene104ScalarQuantizedVectorScorer(FlatVectorsScorer nonQuantizedDelegate) {
+ this.nonQuantizedDelegate = nonQuantizedDelegate;
+ }
+
+ @Override
+ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
+ throws IOException {
+ if (vectorValues instanceof QuantizedByteVectorValues qv) {
+ return new ScalarQuantizedVectorScorerSupplier(qv, similarityFunction);
+ }
+ // It is possible to get to this branch during initial indexing and flush
+ return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
+ throws IOException {
+ if (vectorValues instanceof QuantizedByteVectorValues qv) {
+ OptimizedScalarQuantizer quantizer = qv.getQuantizer();
+ byte[] targetQuantized =
+ new byte
+ [OptimizedScalarQuantizer.discretize(
+ target.length, qv.getScalarEncoding().getDimensionsPerByte())];
+ // We make a copy as the quantization process mutates the input
+ float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
+ if (similarityFunction == COSINE) {
+ VectorUtil.l2normalize(copy);
+ }
+ target = copy;
+ var targetCorrectiveTerms =
+ quantizer.scalarQuantize(
+ target, targetQuantized, qv.getScalarEncoding().getBits(), qv.getCentroid());
+ return new RandomVectorScorer.AbstractRandomVectorScorer(qv) {
+ @Override
+ public float score(int node) throws IOException {
+ return quantizedScore(
+ targetQuantized, targetCorrectiveTerms, qv, node, similarityFunction);
+ }
+ };
+ }
+ // It is possible to get to this branch during initial indexing and flush
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
+ throws IOException {
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate="
+ + nonQuantizedDelegate
+ + ")";
+ }
+
+ private static final class ScalarQuantizedVectorScorerSupplier
+ implements RandomVectorScorerSupplier {
+ private final QuantizedByteVectorValues targetValues;
+ private final QuantizedByteVectorValues values;
+ private final VectorSimilarityFunction similarity;
+
+ public ScalarQuantizedVectorScorerSupplier(
+ QuantizedByteVectorValues values, VectorSimilarityFunction similarity) throws IOException {
+ this.targetValues = values.copy();
+ this.values = values;
+ this.similarity = similarity;
+ }
+
+ @Override
+ public UpdateableRandomVectorScorer scorer() throws IOException {
+ return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) {
+ private byte[] targetVector;
+ private OptimizedScalarQuantizer.QuantizationResult targetCorrectiveTerms;
+
+ @Override
+ public float score(int node) throws IOException {
+ return quantizedScore(targetVector, targetCorrectiveTerms, values, node, similarity);
+ }
+
+ @Override
+ public void setScoringOrdinal(int node) throws IOException {
+ var rawTargetVector = targetValues.vectorValue(node);
+ switch (values.getScalarEncoding()) {
+ case UNSIGNED_BYTE -> targetVector = rawTargetVector;
+ case SEVEN_BIT -> targetVector = rawTargetVector;
+ case PACKED_NIBBLE -> {
+ if (targetVector == null) {
+ targetVector = new byte[OptimizedScalarQuantizer.discretize(values.dimension(), 2)];
+ }
+ OffHeapScalarQuantizedVectorValues.unpackNibbles(rawTargetVector, targetVector);
+ }
+ }
+ targetCorrectiveTerms = targetValues.getCorrectiveTerms(node);
+ }
+ };
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return new ScalarQuantizedVectorScorerSupplier(values.copy(), similarity);
+ }
+ }
+
+ private static final float[] SCALE_LUT =
+ new float[] {
+ 1f,
+ 1f / ((1 << 2) - 1),
+ 1f / ((1 << 3) - 1),
+ 1f / ((1 << 4) - 1),
+ 1f / ((1 << 5) - 1),
+ 1f / ((1 << 6) - 1),
+ 1f / ((1 << 7) - 1),
+ 1f / ((1 << 8) - 1),
+ };
+
+ private static float quantizedScore(
+ byte[] quantizedQuery,
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+ QuantizedByteVectorValues targetVectors,
+ int targetOrd,
+ VectorSimilarityFunction similarityFunction)
+ throws IOException {
+ var scalarEncoding = targetVectors.getScalarEncoding();
+ byte[] quantizedDoc = targetVectors.vectorValue(targetOrd);
+ float qcDist =
+ switch (scalarEncoding) {
+ case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc);
+ case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc);
+ case PACKED_NIBBLE -> VectorUtil.int4DotProductPacked(quantizedQuery, quantizedDoc);
+ };
+ OptimizedScalarQuantizer.QuantizationResult indexCorrections =
+ targetVectors.getCorrectiveTerms(targetOrd);
+ float scale = SCALE_LUT[scalarEncoding.getBits() - 1];
+ float x1 = indexCorrections.quantizedComponentSum();
+ float ax = indexCorrections.lowerInterval();
+ // Here we must scale according to the bits
+ float lx = (indexCorrections.upperInterval() - ax) * scale;
+ float ay = queryCorrections.lowerInterval();
+ float ly = (queryCorrections.upperInterval() - ay) * scale;
+ float y1 = queryCorrections.quantizedComponentSum();
+ float score =
+ ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ score =
+ queryCorrections.additionalCorrection()
+ + indexCorrections.additionalCorrection()
+ - 2 * score;
+ return Math.max(1 / (1f + score), 0);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ score +=
+ queryCorrections.additionalCorrection()
+ + indexCorrections.additionalCorrection()
+ - targetVectors.getCentroidDP();
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ return VectorUtil.scaleMaxInnerProductScore(score);
+ }
+ return Math.max((1f + score) / 2f, 0);
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..9acb43c64cf5
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene104;
+
+import java.io.IOException;
+import java.util.Optional;
+import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+
+/**
+ * The quantization format used here is a per-vector optimized scalar quantization. These ideas are
+ * evolutions of LVQ proposed in Similarity search in the
+ * blink of an eye with compressed indices by Cecilia Aguerrebere et al., the previous work on
+ * globally optimized scalar quantization in Apache Lucene, and Accelerating Large-Scale Inference with Anisotropic
+ * Vector Quantization by Ruiqi Guo et. al. Also see {@link
+ * org.apache.lucene.util.quantization.OptimizedScalarQuantizer}. Some of key features are:
+ *
+ *
+ * - Estimating the distance between two vectors using their centroid centered distance. This
+ * requires some additional corrective factors, but allows for centroid centering to occur.
+ *
- Optimized scalar quantization to single bit level of centroid centered vectors.
+ *
- Asymmetric quantization of vectors, where query vectors are quantized to half-byte (4 bits)
+ * precision (normalized to the centroid) and then compared directly against the single bit
+ * quantized vectors in the index.
+ *
- Transforming the half-byte quantized query vectors in such a way that the comparison with
+ * single bit vectors can be done with bit arithmetic.
+ *
+ *
+ * A previous work related to improvements over regular LVQ is Practical and Asymptotically Optimal Quantization of
+ * High-Dimensional Vectors in Euclidean Space for Approximate Nearest Neighbor Search by
+ * Jianyang Gao, et. al.
+ *
+ * The format is stored within two files:
+ *
+ *
.veq (vector data) file
+ *
+ * Stores the quantized vectors in a flat format. Additionally, it stores each vector's
+ * corrective factors. At the end of the file, additional information is stored for vector ordinal
+ * to centroid ordinal mapping and sparse vector information.
+ *
+ *
+ * - For each vector:
+ *
+ * - [byte] the quantized values. Each dimension may be up to 8 bits, and multiple
+ * dimensions may be packed into a single byte.
+ *
- [float] the optimized quantiles and an additional similarity dependent
+ * corrective factor.
+ *
- [int] the sum of the quantized components
+ *
+ * - After the vectors, sparse vector information keeping track of monotonic blocks.
+ *
+ *
+ * .vemq (vector metadata) file
+ *
+ * Stores the metadata for the vectors. This includes the number of vectors, the number of
+ * dimensions, and file offset information.
+ *
+ *
+ * - int the field number
+ *
- int the vector encoding ordinal
+ *
- int the vector similarity ordinal
+ *
- vint the vector dimensions
+ *
- vlong the offset to the vector data in the .veq file
+ *
- vlong the length of the vector data in the .veq file
+ *
- vint the number of vectors
+ *
- vint the wire number for ScalarEncoding
+ *
- [float] the centroid
+ *
- float the centroid square magnitude
+ *
- The sparse vector information, if required, mapping vector ordinal to doc ID
+ *
+ */
+public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
+ public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
+ public static final String NAME = "Lucene104ScalarQuantizedVectorsFormat";
+
+ static final int VERSION_START = 0;
+ static final int VERSION_CURRENT = VERSION_START;
+ static final String META_CODEC_NAME = "Lucene104ScalarQuantizedVectorsFormatMeta";
+ static final String VECTOR_DATA_CODEC_NAME = "Lucene104ScalarQuantizedVectorsFormatData";
+ static final String META_EXTENSION = "vemq";
+ static final String VECTOR_DATA_EXTENSION = "veq";
+ static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
+
+ private static final FlatVectorsFormat rawVectorFormat =
+ new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
+
+ private static final Lucene104ScalarQuantizedVectorScorer scorer =
+ new Lucene104ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
+
+ private final ScalarEncoding encoding;
+
+ /**
+ * Allowed encodings for scalar quantization.
+ *
+ * This specifies how many bits are used per dimension and also dictates packing of dimensions
+ * into a byte stream.
+ */
+ public enum ScalarEncoding {
+ /** Each dimension is quantized to 8 bits and treated as an unsigned value. */
+ UNSIGNED_BYTE(0, (byte) 8, 1),
+ /** Each dimension is quantized to 4 bits two values are packed into each output byte. */
+ PACKED_NIBBLE(1, (byte) 4, 2),
+ /**
+ * Each dimension is quantized to 7 bits and treated as a signed value.
+ *
+ *
This is intended for backwards compatibility with older iterations of scalar quantization.
+ * This setting will produce an index the same size as {@link #UNSIGNED_BYTE} but will produce
+ * less accurate vector comparisons.
+ */
+ SEVEN_BIT(2, (byte) 7, 1);
+
+ /** The number used to identify this encoding on the wire, rather than relying on ordinal. */
+ private final int wireNumber;
+
+ private final byte bits;
+ private final int dimsPerByte;
+
+ ScalarEncoding(int wireNumber, byte bits, int dimsPerByte) {
+ this.wireNumber = wireNumber;
+ this.bits = bits;
+ this.dimsPerByte = dimsPerByte;
+ }
+
+ int getWireNumber() {
+ return wireNumber;
+ }
+
+ /** Return the number of bits used per dimension. */
+ public byte getBits() {
+ return bits;
+ }
+
+ /** Return the number of dimensions that can be packed into a single byte. */
+ public int getDimensionsPerByte() {
+ return this.dimsPerByte;
+ }
+
+ /** Return the number of bytes required to store a packed vector of the given dimensions. */
+ public int getPackedLength(int dimensions) {
+ return (dimensions + this.dimsPerByte - 1) / this.dimsPerByte;
+ }
+
+ /** Returns the encoding for the given wire number, or empty if unknown. */
+ public static Optional fromWireNumber(int wireNumber) {
+ for (ScalarEncoding encoding : values()) {
+ if (encoding.wireNumber == wireNumber) {
+ return Optional.of(encoding);
+ }
+ }
+ return Optional.empty();
+ }
+ }
+
+ /** Creates a new instance with UNSIGNED_BYTE encoding. */
+ public Lucene104ScalarQuantizedVectorsFormat() {
+ this(ScalarEncoding.UNSIGNED_BYTE);
+ }
+
+ /** Creates a new instance with the chosen encoding. */
+ public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding) {
+ super(NAME);
+ this.encoding = encoding;
+ }
+
+ @Override
+ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new Lucene104ScalarQuantizedVectorsWriter(
+ state, encoding, rawVectorFormat.fieldsWriter(state), scorer);
+ }
+
+ @Override
+ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new Lucene104ScalarQuantizedVectorsReader(
+ state, rawVectorFormat.fieldsReader(state), scorer);
+ }
+
+ @Override
+ public int getMaxDimensions(String fieldName) {
+ return 1024;
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene104ScalarQuantizedVectorsFormat(name="
+ + NAME
+ + ", encoding="
+ + encoding
+ + ", flatVectorScorer="
+ + scorer
+ + ", rawVectorFormat="
+ + rawVectorFormat
+ + ")";
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java
new file mode 100644
index 000000000000..0251224998b4
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene104;
+
+import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.CorruptIndexException;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FieldInfos;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.AcceptDocs;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.ChecksumIndexInput;
+import org.apache.lucene.store.DataAccessHint;
+import org.apache.lucene.store.FileDataHint;
+import org.apache.lucene.store.FileTypeHint;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+/** Reader for scalar quantized vectors in the Lucene 10.4 format. */
+class Lucene104ScalarQuantizedVectorsReader extends FlatVectorsReader {
+
+ private static final long SHALLOW_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(Lucene104ScalarQuantizedVectorsReader.class);
+
+ private final Map fields = new HashMap<>();
+ private final IndexInput quantizedVectorData;
+ private final FlatVectorsReader rawVectorsReader;
+ private final Lucene104ScalarQuantizedVectorScorer vectorScorer;
+
+ Lucene104ScalarQuantizedVectorsReader(
+ SegmentReadState state,
+ FlatVectorsReader rawVectorsReader,
+ Lucene104ScalarQuantizedVectorScorer vectorsScorer)
+ throws IOException {
+ super(vectorsScorer);
+ this.vectorScorer = vectorsScorer;
+ this.rawVectorsReader = rawVectorsReader;
+ int versionMeta = -1;
+ String metaFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene104ScalarQuantizedVectorsFormat.META_EXTENSION);
+ try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
+ Throwable priorE = null;
+ try {
+ versionMeta =
+ CodecUtil.checkIndexHeader(
+ meta,
+ Lucene104ScalarQuantizedVectorsFormat.META_CODEC_NAME,
+ Lucene104ScalarQuantizedVectorsFormat.VERSION_START,
+ Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ readFields(meta, state.fieldInfos);
+ } catch (Throwable exception) {
+ priorE = exception;
+ } finally {
+ CodecUtil.checkFooter(meta, priorE);
+ }
+ quantizedVectorData =
+ openDataInput(
+ state,
+ versionMeta,
+ VECTOR_DATA_EXTENSION,
+ Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
+ // Quantized vectors are accessed randomly from their node ID stored in the HNSW
+ // graph.
+ state.context.withHints(
+ FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM));
+ } catch (Throwable t) {
+ IOUtils.closeWhileSuppressingExceptions(t, this);
+ throw t;
+ }
+ }
+
+ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
+ for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
+ FieldInfo info = infos.fieldInfo(fieldNumber);
+ if (info == null) {
+ throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
+ }
+ FieldEntry fieldEntry = readField(meta, info);
+ validateFieldEntry(info, fieldEntry);
+ fields.put(info.name, fieldEntry);
+ }
+ }
+
+ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
+ int dimension = info.getVectorDimension();
+ if (dimension != fieldEntry.dimension) {
+ throw new IllegalStateException(
+ "Inconsistent vector dimension for field=\""
+ + info.name
+ + "\"; "
+ + dimension
+ + " != "
+ + fieldEntry.dimension);
+ }
+
+ long numQuantizedVectorBytes =
+ Math.multiplyExact(
+ (fieldEntry.scalarEncoding.getPackedLength(dimension)
+ + (Float.BYTES * 3)
+ + Integer.BYTES),
+ (long) fieldEntry.size);
+ if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
+ throw new IllegalStateException(
+ "vector data length "
+ + fieldEntry.vectorDataLength
+ + " not matching size = "
+ + fieldEntry.size
+ + " * (dims="
+ + dimension
+ + " + 16"
+ + ") = "
+ + numQuantizedVectorBytes);
+ }
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
+ FieldEntry fi = fields.get(field);
+ if (fi == null) {
+ return null;
+ }
+ return vectorScorer.getRandomVectorScorer(
+ fi.similarityFunction,
+ OffHeapScalarQuantizedVectorValues.load(
+ fi.ordToDocDISIReaderConfiguration,
+ fi.dimension,
+ fi.size,
+ new OptimizedScalarQuantizer(fi.similarityFunction),
+ fi.scalarEncoding,
+ fi.similarityFunction,
+ vectorScorer,
+ fi.centroid,
+ fi.centroidDP,
+ fi.vectorDataOffset,
+ fi.vectorDataLength,
+ quantizedVectorData),
+ target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
+ return rawVectorsReader.getRandomVectorScorer(field, target);
+ }
+
+ @Override
+ public void checkIntegrity() throws IOException {
+ rawVectorsReader.checkIntegrity();
+ CodecUtil.checksumEntireFile(quantizedVectorData);
+ }
+
+ @Override
+ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
+ FieldEntry fi = fields.get(field);
+ if (fi == null) {
+ return null;
+ }
+ if (fi.vectorEncoding != VectorEncoding.FLOAT32) {
+ throw new IllegalArgumentException(
+ "field=\""
+ + field
+ + "\" is encoded as: "
+ + fi.vectorEncoding
+ + " expected: "
+ + VectorEncoding.FLOAT32);
+ }
+ OffHeapScalarQuantizedVectorValues sqvv =
+ OffHeapScalarQuantizedVectorValues.load(
+ fi.ordToDocDISIReaderConfiguration,
+ fi.dimension,
+ fi.size,
+ new OptimizedScalarQuantizer(fi.similarityFunction),
+ fi.scalarEncoding,
+ fi.similarityFunction,
+ vectorScorer,
+ fi.centroid,
+ fi.centroidDP,
+ fi.vectorDataOffset,
+ fi.vectorDataLength,
+ quantizedVectorData);
+ return new ScalarQuantizedVectorValues(rawVectorsReader.getFloatVectorValues(field), sqvv);
+ }
+
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ return rawVectorsReader.getByteVectorValues(field);
+ }
+
+ @Override
+ public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
+ throws IOException {
+ rawVectorsReader.search(field, target, knnCollector, acceptDocs);
+ }
+
+ @Override
+ public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
+ throws IOException {
+ if (knnCollector.k() == 0) return;
+ final RandomVectorScorer scorer = getRandomVectorScorer(field, target);
+ if (scorer == null) return;
+ OrdinalTranslatedKnnCollector collector =
+ new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
+ Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
+ for (int i = 0; i < scorer.maxOrd(); i++) {
+ if (acceptedOrds == null || acceptedOrds.get(i)) {
+ collector.collect(i, scorer.score(i));
+ collector.incVisitedCount(1);
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(quantizedVectorData, rawVectorsReader);
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long size = SHALLOW_SIZE;
+ size +=
+ RamUsageEstimator.sizeOfMap(
+ fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
+ size += rawVectorsReader.ramBytesUsed();
+ return size;
+ }
+
+ @Override
+ public Map getOffHeapByteSize(FieldInfo fieldInfo) {
+ Objects.requireNonNull(fieldInfo);
+ var raw = rawVectorsReader.getOffHeapByteSize(fieldInfo);
+ var fieldEntry = fields.get(fieldInfo.name);
+ if (fieldEntry == null) {
+ assert fieldInfo.getVectorEncoding() == VectorEncoding.BYTE;
+ return raw;
+ }
+ var quant = Map.of(VECTOR_DATA_EXTENSION, fieldEntry.vectorDataLength());
+ return KnnVectorsReader.mergeOffHeapByteSizeMaps(raw, quant);
+ }
+
+ public float[] getCentroid(String field) {
+ FieldEntry fieldEntry = fields.get(field);
+ if (fieldEntry != null) {
+ return fieldEntry.centroid;
+ }
+ return null;
+ }
+
+ private static IndexInput openDataInput(
+ SegmentReadState state,
+ int versionMeta,
+ String fileExtension,
+ String codecName,
+ IOContext context)
+ throws IOException {
+ String fileName =
+ IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
+ IndexInput in = state.directory.openInput(fileName, context);
+ try {
+ int versionVectorData =
+ CodecUtil.checkIndexHeader(
+ in,
+ codecName,
+ Lucene104ScalarQuantizedVectorsFormat.VERSION_START,
+ Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ if (versionMeta != versionVectorData) {
+ throw new CorruptIndexException(
+ "Format versions mismatch: meta="
+ + versionMeta
+ + ", "
+ + codecName
+ + "="
+ + versionVectorData,
+ in);
+ }
+ CodecUtil.retrieveChecksum(in);
+ return in;
+ } catch (Throwable t) {
+ IOUtils.closeWhileSuppressingExceptions(t, in);
+ throw t;
+ }
+ }
+
+ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
+ VectorEncoding vectorEncoding = readVectorEncoding(input);
+ VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
+ if (similarityFunction != info.getVectorSimilarityFunction()) {
+ throw new IllegalStateException(
+ "Inconsistent vector similarity function for field=\""
+ + info.name
+ + "\"; "
+ + similarityFunction
+ + " != "
+ + info.getVectorSimilarityFunction());
+ }
+ return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction());
+ }
+
+ private record FieldEntry(
+ VectorSimilarityFunction similarityFunction,
+ VectorEncoding vectorEncoding,
+ int dimension,
+ long vectorDataOffset,
+ long vectorDataLength,
+ int size,
+ ScalarEncoding scalarEncoding,
+ float[] centroid,
+ float centroidDP,
+ OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) {
+
+ static FieldEntry create(
+ IndexInput input,
+ VectorEncoding vectorEncoding,
+ VectorSimilarityFunction similarityFunction)
+ throws IOException {
+ int dimension = input.readVInt();
+ long vectorDataOffset = input.readVLong();
+ long vectorDataLength = input.readVLong();
+ int size = input.readVInt();
+ final float[] centroid;
+ float centroidDP = 0;
+ ScalarEncoding scalarEncoding = ScalarEncoding.UNSIGNED_BYTE;
+ if (size > 0) {
+ int wireNumber = input.readVInt();
+ scalarEncoding =
+ ScalarEncoding.fromWireNumber(wireNumber)
+ .orElseThrow(
+ () ->
+ new IllegalStateException(
+ "Could not get ScalarEncoding from wire number: " + wireNumber));
+ centroid = new float[dimension];
+ input.readFloats(centroid, 0, dimension);
+ centroidDP = Float.intBitsToFloat(input.readInt());
+ } else {
+ centroid = null;
+ }
+ OrdToDocDISIReaderConfiguration conf =
+ OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
+ return new FieldEntry(
+ similarityFunction,
+ vectorEncoding,
+ dimension,
+ vectorDataOffset,
+ vectorDataLength,
+ size,
+ scalarEncoding,
+ centroid,
+ centroidDP,
+ conf);
+ }
+ }
+
+ /** Vector values holding row and quantized vector values */
+ protected static final class ScalarQuantizedVectorValues extends FloatVectorValues {
+ private final FloatVectorValues rawVectorValues;
+ private final QuantizedByteVectorValues quantizedVectorValues;
+
+ ScalarQuantizedVectorValues(
+ FloatVectorValues rawVectorValues, QuantizedByteVectorValues quantizedVectorValues) {
+ this.rawVectorValues = rawVectorValues;
+ this.quantizedVectorValues = quantizedVectorValues;
+ }
+
+ @Override
+ public int dimension() {
+ return rawVectorValues.dimension();
+ }
+
+ @Override
+ public int size() {
+ return rawVectorValues.size();
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ return rawVectorValues.vectorValue(ord);
+ }
+
+ @Override
+ public ScalarQuantizedVectorValues copy() throws IOException {
+ return new ScalarQuantizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy());
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return rawVectorValues.getAcceptOrds(acceptDocs);
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return rawVectorValues.ordToDoc(ord);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return rawVectorValues.iterator();
+ }
+
+ @Override
+ public VectorScorer scorer(float[] query) throws IOException {
+ return quantizedVectorValues.scorer(query);
+ }
+
+ QuantizedByteVectorValues getQuantizedVectorValues() throws IOException {
+ return quantizedVectorValues;
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java
new file mode 100644
index 000000000000..12de1f83bd36
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java
@@ -0,0 +1,861 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene104;
+
+import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
+import static org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
+import org.apache.lucene.index.DocsWithFieldSet;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.MergeState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.internal.hppc.FloatArrayList;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+/** Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */
+public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
+ private static final long SHALLOW_RAM_BYTES_USED =
+ shallowSizeOfInstance(Lucene104ScalarQuantizedVectorsWriter.class);
+
+ private final SegmentWriteState segmentWriteState;
+ private final List fields = new ArrayList<>();
+ private final IndexOutput meta, vectorData;
+ private final ScalarEncoding encoding;
+ private final FlatVectorsWriter rawVectorDelegate;
+ private final Lucene104ScalarQuantizedVectorScorer vectorsScorer;
+ private boolean finished;
+
+ /**
+ * Sole constructor
+ *
+ * @param vectorsScorer the scorer to use for scoring vectors
+ */
+ protected Lucene104ScalarQuantizedVectorsWriter(
+ SegmentWriteState state,
+ ScalarEncoding encoding,
+ FlatVectorsWriter rawVectorDelegate,
+ Lucene104ScalarQuantizedVectorScorer vectorsScorer)
+ throws IOException {
+ super(vectorsScorer);
+ this.encoding = encoding;
+ this.vectorsScorer = vectorsScorer;
+ this.segmentWriteState = state;
+ String metaFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene104ScalarQuantizedVectorsFormat.META_EXTENSION);
+
+ String vectorDataFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION);
+ this.rawVectorDelegate = rawVectorDelegate;
+ try {
+ meta = state.directory.createOutput(metaFileName, state.context);
+ vectorData = state.directory.createOutput(vectorDataFileName, state.context);
+
+ CodecUtil.writeIndexHeader(
+ meta,
+ Lucene104ScalarQuantizedVectorsFormat.META_CODEC_NAME,
+ Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ CodecUtil.writeIndexHeader(
+ vectorData,
+ Lucene104ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
+ Lucene104ScalarQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ } catch (Throwable t) {
+ IOUtils.closeWhileSuppressingExceptions(t, this);
+ throw t;
+ }
+ }
+
+ @Override
+ public FlatFieldVectorsWriter> addField(FieldInfo fieldInfo) throws IOException {
+ FlatFieldVectorsWriter> rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo);
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ @SuppressWarnings("unchecked")
+ FieldWriter fieldWriter =
+ new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate);
+ fields.add(fieldWriter);
+ return fieldWriter;
+ }
+ return rawVectorDelegate;
+ }
+
+ @Override
+ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
+ rawVectorDelegate.flush(maxDoc, sortMap);
+ for (FieldWriter field : fields) {
+ // after raw vectors are written, normalize vectors for clustering and quantization
+ if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) {
+ field.normalizeVectors();
+ }
+ final float[] clusterCenter;
+ int vectorCount = field.flatFieldVectorsWriter.getVectors().size();
+ clusterCenter = new float[field.dimensionSums.length];
+ if (vectorCount > 0) {
+ for (int i = 0; i < field.dimensionSums.length; i++) {
+ clusterCenter[i] = field.dimensionSums[i] / vectorCount;
+ }
+ if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) {
+ VectorUtil.l2normalize(clusterCenter);
+ }
+ }
+ if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(
+ QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ OptimizedScalarQuantizer quantizer =
+ new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction());
+ if (sortMap == null) {
+ writeField(field, clusterCenter, maxDoc, quantizer);
+ } else {
+ writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer);
+ }
+ field.finish();
+ }
+ }
+
+ private void writeField(
+ FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer)
+ throws IOException {
+ // write vector values
+ long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
+ writeVectors(fieldData, clusterCenter, quantizer);
+ long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
+ float centroidDp =
+ !fieldData.getVectors().isEmpty() ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0;
+
+ writeMeta(
+ fieldData.fieldInfo,
+ maxDoc,
+ vectorDataOffset,
+ vectorDataLength,
+ clusterCenter,
+ centroidDp,
+ fieldData.getDocsWithFieldSet());
+ }
+
+ private void writeVectors(
+ FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer)
+ throws IOException {
+ byte[] scratch =
+ new byte
+ [OptimizedScalarQuantizer.discretize(
+ fieldData.fieldInfo.getVectorDimension(), encoding.getDimensionsPerByte())];
+ byte[] vector =
+ switch (encoding) {
+ case UNSIGNED_BYTE -> scratch;
+ case SEVEN_BIT -> scratch;
+ case PACKED_NIBBLE ->
+ new byte[encoding.getPackedLength(fieldData.fieldInfo.getVectorDimension())];
+ };
+ for (int i = 0; i < fieldData.getVectors().size(); i++) {
+ float[] v = fieldData.getVectors().get(i);
+ OptimizedScalarQuantizer.QuantizationResult corrections =
+ scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter);
+ if (encoding == ScalarEncoding.PACKED_NIBBLE) {
+ OffHeapScalarQuantizedVectorValues.packNibbles(scratch, vector);
+ }
+ vectorData.writeBytes(vector, vector.length);
+ vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
+ vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval()));
+ vectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
+ vectorData.writeInt(corrections.quantizedComponentSum());
+ }
+ }
+
+ private void writeSortingField(
+ FieldWriter fieldData,
+ float[] clusterCenter,
+ int maxDoc,
+ Sorter.DocMap sortMap,
+ OptimizedScalarQuantizer scalarQuantizer)
+ throws IOException {
+ final int[] ordMap =
+ new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord
+
+ DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
+ mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField);
+
+ // write vector values
+ long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
+ writeSortedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer);
+ long quantizedVectorLength = vectorData.getFilePointer() - vectorDataOffset;
+
+ float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter);
+ writeMeta(
+ fieldData.fieldInfo,
+ maxDoc,
+ vectorDataOffset,
+ quantizedVectorLength,
+ clusterCenter,
+ centroidDp,
+ newDocsWithField);
+ }
+
+ private void writeSortedVectors(
+ FieldWriter fieldData,
+ float[] clusterCenter,
+ int[] ordMap,
+ OptimizedScalarQuantizer scalarQuantizer)
+ throws IOException {
+ byte[] scratch =
+ new byte
+ [OptimizedScalarQuantizer.discretize(
+ fieldData.fieldInfo.getVectorDimension(), encoding.getDimensionsPerByte())];
+ byte[] vector =
+ switch (encoding) {
+ case UNSIGNED_BYTE -> scratch;
+ case SEVEN_BIT -> scratch;
+ case PACKED_NIBBLE ->
+ new byte[encoding.getPackedLength(fieldData.fieldInfo.getVectorDimension())];
+ };
+ for (int ordinal : ordMap) {
+ float[] v = fieldData.getVectors().get(ordinal);
+ OptimizedScalarQuantizer.QuantizationResult corrections =
+ scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter);
+ if (encoding == ScalarEncoding.PACKED_NIBBLE) {
+ OffHeapScalarQuantizedVectorValues.packNibbles(scratch, vector);
+ }
+ vectorData.writeBytes(vector, vector.length);
+ vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
+ vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval()));
+ vectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
+ vectorData.writeInt(corrections.quantizedComponentSum());
+ }
+ }
+
+ private void writeMeta(
+ FieldInfo field,
+ int maxDoc,
+ long vectorDataOffset,
+ long vectorDataLength,
+ float[] clusterCenter,
+ float centroidDp,
+ DocsWithFieldSet docsWithField)
+ throws IOException {
+ meta.writeInt(field.number);
+ meta.writeInt(field.getVectorEncoding().ordinal());
+ meta.writeInt(field.getVectorSimilarityFunction().ordinal());
+ meta.writeVInt(field.getVectorDimension());
+ meta.writeVLong(vectorDataOffset);
+ meta.writeVLong(vectorDataLength);
+ int count = docsWithField.cardinality();
+ meta.writeVInt(count);
+ if (count > 0) {
+ meta.writeVInt(encoding.getWireNumber());
+ final ByteBuffer buffer =
+ ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES)
+ .order(ByteOrder.LITTLE_ENDIAN);
+ buffer.asFloatBuffer().put(clusterCenter);
+ meta.writeBytes(buffer.array(), buffer.array().length);
+ meta.writeInt(Float.floatToIntBits(centroidDp));
+ }
+ OrdToDocDISIReaderConfiguration.writeStoredMeta(
+ DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField);
+ }
+
+ @Override
+ public void finish() throws IOException {
+ if (finished) {
+ throw new IllegalStateException("already finished");
+ }
+ finished = true;
+ rawVectorDelegate.finish();
+ if (meta != null) {
+ // write end of fields marker
+ meta.writeInt(-1);
+ CodecUtil.writeFooter(meta);
+ }
+ if (vectorData != null) {
+ CodecUtil.writeFooter(vectorData);
+ }
+ }
+
+ @Override
+ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ return;
+ }
+
+ final float[] centroid;
+ final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()];
+ int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid);
+ // Don't need access to the random vectors, we can just use the merged
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ centroid = mergedCentroid;
+ if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(
+ QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ FloatVectorValues floatVectorValues =
+ MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
+ }
+ QuantizedFloatVectorValues quantizedVectorValues =
+ new QuantizedFloatVectorValues(
+ floatVectorValues,
+ new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()),
+ encoding,
+ centroid);
+ long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
+ DocsWithFieldSet docsWithField = writeVectorData(vectorData, quantizedVectorValues);
+ long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
+ float centroidDp =
+ docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
+ writeMeta(
+ fieldInfo,
+ segmentWriteState.segmentInfo.maxDoc(),
+ vectorDataOffset,
+ vectorDataLength,
+ centroid,
+ centroidDp,
+ docsWithField);
+ }
+
+ static DocsWithFieldSet writeVectorData(
+ IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException {
+ DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ KnnVectorValues.DocIndexIterator iterator = quantizedByteVectorValues.iterator();
+ for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
+ // write vector
+ byte[] binaryValue = quantizedByteVectorValues.vectorValue(iterator.index());
+ output.writeBytes(binaryValue, binaryValue.length);
+ OptimizedScalarQuantizer.QuantizationResult corrections =
+ quantizedByteVectorValues.getCorrectiveTerms(iterator.index());
+ output.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
+ output.writeInt(Float.floatToIntBits(corrections.upperInterval()));
+ output.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
+ output.writeInt(corrections.quantizedComponentSum());
+ docsWithField.add(docV);
+ }
+ return docsWithField;
+ }
+
+ @Override
+ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
+ FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState);
+ }
+
+ final float[] centroid;
+ final float cDotC;
+ final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()];
+ int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid);
+
+ // Don't need access to the random vectors, we can just use the merged
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ centroid = mergedCentroid;
+ cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
+ if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(
+ QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC);
+ }
+
+ private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
+ SegmentWriteState segmentWriteState,
+ FieldInfo fieldInfo,
+ MergeState mergeState,
+ float[] centroid,
+ float cDotC)
+ throws IOException {
+ long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
+ IndexOutput tempQuantizedVectorData = null;
+ IndexInput quantizedDataInput = null;
+ OptimizedScalarQuantizer quantizer =
+ new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
+ try {
+ tempQuantizedVectorData =
+ segmentWriteState.directory.createTempOutput(
+ vectorData.getName(), "temp", segmentWriteState.context);
+ final String tempQuantizedVectorName = tempQuantizedVectorData.getName();
+ FloatVectorValues floatVectorValues =
+ MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
+ }
+ DocsWithFieldSet docsWithField =
+ writeVectorData(
+ tempQuantizedVectorData,
+ new QuantizedFloatVectorValues(floatVectorValues, quantizer, encoding, centroid));
+ CodecUtil.writeFooter(tempQuantizedVectorData);
+ IOUtils.close(tempQuantizedVectorData);
+ quantizedDataInput =
+ segmentWriteState.directory.openInput(tempQuantizedVectorName, segmentWriteState.context);
+ vectorData.copyBytes(
+ quantizedDataInput, quantizedDataInput.length() - CodecUtil.footerLength());
+ long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
+ CodecUtil.retrieveChecksum(quantizedDataInput);
+ writeMeta(
+ fieldInfo,
+ segmentWriteState.segmentInfo.maxDoc(),
+ vectorDataOffset,
+ vectorDataLength,
+ centroid,
+ cDotC,
+ docsWithField);
+
+ final IndexInput finalQuantizedDataInput = quantizedDataInput;
+ tempQuantizedVectorData = null;
+ quantizedDataInput = null;
+
+ OffHeapScalarQuantizedVectorValues vectorValues =
+ new OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues(
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ centroid,
+ cDotC,
+ quantizer,
+ encoding,
+ fieldInfo.getVectorSimilarityFunction(),
+ vectorsScorer,
+ finalQuantizedDataInput);
+ RandomVectorScorerSupplier scorerSupplier =
+ vectorsScorer.getRandomVectorScorerSupplier(
+ fieldInfo.getVectorSimilarityFunction(), vectorValues);
+ return new QuantizedCloseableRandomVectorScorerSupplier(
+ scorerSupplier,
+ vectorValues,
+ () -> {
+ IOUtils.close(finalQuantizedDataInput);
+ IOUtils.deleteFilesIgnoringExceptions(
+ segmentWriteState.directory, tempQuantizedVectorName);
+ });
+ } catch (Throwable t) {
+ IOUtils.closeWhileSuppressingExceptions(t, tempQuantizedVectorData, quantizedDataInput);
+ if (tempQuantizedVectorData != null) {
+ IOUtils.deleteFilesSuppressingExceptions(
+ t, segmentWriteState.directory, tempQuantizedVectorData.getName());
+ }
+ throw t;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(meta, vectorData, rawVectorDelegate);
+ }
+
+ static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) {
+ if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
+ vectorsReader = candidateReader.getFieldReader(fieldName);
+ }
+ if (vectorsReader instanceof Lucene104ScalarQuantizedVectorsReader reader) {
+ return reader.getCentroid(fieldName);
+ }
+ return null;
+ }
+
+ static int mergeAndRecalculateCentroids(
+ MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException {
+ boolean recalculate = false;
+ int totalVectorCount = 0;
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+ if (knnVectorsReader == null
+ || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) {
+ continue;
+ }
+ float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name);
+ int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size();
+ if (vectorCount == 0) {
+ continue;
+ }
+ totalVectorCount += vectorCount;
+ // If there aren't centroids, or previously clustered with more than one cluster
+ // or if there are deleted docs, we must recalculate the centroid
+ if (centroid == null || mergeState.liveDocs[i] != null) {
+ recalculate = true;
+ break;
+ }
+ for (int j = 0; j < centroid.length; j++) {
+ mergedCentroid[j] += centroid[j] * vectorCount;
+ }
+ }
+ if (recalculate) {
+ return calculateCentroid(mergeState, fieldInfo, mergedCentroid);
+ } else {
+ for (int j = 0; j < mergedCentroid.length; j++) {
+ mergedCentroid[j] = mergedCentroid[j] / totalVectorCount;
+ }
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ VectorUtil.l2normalize(mergedCentroid);
+ }
+ return totalVectorCount;
+ }
+ }
+
+ static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid)
+ throws IOException {
+ assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
+ // clear out the centroid
+ Arrays.fill(centroid, 0);
+ int count = 0;
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+ if (knnVectorsReader == null) continue;
+ FloatVectorValues vectorValues =
+ mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
+ if (vectorValues == null) {
+ continue;
+ }
+ KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
+ for (int doc = iterator.nextDoc();
+ doc != DocIdSetIterator.NO_MORE_DOCS;
+ doc = iterator.nextDoc()) {
+ ++count;
+ float[] vector = vectorValues.vectorValue(iterator.index());
+ for (int j = 0; j < vector.length; j++) {
+ centroid[j] += vector[j];
+ }
+ }
+ }
+ if (count == 0) {
+ return count;
+ }
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] /= count;
+ }
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ VectorUtil.l2normalize(centroid);
+ }
+ return count;
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long total = SHALLOW_RAM_BYTES_USED;
+ for (FieldWriter field : fields) {
+ // the field tracks the delegate field usage
+ total += field.ramBytesUsed();
+ }
+ return total;
+ }
+
+ static class FieldWriter extends FlatFieldVectorsWriter {
+ private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
+ private final FieldInfo fieldInfo;
+ private boolean finished;
+ private final FlatFieldVectorsWriter flatFieldVectorsWriter;
+ private final float[] dimensionSums;
+ private final FloatArrayList magnitudes = new FloatArrayList();
+
+ FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) {
+ this.fieldInfo = fieldInfo;
+ this.flatFieldVectorsWriter = flatFieldVectorsWriter;
+ this.dimensionSums = new float[fieldInfo.getVectorDimension()];
+ }
+
+ @Override
+ public List getVectors() {
+ return flatFieldVectorsWriter.getVectors();
+ }
+
+ public void normalizeVectors() {
+ for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) {
+ float[] vector = flatFieldVectorsWriter.getVectors().get(i);
+ float magnitude = magnitudes.get(i);
+ for (int j = 0; j < vector.length; j++) {
+ vector[j] /= magnitude;
+ }
+ }
+ }
+
+ @Override
+ public DocsWithFieldSet getDocsWithFieldSet() {
+ return flatFieldVectorsWriter.getDocsWithFieldSet();
+ }
+
+ @Override
+ public void finish() throws IOException {
+ if (finished) {
+ return;
+ }
+ assert flatFieldVectorsWriter.isFinished();
+ finished = true;
+ }
+
+ @Override
+ public boolean isFinished() {
+ return finished && flatFieldVectorsWriter.isFinished();
+ }
+
+ @Override
+ public void addValue(int docID, float[] vectorValue) throws IOException {
+ flatFieldVectorsWriter.addValue(docID, vectorValue);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ float dp = VectorUtil.dotProduct(vectorValue, vectorValue);
+ float divisor = (float) Math.sqrt(dp);
+ magnitudes.add(divisor);
+ for (int i = 0; i < vectorValue.length; i++) {
+ dimensionSums[i] += (vectorValue[i] / divisor);
+ }
+ } else {
+ for (int i = 0; i < vectorValue.length; i++) {
+ dimensionSums[i] += vectorValue[i];
+ }
+ }
+ }
+
+ @Override
+ public float[] copyValue(float[] vectorValue) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long size = SHALLOW_SIZE;
+ size += flatFieldVectorsWriter.ramBytesUsed();
+ size += magnitudes.ramBytesUsed();
+ return size;
+ }
+ }
+
+ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
+ private OptimizedScalarQuantizer.QuantizationResult corrections;
+ private final byte[] quantized;
+ private final byte[] packed;
+ private final float[] centroid;
+ private final float centroidDP;
+ private final FloatVectorValues values;
+ private final OptimizedScalarQuantizer quantizer;
+ private final ScalarEncoding encoding;
+
+ private int lastOrd = -1;
+
+ QuantizedFloatVectorValues(
+ FloatVectorValues delegate,
+ OptimizedScalarQuantizer quantizer,
+ ScalarEncoding encoding,
+ float[] centroid) {
+ this.values = delegate;
+ this.quantizer = quantizer;
+ this.encoding = encoding;
+ this.quantized =
+ new byte
+ [OptimizedScalarQuantizer.discretize(
+ delegate.dimension(), encoding.getDimensionsPerByte())];
+ this.packed =
+ switch (encoding) {
+ case UNSIGNED_BYTE -> this.quantized;
+ case SEVEN_BIT -> this.quantized;
+ case PACKED_NIBBLE -> new byte[encoding.getPackedLength(delegate.dimension())];
+ };
+ this.centroid = centroid;
+ this.centroidDP = VectorUtil.dotProduct(centroid, centroid);
+ }
+
+ @Override
+ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
+ if (ord != lastOrd) {
+ throw new IllegalStateException(
+ "attempt to retrieve corrective terms for different ord "
+ + ord
+ + " than the quantization was done for: "
+ + lastOrd);
+ }
+ return corrections;
+ }
+
+ @Override
+ public byte[] vectorValue(int ord) throws IOException {
+ if (ord != lastOrd) {
+ quantize(ord);
+ lastOrd = ord;
+ }
+ return packed;
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public OptimizedScalarQuantizer getQuantizer() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ScalarEncoding getScalarEncoding() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float[] getCentroid() throws IOException {
+ return centroid;
+ }
+
+ @Override
+ public float getCentroidDP() {
+ return centroidDP;
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public QuantizedByteVectorValues copy() throws IOException {
+ return new QuantizedFloatVectorValues(values.copy(), quantizer, encoding, centroid);
+ }
+
+ private void quantize(int ord) throws IOException {
+ corrections =
+ quantizer.scalarQuantize(
+ values.vectorValue(ord), quantized, encoding.getBits(), centroid);
+ if (encoding == ScalarEncoding.PACKED_NIBBLE) {
+ OffHeapScalarQuantizedVectorValues.packNibbles(quantized, packed);
+ }
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+ }
+
+ static class QuantizedCloseableRandomVectorScorerSupplier
+ implements CloseableRandomVectorScorerSupplier {
+ private final RandomVectorScorerSupplier supplier;
+ private final KnnVectorValues vectorValues;
+ private final Closeable onClose;
+
+ QuantizedCloseableRandomVectorScorerSupplier(
+ RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) {
+ this.supplier = supplier;
+ this.onClose = onClose;
+ this.vectorValues = vectorValues;
+ }
+
+ @Override
+ public UpdateableRandomVectorScorer scorer() throws IOException {
+ return supplier.scorer();
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return supplier.copy();
+ }
+
+ @Override
+ public void close() throws IOException {
+ onClose.close();
+ }
+
+ @Override
+ public int totalVectorCount() {
+ return vectorValues.size();
+ }
+ }
+
+ static final class NormalizedFloatVectorValues extends FloatVectorValues {
+ private final FloatVectorValues values;
+ private final float[] normalizedVector;
+
+ NormalizedFloatVectorValues(FloatVectorValues values) {
+ this.values = values;
+ this.normalizedVector = new float[values.dimension()];
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length);
+ VectorUtil.l2normalize(normalizedVector);
+ return normalizedVector;
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+
+ @Override
+ public NormalizedFloatVectorValues copy() throws IOException {
+ return new NormalizedFloatVectorValues(values.copy());
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java
new file mode 100644
index 000000000000..d2c678d8f8ba
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene104;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
+import org.apache.lucene.codecs.lucene90.IndexedDISI;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.packed.DirectMonotonicReader;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+/**
+ * Scalar quantized vector values loaded from off-heap
+ *
+ * @lucene.internal
+ */
+public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVectorValues {
+ final int dimension;
+ final int size;
+ final VectorSimilarityFunction similarityFunction;
+ final FlatVectorsScorer vectorsScorer;
+
+ final IndexInput slice;
+ final byte[] vectorValue;
+ final ByteBuffer byteBuffer;
+ final int byteSize;
+ private int lastOrd = -1;
+ final float[] correctiveValues;
+ int quantizedComponentSum;
+ final OptimizedScalarQuantizer quantizer;
+ final ScalarEncoding encoding;
+ final float[] centroid;
+ final float centroidDp;
+
+ OffHeapScalarQuantizedVectorValues(
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ OptimizedScalarQuantizer quantizer,
+ ScalarEncoding encoding,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice) {
+ this.dimension = dimension;
+ this.size = size;
+ this.similarityFunction = similarityFunction;
+ this.vectorsScorer = vectorsScorer;
+ this.slice = slice;
+ this.centroid = centroid;
+ this.centroidDp = centroidDp;
+ this.correctiveValues = new float[3];
+ this.byteSize = encoding.getPackedLength(dimension) + (Float.BYTES * 3) + Integer.BYTES;
+ this.byteBuffer = ByteBuffer.allocate(encoding.getPackedLength(dimension));
+ this.vectorValue = byteBuffer.array();
+ this.quantizer = quantizer;
+ this.encoding = encoding;
+ }
+
+ @Override
+ public int dimension() {
+ return dimension;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public byte[] vectorValue(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return vectorValue;
+ }
+ slice.seek((long) targetOrd * byteSize);
+ slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length);
+ slice.readFloats(correctiveValues, 0, 3);
+ quantizedComponentSum = slice.readInt();
+ lastOrd = targetOrd;
+ return vectorValue;
+ }
+
+ @Override
+ public float getCentroidDP() {
+ return centroidDp;
+ }
+
+ @Override
+ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd)
+ throws IOException {
+ if (lastOrd == targetOrd) {
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum);
+ }
+ slice.seek(((long) targetOrd * byteSize) + dimension);
+ slice.readFloats(correctiveValues, 0, 3);
+ quantizedComponentSum = slice.readInt();
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum);
+ }
+
+ @Override
+ public OptimizedScalarQuantizer getQuantizer() {
+ return quantizer;
+ }
+
+ @Override
+ public ScalarEncoding getScalarEncoding() {
+ return encoding;
+ }
+
+ @Override
+ public float[] getCentroid() {
+ return centroid;
+ }
+
+ @Override
+ public int getVectorByteLength() {
+ return dimension;
+ }
+
+ static void packNibbles(byte[] unpacked, byte[] packed) {
+ assert unpacked.length == packed.length * 2;
+ for (int i = 0; i < packed.length; i++) {
+ int x = unpacked[i] << 4 | unpacked[packed.length + i];
+ packed[i] = (byte) x;
+ }
+ }
+
+ static void unpackNibbles(byte[] packed, byte[] unpacked) {
+ assert unpacked.length == packed.length * 2;
+ for (int i = 0; i < packed.length; i++) {
+ unpacked[i] = (byte) ((packed[i] >> 4) & 0x0F);
+ unpacked[packed.length + i] = (byte) (packed[i] & 0x0F);
+ }
+ }
+
+ static OffHeapScalarQuantizedVectorValues load(
+ OrdToDocDISIReaderConfiguration configuration,
+ int dimension,
+ int size,
+ OptimizedScalarQuantizer quantizer,
+ ScalarEncoding encoding,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ float[] centroid,
+ float centroidDp,
+ long quantizedVectorDataOffset,
+ long quantizedVectorDataLength,
+ IndexInput vectorData)
+ throws IOException {
+ if (configuration.isEmpty()) {
+ return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer);
+ }
+ assert centroid != null;
+ IndexInput bytesSlice =
+ vectorData.slice(
+ "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
+ if (configuration.isDense()) {
+ return new DenseOffHeapVectorValues(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ quantizer,
+ encoding,
+ similarityFunction,
+ vectorsScorer,
+ bytesSlice);
+ } else {
+ return new SparseOffHeapVectorValues(
+ configuration,
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ quantizer,
+ encoding,
+ vectorData,
+ similarityFunction,
+ vectorsScorer,
+ bytesSlice);
+ }
+ }
+
+ /** Dense off-heap scalar quantized vector values */
+ static class DenseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues {
+ DenseOffHeapVectorValues(
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ OptimizedScalarQuantizer quantizer,
+ ScalarEncoding encoding,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice) {
+ super(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ quantizer,
+ encoding,
+ similarityFunction,
+ vectorsScorer,
+ slice);
+ }
+
+ @Override
+ public OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy() throws IOException {
+ return new OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ quantizer,
+ encoding,
+ similarityFunction,
+ vectorsScorer,
+ slice.clone());
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return acceptDocs;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy = copy();
+ DocIndexIterator iterator = copy.iterator();
+ RandomVectorScorer scorer =
+ vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
+ return new VectorScorer() {
+ @Override
+ public float score() throws IOException {
+ return scorer.score(iterator.index());
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return iterator;
+ }
+ };
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return createDenseIterator();
+ }
+ }
+
+ /** Sparse off-heap scalar quantized vector values */
+ private static class SparseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues {
+ private final DirectMonotonicReader ordToDoc;
+ private final IndexedDISI disi;
+ // dataIn was used to init a new IndexedDIS for #randomAccess()
+ private final IndexInput dataIn;
+ private final OrdToDocDISIReaderConfiguration configuration;
+
+ SparseOffHeapVectorValues(
+ OrdToDocDISIReaderConfiguration configuration,
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ OptimizedScalarQuantizer quantizer,
+ ScalarEncoding encoding,
+ IndexInput dataIn,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice)
+ throws IOException {
+ super(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ quantizer,
+ encoding,
+ similarityFunction,
+ vectorsScorer,
+ slice);
+ this.configuration = configuration;
+ this.dataIn = dataIn;
+ this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
+ this.disi = configuration.getIndexedDISI(dataIn);
+ }
+
+ @Override
+ public SparseOffHeapVectorValues copy() throws IOException {
+ return new SparseOffHeapVectorValues(
+ configuration,
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ quantizer,
+ encoding,
+ dataIn,
+ similarityFunction,
+ vectorsScorer,
+ slice.clone());
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return (int) ordToDoc.get(ord);
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ if (acceptDocs == null) {
+ return null;
+ }
+ return new Bits() {
+ @Override
+ public boolean get(int index) {
+ return acceptDocs.get(ordToDoc(index));
+ }
+
+ @Override
+ public int length() {
+ return size;
+ }
+ };
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return IndexedDISI.asDocIndexIterator(disi);
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ SparseOffHeapVectorValues copy = copy();
+ DocIndexIterator iterator = copy.iterator();
+ RandomVectorScorer scorer =
+ vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
+ return new VectorScorer() {
+ @Override
+ public float score() throws IOException {
+ return scorer.score(iterator.index());
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return iterator;
+ }
+ };
+ }
+ }
+
+ private static class EmptyOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues {
+ EmptyOffHeapVectorValues(
+ int dimension,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer) {
+ super(
+ dimension,
+ 0,
+ null,
+ Float.NaN,
+ null,
+ ScalarEncoding.UNSIGNED_BYTE,
+ similarityFunction,
+ vectorsScorer,
+ null);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return createDenseIterator();
+ }
+
+ @Override
+ public OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return null;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) {
+ return null;
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java
new file mode 100644
index 000000000000..48d0c4e665f1
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene104;
+
+import java.io.IOException;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+/** Scalar quantized byte vector values */
+abstract class QuantizedByteVectorValues extends ByteVectorValues {
+
+ /**
+ * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of
+ * distances, the corrective terms are, in order
+ *
+ *
+ * - the lower optimized interval
+ *
- the upper optimized interval
+ *
- the dot-product of the non-centered vector with the centroid
+ *
- the sum of quantized components
+ *
+ *
+ * For euclidean:
+ *
+ *
+ * - the lower optimized interval
+ *
- the upper optimized interval
+ *
- the l2norm of the centered vector
+ *
- the sum of quantized components
+ *
+ *
+ * @param vectorOrd the vector ordinal
+ * @return the corrective terms
+ * @throws IOException if an I/O error occurs
+ */
+ public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd)
+ throws IOException;
+
+ /**
+ * @return the quantizer used to quantize the vectors
+ */
+ public abstract OptimizedScalarQuantizer getQuantizer();
+
+ /**
+ * @return the scalar encoding used to pack the vectors.
+ */
+ public abstract ScalarEncoding getScalarEncoding();
+
+ /**
+ * @return the centroid used to center the vectors prior to quantization
+ */
+ public abstract float[] getCentroid() throws IOException;
+
+ /**
+ * @return the dot product of the centroid.
+ */
+ public abstract float getCentroidDP() throws IOException;
+
+ /**
+ * Return a {@link VectorScorer} for the given query vector.
+ *
+ * @param query the query vector
+ * @return a {@link VectorScorer} instance or null
+ */
+ public abstract VectorScorer scorer(float[] query) throws IOException;
+
+ @Override
+ public abstract QuantizedByteVectorValues copy() throws IOException;
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java
index 1966ed21d654..42e9381aa11d 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java
@@ -42,6 +42,7 @@
*
* @lucene.experimental
*/
+@Deprecated
public class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
public static final String NAME = "Lucene99HnswScalarQuantizedVectorsFormat";
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java
index 0f339ecbe0a8..da18fb8d2e57 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java
@@ -31,6 +31,7 @@
*
* @lucene.experimental
*/
+@Deprecated
public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
// The bits that are allowed for scalar quantization
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
index 38b9cf6d67a5..c27a7b6f5fdb 100644
--- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
@@ -214,7 +214,7 @@ public static int int4DotProduct(byte[] a, byte[] b) {
public static int int4DotProductPacked(byte[] unpacked, byte[] packed) {
if (packed.length != ((unpacked.length + 1) >> 1)) {
throw new IllegalArgumentException(
- "vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length);
+ "vector dimensions differ: " + unpacked.length + " != 2 * " + packed.length);
}
return IMPL.int4DotProduct(unpacked, false, packed, true);
}
diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
index 0558fc8fef05..fde541c2ac08 100644
--- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
+++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
@@ -18,3 +18,5 @@ org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat
org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat
org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat
+org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat
+org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..7f01f1bbe852
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene104;
+
+import static java.lang.String.format;
+import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.oneOf;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Locale;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.FilterCodec;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.index.CodecReader;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.AcceptDocs;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.tests.util.TestUtil;
+import org.apache.lucene.util.SameThreadExecutorService;
+
+public class TestLucene104HnswScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
+
+ private static final KnnVectorsFormat FORMAT = new Lucene104HnswScalarQuantizedVectorsFormat();
+
+ @Override
+ protected Codec getCodec() {
+ return TestUtil.alwaysKnnVectorsFormat(FORMAT);
+ }
+
+ public void testToString() {
+ FilterCodec customCodec =
+ new FilterCodec("foo", Codec.getDefault()) {
+ @Override
+ public KnnVectorsFormat knnVectorsFormat() {
+ return new Lucene104HnswScalarQuantizedVectorsFormat(
+ ScalarEncoding.UNSIGNED_BYTE, 10, 20, 1, null);
+ }
+ };
+ String expectedPattern =
+ "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20,"
+ + " flatVectorFormat=Lucene104ScalarQuantizedVectorsFormat(name=Lucene104ScalarQuantizedVectorsFormat,"
+ + " encoding=UNSIGNED_BYTE,"
+ + " flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()),"
+ + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))";
+
+ var defaultScorer =
+ format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer");
+ var memSegScorer =
+ format(
+ Locale.ROOT,
+ expectedPattern,
+ "Lucene99MemorySegmentFlatVectorsScorer",
+ "Lucene99MemorySegmentFlatVectorsScorer");
+ assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
+ }
+
+ public void testSingleVectorCase() throws Exception {
+ float[] vector = randomVector(random().nextInt(12, 500));
+ for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
+ try (Directory dir = newDirectory();
+ IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ Document doc = new Document();
+ doc.add(new KnnFloatVectorField("f", vector, similarityFunction));
+ w.addDocument(doc);
+ w.commit();
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ FloatVectorValues vectorValues = r.getFloatVectorValues("f");
+ KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
+ assert (vectorValues.size() == 1);
+ while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
+ assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f);
+ }
+ float[] randomVector = randomVector(vector.length);
+ float trueScore = similarityFunction.compare(vector, randomVector);
+ TopDocs td =
+ r.searchNearestVectors(
+ "f",
+ randomVector,
+ 1,
+ AcceptDocs.fromLiveDocs(null, r.maxDoc()),
+ Integer.MAX_VALUE);
+ assertEquals(1, td.totalHits.value());
+ assertTrue(td.scoreDocs[0].score >= 0);
+ // When it's the only vector in a segment, the score should be very close to the true
+ // score
+ assertEquals(trueScore, td.scoreDocs[0].score, 0.01f);
+ }
+ }
+ }
+ }
+
+ public void testLimits() {
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene104HnswScalarQuantizedVectorsFormat(-1, 20));
+ expectThrows(
+ IllegalArgumentException.class, () -> new Lucene104HnswScalarQuantizedVectorsFormat(0, 20));
+ expectThrows(
+ IllegalArgumentException.class, () -> new Lucene104HnswScalarQuantizedVectorsFormat(20, 0));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene104HnswScalarQuantizedVectorsFormat(20, -1));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene104HnswScalarQuantizedVectorsFormat(512 + 1, 20));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene104HnswScalarQuantizedVectorsFormat(20, 3201));
+ expectThrows(
+ IllegalArgumentException.class,
+ () ->
+ new Lucene104HnswScalarQuantizedVectorsFormat(
+ ScalarEncoding.UNSIGNED_BYTE, 20, 100, 1, new SameThreadExecutorService()));
+ }
+
+ // Ensures that all expected vector similarity functions are translatable in the format.
+ public void testVectorSimilarityFuncs() {
+ // This does not necessarily have to be all similarity functions, but
+ // differences should be considered carefully.
+ var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList();
+ assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues);
+ }
+
+ public void testSimpleOffHeapSize() throws IOException {
+ float[] vector = randomVector(random().nextInt(12, 500));
+ try (Directory dir = newDirectory();
+ IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ Document doc = new Document();
+ doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT));
+ w.addDocument(doc);
+ w.commit();
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ if (r instanceof CodecReader codecReader) {
+ KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
+ if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
+ knnVectorsReader = fieldsReader.getFieldReader("f");
+ }
+ var fieldInfo = r.getFieldInfos().fieldInfo("f");
+ var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo);
+ assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec"));
+ assertEquals(1L, (long) offHeap.get("vex"));
+ long corrections = Float.BYTES + Float.BYTES + Float.BYTES + Integer.BYTES;
+ long expected = fieldInfo.getVectorDimension() + corrections;
+ assertEquals(expected, (long) offHeap.get("veq"));
+ assertEquals(3, offHeap.size());
+ }
+ }
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..29041b5b07f0
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene104;
+
+import static java.lang.String.format;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.oneOf;
+
+import java.io.IOException;
+import java.util.Locale;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.FilterCodec;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.tests.util.TestUtil;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+import org.junit.Before;
+
+public class TestLucene104ScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
+
+ private ScalarEncoding encoding;
+ private KnnVectorsFormat format;
+
+ @Before
+ @Override
+ public void setUp() throws Exception {
+ var encodingValues = ScalarEncoding.values();
+ encoding = encodingValues[random().nextInt(encodingValues.length)];
+ format = new Lucene104ScalarQuantizedVectorsFormat(encoding);
+ super.setUp();
+ }
+
+ @Override
+ protected Codec getCodec() {
+ return TestUtil.alwaysKnnVectorsFormat(format);
+ }
+
+ public void testSearch() throws Exception {
+ String fieldName = "field";
+ int numVectors = random().nextInt(99, 500);
+ int dims = random().nextInt(4, 65);
+ float[] vector = randomVector(dims);
+ VectorSimilarityFunction similarityFunction = randomSimilarity();
+ KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction);
+ IndexWriterConfig iwc = newIndexWriterConfig();
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(dir, iwc)) {
+ for (int i = 0; i < numVectors; i++) {
+ Document doc = new Document();
+ knnField.setVectorValue(randomVector(dims));
+ doc.add(knnField);
+ w.addDocument(doc);
+ }
+ w.commit();
+
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ final int k = random().nextInt(5, 50);
+ float[] queryVector = randomVector(dims);
+ Query q = new KnnFloatVectorQuery(fieldName, queryVector, k);
+ TopDocs collectedDocs = searcher.search(q, k);
+ assertEquals(k, collectedDocs.totalHits.value());
+ assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation());
+ }
+ }
+ }
+ }
+
+ public void testToString() {
+ FilterCodec customCodec =
+ new FilterCodec("foo", Codec.getDefault()) {
+ @Override
+ public KnnVectorsFormat knnVectorsFormat() {
+ return new Lucene104ScalarQuantizedVectorsFormat();
+ }
+ };
+ String expectedPattern =
+ "Lucene104ScalarQuantizedVectorsFormat("
+ + "name=Lucene104ScalarQuantizedVectorsFormat, "
+ + "encoding=UNSIGNED_BYTE, "
+ + "flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()), "
+ + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))";
+ var defaultScorer =
+ format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer");
+ var memSegScorer =
+ format(
+ Locale.ROOT,
+ expectedPattern,
+ "Lucene99MemorySegmentFlatVectorsScorer",
+ "Lucene99MemorySegmentFlatVectorsScorer");
+ assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
+ }
+
+ @Override
+ public void testRandomWithUpdatesAndGraph() {
+ // graph not supported
+ }
+
+ @Override
+ public void testSearchWithVisitedLimit() {
+ // visited limit is not respected, as it is brute force search
+ }
+
+ public void testQuantizedVectorsWriteAndRead() throws IOException {
+ String fieldName = "field";
+ int numVectors = random().nextInt(99, 500);
+ int dims = random().nextInt(4, 65);
+
+ float[] vector = randomVector(dims);
+ VectorSimilarityFunction similarityFunction = randomSimilarity();
+ KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction);
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ for (int i = 0; i < numVectors; i++) {
+ Document doc = new Document();
+ knnField.setVectorValue(randomVector(dims));
+ doc.add(knnField);
+ w.addDocument(doc);
+ if (i % 101 == 0) {
+ w.commit();
+ }
+ }
+ w.commit();
+ w.forceMerge(1);
+
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
+ assertEquals(vectorValues.size(), numVectors);
+ QuantizedByteVectorValues qvectorValues =
+ ((Lucene104ScalarQuantizedVectorsReader.ScalarQuantizedVectorValues) vectorValues)
+ .getQuantizedVectorValues();
+ float[] centroid = qvectorValues.getCentroid();
+ assertEquals(centroid.length, dims);
+
+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
+ byte[] scratch =
+ new byte[OptimizedScalarQuantizer.discretize(dims, encoding.getDimensionsPerByte())];
+ byte[] expectedVector = new byte[encoding.getPackedLength(dims)];
+ if (similarityFunction == VectorSimilarityFunction.COSINE) {
+ vectorValues =
+ new Lucene104ScalarQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues);
+ }
+ KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
+
+ while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
+ OptimizedScalarQuantizer.QuantizationResult corrections =
+ quantizer.scalarQuantize(
+ vectorValues.vectorValue(docIndexIterator.index()),
+ scratch,
+ encoding.getBits(),
+ centroid);
+ switch (encoding) {
+ case UNSIGNED_BYTE -> System.arraycopy(scratch, 0, expectedVector, 0, dims);
+ case SEVEN_BIT -> System.arraycopy(scratch, 0, expectedVector, 0, dims);
+ case PACKED_NIBBLE ->
+ OffHeapScalarQuantizedVectorValues.packNibbles(scratch, expectedVector);
+ }
+ assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index()));
+ var actualCorrections = qvectorValues.getCorrectiveTerms(docIndexIterator.index());
+ assertEquals(corrections.lowerInterval(), actualCorrections.lowerInterval(), 0.00001f);
+ assertEquals(corrections.upperInterval(), actualCorrections.upperInterval(), 0.00001f);
+ assertEquals(
+ corrections.additionalCorrection(),
+ actualCorrections.additionalCorrection(),
+ 0.00001f);
+ assertEquals(
+ corrections.quantizedComponentSum(), actualCorrections.quantizedComponentSum());
+ }
+ }
+ }
+ }
+ }
+}