Skip to content

Implement RFC 7523 JWT flows #1247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
873bbe0
Implement client credentials auth flow
LucaButBoring Jun 23, 2025
7c65d76
Add example client and server for client_credentials flow
LucaButBoring Jun 23, 2025
d927bc0
Set issuer to AS URI for RFC8414 compliance in auth example
LucaButBoring Jun 24, 2025
20b5dfc
Merge RS/AS in client_credentials example
LucaButBoring Jun 24, 2025
1a5f104
Revert "Merge RS/AS in client_credentials example"
LucaButBoring Jun 24, 2025
fe548e5
Implement simplified RS+AS token mapping without DCR
LucaButBoring Jun 24, 2025
7e0dfaf
Update naming and docs for client credentials example
LucaButBoring Jun 24, 2025
4ab922c
Implement client_secret_post support in client credentials example
LucaButBoring Jun 24, 2025
e3a2b6d
Don't use client_credentials by default; fix test
LucaButBoring Jun 24, 2025
ed2a486
Merge branch 'main' of https://github.com/modelcontextprotocol/python…
LucaButBoring Jun 24, 2025
a8067e1
Add tests for client_credentials flow
LucaButBoring Jun 25, 2025
0be70c4
Merge branch 'main' into feat/client-credentials
LucaButBoring Jun 25, 2025
13b3478
Update function name in PRM unit tests
LucaButBoring Jun 25, 2025
aaf2cc7
Merge branch 'main' into feat/client-credentials
LucaButBoring Jun 26, 2025
efecc7d
Use client_metadata to determine if 2LO should be used
LucaButBoring Jun 26, 2025
f1d1591
Merge branch 'main' into feat/client-credentials
LucaButBoring Jun 30, 2025
06177d1
Merge branch 'main' into feat/client-credentials
LucaButBoring Jul 9, 2025
2a2f562
Merge branch 'main' into feat/client-credentials
LucaButBoring Jul 10, 2025
31eeb63
Fix Markdown formatting
LucaButBoring Jul 10, 2025
ac75345
Merge branch 'main' of https://github.com/modelcontextprotocol/python…
LucaButBoring Jul 17, 2025
c3c6725
Merge branch 'main' into feat/client-credentials
LucaButBoring Jul 24, 2025
4a4c007
Update auth tests to fix mock behavior
LucaButBoring Jul 24, 2025
6677894
Merge branch 'main' of https://github.com/modelcontextprotocol/python…
LucaButBoring Jul 28, 2025
fc8331c
Implement RFC 7523 authorization grant flow
LucaButBoring Jul 29, 2025
39758e2
Implement RFC 7523 Section 2.2 for client_credentials
yannj-fr Aug 1, 2025
36532fd
Add case for preconfigured assertion in RFC7523 S2.2 flow
LucaButBoring Aug 1, 2025
ed23997
Remove 2LO example for now
LucaButBoring Aug 1, 2025
e10f7c9
Remove 2LO in this branch, limit to RFC7523
LucaButBoring Aug 7, 2025
bcc5b39
Fix rfc8707 test error
LucaButBoring Aug 7, 2025
b90a6c2
Merge branch 'main' into feat/rfc7523
LucaButBoring Aug 7, 2025
ba3cd1e
Merge branch 'main' of https://github.com/modelcontextprotocol/python…
LucaButBoring Aug 12, 2025
27f38e2
Fix type error in FastMCP auth test
LucaButBoring Aug 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions examples/clients/simple-auth-client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
]
dependencies = [
"click>=8.0.0",
"mcp>=1.0.0",
]
dependencies = ["click>=8.0.0", "mcp"]

[project.scripts]
mcp-simple-auth-client = "mcp_simple_auth_client.main:cli"
Expand All @@ -44,9 +41,3 @@ target-version = "py310"

[tool.uv]
dev-dependencies = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"]

[tool.uv.sources]
mcp = { path = "../../../" }

[[tool.uv.index]]
url = "https://pypi.org/simple"
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:

async def register_client(self, client_info: OAuthClientInformationFull):
"""Register a new OAuth client."""
if not client_info.client_id:
raise ValueError("No client_id provided")
self.clients[client_info.client_id] = client_info

async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
Expand Down Expand Up @@ -209,6 +211,8 @@ async def exchange_authorization_code(
"""Exchange authorization code for tokens."""
if authorization_code.code not in self.auth_codes:
raise ValueError("Invalid authorization code")
if not client.client_id:
raise ValueError("No client_id provided")

# Generate MCP access token
mcp_token = f"mcp_{secrets.token_hex(32)}"
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
"jsonschema>=4.20.0",
"pywin32>=310; sys_platform == 'win32'",
"pyjwt[crypto]>=2.10.1",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -116,7 +117,7 @@ extend-exclude = ["README.md"]
"tests/server/fastmcp/test_func_metadata.py" = ["E501"]

[tool.uv.workspace]
members = ["examples/servers/*", "examples/snippets"]
members = ["examples/clients/*", "examples/servers/*", "examples/snippets"]

[tool.uv.sources]
mcp = { workspace = true }
Expand Down
201 changes: 176 additions & 25 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Protocol
from typing import Any, Protocol
from urllib.parse import urlencode, urljoin, urlparse
from uuid import uuid4

import anyio
import httpx
import jwt
from pydantic import BaseModel, Field, ValidationError

from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
Expand Down Expand Up @@ -61,6 +63,58 @@ def generate(cls) -> "PKCEParameters":
return cls(code_verifier=code_verifier, code_challenge=code_challenge)


class JWTParameters(BaseModel):
"""JWT parameters."""

assertion: str | None = Field(
default=None,
description="JWT assertion for JWT authentication. "
"Will be used instead of generating a new assertion if provided.",
)

issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")

def to_assertion(self, with_audience_fallback: str | None = None) -> str:
if self.assertion is not None:
# Prebuilt JWT (e.g. acquired out-of-band)
assertion = self.assertion
else:
if not self.jwt_signing_key:
raise OAuthFlowError("Missing signing key for JWT bearer grant")
if not self.issuer:
raise OAuthFlowError("Missing issuer for JWT bearer grant")
if not self.subject:
raise OAuthFlowError("Missing subject for JWT bearer grant")

audience = self.audience if self.audience else with_audience_fallback
if not audience:
raise OAuthFlowError("Missing audience for JWT bearer grant")

now = int(time.time())
claims: dict[str, Any] = {
"iss": self.issuer,
"sub": self.subject,
"aud": audience,
"exp": now + self.jwt_lifetime_seconds,
"iat": now,
"jti": str(uuid4()),
}
claims.update(self.claims or {})

assertion = jwt.encode(
claims,
self.jwt_signing_key,
algorithm=self.jwt_signing_algorithm or "RS256",
)
return assertion


class TokenStorage(Protocol):
"""Protocol for token storage implementations."""

Expand Down Expand Up @@ -88,8 +142,8 @@ class OAuthContext:
server_url: str
client_metadata: OAuthClientMetadata
storage: TokenStorage
redirect_handler: Callable[[str], Awaitable[None]]
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]]
redirect_handler: Callable[[str], Awaitable[None]] | None
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
timeout: float = 300.0

# Discovered metadata
Expand Down Expand Up @@ -189,8 +243,8 @@ def __init__(
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
redirect_handler: Callable[[str], Awaitable[None]],
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]],
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
):
"""Initialize OAuth2 authentication."""
Expand Down Expand Up @@ -309,8 +363,21 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
except ValidationError as e:
raise OAuthRegistrationError(f"Invalid registration response: {e}")

async def _perform_authorization(self) -> tuple[str, str]:
async def _perform_authorization(self) -> httpx.Request:
"""Perform the authorization flow."""
auth_code, code_verifier = await self._perform_authorization_code_grant()
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
return token_request

async def _perform_authorization_code_grant(self) -> tuple[str, str]:
"""Perform the authorization redirect and get auth code."""
if self.context.client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
if not self.context.redirect_handler:
raise OAuthFlowError("No redirect handler provided for authorization code grant")
if not self.context.callback_handler:
raise OAuthFlowError("No callback handler provided for authorization code grant")

if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint)
else:
Expand Down Expand Up @@ -355,24 +422,34 @@ async def _perform_authorization(self) -> tuple[str, str]:
# Return auth code and code verifier for token exchange
return auth_code, pkce_params.code_verifier

async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Request:
"""Build token exchange request."""
if not self.context.client_info:
raise OAuthFlowError("Missing client info")

def _get_token_endpoint(self) -> str:
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
token_url = str(self.context.oauth_metadata.token_endpoint)
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
token_url = urljoin(auth_base_url, "/token")
return token_url

async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
) -> httpx.Request:
"""Build token exchange request for authorization_code flow."""
if self.context.client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
if not self.context.client_info:
raise OAuthFlowError("Missing client info")

token_data = {
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
"client_id": self.context.client_info.client_id,
"code_verifier": code_verifier,
}
token_url = self._get_token_endpoint()
token_data = token_data or {}
token_data.update(
{
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
"client_id": self.context.client_info.client_id,
"code_verifier": code_verifier,
}
)

# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
Expand All @@ -388,7 +465,9 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req
async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response."""
if response.status_code != 200:
raise OAuthTokenError(f"Token exchange failed: {response.status_code}")
body = await response.aread()
body = body.decode("utf-8")
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}")

try:
content = await response.aread()
Expand Down Expand Up @@ -535,12 +614,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
registration_response = yield registration_request
await self._handle_registration_response(registration_response)

# Step 4: Perform authorization
auth_code, code_verifier = await self._perform_authorization()

# Step 5: Exchange authorization code for tokens
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
# Step 4: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)
except Exception:
logger.exception("OAuth flow error")
Expand All @@ -549,3 +624,79 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Retry with new tokens
self._add_auth_header(request)
yield request


class RFC7523OAuthClientProvider(OAuthClientProvider):
"""OAuth client provider for RFC7532 clients."""

jwt_parameters: JWTParameters | None = None

def __init__(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
jwt_parameters: JWTParameters | None = None,
) -> None:
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
self.jwt_parameters = jwt_parameters

async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
) -> httpx.Request:
"""Build token exchange request for authorization_code flow."""
token_data = token_data or {}
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
self._add_client_authentication_jwt(token_data=token_data)
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)

async def _perform_authorization(self) -> httpx.Request:
"""Perform the authorization flow."""
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
token_request = await self._exchange_token_jwt_bearer()
return token_request
else:
return await super()._perform_authorization()

def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]):
"""Add JWT assertion for client authentication to token endpoint parameters."""
if not self.jwt_parameters:
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")

token_url = self._get_token_endpoint()
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=token_url)

# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
token_data["client_assertion"] = assertion
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
# We need to set the audience to the token endpoint, the audience is difference from the one in claims
# it represents the resource server that will validate the token
token_data["audience"] = self.context.get_resource_url()

async def _exchange_token_jwt_bearer(self) -> httpx.Request:
"""Build token exchange request for JWT bearer grant."""
if not self.context.client_info:
raise OAuthFlowError("Missing client info")
if not self.jwt_parameters:
raise OAuthFlowError("Missing JWT parameters")

token_url = self._get_token_endpoint()
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=token_url)

token_data = {
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": assertion,
}

if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url()

if self.context.client_metadata.scope:
token_data["scope"] = self.context.client_metadata.scope

return httpx.Request(
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
)
18 changes: 8 additions & 10 deletions src/mcp/shared/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@ class OAuthClientMetadata(BaseModel):
for the full specification.
"""

redirect_uris: list[AnyUrl] = Field(..., min_length=1)
# token_endpoint_auth_method: this implementation only supports none &
# client_secret_post;
# ie: we do not support client_secret_basic
token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post"
# grant_types: this implementation only supports authorization_code & refresh_token
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
redirect_uris: list[AnyUrl] | None = Field(..., min_length=1)
# supported auth methods for the token endpoint
token_endpoint_auth_method: Literal["none", "client_secret_post", "private_key_jwt"] = "client_secret_post"
# supported grant_types of this implementation
grant_types: list[Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"]] = [
"authorization_code",
"refresh_token",
]
Expand Down Expand Up @@ -81,10 +79,10 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None:
def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
if redirect_uri is not None:
# Validate redirect_uri against client's registered redirect URIs
if redirect_uri not in self.redirect_uris:
if self.redirect_uris is None or redirect_uri not in self.redirect_uris:
raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client")
return redirect_uri
elif len(self.redirect_uris) == 1:
elif self.redirect_uris is not None and len(self.redirect_uris) == 1:
return self.redirect_uris[0]
else:
raise InvalidRedirectUriError("redirect_uri must be specified when client has multiple registered URIs")
Expand All @@ -96,7 +94,7 @@ class OAuthClientInformationFull(OAuthClientMetadata):
(client information plus metadata).
"""

client_id: str
client_id: str | None = None
client_secret: str | None = None
client_id_issued_at: int | None = None
client_secret_expires_at: int | None = None
Expand Down
Loading
Loading