From c55f116bed2c7663e651b65628c7de0066a797e7 Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Fri, 23 Jan 2026 12:01:31 -0800 Subject: [PATCH] feat: add MCP tools exposing review agent functionality Implements 5 MCP tools: review_pr, check_pr_size, check_pr_lint, get_review_history, and get_cost_summary. Uses lazy imports to avoid circular dependencies. Closes #14 Co-Authored-By: Claude Opus 4.5 --- src/pr_review_agent/mcp/tools.py | 333 +++++++++++++++++++++++++++++++ tests/test_mcp_tools.py | 138 +++++++++++++ 2 files changed, 471 insertions(+) create mode 100644 src/pr_review_agent/mcp/tools.py create mode 100644 tests/test_mcp_tools.py diff --git a/src/pr_review_agent/mcp/tools.py b/src/pr_review_agent/mcp/tools.py new file mode 100644 index 0000000..c4e5f3f --- /dev/null +++ b/src/pr_review_agent/mcp/tools.py @@ -0,0 +1,333 @@ +"""MCP tools exposing review agent functionality.""" + +import os +from pathlib import Path + +from mcp.types import TextContent, Tool + +from pr_review_agent.mcp.server import get_anthropic_key, get_github_token, server + + +@server.list_tools() +async def list_tools() -> list[Tool]: + """Return available MCP tools.""" + return [ + Tool( + name="review_pr", + description="Run the full review pipeline on a GitHub PR", + inputSchema={ + "type": "object", + "properties": { + "repo": { + "type": "string", + "description": "GitHub repo in owner/repo format", + }, + "pr_number": { + "type": "integer", + "description": "PR number to review", + }, + }, + "required": ["repo", "pr_number"], + }, + ), + Tool( + name="check_pr_size", + description="Run only the size gate on a PR", + inputSchema={ + "type": "object", + "properties": { + "repo": { + "type": "string", + "description": "GitHub repo in owner/repo format", + }, + "pr_number": { + "type": "integer", + "description": "PR number to check", + }, + }, + "required": ["repo", "pr_number"], + }, + ), + Tool( + name="check_pr_lint", + description="Run only the lint gate on a PR", + inputSchema={ + "type": "object", + "properties": { + "repo": { + "type": "string", + "description": "GitHub repo in owner/repo format", + }, + "pr_number": { + "type": "integer", + "description": "PR number to check", + }, + }, + "required": ["repo", "pr_number"], + }, + ), + Tool( + name="get_review_history", + description="Get past reviews for a repo or specific file", + inputSchema={ + "type": "object", + "properties": { + "repo": { + "type": "string", + "description": "GitHub repo in owner/repo format", + }, + "file_path": { + "type": "string", + "description": "Optional file path to filter by", + }, + }, + "required": ["repo"], + }, + ), + Tool( + name="get_cost_summary", + description="Get cost metrics for reviews", + inputSchema={ + "type": "object", + "properties": { + "repo": { + "type": "string", + "description": "Optional repo filter (owner/repo)", + }, + "days": { + "type": "integer", + "description": "Number of days to look back (default 30)", + }, + }, + "required": [], + }, + ), + ] + + +@server.call_tool() +async def call_tool(name: str, arguments: dict) -> list[TextContent]: + """Handle tool calls.""" + if name == "review_pr": + return await _review_pr(arguments) + elif name == "check_pr_size": + return await _check_pr_size(arguments) + elif name == "check_pr_lint": + return await _check_pr_lint(arguments) + elif name == "get_review_history": + return await _get_review_history(arguments) + elif name == "get_cost_summary": + return await _get_cost_summary(arguments) + else: + return [TextContent(type="text", text=f"Unknown tool: {name}")] + + +async def _review_pr(args: dict) -> list[TextContent]: + """Run full review pipeline.""" + from pr_review_agent.config import load_config + from pr_review_agent.gates.lint_gate import run_lint + from pr_review_agent.gates.size_gate import check_size + from pr_review_agent.github_client import GitHubClient + from pr_review_agent.review.confidence import calculate_confidence + from pr_review_agent.review.llm_reviewer import LLMReviewer + + repo = args["repo"] + pr_number = args["pr_number"] + owner, repo_name = repo.split("/") + + config = load_config(Path(".ai-review.yaml")) + github = GitHubClient(get_github_token()) + pr = github.fetch_pr(owner, repo_name, pr_number) + + # Size gate + size_result = check_size(pr, config) + if not size_result.passed: + return [TextContent( + type="text", + text=f"Size gate failed: {size_result.reason}", + )] + + # Lint gate + lint_result = run_lint(pr.files_changed, config) + if not lint_result.passed: + return [TextContent( + type="text", + text=f"Lint gate failed: {lint_result.error_count} errors", + )] + + # LLM review + reviewer = LLMReviewer(get_anthropic_key()) + review = reviewer.review( + diff=pr.diff, + pr_description=pr.description, + model=config.llm.default_model, + config=config, + ) + + confidence = calculate_confidence(review, pr, config) + + lines = [ + f"## Review: {repo}#{pr_number}", + f"**Confidence:** {confidence.score:.0%} ({confidence.level})", + f"**Summary:** {review.summary}", + "", + ] + + if review.issues: + lines.append(f"**Issues ({len(review.issues)}):**") + for issue in review.issues: + loc = f"{issue.file}:{issue.line}" if issue.line else issue.file + lines.append(f"- [{issue.severity}] {loc}: {issue.description}") + + lines.append(f"\nModel: {review.model} | Cost: ${review.cost_usd:.4f}") + + return [TextContent(type="text", text="\n".join(lines))] + + +async def _check_pr_size(args: dict) -> list[TextContent]: + """Run size gate only.""" + from pr_review_agent.config import load_config + from pr_review_agent.gates.size_gate import check_size + from pr_review_agent.github_client import GitHubClient + + repo = args["repo"] + pr_number = args["pr_number"] + owner, repo_name = repo.split("/") + + config = load_config(Path(".ai-review.yaml")) + github = GitHubClient(get_github_token()) + pr = github.fetch_pr(owner, repo_name, pr_number) + + result = check_size(pr, config) + status = "PASSED" if result.passed else "FAILED" + text = ( + f"Size gate: {status}\n" + f"Lines: {pr.lines_added + pr.lines_removed} " + f"(limit: {config.limits.max_lines_changed})\n" + f"Files: {len(pr.files_changed)} " + f"(limit: {config.limits.max_files_changed})" + ) + if not result.passed: + text += f"\nReason: {result.reason}" + + return [TextContent(type="text", text=text)] + + +async def _check_pr_lint(args: dict) -> list[TextContent]: + """Run lint gate only.""" + from pr_review_agent.config import load_config + from pr_review_agent.gates.lint_gate import run_lint + from pr_review_agent.github_client import GitHubClient + + repo = args["repo"] + pr_number = args["pr_number"] + owner, repo_name = repo.split("/") + + config = load_config(Path(".ai-review.yaml")) + github = GitHubClient(get_github_token()) + pr = github.fetch_pr(owner, repo_name, pr_number) + + result = run_lint(pr.files_changed, config) + status = "PASSED" if result.passed else "FAILED" + text = f"Lint gate: {status}\nErrors: {result.error_count}" + if not result.passed: + text += f"\nThreshold: {config.linting.fail_threshold}" + + return [TextContent(type="text", text=text)] + + +async def _get_review_history(args: dict) -> list[TextContent]: + """Get review history from Supabase.""" + supabase_url = os.environ.get("SUPABASE_URL") + supabase_key = os.environ.get("SUPABASE_KEY") + + if not supabase_url or not supabase_key: + return [TextContent( + type="text", + text="SUPABASE_URL and SUPABASE_KEY required for review history", + )] + + from supabase import create_client + + client = create_client(supabase_url, supabase_key) + repo = args["repo"] + owner, repo_name = repo.split("/") + + query = ( + client.table("review_events") + .select("*") + .eq("repo_owner", owner) + .eq("repo_name", repo_name) + .order("created_at", desc=True) + .limit(20) + ) + + result = query.execute() + reviews = result.data if result.data else [] + + if not reviews: + return [TextContent(type="text", text=f"No reviews found for {repo}")] + + lines = [f"## Review History: {repo} ({len(reviews)} recent)"] + for r in reviews: + confidence = r.get("confidence_score") + conf_str = f"{confidence:.0%}" if confidence else "N/A" + lines.append( + f"- PR#{r['pr_number']}: {r.get('outcome', 'unknown')} " + f"(confidence: {conf_str}, cost: ${r.get('cost_usd', 0):.4f})" + ) + + return [TextContent(type="text", text="\n".join(lines))] + + +async def _get_cost_summary(args: dict) -> list[TextContent]: + """Get cost summary from Supabase.""" + supabase_url = os.environ.get("SUPABASE_URL") + supabase_key = os.environ.get("SUPABASE_KEY") + + if not supabase_url or not supabase_key: + return [TextContent( + type="text", + text="SUPABASE_URL and SUPABASE_KEY required for cost summary", + )] + + from datetime import UTC, datetime + + from supabase import create_client + + client = create_client(supabase_url, supabase_key) + days = args.get("days", 30) + + since = datetime.now(UTC) + since = since.replace(hour=0, minute=0, second=0, microsecond=0) + from datetime import timedelta + + since = since - timedelta(days=days) + + query = ( + client.table("review_events") + .select("cost_usd,model_used,repo_owner,repo_name") + .eq("llm_called", True) + .gte("created_at", since.isoformat()) + ) + + repo = args.get("repo") + if repo: + owner, repo_name = repo.split("/") + query = query.eq("repo_owner", owner).eq("repo_name", repo_name) + + result = query.execute() + reviews = result.data if result.data else [] + + total_cost = sum(r.get("cost_usd", 0) or 0 for r in reviews) + review_count = len(reviews) + avg_cost = total_cost / review_count if review_count > 0 else 0 + + lines = [ + f"## Cost Summary (last {days} days)", + f"Total spend: ${total_cost:.2f}", + f"Reviews: {review_count}", + f"Avg cost/review: ${avg_cost:.4f}", + ] + + return [TextContent(type="text", text="\n".join(lines))] diff --git a/tests/test_mcp_tools.py b/tests/test_mcp_tools.py new file mode 100644 index 0000000..715a3f3 --- /dev/null +++ b/tests/test_mcp_tools.py @@ -0,0 +1,138 @@ +"""Tests for MCP tools.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from pr_review_agent.mcp.tools import ( + _check_pr_lint, + _check_pr_size, + _get_cost_summary, + _get_review_history, + _review_pr, + list_tools, +) + + +@pytest.mark.asyncio +async def test_list_tools_returns_all(): + """list_tools returns all 5 tools.""" + tools = await list_tools() + assert len(tools) == 5 + names = {t.name for t in tools} + assert names == { + "review_pr", + "check_pr_size", + "check_pr_lint", + "get_review_history", + "get_cost_summary", + } + + +@pytest.mark.asyncio +async def test_list_tools_have_schemas(): + """Each tool has an input schema.""" + tools = await list_tools() + for tool in tools: + assert tool.inputSchema is not None + assert tool.inputSchema["type"] == "object" + + +@pytest.mark.asyncio +@patch("pr_review_agent.mcp.tools.get_github_token", return_value="token") +@patch("pr_review_agent.mcp.tools.get_anthropic_key", return_value="key") +async def test_review_pr_size_gate_fails(mock_key, mock_token): + """review_pr returns early when size gate fails.""" + with ( + patch("pr_review_agent.github_client.GitHubClient") as mock_gh, + patch("pr_review_agent.config.load_config"), + patch("pr_review_agent.gates.size_gate.check_size") as mock_size, + ): + mock_pr = MagicMock() + mock_gh.return_value.fetch_pr.return_value = mock_pr + mock_size.return_value = MagicMock(passed=False, reason="Too large") + + result = await _review_pr({"repo": "org/repo", "pr_number": 1}) + + assert len(result) == 1 + assert "Size gate failed" in result[0].text + + +@pytest.mark.asyncio +@patch("pr_review_agent.mcp.tools.get_github_token", return_value="token") +@patch("pr_review_agent.mcp.tools.get_anthropic_key", return_value="key") +async def test_review_pr_lint_gate_fails(mock_key, mock_token): + """review_pr returns early when lint gate fails.""" + with ( + patch("pr_review_agent.github_client.GitHubClient") as mock_gh, + patch("pr_review_agent.config.load_config"), + patch("pr_review_agent.gates.size_gate.check_size") as mock_size, + patch("pr_review_agent.gates.lint_gate.run_lint") as mock_lint, + ): + mock_pr = MagicMock() + mock_gh.return_value.fetch_pr.return_value = mock_pr + mock_size.return_value = MagicMock(passed=True) + mock_lint.return_value = MagicMock(passed=False, error_count=5) + + result = await _review_pr({"repo": "org/repo", "pr_number": 1}) + + assert "Lint gate failed" in result[0].text + + +@pytest.mark.asyncio +@patch("pr_review_agent.mcp.tools.get_github_token", return_value="token") +async def test_check_pr_size_passed(mock_token): + """check_pr_size returns PASSED status.""" + with ( + patch("pr_review_agent.github_client.GitHubClient") as mock_gh, + patch("pr_review_agent.config.load_config") as mock_config, + patch("pr_review_agent.gates.size_gate.check_size") as mock_size, + ): + mock_pr = MagicMock(lines_added=50, lines_removed=10, files_changed=["a.py"]) + mock_gh.return_value.fetch_pr.return_value = mock_pr + mock_size.return_value = MagicMock(passed=True) + mock_config.return_value = MagicMock( + limits=MagicMock(max_lines_changed=500, max_files_changed=20) + ) + + result = await _check_pr_size({"repo": "org/repo", "pr_number": 1}) + + assert "PASSED" in result[0].text + + +@pytest.mark.asyncio +@patch("pr_review_agent.mcp.tools.get_github_token", return_value="token") +async def test_check_pr_lint_failed(mock_token): + """check_pr_lint returns FAILED with error count.""" + with ( + patch("pr_review_agent.github_client.GitHubClient") as mock_gh, + patch("pr_review_agent.config.load_config") as mock_config, + patch("pr_review_agent.gates.lint_gate.run_lint") as mock_lint, + ): + mock_pr = MagicMock(files_changed=["a.py"]) + mock_gh.return_value.fetch_pr.return_value = mock_pr + mock_lint.return_value = MagicMock(passed=False, error_count=12) + mock_config.return_value = MagicMock( + linting=MagicMock(fail_threshold=5) + ) + + result = await _check_pr_lint({"repo": "org/repo", "pr_number": 1}) + + assert "FAILED" in result[0].text + assert "12" in result[0].text + + +@pytest.mark.asyncio +async def test_get_review_history_no_supabase(): + """get_review_history returns error without credentials.""" + with patch.dict("os.environ", {}, clear=True): + result = await _get_review_history({"repo": "org/repo"}) + assert "SUPABASE_URL" in result[0].text + + +@pytest.mark.asyncio +async def test_get_cost_summary_no_supabase(): + """get_cost_summary returns error without credentials.""" + with patch.dict("os.environ", {}, clear=True): + result = await _get_cost_summary({}) + assert "SUPABASE_URL" in result[0].text