Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/shared.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
permissions:
contents: read

env:
COLUMNS: 150

jobs:
pre-commit:
runs-on: ubuntu-latest
Expand Down
32 changes: 17 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession


# Mock database class for example
Expand Down Expand Up @@ -242,7 +243,7 @@ mcp = FastMCP("My App", lifespan=app_lifespan)

# Access type-safe lifespan context in tools
@mcp.tool()
def query_db(ctx: Context) -> str:
def query_db(ctx: Context[ServerSession, AppContext]) -> str:
"""Tool that uses initialized resources."""
db = ctx.request_context.lifespan_context.db
return db.query()
Expand Down Expand Up @@ -314,12 +315,13 @@ Tools can optionally receive a Context object by including a parameter with the
<!-- snippet-source examples/snippets/servers/tool_progress.py -->
```python
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession

mcp = FastMCP(name="Progress Example")


@mcp.tool()
async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str:
async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str:
"""Execute a task with progress updates."""
await ctx.info(f"Starting: {task_name}")

Expand Down Expand Up @@ -445,7 +447,7 @@ def get_user(user_id: str) -> UserProfile:

# Classes WITHOUT type hints cannot be used for structured output
class UntypedConfig:
def __init__(self, setting1, setting2):
def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType]
self.setting1 = setting1
self.setting2 = setting2

Expand Down Expand Up @@ -571,12 +573,13 @@ The Context object provides the following capabilities:
<!-- snippet-source examples/snippets/servers/tool_progress.py -->
```python
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession

mcp = FastMCP(name="Progress Example")


@mcp.tool()
async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str:
async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str:
"""Execute a task with progress updates."""
await ctx.info(f"Starting: {task_name}")

Expand Down Expand Up @@ -694,6 +697,7 @@ Request additional information from users. This example shows an Elicitation dur
from pydantic import BaseModel, Field

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession

mcp = FastMCP(name="Elicitation Example")

Expand All @@ -709,12 +713,7 @@ class BookingPreferences(BaseModel):


@mcp.tool()
async def book_table(
date: str,
time: str,
party_size: int,
ctx: Context,
) -> str:
async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str:
"""Book a table with date availability check."""
# Check if date is available
if date == "2024-12-25":
Expand Down Expand Up @@ -750,13 +749,14 @@ Tools can interact with LLMs through sampling (generating text):
<!-- snippet-source examples/snippets/servers/sampling.py -->
```python
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.types import SamplingMessage, TextContent

mcp = FastMCP(name="Sampling Example")


@mcp.tool()
async def generate_poem(topic: str, ctx: Context) -> str:
async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str:
"""Generate a poem using LLM sampling."""
prompt = f"Write a short poem about {topic}"

Expand Down Expand Up @@ -785,12 +785,13 @@ Tools can send logs and notifications through the context:
<!-- snippet-source examples/snippets/servers/notifications.py -->
```python
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession

mcp = FastMCP(name="Notifications Example")


@mcp.tool()
async def process_data(data: str, ctx: Context) -> str:
async def process_data(data: str, ctx: Context[ServerSession, None]) -> str:
"""Process data with logging."""
# Different log levels
await ctx.debug(f"Debug: Processing '{data}'")
Expand Down Expand Up @@ -1244,6 +1245,7 @@ Run from the repository root:

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any

import mcp.server.stdio
import mcp.types as types
Expand Down Expand Up @@ -1272,7 +1274,7 @@ class Database:


@asynccontextmanager
async def server_lifespan(_server: Server) -> AsyncIterator[dict]:
async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]:
"""Manage server startup and shutdown lifecycle."""
# Initialize resources on startup
db = await Database.connect()
Expand Down Expand Up @@ -1304,7 +1306,7 @@ async def handle_list_tools() -> list[types.Tool]:


@server.call_tool()
async def query_db(name: str, arguments: dict) -> list[types.TextContent]:
async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]:
"""Handle database query tool call."""
if name != "query_db":
raise ValueError(f"Unknown tool: {name}")
Expand Down Expand Up @@ -1558,7 +1560,7 @@ server_params = StdioServerParameters(

# Optional: create a sampling callback
async def handle_sampling_message(
context: RequestContext, params: types.CreateMessageRequestParams
context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams
) -> types.CreateMessageResult:
print(f"Sampling request: {params.messages}")
return types.CreateMessageResult(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ async def _default_redirect_handler(authorization_url: str) -> None:
# Create OAuth authentication handler using the new interface
oauth_auth = OAuthClientProvider(
server_url=self.server_url.replace("/mcp", ""),
client_metadata=OAuthClientMetadata.model_validate(
client_metadata_dict
),
client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
storage=InMemoryTokenStorage(),
redirect_handler=_default_redirect_handler,
callback_handler=callback_handler,
Expand Down Expand Up @@ -322,9 +320,7 @@ async def interactive_loop(self):
await self.call_tool(tool_name, arguments)

else:
print(
"❌ Unknown command. Try 'list', 'call <tool_name>', or 'quit'"
)
print("❌ Unknown command. Try 'list', 'call <tool_name>', or 'quit'")

except KeyboardInterrupt:
print("\n\n👋 Goodbye!")
Expand Down
2 changes: 1 addition & 1 deletion examples/clients/simple-auth-client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ select = ["E", "F", "I"]
ignore = []

[tool.ruff]
line-length = 88
line-length = 120
target-version = "py310"

[tool.uv]
Expand Down
57 changes: 13 additions & 44 deletions examples/clients/simple-chatbot/mcp_simple_chatbot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
from mcp.client.stdio import stdio_client

# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


class Configuration:
Expand Down Expand Up @@ -75,29 +73,19 @@ def __init__(self, name: str, config: dict[str, Any]) -> None:

async def initialize(self) -> None:
"""Initialize the server connection."""
command = (
shutil.which("npx")
if self.config["command"] == "npx"
else self.config["command"]
)
command = shutil.which("npx") if self.config["command"] == "npx" else self.config["command"]
if command is None:
raise ValueError("The command must be a valid string and cannot be None.")

server_params = StdioServerParameters(
command=command,
args=self.config["args"],
env={**os.environ, **self.config["env"]}
if self.config.get("env")
else None,
env={**os.environ, **self.config["env"]} if self.config.get("env") else None,
)
try:
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
read, write = stdio_transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write)
)
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
self.session = session
except Exception as e:
Expand All @@ -122,10 +110,7 @@ async def list_tools(self) -> list[Any]:

for item in tools_response:
if isinstance(item, tuple) and item[0] == "tools":
tools.extend(
Tool(tool.name, tool.description, tool.inputSchema, tool.title)
for tool in item[1]
)
tools.extend(Tool(tool.name, tool.description, tool.inputSchema, tool.title) for tool in item[1])

return tools

Expand Down Expand Up @@ -164,9 +149,7 @@ async def execute_tool(

except Exception as e:
attempt += 1
logging.warning(
f"Error executing tool: {e}. Attempt {attempt} of {retries}."
)
logging.warning(f"Error executing tool: {e}. Attempt {attempt} of {retries}.")
if attempt < retries:
logging.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
Expand Down Expand Up @@ -209,9 +192,7 @@ def format_for_llm(self) -> str:
args_desc = []
if "properties" in self.input_schema:
for param_name, param_info in self.input_schema["properties"].items():
arg_desc = (
f"- {param_name}: {param_info.get('description', 'No description')}"
)
arg_desc = f"- {param_name}: {param_info.get('description', 'No description')}"
if param_name in self.input_schema.get("required", []):
arg_desc += " (required)"
args_desc.append(arg_desc)
Expand Down Expand Up @@ -281,10 +262,7 @@ def get_response(self, messages: list[dict[str, str]]) -> str:
logging.error(f"Status code: {status_code}")
logging.error(f"Response details: {e.response.text}")

return (
f"I encountered an error: {error_message}. "
"Please try again or rephrase your request."
)
return f"I encountered an error: {error_message}. Please try again or rephrase your request."


class ChatSession:
Expand Down Expand Up @@ -323,17 +301,13 @@ async def process_llm_response(self, llm_response: str) -> str:
tools = await server.list_tools()
if any(tool.name == tool_call["tool"] for tool in tools):
try:
result = await server.execute_tool(
tool_call["tool"], tool_call["arguments"]
)
result = await server.execute_tool(tool_call["tool"], tool_call["arguments"])

if isinstance(result, dict) and "progress" in result:
progress = result["progress"]
total = result["total"]
percentage = (progress / total) * 100
logging.info(
f"Progress: {progress}/{total} ({percentage:.1f}%)"
)
logging.info(f"Progress: {progress}/{total} ({percentage:.1f}%)")

return f"Tool execution result: {result}"
except Exception as e:
Expand Down Expand Up @@ -408,9 +382,7 @@ async def start(self) -> None:

final_response = self.llm_client.get_response(messages)
logging.info("\nFinal response: %s", final_response)
messages.append(
{"role": "assistant", "content": final_response}
)
messages.append({"role": "assistant", "content": final_response})
else:
messages.append({"role": "assistant", "content": llm_response})

Expand All @@ -426,10 +398,7 @@ async def main() -> None:
"""Initialize and run the chat session."""
config = Configuration()
server_config = config.load_config("servers_config.json")
servers = [
Server(name, srv_config)
for name, srv_config in server_config["mcpServers"].items()
]
servers = [Server(name, srv_config) for name, srv_config in server_config["mcpServers"].items()]
llm_client = LLMClient(config.llm_api_key)
chat_session = ChatSession(servers, llm_client)
await chat_session.start()
Expand Down
2 changes: 1 addition & 1 deletion examples/clients/simple-chatbot/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ select = ["E", "F", "I"]
ignore = []

[tool.ruff]
line-length = 88
line-length = 120
target-version = "py310"

[tool.uv]
Expand Down
3 changes: 2 additions & 1 deletion examples/servers/simple-auth/mcp_simple_auth/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class ResourceServerSettings(BaseSettings):
# RFC 8707 resource validation
oauth_strict: bool = False

def __init__(self, **data):
# TODO(Marcelo): Is this even needed? I didn't have time to check.
def __init__(self, **data: Any):
"""Initialize settings with values from environment variables."""
super().__init__(**data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SimpleAuthSettings(BaseSettings):
mcp_scope: str = "user"


class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
class SimpleOAuthProvider(OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]):
"""
Simple OAuth provider for demo purposes.
Expand Down Expand Up @@ -116,7 +116,7 @@ async def get_login_page(self, state: str) -> HTMLResponse:
<p>This is a simplified authentication demo. Use the demo credentials below:</p>
<p><strong>Username:</strong> demo_user<br>
<strong>Password:</strong> demo_password</p>
<form action="{self.server_url.rstrip("/")}/login/callback" method="post">
<input type="hidden" name="state" value="{state}">
<div class="form-group">
Expand Down Expand Up @@ -264,7 +264,8 @@ async def exchange_refresh_token(
"""Exchange refresh token - not supported in this example."""
raise NotImplementedError("Refresh tokens not supported")

async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None:
# TODO(Marcelo): The type hint is wrong. We need to fix, and test to check if it works.
async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: # type: ignore
"""Revoke a token."""
if token in self.tokens:
del self.tokens[token]
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Example token verifier implementation using OAuth 2.0 Token Introspection (RFC 7662)."""

import logging
from typing import Any

from mcp.server.auth.provider import AccessToken, TokenVerifier
from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url
Expand Down Expand Up @@ -79,13 +80,13 @@ async def verify_token(self, token: str) -> AccessToken | None:
logger.warning(f"Token introspection failed: {e}")
return None

def _validate_resource(self, token_data: dict) -> bool:
def _validate_resource(self, token_data: dict[str, Any]) -> bool:
"""Validate token was issued for this resource server."""
if not self.server_url or not self.resource_url:
return False # Fail if strict validation requested but URLs missing

# Check 'aud' claim first (standard JWT audience)
aud = token_data.get("aud")
aud: list[str] | str | None = token_data.get("aud")
if isinstance(aud, list):
for audience in aud:
if self._is_valid_resource(audience):
Expand Down
Loading
Loading