From 900920bf8ffc4d504a551a6c60649224fa713c3b Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Tue, 11 Nov 2025 16:28:14 +0100 Subject: [PATCH] some modifications to download the sift1m dataset from huggingface --- fdb-extensions/fdb-extensions.gradle | 46 +- .../async/hnsw/CompactStorageAdapter.java | 5 +- .../apple/foundationdb/async/hnsw/HNSW.java | 23 +- .../async/hnsw/InliningStorageAdapter.java | 7 +- .../foundationdb/async/hnsw/HNSWTest.java | 579 ++++++++++++++++-- 5 files changed, 596 insertions(+), 64 deletions(-) diff --git a/fdb-extensions/fdb-extensions.gradle b/fdb-extensions/fdb-extensions.gradle index e77584b6ed..6490ed676e 100644 --- a/fdb-extensions/fdb-extensions.gradle +++ b/fdb-extensions/fdb-extensions.gradle @@ -44,21 +44,44 @@ dependencies { testFixturesAnnotationProcessor(libs.autoService) } -def siftSmallFile = layout.buildDirectory.file('downloads/siftsmall.tar.gz') -def extractDir = layout.buildDirectory.dir("extracted") +// Describe all files here +def siftDownloads = [ + siftsmall: [ + url : 'https://huggingface.co/datasets/vecdata/siftsmall/resolve/3106e1b83049c44713b1ce06942d0ab474bbdfb6/siftsmall.tar.gz', + dest: layout.buildDirectory.file("downloads/siftsmall.tar.gz") + ], + sift1mbase: [ + url : 'https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_base.fvecs', + dest: layout.buildDirectory.file("downloads/sift_base.fvecs") + ], + sift1mgroundtruth: [ + url : 'https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_groundtruth.ivecs', + dest: layout.buildDirectory.file("downloads/sift_groundtruth.ivecs") + ], + sift1mquery: [ + url : 'https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_query.fvecs', + dest: layout.buildDirectory.file("downloads/sift_query.fvecs") + ], +] -// Task that downloads the CSV exactly once unless it changed -tasks.register('downloadSiftSmall', de.undercouch.gradle.tasks.download.Download) { - src 'https://huggingface.co/datasets/vecdata/siftsmall/resolve/3106e1b83049c44713b1ce06942d0ab474bbdfb6/siftsmall.tar.gz' - dest siftSmallFile.get().asFile - onlyIfModified true - tempAndMove true - retries 3 +// Register one Download task per entry +def downloadTasks = siftDownloads.collect { name, cfg -> + tasks.register("download${name.capitalize()}", Download) { + src cfg.url + dest cfg.dest.get().asFile + onlyIfModified true + tempAndMove true + retries 3 + overwrite false + outputs.file(dest) + } } +def extractDir = layout.buildDirectory.dir("extracted") + tasks.register('extractSiftSmall', Copy) { - dependsOn 'downloadSiftSmall' - from(tarTree(resources.gzip(siftSmallFile))) + dependsOn downloadTasks + from(tarTree(resources.gzip(siftDownloads.siftsmall.dest))) into extractDir doLast { @@ -72,6 +95,7 @@ tasks.register('extractSiftSmall', Copy) { } test { + dependsOn downloadTasks dependsOn tasks.named('extractSiftSmall') inputs.dir extractDir } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java index d14bee5368..b03c296f67 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -86,7 +86,8 @@ public CompactStorageAdapter(@Nonnull final Config config, * @param layer the layer of the node to fetch * @param primaryKey the primary key of the node to fetch * - * @return a future that will complete with the fetched {@link AbstractNode} + * @return a future that will complete with the fetched {@link AbstractNode} or {@code null} if the node cannot + * be fetched * * @throws IllegalStateException if the node cannot be found in the database for the given key */ @@ -101,7 +102,7 @@ protected CompletableFuture> fetchNodeInternal(@Nonn return readTransaction.get(keyBytes) .thenApply(valueBytes -> { if (valueBytes == null) { - throw new IllegalStateException("cannot fetch node"); + return null; } return nodeFromRaw(storageTransform, layer, primaryKey, keyBytes, valueBytes); }); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java index 2e7816b530..9f46c479c9 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -581,7 +581,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { return onReadListener.onAsyncRead( storageAdapter.fetchNode(readTransaction, storageTransform, layer, nodeReference.getPrimaryKey())) - .thenApply(node -> biMapFunction.apply(nodeReference, node)); + .thenApply(node -> biMapFunction.apply(nodeReference, Objects.requireNonNull(node))); } /** @@ -754,6 +754,18 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N } return StorageAdapter.fetchAccessInfo(getConfig(), transaction, getSubspace(), getOnReadListener()) + .thenCombine(exists(transaction, newPrimaryKey), + (accessInfo, nodeAlreadyExists) -> { + if (nodeAlreadyExists) { + if (logger.isInfoEnabled()) { + logger.info("new record already exists in HNSW with key={} on layer={}", newPrimaryKey, + insertionLayer); + } + + throw new IllegalStateException("key already exists"); + } + return accessInfo; + }) .thenCompose(accessInfo -> { final AccessInfo currentAccessInfo; final AffineOperator storageTransform = storageTransform(accessInfo); @@ -821,6 +833,15 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N }).thenCompose(ignored -> AsyncUtil.DONE); } + @Nonnull + @VisibleForTesting + CompletableFuture exists(@Nonnull final ReadTransaction readTransaction, + @Nonnull final Tuple primaryKey) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(0); + return storageAdapter.fetchNode(readTransaction, AffineOperator.identity(), 0, primaryKey) + .thenApply(Objects::nonNull); + } + /** * Method to keep stats if necessary. Stats need to be kept and maintained when the client would like to use * e.g. RaBitQ as RaBitQ needs a stable somewhat correct centroid in order to function properly. diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java index fce0fdac34..e9276024c0 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java @@ -103,7 +103,12 @@ public InliningStorageAdapter(@Nonnull final Config config, return AsyncUtil.collect(readTransaction.getRange(Range.startsWith(rangeKey), ReadTransaction.ROW_LIMIT_UNLIMITED, false, StreamingMode.WANT_ALL), readTransaction.getExecutor()) - .thenApply(keyValues -> nodeFromRaw(storageTransform, layer, primaryKey, keyValues)); + .thenApply(keyValues -> { + if (keyValues.isEmpty()) { + return null; + } + return nodeFromRaw(storageTransform, layer, primaryKey, keyValues); + }); } /** diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java index 94dc7a9804..871a455179 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java @@ -21,8 +21,10 @@ package com.apple.foundationdb.async.hnsw; import com.apple.foundationdb.Database; +import com.apple.foundationdb.Range; import com.apple.foundationdb.Transaction; import com.apple.foundationdb.async.rtree.RTree; +import com.apple.foundationdb.directory.DirectoryLayer; import com.apple.foundationdb.linear.AffineOperator; import com.apple.foundationdb.linear.DoubleRealVector; import com.apple.foundationdb.linear.HalfRealVector; @@ -31,16 +33,17 @@ import com.apple.foundationdb.linear.RealVector; import com.apple.foundationdb.linear.StoredVecsIterator; import com.apple.foundationdb.rabitq.EncodedRealVector; +import com.apple.foundationdb.subspace.Subspace; import com.apple.foundationdb.test.TestDatabaseExtension; import com.apple.foundationdb.test.TestExecutors; import com.apple.foundationdb.test.TestSubspaceExtension; import com.apple.foundationdb.tuple.Tuple; import com.apple.test.RandomSeedSource; import com.apple.test.RandomizedTestUtils; -import com.apple.test.SuperSlow; import com.apple.test.Tags; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.common.collect.ObjectArrays; @@ -50,6 +53,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.parallel.Execution; import org.junit.jupiter.api.parallel.ExecutionMode; @@ -66,8 +70,10 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Comparator; +import java.util.Deque; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -79,6 +85,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BinaryOperator; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.LongStream; @@ -101,7 +108,7 @@ class HNSWTest { @RegisterExtension static final TestDatabaseExtension dbExtension = new TestDatabaseExtension(); @RegisterExtension - TestSubspaceExtension rtSubspace = new TestSubspaceExtension(dbExtension); + TestSubspaceExtension hnswSubspace = new TestSubspaceExtension(dbExtension); @RegisterExtension TestSubspaceExtension rtSecondarySubspace = new TestSubspaceExtension(dbExtension); @@ -119,7 +126,7 @@ void testCompactSerialization(final long seed) { final int numDimensions = 768; final CompactStorageAdapter storageAdapter = new CompactStorageAdapter(HNSW.newConfigBuilder().build(numDimensions), CompactNode.factory(), - rtSubspace.getSubspace(), OnWriteListener.NOOP, OnReadListener.NOOP); + hnswSubspace.getSubspace(), OnWriteListener.NOOP, OnReadListener.NOOP); final AbstractNode originalNode = db.run(tr -> { final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); @@ -159,7 +166,7 @@ void testInliningSerialization(final long seed) { final int numDimensions = 768; final InliningStorageAdapter storageAdapter = new InliningStorageAdapter(HNSW.newConfigBuilder().build(numDimensions), - InliningNode.factory(), rtSubspace.getSubspace(), + InliningNode.factory(), hnswSubspace.getSubspace(), OnWriteListener.NOOP, OnReadListener.NOOP); final Node originalNode = db.run(tr -> { @@ -208,10 +215,11 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e final Metric metric = Metric.EUCLIDEAN_METRIC; final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + final TestOnWriteListener onWriteListener = new TestOnWriteListener(); final TestOnReadListener onReadListener = new TestOnReadListener(); final int numDimensions = 128; - final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + final HNSW hnsw = new HNSW(hnswSubspace.getSubspace(), TestExecutors.defaultThreadPool(), HNSW.newConfigBuilder().setMetric(metric) .setUseInlining(useInlining).setExtendCandidates(extendCandidates) .setKeepPrunedConnections(keepPrunedConnections) @@ -221,17 +229,17 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e .setMaintainStatsProbability(0.1d) .setStatsThreshold(100) .setM(32).setMMax(32).setMMax0(64).build(numDimensions), - OnWriteListener.NOOP, onReadListener); + onWriteListener, onReadListener); final int k = 50; final HalfRealVector queryVector = createRandomHalfVector(random, numDimensions); final TreeSet recordsOrderedByDistance = new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance)); - for (int i = 0; i < 1000;) { - i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, - tr -> { - final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); + for (long i = 0; i < 1000;) { + final Stats batchStats = insertBatch(hnsw, i, 100, + id -> { + final var primaryKey = primaryKey(id); final HalfRealVector dataVector = createRandomHalfVector(random, numDimensions); final double distance = metric.distance(dataVector, queryVector); final PrimaryKeyVectorAndDistance record = @@ -242,6 +250,8 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e } return record; }); + batchStats.logInsertInfo(i); + i += batchStats.getNumRecords(); } onReadListener.reset(); @@ -267,7 +277,7 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e final double recall = (double)recallCount / (double)k; logger.info("search transaction took elapsedTime={}ms; read nodes={}, read bytes={}, recall={}", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), - onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer(), + onReadListener.getNodesReadByLayer(), onReadListener.getBytesReadByLayer(), String.format(Locale.ROOT, "%.2f", recall * 100.0d)); Assertions.assertThat(recall).isGreaterThan(0.9); @@ -295,7 +305,7 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); final int numDimensions = 128; - final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + final HNSW hnsw = new HNSW(hnswSubspace.getSubspace(), TestExecutors.defaultThreadPool(), HNSW.newConfigBuilder().setMetric(metric) .setUseRaBitQ(true) .setRaBitQNumExBits(5) @@ -303,7 +313,7 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { .setMaintainStatsProbability(1.0d) // for every vector we maintain the stats .setStatsThreshold(950) // after 950 vectors we enable RaBitQ .setM(32).setMMax(32).setMMax0(64).build(numDimensions), - OnWriteListener.NOOP, OnReadListener.NOOP); + new TestOnWriteListener(), new TestOnReadListener()); final int k = 499; final DoubleRealVector queryVector = createRandomDoubleVector(random, numDimensions); @@ -311,10 +321,10 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { final TreeSet recordsOrderedByDistance = new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance)); - for (int i = 0; i < 1000;) { - i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, new TestOnReadListener(), - tr -> { - final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); + for (long i = 0; i < 1000;) { + final Stats batchStats = insertBatch(hnsw, i, 100, + id -> { + final var primaryKey = primaryKey(id); final DoubleRealVector dataVector = createRandomDoubleVector(random, numDimensions); final double distance = metric.distance(dataVector, queryVector); dataMap.put(primaryKey, dataVector); @@ -327,6 +337,8 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { } return record; }); + batchStats.logInsertInfo(i); + i += batchStats.getNumRecords(); } // @@ -384,41 +396,42 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { Assertions.assertThat(encodedVectorCount).isGreaterThan(0); } - private int basicInsertBatch(final HNSW hnsw, final int batchSize, - @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, - @Nonnull final Function insertFunction) { + private Stats insertBatch(@Nonnull final HNSW hnsw, final long startIndex, final int batchSize, + @Nonnull final Function generatorFunction) { + final TestOnReadListener onReadListener = (TestOnReadListener)hnsw.getOnReadListener(); + final TestOnWriteListener onWriteListener = (TestOnWriteListener)hnsw.getOnWriteListener(); return db.run(tr -> { onReadListener.reset(); - final long nextNodeId = nextNodeIdAtomic.get(); + onWriteListener.reset(); final long beginTs = System.nanoTime(); for (int i = 0; i < batchSize; i ++) { - final var record = insertFunction.apply(tr); + final var record = generatorFunction.apply(startIndex + i); if (record == null) { - return i; + return new Stats(i, System.nanoTime() - beginTs, + onReadListener.getNodesReadByLayer(), onReadListener.getBytesReadByLayer(), + onWriteListener.getNodesWrittenByLayer(), onWriteListener.getBytesReadByLayer(), + onReadListener.getSumMByLayer(), 0, 0, 0); } hnsw.insert(tr, record.getPrimaryKey(), record.getVector()).join(); } - final long endTs = System.nanoTime(); - logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, readBytes={}", - batchSize, nextNodeId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), - onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer()); - return batchSize; + return new Stats(batchSize, System.nanoTime() - beginTs, + onReadListener.getNodesReadByLayer(), onReadListener.getBytesReadByLayer(), + onWriteListener.getNodesWrittenByLayer(), onWriteListener.getBytesReadByLayer(), + onReadListener.getSumMByLayer(), 0, 0, 0); }); } @Test - @SuperSlow + //@SuperSlow void testSIFTInsertSmall() throws Exception { final Metric metric = Metric.EUCLIDEAN_METRIC; final int k = 100; final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); - final TestOnReadListener onReadListener = new TestOnReadListener(); - - final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + final HNSW hnsw = new HNSW(hnswSubspace.getSubspace(), TestExecutors.defaultThreadPool(), HNSW.newConfigBuilder().setUseRaBitQ(true).setRaBitQNumExBits(5) .setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(128), - OnWriteListener.NOOP, onReadListener); + new TestOnWriteListener(), new TestOnReadListener()); final Path siftSmallPath = Paths.get(".out/extracted/siftsmall/siftsmall_base.fvecs"); @@ -427,16 +440,16 @@ void testSIFTInsertSmall() throws Exception { try (final var fileChannel = FileChannel.open(siftSmallPath, StandardOpenOption.READ)) { final Iterator vectorIterator = new StoredVecsIterator.StoredFVecsIterator(fileChannel); - int i = 0; + long i = 0; final AtomicReference sumReference = new AtomicReference<>(null); while (vectorIterator.hasNext()) { - i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, - tr -> { + final Stats batchStats = insertBatch(hnsw, i, 100, + id -> { if (!vectorIterator.hasNext()) { return null; } final DoubleRealVector doubleVector = vectorIterator.next(); - final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final Tuple currentPrimaryKey = primaryKey(id); final HalfRealVector currentVector = doubleVector.toHalfRealVector(); if (sumReference.get() == null) { @@ -448,6 +461,8 @@ void testSIFTInsertSmall() throws Exception { dataMap.put(Math.toIntExact(currentPrimaryKey.getLong(0)), currentVector); return new PrimaryKeyAndVector(currentPrimaryKey, currentVector); }); + batchStats.logInsertInfo(i); + i += batchStats.getNumRecords(); } Assertions.assertThat(i).isEqualTo(10000); } @@ -480,7 +495,7 @@ private void validateSIFTSmall(@Nonnull final HNSW hnsw, @Nonnull final Map insertStatsQueue = new ArrayDeque<>(); + Stats runningInsertStats = null; + final Deque queryStatsQueue = new ArrayDeque<>(); + Stats runningQueryStats = null; + + final Path siftBasePath = Paths.get(".out/downloads/sift_base.fvecs"); + try (final var fileChannel = FileChannel.open(siftBasePath, StandardOpenOption.READ)) { + final Iterator vectorIterator = new StoredVecsIterator.StoredFVecsIterator(fileChannel); + + boolean doneSkipping = false; + long i = 0; + while (vectorIterator.hasNext()) { + if (!doneSkipping) { + if (i < skip || db.run(transaction -> + hnsw.exists(transaction, Tuple.from(nextNodeIdAtomic.get()))).join()) { + if (skip >= i) { + if (i % 5000 == 0) { + logger.info("skipping numRecords = {}", i); + } + } else { + if (i % 10 == 0) { + logger.info("skipping records since record exists numRecords = {}", i); + } + } + + vectorIterator.next(); + i++; + nextNodeIdAtomic.set(i); + continue; + } + + doneSkipping = true; + logger.info("done skipping numRecords = {}", i); + } + + final Stats insertStats = insertBatch(hnsw, i, 10, + id -> { + if (!vectorIterator.hasNext()) { + return null; + } + final DoubleRealVector doubleVector = vectorIterator.next(); + final Tuple currentPrimaryKey = primaryKey(id); + final HalfRealVector currentVector = doubleVector.toHalfRealVector(); + + return new PrimaryKeyAndVector(currentPrimaryKey, currentVector); + }); + insertStats.logInsertTrace(i); + + if (runningInsertStats == null) { + runningInsertStats = insertStats; + insertStatsQueue.addFirst(insertStats); + } else { + while (insertStatsQueue.size() >= statsQueueMaxSize) { + final Stats removedStats = insertStatsQueue.removeLast(); + runningInsertStats = runningInsertStats.remove(removedStats); + } + insertStatsQueue.addFirst(insertStats); + runningInsertStats = runningInsertStats.add(insertStats); + } + i += insertStats.getNumRecords(); + + if (i % 1000 == 0) { + runningInsertStats.logInsertAveragesInfo(i - 1000); + + final Stats queryStats = querySIFT(hnsw, i); + if (queryStats != null) { + if (runningQueryStats == null) { + runningQueryStats = queryStats; + queryStatsQueue.addFirst(queryStats); + } else { + while (queryStatsQueue.size() >= statsQueueMaxSize) { + final Stats removedStats = queryStatsQueue.removeLast(); + runningQueryStats = runningQueryStats.remove(removedStats); + } + queryStatsQueue.addFirst(queryStats); + runningQueryStats = runningQueryStats.add(queryStats); + } + + runningQueryStats.logQueryAveragesInfo(i); + } + } + } + } + } + + @Test + void testSIFTQuery1M() throws Exception { + final Subspace subspace = getSameSubspace("testSIFTInsert1M"); + final Metric metric = Metric.EUCLIDEAN_METRIC; + + final HNSW hnsw = new HNSW(subspace, TestExecutors.defaultThreadPool(), + HNSW.newConfigBuilder().setUseRaBitQ(true).setRaBitQNumExBits(5) + .setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(128), + new TestOnWriteListener(), new TestOnReadListener()); + + final Stats queryStats = Objects.requireNonNull(querySIFT(hnsw, 430000)); + queryStats.logQueryAveragesInfo(430000); + } + + @Nullable + private Stats querySIFT(@Nonnull final HNSW hnsw, final long numInserted) throws IOException { + final Path siftGroundTruthPath = Paths.get(".out/downloads/sift_groundtruth.ivecs"); + final Path siftQueryPath = Paths.get(".out/downloads/sift_query.fvecs"); + + final TestOnReadListener onReadListener = (TestOnReadListener)hnsw.getOnReadListener(); + + Stats stats = null; + try (final var queryChannel = FileChannel.open(siftQueryPath, StandardOpenOption.READ); + final var groundTruthChannel = FileChannel.open(siftGroundTruthPath, StandardOpenOption.READ)) { + final Iterator queryIterator = new StoredVecsIterator.StoredFVecsIterator(queryChannel); + final Iterator> groundTruthIterator = new StoredVecsIterator.StoredIVecsIterator(groundTruthChannel); + + Verify.verify(queryIterator.hasNext() == groundTruthIterator.hasNext()); + + int i = 0; + while (queryIterator.hasNext()) { + final DoubleRealVector queryVector = queryIterator.next(); + final Set groundTruthIndices = Sets.newHashSet(groundTruthIterator.next()); + // remove all indexes for items not yet inserted + groundTruthIndices.removeIf(id -> id >= numInserted); + if (groundTruthIndices.isEmpty()) { + continue; + } + + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, groundTruthIndices.size(), + efSearchFromK(groundTruthIndices.size()), true, queryVector).join()); + final long endTs = System.nanoTime(); + + int recallCount = 0; + for (final ResultEntry resultEntry : results) { + final int primaryKeyIndex = (int)resultEntry.getPrimaryKey().getLong(0); + if (groundTruthIndices.contains(primaryKeyIndex)) { + recallCount ++; + } + } + + final Stats currentStats = + new Stats(0L, endTs - beginTs, + onReadListener.getNodesReadByLayer(), onReadListener.getBytesReadByLayer(), + ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), 1, + groundTruthIndices.size(), recallCount); + + if (stats == null) { + stats = currentStats; + } else { + stats = stats.add(currentStats); + } + + i ++; + if (i >= 100) { + break; + } + } + } + return stats; + } + + private static int efSearchFromK(final int k) { + return 100; + //return Math.min(Math.max(4 * k, 64), 400); + } + + @Nonnull + private Subspace getSameSubspace(@Nonnull final String name) { + return dbExtension.getDatabase().runAsync(tr -> + DirectoryLayer.getDefault().createOrOpen(tr, List.of("fdb-extensions-test")) + .thenApply(directorySubspace -> directorySubspace.subspace(Tuple.from(name))) + ).join(); + } + + private void clearSubspace(@Nonnull final Subspace subspace) { + dbExtension.getDatabase().run(tx -> { + tx.clear(Range.startsWith(subspace.pack())); + return null; + }); + } + private void writeNode(@Nonnull final Transaction transaction, @Nonnull final StorageAdapter storageAdapter, @Nonnull final AbstractNode node, @@ -575,42 +790,48 @@ private static Tuple createRandomPrimaryKey(final @Nonnull Random random) { } @Nonnull - private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic) { - return Tuple.from(nextIdAtomic.getAndIncrement()); + private static Tuple primaryKey(final long id) { + return Tuple.from(id); } private static class TestOnReadListener implements OnReadListener { - final Map nodeCountByLayer; - final Map sumMByLayer; - final Map bytesReadByLayer; + @Nonnull + private final Map nodesReadByLayer; + @Nonnull + private final Map bytesReadByLayer; + @Nonnull + private final Map sumMByLayer; public TestOnReadListener() { - this.nodeCountByLayer = Maps.newConcurrentMap(); - this.sumMByLayer = Maps.newConcurrentMap(); + this.nodesReadByLayer = Maps.newConcurrentMap(); this.bytesReadByLayer = Maps.newConcurrentMap(); + this.sumMByLayer = Maps.newConcurrentMap(); } - public Map getNodeCountByLayer() { - return nodeCountByLayer; + @Nonnull + public Map getNodesReadByLayer() { + return nodesReadByLayer; } + @Nonnull public Map getBytesReadByLayer() { return bytesReadByLayer; } + @Nonnull public Map getSumMByLayer() { return sumMByLayer; } public void reset() { - nodeCountByLayer.clear(); + nodesReadByLayer.clear(); bytesReadByLayer.clear(); sumMByLayer.clear(); } @Override public void onNodeRead(final int layer, @Nonnull final Node node) { - nodeCountByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + 1L); + nodesReadByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + 1L); sumMByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + node.getNeighbors().size()); } @@ -621,6 +842,44 @@ public void onKeyValueRead(final int layer, @Nonnull final byte[] key, @Nullable } } + private static class TestOnWriteListener implements OnWriteListener { + @Nonnull + private final Map nodesWrittenByLayer; + @Nonnull + private final Map bytesWrittenByLayer; + + public TestOnWriteListener() { + this.nodesWrittenByLayer = Maps.newConcurrentMap(); + this.bytesWrittenByLayer = Maps.newConcurrentMap(); + } + + @Nonnull + public Map getNodesWrittenByLayer() { + return nodesWrittenByLayer; + } + + @Nonnull + public Map getBytesReadByLayer() { + return bytesWrittenByLayer; + } + + public void reset() { + nodesWrittenByLayer.clear(); + bytesWrittenByLayer.clear(); + } + + @Override + public void onNodeWritten(final int layer, @Nonnull final Node node) { + nodesWrittenByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + 1L); + } + + @Override + public void onKeyValueWritten(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + bytesWrittenByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + + key.length + value.length); + } + } + private static class PrimaryKeyAndVector { @Nonnull private final Tuple primaryKey; @@ -658,4 +917,226 @@ public double getDistance() { return distance; } } + + private static class Stats { + private final long numRecords; + private final long elapsedTimeNs; + @Nonnull + private final Map nodesReadByLayerMap; + @Nonnull + private final Map bytesReadByLayerMap; + @Nonnull + private final Map nodesWrittenByLayerMap; + @Nonnull + private final Map bytesWrittenByLayerMap; + @Nonnull + @SuppressWarnings("checkstyle:MemberName") + private final Map mByLayerMap; + private final long numQueries; + private final long numResults; + private final long numRecall; + + public Stats(final long numRecords, final long elapsedTimeNs, + @Nonnull final Map nodesReadByLayerMap, + @Nonnull final Map bytesReadByLayerMap, + @Nonnull final Map nodesWrittenByLayerMap, + @Nonnull final Map bytesWrittenByLayerMap, + @Nonnull final Map mByLayerMap, + final long numQueries, + final long numResults, + final long numRecall) { + this.numRecords = numRecords; + this.elapsedTimeNs = elapsedTimeNs; + this.nodesReadByLayerMap = ImmutableMap.copyOf(nodesReadByLayerMap); + this.bytesReadByLayerMap = ImmutableMap.copyOf(bytesReadByLayerMap); + this.nodesWrittenByLayerMap = ImmutableMap.copyOf(nodesWrittenByLayerMap); + this.bytesWrittenByLayerMap = ImmutableMap.copyOf(bytesWrittenByLayerMap); + this.mByLayerMap = ImmutableMap.copyOf(mByLayerMap); + this.numQueries = numQueries; + this.numResults = numResults; + this.numRecall = numRecall; + } + + public long getNumRecords() { + return numRecords; + } + + public long getElapsedTimeNs() { + return elapsedTimeNs; + } + + @Nonnull + public Map getNodesReadByLayerMap() { + return nodesReadByLayerMap; + } + + @Nonnull + public Map getBytesReadByLayerMap() { + return bytesReadByLayerMap; + } + + @Nonnull + public Map getNodesWrittenByLayerMap() { + return nodesWrittenByLayerMap; + } + + @Nonnull + public Map getBytesWrittenByLayerMap() { + return bytesWrittenByLayerMap; + } + + @Nonnull + public Map getMByLayerMap() { + return mByLayerMap; + } + + public long getNumQueries() { + return numQueries; + } + + public long getNumResults() { + return numResults; + } + + public long getNumRecall() { + return numRecall; + } + + @Nonnull + public Stats add(@Nonnull final Stats other) { + return new Stats(getNumRecords() + other.getNumRecords(), + getElapsedTimeNs() + other.getElapsedTimeNs(), + aggregateMap(getNodesReadByLayerMap(), other.getNodesReadByLayerMap(), Long::sum), + aggregateMap(getBytesReadByLayerMap(), other.getBytesReadByLayerMap(), Long::sum), + aggregateMap(getNodesWrittenByLayerMap(), other.getNodesWrittenByLayerMap(), Long::sum), + aggregateMap(getBytesWrittenByLayerMap(), other.getBytesWrittenByLayerMap(), Long::sum), + aggregateMap(getMByLayerMap(), other.getMByLayerMap(), Long::sum), + getNumQueries() + other.getNumQueries(), + getNumResults() + other.getNumResults(), + getNumRecall() + other.getNumRecall()); + } + + @Nonnull + public Stats remove(@Nonnull final Stats other) { + return new Stats(getNumRecords() - other.getNumRecords(), + getElapsedTimeNs() - other.getElapsedTimeNs(), + aggregateMap(getNodesReadByLayerMap(), other.getNodesReadByLayerMap(), (l, r) -> l - r), + aggregateMap(getBytesReadByLayerMap(), other.getBytesReadByLayerMap(), (l, r) -> l - r), + aggregateMap(getNodesWrittenByLayerMap(), other.getNodesWrittenByLayerMap(), (l, r) -> l - r), + aggregateMap(getBytesWrittenByLayerMap(), other.getBytesWrittenByLayerMap(), (l, r) -> l - r), + aggregateMap(getMByLayerMap(), other.getMByLayerMap(), (l, r) -> l - r), + getNumQueries() - other.getNumQueries(), + getNumResults() - other.getNumResults(), + getNumRecall() - other.getNumRecall()); + } + + public void logInsertInfo(final long index) { + if (logger.isInfoEnabled()) { + logger.info(getInsertLogMessage(index)); + } + } + + public void logInsertTrace(final long index) { + if (logger.isTraceEnabled()) { + logger.trace(getInsertLogMessage(index)); + } + } + + @Nonnull + private String getInsertLogMessage(final long index) { + return String.format("inserted batchSize=%d records starting at nodeId=%d took elapsedTime=%dms, nodesRead=%s, bytesRead=%s", + getNumRecords(), index, TimeUnit.NANOSECONDS.toMillis(getElapsedTimeNs()), + getNodesReadByLayerMap(), getBytesReadByLayerMap()); + } + + public void logInsertAveragesInfo(final long index) { + if (logger.isInfoEnabled()) { + logger.info(getInsertAveragesLogMessage(index)); + } + } + + @Nonnull + private String getInsertAveragesLogMessage(final long index) { + //return String.format("after inserting %d records starting at nodeId=%d; elapsedTime=%dms, nodesRead=%d, bytesRead=%d, nodesWritten=%d, bytesWritten=%d, m=%d, nodesRead=%s, bytesRead=%s, nodesWritten=%s, bytesWritten=%s, m=%s", + return String.format("i %d,%d,%d,%d,%d,%d,%d,%d,%s,%s,%s,%s,%s", + getNumRecords(), index, + TimeUnit.NANOSECONDS.toMillis(getElapsedTimeNs() / getNumRecords()), + sumMap(getNodesReadByLayerMap()) / getNumRecords(), + sumMap(getBytesReadByLayerMap()) / getNumRecords(), + sumMap(getNodesWrittenByLayerMap()) / getNumRecords(), + sumMap(getBytesWrittenByLayerMap()) / getNumRecords(), + sumMap(getMByLayerMap()) / sumMap(getNodesReadByLayerMap()), + averageOfMap(getNodesReadByLayerMap(), getNumRecords()), + averageOfMap(getBytesReadByLayerMap(), getNumRecords()), + averageOfMap(getNodesWrittenByLayerMap(), getNumRecords()), + averageOfMap(getBytesWrittenByLayerMap(), getNumRecords()), + averageOfMap(getMByLayerMap(), getNodesReadByLayerMap())); + } + + public void logQueryAveragesInfo(final long index) { + if (logger.isInfoEnabled()) { + logger.info(getQueryAveragesLogMessage(index)); + } + } + + @Nonnull + private String getQueryAveragesLogMessage(final long index) { + //return String.format("querying, num=%d; averages after inserting %d records took elapsedTime=%dms, recall=%.2f, nodesRead=%d, bytesRead=%d, nodesRead=%s, bytesRead=%s", + return String.format("%d,%d,%d,%.2f,%d,%d,%s,%s", + getNumQueries(), + index, + TimeUnit.NANOSECONDS.toMillis(getElapsedTimeNs() / getNumQueries()), + (double)getNumRecall() * 100.0d / getNumResults(), + sumMap(getNodesReadByLayerMap()) / getNumQueries(), + sumMap(getBytesReadByLayerMap()) / getNumQueries(), + averageOfMap(getNodesReadByLayerMap(), getNumQueries()), + averageOfMap(getBytesReadByLayerMap(), getNumQueries())); + } + + @Nonnull + private static Map aggregateMap(@Nonnull final Map map1, + @Nonnull final Map map2, + @Nonnull final BinaryOperator operator) { + final ImmutableMap.Builder resultBuilder = ImmutableMap.builder(); + + for (final Map.Entry entry1 : map1.entrySet()) { + if (map2.containsKey(entry1.getKey())) { + resultBuilder.put(entry1.getKey(), operator.apply(entry1.getValue(), map2.get(entry1.getKey()))); + } else { + resultBuilder.put(entry1); + } + } + + for (final Map.Entry entry2 : map2.entrySet()) { + if (!map1.containsKey(entry2.getKey())) { + resultBuilder.put(entry2); + } + } + + return resultBuilder.build(); + } + + @Nonnull + private static Map averageOfMap(@Nonnull final Map map, final long numRecords) { + final ImmutableMap.Builder resultBuilder = ImmutableMap.builder(); + for (final Map.Entry entry : map.entrySet()) { + resultBuilder.put(entry.getKey(), entry.getValue() / numRecords); + } + return resultBuilder.build(); + } + + @Nonnull + private static Map averageOfMap(@Nonnull final Map dividentMap, + @Nonnull final Map divisorMap) { + final ImmutableMap.Builder resultBuilder = ImmutableMap.builder(); + for (final Map.Entry entry : dividentMap.entrySet()) { + resultBuilder.put(entry.getKey(), entry.getValue() / divisorMap.get(entry.getKey())); + } + return resultBuilder.build(); + } + + private static long sumMap(@Nonnull final Map map) { + return map.values().stream().mapToLong(v -> v).sum(); + } + } }