Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ async def _exchange_token_authorization_code(

async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response."""
if response.status_code != 200:
if response.status_code not in {200, 201}:
body = await response.aread()
body = body.decode("utf-8")
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}")
Expand Down
110 changes: 110 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,116 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth(
# Verify exactly one request was yielded (no double-sending)
assert request_yields == 1, f"Expected 1 request yield, got {request_yields}"

@pytest.mark.anyio
async def test_token_exchange_accepts_201_status(
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage
):
"""Test that token exchange accepts both 200 and 201 status codes."""
# Ensure no tokens are stored
oauth_provider.context.current_tokens = None
oauth_provider.context.token_expiry_time = None
oauth_provider._initialized = True

# Create a test request
test_request = httpx.Request("GET", "https://api.example.com/mcp")

# Mock the auth flow
auth_flow = oauth_provider.async_auth_flow(test_request)

# First request should be the original request without auth header
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# Send a 401 response to trigger the OAuth flow
response = httpx.Response(
401,
headers={
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
},
request=test_request,
)

# Next request should be to discover protected resource metadata
discovery_request = await auth_flow.asend(response)
assert discovery_request.method == "GET"
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"

# Send a successful discovery response with minimal protected resource metadata
discovery_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
request=discovery_request,
)

# Next request should be to discover OAuth metadata
oauth_metadata_request = await auth_flow.asend(discovery_response)
assert oauth_metadata_request.method == "GET"
assert str(oauth_metadata_request.url).startswith("https://auth.example.com/")
assert "mcp-protocol-version" in oauth_metadata_request.headers

# Send a successful OAuth metadata response
oauth_metadata_response = httpx.Response(
200,
content=(
b'{"issuer": "https://auth.example.com", '
b'"authorization_endpoint": "https://auth.example.com/authorize", '
b'"token_endpoint": "https://auth.example.com/token", '
b'"registration_endpoint": "https://auth.example.com/register"}'
),
request=oauth_metadata_request,
)

# Next request should be to register client
registration_request = await auth_flow.asend(oauth_metadata_response)
assert registration_request.method == "POST"
assert str(registration_request.url) == "https://auth.example.com/register"

# Send a successful registration response with 201 status
registration_response = httpx.Response(
201,
content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}',
request=registration_request,
)

# Mock the authorization process
oauth_provider._perform_authorization_code_grant = mock.AsyncMock(
return_value=("test_auth_code", "test_code_verifier")
)

# Next request should be to exchange token
token_request = await auth_flow.asend(registration_response)
assert token_request.method == "POST"
assert str(token_request.url) == "https://auth.example.com/token"
assert "code=test_auth_code" in token_request.content.decode()

# Send a successful token response with 201 status code (test both 200 and 201 are accepted)
token_response = httpx.Response(
201,
content=(
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
b'"refresh_token": "new_refresh_token"}'
),
request=token_request,
)

# Final request should be the original request with auth header
final_request = await auth_flow.asend(token_response)
assert final_request.headers["Authorization"] == "Bearer new_access_token"
assert final_request.method == "GET"
assert str(final_request.url) == "https://api.example.com/mcp"

# Send final success response to properly close the generator
final_response = httpx.Response(200, request=final_request)
try:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass # Expected - generator should complete

# Verify tokens were stored
assert oauth_provider.context.current_tokens is not None
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
assert oauth_provider.context.token_expiry_time is not None

@pytest.mark.anyio
async def test_403_insufficient_scope_updates_scope_from_header(
self,
Expand Down