Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ async def _handle_slash_command(
return None


async def main(model: str | None = None):
async def main():
"""Interactive chat with the agent"""

# Clear screen
Expand All @@ -822,8 +822,6 @@ async def main(model: str | None = None):
hf_token = await _prompt_and_save_hf_token(prompt_session)

config = load_config(CLI_CONFIG_PATH)
if model:
config.model_name = model

# Resolve username for banner
hf_user = _get_hf_user(hf_token)
Expand Down Expand Up @@ -1223,8 +1221,6 @@ def cli():
_configure_runtime_logging()
# Suppress litellm pydantic deprecation warnings
warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm")
# Suppress whoosh invalid escape sequence warnings (third-party, unfixed upstream)
warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh")

parser = argparse.ArgumentParser(description="Hugging Face Agent CLI")
parser.add_argument("prompt", nargs="?", default=None, help="Run headlessly with this prompt")
Expand All @@ -1242,7 +1238,7 @@ def cli():
max_iter = 10_000 # effectively unlimited
asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream))
else:
asyncio.run(main(model=args.model))
asyncio.run(main())
except KeyboardInterrupt:
print("\n\nGoodbye!")

Expand Down
6 changes: 6 additions & 0 deletions agent/search/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Search infrastructure used by research tools."""

from agent.search.chunking import chunk_code, chunk_markdown
from agent.search.tantivy_index import SearchHit, TantivyTextIndex

__all__ = ["SearchHit", "TantivyTextIndex", "chunk_code", "chunk_markdown"]
48 changes: 48 additions & 0 deletions agent/search/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Small JSON cache for network-backed research search indexes."""

from __future__ import annotations

import hashlib
import json
import os
import time
from pathlib import Path
from typing import Any


DEFAULT_TTL_SECONDS = 7 * 24 * 60 * 60


def cache_root() -> Path:
configured = os.environ.get("ML_INTERN_SEARCH_CACHE_DIR")
if configured:
return Path(configured)
return Path.cwd() / ".ml-intern-cache" / "search"


def read_json(namespace: str, key: str, *, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> Any | None:
path = _path(namespace, key)
try:
if time.time() - path.stat().st_mtime > ttl_seconds:
return None
return json.loads(path.read_text(encoding="utf-8"))
except (FileNotFoundError, json.JSONDecodeError, OSError):
return None


def write_json(namespace: str, key: str, value: Any) -> None:
path = _path(namespace, key)
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(".tmp")
tmp.write_text(json.dumps(value), encoding="utf-8")
tmp.replace(path)


def stable_key(*parts: object) -> str:
raw = "\x1f".join(str(part) for part in parts)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()


def _path(namespace: str, key: str) -> Path:
safe_namespace = namespace.replace("/", "_")
return cache_root() / safe_namespace / f"{key}.json"
117 changes: 117 additions & 0 deletions agent/search/chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Chunk text into source-addressable passages for search tools."""

from __future__ import annotations

import re
from dataclasses import dataclass


@dataclass(frozen=True)
class TextChunk:
title: str
text: str
line_start: int
line_end: int


def chunk_markdown(content: str, *, max_chars: int = 1800) -> list[TextChunk]:
"""Split markdown into heading-aware chunks with line ranges."""
lines = content.splitlines()
chunks: list[TextChunk] = []
heading = "Introduction"
buffer: list[tuple[int, str]] = []

def flush() -> None:
nonlocal buffer
if not buffer:
return
text = "\n".join(line for _, line in buffer).strip()
if text:
chunks.extend(_split_oversized(heading, buffer, max_chars=max_chars))
buffer = []

for index, line in enumerate(lines, 1):
heading_match = re.match(r"^(#{1,6})\s+(.+?)\s*$", line)
if heading_match:
flush()
heading = heading_match.group(2).strip()
buffer.append((index, line))
continue
buffer.append((index, line))
if sum(len(part) + 1 for _, part in buffer) >= max_chars:
flush()
flush()

if chunks:
return chunks
text = content.strip()
if not text:
return []
return [TextChunk(title="Content", text=text[:max_chars], line_start=1, line_end=len(lines) or 1)]


def chunk_code(content: str, *, window: int = 80, overlap: int = 15) -> list[TextChunk]:
"""Split source code into overlapping line windows."""
lines = content.splitlines()
if not lines:
return []
chunks: list[TextChunk] = []
step = max(1, window - overlap)
for start in range(0, len(lines), step):
end = min(len(lines), start + window)
chunk_lines = lines[start:end]
title = _guess_code_title(chunk_lines) or f"Lines {start + 1}-{end}"
chunks.append(
TextChunk(
title=title,
text="\n".join(chunk_lines).strip(),
line_start=start + 1,
line_end=end,
)
)
if end == len(lines):
break
return [chunk for chunk in chunks if chunk.text]


def _split_oversized(
heading: str, buffer: list[tuple[int, str]], *, max_chars: int
) -> list[TextChunk]:
chunks: list[TextChunk] = []
current: list[tuple[int, str]] = []
current_chars = 0
for item in buffer:
line_len = len(item[1]) + 1
if current and current_chars + line_len > max_chars:
chunks.append(_make_chunk(heading, current))
current = []
current_chars = 0
current.append(item)
current_chars += line_len
if current:
chunks.append(_make_chunk(heading, current))
return chunks


def _make_chunk(heading: str, items: list[tuple[int, str]]) -> TextChunk:
return TextChunk(
title=heading,
text="\n".join(line for _, line in items).strip(),
line_start=items[0][0],
line_end=items[-1][0],
)


def _guess_code_title(lines: list[str]) -> str | None:
for line in lines:
stripped = line.strip()
match = re.match(r"(async\s+def|def|class)\s+([A-Za-z_][\w]*)", stripped)
if match:
return stripped.rstrip(":")
if stripped.startswith("if __name__"):
return "Script entrypoint"
for line in lines:
stripped = line.strip()
if stripped and not stripped.startswith(("#", "//", "/*", "*")):
return stripped[:80]
return None
93 changes: 93 additions & 0 deletions agent/search/tantivy_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Small Tantivy wrapper for local, snippet-first research search."""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any

import tantivy


@dataclass(frozen=True)
class SearchHit:
score: float
fields: dict[str, Any]


class TantivyTextIndex:
"""A compact text index with stored metadata fields.

The wrapper keeps Tantivy-specific details out of tool handlers and gives the
research tools one simple operation: add stored documents, then search them
with BM25 ranking and field boosts.
"""

def __init__(
self,
*,
text_fields: list[str],
stored_fields: list[str] | None = None,
field_boosts: dict[str, float] | None = None,
path: Path | None = None,
) -> None:
self.text_fields = text_fields
self.stored_fields = list(dict.fromkeys([*(stored_fields or []), *text_fields]))
self.field_boosts = field_boosts or {}

builder = tantivy.SchemaBuilder()
for field in self.stored_fields:
tokenizer = "en_stem" if field in text_fields else "default"
builder.add_text_field(field, stored=True, tokenizer_name=tokenizer)
self.schema = builder.build()

if path is not None:
path.mkdir(parents=True, exist_ok=True)
self.index = tantivy.Index(self.schema, path=str(path))
else:
self.index = tantivy.Index(self.schema)

def add_documents(self, documents: list[dict[str, Any]]) -> None:
if not documents:
return

writer = self.index.writer(heap_size=30_000_000, num_threads=1)
for item in documents:
doc = tantivy.Document()
for field in self.stored_fields:
value = item.get(field, "")
if value is None:
value = ""
doc.add_text(field, str(value))
writer.add_document(doc)
writer.commit()
writer.wait_merging_threads()
self.index.reload()

def search(self, query: str, *, limit: int = 10) -> tuple[list[SearchHit], list[str]]:
clean_query = (query or "").strip()
if not clean_query:
return [], []

parsed, errors = self.index.parse_query_lenient(
clean_query,
self.text_fields,
field_boosts=self.field_boosts,
)
searcher = self.index.searcher()
results = searcher.search(parsed, limit)
hits: list[SearchHit] = []
for score, address in results.hits:
doc = searcher.doc(address).to_dict()
hits.append(SearchHit(score=float(score), fields=_flatten_doc(doc)))
return hits, [str(err) for err in errors]


def _flatten_doc(doc: dict[str, Any]) -> dict[str, Any]:
flattened: dict[str, Any] = {}
for key, value in doc.items():
if isinstance(value, list):
flattened[key] = value[0] if value else ""
else:
flattened[key] = value
return flattened
Loading