diff --git a/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/chunk_based_search.py b/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/chunk_based_search.py index aa2bb5e5..540e2bc8 100644 --- a/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/chunk_based_search.py +++ b/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/chunk_based_search.py @@ -121,15 +121,22 @@ def get_start_node_ids(self, query_bundle: QueryBundle) -> List[str]: """ logger.debug('Getting start node ids for chunk-based search...') - chunks = get_diverse_vss_elements( - 'chunk', - query_bundle, - self.vector_store, - self.args.vss_diversity_factor, - self.args.vss_top_k, - self.filter_config - ) - + # CompositeTraversalBasedRetriever may attach a pre-kicked-off future + # for this call (see its _retrieve override). Pop to consume so a reused + # instance doesn't pick up a stale future on the next call. + prefetched = self.__dict__.pop('_prefetched_chunks', None) + if prefetched is not None: + chunks = prefetched.result() + else: + chunks = get_diverse_vss_elements( + 'chunk', + query_bundle, + self.vector_store, + self.args.vss_diversity_factor, + self.args.vss_top_k, + self.filter_config + ) + return [chunk['chunk']['chunkId'] for chunk in chunks] def do_graph_search(self, query_bundle: QueryBundle, start_node_ids:List[str]) -> SearchResultCollection: diff --git a/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/composite_traversal_based_retriever.py b/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/composite_traversal_based_retriever.py index 80b2627c..a78a5167 100644 --- a/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/composite_traversal_based_retriever.py +++ b/lexical-graph/src/graphrag_toolkit/lexical_graph/retrieval/retrievers/composite_traversal_based_retriever.py @@ -15,12 +15,31 @@ from graphrag_toolkit.lexical_graph.retrieval.utils.query_decomposition import QueryDecomposition from graphrag_toolkit.lexical_graph.retrieval.retrievers.entity_network_search import EntityNetworkSearch from graphrag_toolkit.lexical_graph.retrieval.retrievers.chunk_based_search import ChunkBasedSearch +from graphrag_toolkit.lexical_graph.retrieval.utils.vector_utils import get_diverse_vss_elements from graphrag_toolkit.lexical_graph.retrieval.model import SearchResultCollection, SearchResult from llama_index.core.schema import QueryBundle, NodeWithScore logger = logging.getLogger(__name__) + +def _log_prefetch_exception(future): + """Done-callback: surface prefetch failures that would otherwise be silent + if `_init` also raises and the future is never consumed via `.result()`. + """ + exc = future.exception() + if exc is not None: + logger.debug('Chunk prefetch failed: %s', exc) + + +def _is_chunk_based_retriever(retriever_ref): + """Return True whether retriever_ref is a ChunkBasedSearch instance or class.""" + if isinstance(retriever_ref, ChunkBasedSearch): + return True + if isinstance(retriever_ref, type) and issubclass(retriever_ref, ChunkBasedSearch): + return True + return False + TraversalBasedRetrieverType = Union[TraversalBasedBaseRetriever, Type[TraversalBasedBaseRetriever]] @dataclass @@ -106,6 +125,45 @@ def __init__(self, for r in retrievers ] + def _retrieve(self, query_bundle: QueryBundle): + """Kicks off a `ChunkBasedSearch`-start-node prefetch concurrently with + the entity-context phase of `_init`, so the ~150–200 ms chunk VSS call + hides behind the ~2 s entity-context work rather than running serially + afterward. The future is attached onto each `ChunkBasedSearch` instance + in `_get_search_results_for_query` and consumed in the sub-retriever's + `get_start_node_ids`. + + No-ops (falls through to base `_retrieve`) when: + - `args.derive_subqueries` is True — subqueries would carry different + `query_bundle`s, and the prefetch was built from the original, so + it wouldn't apply; + - no `ChunkBasedSearch` is in the configured retriever list. + """ + should_prefetch = ( + not self.args.derive_subqueries + and any(_is_chunk_based_retriever( + wr.retriever if isinstance(wr, WeightedTraversalBasedRetriever) else wr) + for wr in self.weighted_retrievers) + ) + if not should_prefetch: + return super()._retrieve(query_bundle) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix='chunk-prefetch') as executor: + self._chunk_prefetch = executor.submit( + get_diverse_vss_elements, + 'chunk', + query_bundle, + self.vector_store, + self.args.vss_diversity_factor, + self.args.vss_top_k, + self.filter_config, + ) + self._chunk_prefetch.add_done_callback(_log_prefetch_exception) + try: + return super()._retrieve(query_bundle) + finally: + self._chunk_prefetch = None + def get_start_node_ids(self, query_bundle: QueryBundle) -> List[str]: """ Gets the starting node IDs for a given query. @@ -174,9 +232,9 @@ def weighted_arg(v, weight, factor): sub_args['reranker'] = 'tfidf' sub_args['enrich_query'] = False - retriever = (wr.retriever if isinstance(wr.retriever, TraversalBasedBaseRetriever) + retriever = (wr.retriever if isinstance(wr.retriever, TraversalBasedBaseRetriever) else wr.retriever( - self.graph_store, + self.graph_store, self.vector_store, # processors=[ # # No processing - just raw results @@ -189,6 +247,13 @@ def weighted_arg(v, weight, factor): **sub_args )) + # Hand off the chunk-VSS prefetch (kicked off in _retrieve concurrently + # with _init) to the ChunkBasedSearch instance. Always overwrite so a + # stale future from a prior _retrieve on a reused instance is replaced. + prefetch = getattr(self, '_chunk_prefetch', None) + if prefetch is not None and isinstance(retriever, ChunkBasedSearch): + retriever._prefetched_chunks = prefetch + retrievers.append(retriever) search_results = []