Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ fastembed
watchdog
onnxruntime
tokenizers
orjson
tree_sitter>=0.25.2
tree_sitter_languages; python_version < "3.13"
mcp==1.17.0
Expand Down
129 changes: 105 additions & 24 deletions scripts/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,64 @@
import os
import argparse
from typing import List, Dict, Any, Tuple
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache

from qdrant_client import QdrantClient, models
from fastembed import TextEmbedding
import re
import json
import math

# Prefer orjson for faster serialization (2-3x speedup)
try:
import orjson
def _json_dumps(obj):
return orjson.dumps(obj).decode()
except ImportError:
orjson = None
def _json_dumps(obj):
return json.dumps(obj)

import logging
import threading

# Connection pooling imports
try:
from scripts.qdrant_client_manager import get_qdrant_client, return_qdrant_client, pooled_qdrant_client
_POOL_AVAILABLE = True
except ImportError:
_POOL_AVAILABLE = False
def get_qdrant_client(url=None, api_key=None, force_new=False, use_pool=True):
return QdrantClient(url=url or os.environ.get("QDRANT_URL", "http://localhost:6333"),
api_key=api_key or os.environ.get("QDRANT_API_KEY"))
def return_qdrant_client(client):
pass

# ThreadPoolExecutor for parallel queries (reuse across calls)
_QUERY_EXECUTOR = None
_EXECUTOR_LOCK = threading.Lock()

def _get_query_executor(max_workers: int = 4) -> ThreadPoolExecutor:
"""Get or create a shared ThreadPoolExecutor for parallel queries."""
global _QUERY_EXECUTOR
if _QUERY_EXECUTOR is None:
with _EXECUTOR_LOCK:
if _QUERY_EXECUTOR is None:
_QUERY_EXECUTOR = ThreadPoolExecutor(max_workers=max_workers)
return _QUERY_EXECUTOR

# Filter sanitization cache (avoids repeated deep copies)
_FILTER_CACHE = {}
_FILTER_CACHE_LOCK = threading.Lock()
_FILTER_CACHE_MAX = 256

# Cached regex pattern compilation (avoids recompiling same patterns)
@lru_cache(maxsize=128)
def _compile_regex(pattern: str, flags: int = 0):
"""Cached regex compilation for repeated patterns."""
return re.compile(pattern, flags)

# Import unified caching system
try:
from scripts.cache_manager import get_search_cache, get_embedding_cache, get_expansion_cache
Expand Down Expand Up @@ -1272,10 +1320,18 @@ def lex_hash_vector(phrases: List[str], dim: int = LEX_VECTOR_DIM) -> List[float

# Defensive: sanitize Qdrant filter objects so we never send an empty filter {}
# Qdrant returns 400 if filter has no conditions; return None in that case.
# Uses caching for repeated filter patterns to avoid redundant validation.
def _sanitize_filter_obj(flt):
if flt is None:
return None

# Try cache first (hash by id for object identity)
cache_key = id(flt)
with _FILTER_CACHE_LOCK:
if cache_key in _FILTER_CACHE:
return _FILTER_CACHE[cache_key]

try:
if flt is None:
return None
# Try model-style attributes first
must = getattr(flt, "must", None)
should = getattr(flt, "should", None)
Expand All @@ -1286,17 +1342,24 @@ def _sanitize_filter_obj(flt):
m = [c for c in (flt.get("must") or []) if c is not None]
s = [c for c in (flt.get("should") or []) if c is not None]
mn = [c for c in (flt.get("must_not") or []) if c is not None]
return None if (not m and not s and not mn) else flt
# Unknown structure -> drop
return None
m = [c for c in (must or []) if c is not None]
s = [c for c in (should or []) if c is not None]
mn = [c for c in (must_not or []) if c is not None]
if not m and not s and not mn:
return None
return flt
result = None if (not m and not s and not mn) else flt
else:
# Unknown structure -> drop
result = None
else:
m = [c for c in (must or []) if c is not None]
s = [c for c in (should or []) if c is not None]
mn = [c for c in (must_not or []) if c is not None]
result = None if (not m and not s and not mn) else flt
except Exception:
return None
result = None

# Cache result (with size limit)
with _FILTER_CACHE_LOCK:
if len(_FILTER_CACHE) < _FILTER_CACHE_MAX:
_FILTER_CACHE[cache_key] = result

return result


def lex_query(client: QdrantClient, v: List[float], flt, per_query: int, collection_name: str | None = None) -> List[Any]:
Expand Down Expand Up @@ -1984,18 +2047,36 @@ def _scaled_rrf(rank: int) -> float:

flt_gated = _sanitize_filter_obj(flt_gated)

result_sets: List[List[Any]] = [
dense_query(
client,
vec_name,
v,
flt_gated,
_scaled_per_query,
collection,
query_text=queries[i] if i < len(queries) else None,
)
for i, v in enumerate(embedded)
]
# Parallel dense query execution for multiple queries
if len(embedded) > 1 and os.environ.get("PARALLEL_DENSE_QUERIES", "1") == "1":
executor = _get_query_executor()
futures = [
executor.submit(
dense_query,
client,
vec_name,
v,
flt_gated,
_scaled_per_query,
collection,
queries[i] if i < len(queries) else None,
)
for i, v in enumerate(embedded)
]
result_sets: List[List[Any]] = [f.result() for f in futures]
else:
result_sets: List[List[Any]] = [
dense_query(
client,
vec_name,
v,
flt_gated,
_scaled_per_query,
collection,
query_text=queries[i] if i < len(queries) else None,
)
for i, v in enumerate(embedded)
]
if os.environ.get("DEBUG_HYBRID_SEARCH"):
total_dense_results = sum(len(rs) for rs in result_sets)
logger.debug(f"Dense query returned {total_dense_results} total results across {len(result_sets)} queries")
Expand Down
8 changes: 4 additions & 4 deletions scripts/ingest_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _detect_repo_name_from_path(path: Path) -> str:
import ast
import time
from pathlib import Path
from typing import List, Dict, Iterable
from typing import List, Dict, Iterable, Any, Optional

try:
from tqdm import tqdm
Expand Down Expand Up @@ -2782,7 +2782,7 @@ def index_repo(
batch_meta: list[dict] = []
batch_ids: list[int] = []
batch_lex: list[list[float]] = []
BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", "64") or 64)
BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", "256") or 256)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@m1rl0k any reason you raised defaults?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a performance optimization - larger batches = fewer Qdrant upsert roundtrips = faster indexing. Trade-off is higher memory per batch... we can leave at default though

CHUNK_LINES = int(os.environ.get("INDEX_CHUNK_LINES", "120") or 120)
CHUNK_OVERLAP = int(os.environ.get("INDEX_CHUNK_OVERLAP", "20") or 20)
PROGRESS_EVERY = int(os.environ.get("INDEX_PROGRESS_EVERY", "200") or 200)
Expand Down Expand Up @@ -3286,8 +3286,8 @@ def make_point(pid, dense_vec, lex_vec, payload):
status={
"state": "indexing",
"progress": {
"files_processed": repo_progress.get(per_file_repo, 0),
"total_files": repo_total.get(per_file_repo, None),
"files_processed": files_indexed,
"total_files": files_seen,
"current_file": str(file_path),
},
},
Expand Down
19 changes: 17 additions & 2 deletions scripts/mcp_indexer_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import subprocess
import threading
import time
from typing import Any, Dict, Optional, List
from typing import Any, Dict, Optional, List, Tuple

from pathlib import Path
import sys
Expand Down Expand Up @@ -629,6 +629,20 @@ def _maybe_parse_jsonish(obj: _Any):
import urllib.parse as _urlparse, ast as _ast


def _invalidate_router_scratchpad(workspace_path: str) -> bool:
"""Invalidate any cached router scratchpad for the workspace.

This is called after indexing operations to ensure the router
picks up new/changed code. Returns True if invalidation occurred.
"""
# Stub implementation - can be extended later for router cache invalidation
try:
# Clear any in-memory caches that might be stale
return True
except Exception:
return False


def _parse_kv_string(s: str) -> _Dict[str, _Any]:
"""Parse non-JSON strings like "a=1&b=2" or "query=[\"a\",\"b\"]" into a dict.
Values are JSON-decoded when possible; else literal-eval; else kept as raw strings.
Expand All @@ -653,7 +667,7 @@ def _parse_kv_string(s: str) -> _Dict[str, _Any]:
out[k.strip()] = _coerce_value_string(v.strip())
return out
except Exception as e:
logger.debug(f"Failed to parse KV string '{input_str}': {e}")
logger.debug(f"Failed to parse KV string '{s}': {e}")
return {}
return out

Expand Down Expand Up @@ -4776,6 +4790,7 @@ async def expand_query(query: Any = None, max_new: Any = None) -> Dict[str, Any]
"Return JSON array of strings only. No explanations.\n"
f"Queries: {qlist}\n"
)
client = LlamaCppRefragClient()
out = client.generate_with_soft_embeddings(
prompt=prompt,
max_tokens=int(os.environ.get("EXPAND_MAX_TOKENS", "64") or 64),
Expand Down
29 changes: 29 additions & 0 deletions scripts/rerank_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
_RERANK_LOCK = threading.Lock()


_WARMUP_DONE = False

def _get_rerank_session():
global _RERANK_SESSION, _RERANK_TOKENIZER
if not (ort and Tokenizer and RERANKER_ONNX_PATH and RERANKER_TOKENIZER_PATH):
Expand Down Expand Up @@ -101,6 +103,33 @@ def _get_rerank_session():
from scripts.utils import sanitize_vector_name as _sanitize_vector_name


def warmup_reranker():
"""Background warmup: load ONNX session and run a dummy inference."""
global _WARMUP_DONE
if _WARMUP_DONE:
return
sess, tok = _get_rerank_session()
if sess and tok:
try:
# Dummy inference to warm up the session
dummy_pairs = [("warmup query", "warmup document")]
rerank_local(dummy_pairs)
except Exception:
pass
_WARMUP_DONE = True


def _start_background_warmup():
"""Start background thread to warm up reranker."""
if os.environ.get("RERANK_WARMUP", "1") == "1":
t = threading.Thread(target=warmup_reranker, daemon=True)
t.start()


# Auto-start warmup on module import
_start_background_warmup()


def _norm_under(u: str | None) -> str | None:
if not u:
return None
Expand Down
Loading
Loading