diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 9b950db724..57973837a1 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -144,11 +144,8 @@ def get_resource_url(self) -> str: """ resource = resource_url_from_server_url(self.server_url) - # If PRM provides a resource that's a valid parent, use it if self.protected_resource_metadata and self.protected_resource_metadata.resource: - prm_resource = str(self.protected_resource_metadata.resource) - if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): - resource = prm_resource + resource = str(self.protected_resource_metadata.resource) return resource @@ -292,6 +289,13 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> try: content = await response.aread() metadata = ProtectedResourceMetadata.model_validate_json(content) + # Validate resource field BEFORE storing metadata per RFC 9728 Section 3.3. + if not check_resource_allowed( + requested_resource=self.context.server_url, + configured_resource=str(metadata.resource), + ): + return False + self.context.protected_resource_metadata = metadata if metadata.authorization_servers: self.context.auth_server_url = str(metadata.authorization_servers[0]) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 8cea6cefd7..df51d9a3fe 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -642,6 +642,59 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa content = request.content.decode() assert "resource=" in content + @pytest.mark.anyio + async def test_reject_metadata_with_mismatched_origin(self, oauth_provider: OAuthClientProvider): + """Test RFC 9728 Section 3.3: reject metadata with different scheme, host, or port.""" + # Test different scheme + response_wrong_scheme = httpx.Response( + 200, + content=b'{"resource": "http://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + ) + result = await oauth_provider._handle_protected_resource_response(response_wrong_scheme) + assert result is False + assert oauth_provider.context.protected_resource_metadata is None + + # Test different host + response_wrong_host = httpx.Response( + 200, + content=b'{"resource": "https://evil.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + ) + result = await oauth_provider._handle_protected_resource_response(response_wrong_host) + assert result is False + assert oauth_provider.context.protected_resource_metadata is None + + # Test different port + response_wrong_port = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com:8080/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + ) + result = await oauth_provider._handle_protected_resource_response(response_wrong_port) + assert result is False + assert oauth_provider.context.protected_resource_metadata is None + + @pytest.mark.anyio + async def test_reject_metadata_with_invalid_path_hierarchy(self, oauth_provider: OAuthClientProvider): + """Test RFC 9728 Section 3.3: reject metadata where resource is child of server URL.""" + + # Invalid: resource is child path + response_child_path = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp/subpath", "authorization_servers": ["https://auth.example.com"]}', + ) + result = await oauth_provider._handle_protected_resource_response(response_child_path) + assert result is False + assert oauth_provider.context.protected_resource_metadata is None + + # Valid: resource is parent path + response_parent_path = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1", "authorization_servers": ["https://auth.example.com"]}', + ) + result = await oauth_provider._handle_protected_resource_response(response_parent_path) + assert result is True + assert oauth_provider.context.protected_resource_metadata is not None + assert str(oauth_provider.context.protected_resource_metadata.resource) == "https://api.example.com/v1" + class TestRegistrationResponse: """Test client registration response handling.""" @@ -745,7 +798,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide # Send a successful discovery response with minimal protected resource metadata discovery_response = httpx.Response( 200, - content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}', + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', request=discovery_request, )