diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java index d14bee5368..b03c296f67 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -86,7 +86,8 @@ public CompactStorageAdapter(@Nonnull final Config config, * @param layer the layer of the node to fetch * @param primaryKey the primary key of the node to fetch * - * @return a future that will complete with the fetched {@link AbstractNode} + * @return a future that will complete with the fetched {@link AbstractNode} or {@code null} if the node cannot + * be fetched * * @throws IllegalStateException if the node cannot be found in the database for the given key */ @@ -101,7 +102,7 @@ protected CompletableFuture> fetchNodeInternal(@Nonn return readTransaction.get(keyBytes) .thenApply(valueBytes -> { if (valueBytes == null) { - throw new IllegalStateException("cannot fetch node"); + return null; } return nodeFromRaw(storageTransform, layer, primaryKey, keyBytes, valueBytes); }); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java index eda5a0c17d..efa2d3181b 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java @@ -21,6 +21,7 @@ package com.apple.foundationdb.async.hnsw; import com.apple.foundationdb.linear.Metric; +import com.google.common.base.Preconditions; import com.google.errorprone.annotations.CanIgnoreReturnValue; import javax.annotation.Nonnull; @@ -31,7 +32,7 @@ */ @SuppressWarnings("checkstyle:MemberName") public final class Config { - public static final long DEFAULT_RANDOM_SEED = 0L; + public static final boolean DEFAULT_DETERMINISTIC_SEEDING = false; @Nonnull public static final Metric DEFAULT_METRIC = Metric.EUCLIDEAN_METRIC; public static final boolean DEFAULT_USE_INLINING = false; public static final int DEFAULT_M = 16; @@ -52,127 +53,55 @@ public final class Config { public static final int DEFAULT_MAX_NUM_CONCURRENT_NODE_FETCHES = 16; public static final int DEFAULT_MAX_NUM_CONCURRENT_NEIGHBOR_FETCHES = 16; - /** - * The random seed that is used to probabilistically determine the highest layer of an insert. - */ - private final long randomSeed; - - /** - * The metric that is used to determine distances between vectors. - */ + private final boolean deterministicSeeding; @Nonnull private final Metric metric; - - /** - * The number of dimensions used. All vectors must have exactly this number of dimensions. - */ private final int numDimensions; - - /** - * Indicator if all layers except layer {@code 0} use inlining. If inlining is used, each node is persisted - * as a key/value pair per neighbor which includes the vectors of the neighbors but not for itself. If inlining is - * not used, each node is persisted as exactly one key/value pair per node which stores its own vector but - * specifically excludes the vectors of the neighbors. - */ private final boolean useInlining; - - /** - * This attribute (named {@code M} by the HNSW paper) is the connectivity value for all nodes stored on any layer. - * While by no means enforced or even enforceable, we strive to create and maintain exactly {@code m} neighbors for - * a node. Due to insert/delete operations it is possible that the actual number of neighbors a node references is - * not exactly {@code m} at any given time. - */ private final int m; - - /** - * This attribute (named {@code M_max} by the HNSW paper) is the maximum connectivity value for nodes stored on a - * layer greater than {@code 0}. We will never create more that {@code mMax} neighbors for a node. That means that - * we even prune the neighbors of a node if the actual number of neighbors would otherwise exceed {@code mMax}. - */ private final int mMax; - - /** - * This attribute (named {@code M_max0} by the HNSW paper) is the maximum connectivity value for nodes stored on - * layer {@code 0}. We will never create more that {@code mMax0} neighbors for a node that is stored on that layer. - * That means that we even prune the neighbors of a node if the actual number of neighbors would otherwise exceed - * {@code mMax0}. - */ private final int mMax0; - - /** - * Maximum size of the search queues (on independent queue per layer) that are used during the insertion of a new - * node. If {@code efConstruction} is set to {@code 1}, the search naturally follows a greedy approach - * (monotonous descent), whereas a high number for {@code efConstruction} allows for a more nuanced search that can - * tolerate (false) local minima. - */ private final int efConstruction; - - /** - * Indicator to signal if, during the insertion of a node, the set of nearest neighbors of that node is to be - * extended by the actual neighbors of those neighbors to form a set of candidates that the new node may be - * connected to during the insert operation. - */ private final boolean extendCandidates; - - /** - * Indicator to signal if, during the insertion of a node, candidates that have been discarded due to not satisfying - * the select-neighbor heuristic may get added back in to pad the set of neighbors if the new node would otherwise - * have too few neighbors (see {@link #m}). - */ private final boolean keepPrunedConnections; - - /** - * If sampling is necessary (currently iff {@link #isUseRaBitQ()} is {@code true}), this attribute represents the - * probability of a vector being inserted to also be written into the - * {@link StorageAdapter#SUBSPACE_PREFIX_SAMPLES} subspace. The vectors in that subspace are continuously aggregated - * until a total {@link #statsThreshold} has been reached. - */ private final double sampleVectorStatsProbability; - - /** - * If sampling is necessary (currently iff {@link #isUseRaBitQ()} is {@code true}), this attribute represents the - * probability of the {@link StorageAdapter#SUBSPACE_PREFIX_SAMPLES} subspace to be further aggregated (rolled-up) - * when a new vector is inserted. The vectors in that subspace are continuously aggregated until a total - * {@link #statsThreshold} has been reached. - */ private final double maintainStatsProbability; - - /** - * If sampling is necessary (currently iff {@link #isUseRaBitQ()} is {@code true}), this attribute represents the - * threshold (being a number of vectors) that when reached causes the stats maintenance logic to compute the actual - * statistics (currently the centroid of the vectors that have been inserted to far). - */ private final int statsThreshold; - - /** - * Indicator if we should RaBitQ quantization. See {@link com.apple.foundationdb.rabitq.RaBitQuantizer} for more - * details. - */ private final boolean useRaBitQ; - - /** - * Number of bits per dimensions iff {@link #isUseRaBitQ()} is set to {@code true}, ignored otherwise. If RaBitQ - * encoding is used, a vector is stored using roughly {@code 25 + numDimensions * (numExBits + 1) / 8} bytes. - */ private final int raBitQNumExBits; - - /** - * Maximum number of concurrent node fetches during search and modification operations. - */ private final int maxNumConcurrentNodeFetches; - - /** - * Maximum number of concurrent neighborhood fetches during modification operations when the neighbors are pruned. - */ private final int maxNumConcurrentNeighborhoodFetches; - private Config(final long randomSeed, @Nonnull final Metric metric, final int numDimensions, + private Config(final boolean deterministicSeeding, @Nonnull final Metric metric, final int numDimensions, final boolean useInlining, final int m, final int mMax, final int mMax0, final int efConstruction, final boolean extendCandidates, final boolean keepPrunedConnections, final double sampleVectorStatsProbability, final double maintainStatsProbability, final int statsThreshold, final boolean useRaBitQ, final int raBitQNumExBits, final int maxNumConcurrentNodeFetches, final int maxNumConcurrentNeighborhoodFetches) { - this.randomSeed = randomSeed; + Preconditions.checkArgument(numDimensions >= 1, "numDimensions must be (1, MAX_INT]"); + Preconditions.checkArgument(m >= 4 && m <= 200, "m must be [4, 200]"); + Preconditions.checkArgument(mMax >= 4 && mMax <= 200, "mMax must be [4, 200]"); + Preconditions.checkArgument(mMax0 >= 4 && mMax0 <= 300, "mMax0 must be [4, 300]"); + Preconditions.checkArgument(m <= mMax, "m must be less than or equal to mMax"); + Preconditions.checkArgument(mMax <= mMax0, "mMax must be less than or equal to mMax0"); + Preconditions.checkArgument(efConstruction >= 100 && efConstruction <= 400, + "efConstruction must be [100, 400]"); + Preconditions.checkArgument(!useRaBitQ || + (sampleVectorStatsProbability > 0.0d && sampleVectorStatsProbability <= 1.0d), + "sampleVectorStatsProbability out of range"); + Preconditions.checkArgument(!useRaBitQ || + (maintainStatsProbability > 0.0d && maintainStatsProbability <= 1.0d), + "maintainStatsProbability out of range"); + Preconditions.checkArgument(!useRaBitQ || statsThreshold > 10, "statThreshold out of range"); + Preconditions.checkArgument(!useRaBitQ || (raBitQNumExBits > 0 && raBitQNumExBits < 16), + "raBitQNumExBits out of range"); + Preconditions.checkArgument(maxNumConcurrentNodeFetches > 0 && maxNumConcurrentNodeFetches <= 64, + "maxNumConcurrentNodeFetches must be (0, 64]"); + Preconditions.checkArgument(maxNumConcurrentNeighborhoodFetches > 0 && + maxNumConcurrentNeighborhoodFetches <= 64, + "maxNumConcurrentNeighborhoodFetches must be (0, 64]"); + + this.deterministicSeeding = deterministicSeeding; this.metric = metric; this.numDimensions = numDimensions; this.useInlining = useInlining; @@ -191,78 +120,184 @@ private Config(final long randomSeed, @Nonnull final Metric metric, final int nu this.maxNumConcurrentNeighborhoodFetches = maxNumConcurrentNeighborhoodFetches; } - public long getRandomSeed() { - return randomSeed; + /** + * Indicator that if {@code true} causes the insert logic of the HNSW to be seeded using a hash of the primary key + * of the record that is inserted. That can be useful for testing. If {@code isDeterministicSeeding} is + * {@code false}, we use {@link System#nanoTime()} for seeding. + */ + public boolean isDeterministicSeeding() { + return deterministicSeeding; } + /** + * The metric that is used to determine distances between vectors. + */ @Nonnull public Metric getMetric() { return metric; } + /** + * The number of dimensions used. All vectors must have exactly this number of dimensions. + */ public int getNumDimensions() { return numDimensions; } + /** + * Indicator if all layers except layer {@code 0} use inlining. One entire layer is fully managed using either + * the compact or the inlining layout. If inlining is used, each node is persisted as a key/value pair per neighbor + * which includes the vectors of the neighbors but not the vector for itself. If inlining is not used, and therefore + * the compact layout is used instead, each node is persisted as exactly one key/value pair per node which stores + * its own vector but specifically excludes the vectors of the neighbors. + *

+ * If a layer uses the compact storage layout, each vector is stored with the node and therefore is stored exactly + * once. During a nearest neighbor search, a fetch of the neighborhood of a node incurs a fetch (get) for each of + * the neighbors of that node. + *

+ * If a layer uses the inlining storage layout, a vector of a node is stored with the neighboring information of an + * adjacent node pointing to this node and is therefore is stored multiple times (once per neighbor). During a + * nearest neighbor search, however, the neighboring vectors of a vector can all be fetched in one range scan. + *

+ * Choosing which storage format is right for the use case depends on some factors: + *

+ */ public boolean isUseInlining() { return useInlining; } + /** + * This attribute (named {@code M} by the HNSW paper) is the connectivity value for all nodes stored on any layer. + * While by no means enforced or even enforceable, we strive to create and maintain exactly {@code m} neighbors for + * a node. Due to insert/delete operations it is possible that the actual number of neighbors a node references is + * not exactly {@code m} at any given time. + */ public int getM() { return m; } + /** + * This attribute (named {@code M_max} by the HNSW paper) is the maximum connectivity value for nodes stored on a + * layer greater than {@code 0}. We will never create more that {@code mMax} neighbors for a node. That means that + * we even prune the neighbors of a node if the actual number of neighbors would otherwise exceed {@code mMax}. + * Note that this attribute must be greater than or equal to {@link #m}. + */ public int getMMax() { return mMax; } + /** + * This attribute (named {@code M_max0} by the HNSW paper) is the maximum connectivity value for nodes stored on + * layer {@code 0}. We will never create more that {@code mMax0} neighbors for a node that is stored on that layer. + * That means that we even prune the neighbors of a node if the actual number of neighbors would otherwise exceed + * {@code mMax0}. Note that this attribute must be greater than or equal to {@link #mMax}. + */ public int getMMax0() { return mMax0; } + /** + * Maximum size of the search queues (one independent queue per layer) that are used during the insertion of a new + * node. If {@code efConstruction} is set to {@code 1}, the search naturally follows a greedy approach + * (monotonous descent), whereas a high number for {@code efConstruction} allows for a more nuanced search that can + * tolerate (false) local minima. + */ public int getEfConstruction() { return efConstruction; } + /** + * Indicator to signal if, during the insertion of a node, the set of nearest neighbors of that node is to be + * extended by the actual neighbors of those neighbors to form a set of candidates that the new node may be + * connected to during the insert operation. + */ public boolean isExtendCandidates() { return extendCandidates; } + /** + * Indicator to signal if, during the insertion of a node, candidates that have been discarded due to not satisfying + * the select-neighbor heuristic may get added back in to pad the set of neighbors if the new node would otherwise + * have too few neighbors (see {@link #m}). + */ public boolean isKeepPrunedConnections() { return keepPrunedConnections; } + /** + * If sampling is necessary (currently iff {@link #isUseRaBitQ()} is {@code true}), this attribute represents the + * probability of a vector being inserted to also be written into the + * {@link StorageAdapter#SUBSPACE_PREFIX_SAMPLES} subspace. The vectors in that subspace are continuously aggregated + * until a total {@link #statsThreshold} has been reached. + */ public double getSampleVectorStatsProbability() { return sampleVectorStatsProbability; } + /** + * If sampling is necessary (currently iff {@link #isUseRaBitQ()} is {@code true}), this attribute represents the + * probability of the {@link StorageAdapter#SUBSPACE_PREFIX_SAMPLES} subspace to be further aggregated (rolled-up) + * when a new vector is inserted. The vectors in that subspace are continuously aggregated until a total + * {@link #statsThreshold} has been reached. + */ public double getMaintainStatsProbability() { return maintainStatsProbability; } + /** + * If sampling is necessary (currently iff {@link #isUseRaBitQ()} is {@code true}), this attribute represents the + * threshold (being a number of vectors) that when reached causes the stats maintenance logic to compute the actual + * statistics (currently the centroid of the vectors that have been inserted to far). + */ public int getStatsThreshold() { return statsThreshold; } + /** + * Indicator if we should RaBitQ quantization. See {@link com.apple.foundationdb.rabitq.RaBitQuantizer} for more + * details. + */ public boolean isUseRaBitQ() { return useRaBitQ; } + /** + * Number of bits per dimensions iff {@link #isUseRaBitQ()} is set to {@code true}, ignored otherwise. If RaBitQ + * encoding is used, a vector is stored using roughly {@code 25 + numDimensions * (numExBits + 1) / 8} bytes. + */ public int getRaBitQNumExBits() { return raBitQNumExBits; } + /** + * Maximum number of concurrent node fetches during search and modification operations. + */ public int getMaxNumConcurrentNodeFetches() { return maxNumConcurrentNodeFetches; } + /** + * Maximum number of concurrent neighborhood fetches during modification operations when the neighbors are pruned. + */ public int getMaxNumConcurrentNeighborhoodFetches() { return maxNumConcurrentNeighborhoodFetches; } @Nonnull public ConfigBuilder toBuilder() { - return new ConfigBuilder(getRandomSeed(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), + return new ConfigBuilder(isDeterministicSeeding(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections(), getSampleVectorStatsProbability(), getMaintainStatsProbability(), getStatsThreshold(), isUseRaBitQ(), getRaBitQNumExBits(), getMaxNumConcurrentNodeFetches(), @@ -278,7 +313,7 @@ public boolean equals(final Object o) { return false; } final Config config = (Config)o; - return randomSeed == config.randomSeed && numDimensions == config.numDimensions && + return deterministicSeeding == config.deterministicSeeding && numDimensions == config.numDimensions && useInlining == config.useInlining && m == config.m && mMax == config.mMax && mMax0 == config.mMax0 && efConstruction == config.efConstruction && extendCandidates == config.extendCandidates && keepPrunedConnections == config.keepPrunedConnections && @@ -292,7 +327,7 @@ public boolean equals(final Object o) { @Override public int hashCode() { - return Objects.hash(randomSeed, metric, numDimensions, useInlining, m, mMax, mMax0, efConstruction, + return Objects.hash(deterministicSeeding, metric, numDimensions, useInlining, m, mMax, mMax0, efConstruction, extendCandidates, keepPrunedConnections, sampleVectorStatsProbability, maintainStatsProbability, statsThreshold, useRaBitQ, raBitQNumExBits, maxNumConcurrentNodeFetches, maxNumConcurrentNeighborhoodFetches); } @@ -300,7 +335,7 @@ public int hashCode() { @Override @Nonnull public String toString() { - return "Config[randomSeed=" + getRandomSeed() + ", metric=" + getMetric() + + return "Config[deterministicSeeding=" + isDeterministicSeeding() + ", metric=" + getMetric() + ", numDimensions=" + getNumDimensions() + ", isUseInlining=" + isUseInlining() + ", M=" + getM() + ", MMax=" + getMMax() + ", MMax0=" + getMMax0() + ", efConstruction=" + getEfConstruction() + ", isExtendCandidates=" + isExtendCandidates() + @@ -321,7 +356,7 @@ public String toString() { @CanIgnoreReturnValue @SuppressWarnings("checkstyle:MemberName") public static class ConfigBuilder { - private long randomSeed = DEFAULT_RANDOM_SEED; + private boolean deterministicSeeding = DEFAULT_DETERMINISTIC_SEEDING; @Nonnull private Metric metric = DEFAULT_METRIC; private boolean useInlining = DEFAULT_USE_INLINING; @@ -345,13 +380,13 @@ public static class ConfigBuilder { public ConfigBuilder() { } - public ConfigBuilder(final long randomSeed, @Nonnull final Metric metric, final boolean useInlining, + public ConfigBuilder(final boolean deterministicSeeding, @Nonnull final Metric metric, final boolean useInlining, final int m, final int mMax, final int mMax0, final int efConstruction, final boolean extendCandidates, final boolean keepPrunedConnections, final double sampleVectorStatsProbability, final double maintainStatsProbability, final int statsThreshold, final boolean useRaBitQ, final int raBitQNumExBits, final int maxNumConcurrentNodeFetches, final int maxNumConcurrentNeighborhoodFetches) { - this.randomSeed = randomSeed; + this.deterministicSeeding = deterministicSeeding; this.metric = metric; this.useInlining = useInlining; this.m = m; @@ -369,13 +404,13 @@ public ConfigBuilder(final long randomSeed, @Nonnull final Metric metric, final this.maxNumConcurrentNeighborhoodFetches = maxNumConcurrentNeighborhoodFetches; } - public long getRandomSeed() { - return randomSeed; + public boolean isDeterministicSeeding() { + return deterministicSeeding; } @Nonnull - public ConfigBuilder setRandomSeed(final long randomSeed) { - this.randomSeed = randomSeed; + public ConfigBuilder setDeterministicSeeding(final boolean deterministicSeeding) { + this.deterministicSeeding = deterministicSeeding; return this; } @@ -529,7 +564,7 @@ public ConfigBuilder setMaxNumConcurrentNeighborhoodFetches(final int maxNumConc } public Config build(final int numDimensions) { - return new Config(getRandomSeed(), getMetric(), numDimensions, isUseInlining(), getM(), getMMax(), + return new Config(isDeterministicSeeding(), getMetric(), numDimensions, isUseInlining(), getM(), getMMax(), getMMax0(), getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections(), getSampleVectorStatsProbability(), getMaintainStatsProbability(), getStatsThreshold(), isUseRaBitQ(), getRaBitQNumExBits(), getMaxNumConcurrentNodeFetches(), diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java index 2e7816b530..639cd2f273 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -55,8 +55,8 @@ import java.util.Objects; import java.util.PriorityQueue; import java.util.Queue; -import java.util.Random; import java.util.Set; +import java.util.SplittableRandom; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReference; @@ -89,8 +89,6 @@ public class HNSW { @Nonnull private static final Logger logger = LoggerFactory.getLogger(HNSW.class); - @Nonnull - private final Random random; @Nonnull private final Subspace subspace; @Nonnull @@ -141,7 +139,6 @@ public HNSW(@Nonnull final Subspace subspace, @Nonnull final Config config, @Nonnull final OnWriteListener onWriteListener, @Nonnull final OnReadListener onReadListener) { - this.random = new Random(config.getRandomSeed()); this.subspace = subspace; this.executor = executor; this.config = config; @@ -581,7 +578,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { return onReadListener.onAsyncRead( storageAdapter.fetchNode(readTransaction, storageTransform, layer, nodeReference.getPrimaryKey())) - .thenApply(node -> biMapFunction.apply(nodeReference, node)); + .thenApply(node -> biMapFunction.apply(nodeReference, Objects.requireNonNull(node))); } /** @@ -748,19 +745,35 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { @Nonnull public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, @Nonnull final RealVector newVector) { - final int insertionLayer = insertionLayer(); + final SplittableRandom random = random(newPrimaryKey); + final int insertionLayer = insertionLayer(random); if (logger.isTraceEnabled()) { logger.trace("new node with key={} selected to be inserted into layer={}", newPrimaryKey, insertionLayer); } return StorageAdapter.fetchAccessInfo(getConfig(), transaction, getSubspace(), getOnReadListener()) - .thenCompose(accessInfo -> { - final AccessInfo currentAccessInfo; + .thenCombine(exists(transaction, newPrimaryKey), + (accessInfo, nodeAlreadyExists) -> { + if (nodeAlreadyExists) { + if (logger.isDebugEnabled()) { + logger.debug("new record already exists in HNSW with key={} on layer={}", + newPrimaryKey, insertionLayer); + } + } + return new AccessInfoAndNodeExistence(accessInfo, nodeAlreadyExists); + }) + .thenCompose(accessInfoAndNodeExistence -> { + if (accessInfoAndNodeExistence.isNodeExists()) { + return AsyncUtil.DONE; + } + + final AccessInfo accessInfo = accessInfoAndNodeExistence.getAccessInfo(); final AffineOperator storageTransform = storageTransform(accessInfo); final Transformed transformedNewVector = storageTransform.transform(newVector); final Quantizer quantizer = quantizer(accessInfo); final Estimator estimator = quantizer.estimator(); + final AccessInfo currentAccessInfo; if (accessInfo == null) { // this is the first node writeLonelyNodes(quantizer, transaction, newPrimaryKey, transformedNewVector, @@ -817,10 +830,24 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N insertIntoLayers(transaction, storageTransform, quantizer, newPrimaryKey, transformedNewVector, nodeReference, lMax, insertionLayer)) .thenCompose(ignored -> - addToStatsIfNecessary(transaction, currentAccessInfo, transformedNewVector)); + addToStatsIfNecessary(random.split(), transaction, currentAccessInfo, transformedNewVector)); }).thenCompose(ignored -> AsyncUtil.DONE); } + @Nonnull + @VisibleForTesting + CompletableFuture exists(@Nonnull final ReadTransaction readTransaction, + @Nonnull final Tuple primaryKey) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(0); + + // + // Call fetchNode() to check for the node's existence; we are handing in the identity operator, since we don't + // care about the vector itself at all. + // + return storageAdapter.fetchNode(readTransaction, AffineOperator.identity(), 0, primaryKey) + .thenApply(Objects::nonNull); + } + /** * Method to keep stats if necessary. Stats need to be kept and maintained when the client would like to use * e.g. RaBitQ as RaBitQ needs a stable somewhat correct centroid in order to function properly. @@ -832,21 +859,23 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N * in order to finally compute the centroid if {@link Config#getStatsThreshold()} number of vectors have been * sampled and aggregated. That centroid is then used to update the access info. * + * @param random a random to use * @param transaction the transaction * @param currentAccessInfo this current access info that was fetched as part of an insert * @param transformedNewVector the new vector (in the transformed coordinate system) that may be added * @return a future that returns {@code null} when completed */ @Nonnull - private CompletableFuture addToStatsIfNecessary(@Nonnull final Transaction transaction, + private CompletableFuture addToStatsIfNecessary(@Nonnull final SplittableRandom random, + @Nonnull final Transaction transaction, @Nonnull final AccessInfo currentAccessInfo, @Nonnull final Transformed transformedNewVector) { if (getConfig().isUseRaBitQ() && !currentAccessInfo.canUseRaBitQ()) { - if (shouldSampleVector()) { + if (shouldSampleVector(random)) { StorageAdapter.appendSampledVector(transaction, getSubspace(), 1, transformedNewVector, onWriteListener); } - if (shouldMaintainStats()) { + if (shouldMaintainStats(random)) { return StorageAdapter.consumeSampledVectors(transaction, getSubspace(), 50, onReadListener) .thenApply(sampledVectors -> { @@ -1512,6 +1541,15 @@ private StorageAdapter getStorageAdapterForLayer(final getOnReadListener()); } + @Nonnull + private SplittableRandom random(@Nonnull final Tuple primaryKey) { + if (config.isDeterministicSeeding()) { + return new SplittableRandom(primaryKey.hashCode()); + } else { + return new SplittableRandom(System.nanoTime()); + } + } + /** * Calculates a random layer for a new element to be inserted. *

@@ -1521,20 +1559,20 @@ private StorageAdapter getStorageAdapterForLayer(final * is {@code floor(-ln(u) * lambda)}, where {@code u} is a uniform random * number and {@code lambda} is a normalization factor derived from a system * configuration parameter {@code M}. - * + * @param random a random to use * @return a non-negative integer representing the randomly selected layer. */ - private int insertionLayer() { + private int insertionLayer(@Nonnull final SplittableRandom random) { double lambda = 1.0 / Math.log(getConfig().getM()); double u = 1.0 - random.nextDouble(); // Avoid log(0) return (int) Math.floor(-Math.log(u) * lambda); } - private boolean shouldSampleVector() { + private boolean shouldSampleVector(@Nonnull final SplittableRandom random) { return random.nextDouble() < getConfig().getSampleVectorStatsProbability(); } - private boolean shouldMaintainStats() { + private boolean shouldMaintainStats(@Nonnull final SplittableRandom random) { return random.nextDouble() < getConfig().getMaintainStatsProbability(); } @@ -1546,4 +1584,24 @@ private static List drain(@Nonnull Queue queue) { } return resultBuilder.build(); } + + private static class AccessInfoAndNodeExistence { + @Nullable + private final AccessInfo accessInfo; + private final boolean nodeExists; + + public AccessInfoAndNodeExistence(@Nullable final AccessInfo accessInfo, final boolean nodeExists) { + this.accessInfo = accessInfo; + this.nodeExists = nodeExists; + } + + @Nullable + public AccessInfo getAccessInfo() { + return accessInfo; + } + + public boolean isNodeExists() { + return nodeExists; + } + } } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java index 9bffaacc67..4c7296ad45 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java @@ -29,8 +29,6 @@ import com.apple.foundationdb.async.AsyncUtil; import com.apple.foundationdb.linear.AffineOperator; import com.apple.foundationdb.linear.DoubleRealVector; -import com.apple.foundationdb.linear.FloatRealVector; -import com.apple.foundationdb.linear.HalfRealVector; import com.apple.foundationdb.linear.Quantizer; import com.apple.foundationdb.linear.RealVector; import com.apple.foundationdb.linear.Transformed; @@ -59,7 +57,6 @@ * @param the type of {@link NodeReference} this storage adapter manages */ interface StorageAdapter { - ImmutableList VECTOR_TYPES = ImmutableList.copyOf(VectorType.values()); /** * Subspace for data. @@ -199,29 +196,24 @@ static RealVector vectorFromTuple(@Nonnull final Config config, @Nonnull final T /** * Creates a {@link RealVector} from a byte array. *

- * This method interprets the input byte array by interpreting the first byte of the array as the precision shift. - * The byte array must have the proper size, i.e. the invariant {@code (bytesLength - 1) % precision == 0} must - * hold. + * This method interprets the input byte array by interpreting the first byte of the array. + * It the delegates to {@link RealVector#fromBytes(VectorType, byte[])}. * @param config an HNSW config * @param vectorBytes the non-null byte array to convert. * @return a new {@link RealVector} instance created from the byte array. - * @throws com.google.common.base.VerifyException if the length of {@code vectorBytes} does not meet the invariant - * {@code (bytesLength - 1) % precision == 0} */ @Nonnull static RealVector vectorFromBytes(@Nonnull final Config config, @Nonnull final byte[] vectorBytes) { final byte vectorTypeOrdinal = vectorBytes[0]; - switch (fromVectorTypeOrdinal(vectorTypeOrdinal)) { - case HALF: - return HalfRealVector.fromBytes(vectorBytes); - case SINGLE: - return FloatRealVector.fromBytes(vectorBytes); - case DOUBLE: - return DoubleRealVector.fromBytes(vectorBytes); + switch (RealVector.fromVectorTypeOrdinal(vectorTypeOrdinal)) { case RABITQ: Verify.verify(config.isUseRaBitQ()); return EncodedRealVector.fromBytes(vectorBytes, config.getNumDimensions(), config.getRaBitQNumExBits()); + case HALF: + case SINGLE: + case DOUBLE: + return RealVector.fromBytes(vectorBytes); default: throw new RuntimeException("unable to serialize vector"); } @@ -251,11 +243,6 @@ static Tuple tupleFromVector(@Nonnull final RealVector vector) { return Tuple.from(vector.getRawData()); } - @Nonnull - static VectorType fromVectorTypeOrdinal(final int ordinal) { - return VECTOR_TYPES.get(ordinal); - } - @Nonnull static CompletableFuture fetchAccessInfo(@Nonnull final Config config, @Nonnull final ReadTransaction readTransaction, diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java index b0c79513f2..67e1dd4175 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java @@ -20,8 +20,9 @@ package com.apple.foundationdb.linear; -import com.google.common.base.Preconditions; import com.apple.foundationdb.half.Half; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import javax.annotation.Nonnull; @@ -34,6 +35,8 @@ * data type conversions and raw data representation. */ public interface RealVector { + ImmutableList VECTOR_TYPES = ImmutableList.copyOf(VectorType.values()); + /** * Returns the number of elements in the vector, i.e. the number of dimensions. * @return the number of dimensions @@ -189,4 +192,47 @@ default RealVector multiply(final double scalarFactor) { } return withData(result); } + + @Nonnull + static VectorType fromVectorTypeOrdinal(final int ordinal) { + return VECTOR_TYPES.get(ordinal); + } + + /** + * Creates a {@link RealVector} from a byte array. + *

+ * This method interprets the input byte array by interpreting the first byte of the array as the type of vector. + * It then delegates to {@link #fromBytes(VectorType, byte[])} to do the actual deserialization. + * + * @param vectorBytes the non-null byte array to convert. + * @return a new {@link RealVector} instance created from the byte array. + */ + @Nonnull + static RealVector fromBytes(@Nonnull final byte[] vectorBytes) { + final byte vectorTypeOrdinal = vectorBytes[0]; + return fromBytes(fromVectorTypeOrdinal(vectorTypeOrdinal), vectorBytes); + } + + /** + * Creates a {@link RealVector} from a byte array. + *

+ * This implementation dispatches to the actual logic that deserialize a byte array to a vector which is located in + * the respective implementations of {@link RealVector}. + * @param vectorType the vector type of the serialized vector + * @param vectorBytes the non-null byte array to convert. + * @return a new {@link RealVector} instance created from the byte array. + */ + @Nonnull + static RealVector fromBytes(@Nonnull final VectorType vectorType, @Nonnull final byte[] vectorBytes) { + switch (vectorType) { + case HALF: + return HalfRealVector.fromBytes(vectorBytes); + case SINGLE: + return FloatRealVector.fromBytes(vectorBytes); + case DOUBLE: + return DoubleRealVector.fromBytes(vectorBytes); + default: + throw new RuntimeException("unable to deserialize vector"); + } + } } diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java index c3a5c69117..141df55cfe 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java @@ -32,7 +32,7 @@ void testConfig() { Assertions.assertThat(HNSW.newConfigBuilder().build(768)).isEqualTo(defaultConfig); Assertions.assertThat(defaultConfig.toBuilder().build(768)).isEqualTo(defaultConfig); - final long randomSeed = 1L; + final boolean deterministicSeeding = true; final Metric metric = Metric.COSINE_METRIC; final boolean useInlining = true; final int m = Config.DEFAULT_M + 1; @@ -41,7 +41,7 @@ void testConfig() { final int efConstruction = Config.DEFAULT_EF_CONSTRUCTION + 1; final boolean extendCandidates = true; final boolean keepPrunedConnections = true; - final int statsThreshold = 1; + final int statsThreshold = 5000; final double sampleVectorStatsProbability = 0.000001d; final double maintainStatsProbability = 0.000002d; @@ -51,7 +51,7 @@ void testConfig() { final int maxNumConcurrentNodeFetches = 1; final int maxNumConcurrentNeighborhoodFetches = 2; - Assertions.assertThat(defaultConfig.getRandomSeed()).isNotEqualTo(randomSeed); + Assertions.assertThat(defaultConfig.isDeterministicSeeding()).isNotEqualTo(deterministicSeeding); Assertions.assertThat(defaultConfig.getMetric()).isNotSameAs(metric); Assertions.assertThat(defaultConfig.isUseInlining()).isNotEqualTo(useInlining); Assertions.assertThat(defaultConfig.getM()).isNotEqualTo(m); @@ -73,7 +73,7 @@ void testConfig() { final Config newConfig = defaultConfig.toBuilder() - .setRandomSeed(randomSeed) + .setDeterministicSeeding(deterministicSeeding) .setMetric(metric) .setUseInlining(useInlining) .setM(m) @@ -91,7 +91,7 @@ void testConfig() { .setMaxNumConcurrentNeighborhoodFetches(maxNumConcurrentNeighborhoodFetches) .build(768); - Assertions.assertThat(newConfig.getRandomSeed()).isEqualTo(randomSeed); + Assertions.assertThat(newConfig.isDeterministicSeeding()).isEqualTo(deterministicSeeding); Assertions.assertThat(newConfig.getMetric()).isSameAs(metric); Assertions.assertThat(newConfig.isUseInlining()).isEqualTo(useInlining); Assertions.assertThat(newConfig.getM()).isEqualTo(m); @@ -116,7 +116,7 @@ void testConfig() { void testEqualsHashCodeAndToString() { final Config config1 = HNSW.newConfigBuilder().build(768); final Config config2 = HNSW.newConfigBuilder().build(768); - final Config config3 = HNSW.newConfigBuilder().setM(1).build(768); + final Config config3 = HNSW.newConfigBuilder().setM(4).build(768); Assertions.assertThat(config1.hashCode()).isEqualTo(config2.hashCode()); Assertions.assertThat(config1).isEqualTo(config2); diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java index 62fcd89076..6d316103c8 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java @@ -31,21 +31,24 @@ import org.junit.jupiter.params.ParameterizedTest; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; +import java.util.Objects; import java.util.Random; +import java.util.function.BiFunction; import java.util.function.Function; class DataRecordsTest { @ParameterizedTest @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) void testAccessInfo(final long randomSeed) { - assertHashCodeEqualsToString(randomSeed, DataRecordsTest::accessInfo); + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::accessInfo, DataRecordsTest::accessInfo); } @ParameterizedTest @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) void testAggregatedVector(final long randomSeed) { - assertHashCodeEqualsToString(randomSeed, DataRecordsTest::aggregatedVector); + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::aggregatedVector, DataRecordsTest::aggregatedVector); } @ParameterizedTest @@ -58,7 +61,7 @@ void testCompactNode(final long randomSeed) { final CompactNode compactNode1Clone = compactNode(new Random(dependentRandomSeed)); Assertions.assertThat(compactNode1).hasToString(compactNode1Clone.toString()); - final CompactNode compactNode2 = compactNode(random); + final CompactNode compactNode2 = compactNode(random, compactNode1); Assertions.assertThat(compactNode1).doesNotHaveToString(compactNode2.toString()); Assertions.assertThatThrownBy(compactNode1::asInliningNode).isInstanceOf(IllegalStateException.class); @@ -74,7 +77,7 @@ void testInliningNode(final long randomSeed) { final InliningNode inliningNode1Clone = inliningNode(new Random(dependentRandomSeed)); Assertions.assertThat(inliningNode1).hasToString(inliningNode1Clone.toString()); - final InliningNode inliningNode2 = inliningNode(random); + final InliningNode inliningNode2 = inliningNode(random, inliningNode1); Assertions.assertThat(inliningNode1).doesNotHaveToString(inliningNode2.toString()); Assertions.assertThatThrownBy(inliningNode1::asCompactNode).isInstanceOf(IllegalStateException.class); @@ -83,13 +86,13 @@ void testInliningNode(final long randomSeed) { @ParameterizedTest @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) void testEntryNodeReference(final long randomSeed) { - assertHashCodeEqualsToString(randomSeed, DataRecordsTest::entryNodeReference); + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::entryNodeReference, DataRecordsTest::entryNodeReference); } @ParameterizedTest @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) void testNodeReference(final long randomSeed) { - assertHashCodeEqualsToString(randomSeed, DataRecordsTest::nodeReference); + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::nodeReference, DataRecordsTest::nodeReference); final NodeReference nodeReference = nodeReference(new Random(randomSeed)); Assertions.assertThat(nodeReference.isNodeReferenceWithVector()).isFalse(); Assertions.assertThatThrownBy(nodeReference::asNodeReferenceWithVector).isInstanceOf(IllegalStateException.class); @@ -98,7 +101,8 @@ void testNodeReference(final long randomSeed) { @ParameterizedTest @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) void testNodeReferenceWithVector(final long randomSeed) { - assertHashCodeEqualsToString(randomSeed, DataRecordsTest::nodeReferenceWithVector); + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::nodeReferenceWithVector, + DataRecordsTest::nodeReferenceWithVector); final NodeReferenceWithVector nodeReference = nodeReferenceWithVector(new Random(randomSeed)); Assertions.assertThat(nodeReference.isNodeReferenceWithVector()).isTrue(); Assertions.assertThat(nodeReference.asNodeReferenceWithVector()).isInstanceOf(NodeReferenceWithVector.class); @@ -107,7 +111,8 @@ void testNodeReferenceWithVector(final long randomSeed) { @ParameterizedTest @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) void testNodeReferenceWithDistance(final long randomSeed) { - assertHashCodeEqualsToString(randomSeed, DataRecordsTest::nodeReferenceWithDistance); + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::nodeReferenceWithDistance, + DataRecordsTest::nodeReferenceWithDistance); final NodeReferenceWithDistance nodeReference = nodeReferenceWithDistance(new Random(randomSeed)); Assertions.assertThat(nodeReference.isNodeReferenceWithVector()).isTrue(); Assertions.assertThat(nodeReference.asNodeReferenceWithVector()).isInstanceOf(NodeReferenceWithDistance.class); @@ -116,10 +121,12 @@ void testNodeReferenceWithDistance(final long randomSeed) { @ParameterizedTest @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) void testResultEntry(final long randomSeed) { - assertHashCodeEqualsToString(randomSeed, DataRecordsTest::resultEntry); + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::resultEntry, DataRecordsTest::resultEntry); } - private static void assertHashCodeEqualsToString(final long randomSeed, final Function createFunction) { + private static void assertHashCodeEqualsToString(final long randomSeed, + @Nonnull final Function createFunction, + @Nonnull final BiFunction createDifferentFunction) { final Random random = new Random(randomSeed); final long dependentRandomSeed = random.nextLong(); final T t1 = createFunction.apply(new Random(dependentRandomSeed)); @@ -128,7 +135,7 @@ private static void assertHashCodeEqualsToString(final long randomSeed, fina Assertions.assertThat(t1).isEqualTo(t1Clone); Assertions.assertThat(t1).hasToString(t1Clone.toString()); - final T t2 = createFunction.apply(random); + final T t2 = createDifferentFunction.apply(random, t1); Assertions.assertThat(t1).isNotEqualTo(t2); Assertions.assertThat(t1).doesNotHaveToString(t2.toString()); } @@ -138,6 +145,14 @@ private static ResultEntry resultEntry(@Nonnull final Random random) { return new ResultEntry(primaryKey(random), rawVector(random), random.nextDouble(), random.nextInt(100)); } + @Nonnull + private static ResultEntry resultEntry(@Nonnull final Random random, @Nonnull final ResultEntry original) { + return new ResultEntry(primaryKey(random, original.getPrimaryKey()), + rawVector(random, Objects.requireNonNull(original.getVector())), + differentDouble(random, original.getDistance()), + differentInteger(random, original.getRankOrRowNumber(), 100)); + } + @Nonnull private static CompactNode compactNode(@Nonnull final Random random) { return CompactNode.factory() @@ -145,10 +160,27 @@ private static CompactNode compactNode(@Nonnull final Random random) { .asCompactNode(); } + @Nonnull + private static CompactNode compactNode(@Nonnull final Random random, @Nonnull CompactNode original) { + return CompactNode.factory() + .create(primaryKey(random, original.getPrimaryKey()), + vector(random, original.getVector()), + nodeReferences(random, original.getNeighbors())) + .asCompactNode(); + } + @Nonnull private static InliningNode inliningNode(@Nonnull final Random random) { return InliningNode.factory() - .create(primaryKey(random), vector(random), nodeReferenceWithVectors(random)) + .create(primaryKey(random), null, nodeReferenceWithVectors(random)) + .asInliningNode(); + } + + private static InliningNode inliningNode(@Nonnull final Random random, @Nonnull final InliningNode original) { + return InliningNode.factory() + .create(primaryKey(random, original.getPrimaryKey()), + null, + nodeReferenceWithVectors(random, original.getNeighbors())) .asInliningNode(); } @@ -157,9 +189,26 @@ private static NodeReferenceWithDistance nodeReferenceWithDistance(@Nonnull fina return new NodeReferenceWithDistance(primaryKey(random), vector(random), random.nextDouble()); } + @Nonnull + private static NodeReferenceWithDistance nodeReferenceWithDistance(@Nonnull final Random random, + @Nonnull final NodeReferenceWithDistance original) { + return new NodeReferenceWithDistance( + primaryKey(random, original.getPrimaryKey()), + vector(random, original.getVector()), + differentDouble(random, original.getDistance())); + } + @Nonnull private static List nodeReferenceWithVectors(@Nonnull final Random random) { - int size = random.nextInt(20); + return nodeReferenceWithVectors(random, null); + } + + @Nonnull + private static List nodeReferenceWithVectors(@Nonnull final Random random, + @Nullable final List original) { + final int size = original == null + ? random.nextInt(20) + : differentInteger(random, original.size(), 20); final ImmutableList.Builder resultBuilder = ImmutableList.builder(); for (int i = 0; i < size; i ++) { resultBuilder.add(nodeReferenceWithVector(random)); @@ -172,9 +221,24 @@ private static NodeReferenceWithVector nodeReferenceWithVector(@Nonnull final Ra return new NodeReferenceWithVector(primaryKey(random), vector(random)); } + @Nonnull + private static NodeReferenceWithVector nodeReferenceWithVector(@Nonnull final Random random, + @Nonnull final NodeReferenceWithVector original) { + return new NodeReferenceWithVector(primaryKey(random, original.getPrimaryKey()), + vector(random, original.getVector())); + } + @Nonnull private static List nodeReferences(@Nonnull final Random random) { - int size = random.nextInt(20); + return nodeReferences(random, null); + } + + @Nonnull + private static List nodeReferences(@Nonnull final Random random, + @Nullable final List original) { + final int size = original == null + ? random.nextInt(20) + : differentInteger(random, original.size(), 20); final ImmutableList.Builder resultBuilder = ImmutableList.builder(); for (int i = 0; i < size; i ++) { resultBuilder.add(nodeReference(random)); @@ -187,33 +251,104 @@ private static NodeReference nodeReference(@Nonnull final Random random) { return new NodeReference(primaryKey(random)); } + @Nonnull + private static NodeReference nodeReference(@Nonnull final Random random, @Nonnull NodeReference original) { + return new NodeReference(primaryKey(random, original.getPrimaryKey())); + } + @Nonnull private static AggregatedVector aggregatedVector(@Nonnull final Random random) { return new AggregatedVector(random.nextInt(100), vector(random)); } + @Nonnull + private static AggregatedVector aggregatedVector(@Nonnull final Random random, + @Nonnull final AggregatedVector original) { + return new AggregatedVector(differentInteger(random, original.getPartialCount(), 100), + vector(random, original.getPartialVector())); + } + @Nonnull private static AccessInfo accessInfo(@Nonnull final Random random) { return new AccessInfo(entryNodeReference(random), random.nextLong(), rawVector(random)); } + @Nonnull + private static AccessInfo accessInfo(@Nonnull final Random random, @Nonnull final AccessInfo original) { + return new AccessInfo(entryNodeReference(random, original.getEntryNodeReference()), + differentLong(random, original.getRotatorSeed()), + rawVector(random, Objects.requireNonNull(original.getNegatedCentroid()))); + } + @Nonnull private static EntryNodeReference entryNodeReference(@Nonnull final Random random) { return new EntryNodeReference(primaryKey(random), vector(random), random.nextInt(10)); } + @Nonnull + private static EntryNodeReference entryNodeReference(@Nonnull final Random random, + @Nonnull final EntryNodeReference original) { + return new EntryNodeReference(primaryKey(random, original.getPrimaryKey()), + vector(random, original.getVector()), + differentInteger(random, original.getLayer(), 10)); + } + @Nonnull private static Tuple primaryKey(@Nonnull final Random random) { return Tuple.from(random.nextInt(100)); } + @Nonnull + private static Tuple primaryKey(@Nonnull final Random random, @Nonnull final Tuple original) { + return Tuple.from(differentInteger(random, Math.toIntExact(original.getLong(0)), 100)); + } + @Nonnull private static Transformed vector(@Nonnull final Random random) { return AffineOperator.identity().transform(rawVector(random)); } + @Nonnull + private static Transformed vector(@Nonnull final Random random, + @Nonnull final Transformed original) { + return AffineOperator.identity().transform(rawVector(random, original.getUnderlyingVector())); + } + @Nonnull private static RealVector rawVector(@Nonnull final Random random) { return RealVectorTest.createRandomDoubleVector(random, 768); } + + @Nonnull + private static RealVector rawVector(@Nonnull final Random random, @Nonnull final RealVector original) { + RealVector randomVector; + do { + randomVector = RealVectorTest.createRandomDoubleVector(random, 768); + } while (randomVector.equals(original)); + return randomVector; + } + + private static int differentInteger(@Nonnull final Random random, final int original, final int bound) { + int randomInteger; + do { + randomInteger = random.nextInt(bound); + } while (randomInteger == original); + return randomInteger; + } + + private static long differentLong(@Nonnull final Random random, final long original) { + long randomLong; + do { + randomLong = random.nextLong(); + } while (randomLong == original); + return randomLong; + } + + private static double differentDouble(@Nonnull final Random random, final double original) { + double randomDouble; + do { + randomDouble = random.nextDouble(); + } while (randomDouble == original); + return randomDouble; + } } diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java index 94dc7a9804..726a38902f 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java @@ -77,9 +77,8 @@ import java.util.Set; import java.util.TreeSet; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; +import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.LongStream; import java.util.stream.Stream; @@ -206,21 +205,25 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e final boolean keepPrunedConnections, final boolean useRaBitQ) { final Random random = new Random(seed); final Metric metric = Metric.EUCLIDEAN_METRIC; - final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); - final TestOnReadListener onReadListener = new TestOnReadListener(); final int numDimensions = 128; final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), - HNSW.newConfigBuilder().setMetric(metric) - .setUseInlining(useInlining).setExtendCandidates(extendCandidates) + HNSW.newConfigBuilder() + .setDeterministicSeeding(true) + .setMetric(metric) + .setUseInlining(useInlining) + .setExtendCandidates(extendCandidates) .setKeepPrunedConnections(keepPrunedConnections) .setUseRaBitQ(useRaBitQ) .setRaBitQNumExBits(5) .setSampleVectorStatsProbability(1.0d) .setMaintainStatsProbability(0.1d) .setStatsThreshold(100) - .setM(32).setMMax(32).setMMax0(64).build(numDimensions), + .setM(32) + .setMMax(32) + .setMMax0(64) + .build(numDimensions), OnWriteListener.NOOP, onReadListener); final int k = 50; @@ -229,9 +232,9 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance)); for (int i = 0; i < 1000;) { - i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, - tr -> { - final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); + i += basicInsertBatch(hnsw, 100, i, onReadListener, + (tr, nextId) -> { + final var primaryKey = createPrimaryKey(nextId); final HalfRealVector dataVector = createRandomHalfVector(random, numDimensions); final double distance = metric.distance(dataVector, queryVector); final PrimaryKeyVectorAndDistance record = @@ -244,6 +247,21 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e }); } + // + // Attempt to mutate some records by updating them using the same primary keys but different random vectors. + // This should not fail but should be silently ignored. If this succeeds, the following searches will all + // return records that are not aligned with recordsOrderedByDistance. + // + for (int i = 0; i < 100;) { + i += basicInsertBatch(hnsw, 100, 0, onReadListener, + (tr, ignored) -> { + final var primaryKey = createPrimaryKey(random.nextInt(1000)); + final HalfRealVector dataVector = createRandomHalfVector(random, numDimensions); + final double distance = metric.distance(dataVector, queryVector); + return new PrimaryKeyVectorAndDistance(primaryKey, dataVector, distance); + }); + } + onReadListener.reset(); final long beginTs = System.nanoTime(); final List results = @@ -293,16 +311,20 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { final Random random = new Random(seed); final Metric metric = Metric.EUCLIDEAN_METRIC; - final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); final int numDimensions = 128; final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), - HNSW.newConfigBuilder().setMetric(metric) + HNSW.newConfigBuilder() + .setDeterministicSeeding(true) + .setMetric(metric) .setUseRaBitQ(true) .setRaBitQNumExBits(5) .setSampleVectorStatsProbability(1.0d) // every vector is sampled .setMaintainStatsProbability(1.0d) // for every vector we maintain the stats .setStatsThreshold(950) // after 950 vectors we enable RaBitQ - .setM(32).setMMax(32).setMMax0(64).build(numDimensions), + .setM(32) + .setMMax(32) + .setMMax0(64) + .build(numDimensions), OnWriteListener.NOOP, OnReadListener.NOOP); final int k = 499; @@ -312,9 +334,9 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance)); for (int i = 0; i < 1000;) { - i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, new TestOnReadListener(), - tr -> { - final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); + i += basicInsertBatch(hnsw, 100, i, new TestOnReadListener(), + (tr, nextId) -> { + final var primaryKey = createPrimaryKey(nextId); final DoubleRealVector dataVector = createRandomDoubleVector(random, numDimensions); final double distance = metric.distance(dataVector, queryVector); dataMap.put(primaryKey, dataVector); @@ -385,14 +407,13 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { } private int basicInsertBatch(final HNSW hnsw, final int batchSize, - @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, - @Nonnull final Function insertFunction) { + final long firstId, @Nonnull final TestOnReadListener onReadListener, + @Nonnull final BiFunction insertFunction) { return db.run(tr -> { onReadListener.reset(); - final long nextNodeId = nextNodeIdAtomic.get(); final long beginTs = System.nanoTime(); for (int i = 0; i < batchSize; i ++) { - final var record = insertFunction.apply(tr); + final var record = insertFunction.apply(tr, firstId + i); if (record == null) { return i; } @@ -400,7 +421,7 @@ final var record = insertFunction.apply(tr); } final long endTs = System.nanoTime(); logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, readBytes={}", - batchSize, nextNodeId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + batchSize, firstId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer()); return batchSize; }); @@ -411,13 +432,18 @@ final var record = insertFunction.apply(tr); void testSIFTInsertSmall() throws Exception { final Metric metric = Metric.EUCLIDEAN_METRIC; final int k = 100; - final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); - final TestOnReadListener onReadListener = new TestOnReadListener(); final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), - HNSW.newConfigBuilder().setUseRaBitQ(true).setRaBitQNumExBits(5) - .setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(128), + HNSW.newConfigBuilder() + .setDeterministicSeeding(false) + .setUseRaBitQ(true) + .setRaBitQNumExBits(5) + .setMetric(metric) + .setM(32) + .setMMax(32) + .setMMax0(64) + .build(128), OnWriteListener.NOOP, onReadListener); final Path siftSmallPath = Paths.get(".out/extracted/siftsmall/siftsmall_base.fvecs"); @@ -430,13 +456,13 @@ void testSIFTInsertSmall() throws Exception { int i = 0; final AtomicReference sumReference = new AtomicReference<>(null); while (vectorIterator.hasNext()) { - i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, - tr -> { + i += basicInsertBatch(hnsw, 100, i, onReadListener, + (tr, nextId) -> { if (!vectorIterator.hasNext()) { return null; } final DoubleRealVector doubleVector = vectorIterator.next(); - final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final Tuple currentPrimaryKey = createPrimaryKey(nextId); final HalfRealVector currentVector = doubleVector.toHalfRealVector(); if (sumReference.get() == null) { @@ -571,12 +597,12 @@ private NodeReferenceWithVector createRandomNodeReferenceWithVector(@Nonnull fin @Nonnull private static Tuple createRandomPrimaryKey(final @Nonnull Random random) { - return Tuple.from(random.nextLong()); + return createPrimaryKey(random.nextLong()); } @Nonnull - private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic) { - return Tuple.from(nextIdAtomic.getAndIncrement()); + private static Tuple createPrimaryKey(final long nextId) { + return Tuple.from(nextId); } private static class TestOnReadListener implements OnReadListener { diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/IndexScanType.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/IndexScanType.java index 11df7874bd..7ae70326f2 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/IndexScanType.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/IndexScanType.java @@ -48,6 +48,8 @@ public class IndexScanType implements PlanHashable, PlanSerializable { public static final IndexScanType BY_TIME_WINDOW = new IndexScanType("BY_TIME_WINDOW"); @Nonnull public static final IndexScanType BY_TEXT_TOKEN = new IndexScanType("BY_TEXT_TOKEN"); + @Nonnull + public static final IndexScanType BY_DISTANCE = new IndexScanType("BY_DISTANCE"); private final String name; diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java index 2b66805b2f..a393debff4 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java @@ -21,6 +21,7 @@ package com.apple.foundationdb.record.metadata; import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.hnsw.Config; import com.apple.foundationdb.async.rtree.RTree; import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainer; @@ -223,6 +224,143 @@ public class IndexOptions { */ public static final String RTREE_USE_NODE_SLOT_INDEX = "rtreeUseNodeSlotIndex"; + /** + * HNSW-only: The seeding method that is used to see the PRNG that is in turn used to probabilistically determine + * the highest layer of an insert into an HNSW structure. See {@link Config#isDeterministicSeeding()}. The default + * random seed is {@link Config#DEFAULT_DETERMINISTIC_SEEDING}. + */ + public static final String HNSW_DETERMINISTIC_SEEDING = "hnswDeterministicSeeding"; + + /** + * HNSW-only: The metric that is used to determine distances between vectors. The default metric is + * {@link Config#DEFAULT_METRIC}. See {@link Config#getMetric()}. + */ + public static final String HNSW_METRIC = "hnswMetric"; + + /** + * HNSW-only: The number of dimensions used. All vectors must have exactly this number of dimensions. This option + * must be set when interacting with a vector index as it there is no default. + * See {@link Config#getNumDimensions()}. + */ + public static final String HNSW_NUM_DIMENSIONS = "hnswNumDimensions"; + + /** + * HNSW-only: Indicator if all layers except layer {@code 0} use inlining. If inlining is used, each node is + * persisted as a key/value pair per neighbor which includes the vectors of the neighbors but not for itself. If + * inlining is not used, each node is persisted as exactly one key/value pair per node which stores its own vector + * but specifically excludes the vectors of the neighbors. The default value is set to + * {@link Config#DEFAULT_USE_INLINING}. See {@link Config#isUseInlining()}. + */ + public static final String HNSW_USE_INLINING = "hnswUseInlining"; + + /** + * HNSW-only: This option (named {@code M} by the HNSW paper) is the connectivity value for all nodes stored on + * any layer. While by no means enforced or even enforceable, we strive to create and maintain exactly {@code m} + * neighbors for a node. Due to insert/delete operations it is possible that the actual number of neighbors a node + * references is not exactly {@code m} at any given time. The default value is set to {@link Config#DEFAULT_M}. + * See {@link Config#getM()}. + */ + public static final String HNSW_M = "hnswM"; + + /** + * HNSW-only: This attribute (named {@code M_max} by the HNSW paper) is the maximum connectivity value for nodes + * stored on a layer greater than {@code 0}. A node can never have more that {@code mMax} neighbors. That means that + * neighbors of a node are pruned if the actual number of neighbors would otherwise exceed {@code mMax}. Note that + * this option must be greater than or equal to {@link #HNSW_M}. The default value is set to + * {@link Config#DEFAULT_M_MAX}. See {@link Config#getMMax()}. + */ + public static final String HNSW_M_MAX = "hnswMMax"; + + /** + * HNSW-only: This option (named {@code M_max0} by the HNSW paper) is the maximum connectivity value for nodes + * stored on layer {@code 0}. We will never create more that {@code mMax0} neighbors for a node that is stored on + * that layer. That means that we even prune the neighbors of a node if the actual number of neighbors would + * otherwise exceed {@code mMax0}. Note that this option must be greater than or equal to {@link #HNSW_M_MAX}. + * The default value is set to {@link Config#DEFAULT_M_MAX_0}. See {@link Config#getMMax0()}. + */ + public static final String HNSW_M_MAX_0 = "hnswMMax0"; + + /** + * HNSW-only: Maximum size of the search queues (one independent queue per layer) that are used during the insertion + * of a new node. If {@code HNSW_EF_CONSTRUCTION} is set to {@code 1}, the search naturally follows a greedy + * approach (monotonous descent), whereas a high number for {@code HNSW_EF_CONSTRUCTION} allows for a more nuanced + * search that can tolerate (false) local minima. The default value is set to {@link Config#DEFAULT_EF_CONSTRUCTION}. + * See {@link Config#getEfConstruction()}. + */ + public static final String HNSW_EF_CONSTRUCTION = "hnswEfConstruction"; + + /** + * HNSW-only: Indicator to signal if, during the insertion of a node, the set of nearest neighbors of that node is + * to be extended by the actual neighbors of those neighbors to form a set of candidates that the new node may be + * connected to during the insert operation. The default value is set to {@link Config#DEFAULT_EXTEND_CANDIDATES}. + * See {@link Config#isExtendCandidates()}. + */ + public static final String HNSW_EXTEND_CANDIDATES = "hnswExtendCandidates"; + + /** + * HNSW-only: Indicator to signal if, during the insertion of a node, candidates that have been discarded due to not + * satisfying the select-neighbor heuristic may get added back in to pad the set of neighbors if the new node would + * otherwise have too few neighbors (see {@link Config#getM()}). The default value is set to + * {@link Config#DEFAULT_KEEP_PRUNED_CONNECTIONS}. See {@link Config#isKeepPrunedConnections()}. + */ + public static final String HNSW_KEEP_PRUNED_CONNECTIONS = "hnswKeepPrunedConnections"; + + /** + * HNSW-only: If sampling is necessary (currently iff {@link #HNSW_USE_RABITQ} is {@code "true"}), this option + * represents the probability of a vector being inserted to also be written into the samples subspace of the hnsw + * structure. The vectors in that subspace are continuously aggregated until a total {@link #HNSW_STATS_THRESHOLD} + * has been reached. The default value is set to {@link Config#DEFAULT_SAMPLE_VECTOR_STATS_PROBABILITY}. See + * {@link Config#getSampleVectorStatsProbability()}. + */ + public static final String HNSW_SAMPLE_VECTOR_STATS_PROBABILITY = "hnswSampleVectorStatsProbability"; + + /** + * HNSW-only: If sampling is necessary (currently iff {@link #HNSW_USE_RABITQ} is {@code "true"}), this option + * represents the probability of the samples subspace to be further aggregated (rolled-up) when a new vector is + * inserted. The vectors in that subspace are continuously aggregated until a total + * {@link #HNSW_STATS_THRESHOLD} has been reached. The default value is set to + * {@link Config#DEFAULT_MAINTAIN_STATS_PROBABILITY}. See {@link Config#getMaintainStatsProbability()}. + */ + public static final String HNSW_MAINTAIN_STATS_PROBABILITY = "hnswMaintainStatsProbability"; + + /** + * HNSW-only: If sampling is necessary (currently iff {@link #HNSW_USE_RABITQ} is {@code "true"}), this option + * represents the threshold (being a number of vectors) that when reached causes the stats maintenance logic to + * compute the actual statistics (currently the centroid of the vectors that have been inserted to far). The result + * is then inserted into the access info subspace of the index. The default value is set to + * {@link Config#DEFAULT_STATS_THRESHOLD}. See {@link Config#getStatsThreshold()}. + */ + public static final String HNSW_STATS_THRESHOLD = "hnswStatsThreshold"; + + /** + * HNSW-only: Indicator if we should RaBitQ quantization. See {@link com.apple.foundationdb.rabitq.RaBitQuantizer} + * for more details. The default value is set to {@link Config#DEFAULT_USE_RABITQ}. + * See {@link Config#isUseRaBitQ()}. + */ + public static final String HNSW_USE_RABITQ = "hnswUseRaBitQ"; + + /** + * HNSW-only: Number of bits per dimensions iff {@link #HNSW_USE_RABITQ} is set to {@code "true"}, ignored + * otherwise. If RaBitQ encoding is used, a vector is stored using roughly + * {@code 25 + numDimensions * (numExBits + 1) / 8} bytes. The default value is set to + * {@link Config#DEFAULT_RABITQ_NUM_EX_BITS}. See {@link Config#getRaBitQNumExBits()}. + */ + public static final String HNSW_RABITQ_NUM_EX_BITS = "hnswRaBitQNumExBits"; + + /** + * HNSW-only: Maximum number of concurrent node fetches during search and modification operations. The default value + * is set to {@link Config#DEFAULT_MAX_NUM_CONCURRENT_NODE_FETCHES}. + * See {@link Config#getMaxNumConcurrentNodeFetches()}. + */ + public static final String HNSW_MAX_NUM_CONCURRENT_NODE_FETCHES = "hnswMaxNumConcurrentNodeFetches"; + + /** + * HNSW-only: Maximum number of concurrent neighborhood fetches during modification operations when the neighbors + * are pruned. The default value is set to {@link Config#DEFAULT_MAX_NUM_CONCURRENT_NEIGHBOR_FETCHES}. + * See {@link Config#getMaxNumConcurrentNeighborhoodFetches()}. + */ + public static final String HNSW_MAX_NUM_CONCURRENT_NEIGHBORHOOD_FETCHES = "hnswMaxNumConcurrentNeighborhoodFetches"; + private IndexOptions() { } } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexTypes.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexTypes.java index 1d19171093..8d10f26d9e 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexTypes.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexTypes.java @@ -164,6 +164,11 @@ public class IndexTypes { */ public static final String MULTIDIMENSIONAL = "multidimensional"; + /** + * An index using an HNSW structure. + */ + public static final String VECTOR = "vector"; + private IndexTypes() { } } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBStoreTimer.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBStoreTimer.java index c31508c8d2..8432bb363b 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBStoreTimer.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBStoreTimer.java @@ -761,6 +761,14 @@ public enum Counts implements Count { LOCKS_ATTEMPTED("number of attempts to register a lock", false), /** Count of the locks released. */ LOCKS_RELEASED("number of locks released", false), + VECTOR_NODE_READS("intermediate nodes read", false), + VECTOR_NODE_READ_BYTES("intermediate node bytes read", true), + VECTOR_NODE0_READS("intermediate nodes read", false), + VECTOR_NODE0_READ_BYTES("intermediate node bytes read", true), + VECTOR_NODE_WRITES("intermediate nodes written", false), + VECTOR_NODE_WRITE_BYTES("intermediate node bytes written", true), + VECTOR_NODE0_WRITES("intermediate nodes written", false), + VECTOR_NODE0_WRITE_BYTES("intermediate node bytes written", true), ; private final String title; diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/IndexScanComparisons.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/IndexScanComparisons.java index 16ba1182f1..46780ce539 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/IndexScanComparisons.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/IndexScanComparisons.java @@ -92,8 +92,14 @@ public IndexScanType getScanType() { return scanType; } + @Override + public boolean hasScanComparisons() { + return true; + } + @Nonnull - public ScanComparisons getComparisons() { + @Override + public ScanComparisons getScanComparisons() { return scanComparisons; } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/IndexScanParameters.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/IndexScanParameters.java index aca9d6ae83..c25b51af05 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/IndexScanParameters.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/IndexScanParameters.java @@ -26,8 +26,10 @@ import com.apple.foundationdb.record.PlanHashable; import com.apple.foundationdb.record.PlanSerializable; import com.apple.foundationdb.record.PlanSerializationContext; +import com.apple.foundationdb.record.RecordCoreException; import com.apple.foundationdb.record.metadata.Index; import com.apple.foundationdb.record.planprotos.PIndexScanParameters; +import com.apple.foundationdb.record.query.plan.ScanComparisons; import com.apple.foundationdb.record.query.plan.cascades.Correlated; import com.apple.foundationdb.record.query.plan.explain.ExplainTokensWithPrecedence; import com.apple.foundationdb.record.query.plan.cascades.explain.Attribute; @@ -85,6 +87,14 @@ public interface IndexScanParameters extends PlanHashable, Correlated detailsBuilder, @Nonnull ImmutableMap.Builder attributeMapBuilder); + default boolean hasScanComparisons() { + return false; + } + + default ScanComparisons getScanComparisons() { + throw new RecordCoreException("this index scan parameter object does not use ScanComparisons"); + } + @Nonnull IndexScanParameters translateCorrelations(@Nonnull TranslationMap translationMap, boolean shouldSimplifyValues); diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanBounds.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanBounds.java new file mode 100644 index 0000000000..162a4098aa --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanBounds.java @@ -0,0 +1,100 @@ +/* + * MultidimensionalIndexScanBounds.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb; + +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.record.IndexScanType; +import com.apple.foundationdb.record.RecordCoreException; +import com.apple.foundationdb.record.TupleRange; +import com.apple.foundationdb.record.query.expressions.Comparisons; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * TODO. + */ +@API(API.Status.EXPERIMENTAL) +public class VectorIndexScanBounds implements IndexScanBounds { + @Nonnull + private final TupleRange prefixRange; + + @Nonnull + private final Comparisons.Type comparisonType; + @Nullable + private final RealVector queryVector; + private final int limit; + @Nonnull final VectorIndexScanOptions vectorIndexScanOptions; + + public VectorIndexScanBounds(@Nonnull final TupleRange prefixRange, + @Nonnull final Comparisons.Type comparisonType, + @Nullable final RealVector queryVector, + final int limit, + @Nonnull final VectorIndexScanOptions vectorIndexScanOptions) { + this.prefixRange = prefixRange; + this.comparisonType = comparisonType; + this.queryVector = queryVector; + this.limit = limit; + this.vectorIndexScanOptions = vectorIndexScanOptions; + } + + @Nonnull + @Override + public IndexScanType getScanType() { + return IndexScanType.BY_DISTANCE; + } + + @Nonnull + public TupleRange getPrefixRange() { + return prefixRange; + } + + @Nonnull + public Comparisons.Type getComparisonType() { + return comparisonType; + } + + @Nullable + public RealVector getQueryVector() { + return queryVector; + } + + public int getLimit() { + return limit; + } + + @Nonnull + public VectorIndexScanOptions getVectorIndexScanOptions() { + return vectorIndexScanOptions; + } + + public int getAdjustedLimit() { + switch (getComparisonType()) { + case DISTANCE_RANK_LESS_THAN: + return limit - 1; + case DISTANCE_RANK_LESS_THAN_OR_EQUAL: + return limit; + default: + throw new RecordCoreException("unsupported comparison"); + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java new file mode 100644 index 0000000000..d2cc648f3e --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java @@ -0,0 +1,343 @@ +/* + * MultidimensionalIndexScanComparisons.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb; + +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.record.EvaluationContext; +import com.apple.foundationdb.record.IndexScanType; +import com.apple.foundationdb.record.PlanDeserializer; +import com.apple.foundationdb.record.PlanHashable; +import com.apple.foundationdb.record.PlanSerializationContext; +import com.apple.foundationdb.record.TupleRange; +import com.apple.foundationdb.record.metadata.Index; +import com.apple.foundationdb.record.planprotos.PIndexScanParameters; +import com.apple.foundationdb.record.planprotos.PVectorIndexScanComparisons; +import com.apple.foundationdb.record.query.expressions.Comparisons; +import com.apple.foundationdb.record.query.expressions.Comparisons.DistanceRankValueComparison; +import com.apple.foundationdb.record.query.plan.ScanComparisons; +import com.apple.foundationdb.record.query.plan.cascades.AliasMap; +import com.apple.foundationdb.record.query.plan.cascades.CorrelationIdentifier; +import com.apple.foundationdb.record.query.plan.cascades.explain.Attribute; +import com.apple.foundationdb.record.query.plan.cascades.values.translation.TranslationMap; +import com.apple.foundationdb.record.query.plan.explain.ExplainTokens; +import com.apple.foundationdb.record.query.plan.explain.ExplainTokensWithPrecedence; +import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.Objects; +import java.util.Set; + +/** + * {@link ScanComparisons} for use in a multidimensional index scan. + */ +@API(API.Status.UNSTABLE) +public class VectorIndexScanComparisons implements IndexScanParameters { + @Nonnull + private final ScanComparisons prefixScanComparisons; + @Nonnull + private final DistanceRankValueComparison distanceRankValueComparison; + @Nonnull + private final VectorIndexScanOptions vectorIndexScanOptions; + + public VectorIndexScanComparisons(@Nonnull final ScanComparisons prefixScanComparisons, + @Nonnull final DistanceRankValueComparison distanceRankValueComparison, + @Nonnull final VectorIndexScanOptions vectorIndexScanOptions) { + this.prefixScanComparisons = prefixScanComparisons; + this.distanceRankValueComparison = distanceRankValueComparison; + this.vectorIndexScanOptions = vectorIndexScanOptions; + } + + @Nonnull + @Override + public IndexScanType getScanType() { + return IndexScanType.BY_DISTANCE; + } + + @Nonnull + public ScanComparisons getPrefixScanComparisons() { + return prefixScanComparisons; + } + + @Nonnull + public DistanceRankValueComparison getDistanceRankValueComparison() { + return distanceRankValueComparison; + } + + @Nonnull + public VectorIndexScanOptions getVectorIndexScanOptions() { + return vectorIndexScanOptions; + } + + @Override + public boolean hasScanComparisons() { + return true; + } + + @Override + public ScanComparisons getScanComparisons() { + final var builder = new ScanComparisons.Builder(); + builder.addAll(prefixScanComparisons.getEqualityComparisons(), ImmutableSet.of()); + if (!prefixScanComparisons.isEquality()) { + builder.addAll(ImmutableList.of(), prefixScanComparisons.getInequalityComparisons()); + return builder.build(); + } + // only equalities coming from the prefix + if (getDistanceRankValueComparison().getType().isEquality()) { + builder.addEqualityComparison(getDistanceRankValueComparison()); + } else { + builder.addInequalityComparison(getDistanceRankValueComparison()); + } + return builder.build(); + } + + @Nonnull + @Override + public VectorIndexScanBounds bind(@Nonnull final FDBRecordStoreBase store, @Nonnull final Index index, + @Nonnull final EvaluationContext context) { + return new VectorIndexScanBounds(prefixScanComparisons.toTupleRange(store, context), + distanceRankValueComparison.getType(), distanceRankValueComparison.getVector(store, context), + distanceRankValueComparison.getLimit(store, context), vectorIndexScanOptions); + } + + @Override + public int planHash(@Nonnull PlanHashMode mode) { + return PlanHashable.objectsPlanHash(mode, prefixScanComparisons, distanceRankValueComparison, + vectorIndexScanOptions); + } + + @Override + public boolean isUnique(@Nonnull Index index) { + // + // This is currently never true as we would need an equality-bound scan comparison that includes the primary + // key which we currently cannot express. We can only express equality-bound constraints on the prefix, thus + // the only case where this is true, and we can detect it would currently occur if the prefix contained the + // primary key which in turn would make for a very uninteresting scenario, as each partition would contain + // exactly one vector. + // + return false; + } + + @Nonnull + @Override + public ExplainTokensWithPrecedence explain() { + @Nullable var tupleRange = prefixScanComparisons.toTupleRangeWithoutContext(); + final var prefix = tupleRange == null + ? prefixScanComparisons.explain().getExplainTokens() + : new ExplainTokens().addToString(tupleRange); + + ExplainTokens distanceRank; + try { + @Nullable var vector = distanceRankValueComparison.getVector(null, null); + int limit = distanceRankValueComparison.getLimit(null, null); + distanceRank = + new ExplainTokens().addNested(vector == null + ? new ExplainTokens().addKeyword("null") + : new ExplainTokens().addToString(vector)); + distanceRank.addKeyword(distanceRankValueComparison.getType().name()).addWhitespace().addToString(limit); + } catch (final Comparisons.EvaluationContextRequiredException e) { + distanceRank = + new ExplainTokens().addNested(distanceRankValueComparison.explain().getExplainTokens()); + } + + return ExplainTokensWithPrecedence.of(prefix.addOptionalWhitespace().addToString(":{").addOptionalWhitespace() + .addNested(distanceRank).addOptionalWhitespace().addToString("}:") + .addOptionalWhitespace().addNested(vectorIndexScanOptions.explain().getExplainTokens())); + } + + @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance") + @Override + public void getPlannerGraphDetails(@Nonnull final ImmutableList.Builder detailsBuilder, + @Nonnull final ImmutableMap.Builder attributeMapBuilder) { + @Nullable TupleRange tupleRange = prefixScanComparisons.toTupleRangeWithoutContext(); + if (tupleRange != null) { + detailsBuilder.add("prefix: " + tupleRange.getLowEndpoint().toString(false) + "{{plow}}, {{phigh}}" + tupleRange.getHighEndpoint().toString(true)); + attributeMapBuilder.put("plow", Attribute.gml(tupleRange.getLow() == null ? "-∞" : tupleRange.getLow().toString())); + attributeMapBuilder.put("phigh", Attribute.gml(tupleRange.getHigh() == null ? "∞" : tupleRange.getHigh().toString())); + } else { + detailsBuilder.add("prefix comparisons: {{pcomparisons}}"); + attributeMapBuilder.put("pcomparisons", Attribute.gml(prefixScanComparisons.toString())); + } + + try { + @Nullable var vector = distanceRankValueComparison.getVector(null, null); + int limit = distanceRankValueComparison.getLimit(null, null); + detailsBuilder.add("distanceRank: {{vector}} {{type}} {{limit}}"); + attributeMapBuilder.put("vector", Attribute.gml(String.valueOf(vector))); + attributeMapBuilder.put("type", Attribute.gml(distanceRankValueComparison.getType())); + attributeMapBuilder.put("limit", Attribute.gml(limit)); + } catch (final Comparisons.EvaluationContextRequiredException e) { + detailsBuilder.add("distanceRank: {{comparison}}"); + attributeMapBuilder.put("comparison", Attribute.gml(distanceRankValueComparison)); + } + + detailsBuilder.add("scan options: {{scanoptions}}"); + attributeMapBuilder.put("scanoptions", Attribute.gml(vectorIndexScanOptions.toString())); + } + + @Nonnull + @Override + public Set getCorrelatedTo() { + final ImmutableSet.Builder correlatedToBuilder = ImmutableSet.builder(); + correlatedToBuilder.addAll(prefixScanComparisons.getCorrelatedTo()); + correlatedToBuilder.addAll(distanceRankValueComparison.getCorrelatedTo()); + return correlatedToBuilder.build(); + } + + @Nonnull + @Override + public IndexScanParameters rebase(@Nonnull final AliasMap translationMap) { + return translateCorrelations(TranslationMap.rebaseWithAliasMap(translationMap), false); + } + + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public boolean semanticEquals(@Nullable final Object other, @Nonnull final AliasMap aliasMap) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + + final VectorIndexScanComparisons that = (VectorIndexScanComparisons)other; + + if (!prefixScanComparisons.semanticEquals(that.prefixScanComparisons, aliasMap)) { + return false; + } + + if (!distanceRankValueComparison.semanticEquals(that.distanceRankValueComparison, aliasMap)) { + return false; + } + return vectorIndexScanOptions.equals(that.vectorIndexScanOptions); + } + + @Override + public int semanticHashCode() { + int hashCode = prefixScanComparisons.semanticHashCode(); + hashCode = 31 * hashCode + distanceRankValueComparison.semanticHashCode(); + return 31 * hashCode + vectorIndexScanOptions.hashCode(); + } + + @Nonnull + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public IndexScanParameters translateCorrelations(@Nonnull final TranslationMap translationMap, + final boolean shouldSimplifyValues) { + final ScanComparisons translatedPrefixScanComparisons = + prefixScanComparisons.translateCorrelations(translationMap, shouldSimplifyValues); + + final DistanceRankValueComparison translatedDistanceRankValueComparison = + distanceRankValueComparison.translateCorrelations(translationMap, shouldSimplifyValues); + + if (translatedPrefixScanComparisons != prefixScanComparisons || + translatedDistanceRankValueComparison != distanceRankValueComparison) { + return withComparisonsAndOptions(translatedPrefixScanComparisons, translatedDistanceRankValueComparison, + vectorIndexScanOptions); + } + return this; + } + + @Nonnull + protected VectorIndexScanComparisons withComparisonsAndOptions(@Nonnull final ScanComparisons prefixScanComparisons, + @Nonnull final DistanceRankValueComparison distanceRankValueComparison, + @Nonnull final VectorIndexScanOptions vectorIndexScanOptions) { + return new VectorIndexScanComparisons(prefixScanComparisons, distanceRankValueComparison, + vectorIndexScanOptions); + } + + @Override + public String toString() { + return "BY_VALUE(VECTOR):" + prefixScanComparisons + ":" + distanceRankValueComparison + ":" + vectorIndexScanOptions; + } + + @Override + @SpotBugsSuppressWarnings("EQ_UNUSUAL") + @SuppressWarnings("EqualsWhichDoesntCheckParameterClass") + public boolean equals(final Object o) { + return semanticEquals(o, AliasMap.emptyMap()); + } + + @Override + public int hashCode() { + return semanticHashCode(); + } + + @Nonnull + @Override + public PVectorIndexScanComparisons toProto(@Nonnull final PlanSerializationContext serializationContext) { + final PVectorIndexScanComparisons.Builder builder = PVectorIndexScanComparisons.newBuilder(); + builder.setPrefixScanComparisons(prefixScanComparisons.toProto(serializationContext)); + builder.setDistanceRankValueComparison(distanceRankValueComparison.toProto(serializationContext)); + builder.setVectorIndexScanOptions(vectorIndexScanOptions.toProto(serializationContext)); + return builder.build(); + } + + @Nonnull + @Override + public PIndexScanParameters toIndexScanParametersProto(@Nonnull final PlanSerializationContext serializationContext) { + return PIndexScanParameters.newBuilder().setVectorIndexScanComparisons(toProto(serializationContext)).build(); + } + + @Nonnull + public static VectorIndexScanComparisons fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PVectorIndexScanComparisons vectorIndexScanComparisonsProto) { + return new VectorIndexScanComparisons(ScanComparisons.fromProto(serializationContext, + Objects.requireNonNull(vectorIndexScanComparisonsProto.getPrefixScanComparisons())), + DistanceRankValueComparison.fromProto(serializationContext, Objects.requireNonNull(vectorIndexScanComparisonsProto.getDistanceRankValueComparison())), + VectorIndexScanOptions.fromProto(Objects.requireNonNull(vectorIndexScanComparisonsProto.getVectorIndexScanOptions()))); + } + + @Nonnull + public static VectorIndexScanComparisons byDistance(@Nullable ScanComparisons prefixScanComparisons, + @Nonnull final DistanceRankValueComparison distanceRankValueComparison, + @Nonnull VectorIndexScanOptions vectorIndexScanOptions) { + if (prefixScanComparisons == null) { + prefixScanComparisons = ScanComparisons.EMPTY; + } + + return new VectorIndexScanComparisons(prefixScanComparisons, distanceRankValueComparison, + vectorIndexScanOptions); + } + + /** + * Deserializer. + */ + @AutoService(PlanDeserializer.class) + public static class Deserializer implements PlanDeserializer { + @Nonnull + @Override + public Class getProtoMessageClass() { + return PVectorIndexScanComparisons.class; + } + + @Nonnull + @Override + public VectorIndexScanComparisons fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PVectorIndexScanComparisons vectorIndexScanComparisonsProto) { + return VectorIndexScanComparisons.fromProto(serializationContext, vectorIndexScanComparisonsProto); + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanOptions.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanOptions.java new file mode 100644 index 0000000000..3b8dedb5c8 --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanOptions.java @@ -0,0 +1,259 @@ +/* + * VectorIndexScanOptions.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb; + +import com.apple.foundationdb.record.PlanDeserializer; +import com.apple.foundationdb.record.PlanHashable; +import com.apple.foundationdb.record.PlanSerializable; +import com.apple.foundationdb.record.PlanSerializationContext; +import com.apple.foundationdb.record.metadata.expressions.LiteralKeyExpression; +import com.apple.foundationdb.record.planprotos.PVectorIndexScanOptions; +import com.apple.foundationdb.record.planprotos.PVectorIndexScanOptions.POptionEntry; +import com.apple.foundationdb.record.query.plan.explain.DefaultExplainFormatter; +import com.apple.foundationdb.record.query.plan.explain.ExplainTokens; +import com.apple.foundationdb.record.query.plan.explain.ExplainTokensWithPrecedence; +import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public final class VectorIndexScanOptions implements PlanHashable, PlanSerializable { + public static final OptionKey HNSW_EF_SEARCH = new OptionKey<>("hnswEfSearch", Integer.class); + public static final OptionKey HNSW_RETURN_VECTORS = new OptionKey<>("hnswReturnVectors", Boolean.class); + + private static final VectorIndexScanOptions EMPTY = new VectorIndexScanOptions(ImmutableMap.of()); + + private static final Map> optionsNameMap = + ImmutableMap.of(HNSW_EF_SEARCH.getOptionName(), HNSW_EF_SEARCH, + HNSW_RETURN_VECTORS.getOptionName(), HNSW_RETURN_VECTORS); + + @Nonnull + private final Map, Object> optionsMap; + + private VectorIndexScanOptions(@Nonnull final Map, Object> optionsMap) { + this.optionsMap = optionsMap; + } + + public boolean containsOption(@Nonnull final OptionKey key) { + return optionsMap.containsKey(key); + } + + @Nullable + public T getOption(@Nonnull final OptionKey key) { + return key.getClazz().cast(optionsMap.get(key)); + } + + @Nonnull + public VectorIndexScanOptions.Builder toBuilder() { + return new Builder(optionsMap); + } + + @Override + public int planHash(@Nonnull final PlanHashMode hashMode) { + return PlanHashable.objectPlanHash(hashMode, optionsMap); + } + + @Nonnull + @Override + public PVectorIndexScanOptions toProto(@Nonnull final PlanSerializationContext serializationContext) { + final PVectorIndexScanOptions.Builder scanOptionsBuilder = PVectorIndexScanOptions.newBuilder(); + for (final Map.Entry, Object> entry : optionsMap.entrySet()) { + scanOptionsBuilder.addOptionEntries( + POptionEntry.newBuilder() + .setKey(entry.getKey().getOptionName()) + .setValue(LiteralKeyExpression.toProtoValue(entry.getValue()))) + .build(); + } + + return scanOptionsBuilder.build(); + } + + @Nonnull + public ExplainTokensWithPrecedence explain() { + final var explainTokens = + new ExplainTokens().addSequence(() -> new ExplainTokens().addCommaAndWhiteSpace(), + optionsMap.entrySet() + .stream() + .map(entry -> + new ExplainTokens().addIdentifier(entry.getKey().getOptionName()).addKeyword(":") + .addWhitespace().addToString(entry.getValue())) + .collect(Collectors.toList())); + + return ExplainTokensWithPrecedence.of(new ExplainTokens().addOptionalWhitespace().addOpeningSquareBracket() + .addNested(explainTokens).addOptionalWhitespace().addClosingSquareBracket()); + } + + @Override + public boolean equals(final Object o) { + if (o == this) { + return true; + } + if (!(o instanceof VectorIndexScanOptions)) { + return false; + } + final VectorIndexScanOptions that = (VectorIndexScanOptions)o; + return Objects.equals(optionsMap, that.optionsMap); + } + + @Override + public int hashCode() { + return Objects.hashCode(optionsMap); + } + + @Override + public String toString() { + return explain().getExplainTokens().render(DefaultExplainFormatter.forDebugging()).toString(); + } + + @Nonnull + public static VectorIndexScanOptions.Builder builder() { + return new Builder(); + } + + @Nonnull + public static VectorIndexScanOptions empty() { + return EMPTY; + } + + @Nonnull + public static VectorIndexScanOptions fromProto(@Nonnull final PVectorIndexScanOptions vectorIndexScanOptionsProto) { + final Map, Object> optionsMap = + Maps.newHashMapWithExpectedSize(vectorIndexScanOptionsProto.getOptionEntriesCount()); + for (int i = 0; i < vectorIndexScanOptionsProto.getOptionEntriesCount(); i ++) { + final POptionEntry optionEntryProto = vectorIndexScanOptionsProto.getOptionEntries(i); + optionsMap.put(Objects.requireNonNull(optionsNameMap.get(optionEntryProto.getKey())), + LiteralKeyExpression.fromProtoValue(optionEntryProto.getValue())); + } + return new VectorIndexScanOptions(optionsMap); + } + + public static class Builder { + @Nonnull + private final Map, Object> optionsMap; + + public Builder() { + this(ImmutableMap.of()); + } + + public Builder(@Nonnull final Map, Object> optionsMap) { + // creating an ordinary hashmap here since it is not null-averse + this.optionsMap = Maps.newHashMap(optionsMap); + } + + @Nonnull + public Builder putOption(@Nonnull final OptionKey key, T value) { + optionsMap.put(key, value); + return this; + } + + @Nonnull + public Builder removeOption(@Nonnull final OptionKey key) { + optionsMap.remove(key); + return this; + } + + @Nonnull + public VectorIndexScanOptions build() { + if (optionsMap.isEmpty()) { + return EMPTY; + } + return new VectorIndexScanOptions(Collections.unmodifiableMap(Maps.newHashMap(optionsMap))); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Builder)) { + return false; + } + final Builder builder = (Builder)o; + return Objects.equals(optionsMap, builder.optionsMap); + } + + @Override + public int hashCode() { + return Objects.hashCode(optionsMap); + } + } + + public static class OptionKey { + @Nonnull + private final String optionName; + @Nonnull + private final Class clazz; + + public OptionKey(@Nonnull final String optionName, @Nonnull final Class clazz) { + this.optionName = optionName; + this.clazz = clazz; + } + + @Nonnull + public String getOptionName() { + return optionName; + } + + @Nonnull + public Class getClazz() { + return clazz; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof OptionKey)) { + return false; + } + final OptionKey optionKey = (OptionKey)o; + return Objects.equals(optionName, optionKey.optionName) && clazz == optionKey.clazz; + } + + @Override + public int hashCode() { + return Objects.hash(optionName); + } + } + + /** + * Deserializer. + */ + @AutoService(PlanDeserializer.class) + public static class Deserializer implements PlanDeserializer { + @Nonnull + @Override + public Class getProtoMessageClass() { + return PVectorIndexScanOptions.class; + } + + @Nonnull + @Override + public VectorIndexScanOptions fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PVectorIndexScanOptions vectorIndexScanOptionsProto) { + return VectorIndexScanOptions.fromProto(vectorIndexScanOptionsProto); + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java new file mode 100644 index 0000000000..a0be8a86c0 --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java @@ -0,0 +1,156 @@ +/* + * VectorIndexHelper.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.hnsw.Config; +import com.apple.foundationdb.async.hnsw.Config.ConfigBuilder; +import com.apple.foundationdb.async.hnsw.HNSW; +import com.apple.foundationdb.linear.Metric; +import com.apple.foundationdb.record.logging.LogMessageKeys; +import com.apple.foundationdb.record.metadata.Index; +import com.apple.foundationdb.record.metadata.IndexOptions; +import com.apple.foundationdb.record.metadata.MetaDataException; +import com.apple.foundationdb.record.provider.common.StoreTimer; + +import javax.annotation.Nonnull; + +/** + * Helper functions for index maintainers that use a {@link HNSW}. + */ +@API(API.Status.EXPERIMENTAL) +public final class VectorIndexHelper { + private VectorIndexHelper() { + } + + /** + * Parse standard options into {@link Config}. + * @param index the index definition to get options from + * @return parsed config options + */ + @Nonnull + public static Config getConfig(@Nonnull final Index index) { + final ConfigBuilder builder = HNSW.newConfigBuilder(); + final String hnswRandomSeedOption = index.getOption(IndexOptions.HNSW_DETERMINISTIC_SEEDING); + if (hnswRandomSeedOption != null) { + builder.setDeterministicSeeding(Boolean.parseBoolean(hnswRandomSeedOption)); + } + final String hnswMetricOption = index.getOption(IndexOptions.HNSW_METRIC); + if (hnswMetricOption != null) { + builder.setMetric(Metric.valueOf(hnswMetricOption)); + } + final String hnswNumDimensionsOption = index.getOption(IndexOptions.HNSW_NUM_DIMENSIONS); + if (hnswNumDimensionsOption == null) { + throw new MetaDataException("need to specify the number of dimensions", + LogMessageKeys.INDEX_NAME, index.getName()); + } + final int numDimensions = Integer.parseInt(hnswNumDimensionsOption); + + final String hnswUseInliningOption = index.getOption(IndexOptions.HNSW_USE_INLINING); + if (hnswUseInliningOption != null) { + builder.setUseInlining(Boolean.parseBoolean(hnswUseInliningOption)); + } + final String hnswMOption = index.getOption(IndexOptions.HNSW_M); + if (hnswMOption != null) { + builder.setM(Integer.parseInt(hnswMOption)); + } + final String hnswMMaxOption = index.getOption(IndexOptions.HNSW_M_MAX); + if (hnswMMaxOption != null) { + builder.setMMax(Integer.parseInt(hnswMMaxOption)); + } + final String hnswMMax0Option = index.getOption(IndexOptions.HNSW_M_MAX_0); + if (hnswMMax0Option != null) { + builder.setMMax0(Integer.parseInt(hnswMMax0Option)); + } + final String hnswEfConstructionOption = index.getOption(IndexOptions.HNSW_EF_CONSTRUCTION); + if (hnswEfConstructionOption != null) { + builder.setEfConstruction(Integer.parseInt(hnswEfConstructionOption)); + } + final String hnswExtendCandidatesOption = index.getOption(IndexOptions.HNSW_EXTEND_CANDIDATES); + if (hnswExtendCandidatesOption != null) { + builder.setExtendCandidates(Boolean.parseBoolean(hnswExtendCandidatesOption)); + } + final String hnswKeepPrunedConnectionsOption = index.getOption(IndexOptions.HNSW_KEEP_PRUNED_CONNECTIONS); + if (hnswKeepPrunedConnectionsOption != null) { + builder.setKeepPrunedConnections(Boolean.parseBoolean(hnswKeepPrunedConnectionsOption)); + } + final String hnswSampleVectorStatsProbabilityOption = index.getOption(IndexOptions.HNSW_SAMPLE_VECTOR_STATS_PROBABILITY); + if (hnswSampleVectorStatsProbabilityOption != null) { + builder.setSampleVectorStatsProbability(Double.parseDouble(hnswSampleVectorStatsProbabilityOption)); + } + final String hnswMaintainStatsProbabilityOption = index.getOption(IndexOptions.HNSW_MAINTAIN_STATS_PROBABILITY); + if (hnswMaintainStatsProbabilityOption != null) { + builder.setMaintainStatsProbability(Double.parseDouble(hnswMaintainStatsProbabilityOption)); + } + final String hnswStatsThresholdOption = index.getOption(IndexOptions.HNSW_STATS_THRESHOLD); + if (hnswStatsThresholdOption != null) { + builder.setStatsThreshold(Integer.parseInt(hnswStatsThresholdOption)); + } + final String hnswUseRaBitQOption = index.getOption(IndexOptions.HNSW_USE_RABITQ); + if (hnswUseRaBitQOption != null) { + builder.setUseRaBitQ(Boolean.parseBoolean(hnswUseRaBitQOption)); + } + final String hnswRaBitQNumExBitsOption = index.getOption(IndexOptions.HNSW_RABITQ_NUM_EX_BITS); + if (hnswRaBitQNumExBitsOption != null) { + builder.setRaBitQNumExBits(Integer.parseInt(hnswRaBitQNumExBitsOption)); + } + final String hnswMaxNumConcurrentNodeFetchesOption = index.getOption(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NODE_FETCHES); + if (hnswMaxNumConcurrentNodeFetchesOption != null) { + builder.setMaxNumConcurrentNodeFetches(Integer.parseInt(hnswMaxNumConcurrentNodeFetchesOption)); + } + final String hnswMaxNumConcurrentNeighborhoodFetchesOption = index.getOption(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NEIGHBORHOOD_FETCHES); + if (hnswMaxNumConcurrentNeighborhoodFetchesOption != null) { + builder.setMaxNumConcurrentNeighborhoodFetches(Integer.parseInt(hnswMaxNumConcurrentNeighborhoodFetchesOption)); + } + return builder.build(numDimensions); + } + + /** + * Instrumentation events specific to vector index maintenance. + */ + public enum Events implements StoreTimer.DetailEvent { + VECTOR_SCAN("scanning the partition of a vector index"), + VECTOR_SKIP_SCAN("skip scan the prefix tuples of a vector index scan"); + + private final String title; + private final String logKey; + + Events(String title, String logKey) { + this.title = title; + this.logKey = (logKey != null) ? logKey : StoreTimer.DetailEvent.super.logKey(); + } + + Events(String title) { + this(title, null); + } + + @Override + public String title() { + return title; + } + + @Override + @Nonnull + public String logKey() { + return this.logKey; + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java new file mode 100644 index 0000000000..e5cb5bc996 --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java @@ -0,0 +1,560 @@ +/* + * VectorIndexMaintainer.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.async.hnsw.Config; +import com.apple.foundationdb.async.hnsw.HNSW; +import com.apple.foundationdb.async.hnsw.Node; +import com.apple.foundationdb.async.hnsw.NodeReference; +import com.apple.foundationdb.async.hnsw.OnReadListener; +import com.apple.foundationdb.async.hnsw.OnWriteListener; +import com.apple.foundationdb.async.hnsw.ResultEntry; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.record.CursorStreamingMode; +import com.apple.foundationdb.record.EndpointType; +import com.apple.foundationdb.record.ExecuteProperties; +import com.apple.foundationdb.record.IndexEntry; +import com.apple.foundationdb.record.IndexScanType; +import com.apple.foundationdb.record.PipelineOperation; +import com.apple.foundationdb.record.RecordCoreException; +import com.apple.foundationdb.record.RecordCursor; +import com.apple.foundationdb.record.RecordCursorContinuation; +import com.apple.foundationdb.record.RecordCursorProto; +import com.apple.foundationdb.record.ScanProperties; +import com.apple.foundationdb.record.TupleRange; +import com.apple.foundationdb.record.cursors.AsyncLockCursor; +import com.apple.foundationdb.record.cursors.ChainedCursor; +import com.apple.foundationdb.record.cursors.LazyCursor; +import com.apple.foundationdb.record.cursors.ListCursor; +import com.apple.foundationdb.record.locking.LockIdentifier; +import com.apple.foundationdb.record.metadata.Key; +import com.apple.foundationdb.record.metadata.expressions.KeyExpression; +import com.apple.foundationdb.record.metadata.expressions.KeyWithValueExpression; +import com.apple.foundationdb.record.provider.common.StoreTimer; +import com.apple.foundationdb.record.provider.foundationdb.FDBIndexableRecord; +import com.apple.foundationdb.record.provider.foundationdb.FDBStoreTimer; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainerState; +import com.apple.foundationdb.record.provider.foundationdb.IndexScanBounds; +import com.apple.foundationdb.record.provider.foundationdb.KeyValueCursor; +import com.apple.foundationdb.record.provider.foundationdb.VectorIndexScanBounds; +import com.apple.foundationdb.record.provider.foundationdb.VectorIndexScanOptions; +import com.apple.foundationdb.record.query.QueryToKeyMatcher; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil2; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.foundationdb.tuple.TupleHelpers; +import com.google.common.base.Verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +/** + * An index maintainer for keeping an {@link HNSW}. + */ +@API(API.Status.EXPERIMENTAL) +public class VectorIndexMaintainer extends StandardIndexMaintainer { + @Nonnull + private final Config config; + + public VectorIndexMaintainer(IndexMaintainerState state) { + super(state); + this.config = VectorIndexHelper.getConfig(state.index); + } + + @Nonnull + public Config getConfig() { + return config; + } + + /** + * Scan the vector index. + * @param scanBounds the {@link VectorIndexScanBounds bounds} of the scan to perform + * @param continuation any continuation from a previous scan invocation + * @param scanProperties skip, limit and other properties of the scan + * @return a {@link RecordCursor} of index entries + */ + @Nonnull + @Override + @SuppressWarnings("resource") + public RecordCursor scan(@Nonnull final IndexScanBounds scanBounds, @Nullable final byte[] continuation, + @Nonnull final ScanProperties scanProperties) { + if (!scanBounds.getScanType().equals(IndexScanType.BY_DISTANCE)) { + throw new RecordCoreException("Can only scan vector index by value."); + } + if (!(scanBounds instanceof VectorIndexScanBounds)) { + throw new RecordCoreException("Need proper vector index scan bounds."); + } + final VectorIndexScanBounds vectorIndexScanBounds = (VectorIndexScanBounds)scanBounds; + + final KeyWithValueExpression keyWithValueExpression = getKeyWithValueExpression(state.index.getRootExpression()); + final int prefixSize = keyWithValueExpression.getSplitPoint(); + + final ExecuteProperties executeProperties = scanProperties.getExecuteProperties(); + final ScanProperties innerScanProperties = scanProperties.with(ExecuteProperties::clearSkipAndLimit); + final Subspace indexSubspace = getIndexSubspace(); + final FDBStoreTimer timer = Objects.requireNonNull(state.context.getTimer()); + + // + // If there is a {@code prefix > 0}, then we model the scan as a flatmap over the distinct prefixes as the outer + // and the correlated HNSW search as the inner. + // + if (prefixSize > 0) { + // + // Skip-scan through the prefixes in a way that we only consider each distinct prefix. That skip scan + // forms the outer of a join with an inner that searches the R-tree for that prefix using the + // spatial predicates of the scan bounds. + // + return RecordCursor.flatMapPipelined(prefixSkipScan(prefixSize, timer, vectorIndexScanBounds, innerScanProperties), + (prefixTuple, innerContinuation) -> { + Verify.verify(prefixTuple.size() == prefixSize); + final Subspace hnswSubspace = indexSubspace.subspace(prefixTuple); + + return scanSinglePartition(prefixTuple, innerContinuation, hnswSubspace, + timer, vectorIndexScanBounds, scanProperties); + }, + continuation, + state.store.getPipelineSize(PipelineOperation.INDEX_TO_RECORD)) + .skipThenLimit(executeProperties.getSkip(), executeProperties.getReturnedRowLimit()); + } else { + // + // As {@code prefix == 0}, there only is exactly one prefix ({@code null}). While it is possible to also + // just do a flatmap over some non-existing outer, it's probably more efficient to just do a plain scan + // of the HNSW here. + // + return scanSinglePartition(null, continuation, + indexSubspace, timer, vectorIndexScanBounds, scanProperties) + .skipThenLimit(executeProperties.getSkip(), executeProperties.getReturnedRowLimit()); + } + } + + /** + * Scan one partition of the vector index, i.e. the one HNSW that holds the data for the partition identified + * by {@code prefixTuple}. + * @param prefixTuple the tuple identifying the partition + * @param continuation the continuation for this scan or {@code null} if this is the first execution + * @param hnswSubspace the subspace where the HNSW resides + * @param timer the times + * @param vectorIndexScanBounds the bounds for this scan + * @param scanProperties the scan properties for this scan + * @return a {@link RecordCursor} returning the index entries for this scan + */ + @Nonnull + @SuppressWarnings("resource") + private RecordCursor scanSinglePartition(@Nullable final Tuple prefixTuple, + @Nullable final byte[] continuation, + @Nonnull final Subspace hnswSubspace, + @Nonnull final FDBStoreTimer timer, + @Nonnull final VectorIndexScanBounds vectorIndexScanBounds, + @Nonnull final ScanProperties scanProperties) { + if (continuation != null) { + final RecordCursorProto.VectorIndexScanContinuation parsedContinuation = + Continuation.fromBytes(continuation); + final ImmutableList.Builder indexEntriesBuilder = ImmutableList.builder(); + for (int i = 0; i < parsedContinuation.getIndexEntriesCount(); i++) { + final RecordCursorProto.VectorIndexScanContinuation.IndexEntry indexEntryProto = + parsedContinuation.getIndexEntries(i); + indexEntriesBuilder.add(new IndexEntry(state.index, + Tuple.fromBytes(indexEntryProto.getKey().toByteArray()), + Tuple.fromBytes(indexEntryProto.getValue().toByteArray()))); + } + final ImmutableList indexEntries = indexEntriesBuilder.build(); + return new ListCursor<>(indexEntries, parsedContinuation.getInnerContinuation().toByteArray()) + .mapResult(result -> + result.withContinuation(new Continuation(indexEntries, result.getContinuation()))); + } + + final HNSW hnsw = new HNSW(hnswSubspace, getExecutor(), getConfig(), + OnWriteListener.NOOP, new OnRead(timer)); + final ReadTransaction transaction = + state.context.readTransaction(scanProperties.getExecuteProperties().getIsolationLevel().isSnapshot()); + return new LazyCursor<>( + state.context.acquireReadLock(new LockIdentifier(hnswSubspace)) + .thenApply(lock -> + new AsyncLockCursor<>(lock, + new LazyCursor<>( + kNearestNeighborSearch(prefixTuple, hnsw, transaction, + vectorIndexScanBounds), + getExecutor()))), + state.context.getExecutor()); + } + + @SuppressWarnings({"resource", "checkstyle:MethodName"}) + @Nonnull + private CompletableFuture> + kNearestNeighborSearch(@Nullable final Tuple prefixTuple, + @Nonnull final HNSW hnsw, + @Nonnull final ReadTransaction transaction, + @Nonnull final VectorIndexScanBounds vectorIndexScanBounds) { + return hnsw.kNearestNeighborsSearch(transaction, vectorIndexScanBounds.getAdjustedLimit(), + efSearch(vectorIndexScanBounds), returnVectors(hnsw.getConfig(), vectorIndexScanBounds), + Objects.requireNonNull(vectorIndexScanBounds.getQueryVector())) + .thenApply(resultEntries -> { + final ImmutableList.Builder nearestNeighborEntriesBuilder = ImmutableList.builder(); + for (final ResultEntry nearestNeighbor : resultEntries) { + nearestNeighborEntriesBuilder.add(toIndexEntry(prefixTuple, nearestNeighbor)); + } + final ImmutableList nearestNeighborsEntries = nearestNeighborEntriesBuilder.build(); + return new ListCursor<>(getExecutor(), nearestNeighborsEntries, 0) + .mapResult(result -> { + final RecordCursorContinuation continuation = result.getContinuation(); + if (continuation.isEnd()) { + return result; + } + return result.withContinuation(new Continuation(nearestNeighborsEntries, continuation)); + }); + }); + } + + @Nonnull + private IndexEntry toIndexEntry(@Nullable final Tuple prefixTuple, @Nonnull final ResultEntry resultEntry) { + final List keyItems = Lists.newArrayList(); + if (prefixTuple != null) { + keyItems.addAll(prefixTuple.getItems()); + } + keyItems.addAll(resultEntry.getPrimaryKey().getItems()); + final List valueItems = Lists.newArrayList(); + final RealVector vector = resultEntry.getVector(); + valueItems.add(vector == null ? null : resultEntry.getVector().getRawData()); + return new IndexEntry(state.index, Tuple.fromList(keyItems), + Tuple.fromList(valueItems)); + } + + @Nonnull + @Override + public RecordCursor scan(@Nonnull final IndexScanType scanType, @Nonnull final TupleRange range, + @Nullable final byte[] continuation, @Nonnull final ScanProperties scanProperties) { + throw new IllegalStateException("index maintainer does not support this scan api"); + } + + @Nonnull + private Function> prefixSkipScan(final int prefixSize, + @Nonnull final StoreTimer timer, + @Nonnull final VectorIndexScanBounds vectorIndexScanBounds, + @Nonnull final ScanProperties innerScanProperties) { + Verify.verify(prefixSize > 0); + return outerContinuation -> timer.instrument(MultiDimensionalIndexHelper.Events.MULTIDIMENSIONAL_SKIP_SCAN, + new ChainedCursor<>(state.context, + lastKeyOptional -> nextPrefixTuple(vectorIndexScanBounds.getPrefixRange(), + prefixSize, lastKeyOptional.orElse(null), innerScanProperties), + Tuple::pack, + Tuple::fromBytes, + outerContinuation, + innerScanProperties)); + } + + @SuppressWarnings({"resource", "PMD.CloseResource"}) + private CompletableFuture> nextPrefixTuple(@Nonnull final TupleRange prefixRange, + final int prefixSize, + @Nullable final Tuple lastPrefixTuple, + @Nonnull final ScanProperties scanProperties) { + final Subspace indexSubspace = getIndexSubspace(); + final KeyValueCursor cursor; + if (lastPrefixTuple == null) { + cursor = KeyValueCursor.Builder.withSubspace(indexSubspace) + .setContext(state.context) + .setRange(prefixRange) + .setContinuation(null) + .setScanProperties(scanProperties.setStreamingMode(CursorStreamingMode.ITERATOR) + .with(innerExecuteProperties -> innerExecuteProperties.setReturnedRowLimit(1))) + .build(); + } else { + KeyValueCursor.Builder builder = KeyValueCursor.Builder.withSubspace(indexSubspace) + .setContext(state.context) + .setContinuation(null) + .setScanProperties(scanProperties) + .setScanProperties(scanProperties.setStreamingMode(CursorStreamingMode.ITERATOR) + .with(innerExecuteProperties -> innerExecuteProperties.setReturnedRowLimit(1))); + + cursor = builder.setLow(indexSubspace.pack(lastPrefixTuple), EndpointType.RANGE_EXCLUSIVE) + .setHigh(prefixRange.getHigh(), prefixRange.getHighEndpoint()) + .build(); + } + + return cursor.onNext().thenApply(next -> { + cursor.close(); + if (next.hasNext()) { + final KeyValue kv = Objects.requireNonNull(next.get()); + return Optional.of(TupleHelpers.subTuple(indexSubspace.unpack(kv.getKey()), 0, prefixSize)); + } + return Optional.empty(); + }); + } + + @Override + protected CompletableFuture updateIndexKeys(@Nonnull final FDBIndexableRecord savedRecord, + final boolean remove, + @Nonnull final List indexEntries) { + Verify.verify(indexEntries.size() == 1); + final KeyWithValueExpression keyWithValueExpression = getKeyWithValueExpression(state.index.getRootExpression()); + final int prefixSize = keyWithValueExpression.getColumnSize(); + final Subspace indexSubspace = getIndexSubspace(); + final var indexEntry = indexEntries.get(0); + + final byte[] vectorBytes = indexEntry.getValue().getBytes(0); + if (vectorBytes == null) { + // + // If there is no vector (e.g. vector is NULL), we don't even need to index it. + // + return AsyncUtil.DONE; + } + + final Tuple prefixKey = indexEntry.getKey(); + final Subspace rtSubspace; + if (prefixSize > 0) { + rtSubspace = indexSubspace.subspace(prefixKey); + } else { + rtSubspace = indexSubspace; + } + return state.context.doWithWriteLock(new LockIdentifier(rtSubspace), () -> { + final List primaryKeyParts = Lists.newArrayList(savedRecord.getPrimaryKey().getItems()); + state.index.trimPrimaryKey(primaryKeyParts); + final Tuple trimmedPrimaryKey = Tuple.fromList(primaryKeyParts); + final FDBStoreTimer timer = Objects.requireNonNull(getTimer()); + final HNSW hnsw = + new HNSW(rtSubspace, getExecutor(), getConfig(), new OnWrite(timer), OnReadListener.NOOP); + if (remove) { + throw new UnsupportedOperationException("not implemented"); + } else { + return hnsw.insert(state.transaction, trimmedPrimaryKey, + RealVector.fromBytes(vectorBytes)); + } + }); + } + + @Override + public boolean canDeleteWhere(@Nonnull final QueryToKeyMatcher matcher, @Nonnull final Key.Evaluated evaluated) { + if (!super.canDeleteWhere(matcher, evaluated)) { + return false; + } + return evaluated.size() <= getKeyWithValueExpression(state.index.getRootExpression()).getColumnSize(); + } + + @Override + public CompletableFuture deleteWhere(@Nonnull final Transaction tr, @Nonnull final Tuple prefix) { + Verify.verify(getKeyWithValueExpression(state.index.getRootExpression()).getColumnSize() >= prefix.size()); + return super.deleteWhere(tr, prefix); + } + + /** + * TODO. + */ + @Nonnull + private static KeyWithValueExpression getKeyWithValueExpression(@Nonnull final KeyExpression root) { + if (root instanceof KeyWithValueExpression) { + return (KeyWithValueExpression)root; + } + throw new RecordCoreException("structure of vector index is not supported"); + } + + private int efSearch(@Nonnull final VectorIndexScanBounds scanBounds) { + final VectorIndexScanOptions scanOptions = scanBounds.getVectorIndexScanOptions(); + final Integer efSearchOptionValue = scanOptions.getOption(VectorIndexScanOptions.HNSW_EF_SEARCH); + if (efSearchOptionValue != null) { + return efSearchOptionValue; + } + final var k = scanBounds.getAdjustedLimit(); + return Math.min(Math.max(4 * k, 64), Math.max(k, 400)); + } + + private boolean returnVectors(@Nonnull final Config config, @Nonnull final VectorIndexScanBounds scanBounds) { + final VectorIndexScanOptions scanOptions = scanBounds.getVectorIndexScanOptions(); + final Boolean returnVectorsValue = scanOptions.getOption(VectorIndexScanOptions.HNSW_RETURN_VECTORS); + if (returnVectorsValue != null) { + return returnVectorsValue; + } + + // + // If we use RaBitQ, the vectors returned must be reconstructed which means we potentially wasted computation + // resources if the user didn't explicitly ask for it. If RaBitQ is not used, the vectors returned are identical + // to their inserted counterparts. We also already fetched them, so returning them is free. + // + return !config.isUseRaBitQ(); + } + + static class OnRead implements OnReadListener { + @Nonnull + private final FDBStoreTimer timer; + + public OnRead(@Nonnull final FDBStoreTimer timer) { + this.timer = timer; + } + + @Override + public > CompletableFuture onAsyncRead(@Nonnull CompletableFuture future) { + return timer.instrument(VectorIndexHelper.Events.VECTOR_SCAN, future); + } + + @Override + public void onNodeRead(final int layer, @Nonnull final Node node) { + if (layer == 0) { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE0_READS); + } else { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE_READS); + } + } + + @Override + public void onKeyValueRead(final int layer, @Nonnull final byte[] key, @Nullable final byte[] value) { + final int keyLength = key.length; + final int valueLength = value == null ? 0 : value.length; + + timer.increment(FDBStoreTimer.Counts.LOAD_INDEX_KEY); + timer.increment(FDBStoreTimer.Counts.LOAD_INDEX_KEY_BYTES, keyLength); + timer.increment(FDBStoreTimer.Counts.LOAD_INDEX_VALUE_BYTES, valueLength); + + if (layer == 0) { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE0_READ_BYTES); + } else { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE_READ_BYTES); + } + } + } + + static class OnWrite implements OnWriteListener { + @Nonnull + private final FDBStoreTimer timer; + + public OnWrite(@Nonnull final FDBStoreTimer timer) { + this.timer = timer; + } + + @Override + public void onNodeWritten(final int layer, @Nonnull final Node node) { + if (layer == 0) { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE0_WRITES); + } else { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE_WRITES); + } + } + + @Override + public void onKeyValueWritten(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + final int keyLength = key.length; + final int valueLength = value.length; + + final int totalLength = keyLength + valueLength; + timer.increment(FDBStoreTimer.Counts.SAVE_INDEX_KEY); + timer.increment(FDBStoreTimer.Counts.SAVE_INDEX_KEY_BYTES, keyLength); + timer.increment(FDBStoreTimer.Counts.SAVE_INDEX_VALUE_BYTES, valueLength); + + if (layer == 0) { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE0_WRITE_BYTES, totalLength); + } else { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE_WRITE_BYTES, totalLength); + } + } + } + + private static final class Continuation implements RecordCursorContinuation { + @Nonnull + private final List indexEntries; + @Nonnull + private final RecordCursorContinuation innerContinuation; + + @Nullable + private ByteString cachedByteString; + @Nullable + private byte[] cachedBytes; + + private Continuation(@Nonnull final List indexEntries, + @Nonnull final RecordCursorContinuation innerContinuation) { + this.indexEntries = ImmutableList.copyOf(indexEntries); + this.innerContinuation = innerContinuation; + } + + @Nonnull + public List getIndexEntries() { + return indexEntries; + } + + @Nonnull + public RecordCursorContinuation getInnerContinuation() { + return innerContinuation; + } + + @Nonnull + @Override + public ByteString toByteString() { + if (isEnd()) { + return ByteString.EMPTY; + } + + if (cachedByteString == null) { + final RecordCursorProto.VectorIndexScanContinuation.Builder builder = + RecordCursorProto.VectorIndexScanContinuation.newBuilder(); + for (final var indexEntry : getIndexEntries()) { + builder.addIndexEntries(RecordCursorProto.VectorIndexScanContinuation.IndexEntry.newBuilder() + .setKey(ByteString.copyFrom(indexEntry.getKey().pack())) + .setValue(ByteString.copyFrom(indexEntry.getKey().pack())) + .build()); + } + + cachedByteString = builder + .setInnerContinuation(Objects.requireNonNull(innerContinuation.toByteString())) + .build() + .toByteString(); + } + return cachedByteString; + } + + @Nullable + @Override + public byte[] toBytes() { + if (isEnd()) { + return null; + } + if (cachedBytes == null) { + cachedBytes = toByteString().toByteArray(); + } + return cachedBytes; + } + + @Override + public boolean isEnd() { + return getInnerContinuation().isEnd(); + } + + @Nonnull + private static RecordCursorProto.VectorIndexScanContinuation fromBytes(@Nonnull byte[] continuationBytes) { + try { + return RecordCursorProto.VectorIndexScanContinuation.parseFrom(continuationBytes); + } catch (InvalidProtocolBufferException ex) { + throw new RecordCoreException("error parsing continuation", ex) + .addLogInfo("raw_bytes", ByteArrayUtil2.loggable(continuationBytes)); + } + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java new file mode 100644 index 0000000000..5e164ba943 --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java @@ -0,0 +1,191 @@ +/* + * VectorIndexMaintainerFactory.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.hnsw.Config; +import com.apple.foundationdb.record.logging.LogMessageKeys; +import com.apple.foundationdb.record.metadata.Index; +import com.apple.foundationdb.record.metadata.IndexOptions; +import com.apple.foundationdb.record.metadata.IndexTypes; +import com.apple.foundationdb.record.metadata.IndexValidator; +import com.apple.foundationdb.record.metadata.MetaDataException; +import com.apple.foundationdb.record.metadata.MetaDataValidator; +import com.apple.foundationdb.record.metadata.expressions.KeyExpression; +import com.apple.foundationdb.record.metadata.expressions.KeyWithValueExpression; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainer; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainerFactory; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainerState; +import com.google.auto.service.AutoService; + +import javax.annotation.Nonnull; +import java.util.Arrays; +import java.util.Set; +import java.util.function.Function; + +/** + * A factory for {@link VectorIndexMaintainer} index maintainers. + */ +@AutoService(IndexMaintainerFactory.class) +@API(API.Status.EXPERIMENTAL) +public class VectorIndexMaintainerFactory implements IndexMaintainerFactory { + static final String[] TYPES = { IndexTypes.VECTOR }; + + @Override + @Nonnull + public Iterable getIndexTypes() { + return Arrays.asList(TYPES); + } + + @Override + @Nonnull + public IndexValidator getIndexValidator(Index index) { + return new VectorIndexValidator(index); + } + + @Override + @Nonnull + public IndexMaintainer getIndexMaintainer(@Nonnull final IndexMaintainerState state) { + return new VectorIndexMaintainer(state); + } + + /** + * Index validator for HNSW-based vector indexes. + */ + private static class VectorIndexValidator extends IndexValidator { + public VectorIndexValidator(final Index index) { + super(index); + } + + @Override + public void validate(@Nonnull MetaDataValidator metaDataValidator) { + super.validate(metaDataValidator); + validateStructure(); + + try { + VectorIndexHelper.getConfig(index); + } catch (final IllegalArgumentException illegalArgumentException) { + throw new MetaDataException("incorrect index options", illegalArgumentException); + } + } + + /** + * Validates the key expression structure of a vector index. + *

+ * The root expression must be a {@link KeyWithValueExpression}. Its split point divides the columns: + *

    + *
  • columns before the split point: index prefix (for partitioning)
  • + *
  • columns after the split point: vector column followed by optional covering columns
  • + *
+ * The first column after the split point is always the vector column, so at least one column + * is required after the split point. There are some other structural requirements to the root key + * expression of a vector index and some general requirements to the index itself: + *
    + *
  • the root key expression must not contain a grouping key expression
  • + *
  • the index must not be unique
  • + *
  • the index must not contain any version columns
  • + *
+ *

+ * TODO: Currently only exactly one column after the split point is supported (no covering columns yet). + */ + private void validateStructure() { + validateNotGrouping(); + validateNotUnique(); + validateNotVersion(); + + final KeyExpression key = index.getRootExpression(); + if (!(key instanceof KeyWithValueExpression)) { + throw new KeyExpression.InvalidExpressionException( + "vector index type must use top key with value expression", + LogMessageKeys.INDEX_TYPE, index.getType(), + LogMessageKeys.INDEX_NAME, index.getName(), + LogMessageKeys.INDEX_KEY, index.getRootExpression()); + } + if (key.createsDuplicates()) { + throw new KeyExpression.InvalidExpressionException( + "fan outs not supported in index type", + LogMessageKeys.INDEX_TYPE, index.getType(), + LogMessageKeys.INDEX_NAME, index.getName(), + LogMessageKeys.INDEX_KEY, index.getRootExpression()); + } + } + + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public void validateChangedOptions(@Nonnull final Index oldIndex, + @Nonnull final Set changedOptions) { + if (!changedOptions.isEmpty()) { + final Config oldOptions = VectorIndexHelper.getConfig(oldIndex); + final Config newOptions = VectorIndexHelper.getConfig(index); + + // do not allow changing any of the following + disallowChange(changedOptions, IndexOptions.HNSW_DETERMINISTIC_SEEDING, + oldOptions, newOptions, Config::isDeterministicSeeding); + disallowChange(changedOptions, IndexOptions.HNSW_METRIC, + oldOptions, newOptions, Config::getMetric); + disallowChange(changedOptions, IndexOptions.HNSW_NUM_DIMENSIONS, + oldOptions, newOptions, Config::getNumDimensions); + disallowChange(changedOptions, IndexOptions.HNSW_USE_INLINING, + oldOptions, newOptions, Config::isUseInlining); + disallowChange(changedOptions, IndexOptions.HNSW_M, + oldOptions, newOptions, Config::getM); + disallowChange(changedOptions, IndexOptions.HNSW_M_MAX, + oldOptions, newOptions, Config::getMMax); + disallowChange(changedOptions, IndexOptions.HNSW_M_MAX_0, + oldOptions, newOptions, Config::getMMax0); + disallowChange(changedOptions, IndexOptions.HNSW_EF_CONSTRUCTION, + oldOptions, newOptions, Config::getEfConstruction); + disallowChange(changedOptions, IndexOptions.HNSW_EXTEND_CANDIDATES, + oldOptions, newOptions, Config::isExtendCandidates); + disallowChange(changedOptions, IndexOptions.HNSW_KEEP_PRUNED_CONNECTIONS, + oldOptions, newOptions, Config::isKeepPrunedConnections); + disallowChange(changedOptions, IndexOptions.HNSW_USE_RABITQ, + oldOptions, newOptions, Config::isUseRaBitQ); + disallowChange(changedOptions, IndexOptions.HNSW_RABITQ_NUM_EX_BITS, + oldOptions, newOptions, Config::getRaBitQNumExBits); + + // The following index options can be changed. + changedOptions.remove(IndexOptions.HNSW_SAMPLE_VECTOR_STATS_PROBABILITY); + changedOptions.remove(IndexOptions.HNSW_MAINTAIN_STATS_PROBABILITY); + changedOptions.remove(IndexOptions.HNSW_STATS_THRESHOLD); + changedOptions.remove(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NODE_FETCHES); + changedOptions.remove(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NEIGHBORHOOD_FETCHES); + } + super.validateChangedOptions(oldIndex, changedOptions); + } + + private void disallowChange(@Nonnull final Set changedOptions, + @Nonnull final String optionName, + @Nonnull final Config oldConfig, @Nonnull final Config newConfig, + Function extractorFunction) { + if (changedOptions.contains(optionName)) { + final T oldValue = extractorFunction.apply(oldConfig); + final T newValue = extractorFunction.apply(newConfig); + if (!oldValue.equals(newValue)) { + throw new MetaDataException("attempted to change " + optionName + + " from " + oldValue + " to " + newValue, + LogMessageKeys.INDEX_NAME, index.getName()); + } + changedOptions.remove(optionName); + } + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/expressions/Comparisons.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/expressions/Comparisons.java index c2ea478553..c06bb67634 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/expressions/Comparisons.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/expressions/Comparisons.java @@ -39,6 +39,7 @@ import com.apple.foundationdb.record.metadata.expressions.TupleFieldsHelper; import com.apple.foundationdb.record.planprotos.PComparison; import com.apple.foundationdb.record.planprotos.PComparison.PComparisonType; +import com.apple.foundationdb.record.planprotos.PDistanceRankValueComparison; import com.apple.foundationdb.record.planprotos.PInvertedFunctionComparison; import com.apple.foundationdb.record.planprotos.PListComparison; import com.apple.foundationdb.record.planprotos.PMultiColumnComparison; @@ -645,7 +646,13 @@ public enum Type { @API(API.Status.EXPERIMENTAL) LIKE, IS_DISTINCT_FROM(false), - NOT_DISTINCT_FROM(true); + NOT_DISTINCT_FROM(true), + @API(API.Status.EXPERIMENTAL) + DISTANCE_RANK_EQUALS(true), + @API(API.Status.EXPERIMENTAL) + DISTANCE_RANK_LESS_THAN, + @API(API.Status.EXPERIMENTAL) + DISTANCE_RANK_LESS_THAN_OR_EQUAL; @Nonnull private static final Supplier> protoEnumBiMapSupplier = @@ -1531,6 +1538,12 @@ public static class ValueComparison implements Comparison { @Nonnull private final Supplier hashCodeSupplier; + protected ValueComparison(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PValueComparison valueComparisonProto) { + this(Type.fromProto(serializationContext, Objects.requireNonNull(valueComparisonProto.getType())), + Value.fromValueProto(serializationContext, Objects.requireNonNull(valueComparisonProto.getComparandValue()))); + } + public ValueComparison(@Nonnull final Type type, @Nonnull final Value comparandValue) { this(type, comparandValue, ParameterRelationshipGraph.unbound()); @@ -1561,7 +1574,7 @@ public Type getType() { @Nonnull @Override - public Comparison withType(@Nonnull final Type newType) { + public ValueComparison withType(@Nonnull final Type newType) { if (type == newType) { return this; } @@ -1607,7 +1620,8 @@ public Object getComparand(@Nullable FDBRecordStoreBase store, @Nullable Eval @Nonnull @Override - public Comparison translateCorrelations(@Nonnull final TranslationMap translationMap, final boolean shouldSimplifyValues) { + public ValueComparison translateCorrelations(@Nonnull final TranslationMap translationMap, + final boolean shouldSimplifyValues) { if (comparandValue.getCorrelatedTo() .stream() .noneMatch(translationMap::containsSourceAlias)) { @@ -1705,14 +1719,19 @@ public int planHash(@Nonnull final PlanHashMode mode) { @Nonnull @Override - public Comparison withParameterRelationshipMap(@Nonnull final ParameterRelationshipGraph parameterRelationshipGraph) { + public ValueComparison withParameterRelationshipMap(@Nonnull final ParameterRelationshipGraph parameterRelationshipGraph) { Verify.verify(this.parameterRelationshipGraph.isUnbound()); return new ValueComparison(type, comparandValue, parameterRelationshipGraph); } @Nonnull @Override - public PValueComparison toProto(@Nonnull final PlanSerializationContext serializationContext) { + public Message toProto(@Nonnull final PlanSerializationContext serializationContext) { + return toValueComparisonProto(serializationContext); + } + + @Nonnull + public PValueComparison toValueComparisonProto(@Nonnull final PlanSerializationContext serializationContext) { return PValueComparison.newBuilder() .setType(type.toProto(serializationContext)) .setComparandValue(comparandValue.toValueProto(serializationContext)) @@ -1722,14 +1741,13 @@ public PValueComparison toProto(@Nonnull final PlanSerializationContext serializ @Nonnull @Override public PComparison toComparisonProto(@Nonnull final PlanSerializationContext serializationContext) { - return PComparison.newBuilder().setValueComparison(toProto(serializationContext)).build(); + return PComparison.newBuilder().setValueComparison(toValueComparisonProto(serializationContext)).build(); } @Nonnull public static ValueComparison fromProto(@Nonnull final PlanSerializationContext serializationContext, @Nonnull final PValueComparison valueComparisonProto) { - return new ValueComparison(Type.fromProto(serializationContext, Objects.requireNonNull(valueComparisonProto.getType())), - Value.fromValueProto(serializationContext, Objects.requireNonNull(valueComparisonProto.getComparandValue()))); + return new ValueComparison(serializationContext, valueComparisonProto); } /** @@ -1752,6 +1770,221 @@ public ValueComparison fromProto(@Nonnull final PlanSerializationContext seriali } } + @SpotBugsSuppressWarnings("EQ_DOESNT_OVERRIDE_EQUALS") + public static class DistanceRankValueComparison extends ValueComparison { + private static final ObjectPlanHash BASE_HASH = new ObjectPlanHash("Distance-Rank-Value-Comparison"); + + @Nonnull + private final Value limitValue; + + protected DistanceRankValueComparison(@Nonnull PlanSerializationContext serializationContext, + @Nonnull final PDistanceRankValueComparison distanceRankValueComparisonProto) { + super(serializationContext, distanceRankValueComparisonProto.getSuper()); + this.limitValue = Value.fromValueProto(serializationContext, + Objects.requireNonNull(distanceRankValueComparisonProto.getLimitValue())); + } + + public DistanceRankValueComparison(@Nonnull final Type type, @Nonnull final Value comparandValue, + @Nonnull final Value limitValue) { + this(type, comparandValue, ParameterRelationshipGraph.unbound(), limitValue); + } + + public DistanceRankValueComparison(@Nonnull final Type type, @Nonnull final Value comparandValue, + @Nonnull final ParameterRelationshipGraph parameterRelationshipGraph, + @Nonnull final Value limitValue) { + super(type, comparandValue, parameterRelationshipGraph); + Verify.verify(type == Type.DISTANCE_RANK_LESS_THAN || + type == Type.DISTANCE_RANK_LESS_THAN_OR_EQUAL); + this.limitValue = limitValue; + } + + @Nonnull + public Value getLimitValue() { + return limitValue; + } + + @Nonnull + @Override + public DistanceRankValueComparison withType(@Nonnull final Type newType) { + if (getType() == newType) { + return this; + } + return new DistanceRankValueComparison(newType, getComparandValue(), parameterRelationshipGraph, + getLimitValue()); + } + + @Nonnull + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public DistanceRankValueComparison withValue(@Nonnull final Value value) { + if (getComparandValue() == value) { + return this; + } + return new DistanceRankValueComparison(getType(), value, parameterRelationshipGraph, getLimitValue()); + } + + @Nonnull + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public Optional replaceValuesMaybe(@Nonnull final Function> replacementFunction) { + return replacementFunction.apply(getComparandValue()) + .flatMap(replacedComparandValue -> + replacementFunction.apply(getLimitValue()).map(replacedLimitValue -> { + if (replacedComparandValue == getComparandValue() && + replacedLimitValue == getLimitValue()) { + return this; + } + return new DistanceRankValueComparison(getType(), replacedComparandValue, + parameterRelationshipGraph, replacedLimitValue); + })); + } + + @Nonnull + @Override + public DistanceRankValueComparison translateCorrelations(@Nonnull final TranslationMap translationMap, + final boolean shouldSimplifyValues) { + if (getComparandValue().getCorrelatedTo() + .stream() + .noneMatch(translationMap::containsSourceAlias) && + getLimitValue().getCorrelatedTo() + .stream() + .noneMatch(translationMap::containsSourceAlias)) { + return this; + } + + return new DistanceRankValueComparison(getType(), + getComparandValue().translateCorrelations(translationMap, shouldSimplifyValues), + parameterRelationshipGraph, + getLimitValue().translateCorrelations(translationMap, shouldSimplifyValues)); + } + + @Nonnull + @Override + public Set getCorrelatedTo() { + return ImmutableSet.builder() + .addAll(getComparandValue().getCorrelatedTo()) + .addAll(getLimitValue().getCorrelatedTo()) + .build(); + } + + @Nonnull + @Override + public ConstrainedBoolean semanticEqualsTyped(@Nonnull final Comparison other, @Nonnull final ValueEquivalence valueEquivalence) { + return super.semanticEqualsTyped(other, valueEquivalence) + .compose(ignored -> getLimitValue() + .semanticEquals(((DistanceRankValueComparison)other).getLimitValue(), + valueEquivalence)); + } + + @Nullable + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public Boolean eval(@Nullable FDBRecordStoreBase store, @Nonnull EvaluationContext context, @Nullable Object v) { + throw new IllegalStateException("this comparison can only be evaluated using an index"); + } + + @Nonnull + @Override + public String typelessString() { + return typelessExplain().render(DefaultExplainFormatter.forDebugging()).toString(); + } + + @Override + public String toString() { + return explain().getExplainTokens().render(DefaultExplainFormatter.forDebugging()).toString(); + } + + @Nonnull + @Override + public ExplainTokensWithPrecedence explain() { + return ExplainTokensWithPrecedence.of(new ExplainTokens().addKeyword(getType().name()) + .addWhitespace().addNested(typelessExplain())); + } + + @Nonnull + private ExplainTokens typelessExplain() { + return new ExplainTokens().addNested(getComparandValue().explain().getExplainTokens()) + .addKeyword(":").addWhitespace() + .addNested(getLimitValue().explain().getExplainTokens()); + } + + @Override + public int planHash(@Nonnull final PlanHashMode mode) { + switch (mode.getKind()) { + case LEGACY: + case FOR_CONTINUATION: + return PlanHashable.objectsPlanHash(mode, BASE_HASH, getType(), getComparandValue(), getLimitValue()); + default: + throw new UnsupportedOperationException("Hash Kind " + mode.name() + " is not supported"); + } + } + + @Override + public int computeHashCode() { + return Objects.hash(super.computeHashCode(), getType().name(), getComparandValue(), getLimitValue()); + } + + @Nonnull + @Override + public DistanceRankValueComparison withParameterRelationshipMap(@Nonnull final ParameterRelationshipGraph parameterRelationshipGraph) { + Verify.verify(this.parameterRelationshipGraph.isUnbound()); + return new DistanceRankValueComparison(getType(), getComparandValue(), parameterRelationshipGraph, + getLimitValue()); + } + + @Nonnull + @Override + public PDistanceRankValueComparison toProto(@Nonnull final PlanSerializationContext serializationContext) { + return PDistanceRankValueComparison.newBuilder() + .setSuper(super.toValueComparisonProto(serializationContext)) + .setLimitValue(getLimitValue().toValueProto(serializationContext)) + .build(); + } + + @Nonnull + @Override + public PComparison toComparisonProto(@Nonnull final PlanSerializationContext serializationContext) { + return PComparison.newBuilder().setDistanceRankValueComparison(toProto(serializationContext)).build(); + } + + @Nullable + public RealVector getVector(@Nullable final FDBRecordStoreBase store, final @Nullable EvaluationContext context) { + return (RealVector)getComparand(store, context); + } + + public int getLimit(@Nullable final FDBRecordStoreBase store, final @Nullable EvaluationContext context) { + if (context == null) { + throw EvaluationContextRequiredException.instance(); + } + return (int)Objects.requireNonNull(getLimitValue().eval(store, context)); + } + + @Nonnull + public static DistanceRankValueComparison fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PDistanceRankValueComparison distanceRankValueComparisonProto) { + return new DistanceRankValueComparison(serializationContext, distanceRankValueComparisonProto); + } + + /** + * Deserializer. + */ + @AutoService(PlanDeserializer.class) + public static class Deserializer implements PlanDeserializer { + @Nonnull + @Override + public Class getProtoMessageClass() { + return PDistanceRankValueComparison.class; + } + + @Nonnull + @Override + public DistanceRankValueComparison fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PDistanceRankValueComparison distanceRankValueComparisonProto) { + return DistanceRankValueComparison.fromProto(serializationContext, distanceRankValueComparisonProto); + } + } + } + /** * A comparison with a list of values. */ diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/RecordQueryPlanner.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/RecordQueryPlanner.java index f1bcd08a7f..d1be287c68 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/RecordQueryPlanner.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/RecordQueryPlanner.java @@ -1745,7 +1745,7 @@ private RecordQueryPlan planScan(@Nonnull CandidateScan candidateScan, Set possibleTypes; if (candidateScan.index == null) { Verify.verify(indexScanParameters instanceof IndexScanComparisons); - final ScanComparisons scanComparisons = ((IndexScanComparisons)indexScanParameters).getComparisons(); + final ScanComparisons scanComparisons = ((IndexScanComparisons)indexScanParameters).getScanComparisons(); if (primaryKeyHasRecordTypePrefix && RecordTypeKeyComparison.hasRecordTypeKeyComparison(scanComparisons)) { possibleTypes = RecordTypeKeyComparison.recordTypeKeyComparisonTypes(scanComparisons); } else { diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/ScanComparisons.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/ScanComparisons.java index 4292dda86a..34903baa44 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/ScanComparisons.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/ScanComparisons.java @@ -151,6 +151,7 @@ public static ComparisonType getComparisonType(@Nonnull Comparisons.Comparison c switch (comparison.getType()) { case EQUALS: case IS_NULL: + case DISTANCE_RANK_EQUALS: return ComparisonType.EQUALITY; case LESS_THAN: case LESS_THAN_OR_EQUALS: @@ -159,6 +160,8 @@ public static ComparisonType getComparisonType(@Nonnull Comparisons.Comparison c case STARTS_WITH: case NOT_NULL: case SORT: + case DISTANCE_RANK_LESS_THAN: + case DISTANCE_RANK_LESS_THAN_OR_EQUAL: return ComparisonType.INEQUALITY; case NOT_EQUALS: default: diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/values/LiteralValue.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/values/LiteralValue.java index d1f9cf854a..dedf8022af 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/values/LiteralValue.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/values/LiteralValue.java @@ -145,7 +145,9 @@ public int planHash(@Nonnull final PlanHashMode mode) { @Override public PLiteralValue toProto(@Nonnull final PlanSerializationContext serializationContext) { final var builder = PLiteralValue.newBuilder(); - builder.setValue(PlanSerialization.valueObjectToProto(value)); + // TODO there should be no punishment if the type is serialized alongside with the value + final var resultTypeProto = resultType.isVector() ? resultType.toTypeProto(serializationContext) : null; + builder.setValue(PlanSerialization.valueObjectToProto(value, resultTypeProto)); builder.setResultType(resultType.toTypeProto(serializationContext)); return builder.build(); } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/plans/RecordQueryIndexPlan.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/plans/RecordQueryIndexPlan.java index 2ed9cc2b15..5d32d720ef 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/plans/RecordQueryIndexPlan.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/plans/RecordQueryIndexPlan.java @@ -597,17 +597,14 @@ public void logPlanStructure(StoreTimer timer) { @Override public boolean hasScanComparisons() { - return scanParameters instanceof IndexScanComparisons; + return scanParameters.hasScanComparisons(); } @Nonnull @Override public ScanComparisons getScanComparisons() { - if (scanParameters instanceof IndexScanComparisons) { - return ((IndexScanComparisons)scanParameters).getComparisons(); - } else { - throw new RecordCoreException("this plan does not use ScanComparisons"); - } + Verify.verify(hasScanComparisons()); + return scanParameters.getScanComparisons(); } @Nonnull diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/serialization/PlanSerialization.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/serialization/PlanSerialization.java index c855b82d89..c01965b4ed 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/serialization/PlanSerialization.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/serialization/PlanSerialization.java @@ -124,9 +124,10 @@ public static Object protoToValueObject(@Nonnull final PComparableObject proto) // this can be extended in the future to cover other types as well, such as // FDBRecordVersion, and UUID instead of having special sub-message requirements final var type = Type.fromTypeProto(PlanSerializationContext.newForCurrentMode(), proto.getType()); - Verify.verify(type.isVector()); - final var primitiveObject = proto.getPrimitiveObject(); - return VectorUtils.parseVector(primitiveObject.getBytesValue(), (Type.Vector)type); + if (type.isVector()) { + final var primitiveObject = proto.getPrimitiveObject(); + return VectorUtils.parseVector(primitiveObject.getBytesValue(), (Type.Vector)type); + } } if (proto.hasEnumObject()) { diff --git a/fdb-record-layer-core/src/main/proto/record_cursor.proto b/fdb-record-layer-core/src/main/proto/record_cursor.proto index e8eaee8aeb..5cb54a5ac3 100644 --- a/fdb-record-layer-core/src/main/proto/record_cursor.proto +++ b/fdb-record-layer-core/src/main/proto/record_cursor.proto @@ -135,6 +135,15 @@ message MultidimensionalIndexScanContinuation { optional bytes lastKey = 2; } +message VectorIndexScanContinuation { + message IndexEntry { + optional bytes key = 1; + optional bytes value = 2; + } + repeated IndexEntry indexEntries = 1; + optional bytes inner_continuation = 2; +} + message TempTableInsertContinuation { optional bytes child_continuation = 1; optional planprotos.PTempTable tempTable = 2; diff --git a/fdb-record-layer-core/src/main/proto/record_query_plan.proto b/fdb-record-layer-core/src/main/proto/record_query_plan.proto index 49dc07faf9..2b7ce1f6c3 100644 --- a/fdb-record-layer-core/src/main/proto/record_query_plan.proto +++ b/fdb-record-layer-core/src/main/proto/record_query_plan.proto @@ -1395,6 +1395,9 @@ message PComparison { LIKE = 19; IS_DISTINCT_FROM = 20; NOT_DISTINCT_FROM = 21; + DISTANCE_RANK_EQUALS = 22; + DISTANCE_RANK_LESS_THAN = 23; + DISTANCE_RANK_LESS_THAN_OR_EQUAL = 24; } extensions 5000 to max; @@ -1411,6 +1414,7 @@ message PComparison { PRecordTypeComparison record_type_comparison = 10; PConversionSimpleComparison conversion_simple_comparison = 11; PConversionParameterComparison conversion_parameter_comparison = 12; + PDistanceRankValueComparison distance_rank_value_comparison = 13; } } @@ -1477,6 +1481,11 @@ message PRecordTypeComparison { optional string record_type_name = 1; } +message PDistanceRankValueComparison { + optional PValueComparison super = 1; + optional PValue limitValue = 2; +} + // // Query Predicates // @@ -1822,6 +1831,7 @@ message PIndexScanParameters { PIndexScanComparisons index_scan_comparisons = 2; PMultidimensionalIndexScanComparisons multidimensional_index_scan_comparisons = 3; PTimeWindowScanComparisons time_window_scan_comparisons = 4; + PVectorIndexScanComparisons vector_index_scan_comparisons = 5; } } @@ -1857,6 +1867,20 @@ message PTimeWindowScanComparisons { optional PTimeWindowForFunction time_window = 2; } +message PVectorIndexScanComparisons { + optional PScanComparisons prefix_scan_comparisons = 1; + optional PDistanceRankValueComparison distance_rank_value_comparison = 2; + optional PVectorIndexScanOptions vector_index_scan_options = 3; +} + +message PVectorIndexScanOptions { + message POptionEntry { + optional string key = 1; + optional com.apple.foundationdb.record.expressions.Value value = 2; + } + repeated POptionEntry optionEntries = 1; +} + enum PIndexFetchMethod { SCAN_AND_FETCH = 1; USE_REMOTE_FETCH = 2; diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisonsTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisonsTest.java new file mode 100644 index 0000000000..136ff1f66f --- /dev/null +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisonsTest.java @@ -0,0 +1,319 @@ +/* + * VectorIndexScanComparisonsTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb; + +import com.apple.foundationdb.linear.DoubleRealVector; +import com.apple.foundationdb.record.query.expressions.Comparisons; +import com.apple.foundationdb.record.query.plan.ScanComparisons; +import com.apple.foundationdb.record.query.plan.cascades.AliasMap; +import com.apple.foundationdb.record.query.plan.cascades.CorrelationIdentifier; +import com.apple.foundationdb.record.query.plan.cascades.explain.Attribute; +import com.apple.foundationdb.record.query.plan.cascades.typing.Type; +import com.apple.foundationdb.record.query.plan.cascades.values.LiteralValue; +import com.apple.foundationdb.record.query.plan.cascades.values.QuantifiedObjectValue; +import com.apple.foundationdb.record.query.plan.cascades.values.translation.TranslationMap; +import com.apple.foundationdb.record.query.plan.explain.DefaultExplainFormatter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import org.junit.jupiter.api.Test; + +import javax.annotation.Nonnull; +import java.util.Objects; +import java.util.concurrent.ThreadLocalRandom; + +import static org.assertj.core.api.Assertions.assertThat; + +class VectorIndexScanComparisonsTest { + @Test + void translateCorrelationsTest() { + final ScanComparisons originalPrefixScanComparisons = randomPrefixScanComparisons(); + final Comparisons.DistanceRankValueComparison originalDistanceRankComparison = randomDistanceRankComparison(); + final VectorIndexScanOptions originalScanOptions = + VectorIndexScanOptions.builder() + .putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 101) + .putOption(VectorIndexScanOptions.HNSW_RETURN_VECTORS, false) + .build(); + final VectorIndexScanComparisons original = + VectorIndexScanComparisons.byDistance(originalPrefixScanComparisons, + originalDistanceRankComparison, + originalScanOptions); + final TranslationMap translationMap = + TranslationMap.regularBuilder() + .when(q1()).then(((sourceAlias, leafValue) -> + Objects.requireNonNull(originalPrefixScanComparisons.getEqualityComparisons().get(0).getValue()))) + .when(q2()).then(((sourceAlias, leafValue) -> + Objects.requireNonNull(Iterables.getOnlyElement(originalPrefixScanComparisons.getInequalityComparisons()).getValue()))) + .when(q3()).then(((sourceAlias, leafValue) -> + originalDistanceRankComparison.getComparandValue())) + .when(q4()).then(((sourceAlias, leafValue) -> + originalDistanceRankComparison.getLimitValue())) + .build(); + + final Comparisons.DistanceRankValueComparison correlatedDistanceRankComparison = correlatedDistanceRankComparison(); + final ScanComparisons correlatedPrefixScanComparisons = correlatedPrefixScanComparisons(); + + final VectorIndexScanComparisons correlated = + VectorIndexScanComparisons.byDistance(correlatedPrefixScanComparisons, + correlatedDistanceRankComparison, + originalScanOptions); + + final IndexScanParameters translated = + correlated.translateCorrelations(translationMap, false); + assertThat(translated).isEqualTo(original); + } + + @Test + void rebaseTest() { + final VectorIndexScanOptions originalScanOptions = + VectorIndexScanOptions.builder() + .putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 101) + .putOption(VectorIndexScanOptions.HNSW_RETURN_VECTORS, false) + .build(); + final AliasMap aliasMap = + AliasMap.builder() + .put(q1(), q5()) + .put(q2(), q6()) + .put(q3(), q7()) + .put(q4(), q8()) + .build(); + + final ScanComparisons originalPrefixScanComparisons = correlatedPrefixScanComparisons(); + final Comparisons.DistanceRankValueComparison originalDistanceRankComparison = correlatedDistanceRankComparison(); + + final VectorIndexScanComparisons original = + VectorIndexScanComparisons.byDistance(originalPrefixScanComparisons, + originalDistanceRankComparison, + originalScanOptions); + + final IndexScanParameters rebased = original.rebase(aliasMap); + assertThat(rebased).isNotEqualTo(original); + assertThat(rebased.getCorrelatedTo()).containsExactly(q5(), q6(), q7(), q8()); + + final ImmutableList.Builder originalDetailsBuilder = ImmutableList.builder(); + final ImmutableMap.Builder originalAttributeMapBuilder = ImmutableMap.builder(); + original.getPlannerGraphDetails(originalDetailsBuilder, originalAttributeMapBuilder); + final ImmutableList originalDetails = originalDetailsBuilder.build(); + final ImmutableMap originalAttributeMap = originalAttributeMapBuilder.build(); + + final ImmutableList.Builder rebasedDetailsBuilder = ImmutableList.builder(); + final ImmutableMap.Builder rebasedAttributeMapBuilder = ImmutableMap.builder(); + rebased.getPlannerGraphDetails(rebasedDetailsBuilder, rebasedAttributeMapBuilder); + + assertThat(rebasedDetailsBuilder.build()).isEqualTo(originalDetails); + assertThat(rebasedAttributeMapBuilder.build()).doesNotHaveToString(originalAttributeMap.toString()); + assertThat(rebased).doesNotHaveToString(original.toString()); + assertThat(renderExplain(rebased)).isNotEqualTo(renderExplain(original)); + + final IndexScanParameters inverseRebased = rebased.rebase(aliasMap.inverse()); + assertThat(inverseRebased).isEqualTo(original); + + final ImmutableList.Builder inverseRebasedDetailsBuilder = ImmutableList.builder(); + final ImmutableMap.Builder inverseRebasedAttributeMapBuilder = ImmutableMap.builder(); + inverseRebased.getPlannerGraphDetails(inverseRebasedDetailsBuilder, inverseRebasedAttributeMapBuilder); + assertThat(inverseRebasedDetailsBuilder.build()).isEqualTo(originalDetails); + assertThat(inverseRebasedAttributeMapBuilder.build()).hasToString(originalAttributeMap.toString()); + assertThat(inverseRebased).hasToString(original.toString()); + assertThat(renderExplain(inverseRebased)).isEqualTo(renderExplain(original)); + assertThat(inverseRebased).hasSameHashCodeAs(original); + assertThat(inverseRebased.semanticHashCode()).isEqualTo(original.semanticHashCode()); + } + + @Test + void withComparisonsAndOptions() { + final VectorIndexScanOptions originalScanOptions = + VectorIndexScanOptions.builder() + .putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 101) + .putOption(VectorIndexScanOptions.HNSW_RETURN_VECTORS, false) + .build(); + + final ScanComparisons originalPrefixScanComparisons = correlatedPrefixScanComparisons(); + final Comparisons.DistanceRankValueComparison originalDistanceRankComparison = correlatedDistanceRankComparison(); + + final VectorIndexScanComparisons original = + VectorIndexScanComparisons.byDistance(originalPrefixScanComparisons, + originalDistanceRankComparison, + originalScanOptions); + + final AliasMap aliasMap = + AliasMap.builder() + .put(q1(), q5()) + .put(q2(), q6()) + .put(q3(), q7()) + .put(q4(), q8()) + .build(); + + final ScanComparisons rebasedPrefixScanComparisons = originalPrefixScanComparisons.rebase(aliasMap); + final Comparisons.DistanceRankValueComparison rebasedDistanceRankComparison = + (Comparisons.DistanceRankValueComparison)originalDistanceRankComparison.rebase(aliasMap); + final VectorIndexScanOptions newScanOptions = + VectorIndexScanOptions.builder() + .putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 100) + .putOption(VectorIndexScanOptions.HNSW_RETURN_VECTORS, true) + .build(); + + final var newVectorIndexComparisons = + original.withComparisonsAndOptions(rebasedPrefixScanComparisons, rebasedDistanceRankComparison, + newScanOptions); + assertThat(newVectorIndexComparisons.getPrefixScanComparisons()).isEqualTo(rebasedPrefixScanComparisons); + assertThat(newVectorIndexComparisons.getDistanceRankValueComparison()).isEqualTo(rebasedDistanceRankComparison); + assertThat(newVectorIndexComparisons.getVectorIndexScanOptions()).isEqualTo(newScanOptions); + } + + @Test + void scanComparisonsTest1() { + final ScanComparisons originalPrefixScanComparisons = correlatedPrefixScanComparisons(); + final Comparisons.DistanceRankValueComparison originalDistanceRankComparison = correlatedDistanceRankComparison(); + + final VectorIndexScanComparisons original = + VectorIndexScanComparisons.byDistance(originalPrefixScanComparisons, + originalDistanceRankComparison, + VectorIndexScanOptions.empty()); + assertThat(original.hasScanComparisons()).isTrue(); + final ScanComparisons scanComparisons = original.getScanComparisons(); + assertThat(scanComparisons).isEqualTo(originalPrefixScanComparisons); + } + + @Test + void scanComparisonsTest2() { + final ScanComparisons originalPrefixScanComparisons = correlatedEqualsPrefixScanComparisons(); + final Comparisons.DistanceRankValueComparison originalDistanceRankComparison = correlatedDistanceRankComparison(); + + final VectorIndexScanComparisons original = + VectorIndexScanComparisons.byDistance(originalPrefixScanComparisons, + originalDistanceRankComparison, + VectorIndexScanOptions.empty()); + assertThat(original.hasScanComparisons()).isTrue(); + final ScanComparisons scanComparisons = original.getScanComparisons(); + assertThat(scanComparisons.getEqualityComparisons()).isEqualTo(originalPrefixScanComparisons.getEqualityComparisons()); + assertThat(scanComparisons.getInequalityComparisons()).hasSize(1) + .allSatisfy(comparison -> assertThat(comparison).isEqualTo(originalDistanceRankComparison)); + } + + @Nonnull + protected static String renderExplain(@Nonnull final IndexScanParameters vectorIndexScanComparisons) { + return vectorIndexScanComparisons.explain() + .getExplainTokens() + .render(DefaultExplainFormatter.forDebugging()) + .toString(); + } + + @Nonnull + private static ScanComparisons randomPrefixScanComparisons() { + return new ScanComparisons.Builder() + .addEqualityComparison( + new Comparisons.ValueComparison(Comparisons.Type.EQUALS, + new LiteralValue<>(ThreadLocalRandom.current().nextInt(100)))) + .addInequalityComparison( + new Comparisons.ValueComparison(Comparisons.Type.LESS_THAN, + new LiteralValue<>(ThreadLocalRandom.current().nextInt(100)))) + .build(); + } + + @Nonnull + private static ScanComparisons correlatedPrefixScanComparisons() { + return new ScanComparisons.Builder() + .addEqualityComparison( + new Comparisons.ValueComparison(Comparisons.Type.EQUALS, + QuantifiedObjectValue.of(q1(), Type.primitiveType(Type.TypeCode.INT)))) + .addInequalityComparison( + new Comparisons.ValueComparison(Comparisons.Type.LESS_THAN, + QuantifiedObjectValue.of(q2(), Type.primitiveType(Type.TypeCode.INT)))) + .build(); + } + + @Nonnull + private static ScanComparisons correlatedEqualsPrefixScanComparisons() { + return new ScanComparisons.Builder() + .addEqualityComparison( + new Comparisons.ValueComparison(Comparisons.Type.EQUALS, + QuantifiedObjectValue.of(q1(), Type.primitiveType(Type.TypeCode.INT)))) + .addEqualityComparison( + new Comparisons.ValueComparison(Comparisons.Type.EQUALS, + QuantifiedObjectValue.of(q2(), Type.primitiveType(Type.TypeCode.INT)))) + .build(); + } + + @Nonnull + private static Comparisons.DistanceRankValueComparison randomDistanceRankComparison() { + return new Comparisons.DistanceRankValueComparison(Comparisons.Type.DISTANCE_RANK_LESS_THAN_OR_EQUAL, + getRandomVectorValue(), new LiteralValue<>(10)); + } + + @Nonnull + private static Comparisons.DistanceRankValueComparison correlatedDistanceRankComparison() { + return new Comparisons.DistanceRankValueComparison(Comparisons.Type.DISTANCE_RANK_LESS_THAN_OR_EQUAL, + QuantifiedObjectValue.of(q3(), Type.Vector.of(false, 64, 128)), + QuantifiedObjectValue.of(q4(), Type.primitiveType(Type.TypeCode.INT, false))); + } + + @Nonnull + private static LiteralValue getRandomVectorValue() { + final int numDimensions = 128; + final double[] components = new double[128]; + for (int i = 0; i < numDimensions; i ++) { + components[i] = ThreadLocalRandom.current().nextDouble(); + } + return new LiteralValue<>(Type.Vector.of(false, 64, 128), + new DoubleRealVector(components)); + } + + @Nonnull + protected static CorrelationIdentifier q1() { + return CorrelationIdentifier.of("q1"); + } + + @Nonnull + protected static CorrelationIdentifier q2() { + return CorrelationIdentifier.of("q2"); + } + + @Nonnull + protected static CorrelationIdentifier q3() { + return CorrelationIdentifier.of("q3"); + } + + @Nonnull + protected static CorrelationIdentifier q4() { + return CorrelationIdentifier.of("q4"); + } + + @Nonnull + protected static CorrelationIdentifier q5() { + return CorrelationIdentifier.of("q5"); + } + + @Nonnull + protected static CorrelationIdentifier q6() { + return CorrelationIdentifier.of("q6"); + } + + @Nonnull + protected static CorrelationIdentifier q7() { + return CorrelationIdentifier.of("q7"); + } + + @Nonnull + protected static CorrelationIdentifier q8() { + return CorrelationIdentifier.of("q8"); + } +} + diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/MultidimensionalIndexTestBase.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/MultidimensionalIndexTestBase.java index 31a90d6887..8a73edcb3b 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/MultidimensionalIndexTestBase.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/MultidimensionalIndexTestBase.java @@ -492,7 +492,7 @@ void basicReadWithNullsTest(final boolean useAsync, @Nonnull final String storag } void indexReadTest(final boolean useAsync, final long seed, final int numRecords, @Nonnull final String storage, - final boolean storeHilbertValues, final boolean useNodeSlotIndex) throws Exception { + final boolean storeHilbertValues, final boolean useNodeSlotIndex) throws Exception { final RecordMetaDataHook additionalIndexes = metaDataBuilder -> { addCalendarNameStartEpochIndex(metaDataBuilder); @@ -501,11 +501,16 @@ void indexReadTest(final boolean useAsync, final long seed, final int numRecords loadRecords(useAsync, false, additionalIndexes, seed, ImmutableList.of("business"), numRecords); final long intervalStartInclusive = epochMean + 3600L; final long intervalEndInclusive = epochMean + 5L * 3600L; + final HypercubeScanParameters hypercubeScanParameters = + new HypercubeScanParameters("business", + (Long)null, intervalEndInclusive, + intervalStartInclusive, null); + Assertions.assertFalse(hypercubeScanParameters.hasScanComparisons()); + Assertions.assertThrows(RecordCoreException.class, hypercubeScanParameters::getScanComparisons); + final RecordQueryIndexPlan indexPlan = new RecordQueryIndexPlan("EventIntervals", - new HypercubeScanParameters("business", - (Long)null, intervalEndInclusive, - intervalStartInclusive, null), + hypercubeScanParameters, false); Set actualResults = getResults(additionalIndexes, indexPlan); diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexScanOptionsTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexScanOptionsTest.java new file mode 100644 index 0000000000..4c8b8dcf4b --- /dev/null +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexScanOptionsTest.java @@ -0,0 +1,137 @@ +/* + * VectorIndexScanOptionsTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.record.PlanHashable; +import com.apple.foundationdb.record.PlanSerializationContext; +import com.apple.foundationdb.record.planprotos.PVectorIndexScanOptions; +import com.apple.foundationdb.record.provider.foundationdb.VectorIndexScanOptions; +import com.apple.foundationdb.record.provider.foundationdb.VectorIndexScanOptions.Builder; +import com.apple.foundationdb.record.query.plan.explain.DefaultExplainFormatter; +import org.junit.jupiter.api.Test; + +import javax.annotation.Nonnull; + +import static org.assertj.core.api.Assertions.assertThat; + +class VectorIndexScanOptionsTest { + @Test + void builderRoundTripTest() { + final Builder builder = VectorIndexScanOptions.builder(); + final VectorIndexScanOptions optionsOriginal = builder.build(); + assertThat(optionsOriginal.containsOption(VectorIndexScanOptions.HNSW_EF_SEARCH)).isFalse(); + + Builder resultBuilder = builder.putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 10); + assertThat(resultBuilder).isSameAs(builder); + + final VectorIndexScanOptions optionsAfterPut = builder.build(); + assertThat(optionsAfterPut.containsOption(VectorIndexScanOptions.HNSW_EF_SEARCH)).isTrue(); + assertThat(optionsAfterPut).isNotEqualTo(optionsOriginal); + + assertThat(optionsAfterPut.toBuilder().build()).isEqualTo(optionsAfterPut); + + resultBuilder = builder.removeOption(VectorIndexScanOptions.HNSW_EF_SEARCH); + assertThat(resultBuilder).isSameAs(builder); + + final VectorIndexScanOptions optionsAfterRemove = builder.build(); + assertThat(optionsOriginal.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)) + .isEqualTo(optionsAfterRemove.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)); + assertThat(optionsAfterRemove).isEqualTo(optionsOriginal); + } + + @Test + void builderEqualityTest() { + final Builder builder1 = VectorIndexScanOptions.builder(); + final Builder builder2 = VectorIndexScanOptions.builder(); + + assertThat(builder1).isEqualTo(builder2); + builder1.putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 10); + assertThat(builder1).isNotEqualTo(builder2); + builder2.putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 10); + assertThat(builder1).hasSameHashCodeAs(builder2); + assertThat(builder1).isEqualTo(builder2); + builder1.removeOption(VectorIndexScanOptions.HNSW_EF_SEARCH); + builder2.removeOption(VectorIndexScanOptions.HNSW_EF_SEARCH); + assertThat(builder1).hasSameHashCodeAs(builder2); + assertThat(builder1).isEqualTo(builder2); + + builder1.putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 10); + builder2.putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 20); + assertThat(builder1).isNotEqualTo(builder2); + + assertThat((VectorIndexScanOptions.OptionKey)VectorIndexScanOptions.HNSW_EF_SEARCH) + .isNotEqualTo(VectorIndexScanOptions.HNSW_RETURN_VECTORS); + } + + @Test + void protoRoundTripTest() { + final Builder builder = VectorIndexScanOptions.builder(); + + final VectorIndexScanOptions options = + builder.putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 10) + .build(); + + final PVectorIndexScanOptions proto = + options.toProto(PlanSerializationContext.newForCurrentMode()); + + final VectorIndexScanOptions.Deserializer deserializer = new VectorIndexScanOptions.Deserializer(); + final VectorIndexScanOptions optionsAfterRoundTrip = + deserializer.fromProto(PlanSerializationContext.newForCurrentMode(), proto); + + assertThat(optionsAfterRoundTrip.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)) + .isEqualTo(options.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)); + assertThat(optionsAfterRoundTrip).hasSameHashCodeAs(options); + assertThat(optionsAfterRoundTrip).isEqualTo(options); + } + + @Test + void explainTest() { + final VectorIndexScanOptions options1 = + VectorIndexScanOptions.builder() + .putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 100) + .build(); + final String explain1 = renderExplain(options1); + + final VectorIndexScanOptions options2 = + VectorIndexScanOptions.builder() + .putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 200) + .build(); + assertThat(options1).doesNotHaveToString(options2.toString()); + final String explain2 = renderExplain(options2); + assertThat(explain1).isNotEqualTo(explain2); + + final VectorIndexScanOptions options3 = + options2.toBuilder() + .putOption(VectorIndexScanOptions.HNSW_EF_SEARCH, 100) + .build(); + assertThat(options3).hasToString(options1.toString()); + final String explain3 = renderExplain(options3); + assertThat(explain3).isEqualTo(explain1); + } + + @Nonnull + private static String renderExplain(@Nonnull final VectorIndexScanOptions options) { + return options.explain() + .getExplainTokens() + .render(DefaultExplainFormatter.forDebugging()) + .toString(); + } +} diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTest.java new file mode 100644 index 0000000000..65132ad37c --- /dev/null +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTest.java @@ -0,0 +1,516 @@ +/* + * VectorIndexTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.async.hnsw.NodeReference; +import com.apple.foundationdb.linear.HalfRealVector; +import com.apple.foundationdb.linear.Metric; +import com.apple.foundationdb.record.Bindings; +import com.apple.foundationdb.record.EvaluationContext; +import com.apple.foundationdb.record.ExecuteProperties; +import com.apple.foundationdb.record.ExecuteState; +import com.apple.foundationdb.record.IndexEntry; +import com.apple.foundationdb.record.IndexFetchMethod; +import com.apple.foundationdb.record.IndexScanType; +import com.apple.foundationdb.record.IsolationLevel; +import com.apple.foundationdb.record.RecordCursor; +import com.apple.foundationdb.record.RecordCursorIterator; +import com.apple.foundationdb.record.ScanProperties; +import com.apple.foundationdb.record.TupleRange; +import com.apple.foundationdb.record.metadata.Index; +import com.apple.foundationdb.record.metadata.IndexOptions; +import com.apple.foundationdb.record.metadata.IndexTypes; +import com.apple.foundationdb.record.metadata.IndexValidator; +import com.apple.foundationdb.record.metadata.MetaDataException; +import com.apple.foundationdb.record.metadata.MetaDataValidator; +import com.apple.foundationdb.record.metadata.expressions.KeyWithValueExpression; +import com.apple.foundationdb.record.provider.foundationdb.FDBQueriedRecord; +import com.apple.foundationdb.record.provider.foundationdb.FDBRecordContext; +import com.apple.foundationdb.record.provider.foundationdb.FDBStoredRecord; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainer; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainerFactory; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainerFactoryRegistry; +import com.apple.foundationdb.record.provider.foundationdb.VectorIndexScanComparisons; +import com.apple.foundationdb.record.provider.foundationdb.VectorIndexScanOptions; +import com.apple.foundationdb.record.query.expressions.Comparisons; +import com.apple.foundationdb.record.query.expressions.Query; +import com.apple.foundationdb.record.query.plan.QueryPlanConstraint; +import com.apple.foundationdb.record.query.plan.ScanComparisons; +import com.apple.foundationdb.record.query.plan.cascades.typing.Type; +import com.apple.foundationdb.record.query.plan.cascades.values.LiteralValue; +import com.apple.foundationdb.record.query.plan.plans.RecordQueryFetchFromPartialRecordPlan; +import com.apple.foundationdb.record.query.plan.plans.RecordQueryIndexPlan; +import com.apple.foundationdb.record.vector.TestRecordsVectorsProto.VectorRecord; +import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ObjectArrays; +import com.google.common.collect.Sets; +import com.google.common.primitives.Ints; +import com.google.protobuf.Message; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.apple.foundationdb.record.metadata.Key.Expressions.concat; +import static com.apple.foundationdb.record.metadata.Key.Expressions.field; +import static org.assertj.core.api.Assertions.assertThat; + +class VectorIndexTest extends VectorIndexTestBase { + private static final Logger logger = LoggerFactory.getLogger(VectorIndexTest.class); + + @Nonnull + static Stream randomSeedsWithAsync() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL) + .flatMap(seed -> Sets.cartesianProduct(ImmutableSet.of(true, false)).stream() + .map(arguments -> Arguments.of(ObjectArrays.concat(seed, arguments.toArray())))); + } + + @Nonnull + static Stream randomSeedsWithReturnVectors() { + return RandomizedTestUtils.randomSeeds(0xdeadbeefL) + .flatMap(seed -> Sets.cartesianProduct(ImmutableSet.of(true, false)).stream() + .map(arguments -> Arguments.of(ObjectArrays.concat(seed, arguments.toArray())))); + } + + @Nonnull + static Stream randomSeedsWithAsyncAndLimit() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL) + .flatMap(seed -> Sets.cartesianProduct(ImmutableSet.of(true, false), + ImmutableSet.of(3, 17, 1000)).stream() + .map(arguments -> Arguments.of(ObjectArrays.concat(seed, arguments.toArray())))); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithAsync") + void basicWriteReadTest(final long seed, final boolean useAsync) throws Exception { + final Random random = new Random(seed); + final List> savedRecords = + saveRecords(useAsync, this::addVectorIndexes, random, 1000, 0.3); + try (final FDBRecordContext context = openContext()) { + openRecordStore(context, this::addVectorIndexes); + for (int l = 0; l < 1000; l ++) { + final FDBStoredRecord loadedRecord = + recordStore.loadRecord(savedRecords.get(l).getPrimaryKey()); + + assertThat(loadedRecord).isNotNull(); + assertThat(loadedRecord.getRecord()).isEqualTo(savedRecords.get(l).getRecord()); + } + commit(context); + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithAsyncAndLimit") + void basicWriteIndexReadWithContinuationTest(final long seed, final boolean useAsync, final int limit) throws Exception { + final int k = 100; + final Random random = new Random(seed); + final HalfRealVector queryVector = randomHalfVector(random, 128); + + final List> savedRecords = + saveRecords(useAsync, this::addUngroupedVectorIndex, random, 1000); + + final Set expectedResults = + sortByDistances(savedRecords, queryVector, Metric.EUCLIDEAN_METRIC).stream() + .limit(k) + .map(nodeReferenceWithDistance -> + nodeReferenceWithDistance.getPrimaryKey().getLong(0)) + .collect(ImmutableSet.toImmutableSet()); + + final var indexPlan = + createIndexPlan(queryVector, k, "UngroupedVectorIndex"); + + verifyRebase(indexPlan); + verifySerialization(indexPlan); + + int allCounter = 0; + int recallCounter = 0; + try (final FDBRecordContext context = openContext()) { + openRecordStore(context, this::addUngroupedVectorIndex); + + byte[] continuation = null; + do { + try (RecordCursorIterator> cursor = + executeQuery(indexPlan, continuation, Bindings.EMPTY_BINDINGS, limit)) { + int numRecords = 0; + while (cursor.hasNext()) { + final FDBQueriedRecord rec = cursor.next(); + final VectorRecord record = + VectorRecord.newBuilder() + .mergeFrom(Objects.requireNonNull(rec).getRecord()) + .build(); + numRecords++; + allCounter++; + if (expectedResults.contains(record.getRecNo())) { + recallCounter++; + } + } + if (cursor.getNoNextReason() == RecordCursor.NoNextReason.SOURCE_EXHAUSTED) { + continuation = null; + } else { + continuation = cursor.getContinuation(); + } + if (logger.isInfoEnabled()) { + logger.info("ungrouped read {} records, allCounters={}, recallCounters={}", numRecords, allCounter, + recallCounter); + } + } + } while (continuation != null); + assertThat(allCounter).isEqualTo(k); + assertThat((double)recallCounter / k).isGreaterThan(0.9); + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithAsyncAndLimit") + void basicWriteIndexReadGroupedWithContinuationTest(final long seed, final boolean useAsync, final int limit) throws Exception { + final int k = 100; + final Random random = new Random(seed); + final HalfRealVector queryVector = randomHalfVector(random, 128); + + final Map> expectedResults = + saveRandomRecords(random, this::addGroupedVectorIndex, useAsync, 1000, queryVector); + final var indexPlan = createIndexPlan(queryVector, k, "GroupedVectorIndex"); + + verifyRebase(indexPlan); + verifySerialization(indexPlan); + + try (FDBRecordContext context = openContext()) { + openRecordStore(context, this::addGroupedVectorIndex); + + final int[] allCounters = new int[2]; + final int[] recallCounters = new int[2]; + byte[] continuation = null; + do { + try (final RecordCursorIterator> cursor = + executeQuery(indexPlan, continuation, Bindings.EMPTY_BINDINGS, limit)) { + int numRecords = 0; + while (cursor.hasNext()) { + final FDBQueriedRecord rec = cursor.next(); + final VectorRecord record = + VectorRecord.newBuilder() + .mergeFrom(Objects.requireNonNull(rec).getRecord()) + .build(); + numRecords++; + allCounters[record.getGroupId()]++; + if (expectedResults.get(record.getGroupId()).contains(record.getRecNo())) { + recallCounters[record.getGroupId()]++; + } + } + if (cursor.getNoNextReason() == RecordCursor.NoNextReason.SOURCE_EXHAUSTED) { + continuation = null; + } else { + continuation = cursor.getContinuation(); + } + if (logger.isInfoEnabled()) { + logger.info("grouped read {} records, allCounters={}, recallCounters={}", numRecords, allCounters, + recallCounters); + } + } + } while (continuation != null); + assertThat(Ints.asList(allCounters)) + .allSatisfy(allCounter -> + assertThat(allCounter).isEqualTo(k)); + assertThat(Ints.asList(recallCounters)) + .allSatisfy(recallCounter -> + assertThat((double)recallCounter / k).isGreaterThan(0.9)); + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithAsync") + void deleteWhereGroupedTest(final long seed, final boolean useAsync) throws Exception { + final int k = 100; + final Random random = new Random(seed); + final HalfRealVector queryVector = randomHalfVector(random, 128); + + final Map> expectedResults = saveRandomRecords(random, this::addGroupedVectorIndex, + useAsync, 200, queryVector); + final var indexPlan = createIndexPlan(queryVector, k, "GroupedVectorIndex"); + + try (FDBRecordContext context = openContext()) { + openRecordStore(context, this::addGroupedVectorIndex); + recordStore.deleteRecordsWhere(Query.field("group_id").equalsValue(0)); + + final int[] allCounters = new int[2]; + final int[] recallCounters = new int[2]; + try (final RecordCursorIterator> cursor = executeQuery(indexPlan)) { + while (cursor.hasNext()) { + final FDBQueriedRecord rec = cursor.next(); + final VectorRecord record = + VectorRecord.newBuilder() + .mergeFrom(Objects.requireNonNull(rec).getRecord()) + .build(); + allCounters[record.getGroupId()] ++; + if (expectedResults.get(record.getGroupId()).contains(record.getRecNo())) { + recallCounters[record.getGroupId()] ++; + } + } + } + assertThat(allCounters[0]).isEqualTo(0); + assertThat(allCounters[1]).isEqualTo(k); + + assertThat((double)recallCounters[0] / k).isEqualTo(0.0); + assertThat((double)recallCounters[1] / k).isGreaterThan(0.9); + } + } + + @Test + void directIndexValidatorTest() throws Exception { + try (FDBRecordContext context = openContext()) { + openRecordStore(context, this::addGroupedVectorIndex); + + final Index index = + Objects.requireNonNull(recordStore.getMetaDataProvider()) + .getRecordMetaData().getIndex("GroupedVectorIndex"); + final IndexMaintainerFactoryRegistry indexMaintainerRegistry = recordStore.getIndexMaintainerRegistry(); + final MetaDataValidator metaDataValidator = + new MetaDataValidator(recordStore.getRecordMetaData(), indexMaintainerRegistry); + metaDataValidator.validate(); + + // validate the allowed changes all at once + validateIndexEvolution(metaDataValidator, index, + ImmutableMap.builder() + // cannot change those per se but must accept same value + .put(IndexOptions.HNSW_DETERMINISTIC_SEEDING, "false") + .put(IndexOptions.HNSW_METRIC, Metric.EUCLIDEAN_METRIC.name()) + .put(IndexOptions.HNSW_NUM_DIMENSIONS, "128") + .put(IndexOptions.HNSW_USE_INLINING, "false") + .put(IndexOptions.HNSW_M, "16") + .put(IndexOptions.HNSW_M_MAX, "16") + .put(IndexOptions.HNSW_M_MAX_0, "32") + .put(IndexOptions.HNSW_EF_CONSTRUCTION, "200") + .put(IndexOptions.HNSW_EXTEND_CANDIDATES, "false") + .put(IndexOptions.HNSW_KEEP_PRUNED_CONNECTIONS, "false") + .put(IndexOptions.HNSW_USE_RABITQ, "false") + .put(IndexOptions.HNSW_RABITQ_NUM_EX_BITS, "4") + + // these are allowed to change in any way + .put(IndexOptions.HNSW_SAMPLE_VECTOR_STATS_PROBABILITY, "0.999") + .put(IndexOptions.HNSW_MAINTAIN_STATS_PROBABILITY, "0.78") + .put(IndexOptions.HNSW_STATS_THRESHOLD, "500") + .put(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NODE_FETCHES, "17") + .put(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NEIGHBORHOOD_FETCHES, "9").build()); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_DETERMINISTIC_SEEDING, "true"))).isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_METRIC, Metric.EUCLIDEAN_SQUARE_METRIC.name()))) + .isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "768"))) + .isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_USE_INLINING, "true"))).isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_M, "8"))).isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_M_MAX, "8"))).isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_M_MAX_0, "16"))).isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_EF_CONSTRUCTION, "500"))).isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_EXTEND_CANDIDATES, "true"))).isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_KEEP_PRUNED_CONNECTIONS, "true"))) + .isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_USE_RABITQ, "true"))).isInstanceOf(MetaDataException.class); + + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_RABITQ_NUM_EX_BITS, "1"))).isInstanceOf(MetaDataException.class); + } + } + + private void validateIndexEvolution(@Nonnull final MetaDataValidator metaDataValidator, + @Nonnull final Index oldIndex, @Nonnull final Map optionsMap) { + final Index newIndex = + new Index("GroupedVectorIndex", + new KeyWithValueExpression(concat(field("group_id"), field("vector_data")), 1), + IndexTypes.VECTOR, + optionsMap); + + final IndexMaintainerFactoryRegistry indexMaintainerRegistry = recordStore.getIndexMaintainerRegistry(); + final IndexMaintainerFactory indexMaintainerFactory = + indexMaintainerRegistry.getIndexMaintainerFactory(oldIndex); + + final IndexValidator validatorForCompatibleNewIndex = + indexMaintainerFactory.getIndexValidator(newIndex); + validatorForCompatibleNewIndex.validate(metaDataValidator); + validatorForCompatibleNewIndex.validateChangedOptions(oldIndex); + } + + @SuppressWarnings("resource") + @Test + void directIndexMaintainerTest() throws Exception { + try (FDBRecordContext context = openContext()) { + openRecordStore(context, this::addGroupedVectorIndex); + + final Index index = + Objects.requireNonNull(recordStore.getMetaDataProvider()) + .getRecordMetaData().getIndex("GroupedVectorIndex"); + final IndexMaintainer indexMaintainer = recordStore.getIndexMaintainer(index); + Assertions.assertThatThrownBy(() -> indexMaintainer.scan(IndexScanType.BY_VALUE, TupleRange.ALL, + null, ScanProperties.FORWARD_SCAN)).isInstanceOf(IllegalStateException.class); + } + } + + + @ParameterizedTest + @MethodSource("randomSeedsWithReturnVectors") + void directIndexReadGroupedWithContinuationTest(final long seed, final boolean returnVectors) throws Exception { + final int k = 100; + final Random random = new Random(seed); + final HalfRealVector queryVector = randomHalfVector(random, 128); + + final Map> expectedResults = + saveRandomRecords(random, this::addGroupedVectorIndex, true, 1000, queryVector); + + try (FDBRecordContext context = openContext()) { + openRecordStore(context, this::addGroupedVectorIndex); + + final int[] allCounters = new int[2]; + final int[] recallCounters = new int[2]; + + final Index index = + Objects.requireNonNull(recordStore.getMetaDataProvider()) + .getRecordMetaData().getIndex("GroupedVectorIndex"); + final IndexMaintainer indexMaintainer = recordStore.getIndexMaintainer(index); + final VectorIndexScanComparisons vectorIndexScanComparisons = + createVectorIndexScanComparisons(queryVector, k, + VectorIndexScanOptions.builder() + .putOption(VectorIndexScanOptions.HNSW_RETURN_VECTORS, returnVectors) + .build()); + final ScanProperties scanProperties = ExecuteProperties.newBuilder() + .setIsolationLevel(IsolationLevel.SERIALIZABLE) + .setState(ExecuteState.NO_LIMITS) + .setReturnedRowLimit(Integer.MAX_VALUE).build().asScanProperties(false); + + + try (final RecordCursor cursor = + indexMaintainer.scan(vectorIndexScanComparisons.bind(recordStore, index, + EvaluationContext.empty()), null, scanProperties)) { + final RecordCursorIterator cursorIterator = cursor.asIterator(); + int numRecords = 0; + while (cursorIterator.hasNext()) { + final IndexEntry indexEntry = Objects.requireNonNull(cursorIterator.next()); + numRecords++; + final int groupId = Math.toIntExact(indexEntry.getPrimaryKey().getLong(0)); + final long recNo = Math.toIntExact(indexEntry.getPrimaryKey().getLong(1)); + allCounters[groupId]++; + if (expectedResults.get(groupId).contains(recNo)) { + recallCounters[groupId]++; + } + assertThat(indexEntry.getValue().get(0) != null).isEqualTo(returnVectors); + } + if (logger.isInfoEnabled()) { + logger.info("grouped read {} records, allCounters={}, recallCounters={}", numRecords, allCounters, + recallCounters); + } + } + + assertThat(Ints.asList(allCounters)) + .allSatisfy(allCounter -> + assertThat(allCounter).isEqualTo(k)); + assertThat(Ints.asList(recallCounters)) + .allSatisfy(recallCounter -> + assertThat((double)recallCounter / k).isGreaterThan(0.9)); + } + } + + @Nonnull + private static RecordQueryIndexPlan createIndexPlan(@Nonnull final HalfRealVector queryVector, final int k, + @Nonnull final String indexName) { + final var vectorIndexScanComparisons = + createVectorIndexScanComparisons(queryVector, k, VectorIndexScanOptions.empty()); + + final var baseRecordType = + Type.Record.fromFieldDescriptorsMap( + Type.Record.toFieldDescriptorMap(VectorRecord.getDescriptor().getFields())); + + return new RecordQueryIndexPlan(indexName, field("recNo"), + vectorIndexScanComparisons, IndexFetchMethod.SCAN_AND_FETCH, + RecordQueryFetchFromPartialRecordPlan.FetchIndexRecords.PRIMARY_KEY, false, false, + Optional.empty(), baseRecordType, QueryPlanConstraint.noConstraint()); + } + + @Nonnull + private static VectorIndexScanComparisons createVectorIndexScanComparisons(@Nonnull final HalfRealVector queryVector, final int k, + @Nonnull final VectorIndexScanOptions vectorIndexScanOptions) { + final Comparisons.DistanceRankValueComparison distanceRankComparison = + new Comparisons.DistanceRankValueComparison(Comparisons.Type.DISTANCE_RANK_LESS_THAN_OR_EQUAL, + new LiteralValue<>(Type.Vector.of(false, 16, 128), queryVector), + new LiteralValue<>(k)); + + return VectorIndexScanComparisons.byDistance(ScanComparisons.EMPTY, + distanceRankComparison, vectorIndexScanOptions); + } + + @Nonnull + private Map> saveRandomRecords(@Nonnull final Random random, @Nonnull final RecordMetaDataHook hook, + final boolean useAsync, final int numSamples, + @Nonnull final HalfRealVector queryVector) throws Exception { + final List> savedRecords = + saveRecords(useAsync, hook, random, numSamples); + + return sortByDistances(savedRecords, queryVector, Metric.EUCLIDEAN_METRIC) + .stream() + .map(NodeReference::getPrimaryKey) + .map(primaryKey -> primaryKey.getLong(0)) + .collect(Collectors.groupingBy(nodeId -> Math.toIntExact(nodeId) % 2, Collectors.toSet())); + } +} diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java new file mode 100644 index 0000000000..525d64f593 --- /dev/null +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java @@ -0,0 +1,226 @@ +/* + * VectorIndexTestBase.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.async.hnsw.NodeReferenceWithDistance; +import com.apple.foundationdb.half.Half; +import com.apple.foundationdb.linear.AffineOperator; +import com.apple.foundationdb.linear.HalfRealVector; +import com.apple.foundationdb.linear.Metric; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.record.RecordMetaData; +import com.apple.foundationdb.record.RecordMetaDataBuilder; +import com.apple.foundationdb.record.metadata.Index; +import com.apple.foundationdb.record.metadata.IndexOptions; +import com.apple.foundationdb.record.metadata.IndexTypes; +import com.apple.foundationdb.record.metadata.expressions.KeyWithValueExpression; +import com.apple.foundationdb.record.provider.foundationdb.FDBRecordContext; +import com.apple.foundationdb.record.provider.foundationdb.FDBStoredRecord; +import com.apple.foundationdb.record.provider.foundationdb.query.FDBRecordStoreQueryTestBase; +import com.apple.foundationdb.record.vector.TestRecordsVectorsProto; +import com.apple.foundationdb.record.vector.TestRecordsVectorsProto.VectorRecord; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.test.Tags; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Tag; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +import static com.apple.foundationdb.record.metadata.Key.Expressions.concat; +import static com.apple.foundationdb.record.metadata.Key.Expressions.concatenateFields; +import static com.apple.foundationdb.record.metadata.Key.Expressions.field; + +/** + * Common test helpers for vector type indexes. + */ +@Tag(Tags.RequiresFDB) +public class VectorIndexTestBase extends FDBRecordStoreQueryTestBase { + private static final Logger logger = LoggerFactory.getLogger(VectorIndexTestBase.class); + + @CanIgnoreReturnValue + protected RecordMetaDataBuilder addVectorIndexes(@Nonnull final RecordMetaDataBuilder metaDataBuilder) { + addUngroupedVectorIndex(metaDataBuilder); + addGroupedVectorIndex(metaDataBuilder); + return metaDataBuilder; + } + + @CanIgnoreReturnValue + protected RecordMetaDataBuilder addUngroupedVectorIndex(@Nonnull final RecordMetaDataBuilder metaDataBuilder) { + metaDataBuilder.addIndex("VectorRecord", + new Index("UngroupedVectorIndex", new KeyWithValueExpression(field("vector_data"), 0), + IndexTypes.VECTOR, + ImmutableMap.of(IndexOptions.HNSW_METRIC, Metric.EUCLIDEAN_METRIC.name(), + IndexOptions.HNSW_NUM_DIMENSIONS, "128"))); + return metaDataBuilder; + } + + @CanIgnoreReturnValue + protected RecordMetaDataBuilder addGroupedVectorIndex(@Nonnull final RecordMetaDataBuilder metaDataBuilder) { + metaDataBuilder.addIndex("VectorRecord", + new Index("GroupedVectorIndex", new KeyWithValueExpression(concat(field("group_id"), field("vector_data")), 1), + IndexTypes.VECTOR, + ImmutableMap.of(IndexOptions.HNSW_METRIC, Metric.EUCLIDEAN_METRIC.name(), + IndexOptions.HNSW_NUM_DIMENSIONS, "128"))); + return metaDataBuilder; + } + + protected void openRecordStore(FDBRecordContext context) throws Exception { + openRecordStore(context, NO_HOOK); + } + + protected void openRecordStore(final FDBRecordContext context, final RecordMetaDataHook hook) throws Exception { + RecordMetaDataBuilder metaDataBuilder = RecordMetaData.newBuilder().setRecords(TestRecordsVectorsProto.getDescriptor()); + metaDataBuilder.getRecordType("VectorRecord").setPrimaryKey(concatenateFields("group_id", "rec_no")); + hook.apply(metaDataBuilder); + createOrOpenRecordStore(context, metaDataBuilder.getRecordMetaData()); + } + + protected static Function getRecordGenerator(@Nonnull final Random random, + final double nullProbability) { + return recNo -> { + final VectorRecord.Builder recordBuilder = + VectorRecord.newBuilder() + .setRecNo(recNo) + .setGroupId(recNo.intValue() % 2); + if (random.nextDouble() >= nullProbability) { + final RealVector vector = randomHalfVector(random, 128); + recordBuilder.setVectorData(ByteString.copyFrom(vector.getRawData())); + } + + return recordBuilder.build(); + }; + } + + @Nonnull + protected static HalfRealVector randomHalfVector(final Random random, final int numDimensions) { + final Half[] componentData = new Half[numDimensions]; + for (int i = 0; i < componentData.length; i++) { + componentData[i] = Half.valueOf(random.nextFloat()); + } + + return new HalfRealVector(componentData); + } + + protected List> saveRecords(final boolean useAsync, + @Nonnull final RecordMetaDataHook hook, + @Nonnull final Random random, + final int numSamples) throws Exception { + return saveRecords(useAsync, hook, random, numSamples, 0.0d); + } + + protected List> saveRecords(final boolean useAsync, + @Nonnull final RecordMetaDataHook hook, + @Nonnull final Random random, + final int numSamples, + final double nullProbability) throws Exception { + final var recordGenerator = getRecordGenerator(random, nullProbability); + if (useAsync) { + return asyncBatch(hook, numSamples, 100, + recNo -> recordStore.saveRecordAsync(recordGenerator.apply(recNo))); + } else { + return batch(hook, numSamples, 100, + recNo -> recordStore.saveRecord(recordGenerator.apply(recNo))); + } + } + + private List> batch(final RecordMetaDataHook hook, final int numRecords, + final int batchSize, + Function> recordConsumer) throws Exception { + final List> records = Lists.newArrayList(); + while (records.size() < numRecords) { + try (FDBRecordContext context = openContext()) { + openRecordStore(context, hook); + int recNoInBatch; + + for (recNoInBatch = 0; records.size() < numRecords && recNoInBatch < batchSize; recNoInBatch++) { + records.add(recordConsumer.apply((long)records.size())); + } + commit(context); + logger.info("committed batch of sync inserts, numRecordsCommitted = {}", records.size()); + } + } + return records; + } + + private List> + asyncBatch(@Nonnull final RecordMetaDataHook hook, + final int numRecords, + final int batchSize, + @Nonnull final Function>> recordConsumer) throws Exception { + final List> records = Lists.newArrayList(); + while (records.size() < numRecords) { + try (FDBRecordContext context = openContext()) { + openRecordStore(context, hook); + int recNoInBatch; + final ArrayList>> futures = Lists.newArrayList(); + + for (recNoInBatch = 0; records.size() + recNoInBatch < numRecords && recNoInBatch < batchSize; recNoInBatch++) { + futures.add(recordConsumer.apply((long)records.size() + recNoInBatch)); + } + + // wait and then commit + AsyncUtil.whenAll(futures).get(); + futures.forEach(future -> records.add(future.join())); + commit(context); + logger.info("committed batch of async inserts, numRecordsCommitted = {}", records.size()); + } + } + return records; + } + + protected static List + sortByDistances(@Nonnull final List> storedRecords, + @Nonnull final RealVector queryVector, + @Nonnull final Metric metric) { + return storedRecords.stream() + .map(storedRecord -> { + final VectorRecord vectorRecord = (VectorRecord)storedRecord.getRecord(); + final RealVector storedVector = + RealVector.fromBytes(vectorRecord.getVectorData().toByteArray()); + return new NodeReferenceWithDistance(Tuple.from(vectorRecord.getRecNo()), + AffineOperator.identity().transform(storedVector), + metric.distance(storedVector, queryVector)); + }) + .sorted(Comparator.comparing(NodeReferenceWithDistance::getDistance)) + .collect(ImmutableList.toImmutableList()); + } + + protected static void logRecord(final long recNo, @Nonnull final ByteString vectorData) { + if (logger.isInfoEnabled()) { + logger.info("recNo: {}; vectorData: [{})", + recNo, RealVector.fromBytes(vectorData.toByteArray())); + } + } +} diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/query/FDBRecordStoreQueryTestBase.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/query/FDBRecordStoreQueryTestBase.java index 60dc3336c3..9fe2cf17a6 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/query/FDBRecordStoreQueryTestBase.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/query/FDBRecordStoreQueryTestBase.java @@ -23,6 +23,8 @@ import com.apple.foundationdb.record.Bindings; import com.apple.foundationdb.record.EvaluationContext; import com.apple.foundationdb.record.ExecuteProperties; +import com.apple.foundationdb.record.ExecuteState; +import com.apple.foundationdb.record.IsolationLevel; import com.apple.foundationdb.record.PlanHashable; import com.apple.foundationdb.record.PlanSerializationContext; import com.apple.foundationdb.record.RecordCursor; @@ -57,12 +59,16 @@ import com.apple.foundationdb.record.query.plan.RecordQueryPlanner; import com.apple.foundationdb.record.query.plan.cascades.CascadesPlanner; import com.apple.foundationdb.record.query.plan.cascades.CorrelationIdentifier; +import com.apple.foundationdb.record.query.plan.cascades.Memoizer; +import com.apple.foundationdb.record.query.plan.cascades.PlannerStage; import com.apple.foundationdb.record.query.plan.cascades.Reference; +import com.apple.foundationdb.record.query.plan.cascades.References; import com.apple.foundationdb.record.query.plan.cascades.debug.Debugger; import com.apple.foundationdb.record.query.plan.cascades.matching.structure.BindingMatcher; import com.apple.foundationdb.record.query.plan.cascades.explain.ExplainPlanVisitor; import com.apple.foundationdb.record.query.plan.cascades.typing.TypeRepository; import com.apple.foundationdb.record.query.plan.cascades.values.ConstantObjectValue; +import com.apple.foundationdb.record.query.plan.cascades.values.translation.ToUniqueAliasesTranslationMap; import com.apple.foundationdb.record.query.plan.plans.QueryResult; import com.apple.foundationdb.record.query.plan.plans.RecordQueryPlan; import com.apple.foundationdb.record.query.plan.serialization.DefaultPlanSerializationRegistry; @@ -698,16 +704,51 @@ protected static RecordQueryPlan verifySerialization(@Nonnull final RecordQueryP return deserializedPlan; } + /** + * Serialize the plan to bytes, parse those bytes, reconstruct, and compare the deserialized plan against the + * original plan. + * @param plan the original plan + * @return the deserialized and verified plan + */ + @Nonnull + protected static RecordQueryPlan verifyRebase(@Nonnull final RecordQueryPlan plan) { + final var rebasedPlans = + References.rebaseGraphs(ImmutableList.of(Reference.plannedOf(plan)), + Memoizer.noMemoization(PlannerStage.PLANNED), new ToUniqueAliasesTranslationMap(), + false); + Assertions.assertEquals(1, rebasedPlans.size()); + final var rebasedPlan = rebasedPlans.get(0).getOnlyElementAsPlan(); + Assertions.assertEquals(plan.planHash(PlanHashable.CURRENT_FOR_CONTINUATION), + rebasedPlan.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)); + Assertions.assertEquals(plan, rebasedPlan); + return rebasedPlan; + } + @Nonnull protected RecordCursorIterator> executeQuery(@Nonnull final RecordQueryPlan plan) { return executeQuery(plan, Bindings.EMPTY_BINDINGS); } @Nonnull - protected RecordCursorIterator> executeQuery(@Nonnull final RecordQueryPlan plan, @Nonnull Bindings bindings) { + protected RecordCursorIterator> executeQuery(@Nonnull final RecordQueryPlan plan, + @Nonnull final Bindings bindings) { + return executeQuery(plan, null, bindings, Integer.MAX_VALUE); + } + + @Nonnull + @SuppressWarnings("resource") + protected RecordCursorIterator> executeQuery(@Nonnull final RecordQueryPlan plan, + @Nullable byte[] continuation, + @Nonnull final Bindings bindings, + final int limit) { final var usedTypes = usedTypes().evaluate(plan); final var typeRepository = TypeRepository.newBuilder().addAllTypes(usedTypes).build(); - return plan.execute(recordStore, EvaluationContext.forBindingsAndTypeRepository(bindings, typeRepository)).asIterator(); + final var executeProperties = ExecuteProperties.newBuilder() + .setIsolationLevel(IsolationLevel.SERIALIZABLE) + .setState(ExecuteState.NO_LIMITS) + .setReturnedRowLimit(limit).build(); + return plan.execute(recordStore, EvaluationContext.forBindingsAndTypeRepository(bindings, typeRepository), + continuation, executeProperties).asIterator(); } protected static class Holder { diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/expressions/ComparisonsTestBase.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/expressions/ComparisonsTestBase.java new file mode 100644 index 0000000000..041d625f11 --- /dev/null +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/expressions/ComparisonsTestBase.java @@ -0,0 +1,102 @@ +/* + * ComparisonsTestBase.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.query.expressions; + +import com.apple.foundationdb.record.PlanHashable; +import com.apple.foundationdb.record.PlanSerializationContext; +import com.apple.foundationdb.record.planprotos.PComparison; +import com.apple.foundationdb.record.query.plan.cascades.CorrelationIdentifier; +import com.apple.foundationdb.record.query.plan.cascades.values.LeafValue; +import com.apple.foundationdb.record.query.plan.cascades.values.QuantifiedObjectValue; +import com.apple.foundationdb.record.query.plan.cascades.values.Value; +import com.apple.foundationdb.record.query.plan.cascades.values.translation.TranslationMap; +import com.apple.foundationdb.record.query.plan.explain.DefaultExplainFormatter; +import com.apple.foundationdb.record.query.plan.serialization.PlanSerialization; +import com.google.common.collect.Iterables; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Message; +import org.assertj.core.api.Assertions; + +import javax.annotation.Nonnull; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; + +class ComparisonsTestBase { + protected ComparisonsTestBase() { + // nothing + } + + @Nonnull + protected static CorrelationIdentifier q1() { + return CorrelationIdentifier.of("q1"); + } + + @Nonnull + protected static CorrelationIdentifier q2() { + return CorrelationIdentifier.of("q2"); + } + + @Nonnull + protected static CorrelationIdentifier q3() { + return CorrelationIdentifier.of("q3"); + } + + @SuppressWarnings("unchecked") + protected static void protoRoundTripComparison(@Nonnull final T original) { + final PComparison comparisonProto = original.toComparisonProto(PlanSerializationContext.newForCurrentMode()); + final Map allFields = comparisonProto.getAllFields(); + Assertions.assertThat(allFields).hasSize(1); + final Message specificComparison = (Message)Iterables.getOnlyElement(allFields.values()); + assertThat(original.toProto(PlanSerializationContext.newForCurrentMode())).isEqualTo(specificComparison); + + final T roundTripped = + (T)PlanSerialization.dispatchFromProtoContainer(PlanSerializationContext.newForCurrentMode(), + comparisonProto); + assertThat(roundTripped.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)) + .isEqualTo(original.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)); + assertThat(roundTripped).hasSameHashCodeAs(original); + assertThat(roundTripped).isEqualTo(original); + } + + @Nonnull + protected static Function> replacementFunctionFromTranslationMap(@Nonnull final TranslationMap translationMap) { + return value -> { + if (value instanceof QuantifiedObjectValue) { + final CorrelationIdentifier alias = ((QuantifiedObjectValue)value).getAlias(); + if (translationMap.containsSourceAlias(alias)) { + return Optional.of(translationMap.applyTranslationFunction(alias, (LeafValue)value)); + } + } + return Optional.empty(); + }; + } + + @Nonnull + protected static String renderExplain(@Nonnull final Comparisons.Comparison comparison) { + return comparison.explain() + .getExplainTokens() + .render(DefaultExplainFormatter.forDebugging()) + .toString(); + } +} diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/expressions/DistanceRankValueComparisonTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/expressions/DistanceRankValueComparisonTest.java new file mode 100644 index 0000000000..52a5c458d4 --- /dev/null +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/expressions/DistanceRankValueComparisonTest.java @@ -0,0 +1,182 @@ +/* + * DistanceRankValueComparisonTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.query.expressions; + +import com.apple.foundationdb.linear.DoubleRealVector; +import com.apple.foundationdb.record.EvaluationContext; +import com.apple.foundationdb.record.PlanHashable; +import com.apple.foundationdb.record.query.ParameterRelationshipGraph; +import com.apple.foundationdb.record.query.expressions.Comparisons.DistanceRankValueComparison; +import com.apple.foundationdb.record.query.plan.cascades.typing.Type; +import com.apple.foundationdb.record.query.plan.cascades.values.LiteralValue; +import com.apple.foundationdb.record.query.plan.cascades.values.QuantifiedObjectValue; +import com.apple.foundationdb.record.query.plan.cascades.values.Value; +import com.apple.foundationdb.record.query.plan.cascades.values.translation.TranslationMap; +import org.junit.jupiter.api.Test; + +import javax.annotation.Nonnull; +import java.util.Optional; +import java.util.concurrent.ThreadLocalRandom; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class DistanceRankValueComparisonTest extends ComparisonsTestBase { + @Test + void withValueTest() { + final DistanceRankValueComparison original = randomComparison(); + final Value originalVectorValue = original.getComparandValue(); + final DistanceRankValueComparison withNewValue = original.withValue(getRandomVectorValue()); + assertThat(withNewValue).isNotEqualTo(original); + final DistanceRankValueComparison withOldValue = original.withValue(originalVectorValue); + assertThat(withOldValue.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)) + .isEqualTo(original.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)); + assertThat(withOldValue).hasSameHashCodeAs(original); + assertThat(withOldValue).isEqualTo(original); + } + + @Test + void withTypeTest() { + final DistanceRankValueComparison original = randomComparison(); + final Comparisons.Type originalVectorType = original.getType(); + final DistanceRankValueComparison withNewType = original.withType(Comparisons.Type.DISTANCE_RANK_LESS_THAN); + assertThat(withNewType).isNotEqualTo(original); + final DistanceRankValueComparison withOldType = original.withType(originalVectorType); + assertThat(withOldType.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)) + .isEqualTo(original.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)); + assertThat(withOldType).hasSameHashCodeAs(original); + assertThat(withOldType).isEqualTo(original); + } + + @Test + void withParameterRelationshipMapTest() { + final DistanceRankValueComparison original = randomComparison(); + final DistanceRankValueComparison withNewGraph = + original.withParameterRelationshipMap(ParameterRelationshipGraph.empty()); + assertThat(withNewGraph).hasSameHashCodeAs(original); + assertThat(withNewGraph).isEqualTo(original); + } + + @Test + void correlatedToTest() { + final DistanceRankValueComparison comparison = randomComparison(); + assertThat(comparison.getCorrelatedTo()).isEmpty(); + final DistanceRankValueComparison correlatedComparison = correlatedComparison(); + assertThat(correlatedComparison.getCorrelatedTo()).containsExactly(q1(), q2()); + assertThat(correlatedComparison.isCorrelatedTo(q1())).isTrue(); + assertThat(correlatedComparison.isCorrelatedTo(q2())).isTrue(); + assertThat(correlatedComparison.isCorrelatedTo(q3())).isFalse(); + } + + @Test + void replaceValuesTest() { + final DistanceRankValueComparison original = randomComparison(); + final TranslationMap translationMap = + TranslationMap.regularBuilder() + .when(q1()).then(((sourceAlias, leafValue) -> original.getComparandValue())) + .when(q2()).then(((sourceAlias, leafValue) -> original.getLimitValue())) + .build(); + + final DistanceRankValueComparison correlatedComparison = correlatedComparison(); + + final Optional translatedOptional = + correlatedComparison.replaceValuesMaybe(replacementFunctionFromTranslationMap(translationMap)); + assertThat(translatedOptional).contains(original); + + final TranslationMap badTranslationMap = + TranslationMap.regularBuilder() + .when(q1()).then(((sourceAlias, leafValue) -> original.getComparandValue())) + .build(); + final Optional badlyTranslatedOptional = + correlatedComparison.replaceValuesMaybe(replacementFunctionFromTranslationMap(badTranslationMap)); + assertThat(badlyTranslatedOptional).isEmpty(); + } + + @Test + void translateCorrelationsTest() { + final DistanceRankValueComparison original = randomComparison(); + final TranslationMap translationMap = + TranslationMap.regularBuilder() + .when(q1()).then(((sourceAlias, leafValue) -> original.getComparandValue())) + .when(q2()).then(((sourceAlias, leafValue) -> original.getLimitValue())) + .build(); + + final DistanceRankValueComparison correlatedComparison = correlatedComparison(); + final DistanceRankValueComparison translated = + correlatedComparison.translateCorrelations(translationMap, false); + assertThat(translated).isEqualTo(original); + } + + @Test + void protoRoundTripTest1() { + protoRoundTripComparison(randomComparison()); + } + + @Test + void protoRoundTripTest2() { + protoRoundTripComparison(correlatedComparison()); + } + + @Test + void explainTest() { + final DistanceRankValueComparison randomComparison = randomComparison(); + final DistanceRankValueComparison randomComparison2 = randomComparison(); + assertThat(renderExplain(randomComparison)).isNotEqualTo(renderExplain(randomComparison2)); + assertThat(randomComparison.typelessString()).isNotEqualTo(randomComparison2.typelessString()); + assertThat(randomComparison).doesNotHaveToString(randomComparison2.toString()); + + final DistanceRankValueComparison comparison = correlatedComparison(); + final DistanceRankValueComparison comparison2 = correlatedComparison(); + assertThat(renderExplain(comparison)).isEqualTo(renderExplain(comparison2)); + assertThat(comparison.typelessString()).isEqualTo(comparison2.typelessString()); + assertThat(comparison).hasToString(comparison2.toString()); + } + + @Test + void evalTest() { + assertThatThrownBy(() -> randomComparison().eval(null, EvaluationContext.empty(), 10)) + .isInstanceOf(IllegalStateException.class); + } + + @Nonnull + private static DistanceRankValueComparison correlatedComparison() { + return new DistanceRankValueComparison(Comparisons.Type.DISTANCE_RANK_LESS_THAN_OR_EQUAL, + QuantifiedObjectValue.of(q1(), Type.Vector.of(false, 64, 128)), + QuantifiedObjectValue.of(q2(), Type.primitiveType(Type.TypeCode.INT, false))); + } + + @Nonnull + private static DistanceRankValueComparison randomComparison() { + return new DistanceRankValueComparison(Comparisons.Type.DISTANCE_RANK_LESS_THAN_OR_EQUAL, + getRandomVectorValue(), new LiteralValue<>(10)); + } + + @Nonnull + private static LiteralValue getRandomVectorValue() { + final int numDimensions = 128; + final double[] components = new double[128]; + for (int i = 0; i < numDimensions; i ++) { + components[i] = ThreadLocalRandom.current().nextDouble(); + } + return new LiteralValue<>(Type.Vector.of(false, 64, 128), + new DoubleRealVector(components)); + } +} diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/expressions/ValueComparisonTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/expressions/ValueComparisonTest.java new file mode 100644 index 0000000000..9813c1791b --- /dev/null +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/expressions/ValueComparisonTest.java @@ -0,0 +1,170 @@ +/* + * ValueComparisonTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.query.expressions; + +import com.apple.foundationdb.record.Bindings; +import com.apple.foundationdb.record.EvaluationContext; +import com.apple.foundationdb.record.PlanHashable; +import com.apple.foundationdb.record.query.ParameterRelationshipGraph; +import com.apple.foundationdb.record.query.expressions.Comparisons.Comparison; +import com.apple.foundationdb.record.query.expressions.Comparisons.ValueComparison; +import com.apple.foundationdb.record.query.plan.cascades.typing.Type; +import com.apple.foundationdb.record.query.plan.cascades.values.LiteralValue; +import com.apple.foundationdb.record.query.plan.cascades.values.QuantifiedObjectValue; +import com.apple.foundationdb.record.query.plan.cascades.values.Value; +import com.apple.foundationdb.record.query.plan.cascades.values.translation.TranslationMap; +import org.junit.jupiter.api.Test; + +import javax.annotation.Nonnull; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; + +class ValueComparisonTest extends ComparisonsTestBase { + @Test + void withValueTest() { + final ValueComparison original = comparison(); + final Value originalVectorValue = original.getComparandValue(); + final ValueComparison withNewValue = original.withValue(new LiteralValue<>(20)); + assertThat(withNewValue).isNotEqualTo(original); + final ValueComparison withOldValue = original.withValue(originalVectorValue); + assertThat(withOldValue.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)) + .isEqualTo(original.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)); + assertThat(withOldValue).hasSameHashCodeAs(original); + assertThat(withOldValue).isEqualTo(original); + } + + @Test + void withTypeTest() { + final ValueComparison original = comparison(); + final Comparisons.Type originalVectorType = original.getType(); + final ValueComparison withNewType = original.withType(Comparisons.Type.DISTANCE_RANK_LESS_THAN); + assertThat(withNewType).isNotEqualTo(original); + final ValueComparison withOldType = original.withType(originalVectorType); + assertThat(withOldType.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)) + .isEqualTo(original.planHash(PlanHashable.CURRENT_FOR_CONTINUATION)); + assertThat(withOldType).hasSameHashCodeAs(original); + assertThat(withOldType).isEqualTo(original); + } + + @Test + void withParameterRelationshipMapTest() { + final ValueComparison original = comparison(); + final ValueComparison withNewGraph = + original.withParameterRelationshipMap(ParameterRelationshipGraph.empty()); + assertThat(withNewGraph).hasSameHashCodeAs(original); + assertThat(withNewGraph).isEqualTo(original); + } + + @Test + void correlatedToTest() { + final ValueComparison comparison = comparison(); + assertThat(comparison.getCorrelatedTo()).isEmpty(); + final ValueComparison correlatedComparison = correlatedComparison(); + assertThat(correlatedComparison.getCorrelatedTo()).containsExactly(q1()); + assertThat(correlatedComparison.isCorrelatedTo(q1())).isTrue(); + assertThat(correlatedComparison.isCorrelatedTo(q2())).isFalse(); + } + + @Test + void replaceValuesTest() { + final ValueComparison original = comparison(); + final TranslationMap translationMap = + TranslationMap.regularBuilder() + .when(q1()).then(((sourceAlias, leafValue) -> original.getComparandValue())) + .build(); + + final ValueComparison correlatedComparison = correlatedComparison(); + + final Optional translatedOptional = + correlatedComparison.replaceValuesMaybe(replacementFunctionFromTranslationMap(translationMap)); + assertThat(translatedOptional).contains(original); + + final TranslationMap badTranslationMap = + TranslationMap.regularBuilder() + .when(q2()).then(((sourceAlias, leafValue) -> original.getComparandValue())) + .build(); + final Optional badlyTranslatedOptional = + correlatedComparison.replaceValuesMaybe(replacementFunctionFromTranslationMap(badTranslationMap)); + assertThat(badlyTranslatedOptional).isEmpty(); + } + + @Test + void translateCorrelationsTest() { + final ValueComparison original = comparison(); + final TranslationMap translationMap = + TranslationMap.regularBuilder() + .when(q1()).then(((sourceAlias, leafValue) -> original.getComparandValue())) + .build(); + + final ValueComparison correlatedComparison = correlatedComparison(); + final ValueComparison translated = + correlatedComparison.translateCorrelations(translationMap, false); + assertThat(translated).isEqualTo(original); + } + + @Test + void protoRoundTripTest1() { + protoRoundTripComparison(comparison()); + } + + @Test + void protoRoundTripTest2() { + protoRoundTripComparison(correlatedComparison()); + } + + @Test + void explainTest() { + final ValueComparison randomComparison = comparison(); + final ValueComparison randomComparison2 = comparison(); + assertThat(renderExplain(randomComparison)).isEqualTo(renderExplain(randomComparison2)); + assertThat(randomComparison.typelessString()).isEqualTo(randomComparison2.typelessString()); + assertThat(randomComparison).hasToString(randomComparison2.toString()); + + final ValueComparison comparison = correlatedComparison(); + final ValueComparison comparison2 = correlatedComparison(); + assertThat(renderExplain(comparison)).isEqualTo(renderExplain(comparison2)); + assertThat(comparison.typelessString()).isEqualTo(comparison2.typelessString()); + assertThat(comparison).hasToString(comparison2.toString()); + } + + @Test + void evalTest() { + assertThat(comparison().eval(null, EvaluationContext.empty(), 10)).isTrue(); + assertThat(comparison().eval(null, EvaluationContext.empty(), 20)).isFalse(); + + final EvaluationContext evaluationContext = + EvaluationContext.empty().withBinding(Bindings.Internal.CORRELATION, q1(), 10); + assertThat(comparison().eval(null, evaluationContext, 10)).isTrue(); + assertThat(comparison().eval(null, evaluationContext, 20)).isFalse(); + } + + @Nonnull + private static ValueComparison correlatedComparison() { + return new ValueComparison(Comparisons.Type.EQUALS, + QuantifiedObjectValue.of(q1(), Type.primitiveType(Type.TypeCode.INT, false))); + } + + @Nonnull + private static ValueComparison comparison() { + return new ValueComparison(Comparisons.Type.EQUALS, new LiteralValue<>(10)); + } +} diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/ScanComparisonsTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/ScanComparisonsTest.java index 4fb9c91b25..04f0b02e9c 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/ScanComparisonsTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/ScanComparisonsTest.java @@ -46,7 +46,7 @@ private void checkRange(String expected, ScanComparisons.Builder comparisons, Ev } @Test - public void justEquals() throws Exception { + public void justEquals() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addEqualityComparison(new Comparisons.SimpleComparison(Comparisons.Type.EQUALS, "abc")); comparisons.addEqualityComparison(new Comparisons.SimpleComparison(Comparisons.Type.EQUALS, "xyz")); @@ -54,14 +54,14 @@ public void justEquals() throws Exception { } @Test - public void justLess() throws Exception { + public void justLess() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.LESS_THAN, "xyz")); checkRange("([null],[xyz])", comparisons, context()); } @Test - public void multipleLess() throws Exception { + public void multipleLess() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.LESS_THAN, "mno")); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.LESS_THAN, "xyz")); @@ -69,14 +69,14 @@ public void multipleLess() throws Exception { } @Test - public void justLessEquals() throws Exception { + public void justLessEquals() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.LESS_THAN_OR_EQUALS, "xyz")); checkRange("([null],[xyz]]", comparisons, context()); } @Test - public void lessAndLessEqualsSame() throws Exception { + public void lessAndLessEqualsSame() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.LESS_THAN_OR_EQUALS, "xyz")); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.LESS_THAN, "xyz")); @@ -84,14 +84,14 @@ public void lessAndLessEqualsSame() throws Exception { } @Test - public void justGreater() throws Exception { + public void justGreater() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.GREATER_THAN, "xyz")); checkRange("([xyz],>", comparisons, context()); } @Test - public void multipleGreater() throws Exception { + public void multipleGreater() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.GREATER_THAN, "mno")); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.GREATER_THAN, "xyz")); @@ -99,14 +99,14 @@ public void multipleGreater() throws Exception { } @Test - public void justGreaterEquals() throws Exception { + public void justGreaterEquals() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.GREATER_THAN_OR_EQUALS, "xyz")); checkRange("[[xyz],>", comparisons, context()); } @Test - public void greaterAndGreaterEqualsSame() throws Exception { + public void greaterAndGreaterEqualsSame() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.GREATER_THAN_OR_EQUALS, "xyz")); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.GREATER_THAN, "xyz")); @@ -132,7 +132,7 @@ public void parameters() throws Exception { } @Test - public void notNull() throws Exception { + public void notNull() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.NullComparison(Comparisons.Type.NOT_NULL)); comparisons.addInequalityComparison(new Comparisons.SimpleComparison(Comparisons.Type.LESS_THAN, "zzz")); @@ -140,7 +140,7 @@ public void notNull() throws Exception { } @Test - public void ambiguousUntilBound() throws Exception { + public void ambiguousUntilBound() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addInequalityComparison(new Comparisons.ParameterComparison(Comparisons.Type.GREATER_THAN_OR_EQUALS, "p1")); comparisons.addInequalityComparison(new Comparisons.ParameterComparison(Comparisons.Type.GREATER_THAN, "p2")); @@ -149,7 +149,7 @@ public void ambiguousUntilBound() throws Exception { } @Test - public void multiColumn() throws Exception { + public void multiColumn() { ScanComparisons.Builder comparisons = new ScanComparisons.Builder(); comparisons.addEqualityComparison(new Comparisons.NullComparison(Comparisons.Type.IS_NULL)); comparisons.addInequalityComparison(new Comparisons.MultiColumnComparison(new Comparisons.SimpleComparison(Comparisons.Type.LESS_THAN, Tuple.from("xxx", "yyy")))); diff --git a/fdb-record-layer-core/src/test/proto/test_records_vector.proto b/fdb-record-layer-core/src/test/proto/test_records_vector.proto new file mode 100644 index 0000000000..63ff019c9d --- /dev/null +++ b/fdb-record-layer-core/src/test/proto/test_records_vector.proto @@ -0,0 +1,39 @@ +/* + * test_records_multidimensional.proto + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +syntax = "proto2"; + +package com.apple.foundationdb.record.test.vector; + +option java_package = "com.apple.foundationdb.record.vector"; +option java_outer_classname = "TestRecordsVectorsProto"; + +import "record_metadata_options.proto"; + +option (schema).store_record_versions = true; + +message VectorRecord { + optional int64 rec_no = 1 [(field).primary_key = true]; + optional int32 group_id = 2; + optional bytes vector_data = 3; +} + +message RecordTypeUnion { + optional VectorRecord _VectorRecord = 1; +}