Skip to content

Commit 31be9b2

Browse files
authored
feat(auth): implement CredentialStore trait (#542)
1 parent 03040f8 commit 31be9b2

File tree

2 files changed

+119
-23
lines changed

2 files changed

+119
-23
lines changed

crates/rmcp/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ all-features = true
1414
rustdoc-args = ["--cfg", "docsrs"]
1515

1616
[dependencies]
17+
async-trait = "0.1.89"
1718
serde = { version = "1.0", features = ["derive", "rc"] }
1819
serde_json = "1.0"
1920
thiserror = "2"

crates/rmcp/src/transport/auth.rs

Lines changed: 118 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{collections::HashMap, sync::Arc, time::Duration};
22

3+
use async_trait::async_trait;
34
use oauth2::{
45
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
56
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
@@ -17,6 +18,62 @@ use tracing::{debug, error, warn};
1718

1819
const DEFAULT_EXCHANGE_URL: &str = "http://localhost";
1920

21+
/// Stored credentials for OAuth2 authorization
22+
#[derive(Debug, Clone, Serialize, Deserialize)]
23+
pub struct StoredCredentials {
24+
pub client_id: String,
25+
pub token_response: Option<OAuthTokenResponse>,
26+
}
27+
28+
/// Trait for storing and retrieving OAuth2 credentials
29+
///
30+
/// Implementations of this trait can provide custom storage backends
31+
/// for OAuth2 credentials, such as file-based storage, keychain integration,
32+
/// or database storage.
33+
#[async_trait]
34+
pub trait CredentialStore: Send + Sync {
35+
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError>;
36+
37+
async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError>;
38+
39+
async fn clear(&self) -> Result<(), AuthError>;
40+
}
41+
42+
/// In-memory credential store (default implementation)
43+
///
44+
/// This store keeps credentials in memory only and does not persist them
45+
/// between application restarts. This is the default behavior when no
46+
/// custom credential store is provided.
47+
#[derive(Debug, Default, Clone)]
48+
pub struct InMemoryCredentialStore {
49+
credentials: Arc<RwLock<Option<StoredCredentials>>>,
50+
}
51+
52+
impl InMemoryCredentialStore {
53+
pub fn new() -> Self {
54+
Self {
55+
credentials: Arc::new(RwLock::new(None)),
56+
}
57+
}
58+
}
59+
60+
#[async_trait::async_trait]
61+
impl CredentialStore for InMemoryCredentialStore {
62+
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError> {
63+
Ok(self.credentials.read().await.clone())
64+
}
65+
66+
async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> {
67+
*self.credentials.write().await = Some(credentials);
68+
Ok(())
69+
}
70+
71+
async fn clear(&self) -> Result<(), AuthError> {
72+
*self.credentials.write().await = None;
73+
Ok(())
74+
}
75+
}
76+
2077
/// sse client with oauth2 authorization
2178
#[derive(Clone)]
2279
pub struct AuthClient<C> {
@@ -151,7 +208,7 @@ pub struct AuthorizationManager {
151208
http_client: HttpClient,
152209
metadata: Option<AuthorizationMetadata>,
153210
oauth_client: Option<OAuthClient>,
154-
credentials: RwLock<Option<OAuthTokenResponse>>,
211+
credential_store: Arc<dyn CredentialStore>,
155212
state: RwLock<Option<AuthorizationState>>,
156213
base_url: Url,
157214
}
@@ -222,14 +279,42 @@ impl AuthorizationManager {
222279
http_client,
223280
metadata: None,
224281
oauth_client: None,
225-
credentials: RwLock::new(None),
282+
credential_store: Arc::new(InMemoryCredentialStore::new()),
226283
state: RwLock::new(None),
227284
base_url,
228285
};
229286

230287
Ok(manager)
231288
}
232289

290+
/// Set a custom credential store
291+
///
292+
/// This allows you to provide your own implementation of credential storage,
293+
/// such as file-based storage, keychain integration, or database storage.
294+
/// This should be called before any operations that need credentials.
295+
pub fn set_credential_store<S: CredentialStore + 'static>(&mut self, store: S) {
296+
self.credential_store = Arc::new(store);
297+
}
298+
299+
/// Initialize from stored credentials if available
300+
///
301+
/// This will load credentials from the credential store and configure
302+
/// the client if credentials are found.
303+
pub async fn initialize_from_store(&mut self) -> Result<bool, AuthError> {
304+
if let Some(stored) = self.credential_store.load().await? {
305+
if stored.token_response.is_some() {
306+
if self.metadata.is_none() {
307+
let metadata = self.discover_metadata().await?;
308+
self.metadata = Some(metadata);
309+
}
310+
311+
self.configure_client_id(&stored.client_id)?;
312+
return Ok(true);
313+
}
314+
}
315+
Ok(false)
316+
}
317+
233318
pub fn with_client(&mut self, http_client: HttpClient) -> Result<(), AuthError> {
234319
self.http_client = http_client;
235320
Ok(())
@@ -252,13 +337,16 @@ impl AuthorizationManager {
252337

253338
/// get client id and credentials
254339
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
255-
let credentials = self.credentials.read().await;
256340
let client_id = self
257341
.oauth_client
258342
.as_ref()
259343
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?
260344
.client_id();
261-
Ok((client_id.to_string(), credentials.clone()))
345+
346+
let stored = self.credential_store.load().await?;
347+
let token_response = stored.and_then(|s| s.token_response);
348+
349+
Ok((client_id.to_string(), token_response))
262350
}
263351

264352
/// configure oauth2 client with client credentials
@@ -309,7 +397,6 @@ impl AuthorizationManager {
309397
));
310398
};
311399

312-
// prepare registration request
313400
let registration_request = ClientRegistrationRequest {
314401
client_name: name.to_string(),
315402
redirect_uris: vec![redirect_uri.to_string()],
@@ -479,23 +566,28 @@ impl AuthorizationManager {
479566
};
480567

481568
debug!("exchange token result: {:?}", token_result);
482-
// store credentials
483-
*self.credentials.write().await = Some(token_result.clone());
569+
570+
// Store credentials in the credential store
571+
let client_id = oauth_client.client_id().to_string();
572+
let stored = StoredCredentials {
573+
client_id,
574+
token_response: Some(token_result.clone()),
575+
};
576+
self.credential_store.save(stored).await?;
484577

485578
Ok(token_result)
486579
}
487580

488581
/// get access token, if expired, refresh it automatically
489582
pub async fn get_access_token(&self) -> Result<String, AuthError> {
490-
let credentials = self.credentials.read().await;
583+
// Load credentials from store
584+
let stored = self.credential_store.load().await?;
585+
let credentials = stored.and_then(|s| s.token_response);
491586

492587
if let Some(creds) = credentials.as_ref() {
493-
// check if the token is expire
494588
let expires_in = creds.expires_in().unwrap_or(Duration::from_secs(0));
495589
if expires_in <= Duration::from_secs(0) {
496590
tracing::info!("Access token expired, refreshing.");
497-
// token expired, try to refresh , release the lock
498-
drop(credentials);
499591

500592
let new_creds = self.refresh_token().await?;
501593
tracing::info!("Refreshed access token.");
@@ -517,26 +609,28 @@ impl AuthorizationManager {
517609
.as_ref()
518610
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;
519611

520-
let current_credentials = self
521-
.credentials
522-
.read()
523-
.await
524-
.clone()
612+
let stored = self.credential_store.load().await?;
613+
let current_credentials = stored
614+
.and_then(|s| s.token_response)
525615
.ok_or_else(|| AuthError::AuthorizationRequired)?;
526616

527617
let refresh_token = current_credentials.refresh_token().ok_or_else(|| {
528618
AuthError::TokenRefreshFailed("No refresh token available".to_string())
529619
})?;
530620
debug!("refresh token: {:?}", refresh_token);
531-
// refresh token
621+
532622
let token_result = oauth_client
533623
.exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string()))
534624
.request_async(&self.http_client)
535625
.await
536626
.map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?;
537627

538-
// store new credentials
539-
*self.credentials.write().await = Some(token_result.clone());
628+
let client_id = oauth_client.client_id().to_string();
629+
let stored = StoredCredentials {
630+
client_id,
631+
token_response: Some(token_result.clone()),
632+
};
633+
self.credential_store.save(stored).await?;
540634

541635
Ok(token_result)
542636
}
@@ -1003,14 +1097,15 @@ impl OAuthState {
10031097
AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?,
10041098
);
10051099

1006-
// write credentials
1007-
*manager.credentials.write().await = Some(credentials);
1100+
let stored = StoredCredentials {
1101+
client_id: client_id.to_string(),
1102+
token_response: Some(credentials),
1103+
};
1104+
manager.credential_store.save(stored).await?;
10081105

1009-
// discover metadata
10101106
let metadata = manager.discover_metadata().await?;
10111107
manager.metadata = Some(metadata);
10121108

1013-
// set client id and secret
10141109
manager.configure_client_id(client_id)?;
10151110

10161111
*self = OAuthState::Authorized(manager);

0 commit comments

Comments
 (0)