Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,10 @@
*/
package org.apache.lucene.util.hnsw;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.util.Comparator;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.internal.hppc.IntIntHashMap;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;

/** This merger merges graph in a concurrent manner, by using {@link HnswConcurrentMergeBuilder} */
public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
Expand All @@ -54,89 +45,13 @@ public ConcurrentHnswMerger(
@Override
protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd)
throws IOException {
OnHeapHnswGraph graph;
BitSet initializedNodes = null;

if (graphReaders.size() == 0) {
graph = new OnHeapHnswGraph(M, maxOrd);
} else {
graphReaders.sort(Comparator.comparingInt(GraphReader::graphSize).reversed());
GraphReader initGraphReader = graphReaders.get(0);
KnnVectorsReader initReader = initGraphReader.reader();
MergeState.DocMap initDocMap = initGraphReader.initDocMap();
int initGraphSize = initGraphReader.graphSize();
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);

if (initializerGraph.size() == 0) {
graph = new OnHeapHnswGraph(M, maxOrd);
} else {
initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap =
getNewOrdMapping(
fieldInfo,
initReader,
initDocMap,
initGraphSize,
mergedVectorValues,
initializedNodes);
graph = InitializedHnswGraphBuilder.initGraph(initializerGraph, oldToNewOrdinalMap, maxOrd);
}
}
GraphMergeContext mergeContext = prepareGraphMerge(mergedVectorValues, maxOrd);
return new HnswConcurrentMergeBuilder(
taskExecutor, numWorker, scorerSupplier, beamWidth, graph, initializedNodes);
}

/**
* Creates a new mapping from old ordinals to new ordinals and returns the total number of vectors
* in the newly merged segment.
*
* @param mergedVectorValues vector values in the merged segment
* @param initializedNodes track what nodes have been initialized
* @return the mapping from old ordinals to new ordinals
* @throws IOException If an error occurs while reading from the merge state
*/
private static int[] getNewOrdMapping(
FieldInfo fieldInfo,
KnnVectorsReader initReader,
MergeState.DocMap initDocMap,
int initGraphSize,
KnnVectorValues mergedVectorValues,
BitSet initializedNodes)
throws IOException {
KnnVectorValues.DocIndexIterator initializerIterator = null;

switch (fieldInfo.getVectorEncoding()) {
case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name).iterator();
case FLOAT32 ->
initializerIterator = initReader.getFloatVectorValues(fieldInfo.name).iterator();
}

IntIntHashMap newIdToOldOrdinal = new IntIntHashMap(initGraphSize);
int maxNewDocID = -1;
for (int docId = initializerIterator.nextDoc();
docId != NO_MORE_DOCS;
docId = initializerIterator.nextDoc()) {
int newId = initDocMap.get(docId);
maxNewDocID = Math.max(newId, maxNewDocID);
assert newIdToOldOrdinal.containsKey(newId) == false;
newIdToOldOrdinal.put(newId, initializerIterator.index());
}

if (maxNewDocID == -1) {
return new int[0];
}
final int[] oldToNewOrdinalMap = new int[initGraphSize];
KnnVectorValues.DocIndexIterator mergedVectorIterator = mergedVectorValues.iterator();
for (int newDocId = mergedVectorIterator.nextDoc();
newDocId <= maxNewDocID;
newDocId = mergedVectorIterator.nextDoc()) {
int oldOrd = newIdToOldOrdinal.getOrDefault(newDocId, -1);
if (oldOrd != -1) {
int newOrd = mergedVectorIterator.index();
initializedNodes.set(newOrd);
oldToNewOrdinalMap[oldOrd] = newOrd;
}
}
return oldToNewOrdinalMap;
taskExecutor,
numWorker,
scorerSupplier,
beamWidth,
M,
mergeContext);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.apache.lucene.util.hnsw;

import org.apache.lucene.util.BitSet;

/**
* A helper class to hold the context of merging graphs.
*
* @param initGraphs graphs that will be participated in initialization. For now, it's all graphs that does not have
* any deletion. If there are no such graphs, it will be null.
* @param oldToNewOrdinalMaps for each graph in {@code initGraphs}, it's the old to new ordinal mapping.
* @param maxOrd max ordinal of the new (to be created) graph
* @param initializedNodes all new ordinals that are included in the {@code initGraphs}, they might have already
* been initialized, as part of the very first graph, or will be initialized in a later process, e.g. see
* {@link UpdateGraphsUtils#joinSetGraphMerge(HnswGraph, HnswGraph, int[], HnswBuilder)} Note: in case of
* {@code initGraphs} is non-null but this field is null, it means all ordinals are/will be initialized.
*/
record GraphMergeContext(HnswGraph[] initGraphs, int[][] oldToNewOrdinalMaps, int maxOrd, BitSet initializedNodes) {

public boolean allInitialized() {
return initGraphs != null && initializedNodes == null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ public HnswConcurrentMergeBuilder(
int numWorker,
RandomVectorScorerSupplier scorerSupplier,
int beamWidth,
OnHeapHnswGraph hnsw,
BitSet initializedNodes)
int M,
GraphMergeContext graphMergeContext)
throws IOException {
OnHeapHnswGraph hnsw = initGraph(M, graphMergeContext);
this.taskExecutor = taskExecutor;
AtomicInteger workProgress = new AtomicInteger(0);
AtomicInteger nextGraphToMerge = new AtomicInteger(1);
workers = new ConcurrentMergeWorker[numWorker];
hnswLock = new HnswLock();
for (int i = 0; i < numWorker; i++) {
Expand All @@ -66,11 +68,20 @@ public HnswConcurrentMergeBuilder(
HnswGraphBuilder.randSeed,
hnsw,
hnswLock,
initializedNodes,
workProgress);
workProgress,
nextGraphToMerge,
graphMergeContext
);
}
}

private static OnHeapHnswGraph initGraph(int M, GraphMergeContext graphMergeContext) throws IOException {
if (graphMergeContext.initGraphs() == null || graphMergeContext.initGraphs().length == 0) {
return new OnHeapHnswGraph(M, graphMergeContext.maxOrd());
}
return InitializedHnswGraphBuilder.initGraph(graphMergeContext.initGraphs()[0], graphMergeContext.oldToNewOrdinalMaps()[0], graphMergeContext.maxOrd());
}

@Override
public OnHeapHnswGraph build(int maxOrd) throws IOException {
if (frozen) {
Expand Down Expand Up @@ -146,6 +157,14 @@ private static final class ConcurrentMergeWorker extends HnswGraphBuilder {
*/
private final AtomicInteger workProgress;

/**
* A common AtomicInteger shared among all workers, tracking which graph to merge next, if initGraphs is null, then
* this field will be ignored
*/
private final AtomicInteger nextGraphToMerge;

private final GraphMergeContext graphMergeContext;

private final BitSet initializedNodes;
private int batchSize = DEFAULT_BATCH_SIZE;

Expand All @@ -155,8 +174,10 @@ private ConcurrentMergeWorker(
long seed,
OnHeapHnswGraph hnsw,
HnswLock hnswLock,
BitSet initializedNodes,
AtomicInteger workProgress)
AtomicInteger workProgress,
AtomicInteger nextGraphToMerge,
GraphMergeContext graphMergeContext
)
throws IOException {
super(
scorerSupplier,
Expand All @@ -167,7 +188,9 @@ private ConcurrentMergeWorker(
new MergeSearcher(
new NeighborQueue(beamWidth, true), hnswLock, new FixedBitSet(hnsw.maxNodeId() + 1)));
this.workProgress = workProgress;
this.initializedNodes = initializedNodes;
this.nextGraphToMerge = nextGraphToMerge;
this.graphMergeContext = graphMergeContext;
this.initializedNodes = graphMergeContext.initializedNodes();
}

/**
Expand All @@ -177,6 +200,18 @@ private ConcurrentMergeWorker(
* finishing around the same time.
*/
private void run(int maxOrd) throws IOException {
while (graphMergeContext.initGraphs() != null && nextGraphToMerge.get() < graphMergeContext.initGraphs().length) {
int graphToWork = nextGraphToMerge.getAndIncrement();
if (graphToWork >= graphMergeContext.initGraphs().length) {
break;
}
UpdateGraphsUtils.joinSetGraphMerge(
graphMergeContext.initGraphs()[graphToWork], hnsw, getGraphSearcher(), graphMergeContext.oldToNewOrdinalMaps()[graphToWork], this);
}
if (graphMergeContext.allInitialized()) {
// all the work has been done above since all the graphs are in the initGraphs set
return;
}
int start = getStartPos(maxOrd);
int end;
while (start != -1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ public OnHeapHnswGraph getGraph() {
return hnsw;
}

public HnswGraphSearcher getGraphSearcher() {
return graphSearcher;
}

/** add vectors in range [minOrd, maxOrd) */
protected void addVectors(int minOrd, int maxOrd) throws IOException {
if (frozen) {
Expand Down Expand Up @@ -646,5 +650,6 @@ public TopDocs topDocs() {
public KnnSearchStrategy getSearchStrategy() {
return null;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger {
protected final int M;
protected final int beamWidth;

protected List<GraphReader> graphReaders = new ArrayList<>();
protected List<GraphReader> initGraphReaders = new ArrayList<>();
private int numReaders = 0;

/** Represents a vector reader that contains graph info. */
Expand Down Expand Up @@ -98,7 +98,7 @@ public IncrementalHnswGraphMerger addReader(
candidateVectorCount = vectorValues.size();
}
}
graphReaders.add(new GraphReader(reader, docMap, candidateVectorCount));
initGraphReaders.add(new GraphReader(reader, docMap, candidateVectorCount));
return this;
}

Expand All @@ -112,57 +112,49 @@ public IncrementalHnswGraphMerger addReader(
*/
protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd)
throws IOException {
if (graphReaders.size() == 0) {
GraphMergeContext mergeContext = prepareGraphMerge(mergedVectorValues, maxOrd);
if (mergeContext.initGraphs() == null || mergeContext.initGraphs().length == 0) {
return HnswGraphBuilder.create(
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, maxOrd);
}
graphReaders.sort(Comparator.comparingInt(GraphReader::graphSize).reversed());

final BitSet initializedNodes =
graphReaders.size() == numReaders ? null : new FixedBitSet(maxOrd);
int[][] ordMaps = getNewOrdMapping(mergedVectorValues, initializedNodes);
HnswGraph[] graphs = new HnswGraph[graphReaders.size()];
for (int i = 0; i < graphReaders.size(); i++) {
HnswGraph graph = ((HnswGraphProvider) graphReaders.get(i).reader).getGraph(fieldInfo.name);
if (graph.size() == 0) {
throw new IllegalStateException("Graph should not be empty");
}
graphs[i] = graph;
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, mergeContext.maxOrd());
}

return MergingHnswGraphBuilder.fromGraphs(
scorerSupplier,
beamWidth,
HnswGraphBuilder.randSeed,
graphs,
ordMaps,
maxOrd,
initializedNodes);
mergeContext.initGraphs(),
mergeContext.oldToNewOrdinalMaps(),
mergeContext.maxOrd(),
mergeContext.initializedNodes());
}

protected final int[][] getNewOrdMapping(
/**
* Get old -> new ordinal mapping for all graphs that are participated in initialization.
* A.k.a for all graphs has been added to {@link #initGraphReaders}
*/
protected final int[][] getOldToNewOrdMapping(
KnnVectorValues mergedVectorValues, BitSet initializedNodes) throws IOException {
final int numGraphs = graphReaders.size();
final int numGraphs = initGraphReaders.size();
IntIntHashMap[] newDocIdToOldOrdinals = new IntIntHashMap[numGraphs];
final int[][] oldToNewOrdinalMap = new int[numGraphs][];
for (int i = 0; i < numGraphs; i++) {
KnnVectorValues.DocIndexIterator vectorsIter = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE ->
vectorsIter = graphReaders.get(i).reader.getByteVectorValues(fieldInfo.name).iterator();
vectorsIter = initGraphReaders.get(i).reader.getByteVectorValues(fieldInfo.name).iterator();
case FLOAT32 ->
vectorsIter =
graphReaders.get(i).reader.getFloatVectorValues(fieldInfo.name).iterator();
initGraphReaders.get(i).reader.getFloatVectorValues(fieldInfo.name).iterator();
}
newDocIdToOldOrdinals[i] = new IntIntHashMap(graphReaders.get(i).graphSize);
MergeState.DocMap docMap = graphReaders.get(i).initDocMap();
newDocIdToOldOrdinals[i] = new IntIntHashMap(initGraphReaders.get(i).graphSize);
MergeState.DocMap docMap = initGraphReaders.get(i).initDocMap();
for (int docId = vectorsIter.nextDoc();
docId != NO_MORE_DOCS;
docId = vectorsIter.nextDoc()) {
int newDocId = docMap.get(docId);
newDocIdToOldOrdinals[i].put(newDocId, vectorsIter.index());
}
oldToNewOrdinalMap[i] = new int[graphReaders.get(i).graphSize];
oldToNewOrdinalMap[i] = new int[initGraphReaders.get(i).graphSize];
}

KnnVectorValues.DocIndexIterator mergedVectorIterator = mergedVectorValues.iterator();
Expand All @@ -184,6 +176,31 @@ protected final int[][] getNewOrdMapping(
return oldToNewOrdinalMap;
}

/**
* Prepare the context for merging graphs.
* It sorts on {@link #initGraphReaders} by reverse size such that we will use the first one as the base graph, then
* prepare everything we need into {@link GraphMergeContext}
*/
protected GraphMergeContext prepareGraphMerge(KnnVectorValues mergedVectorValues, int maxOrd) throws IOException {
if (initGraphReaders.size() == 0) {
return new GraphMergeContext(null, null, maxOrd, null);
}
initGraphReaders.sort(Comparator.comparingInt(GraphReader::graphSize).reversed());

final BitSet initializedNodes =
initGraphReaders.size() == numReaders ? null : new FixedBitSet(maxOrd);
int[][] ordMaps = getOldToNewOrdMapping(mergedVectorValues, initializedNodes);
HnswGraph[] graphs = new HnswGraph[initGraphReaders.size()];
for (int i = 0; i < initGraphReaders.size(); i++) {
HnswGraph graph = ((HnswGraphProvider) initGraphReaders.get(i).reader).getGraph(fieldInfo.name);
if (graph.size() == 0) {
throw new IllegalStateException("Graph should not be empty");
}
graphs[i] = graph;
}
return new GraphMergeContext(graphs, ordMaps, maxOrd, initializedNodes);
}

@Override
public OnHeapHnswGraph merge(
KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd) throws IOException {
Expand All @@ -204,4 +221,5 @@ private static boolean hasDeletes(Bits liveDocs) {
}
return false;
}

}
Loading
Loading