Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,22 @@ def get_start_node_ids(self, query_bundle: QueryBundle) -> List[str]:
"""
logger.debug('Getting start node ids for chunk-based search...')

chunks = get_diverse_vss_elements(
'chunk',
query_bundle,
self.vector_store,
self.args.vss_diversity_factor,
self.args.vss_top_k,
self.filter_config
)

# CompositeTraversalBasedRetriever may attach a pre-kicked-off future
# for this call (see its _retrieve override). Pop to consume so a reused
# instance doesn't pick up a stale future on the next call.
prefetched = self.__dict__.pop('_prefetched_chunks', None)
if prefetched is not None:
chunks = prefetched.result()
else:
chunks = get_diverse_vss_elements(
'chunk',
query_bundle,
self.vector_store,
self.args.vss_diversity_factor,
self.args.vss_top_k,
self.filter_config
)

return [chunk['chunk']['chunkId'] for chunk in chunks]

def do_graph_search(self, query_bundle: QueryBundle, start_node_ids:List[str]) -> SearchResultCollection:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,31 @@
from graphrag_toolkit.lexical_graph.retrieval.utils.query_decomposition import QueryDecomposition
from graphrag_toolkit.lexical_graph.retrieval.retrievers.entity_network_search import EntityNetworkSearch
from graphrag_toolkit.lexical_graph.retrieval.retrievers.chunk_based_search import ChunkBasedSearch
from graphrag_toolkit.lexical_graph.retrieval.utils.vector_utils import get_diverse_vss_elements
from graphrag_toolkit.lexical_graph.retrieval.model import SearchResultCollection, SearchResult

from llama_index.core.schema import QueryBundle, NodeWithScore

logger = logging.getLogger(__name__)


def _log_prefetch_exception(future):
"""Done-callback: surface prefetch failures that would otherwise be silent
if `_init` also raises and the future is never consumed via `.result()`.
"""
exc = future.exception()
if exc is not None:
logger.debug('Chunk prefetch failed: %s', exc)


def _is_chunk_based_retriever(retriever_ref):
"""Return True whether retriever_ref is a ChunkBasedSearch instance or class."""
if isinstance(retriever_ref, ChunkBasedSearch):
return True
if isinstance(retriever_ref, type) and issubclass(retriever_ref, ChunkBasedSearch):
return True
return False

TraversalBasedRetrieverType = Union[TraversalBasedBaseRetriever, Type[TraversalBasedBaseRetriever]]

@dataclass
Expand Down Expand Up @@ -106,6 +125,45 @@ def __init__(self,
for r in retrievers
]

def _retrieve(self, query_bundle: QueryBundle):
"""Kicks off a `ChunkBasedSearch`-start-node prefetch concurrently with
the entity-context phase of `_init`, so the ~150–200 ms chunk VSS call
hides behind the ~2 s entity-context work rather than running serially
afterward. The future is attached onto each `ChunkBasedSearch` instance
in `_get_search_results_for_query` and consumed in the sub-retriever's
`get_start_node_ids`.

No-ops (falls through to base `_retrieve`) when:
- `args.derive_subqueries` is True — subqueries would carry different
`query_bundle`s, and the prefetch was built from the original, so
it wouldn't apply;
- no `ChunkBasedSearch` is in the configured retriever list.
"""
should_prefetch = (
not self.args.derive_subqueries
and any(_is_chunk_based_retriever(
wr.retriever if isinstance(wr, WeightedTraversalBasedRetriever) else wr)
for wr in self.weighted_retrievers)
)
if not should_prefetch:
return super()._retrieve(query_bundle)

with concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix='chunk-prefetch') as executor:
self._chunk_prefetch = executor.submit(
get_diverse_vss_elements,
'chunk',
query_bundle,
self.vector_store,
self.args.vss_diversity_factor,
self.args.vss_top_k,
self.filter_config,
)
self._chunk_prefetch.add_done_callback(_log_prefetch_exception)
try:
return super()._retrieve(query_bundle)
finally:
self._chunk_prefetch = None

def get_start_node_ids(self, query_bundle: QueryBundle) -> List[str]:
"""
Gets the starting node IDs for a given query.
Expand Down Expand Up @@ -174,9 +232,9 @@ def weighted_arg(v, weight, factor):
sub_args['reranker'] = 'tfidf'
sub_args['enrich_query'] = False

retriever = (wr.retriever if isinstance(wr.retriever, TraversalBasedBaseRetriever)
retriever = (wr.retriever if isinstance(wr.retriever, TraversalBasedBaseRetriever)
else wr.retriever(
self.graph_store,
self.graph_store,
self.vector_store,
# processors=[
# # No processing - just raw results
Expand All @@ -189,6 +247,13 @@ def weighted_arg(v, weight, factor):
**sub_args
))

# Hand off the chunk-VSS prefetch (kicked off in _retrieve concurrently
# with _init) to the ChunkBasedSearch instance. Always overwrite so a
# stale future from a prior _retrieve on a reused instance is replaced.
prefetch = getattr(self, '_chunk_prefetch', None)
if prefetch is not None and isinstance(retriever, ChunkBasedSearch):
retriever._prefetched_chunks = prefetch

retrievers.append(retriever)

search_results = []
Expand Down