diff --git a/src/lean_explore/mcp/tools.py b/src/lean_explore/mcp/tools.py index 6a5535d..f543434 100644 --- a/src/lean_explore/mcp/tools.py +++ b/src/lean_explore/mcp/tools.py @@ -8,10 +8,32 @@ from lean_explore.mcp.app import AppContext, BackendServiceType, mcp_app from lean_explore.models import SearchResponse, SearchResult +from lean_explore.models.search_types import ( + SearchResultSummary, + SearchSummaryResponse, + extract_bold_description, +) + + +class SearchResultSummaryDict(TypedDict, total=False): + """Serialized SearchResultSummary for slim MCP search responses.""" + + id: int + name: str + description: str | None + + +class SearchSummaryResponseDict(TypedDict, total=False): + """Serialized SearchSummaryResponse for slim MCP search responses.""" + + query: str + results: list[SearchResultSummaryDict] + count: int + processing_time_ms: int | None class SearchResultDict(TypedDict, total=False): - """Serialized SearchResult for MCP tool responses.""" + """Serialized SearchResult for verbose MCP tool responses.""" id: int name: str @@ -24,7 +46,7 @@ class SearchResultDict(TypedDict, total=False): class SearchResponseDict(TypedDict, total=False): - """Serialized SearchResponse for MCP tool responses.""" + """Serialized SearchResponse for verbose MCP tool responses.""" query: str results: list[SearchResultDict] @@ -55,6 +77,41 @@ async def _get_backend_from_context(ctx: MCPContext) -> BackendServiceType: return backend +async def _execute_backend_search( + backend: BackendServiceType, + query: str, + limit: int, + rerank_top: int | None, + packages: list[str] | None, +) -> SearchResponse: + """Execute a search on the backend, handling both async and sync backends. + + Args: + backend: The backend service (ApiClient or Service). + query: The search query string. + limit: Maximum number of results. + rerank_top: Number of candidates to rerank with cross-encoder. + packages: Optional package filter. + + Returns: + The search response from the backend. + + Raises: + RuntimeError: If the backend does not support search. + """ + if not hasattr(backend, "search"): + logger.error("Backend service does not have a 'search' method.") + raise RuntimeError("Search functionality not available on configured backend.") + + if asyncio.iscoroutinefunction(backend.search): + return await backend.search( + query=query, limit=limit, rerank_top=rerank_top, packages=packages + ) + return backend.search( + query=query, limit=limit, rerank_top=rerank_top, packages=packages + ) + + @mcp_app.tool() async def search( ctx: MCPContext, @@ -62,8 +119,12 @@ async def search( limit: int = 10, rerank_top: int | None = 50, packages: list[str] | None = None, -) -> SearchResponseDict: - """Searches Lean declarations by a query string. +) -> SearchSummaryResponseDict: + """Searches Lean declarations and returns concise results. + + Returns slim results (id, name, short description) to minimize token usage. + Use get_by_id to retrieve full details for specific declarations, or + search_verbose to get all fields upfront. Args: ctx: The MCP context, providing access to the backend service. @@ -75,7 +136,7 @@ async def search( Defaults to None (all packages). Returns: - A dictionary containing the search response with results. + A dictionary containing slim search results with id, name, and description. """ backend = await _get_backend_from_context(ctx) logger.info( @@ -83,21 +144,65 @@ async def search( f"rerank_top: {rerank_top}, packages: {packages}" ) - if not hasattr(backend, "search"): - logger.error("Backend service does not have a 'search' method.") - raise RuntimeError("Search functionality not available on configured backend.") + response = await _execute_backend_search( + backend, query, limit, rerank_top, packages + ) - # Call backend search (handle both async and sync) - if asyncio.iscoroutinefunction(backend.search): - response: SearchResponse = await backend.search( - query=query, limit=limit, rerank_top=rerank_top, packages=packages - ) - else: - response: SearchResponse = backend.search( - query=query, limit=limit, rerank_top=rerank_top, packages=packages + # Convert full results to slim summaries + summary_results = [ + SearchResultSummary( + id=result.id, + name=result.name, + description=extract_bold_description(result.informalization), ) + for result in response.results + ] + summary_response = SearchSummaryResponse( + query=response.query, + results=summary_results, + count=response.count, + processing_time_ms=response.processing_time_ms, + ) + + return summary_response.model_dump(exclude_none=True) + + +@mcp_app.tool() +async def search_verbose( + ctx: MCPContext, + query: str, + limit: int = 10, + rerank_top: int | None = 50, + packages: list[str] | None = None, +) -> SearchResponseDict: + """Searches Lean declarations and returns full results with all fields. + + Returns complete results including source code, dependencies, module info, + and full informalization. Use this when you need all details upfront. For + a more concise overview, use search instead. + + Args: + ctx: The MCP context, providing access to the backend service. + query: A search query string, e.g., "continuous function". + limit: The maximum number of search results to return. Defaults to 10. + rerank_top: Number of candidates to rerank with cross-encoder. Set to 0 or + None to skip reranking. Defaults to 50. Only used with local backend. + packages: Filter results to specific packages (e.g., ["Mathlib", "Std"]). + Defaults to None (all packages). + + Returns: + A dictionary containing the full search response with all fields. + """ + backend = await _get_backend_from_context(ctx) + logger.info( + f"MCP Tool 'search_verbose' called with query: '{query}', limit: {limit}, " + f"rerank_top: {rerank_top}, packages: {packages}" + ) + + response = await _execute_backend_search( + backend, query, limit, rerank_top, packages + ) - # Return as dict for MCP return response.model_dump(exclude_none=True) @@ -108,6 +213,9 @@ async def get_by_id( ) -> SearchResultDict | None: """Retrieves a specific declaration by its unique identifier. + Returns the full declaration including source code, dependencies, module + info, and informalization. Use this to expand results from the search tool. + Args: ctx: The MCP context, providing access to the backend service. declaration_id: The unique integer identifier of the declaration. diff --git a/src/lean_explore/models/__init__.py b/src/lean_explore/models/__init__.py index b1cbee6..3088cd1 100644 --- a/src/lean_explore/models/__init__.py +++ b/src/lean_explore/models/__init__.py @@ -4,6 +4,20 @@ """ from lean_explore.models.search_db import Base, Declaration -from lean_explore.models.search_types import SearchResponse, SearchResult +from lean_explore.models.search_types import ( + SearchResponse, + SearchResult, + SearchResultSummary, + SearchSummaryResponse, + extract_bold_description, +) -__all__ = ["Base", "Declaration", "SearchResult", "SearchResponse"] +__all__ = [ + "Base", + "Declaration", + "SearchResult", + "SearchResponse", + "SearchResultSummary", + "SearchSummaryResponse", + "extract_bold_description", +] diff --git a/src/lean_explore/models/search_types.py b/src/lean_explore/models/search_types.py index 6d714e5..7a57239 100644 --- a/src/lean_explore/models/search_types.py +++ b/src/lean_explore/models/search_types.py @@ -1,8 +1,62 @@ """Type definitions for search results and related data structures.""" +import re + from pydantic import BaseModel, ConfigDict +def extract_bold_description(informalization: str | None) -> str | None: + """Extract the bold header text from an informalization string. + + Informalizations follow the pattern: **Bold Title.** Rest of description... + This function extracts just the bold title portion. + + Args: + informalization: The full informalization text, or None. + + Returns: + The bold header text (without ** markers), or None if no bold + header is found or input is None. + """ + if not informalization: + return None + match = re.match(r"\*\*(.+?)\*\*", informalization) + return match.group(1) if match else None + + +class SearchResultSummary(BaseModel): + """A slim search result containing only identification and description. + + Used by the MCP search tool to return concise results that minimize + token usage. Consumers can use the id to fetch full details via get_by_id. + """ + + id: int + """Primary key identifier.""" + + name: str + """Fully qualified Lean name (e.g., 'Nat.add').""" + + description: str | None + """Short description extracted from the informalization bold header.""" + + +class SearchSummaryResponse(BaseModel): + """Response from a slim search operation containing summary results.""" + + query: str + """The original search query string.""" + + results: list[SearchResultSummary] + """List of slim search results.""" + + count: int + """Number of results returned.""" + + processing_time_ms: int | None = None + """Processing time in milliseconds, if available.""" + + class SearchResult(BaseModel): """A search result representing a Lean declaration. diff --git a/tests/mcp/tools_test.py b/tests/mcp/tools_test.py index 482231c..b3815cf 100644 --- a/tests/mcp/tools_test.py +++ b/tests/mcp/tools_test.py @@ -1,13 +1,19 @@ """Tests for the MCP tools module. -These tests verify the MCP tool definitions for search and get_by_id operations. +These tests verify the MCP tool definitions for search, search_verbose, +and get_by_id operations. """ from unittest.mock import AsyncMock, MagicMock import pytest -from lean_explore.mcp.tools import _get_backend_from_context, get_by_id, search +from lean_explore.mcp.tools import ( + _get_backend_from_context, + get_by_id, + search, + search_verbose, +) from lean_explore.models import SearchResponse, SearchResult @@ -39,39 +45,65 @@ async def test_get_backend_not_available(self): await _get_backend_from_context(mock_ctx) +def _make_search_response( + query: str = "test query", + informalization: str | None = "**Test Title.** A test informalization.", +) -> SearchResponse: + """Create a SearchResponse with a single result for testing. + + Args: + query: The query string for the response. + informalization: The informalization text for the result. + + Returns: + A SearchResponse with one result. + """ + return SearchResponse( + query=query, + results=[ + SearchResult( + id=1, + name="Test.result", + module="Test.Module", + docstring="A test result", + source_text="def test := 1", + source_link="https://example.com", + dependencies=None, + informalization=informalization, + ) + ], + count=1, + processing_time_ms=42, + ) + + +def _make_mock_context(backend: MagicMock) -> MagicMock: + """Create a mock MCP context wrapping a backend service. + + Args: + backend: The mock backend service. + + Returns: + A mock MCP context with the backend attached. + """ + mock_app_context = MagicMock() + mock_app_context.backend_service = backend + + mock_ctx = MagicMock() + mock_ctx.request_context.lifespan_context = mock_app_context + return mock_ctx + + class TestSearchTool: - """Tests for the search MCP tool.""" + """Tests for the search MCP tool (slim results).""" @pytest.fixture def mock_context_with_backend(self): """Create a mock MCP context with a backend service.""" mock_backend = MagicMock() - mock_backend.search = AsyncMock( - return_value=SearchResponse( - query="test query", - results=[ - SearchResult( - id=1, - name="Test.result", - module="Test.Module", - docstring="A test result", - source_text="def test := 1", - source_link="https://example.com", - dependencies=None, - informalization="Test informalization", - ) - ], - count=1, - processing_time_ms=42, - ) - ) - - mock_app_context = MagicMock() - mock_app_context.backend_service = mock_backend - - mock_ctx = MagicMock() - mock_ctx.request_context.lifespan_context = mock_app_context + mock_backend.search = AsyncMock(return_value=_make_search_response()) + mock_ctx = _make_mock_context(mock_backend) return mock_ctx, mock_backend async def test_search_calls_backend(self, mock_context_with_backend): @@ -84,16 +116,56 @@ async def test_search_calls_backend(self, mock_context_with_backend): query="test query", limit=10, rerank_top=50, packages=None ) - async def test_search_returns_dict(self, mock_context_with_backend): - """Test that search returns a dictionary response.""" + async def test_search_returns_slim_dict(self, mock_context_with_backend): + """Test that search returns a slim dict with only id, name, description.""" mock_ctx, _ = mock_context_with_backend result = await search(mock_ctx, query="test query", limit=10) assert isinstance(result, dict) - assert "results" in result - assert "query" in result assert result["query"] == "test query" + assert result["count"] == 1 + + # Verify slim format: only id, name, description + search_result = result["results"][0] + assert search_result["id"] == 1 + assert search_result["name"] == "Test.result" + assert search_result["description"] == "Test Title." + + # Verify full fields are NOT present + assert "module" not in search_result + assert "source_text" not in search_result + assert "source_link" not in search_result + assert "docstring" not in search_result + assert "dependencies" not in search_result + assert "informalization" not in search_result + + async def test_search_extracts_bold_description(self): + """Test that search extracts the bold header from informalization.""" + response = _make_search_response( + informalization="**Continuous Map Between Topological Spaces.** " + "A function that preserves the topology." + ) + mock_backend = MagicMock() + mock_backend.search = AsyncMock(return_value=response) + mock_ctx = _make_mock_context(mock_backend) + + result = await search(mock_ctx, query="continuous") + + description = result["results"][0]["description"] + assert description == "Continuous Map Between Topological Spaces." + + async def test_search_handles_no_informalization(self): + """Test that search handles results without informalization.""" + response = _make_search_response(informalization=None) + mock_backend = MagicMock() + mock_backend.search = AsyncMock(return_value=response) + mock_ctx = _make_mock_context(mock_backend) + + result = await search(mock_ctx, query="test") + + # description should be excluded from output (exclude_none=True) + assert "description" not in result["results"][0] async def test_search_default_limit(self, mock_context_with_backend): """Test search with default limit and rerank_top parameters.""" @@ -119,11 +191,7 @@ async def test_search_backend_without_method(self): """Test error when backend lacks search method.""" mock_backend = MagicMock(spec=[]) # No methods - mock_app_context = MagicMock() - mock_app_context.backend_service = mock_backend - - mock_ctx = MagicMock() - mock_ctx.request_context.lifespan_context = mock_app_context + mock_ctx = _make_mock_context(mock_backend) with pytest.raises(RuntimeError, match="Search functionality not available"): await search(mock_ctx, query="test") @@ -138,16 +206,103 @@ async def test_search_with_sync_backend(self): ) mock_backend = MagicMock() - # Make search a regular function, not async mock_backend.search = MagicMock(return_value=mock_response) + mock_ctx = _make_mock_context(mock_backend) - mock_app_context = MagicMock() - mock_app_context.backend_service = mock_backend + result = await search(mock_ctx, query="test") - mock_ctx = MagicMock() - mock_ctx.request_context.lifespan_context = mock_app_context + mock_backend.search.assert_called_once() + assert result["count"] == 0 - result = await search(mock_ctx, query="test") + +class TestSearchVerboseTool: + """Tests for the search_verbose MCP tool (full results).""" + + @pytest.fixture + def mock_context_with_backend(self): + """Create a mock MCP context with a backend service.""" + mock_backend = MagicMock() + mock_backend.search = AsyncMock(return_value=_make_search_response()) + + mock_ctx = _make_mock_context(mock_backend) + return mock_ctx, mock_backend + + async def test_search_verbose_calls_backend(self, mock_context_with_backend): + """Test that search_verbose calls the backend search method.""" + mock_ctx, mock_backend = mock_context_with_backend + + await search_verbose(mock_ctx, query="test query", limit=10) + + mock_backend.search.assert_called_once_with( + query="test query", limit=10, rerank_top=50, packages=None + ) + + async def test_search_verbose_returns_full_dict(self, mock_context_with_backend): + """Test that search_verbose returns all fields.""" + mock_ctx, _ = mock_context_with_backend + + result = await search_verbose(mock_ctx, query="test query", limit=10) + + assert isinstance(result, dict) + assert result["query"] == "test query" + assert result["count"] == 1 + + # Verify full fields are present + search_result = result["results"][0] + assert search_result["id"] == 1 + assert search_result["name"] == "Test.result" + assert search_result["module"] == "Test.Module" + assert search_result["source_text"] == "def test := 1" + assert search_result["source_link"] == "https://example.com" + assert search_result["docstring"] == "A test result" + assert ( + search_result["informalization"] + == "**Test Title.** A test informalization." + ) + + async def test_search_verbose_default_limit(self, mock_context_with_backend): + """Test search_verbose with default limit and rerank_top parameters.""" + mock_ctx, mock_backend = mock_context_with_backend + + await search_verbose(mock_ctx, query="test") + + mock_backend.search.assert_called_once_with( + query="test", limit=10, rerank_top=50, packages=None + ) + + async def test_search_verbose_with_packages_filter(self, mock_context_with_backend): + """Test search_verbose with packages filter.""" + mock_ctx, mock_backend = mock_context_with_backend + + await search_verbose(mock_ctx, query="test", packages=["Mathlib", "Std"]) + + mock_backend.search.assert_called_once_with( + query="test", limit=10, rerank_top=50, packages=["Mathlib", "Std"] + ) + + async def test_search_verbose_backend_without_method(self): + """Test error when backend lacks search method.""" + mock_backend = MagicMock(spec=[]) # No methods + + mock_ctx = _make_mock_context(mock_backend) + + with pytest.raises(RuntimeError, match="Search functionality not available"): + await search_verbose(mock_ctx, query="test") + + async def test_search_verbose_with_sync_backend(self): + """Test search_verbose with a synchronous backend.""" + mock_response = SearchResponse( + query="test", + results=[], + count=0, + processing_time_ms=5, + ) + + mock_backend = MagicMock() + mock_backend.search = MagicMock(return_value=mock_response) + mock_ctx = _make_mock_context(mock_backend) + + result = await search_verbose(mock_ctx, query="test") mock_backend.search.assert_called_once() assert result["count"] == 0 @@ -173,11 +328,7 @@ def mock_context_with_backend(self): mock_backend = MagicMock() mock_backend.get_by_id = AsyncMock(return_value=mock_result) - mock_app_context = MagicMock() - mock_app_context.backend_service = mock_backend - - mock_ctx = MagicMock() - mock_ctx.request_context.lifespan_context = mock_app_context + mock_ctx = _make_mock_context(mock_backend) return mock_ctx, mock_backend, mock_result @@ -204,11 +355,7 @@ async def test_get_by_id_not_found(self): mock_backend = MagicMock() mock_backend.get_by_id = AsyncMock(return_value=None) - mock_app_context = MagicMock() - mock_app_context.backend_service = mock_backend - - mock_ctx = MagicMock() - mock_ctx.request_context.lifespan_context = mock_app_context + mock_ctx = _make_mock_context(mock_backend) result = await get_by_id(mock_ctx, declaration_id=99999) @@ -218,11 +365,7 @@ async def test_get_by_id_backend_without_method(self): """Test error when backend lacks get_by_id method.""" mock_backend = MagicMock(spec=[]) # No methods - mock_app_context = MagicMock() - mock_app_context.backend_service = mock_backend - - mock_ctx = MagicMock() - mock_ctx.request_context.lifespan_context = mock_app_context + mock_ctx = _make_mock_context(mock_backend) with pytest.raises(RuntimeError, match="Get by ID functionality not available"): await get_by_id(mock_ctx, declaration_id=1) @@ -241,14 +384,8 @@ async def test_get_by_id_with_sync_backend(self): ) mock_backend = MagicMock() - # Make get_by_id a regular function, not async mock_backend.get_by_id = MagicMock(return_value=mock_result) - - mock_app_context = MagicMock() - mock_app_context.backend_service = mock_backend - - mock_ctx = MagicMock() - mock_ctx.request_context.lifespan_context = mock_app_context + mock_ctx = _make_mock_context(mock_backend) result = await get_by_id(mock_ctx, declaration_id=1) diff --git a/tests/models/search_types_test.py b/tests/models/search_types_test.py new file mode 100644 index 0000000..85267b9 --- /dev/null +++ b/tests/models/search_types_test.py @@ -0,0 +1,137 @@ +"""Tests for search type models and utility functions.""" + +from lean_explore.models.search_types import ( + SearchResultSummary, + SearchSummaryResponse, + extract_bold_description, +) + + +class TestExtractBoldDescription: + """Tests for extract_bold_description utility function.""" + + def test_extracts_bold_header(self): + """Test extraction from a standard informalization.""" + text = "**Group Homomorphism.** A function that preserves group structure." + assert extract_bold_description(text) == "Group Homomorphism." + + def test_extracts_bold_header_without_trailing_period(self): + """Test extraction when bold header has no trailing period.""" + text = "**Linear Map** A linear transformation between vector spaces." + assert extract_bold_description(text) == "Linear Map" + + def test_returns_none_for_none_input(self): + """Test that None input returns None.""" + assert extract_bold_description(None) is None + + def test_returns_none_for_empty_string(self): + """Test that empty string returns None.""" + assert extract_bold_description("") is None + + def test_returns_none_when_no_bold_markers(self): + """Test that text without bold markers returns None.""" + text = "A plain description with no bold header." + assert extract_bold_description(text) is None + + def test_extracts_first_bold_section_only(self): + """Test that only the first bold section is extracted.""" + text = "**First Bold.** Some text **Second Bold.** More text." + assert extract_bold_description(text) == "First Bold." + + def test_handles_bold_with_special_characters(self): + """Test extraction with special characters in bold text.""" + text = "**Nat.add (Addition of Natural Numbers).** Adds two natural numbers." + assert ( + extract_bold_description(text) == "Nat.add (Addition of Natural Numbers)." + ) + + +class TestSearchResultSummary: + """Tests for SearchResultSummary model.""" + + def test_create_with_description(self): + """Test creating a summary with all fields.""" + summary = SearchResultSummary( + id=1, + name="Nat.add", + description="Addition of Natural Numbers.", + ) + assert summary.id == 1 + assert summary.name == "Nat.add" + assert summary.description == "Addition of Natural Numbers." + + def test_create_with_none_description(self): + """Test creating a summary with no description.""" + summary = SearchResultSummary( + id=42, + name="List.map", + description=None, + ) + assert summary.id == 42 + assert summary.name == "List.map" + assert summary.description is None + + def test_model_dump(self): + """Test serialization to dictionary.""" + summary = SearchResultSummary( + id=1, + name="Nat.add", + description="Addition.", + ) + data = summary.model_dump() + assert data == {"id": 1, "name": "Nat.add", "description": "Addition."} + + def test_model_dump_excludes_none(self): + """Test that model_dump with exclude_none omits None fields.""" + summary = SearchResultSummary(id=1, name="Nat.add", description=None) + data = summary.model_dump(exclude_none=True) + assert data == {"id": 1, "name": "Nat.add"} + + +class TestSearchSummaryResponse: + """Tests for SearchSummaryResponse model.""" + + def test_create_response(self): + """Test creating a summary response with results.""" + results = [ + SearchResultSummary(id=1, name="Nat.add", description="Addition."), + SearchResultSummary(id=2, name="Nat.mul", description="Multiplication."), + ] + response = SearchSummaryResponse( + query="natural number", + results=results, + count=2, + processing_time_ms=150, + ) + assert response.query == "natural number" + assert len(response.results) == 2 + assert response.count == 2 + assert response.processing_time_ms == 150 + + def test_create_response_without_processing_time(self): + """Test that processing_time_ms defaults to None.""" + response = SearchSummaryResponse( + query="test", + results=[], + count=0, + ) + assert response.processing_time_ms is None + + def test_model_dump(self): + """Test serialization of full response.""" + results = [ + SearchResultSummary(id=1, name="Nat.add", description="Addition."), + ] + response = SearchSummaryResponse( + query="add", + results=results, + count=1, + processing_time_ms=100, + ) + data = response.model_dump(exclude_none=True) + assert data == { + "query": "add", + "results": [{"id": 1, "name": "Nat.add", "description": "Addition."}], + "count": 1, + "processing_time_ms": 100, + }