Skip to content

Commit 3a704d7

Browse files
committed
Add client_secret_basic auth support to MCP client
- Implement HTTP Basic auth for OAuth token requests - Automatically sets selects auth method when OAuthClientProvider is configured with OAuthClientMetadata that has token_endpoint_auth_method=None. - Made OAuthClientMetadata.token_endpoint_auth_method optional to support the above auto-configuration. - Removed ` "token_endpoint_auth_method": "client_secret_post"` from the simple-auth-client example as is now auto-configured.
1 parent 48f8385 commit 3a704d7

File tree

5 files changed

+230
-15
lines changed

5 files changed

+230
-15
lines changed

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ async def callback_handler() -> tuple[str, str | None]:
177177
"redirect_uris": ["http://localhost:3030/callback"],
178178
"grant_types": ["authorization_code", "refresh_token"],
179179
"response_types": ["code"],
180-
"token_endpoint_auth_method": "client_secret_post",
181180
}
182181

183182
async def _default_redirect_handler(authorization_url: str) -> None:

src/mcp/client/auth.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections.abc import AsyncGenerator, Awaitable, Callable
1515
from dataclasses import dataclass, field
1616
from typing import Protocol
17-
from urllib.parse import urlencode, urljoin, urlparse
17+
from urllib.parse import quote, urlencode, urljoin, urlparse
1818

1919
import anyio
2020
import httpx
@@ -175,6 +175,42 @@ def should_include_resource_param(self, protocol_version: str | None = None) ->
175175
# Version format is YYYY-MM-DD, so string comparison works
176176
return protocol_version >= "2025-06-18"
177177

178+
def prepare_token_auth(
179+
self, data: dict[str, str], headers: dict[str, str] | None = None
180+
) -> tuple[dict[str, str], dict[str, str]]:
181+
"""Prepare authentication for token requests.
182+
183+
Args:
184+
data: The form data to send
185+
headers: Optional headers dict to update
186+
187+
Returns:
188+
Tuple of (updated_data, updated_headers)
189+
"""
190+
if headers is None:
191+
headers = {}
192+
193+
if not self.client_info:
194+
return data, headers
195+
196+
auth_method = self.client_info.token_endpoint_auth_method
197+
198+
if auth_method == "client_secret_basic" and self.client_info.client_secret:
199+
# URL-encode client ID and secret per RFC 6749 Section 2.3.1
200+
encoded_id = quote(self.client_info.client_id, safe="")
201+
encoded_secret = quote(self.client_info.client_secret, safe="")
202+
credentials = f"{encoded_id}:{encoded_secret}"
203+
encoded_credentials = base64.b64encode(credentials.encode()).decode()
204+
headers["Authorization"] = f"Basic {encoded_credentials}"
205+
# Don't include client_secret in body for basic auth
206+
data = {k: v for k, v in data.items() if k != "client_secret"}
207+
elif auth_method == "client_secret_post" and self.client_info.client_secret:
208+
# Include client_secret in request body
209+
data["client_secret"] = self.client_info.client_secret
210+
# For auth_method == "none", don't add any client_secret
211+
212+
return data, headers
213+
178214

179215
class OAuthClientProvider(httpx.Auth):
180216
"""
@@ -291,6 +327,27 @@ async def _register_client(self) -> httpx.Request | None:
291327

292328
registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
293329

330+
# If token_endpoint_auth_method is None, auto-select based on server support
331+
if self.context.client_metadata.token_endpoint_auth_method is None:
332+
preference_order = ["client_secret_basic", "client_secret_post", "none"]
333+
334+
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint_auth_methods_supported:
335+
supported = self.context.oauth_metadata.token_endpoint_auth_methods_supported
336+
for method in preference_order:
337+
if method in supported:
338+
registration_data["token_endpoint_auth_method"] = method
339+
break
340+
else:
341+
# No compatible methods between client and server
342+
raise OAuthRegistrationError(
343+
f"No compatible authentication methods. "
344+
f"Server supports: {supported}, "
345+
f"Client supports: {preference_order}"
346+
)
347+
else:
348+
# No server metadata available, use our default preference
349+
registration_data["token_endpoint_auth_method"] = preference_order[0]
350+
294351
return httpx.Request(
295352
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
296353
)
@@ -378,12 +435,11 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req
378435
if self.context.should_include_resource_param(self.context.protocol_version):
379436
token_data["resource"] = self.context.get_resource_url() # RFC 8707
380437

381-
if self.context.client_info.client_secret:
382-
token_data["client_secret"] = self.context.client_info.client_secret
438+
# Prepare authentication based on preferred method
439+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
440+
token_data, headers = self.context.prepare_token_auth(token_data, headers)
383441

384-
return httpx.Request(
385-
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
386-
)
442+
return httpx.Request("POST", token_url, data=token_data, headers=headers)
387443

388444
async def _handle_token_response(self, response: httpx.Response) -> None:
389445
"""Handle token exchange response."""
@@ -432,12 +488,11 @@ async def _refresh_token(self) -> httpx.Request:
432488
if self.context.should_include_resource_param(self.context.protocol_version):
433489
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
434490

435-
if self.context.client_info.client_secret:
436-
refresh_data["client_secret"] = self.context.client_info.client_secret
491+
# Prepare authentication based on preferred method
492+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
493+
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
437494

438-
return httpx.Request(
439-
"POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
440-
)
495+
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
441496

442497
async def _handle_refresh_response(self, response: httpx.Response) -> bool:
443498
"""Handle token refresh response. Returns True if successful."""

src/mcp/server/auth/handlers/register.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ async def handle(self, request: Request) -> Response:
4949
)
5050

5151
client_id = str(uuid4())
52+
53+
# If auth method is None, default to client_secret_post
54+
if client_metadata.token_endpoint_auth_method is None:
55+
client_metadata.token_endpoint_auth_method = "client_secret_post"
56+
5257
client_secret = None
5358
if client_metadata.token_endpoint_auth_method != "none":
5459
# cryptographically secure random 32-byte hex string

src/mcp/shared/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class OAuthClientMetadata(BaseModel):
4242
"""
4343

4444
redirect_uris: list[AnyUrl] = Field(..., min_length=1)
45-
token_endpoint_auth_method: Literal["none", "client_secret_post", "client_secret_basic"] = "client_secret_post"
45+
token_endpoint_auth_method: Literal["none", "client_secret_post", "client_secret_basic"] | None = None
4646
# grant_types: this implementation only supports authorization_code & refresh_token
4747
grant_types: list[Literal["authorization_code", "refresh_token"] | str] = [
4848
"authorization_code",

tests/client/test_auth.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,25 @@
22
Tests for refactored OAuth client authentication implementation.
33
"""
44

5+
import base64
6+
import json
57
import time
68
from unittest import mock
9+
from urllib.parse import unquote
710

811
import httpx
912
import pytest
1013
from inline_snapshot import Is, snapshot
1114
from pydantic import AnyHttpUrl, AnyUrl
1215

13-
from mcp.client.auth import OAuthClientProvider, PKCEParameters
14-
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata
16+
from mcp.client.auth import OAuthClientProvider, OAuthRegistrationError, PKCEParameters
17+
from mcp.shared.auth import (
18+
OAuthClientInformationFull,
19+
OAuthClientMetadata,
20+
OAuthMetadata,
21+
OAuthToken,
22+
ProtectedResourceMetadata,
23+
)
1524

1625

1726
class MockTokenStorage:
@@ -415,6 +424,43 @@ async def test_register_client_skip_if_registered(self, oauth_provider: OAuthCli
415424
request = await oauth_provider._register_client()
416425
assert request is None
417426

427+
@pytest.mark.anyio
428+
async def test_register_client_none_auth_method_with_server_metadata(self, oauth_provider: OAuthClientProvider):
429+
"""Test that token_endpoint_auth_method=None selects from server's supported methods."""
430+
# Set server metadata with specific supported methods
431+
oauth_provider.context.oauth_metadata = OAuthMetadata(
432+
issuer=AnyHttpUrl("https://auth.example.com"),
433+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
434+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
435+
token_endpoint_auth_methods_supported=["client_secret_post"],
436+
)
437+
# Ensure client_metadata has None for token_endpoint_auth_method
438+
assert oauth_provider.context.client_metadata.token_endpoint_auth_method is None
439+
440+
request = await oauth_provider._register_client()
441+
assert request is not None
442+
443+
body = json.loads(request.content)
444+
assert body["token_endpoint_auth_method"] == "client_secret_post"
445+
446+
@pytest.mark.anyio
447+
async def test_register_client_none_auth_method_no_compatible(self, oauth_provider: OAuthClientProvider):
448+
"""Test that registration raises error when no compatible auth methods."""
449+
# Set server metadata with unsupported methods only
450+
oauth_provider.context.oauth_metadata = OAuthMetadata(
451+
issuer=AnyHttpUrl("https://auth.example.com"),
452+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
453+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
454+
token_endpoint_auth_methods_supported=["private_key_jwt", "client_secret_jwt"],
455+
)
456+
assert oauth_provider.context.client_metadata.token_endpoint_auth_method is None
457+
458+
with pytest.raises(OAuthRegistrationError) as exc_info:
459+
await oauth_provider._register_client()
460+
461+
assert "No compatible authentication methods" in str(exc_info.value)
462+
assert "private_key_jwt" in str(exc_info.value)
463+
418464
@pytest.mark.anyio
419465
async def test_token_exchange_request(self, oauth_provider: OAuthClientProvider):
420466
"""Test token exchange request building."""
@@ -423,6 +469,7 @@ async def test_token_exchange_request(self, oauth_provider: OAuthClientProvider)
423469
client_id="test_client",
424470
client_secret="test_secret",
425471
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
472+
token_endpoint_auth_method="client_secret_post",
426473
)
427474

428475
request = await oauth_provider._exchange_token("test_auth_code", "test_verifier")
@@ -448,6 +495,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
448495
client_id="test_client",
449496
client_secret="test_secret",
450497
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
498+
token_endpoint_auth_method="client_secret_post",
451499
)
452500

453501
request = await oauth_provider._refresh_token()
@@ -463,6 +511,114 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
463511
assert "client_id=test_client" in content
464512
assert "client_secret=test_secret" in content
465513

514+
@pytest.mark.anyio
515+
async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider):
516+
"""Test token exchange with client_secret_basic authentication."""
517+
# Set up OAuth metadata to support basic auth
518+
oauth_provider.context.oauth_metadata = OAuthMetadata(
519+
issuer=AnyHttpUrl("https://auth.example.com"),
520+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
521+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
522+
token_endpoint_auth_methods_supported=["client_secret_basic", "client_secret_post"],
523+
)
524+
525+
client_id_raw = "test@client" # Include special character to test URL encoding
526+
client_secret_raw = "test:secret" # Include colon to test URL encoding
527+
528+
oauth_provider.context.client_info = OAuthClientInformationFull(
529+
client_id=client_id_raw,
530+
client_secret=client_secret_raw,
531+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
532+
token_endpoint_auth_method="client_secret_basic",
533+
)
534+
535+
request = await oauth_provider._exchange_token("test_auth_code", "test_verifier")
536+
537+
# Should use basic auth (registered method)
538+
assert "Authorization" in request.headers
539+
assert request.headers["Authorization"].startswith("Basic ")
540+
541+
# Decode and verify credentials are properly URL-encoded
542+
encoded_creds = request.headers["Authorization"][6:] # Remove "Basic " prefix
543+
decoded = base64.b64decode(encoded_creds).decode()
544+
client_id, client_secret = decoded.split(":", 1)
545+
546+
# Check URL encoding was applied
547+
assert client_id == "test%40client" # @ should be encoded as %40
548+
assert client_secret == "test%3Asecret" # : should be encoded as %3A
549+
550+
# Verify decoded values match original
551+
assert unquote(client_id) == client_id_raw
552+
assert unquote(client_secret) == client_secret_raw
553+
554+
# client_secret should NOT be in body for basic auth
555+
content = request.content.decode()
556+
assert "client_secret=" not in content
557+
assert "client_id=test%40client" in content # client_id still in body
558+
559+
@pytest.mark.anyio
560+
async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
561+
"""Test token refresh with client_secret_basic authentication."""
562+
oauth_provider.context.current_tokens = valid_tokens
563+
564+
# Set up OAuth metadata to only support basic auth
565+
oauth_provider.context.oauth_metadata = OAuthMetadata(
566+
issuer=AnyHttpUrl("https://auth.example.com"),
567+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
568+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
569+
token_endpoint_auth_methods_supported=["client_secret_basic"],
570+
)
571+
572+
client_id = "test_client"
573+
client_secret = "test_secret"
574+
oauth_provider.context.client_info = OAuthClientInformationFull(
575+
client_id=client_id,
576+
client_secret=client_secret,
577+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
578+
token_endpoint_auth_method="client_secret_basic",
579+
)
580+
581+
request = await oauth_provider._refresh_token()
582+
583+
assert "Authorization" in request.headers
584+
assert request.headers["Authorization"].startswith("Basic ")
585+
586+
encoded_creds = request.headers["Authorization"][6:]
587+
decoded = base64.b64decode(encoded_creds).decode()
588+
assert decoded == f"{client_id}:{client_secret}"
589+
590+
# client_secret should NOT be in body
591+
content = request.content.decode()
592+
assert "client_secret=" not in content
593+
594+
@pytest.mark.anyio
595+
async def test_none_auth_method(self, oauth_provider: OAuthClientProvider):
596+
"""Test 'none' authentication method (public client)."""
597+
oauth_provider.context.oauth_metadata = OAuthMetadata(
598+
issuer=AnyHttpUrl("https://auth.example.com"),
599+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
600+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
601+
token_endpoint_auth_methods_supported=["none"],
602+
)
603+
604+
client_id = "public_client"
605+
oauth_provider.context.client_info = OAuthClientInformationFull(
606+
client_id=client_id,
607+
client_secret=None, # No secret for public client
608+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
609+
token_endpoint_auth_method="none",
610+
)
611+
612+
request = await oauth_provider._exchange_token("test_auth_code", "test_verifier")
613+
614+
# Should NOT have Authorization header
615+
assert "Authorization" not in request.headers
616+
617+
# Should NOT have client_secret in body
618+
content = request.content.decode()
619+
assert "client_secret=" not in content
620+
assert "client_id=public_client" in content
621+
466622

467623
class TestProtectedResourceMetadata:
468624
"""Test protected resource handling."""

0 commit comments

Comments
 (0)