From c5ff7804bd1cfd4acaa4b418f90733c0c87df795 Mon Sep 17 00:00:00 2001 From: voidwisp Date: Fri, 24 Apr 2026 15:25:10 +0100 Subject: [PATCH] perf(retrieval): prefetch ChunkBasedSearch start-node VSS call concurrently with _init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In CompositeTraversalBasedRetriever._retrieve, the entity-context phase (self._init, ~2s on Neptune Serverless + AOSS) runs strictly before each sub-retriever's get_start_node_ids. For ChunkBasedSearch.get_start_node_ids specifically, the call reads only query_bundle / vector_store / args.vss_* — it does not touch self.entity_contexts (which is what _init builds), so it has no data dependency on _init and can run concurrently with it. Override _retrieve in CompositeTraversalBasedRetriever to kick off a single-worker ThreadPoolExecutor that computes the chunk-VSS top-k via get_diverse_vss_elements before super()._retrieve(query_bundle) runs. Attach the resulting future onto each ChunkBasedSearch instance in _get_search_results_for_query. ChunkBasedSearch.get_start_node_ids pops the attribute and consumes the future via .result() if present, otherwise falls back to the existing inline VSS call. Guards: - Skip prefetch when args.derive_subqueries is True (subqueries carry different query_bundles, the prefetch was built from the original). - Skip when no ChunkBasedSearch is in the configured retriever list. - Consume-and-clear via __dict__.pop so a reused instance can't pick up a stale future on the next call. - add_done_callback logs at debug level if the prefetch raises AND _init raises first (would otherwise be swallowed). Validated against production Neptune Serverless + AOSS (toolkit v3.18.3, pool_maxsize=32), 12 representative queries, 2 warmup + 10 timed samples interleaved OLD vs NEW: - Correctness: start_node_ids set-equal across all 12 queries. - Perf: paired median delta -85 ms, paired mean -95 ms, 10/12 queries improved. Worst case +37 ms (within query-intrinsic variance). --- .../retrievers/chunk_based_search.py | 25 ++++--- .../composite_traversal_based_retriever.py | 69 ++++++++++++++++++- 2 files changed, 83 insertions(+), 11 deletions(-) diff --git a/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/chunk_based_search.py b/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/chunk_based_search.py index aa2bb5e5..540e2bc8 100644 --- a/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/chunk_based_search.py +++ b/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/chunk_based_search.py @@ -121,15 +121,22 @@ def get_start_node_ids(self, query_bundle: QueryBundle) -> List[str]: """ logger.debug('Getting start node ids for chunk-based search...') - chunks = get_diverse_vss_elements( - 'chunk', - query_bundle, - self.vector_store, - self.args.vss_diversity_factor, - self.args.vss_top_k, - self.filter_config - ) - + # CompositeTraversalBasedRetriever may attach a pre-kicked-off future + # for this call (see its _retrieve override). Pop to consume so a reused + # instance doesn't pick up a stale future on the next call. + prefetched = self.__dict__.pop('_prefetched_chunks', None) + if prefetched is not None: + chunks = prefetched.result() + else: + chunks = get_diverse_vss_elements( + 'chunk', + query_bundle, + self.vector_store, + self.args.vss_diversity_factor, + self.args.vss_top_k, + self.filter_config + ) + return [chunk['chunk']['chunkId'] for chunk in chunks] def do_graph_search(self, query_bundle: QueryBundle, start_node_ids:List[str]) -> SearchResultCollection: diff --git a/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/composite_traversal_based_retriever.py b/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/composite_traversal_based_retriever.py index 80b2627c..a78a5167 100644 --- a/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/composite_traversal_based_retriever.py +++ b/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/composite_traversal_based_retriever.py @@ -15,12 +15,31 @@ from graphrag_toolkit.lexical_graph.retrieval.utils.query_decomposition import QueryDecomposition from graphrag_toolkit.lexical_graph.retrieval.retrievers.entity_network_search import EntityNetworkSearch from graphrag_toolkit.lexical_graph.retrieval.retrievers.chunk_based_search import ChunkBasedSearch +from graphrag_toolkit.lexical_graph.retrieval.utils.vector_utils import get_diverse_vss_elements from graphrag_toolkit.lexical_graph.retrieval.model import SearchResultCollection, SearchResult from llama_index.core.schema import QueryBundle, NodeWithScore logger = logging.getLogger(__name__) + +def _log_prefetch_exception(future): + """Done-callback: surface prefetch failures that would otherwise be silent + if `_init` also raises and the future is never consumed via `.result()`. + """ + exc = future.exception() + if exc is not None: + logger.debug('Chunk prefetch failed: %s', exc) + + +def _is_chunk_based_retriever(retriever_ref): + """Return True whether retriever_ref is a ChunkBasedSearch instance or class.""" + if isinstance(retriever_ref, ChunkBasedSearch): + return True + if isinstance(retriever_ref, type) and issubclass(retriever_ref, ChunkBasedSearch): + return True + return False + TraversalBasedRetrieverType = Union[TraversalBasedBaseRetriever, Type[TraversalBasedBaseRetriever]] @dataclass @@ -106,6 +125,45 @@ def __init__(self, for r in retrievers ] + def _retrieve(self, query_bundle: QueryBundle): + """Kicks off a `ChunkBasedSearch`-start-node prefetch concurrently with + the entity-context phase of `_init`, so the ~150–200 ms chunk VSS call + hides behind the ~2 s entity-context work rather than running serially + afterward. The future is attached onto each `ChunkBasedSearch` instance + in `_get_search_results_for_query` and consumed in the sub-retriever's + `get_start_node_ids`. + + No-ops (falls through to base `_retrieve`) when: + - `args.derive_subqueries` is True — subqueries would carry different + `query_bundle`s, and the prefetch was built from the original, so + it wouldn't apply; + - no `ChunkBasedSearch` is in the configured retriever list. + """ + should_prefetch = ( + not self.args.derive_subqueries + and any(_is_chunk_based_retriever( + wr.retriever if isinstance(wr, WeightedTraversalBasedRetriever) else wr) + for wr in self.weighted_retrievers) + ) + if not should_prefetch: + return super()._retrieve(query_bundle) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix='chunk-prefetch') as executor: + self._chunk_prefetch = executor.submit( + get_diverse_vss_elements, + 'chunk', + query_bundle, + self.vector_store, + self.args.vss_diversity_factor, + self.args.vss_top_k, + self.filter_config, + ) + self._chunk_prefetch.add_done_callback(_log_prefetch_exception) + try: + return super()._retrieve(query_bundle) + finally: + self._chunk_prefetch = None + def get_start_node_ids(self, query_bundle: QueryBundle) -> List[str]: """ Gets the starting node IDs for a given query. @@ -174,9 +232,9 @@ def weighted_arg(v, weight, factor): sub_args['reranker'] = 'tfidf' sub_args['enrich_query'] = False - retriever = (wr.retriever if isinstance(wr.retriever, TraversalBasedBaseRetriever) + retriever = (wr.retriever if isinstance(wr.retriever, TraversalBasedBaseRetriever) else wr.retriever( - self.graph_store, + self.graph_store, self.vector_store, # processors=[ # # No processing - just raw results @@ -189,6 +247,13 @@ def weighted_arg(v, weight, factor): **sub_args )) + # Hand off the chunk-VSS prefetch (kicked off in _retrieve concurrently + # with _init) to the ChunkBasedSearch instance. Always overwrite so a + # stale future from a prior _retrieve on a reused instance is replaced. + prefetch = getattr(self, '_chunk_prefetch', None) + if prefetch is not None and isinstance(retriever, ChunkBasedSearch): + retriever._prefetched_chunks = prefetch + retrievers.append(retriever) search_results = []