2
2
Tests for refactored OAuth client authentication implementation.
3
3
"""
4
4
5
+ import base64
6
+ import json
5
7
import time
6
8
from unittest import mock
9
+ from urllib .parse import unquote
7
10
8
11
import httpx
9
12
import pytest
10
13
from inline_snapshot import Is , snapshot
11
14
from pydantic import AnyHttpUrl , AnyUrl
12
15
13
- from mcp .client .auth import OAuthClientProvider , PKCEParameters
14
- from mcp .shared .auth import OAuthClientInformationFull , OAuthClientMetadata , OAuthToken , ProtectedResourceMetadata
16
+ from mcp .client .auth import OAuthClientProvider , OAuthRegistrationError , PKCEParameters
17
+ from mcp .shared .auth import (
18
+ OAuthClientInformationFull ,
19
+ OAuthClientMetadata ,
20
+ OAuthMetadata ,
21
+ OAuthToken ,
22
+ ProtectedResourceMetadata ,
23
+ )
15
24
16
25
17
26
class MockTokenStorage :
@@ -415,6 +424,43 @@ async def test_register_client_skip_if_registered(self, oauth_provider: OAuthCli
415
424
request = await oauth_provider ._register_client ()
416
425
assert request is None
417
426
427
+ @pytest .mark .anyio
428
+ async def test_register_client_none_auth_method_with_server_metadata (self , oauth_provider : OAuthClientProvider ):
429
+ """Test that token_endpoint_auth_method=None selects from server's supported methods."""
430
+ # Set server metadata with specific supported methods
431
+ oauth_provider .context .oauth_metadata = OAuthMetadata (
432
+ issuer = AnyHttpUrl ("https://auth.example.com" ),
433
+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
434
+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
435
+ token_endpoint_auth_methods_supported = ["client_secret_post" ],
436
+ )
437
+ # Ensure client_metadata has None for token_endpoint_auth_method
438
+ assert oauth_provider .context .client_metadata .token_endpoint_auth_method is None
439
+
440
+ request = await oauth_provider ._register_client ()
441
+ assert request is not None
442
+
443
+ body = json .loads (request .content )
444
+ assert body ["token_endpoint_auth_method" ] == "client_secret_post"
445
+
446
+ @pytest .mark .anyio
447
+ async def test_register_client_none_auth_method_no_compatible (self , oauth_provider : OAuthClientProvider ):
448
+ """Test that registration raises error when no compatible auth methods."""
449
+ # Set server metadata with unsupported methods only
450
+ oauth_provider .context .oauth_metadata = OAuthMetadata (
451
+ issuer = AnyHttpUrl ("https://auth.example.com" ),
452
+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
453
+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
454
+ token_endpoint_auth_methods_supported = ["private_key_jwt" , "client_secret_jwt" ],
455
+ )
456
+ assert oauth_provider .context .client_metadata .token_endpoint_auth_method is None
457
+
458
+ with pytest .raises (OAuthRegistrationError ) as exc_info :
459
+ await oauth_provider ._register_client ()
460
+
461
+ assert "No compatible authentication methods" in str (exc_info .value )
462
+ assert "private_key_jwt" in str (exc_info .value )
463
+
418
464
@pytest .mark .anyio
419
465
async def test_token_exchange_request (self , oauth_provider : OAuthClientProvider ):
420
466
"""Test token exchange request building."""
@@ -423,6 +469,7 @@ async def test_token_exchange_request(self, oauth_provider: OAuthClientProvider)
423
469
client_id = "test_client" ,
424
470
client_secret = "test_secret" ,
425
471
redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
472
+ token_endpoint_auth_method = "client_secret_post" ,
426
473
)
427
474
428
475
request = await oauth_provider ._exchange_token ("test_auth_code" , "test_verifier" )
@@ -448,6 +495,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
448
495
client_id = "test_client" ,
449
496
client_secret = "test_secret" ,
450
497
redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
498
+ token_endpoint_auth_method = "client_secret_post" ,
451
499
)
452
500
453
501
request = await oauth_provider ._refresh_token ()
@@ -463,6 +511,114 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
463
511
assert "client_id=test_client" in content
464
512
assert "client_secret=test_secret" in content
465
513
514
+ @pytest .mark .anyio
515
+ async def test_basic_auth_token_exchange (self , oauth_provider : OAuthClientProvider ):
516
+ """Test token exchange with client_secret_basic authentication."""
517
+ # Set up OAuth metadata to support basic auth
518
+ oauth_provider .context .oauth_metadata = OAuthMetadata (
519
+ issuer = AnyHttpUrl ("https://auth.example.com" ),
520
+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
521
+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
522
+ token_endpoint_auth_methods_supported = ["client_secret_basic" , "client_secret_post" ],
523
+ )
524
+
525
+ client_id_raw = "test@client" # Include special character to test URL encoding
526
+ client_secret_raw = "test:secret" # Include colon to test URL encoding
527
+
528
+ oauth_provider .context .client_info = OAuthClientInformationFull (
529
+ client_id = client_id_raw ,
530
+ client_secret = client_secret_raw ,
531
+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
532
+ token_endpoint_auth_method = "client_secret_basic" ,
533
+ )
534
+
535
+ request = await oauth_provider ._exchange_token ("test_auth_code" , "test_verifier" )
536
+
537
+ # Should use basic auth (registered method)
538
+ assert "Authorization" in request .headers
539
+ assert request .headers ["Authorization" ].startswith ("Basic " )
540
+
541
+ # Decode and verify credentials are properly URL-encoded
542
+ encoded_creds = request .headers ["Authorization" ][6 :] # Remove "Basic " prefix
543
+ decoded = base64 .b64decode (encoded_creds ).decode ()
544
+ client_id , client_secret = decoded .split (":" , 1 )
545
+
546
+ # Check URL encoding was applied
547
+ assert client_id == "test%40client" # @ should be encoded as %40
548
+ assert client_secret == "test%3Asecret" # : should be encoded as %3A
549
+
550
+ # Verify decoded values match original
551
+ assert unquote (client_id ) == client_id_raw
552
+ assert unquote (client_secret ) == client_secret_raw
553
+
554
+ # client_secret should NOT be in body for basic auth
555
+ content = request .content .decode ()
556
+ assert "client_secret=" not in content
557
+ assert "client_id=test%40client" in content # client_id still in body
558
+
559
+ @pytest .mark .anyio
560
+ async def test_basic_auth_refresh_token (self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken ):
561
+ """Test token refresh with client_secret_basic authentication."""
562
+ oauth_provider .context .current_tokens = valid_tokens
563
+
564
+ # Set up OAuth metadata to only support basic auth
565
+ oauth_provider .context .oauth_metadata = OAuthMetadata (
566
+ issuer = AnyHttpUrl ("https://auth.example.com" ),
567
+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
568
+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
569
+ token_endpoint_auth_methods_supported = ["client_secret_basic" ],
570
+ )
571
+
572
+ client_id = "test_client"
573
+ client_secret = "test_secret"
574
+ oauth_provider .context .client_info = OAuthClientInformationFull (
575
+ client_id = client_id ,
576
+ client_secret = client_secret ,
577
+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
578
+ token_endpoint_auth_method = "client_secret_basic" ,
579
+ )
580
+
581
+ request = await oauth_provider ._refresh_token ()
582
+
583
+ assert "Authorization" in request .headers
584
+ assert request .headers ["Authorization" ].startswith ("Basic " )
585
+
586
+ encoded_creds = request .headers ["Authorization" ][6 :]
587
+ decoded = base64 .b64decode (encoded_creds ).decode ()
588
+ assert decoded == f"{ client_id } :{ client_secret } "
589
+
590
+ # client_secret should NOT be in body
591
+ content = request .content .decode ()
592
+ assert "client_secret=" not in content
593
+
594
+ @pytest .mark .anyio
595
+ async def test_none_auth_method (self , oauth_provider : OAuthClientProvider ):
596
+ """Test 'none' authentication method (public client)."""
597
+ oauth_provider .context .oauth_metadata = OAuthMetadata (
598
+ issuer = AnyHttpUrl ("https://auth.example.com" ),
599
+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
600
+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
601
+ token_endpoint_auth_methods_supported = ["none" ],
602
+ )
603
+
604
+ client_id = "public_client"
605
+ oauth_provider .context .client_info = OAuthClientInformationFull (
606
+ client_id = client_id ,
607
+ client_secret = None , # No secret for public client
608
+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
609
+ token_endpoint_auth_method = "none" ,
610
+ )
611
+
612
+ request = await oauth_provider ._exchange_token ("test_auth_code" , "test_verifier" )
613
+
614
+ # Should NOT have Authorization header
615
+ assert "Authorization" not in request .headers
616
+
617
+ # Should NOT have client_secret in body
618
+ content = request .content .decode ()
619
+ assert "client_secret=" not in content
620
+ assert "client_id=public_client" in content
621
+
466
622
467
623
class TestProtectedResourceMetadata :
468
624
"""Test protected resource handling."""
0 commit comments