From 9f28ceaa9968335b2d37814633ca013bf49aa9a6 Mon Sep 17 00:00:00 2001 From: Aksel Joonas Reedi <125026660+akseljoonas@users.noreply.github.com> Date: Mon, 27 Apr 2026 16:29:41 +0300 Subject: [PATCH] Improve research tool search quality with Tantivy Whoosh is unmaintained and emits Python 3.12 syntax warnings. More importantly, the existing research tools ranked whole pages/files and often forced the agent to spend tokens reading broad results before finding the useful passage. This moves HF docs, HF OpenAPI, and GitHub example search onto a small Tantivy-backed search layer with passage/snippet chunking, source line ranges, and disk caches for network-backed research data. GitHub example lookup now searches file contents as well as paths, tolerates missing or rejected GitHub tokens for public repos, and returns focused snippets that the agent can follow up with github_read_file line ranges. Constraint: Keep the PR scoped to search quality and do not introduce RAG or embedding infra. Rejected: Keep Whoosh and suppress warnings | leaves the stale dependency and weaker result granularity in place. Rejected: Index raw notebooks as snippets | raw ipynb JSON produced noisy excerpts and misleading line ranges. Confidence: high Scope-risk: moderate Directive: Treat this as the search substrate for future research-tool consolidation; broader gh/hf CLI exposure should build on this rather than reintroducing independent search paths. Tested: uv run pytest tests/unit/test_tantivy_search.py tests/unit/test_docs_tantivy_search.py tests/unit/test_github_find_examples_tantivy.py -q Tested: uv run python -m compileall -q agent/search agent/tools/docs_tools.py agent/tools/github_find_examples.py Tested: live explore_hf_docs, find_hf_api, github_find_examples calls with cached follow-up timings Tested: real ml-intern CLI research prompt exercised explore_hf_docs, github_find_examples, fetch_hf_docs, and github_read_file Not-tested: Full unit suite has two pre-existing doom-loop wording assertion failures unrelated to search. --- agent/main.py | 8 +- agent/search/__init__.py | 6 + agent/search/cache.py | 48 +++ agent/search/chunking.py | 117 +++++++ agent/search/tantivy_index.py | 93 +++++ agent/tools/docs_tools.py | 291 ++++++++-------- agent/tools/github_find_examples.py | 320 ++++++++++++++++-- pyproject.toml | 2 +- tests/unit/test_docs_tantivy_search.py | 114 +++++++ .../unit/test_github_find_examples_tantivy.py | 159 +++++++++ tests/unit/test_tantivy_search.py | 54 +++ uv.lock | 46 ++- 12 files changed, 1072 insertions(+), 186 deletions(-) create mode 100644 agent/search/__init__.py create mode 100644 agent/search/cache.py create mode 100644 agent/search/chunking.py create mode 100644 agent/search/tantivy_index.py create mode 100644 tests/unit/test_docs_tantivy_search.py create mode 100644 tests/unit/test_github_find_examples_tantivy.py create mode 100644 tests/unit/test_tantivy_search.py diff --git a/agent/main.py b/agent/main.py index 8459d757..c19e9c93 100644 --- a/agent/main.py +++ b/agent/main.py @@ -807,7 +807,7 @@ async def _handle_slash_command( return None -async def main(model: str | None = None): +async def main(): """Interactive chat with the agent""" # Clear screen @@ -822,8 +822,6 @@ async def main(model: str | None = None): hf_token = await _prompt_and_save_hf_token(prompt_session) config = load_config(CLI_CONFIG_PATH) - if model: - config.model_name = model # Resolve username for banner hf_user = _get_hf_user(hf_token) @@ -1223,8 +1221,6 @@ def cli(): _configure_runtime_logging() # Suppress litellm pydantic deprecation warnings warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm") - # Suppress whoosh invalid escape sequence warnings (third-party, unfixed upstream) - warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh") parser = argparse.ArgumentParser(description="Hugging Face Agent CLI") parser.add_argument("prompt", nargs="?", default=None, help="Run headlessly with this prompt") @@ -1242,7 +1238,7 @@ def cli(): max_iter = 10_000 # effectively unlimited asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream)) else: - asyncio.run(main(model=args.model)) + asyncio.run(main()) except KeyboardInterrupt: print("\n\nGoodbye!") diff --git a/agent/search/__init__.py b/agent/search/__init__.py new file mode 100644 index 00000000..5dc9f408 --- /dev/null +++ b/agent/search/__init__.py @@ -0,0 +1,6 @@ +"""Search infrastructure used by research tools.""" + +from agent.search.chunking import chunk_code, chunk_markdown +from agent.search.tantivy_index import SearchHit, TantivyTextIndex + +__all__ = ["SearchHit", "TantivyTextIndex", "chunk_code", "chunk_markdown"] diff --git a/agent/search/cache.py b/agent/search/cache.py new file mode 100644 index 00000000..b133d072 --- /dev/null +++ b/agent/search/cache.py @@ -0,0 +1,48 @@ +"""Small JSON cache for network-backed research search indexes.""" + +from __future__ import annotations + +import hashlib +import json +import os +import time +from pathlib import Path +from typing import Any + + +DEFAULT_TTL_SECONDS = 7 * 24 * 60 * 60 + + +def cache_root() -> Path: + configured = os.environ.get("ML_INTERN_SEARCH_CACHE_DIR") + if configured: + return Path(configured) + return Path.cwd() / ".ml-intern-cache" / "search" + + +def read_json(namespace: str, key: str, *, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> Any | None: + path = _path(namespace, key) + try: + if time.time() - path.stat().st_mtime > ttl_seconds: + return None + return json.loads(path.read_text(encoding="utf-8")) + except (FileNotFoundError, json.JSONDecodeError, OSError): + return None + + +def write_json(namespace: str, key: str, value: Any) -> None: + path = _path(namespace, key) + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(".tmp") + tmp.write_text(json.dumps(value), encoding="utf-8") + tmp.replace(path) + + +def stable_key(*parts: object) -> str: + raw = "\x1f".join(str(part) for part in parts) + return hashlib.sha256(raw.encode("utf-8")).hexdigest() + + +def _path(namespace: str, key: str) -> Path: + safe_namespace = namespace.replace("/", "_") + return cache_root() / safe_namespace / f"{key}.json" diff --git a/agent/search/chunking.py b/agent/search/chunking.py new file mode 100644 index 00000000..4aea2be7 --- /dev/null +++ b/agent/search/chunking.py @@ -0,0 +1,117 @@ +"""Chunk text into source-addressable passages for search tools.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TextChunk: + title: str + text: str + line_start: int + line_end: int + + +def chunk_markdown(content: str, *, max_chars: int = 1800) -> list[TextChunk]: + """Split markdown into heading-aware chunks with line ranges.""" + lines = content.splitlines() + chunks: list[TextChunk] = [] + heading = "Introduction" + buffer: list[tuple[int, str]] = [] + + def flush() -> None: + nonlocal buffer + if not buffer: + return + text = "\n".join(line for _, line in buffer).strip() + if text: + chunks.extend(_split_oversized(heading, buffer, max_chars=max_chars)) + buffer = [] + + for index, line in enumerate(lines, 1): + heading_match = re.match(r"^(#{1,6})\s+(.+?)\s*$", line) + if heading_match: + flush() + heading = heading_match.group(2).strip() + buffer.append((index, line)) + continue + buffer.append((index, line)) + if sum(len(part) + 1 for _, part in buffer) >= max_chars: + flush() + flush() + + if chunks: + return chunks + text = content.strip() + if not text: + return [] + return [TextChunk(title="Content", text=text[:max_chars], line_start=1, line_end=len(lines) or 1)] + + +def chunk_code(content: str, *, window: int = 80, overlap: int = 15) -> list[TextChunk]: + """Split source code into overlapping line windows.""" + lines = content.splitlines() + if not lines: + return [] + chunks: list[TextChunk] = [] + step = max(1, window - overlap) + for start in range(0, len(lines), step): + end = min(len(lines), start + window) + chunk_lines = lines[start:end] + title = _guess_code_title(chunk_lines) or f"Lines {start + 1}-{end}" + chunks.append( + TextChunk( + title=title, + text="\n".join(chunk_lines).strip(), + line_start=start + 1, + line_end=end, + ) + ) + if end == len(lines): + break + return [chunk for chunk in chunks if chunk.text] + + +def _split_oversized( + heading: str, buffer: list[tuple[int, str]], *, max_chars: int +) -> list[TextChunk]: + chunks: list[TextChunk] = [] + current: list[tuple[int, str]] = [] + current_chars = 0 + for item in buffer: + line_len = len(item[1]) + 1 + if current and current_chars + line_len > max_chars: + chunks.append(_make_chunk(heading, current)) + current = [] + current_chars = 0 + current.append(item) + current_chars += line_len + if current: + chunks.append(_make_chunk(heading, current)) + return chunks + + +def _make_chunk(heading: str, items: list[tuple[int, str]]) -> TextChunk: + return TextChunk( + title=heading, + text="\n".join(line for _, line in items).strip(), + line_start=items[0][0], + line_end=items[-1][0], + ) + + +def _guess_code_title(lines: list[str]) -> str | None: + for line in lines: + stripped = line.strip() + match = re.match(r"(async\s+def|def|class)\s+([A-Za-z_][\w]*)", stripped) + if match: + return stripped.rstrip(":") + if stripped.startswith("if __name__"): + return "Script entrypoint" + for line in lines: + stripped = line.strip() + if stripped and not stripped.startswith(("#", "//", "/*", "*")): + return stripped[:80] + return None diff --git a/agent/search/tantivy_index.py b/agent/search/tantivy_index.py new file mode 100644 index 00000000..dc93a5a1 --- /dev/null +++ b/agent/search/tantivy_index.py @@ -0,0 +1,93 @@ +"""Small Tantivy wrapper for local, snippet-first research search.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import tantivy + + +@dataclass(frozen=True) +class SearchHit: + score: float + fields: dict[str, Any] + + +class TantivyTextIndex: + """A compact text index with stored metadata fields. + + The wrapper keeps Tantivy-specific details out of tool handlers and gives the + research tools one simple operation: add stored documents, then search them + with BM25 ranking and field boosts. + """ + + def __init__( + self, + *, + text_fields: list[str], + stored_fields: list[str] | None = None, + field_boosts: dict[str, float] | None = None, + path: Path | None = None, + ) -> None: + self.text_fields = text_fields + self.stored_fields = list(dict.fromkeys([*(stored_fields or []), *text_fields])) + self.field_boosts = field_boosts or {} + + builder = tantivy.SchemaBuilder() + for field in self.stored_fields: + tokenizer = "en_stem" if field in text_fields else "default" + builder.add_text_field(field, stored=True, tokenizer_name=tokenizer) + self.schema = builder.build() + + if path is not None: + path.mkdir(parents=True, exist_ok=True) + self.index = tantivy.Index(self.schema, path=str(path)) + else: + self.index = tantivy.Index(self.schema) + + def add_documents(self, documents: list[dict[str, Any]]) -> None: + if not documents: + return + + writer = self.index.writer(heap_size=30_000_000, num_threads=1) + for item in documents: + doc = tantivy.Document() + for field in self.stored_fields: + value = item.get(field, "") + if value is None: + value = "" + doc.add_text(field, str(value)) + writer.add_document(doc) + writer.commit() + writer.wait_merging_threads() + self.index.reload() + + def search(self, query: str, *, limit: int = 10) -> tuple[list[SearchHit], list[str]]: + clean_query = (query or "").strip() + if not clean_query: + return [], [] + + parsed, errors = self.index.parse_query_lenient( + clean_query, + self.text_fields, + field_boosts=self.field_boosts, + ) + searcher = self.index.searcher() + results = searcher.search(parsed, limit) + hits: list[SearchHit] = [] + for score, address in results.hits: + doc = searcher.doc(address).to_dict() + hits.append(SearchHit(score=float(score), fields=_flatten_doc(doc))) + return hits, [str(err) for err in errors] + + +def _flatten_doc(doc: dict[str, Any]) -> dict[str, Any]: + flattened: dict[str, Any] = {} + for key, value in doc.items(): + if isinstance(value, list): + flattened[key] = value[0] if value else "" + else: + flattened[key] = value + return flattened diff --git a/agent/tools/docs_tools.py b/agent/tools/docs_tools.py index a1782107..b51ba671 100644 --- a/agent/tools/docs_tools.py +++ b/agent/tools/docs_tools.py @@ -8,10 +8,9 @@ import httpx from bs4 import BeautifulSoup -from whoosh.analysis import StemmingAnalyzer -from whoosh.fields import ID, TEXT, Schema -from whoosh.filedb.filestore import RamStorage -from whoosh.qparser import MultifieldParser, OrGroup + +from agent.search import TantivyTextIndex, chunk_markdown +from agent.search.cache import read_json, stable_key, write_json # --------------------------------------------------------------------------- # Configuration @@ -53,10 +52,10 @@ # --------------------------------------------------------------------------- _docs_cache: dict[str, list[dict[str, str]]] = {} -_index_cache: dict[str, tuple[Any, MultifieldParser]] = {} +_index_cache: dict[str, TantivyTextIndex] = {} _cache_lock = asyncio.Lock() _openapi_cache: dict[str, Any] | None = None -_openapi_index_cache: tuple[Any, MultifieldParser, list[dict[str, Any]]] | None = None +_openapi_index_cache: tuple[TantivyTextIndex, list[dict[str, Any]]] | None = None # --------------------------------------------------------------------------- # Gradio Documentation @@ -98,8 +97,13 @@ async def _fetch_gradio_docs(query: str | None = None) -> str: async def _fetch_endpoint_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]: """Fetch all docs for an endpoint by parsing sidebar and fetching each page.""" + cache_key = stable_key(endpoint) + cached = read_json("hf-docs", cache_key) + if isinstance(cached, list) and all(isinstance(item, dict) for item in cached): + return cached + url = f"https://huggingface.co/docs/{endpoint}" - headers = {"Authorization": f"Bearer {hf_token}"} + headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {} async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: resp = await client.get(url, headers=headers) @@ -137,7 +141,9 @@ async def fetch_page(item: dict[str, str]) -> dict[str, str]: "section": endpoint, } - return list(await asyncio.gather(*[fetch_page(item) for item in nav_items])) + docs = list(await asyncio.gather(*[fetch_page(item) for item in nav_items])) + write_json("hf-docs", cache_key, docs) + return docs async def _get_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]: @@ -172,74 +178,87 @@ async def _get_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]: async def _build_search_index( endpoint: str, docs: list[dict[str, str]] -) -> tuple[Any, MultifieldParser]: - """Build or retrieve cached Whoosh search index.""" +) -> TantivyTextIndex: + """Build or retrieve cached Tantivy passage index.""" async with _cache_lock: if endpoint in _index_cache: return _index_cache[endpoint] - analyzer = StemmingAnalyzer() - schema = Schema( - title=TEXT(stored=True, analyzer=analyzer), - url=ID(stored=True, unique=True), - md_url=ID(stored=True), - section=ID(stored=True), - glimpse=TEXT(stored=True, analyzer=analyzer), - content=TEXT(stored=False, analyzer=analyzer), + index = TantivyTextIndex( + text_fields=["title", "heading", "content"], + stored_fields=[ + "title", + "heading", + "url", + "md_url", + "section", + "glimpse", + "content", + "line_start", + "line_end", + ], + field_boosts={"title": 3.0, "heading": 2.0, "content": 1.0}, ) - storage = RamStorage() - index = storage.create_index(schema) - writer = index.writer() - for doc in docs: - writer.add_document( - title=doc.get("title", ""), - url=doc.get("url", ""), - md_url=doc.get("md_url", ""), - section=doc.get("section", endpoint), - glimpse=doc.get("glimpse", ""), - content=doc.get("content", ""), - ) - writer.commit() - parser = MultifieldParser( - ["title", "content"], - schema=schema, - fieldboosts={"title": 2.0, "content": 1.0}, - group=OrGroup, - ) + passage_docs: list[dict[str, str]] = [] + for doc in docs: + content = doc.get("content", "") or doc.get("glimpse", "") + passages = chunk_markdown(content) + for passage in passages: + text = passage.text.strip() + if not text: + continue + glimpse = text[:500] + "..." if len(text) > 500 else text + passage_docs.append( + { + "title": doc.get("title", ""), + "heading": passage.title, + "url": doc.get("url", ""), + "md_url": doc.get("md_url", ""), + "section": doc.get("section", endpoint), + "glimpse": glimpse, + "content": text, + "line_start": str(passage.line_start), + "line_end": str(passage.line_end), + } + ) + index.add_documents(passage_docs) async with _cache_lock: - _index_cache[endpoint] = (index, parser) - return index, parser + _index_cache[endpoint] = index + return index async def _search_docs( endpoint: str, docs: list[dict[str, str]], query: str, limit: int ) -> tuple[list[dict[str, Any]], str | None]: - """Search docs using Whoosh. Returns (results, fallback_message).""" - index, parser = await _build_search_index(endpoint, docs) - - try: - query_obj = parser.parse(query) - except Exception: - return [], "Query contained unsupported syntax; showing default ordering." - - with index.searcher() as searcher: - results = searcher.search(query_obj, limit=limit) - matches = [ + """Search docs using Tantivy passages. Returns (results, fallback_message).""" + index = await _build_search_index(endpoint, docs) + + hits, parse_errors = index.search(query, limit=limit) + matches = [] + for hit in hits: + fields = hit.fields + title = fields.get("title", "") + heading = fields.get("heading", "") + display_title = f"{title} / {heading}" if heading and heading != title else title + matches.append( { - "title": hit["title"], - "url": hit["url"], - "md_url": hit.get("md_url", ""), - "section": hit.get("section", endpoint), - "glimpse": hit["glimpse"], + "title": display_title, + "url": fields.get("url", ""), + "md_url": fields.get("md_url", ""), + "section": fields.get("section", endpoint), + "glimpse": fields.get("glimpse", ""), + "line_start": fields.get("line_start", ""), + "line_end": fields.get("line_end", ""), "score": round(hit.score, 2), } - for hit in results - ] + ) if not matches: return [], "No strong matches found; showing default ordering." + if parse_errors: + return matches, "Some query syntax was ignored by the search parser." return matches, None @@ -274,6 +293,8 @@ def _format_results( out += f"{i}. **{item['title']}**\n" out += f" URL: {item['url']}\n" out += f" Section: {item.get('section', endpoint)}\n" + if item.get("line_start") and item.get("line_end"): + out += f" Lines: {item['line_start']}-{item['line_end']}\n" if query and "score" in item: out += f" Relevance score: {item['score']:.2f}\n" out += f" Glimpse: {item['glimpse']}\n\n" @@ -317,9 +338,7 @@ async def explore_hf_docs_handler( return f"Error fetching Gradio docs: {str(e)}", False # HF docs - hf_token = session.hf_token if session else None - if not hf_token: - return "Error: No HF token available (not logged in)", False + hf_token = session.hf_token if session else "" try: max_results_int = int(max_results) if max_results is not None else None @@ -387,18 +406,15 @@ async def hf_docs_fetch_handler( if not url: return "Error: No URL provided", False - hf_token = session.hf_token if session else None - if not hf_token: - return "Error: No HF token available (not logged in)", False + hf_token = session.hf_token if session else "" if not url.endswith(".md"): url = f"{url}.md" try: async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: - resp = await client.get( - url, headers={"Authorization": f"Bearer {hf_token}"} - ) + headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {} + resp = await client.get(url, headers=headers) resp.raise_for_status() return f"Documentation from: {url}\n\n{resp.text}", True except httpx.HTTPStatusError as e: @@ -423,11 +439,17 @@ async def _fetch_openapi_spec() -> dict[str, Any]: if _openapi_cache is not None: return _openapi_cache + cached = read_json("hf-openapi", "spec") + if isinstance(cached, dict): + _openapi_cache = cached + return _openapi_cache + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: resp = await client.get("https://huggingface.co/.well-known/openapi.json") resp.raise_for_status() _openapi_cache = resp.json() + write_json("hf-openapi", "spec", _openapi_cache) return _openapi_cache @@ -484,8 +506,8 @@ def _extract_all_endpoints(spec: dict[str, Any]) -> list[dict[str, Any]]: return endpoints -async def _build_openapi_index() -> tuple[Any, MultifieldParser, list[dict[str, Any]]]: - """Build or retrieve cached Whoosh index for OpenAPI endpoints.""" +async def _build_openapi_index() -> tuple[TantivyTextIndex, list[dict[str, Any]]]: + """Build or retrieve cached Tantivy index for OpenAPI endpoints.""" global _openapi_index_cache async with _cache_lock: if _openapi_index_cache is not None: @@ -494,85 +516,78 @@ async def _build_openapi_index() -> tuple[Any, MultifieldParser, list[dict[str, spec = await _fetch_openapi_spec() endpoints = _extract_all_endpoints(spec) - analyzer = StemmingAnalyzer() - schema = Schema( - path=ID(stored=True, unique=True), - method=ID(stored=True), - operationId=TEXT(stored=True, analyzer=analyzer), - summary=TEXT(stored=True, analyzer=analyzer), - description=TEXT(stored=True, analyzer=analyzer), - tags=TEXT(stored=True, analyzer=analyzer), - param_names=TEXT(stored=False, analyzer=analyzer), + index = TantivyTextIndex( + text_fields=[ + "summary", + "description", + "operationId", + "tags", + "param_names", + "path", + ], + stored_fields=[ + "route_key", + "path", + "method", + "operationId", + "summary", + "description", + "tags", + "param_names", + ], + field_boosts={ + "summary": 3.0, + "operationId": 2.5, + "path": 2.0, + "tags": 1.5, + "param_names": 1.25, + "description": 1.0, + }, ) - storage = RamStorage() - index = storage.create_index(schema) - writer = index.writer() + docs: list[dict[str, str]] = [] for ep in endpoints: param_names = " ".join(p.get("name", "") for p in ep.get("parameters", [])) - writer.add_document( - path=ep["path"], - method=ep["method"], - operationId=ep.get("operationId", ""), - summary=ep.get("summary", ""), - description=ep.get("description", ""), - tags=ep.get("tags", ""), - param_names=param_names, + docs.append( + { + "route_key": f"{ep['method']} {ep['path']}", + "path": ep["path"], + "method": ep["method"], + "operationId": ep.get("operationId", ""), + "summary": ep.get("summary", ""), + "description": ep.get("description", ""), + "tags": ep.get("tags", ""), + "param_names": param_names, + } ) - writer.commit() - - parser = MultifieldParser( - ["summary", "description", "operationId", "tags", "param_names"], - schema=schema, - fieldboosts={ - "summary": 3.0, - "operationId": 2.0, - "description": 1.0, - "tags": 1.5, - }, - group=OrGroup, - ) + index.add_documents(docs) async with _cache_lock: - _openapi_index_cache = (index, parser, endpoints) - return index, parser, endpoints + _openapi_index_cache = (index, endpoints) + return index, endpoints async def _search_openapi( query: str, tag: str | None, limit: int = 20 ) -> tuple[list[dict[str, Any]], str | None]: - """Search OpenAPI endpoints using Whoosh. Returns (results, fallback_message).""" - index, parser, endpoints = await _build_openapi_index() - - try: - query_obj = parser.parse(query) - except Exception: - return [], "Query contained unsupported syntax." - - with index.searcher() as searcher: - results = searcher.search( - query_obj, limit=limit * 2 - ) # Get extra for tag filtering - matches = [] - for hit in results: - # Find full endpoint data - ep = next( - ( - e - for e in endpoints - if e["path"] == hit["path"] and e["method"] == hit["method"] - ), - None, - ) - if ep is None: - continue - # Filter by tag if provided - if tag and tag not in ep.get("tags", ""): - continue - matches.append({**ep, "score": round(hit.score, 2)}) - if len(matches) >= limit: - break + """Search OpenAPI endpoints using Tantivy. Returns (results, fallback_message).""" + index, endpoints = await _build_openapi_index() + endpoint_by_key = {f"{ep['method']} {ep['path']}": ep for ep in endpoints} + + hits, parse_errors = index.search(query, limit=limit * 3) + matches = [] + for hit in hits: + ep = endpoint_by_key.get(hit.fields.get("route_key", "")) + if ep is None: + continue + if tag and tag not in ep.get("tags", ""): + continue + matches.append({**ep, "score": round(hit.score, 2)}) + if len(matches) >= limit: + break + if matches and parse_errors: + return matches, "Some query syntax was ignored by the search parser." return matches, None if matches else "No matches found for query." @@ -748,17 +763,17 @@ async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]: try: note = None - # If query provided, try Whoosh search first + # If query provided, try Tantivy search first if query: results, search_note = await _search_openapi(query, tag, limit=20) - # If Whoosh found results, return them + # If search found results, return them if results: return _format_openapi_results( results, tag=tag, query=query, note=search_note ), True - # Whoosh found nothing - fall back to tag-based if tag provided + # Search found nothing - fall back to tag-based if tag provided if tag: note = f"No matches for '{query}'; showing all endpoints in tag '{tag}'" else: @@ -767,7 +782,7 @@ async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]: # Tag-based search (either as fallback or primary) if tag: - _, _, endpoints = await _build_openapi_index() + _, endpoints = await _build_openapi_index() results = [ep for ep in endpoints if tag in ep.get("tags", "")] return _format_openapi_results( results, tag=tag, query=None, note=note diff --git a/agent/tools/github_find_examples.py b/agent/tools/github_find_examples.py index f5f2ddaa..1224da4a 100644 --- a/agent/tools/github_find_examples.py +++ b/agent/tools/github_find_examples.py @@ -10,6 +10,8 @@ import requests from thefuzz import fuzz +from agent.search import TantivyTextIndex, chunk_code +from agent.search.cache import read_json, stable_key, write_json from agent.tools.types import ToolResult # In order of priority (lower index = higher priority for sorting) @@ -52,21 +54,70 @@ "showcase", ] +CODE_EXTENSIONS = { + ".py", + ".js", + ".jsx", + ".ts", + ".tsx", + ".md", + ".mdx", + ".yaml", + ".yml", + ".toml", +} +MAX_INDEXED_EXAMPLE_FILES = 50 +MAX_INDEXED_FILE_BYTES = 400_000 -def _get_repo_tree(org: str, repo: str, token: str) -> tuple[List[Dict[str, Any]], str]: - """Get all files in a repository recursively. Returns (files, error_message)""" + +def _github_headers(token: str, *, raw: bool = False) -> Dict[str, str]: headers = { - "Accept": "application/vnd.github+json", + "Accept": "application/vnd.github.raw" + if raw + else "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {token}", } + if token: + headers["Authorization"] = f"Bearer {token}" + return headers + + +def _github_get( + url: str, + token: str, + *, + raw: bool = False, + **kwargs, +) -> requests.Response: + response = requests.get( + url, + headers=_github_headers(token, raw=raw), + **kwargs, + ) + if response.status_code == 401 and token: + return requests.get( + url, + headers=_github_headers("", raw=raw), + **kwargs, + ) + return response + + +def _get_repo_tree(org: str, repo: str, token: str) -> tuple[List[Dict[str, Any]], str]: + """Get all files in a repository recursively. Returns (files, error_message)""" + cache_key = stable_key(org, repo) + cached = read_json("github-trees", cache_key) + if isinstance(cached, dict) and isinstance(cached.get("files"), list): + return cached["files"], "" full_repo = f"{org}/{repo}" # Get default branch try: - response = requests.get( - f"https://api.github.com/repos/{full_repo}", headers=headers, timeout=10 + response = _github_get( + f"https://api.github.com/repos/{full_repo}", + token, + timeout=10, ) if response.status_code == 404: return [], "not_found" @@ -80,9 +131,9 @@ def _get_repo_tree(org: str, repo: str, token: str) -> tuple[List[Dict[str, Any] # Get repository tree recursively try: - response = requests.get( + response = _github_get( f"https://api.github.com/repos/{full_repo}/git/trees/{default_branch}", - headers=headers, + token, params={"recursive": "1"}, timeout=30, ) @@ -98,12 +149,18 @@ def _get_repo_tree(org: str, repo: str, token: str) -> tuple[List[Dict[str, Any] "path": item["path"], "ref": item["sha"], "size": item.get("size", 0), + "branch": default_branch, "url": f"https://github.com/{full_repo}/blob/{default_branch}/{item['path']}", } for item in tree if item["type"] == "blob" ] + write_json( + "github-trees", + cache_key, + {"default_branch": default_branch, "files": files}, + ) return files, "" except Exception as e: return [], f"Error processing tree: {str(e)}" @@ -111,19 +168,13 @@ def _get_repo_tree(org: str, repo: str, token: str) -> tuple[List[Dict[str, Any] def _search_similar_repos(org: str, repo: str, token: str) -> List[Dict[str, Any]]: """Search for similar repository names in the organization""" - headers = { - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {token}", - } - # Search for repos in the org with similar name query = f"org:{org} {repo}" try: - response = requests.get( + response = _github_get( "https://api.github.com/search/repositories", - headers=headers, + token, params={"q": query, "sort": "stars", "order": "desc", "per_page": 10}, timeout=30, ) @@ -168,6 +219,167 @@ def _score_against_keyword(file_path: str, keyword: str) -> int: return max(partial_score, token_score) +def _is_indexable_example_file(file_path: str, size: int) -> bool: + _, ext = os.path.splitext(file_path.lower()) + return ext in CODE_EXTENSIONS and 0 < size <= MAX_INDEXED_FILE_BYTES + + +def _rank_index_candidate( + file: Dict[str, Any], keyword: str +) -> tuple[int, int, int, int, str]: + in_examples_dir, pattern_priority, path_depth = _get_pattern_priority(file["path"]) + keyword_score = _score_against_keyword(file["path"], keyword) if keyword else 0 + return (-keyword_score, in_examples_dir, pattern_priority, path_depth, file["path"]) + + +def _fetch_file_content_cached( + org: str, + repo: str, + file: Dict[str, Any], + token: str, +) -> str | None: + cache_key = stable_key(org, repo, file.get("path"), file.get("ref")) + cached = read_json("github-files", cache_key) + if isinstance(cached, dict) and isinstance(cached.get("content"), str): + return cached["content"] + + url = f"https://api.github.com/repos/{org}/{repo}/contents/{file['path']}" + params = {"ref": file.get("branch", "HEAD")} + try: + response = _github_get(url, token, raw=True, params=params, timeout=20) + if response.status_code != 200: + return None + content = response.text + except Exception: + return None + + write_json("github-files", cache_key, {"content": content}) + return content + + +def _search_example_snippets( + keyword: str, + org: str, + repo: str, + files: list[Dict[str, Any]], + token: str, + *, + limit: int, +) -> list[Dict[str, Any]]: + candidates = _get_index_candidates(files, keyword) + if not candidates: + return [] + + cache_key = stable_key( + org, + repo, + "snippet-docs", + *[f"{file.get('path')}@{file.get('ref')}" for file in candidates], + ) + cached_docs = read_json("github-snippet-docs", cache_key) + if isinstance(cached_docs, list) and all( + isinstance(item, dict) for item in cached_docs + ): + docs = cached_docs + else: + docs = _build_example_snippet_docs(org, repo, candidates, token) + write_json("github-snippet-docs", cache_key, docs) + + if not docs: + return [] + + index = TantivyTextIndex( + text_fields=["path", "heading", "content"], + stored_fields=[ + "path", + "url", + "ref", + "size", + "heading", + "content", + "line_start", + "line_end", + ], + field_boosts={"path": 3.0, "heading": 2.0, "content": 1.0}, + ) + index.add_documents(docs) + hits, _ = index.search(keyword, limit=limit) + return [ + { + **hit.fields, + "score": round(hit.score, 2), + } + for hit in hits + ] + + +def _build_example_snippet_docs( + org: str, + repo: str, + candidates: list[Dict[str, Any]], + token: str, +) -> list[dict[str, str]]: + docs: list[dict[str, str]] = [] + for file in candidates: + content = _fetch_file_content_cached(org, repo, file, token) + if not content: + continue + for chunk in chunk_code(content): + docs.append( + { + "path": file["path"], + "url": file["url"], + "ref": file["ref"], + "size": str(file.get("size", 0)), + "heading": chunk.title, + "content": chunk.text, + "line_start": str(chunk.line_start), + "line_end": str(chunk.line_end), + } + ) + return docs + + +def _get_index_candidates( + files: list[Dict[str, Any]], keyword: str +) -> list[Dict[str, Any]]: + return sorted( + [ + file + for file in files + if _is_indexable_example_file(file["path"], int(file.get("size", 0))) + ], + key=lambda file: _rank_index_candidate(file, keyword), + )[:MAX_INDEXED_EXAMPLE_FILES] + + +def _excerpt_around_query(content: str, query: str, *, max_chars: int = 900) -> str: + if len(content) <= max_chars: + return content + + terms = [ + term.lower() + for term in query.replace("_", " ").split() + if len(term.strip()) >= 3 + ] + content_lower = content.lower() + first_match = min( + (index for term in terms if (index := content_lower.find(term)) >= 0), + default=0, + ) + start = max(0, first_match - max_chars // 4) + end = min(len(content), start + max_chars) + if end - start < max_chars: + start = max(0, end - max_chars) + + excerpt = content[start:end] + if start > 0: + excerpt = "...\n" + excerpt + if end < len(content): + excerpt += "\n..." + return excerpt + + def _get_pattern_priority(file_path: str) -> tuple[int, int, int]: """ Get priority of a file path based on which example pattern directory it's in. @@ -284,14 +496,7 @@ def find_examples( Returns: ToolResult with matching files, or similar repos if repo not found """ - token = os.environ.get("GITHUB_TOKEN") - if not token: - return { - "formatted": "Error: GITHUB_TOKEN environment variable is required", - "totalResults": 0, - "resultsShared": 0, - "isError": True, - } + token = os.environ.get("GITHUB_TOKEN", "") if not repo: return { @@ -323,14 +528,45 @@ def find_examples( "resultsShared": 0, } - # Step 2: If keyword provided, score and filter by keyword + snippet_hits: list[Dict[str, Any]] = [] + + # Step 2: If keyword provided, score paths and search file contents. if keyword: + snippet_hits = _search_example_snippets( + keyword, + org, + repo, + example_files, + token, + limit=max(max_results * 2, 10), + ) + scored_files = [] for file in example_files: keyword_score = _score_against_keyword(file["path"], keyword) if keyword_score >= min_score: scored_files.append({**file, "score": keyword_score}) + if snippet_hits: + snippet_scores: dict[str, float] = {} + for hit in snippet_hits: + path = hit.get("path", "") + snippet_scores[path] = max( + snippet_scores.get(path, 0.0), float(hit["score"]) + ) + + seen_paths = {file["path"] for file in scored_files} + for file in example_files: + if file["path"] in snippet_scores and file["path"] not in seen_paths: + scored_files.append( + { + **file, + "score": min(100, int(70 + snippet_scores[file["path"]] * 10)), + "content_score": snippet_scores[file["path"]], + } + ) + seen_paths.add(file["path"]) + if not scored_files: return { "formatted": f"No files found in {org}/{repo} matching keyword '{keyword}' (min score: {min_score}) among {len(example_files)} example files.", @@ -338,8 +574,11 @@ def find_examples( "resultsShared": 0, } - # Sort by keyword score (descending) for best matches first - scored_files.sort(key=lambda x: x["score"], reverse=True) + # Prefer files with content hits, then path similarity. + scored_files.sort( + key=lambda x: (float(x.get("content_score", 0.0)), x["score"]), + reverse=True, + ) else: # No keyword: prioritize by pattern directory, then path depth scored_files = [] @@ -394,6 +633,27 @@ def find_examples( lines.append(f" To read, use: {read_params}") lines.append("") + if snippet_hits: + lines.append("## Best indexed code snippets") + lines.append( + "Use these line ranges with `github_read_file` before reading whole files." + ) + lines.append("") + for i, hit in enumerate(snippet_hits[:max_results], 1): + path = hit["path"] + line_start = hit["line_start"] + line_end = hit["line_end"] + excerpt = _excerpt_around_query(hit["content"], keyword) + lines.append(f"{i}. **{path}:{line_start}-{line_end}**") + lines.append(f" Relevance score: {hit['score']:.2f}") + lines.append( + f" To read exactly: {{'repo': '{org}/{repo}', 'path': '{path}', 'line_start': {line_start}, 'line_end': {line_end}}}" + ) + lines.append(" ```") + lines.append(excerpt) + lines.append(" ```") + lines.append("") + return { "formatted": "\n".join(lines), "totalResults": len(results), @@ -406,10 +666,10 @@ def find_examples( "name": "github_find_examples", "description": ( "Find working example scripts in GitHub repositories (from a list of predetermined directories e.g. examples/, scripts/, tutorials/, etc.). " - "Uses fuzzy keyword matching.\n\n" + "Uses fuzzy path matching plus Tantivy content search over indexed code snippets when a keyword is provided.\n\n" "MANDATORY before writing any ML training, fine-tuning, or inference code. " "Your internal knowledge of library APIs is outdated — working examples show current API patterns.\n\n" - "Sequence: github_find_examples → github_read_file (study the example) → implement based on what you found.\n\n" + "Sequence: github_find_examples → github_read_file with the returned line_start/line_end ranges → implement based on what you found.\n\n" "Skip this only for: simple data queries, status checks, non-code tasks.\n\n" "Examples:\n" " {keyword: 'sft', repo: 'trl'} → finds examples/scripts/sft.py\n" @@ -421,7 +681,7 @@ def find_examples( "properties": { "keyword": { "type": "string", - "description": "Keyword to fuzzy match against file paths (e.g., 'grpo', 'sft').", + "description": "Keyword to search against file paths and indexed code snippets (e.g., 'grpo', 'sft', 'dataset_text_field').", }, "repo": { "type": "string", diff --git a/pyproject.toml b/pyproject.toml index 432085e0..c462c3f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,13 +20,13 @@ dependencies = [ "rich>=13.0.0", "nbconvert>=7.16.6", "nbformat>=5.10.4", - "whoosh>=2.7.4", # Web backend dependencies "fastapi>=0.115.0", "uvicorn[standard]>=0.32.0", "httpx>=0.27.0", "websockets>=13.0", "apscheduler>=3.10,<4", + "tantivy>=0.25.1", ] [project.optional-dependencies] diff --git a/tests/unit/test_docs_tantivy_search.py b/tests/unit/test_docs_tantivy_search.py new file mode 100644 index 00000000..cffae09b --- /dev/null +++ b/tests/unit/test_docs_tantivy_search.py @@ -0,0 +1,114 @@ +import pytest + +from agent.tools import docs_tools + + +@pytest.mark.asyncio +async def test_search_docs_returns_best_passage_with_source_lines(): + docs_tools._index_cache.clear() + docs = [ + { + "title": "SFT Trainer", + "url": "https://huggingface.co/docs/trl/sft_trainer", + "md_url": "https://huggingface.co/docs/trl/sft_trainer.md", + "section": "trl", + "glimpse": "", + "content": ( + "# SFT Trainer\n\n" + "Overview of supervised fine tuning.\n\n" + "## Dataset processing\n\n" + "Set dataset_text_field on SFTConfig when your dataset stores text " + "in a custom column." + ), + }, + { + "title": "DPO Trainer", + "url": "https://huggingface.co/docs/trl/dpo_trainer", + "md_url": "https://huggingface.co/docs/trl/dpo_trainer.md", + "section": "trl", + "glimpse": "", + "content": "# DPO Trainer\n\nPreference optimization reference.", + }, + ] + + results, note = await docs_tools._search_docs( + "trl", docs, "dataset_text_field SFTConfig", 3 + ) + + assert note is None + assert results[0]["title"] == "SFT Trainer / Dataset processing" + assert results[0]["line_start"] == "5" + assert results[0]["line_end"] == "7" + assert "dataset_text_field" in results[0]["glimpse"] + + +@pytest.mark.asyncio +async def test_explore_hf_docs_allows_public_docs_without_token(monkeypatch): + async def fake_get_docs(hf_token, endpoint): + assert hf_token == "" + assert endpoint == "trl" + return [ + { + "title": "SFT Trainer", + "url": "https://huggingface.co/docs/trl/sft_trainer", + "md_url": "https://huggingface.co/docs/trl/sft_trainer.md", + "section": "trl", + "glimpse": "Use SFTConfig with dataset_text_field.", + "content": "# SFT Trainer\n\nUse SFTConfig with dataset_text_field.", + } + ] + + monkeypatch.setattr(docs_tools, "_get_docs", fake_get_docs) + + text, ok = await docs_tools.explore_hf_docs_handler( + {"endpoint": "trl", "query": "dataset_text_field", "max_results": 1}, + session=None, + ) + + assert ok is True + assert "SFT Trainer" in text + assert "Lines:" in text + + +@pytest.mark.asyncio +async def test_search_openapi_uses_tantivy_index(monkeypatch): + docs_tools._openapi_cache = None + docs_tools._openapi_index_cache = None + + async def fake_fetch_openapi_spec(): + return { + "servers": [{"url": "https://huggingface.co"}], + "paths": { + "/api/repos/create": { + "post": { + "operationId": "createRepo", + "summary": "Create a repository", + "description": "Create model, dataset, or Space repositories.", + "tags": ["Repo"], + "parameters": [ + {"name": "name", "in": "query", "schema": {"type": "string"}} + ], + "responses": {"200": {"description": "Created"}}, + } + }, + "/api/models/{repo_id}": { + "get": { + "operationId": "modelInfo", + "summary": "Get model info", + "description": "Retrieve model metadata.", + "tags": ["Model"], + "parameters": [], + "responses": {"200": {"description": "OK"}}, + } + }, + }, + } + + monkeypatch.setattr(docs_tools, "_fetch_openapi_spec", fake_fetch_openapi_spec) + + results, note = await docs_tools._search_openapi("create repository", "Repo", 5) + + assert note is None + assert len(results) == 1 + assert results[0]["path"] == "/api/repos/create" + assert results[0]["operationId"] == "createRepo" diff --git a/tests/unit/test_github_find_examples_tantivy.py b/tests/unit/test_github_find_examples_tantivy.py new file mode 100644 index 00000000..bd7530ca --- /dev/null +++ b/tests/unit/test_github_find_examples_tantivy.py @@ -0,0 +1,159 @@ +from agent.tools import github_find_examples + + +class _FakeResponse: + def __init__(self, status_code): + self.status_code = status_code + + +def test_github_get_retries_public_request_when_token_is_rejected(monkeypatch): + seen_auth = [] + + def fake_get(url, headers, **kwargs): + seen_auth.append(headers.get("Authorization")) + return _FakeResponse(401 if headers.get("Authorization") else 200) + + monkeypatch.setattr(github_find_examples.requests, "get", fake_get) + + response = github_find_examples._github_get("https://api.github.com/repos/x/y", "bad") + + assert response.status_code == 200 + assert seen_auth == ["Bearer bad", None] + + +def test_excerpt_around_query_skips_unrelated_prefix(): + content = "license header\n" * 80 + "trainer = GRPOTrainer(args=config)\n" + + excerpt = github_find_examples._excerpt_around_query( + content, "grpo trainer", max_chars=160 + ) + + assert excerpt.startswith("...\n") + assert "GRPOTrainer" in excerpt + assert len(excerpt) < 220 + + +def test_search_example_snippets_finds_content_only_match(monkeypatch): + files = [ + { + "path": "examples/scripts/sft.py", + "ref": "abc123", + "size": 240, + "branch": "main", + "url": "https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py", + }, + { + "path": "examples/scripts/dpo.py", + "ref": "def456", + "size": 120, + "branch": "main", + "url": "https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py", + }, + ] + + def fake_fetch_file_content_cached(org, repo, file, token): + if file["path"].endswith("sft.py"): + return ( + "from trl import SFTConfig, SFTTrainer\n\n" + "config = SFTConfig(dataset_text_field='text', packing=True)\n" + "trainer = SFTTrainer(args=config)\n" + ) + return "from trl import DPOTrainer\n" + + monkeypatch.setattr( + github_find_examples, + "_fetch_file_content_cached", + fake_fetch_file_content_cached, + ) + + hits = github_find_examples._search_example_snippets( + "dataset_text_field packing", + "huggingface", + "trl", + files, + "token", + limit=3, + ) + + assert hits[0]["path"] == "examples/scripts/sft.py" + assert hits[0]["line_start"] == "1" + assert hits[0]["line_end"] == "4" + assert "dataset_text_field" in hits[0]["content"] + + +def test_find_examples_promotes_snippet_hit_from_file_content(monkeypatch): + files = [ + { + "path": "examples/scripts/sft.py", + "ref": "abc123", + "size": 240, + "branch": "main", + "url": "https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py", + } + ] + + monkeypatch.setenv("GITHUB_TOKEN", "token") + monkeypatch.setattr( + github_find_examples, + "_get_repo_tree", + lambda org, repo, token: (files, ""), + ) + monkeypatch.setattr( + github_find_examples, + "_search_example_snippets", + lambda keyword, org, repo, files, token, limit: [ + { + "path": "examples/scripts/sft.py", + "url": "https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py", + "ref": "abc123", + "size": "240", + "heading": "config = SFTConfig(...)", + "content": "config = SFTConfig(dataset_text_field='text', packing=True)", + "line_start": "3", + "line_end": "3", + "score": 1.5, + } + ], + ) + + result = github_find_examples.find_examples( + keyword="dataset_text_field", + repo="trl", + org="huggingface", + max_results=1, + min_score=95, + ) + + assert not result.get("isError", False) + assert "Best indexed code snippets" in result["formatted"] + assert "'line_start': 3, 'line_end': 3" in result["formatted"] + + +def test_find_examples_allows_public_repo_without_github_token(monkeypatch): + files = [ + { + "path": "examples/scripts/sft.py", + "ref": "abc123", + "size": 240, + "branch": "main", + "url": "https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py", + } + ] + + monkeypatch.delenv("GITHUB_TOKEN", raising=False) + + def fake_get_repo_tree(org, repo, token): + assert token == "" + return files, "" + + monkeypatch.setattr(github_find_examples, "_get_repo_tree", fake_get_repo_tree) + monkeypatch.setattr( + github_find_examples, + "_search_example_snippets", + lambda keyword, org, repo, files, token, limit: [], + ) + + result = github_find_examples.find_examples(repo="trl", org="huggingface") + + assert not result.get("isError", False) + assert "examples/scripts/sft.py" in result["formatted"] diff --git a/tests/unit/test_tantivy_search.py b/tests/unit/test_tantivy_search.py new file mode 100644 index 00000000..64411e69 --- /dev/null +++ b/tests/unit/test_tantivy_search.py @@ -0,0 +1,54 @@ +from agent.search import TantivyTextIndex, chunk_code, chunk_markdown + + +def test_tantivy_text_index_ranks_field_boosted_hits(): + index = TantivyTextIndex( + text_fields=["title", "content"], + stored_fields=["url"], + field_boosts={"title": 3.0, "content": 1.0}, + ) + index.add_documents( + [ + { + "title": "SFTTrainer dataset_text_field", + "content": "Configuration reference for supervised fine tuning.", + "url": "https://example.test/sft", + }, + { + "title": "Generic training loop", + "content": "dataset_text_field appears in the body only.", + "url": "https://example.test/generic", + }, + ] + ) + + hits, errors = index.search("dataset_text_field", limit=2) + + assert errors == [] + assert [hit.fields["url"] for hit in hits] == [ + "https://example.test/sft", + "https://example.test/generic", + ] + + +def test_chunk_markdown_preserves_heading_and_line_range(): + content = "# Intro\n\nStart here\n\n## Packing\n\nUse packing with SFTConfig." + + chunks = chunk_markdown(content) + + assert chunks[-1].title == "Packing" + assert chunks[-1].line_start == 5 + assert chunks[-1].line_end == 7 + assert "SFTConfig" in chunks[-1].text + + +def test_chunk_code_uses_overlapping_line_windows(): + content = "\n".join(f"line_{i}" for i in range(1, 121)) + + chunks = chunk_code(content, window=50, overlap=10) + + assert [(chunk.line_start, chunk.line_end) for chunk in chunks] == [ + (1, 50), + (41, 90), + (81, 120), + ] diff --git a/uv.lock b/uv.lock index 3bddba0d..d4661e0e 100644 --- a/uv.lock +++ b/uv.lock @@ -1787,10 +1787,10 @@ dependencies = [ { name = "python-dotenv" }, { name = "requests" }, { name = "rich" }, + { name = "tantivy" }, { name = "thefuzz" }, { name = "uvicorn", extra = ["standard"] }, { name = "websockets" }, - { name = "whoosh" }, ] [package.optional-dependencies] @@ -1836,11 +1836,11 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "requests", specifier = ">=2.33.0" }, { name = "rich", specifier = ">=13.0.0" }, + { name = "tantivy", specifier = ">=0.25.1" }, { name = "tenacity", marker = "extra == 'eval'", specifier = ">=8.0.0" }, { name = "thefuzz", specifier = ">=0.22.1" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.32.0" }, { name = "websockets", specifier = ">=13.0" }, - { name = "whoosh", specifier = ">=2.7.4" }, ] provides-extras = ["eval", "dev", "all"] @@ -3437,6 +3437,39 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/52/1064f510b141bd54025f9b55105e26d1fa970b9be67ad766380a3c9b74b0/starlette-0.50.0-py3-none-any.whl", hash = "sha256:9e5391843ec9b6e472eed1365a78c8098cfceb7a74bfd4d6b1c0c0095efb3bca", size = 74033, upload-time = "2025-11-01T15:25:25.461Z" }, ] +[[package]] +name = "tantivy" +version = "0.25.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/f9/0cd3955d155d3e3ef74b864769514dd191e5dacba9f0beb7af2d914942ce/tantivy-0.25.1.tar.gz", hash = "sha256:68a3314699a7d18fcf338b52bae8ce46a97dde1128a3e47e33fa4db7f71f265e", size = 75120, upload-time = "2025-12-02T11:57:12.997Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/7a/8a277f377e8a151fc0e71d4ffc1114aefb6e5e1c7dd609fed0955cf34ed8/tantivy-0.25.1-cp311-cp311-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:d363d7b4207d3a5aa7f0d212420df35bed18bdb6bae26a2a8bd57428388b7c29", size = 7637033, upload-time = "2025-12-02T11:56:18.104Z" }, + { url = "https://files.pythonhosted.org/packages/71/31/8b4acdedfc9f9a2d04b1340d07eef5213d6f151d1e18da0cb423e5f090d2/tantivy-0.25.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8f4389cf1d889a1df7c5a3195806b4b56c37cee10d8a26faaa0dea35a867b5ff", size = 3932180, upload-time = "2025-12-02T11:56:19.833Z" }, + { url = "https://files.pythonhosted.org/packages/2f/dc/3e8499c21b4b9795e8f2fc54c68ce5b92905aaeadadaa56ecfa9180b11b1/tantivy-0.25.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99864c09fc54652c3c2486cdf13f86cdc8200f4b481569cb291e095ca5d496e5", size = 4197620, upload-time = "2025-12-02T11:56:21.496Z" }, + { url = "https://files.pythonhosted.org/packages/f8/8e/f2ce62fffc811eb62bead92c7b23c2e218f817cbd54c4f3b802e03ba1438/tantivy-0.25.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05abf37ddbc5063c575548be0d62931629c086bff7a5a1b67cf5a8f5ebf4cd8c", size = 4183794, upload-time = "2025-12-02T11:56:23.215Z" }, + { url = "https://files.pythonhosted.org/packages/de/64/24e2891b0ba3fd9853e10c296095a33b89bf3efd65e29da1ee5dae736040/tantivy-0.25.1-cp311-cp311-win_amd64.whl", hash = "sha256:f307ee8ad21597b0be23af83008fd66cfd5f958cdfa24ec0aaa08a38e86bbef4", size = 3424235, upload-time = "2025-12-02T11:56:25.172Z" }, + { url = "https://files.pythonhosted.org/packages/41/e7/6849c713ed0996c7628324c60512c4882006f0a62145e56c624a93407f90/tantivy-0.25.1-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:90fd919e5f611809f746560ecf36eb9be824dec62e21ae17a27243759edb9aa1", size = 7621494, upload-time = "2025-12-02T11:56:27.069Z" }, + { url = "https://files.pythonhosted.org/packages/c5/22/c3d8294600dc6e7fa350daef9ff337d3c06e132b81df727de9f7a50c692a/tantivy-0.25.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:4613c7cf6c23f3a97989819690a0f956d799354957de7a204abcc60083cebe02", size = 3925219, upload-time = "2025-12-02T11:56:29.403Z" }, + { url = "https://files.pythonhosted.org/packages/41/fc/cbb1df71dd44c9110eff4eaaeda9d44f2d06182fe0452193be20ddfba93f/tantivy-0.25.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c477bd20b4df804d57dfc5033431bef27cde605695ae141b03abbf6ebc069129", size = 4198699, upload-time = "2025-12-02T11:56:31.359Z" }, + { url = "https://files.pythonhosted.org/packages/47/4d/71abb78b774073c3ce12a4faa4351a9d910a71ffa3659526affba163873d/tantivy-0.25.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9b1a1ba1113c523c7ff7b10f282d6c4074006f7ef8d71e1d973d51bf7291ddb", size = 4183585, upload-time = "2025-12-02T11:56:33.317Z" }, + { url = "https://files.pythonhosted.org/packages/be/16/3f00cd7ec458b92a0e977960af9ddfbeb762127d9acc68da9094a1fda556/tantivy-0.25.1-cp312-cp312-win_amd64.whl", hash = "sha256:9de0bafd3bd7ac9f8f82d53e17562e9db11a5af308fe5185c4bd86feaddbe4a6", size = 3424622, upload-time = "2025-12-02T11:56:34.788Z" }, + { url = "https://files.pythonhosted.org/packages/3d/25/73cfbcf1a8ea49be6c42817431cac46b70a119fe64da903fcc2d92b5b511/tantivy-0.25.1-cp313-cp313-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:f51ff7196c6f31719202080ed8372d5e3d51e92c749c032fb8234f012e99744c", size = 7622530, upload-time = "2025-12-02T11:56:36.839Z" }, + { url = "https://files.pythonhosted.org/packages/12/c8/c0d7591cdf4f7e7a9fc4da786d1ca8cd1aacffaa2be16ea6d401a8e4a566/tantivy-0.25.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:550e63321bfcacc003859f2fa29c1e8e56450807b3c9a501c1add27cfb9236d9", size = 3925637, upload-time = "2025-12-02T11:56:38.425Z" }, + { url = "https://files.pythonhosted.org/packages/3a/09/bedfc223bffec7641b417dd7ab071134b2ef8f8550e9b1fb6014657ef52e/tantivy-0.25.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fde31cc8d6e122faf7902aeea32bc008a429a6e8904e34d3468126a3ec01b016", size = 4197322, upload-time = "2025-12-02T11:56:40.411Z" }, + { url = "https://files.pythonhosted.org/packages/f5/f1/1fa5183500c8042200c9f2b840d34f5bbcfb434a1ee750e7132262d2a5c9/tantivy-0.25.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b11bd5a518b0be645320b47af8493f6a40c4f3234313e37adcf4534a564d27dd", size = 4183143, upload-time = "2025-12-02T11:56:42.048Z" }, + { url = "https://files.pythonhosted.org/packages/d5/74/a4c4f4eb95888ccb784da3b017aa0625ab1ac411bf5d022a9a797d9a2334/tantivy-0.25.1-cp313-cp313-win_amd64.whl", hash = "sha256:cc7fe88853e06b3251ee4fa42b7a2038727f850c8765bcc8167cfc73585dd24e", size = 3423491, upload-time = "2025-12-02T11:56:43.858Z" }, + { url = "https://files.pythonhosted.org/packages/8b/2f/581519492226f97d23bd0adc95dad991ebeaa73ea6abc8bff389a3096d9a/tantivy-0.25.1-cp313-cp313t-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:dae99e75b7eaa9bf5bd16ab106b416370f08c135aed0e117d62a3201cd1ffe36", size = 7610316, upload-time = "2025-12-02T11:56:45.927Z" }, + { url = "https://files.pythonhosted.org/packages/91/40/5d7bc315ab9e6a22c5572656e8ada1c836cfa96dccf533377504fbc3c9d9/tantivy-0.25.1-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:506e9533c5ef4d3df43bad64ffecc0aa97c76e361ea610815dc3a20a9d6b30b3", size = 3919882, upload-time = "2025-12-02T11:56:48.469Z" }, + { url = "https://files.pythonhosted.org/packages/02/b9/e0ef2f57a6a72444cb66c2ffbc310ab33ffaace275f1c4b0319d84ea3f18/tantivy-0.25.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dbd4f8f264dacbcc9dee542832da2173fd53deaaea03f082d95214f8b5ed6bc", size = 4196031, upload-time = "2025-12-02T11:56:50.151Z" }, + { url = "https://files.pythonhosted.org/packages/1e/02/bf3f8cacfd08642e14a73f7956a3fb95d58119132c98c121b9065a1f8615/tantivy-0.25.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:824c643ccb640dd9e35e00c5d5054ddf3323f56fe4219d57d428a9eeea13d22c", size = 4183437, upload-time = "2025-12-02T11:56:51.818Z" }, + { url = "https://files.pythonhosted.org/packages/9c/83/afa90e570198e2d1139dd567bec3c9cf44d8c54f63a649f16d711ede02f5/tantivy-0.25.1-cp313-cp313t-win_amd64.whl", hash = "sha256:09c987b840afcebac817836ac08407eff17272d8aa60ce6e291f89c81830221d", size = 3419409, upload-time = "2025-12-02T11:56:53.451Z" }, + { url = "https://files.pythonhosted.org/packages/ff/44/9f1d67aa5030f7eebc966c863d1316a510a971dd8bb45651df4acdfae9ed/tantivy-0.25.1-cp314-cp314-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7f5d29ae85dd0f23df8d15b3e7b341d4f9eb5a446bbb9640df48ac1f6d9e0c6c", size = 7623723, upload-time = "2025-12-02T11:56:55.066Z" }, + { url = "https://files.pythonhosted.org/packages/db/30/6e085bd3ed9d12da3c91c185854abd70f9dfd35fb36a75ea98428d42c30b/tantivy-0.25.1-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:f2d2938fb69a74fc1bb36edfaf7f0d1596fa1264db0f377bda2195c58bcb6245", size = 3926243, upload-time = "2025-12-02T11:56:57.058Z" }, + { url = "https://files.pythonhosted.org/packages/32/f5/a00d65433430f51718e5cc6938df571765d7c4e03aedec5aef4ab567aa9b/tantivy-0.25.1-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f5ff124c4802558e627091e780b362ca944169736caba5a372eef39a79d0ae0", size = 4207186, upload-time = "2025-12-02T11:56:58.803Z" }, + { url = "https://files.pythonhosted.org/packages/19/63/61bdb12fc95f2a7f77bd419a5149bfa9f28caa76cb569bf2b6b06e1d033e/tantivy-0.25.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43b80ef62a340416139c93d19264e5f808da48e04f9305f1092b8ed22be0a5be", size = 4187312, upload-time = "2025-12-02T11:57:00.595Z" }, + { url = "https://files.pythonhosted.org/packages/b7/de/e39c0b01d59019bf5c38face8b81defbc4a68cebf5e0c53bcb2cd715a449/tantivy-0.25.1-cp314-cp314-win_amd64.whl", hash = "sha256:286b654f40c70c1e6b64b9bc7031ed0bf5c440f5bffeaeeee21a0ee6cc39f0e2", size = 3436535, upload-time = "2025-12-02T11:57:02.267Z" }, +] + [[package]] name = "tenacity" version = "9.1.2" @@ -3910,15 +3943,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] -[[package]] -name = "whoosh" -version = "2.7.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/25/2b/6beed2107b148edc1321da0d489afc4617b9ed317ef7b72d4993cad9b684/Whoosh-2.7.4.tar.gz", hash = "sha256:7ca5633dbfa9e0e0fa400d3151a8a0c4bec53bd2ecedc0a67705b17565c31a83", size = 968741, upload-time = "2016-04-04T01:19:32.327Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ba/19/24d0f1f454a2c1eb689ca28d2f178db81e5024f42d82729a4ff6771155cf/Whoosh-2.7.4-py2.py3-none-any.whl", hash = "sha256:aa39c3c3426e3fd107dcb4bde64ca1e276a65a889d9085a6e4b54ba82420a852", size = 468790, upload-time = "2016-04-04T01:19:40.379Z" }, -] - [[package]] name = "wrapt" version = "1.17.3"