11use std:: { collections:: HashMap , sync:: Arc , time:: Duration } ;
22
3+ use async_trait:: async_trait;
34use oauth2:: {
45 AuthUrl , AuthorizationCode , ClientId , ClientSecret , CsrfToken , EmptyExtraTokenFields ,
56 PkceCodeChallenge , PkceCodeVerifier , RedirectUrl , RefreshToken , RequestTokenError , Scope ,
@@ -17,6 +18,62 @@ use tracing::{debug, error, warn};
1718
1819const DEFAULT_EXCHANGE_URL : & str = "http://localhost" ;
1920
21+ /// Stored credentials for OAuth2 authorization
22+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
23+ pub struct StoredCredentials {
24+ pub client_id : String ,
25+ pub token_response : Option < OAuthTokenResponse > ,
26+ }
27+
28+ /// Trait for storing and retrieving OAuth2 credentials
29+ ///
30+ /// Implementations of this trait can provide custom storage backends
31+ /// for OAuth2 credentials, such as file-based storage, keychain integration,
32+ /// or database storage.
33+ #[ async_trait]
34+ pub trait CredentialStore : Send + Sync {
35+ async fn load ( & self ) -> Result < Option < StoredCredentials > , AuthError > ;
36+
37+ async fn save ( & self , credentials : StoredCredentials ) -> Result < ( ) , AuthError > ;
38+
39+ async fn clear ( & self ) -> Result < ( ) , AuthError > ;
40+ }
41+
42+ /// In-memory credential store (default implementation)
43+ ///
44+ /// This store keeps credentials in memory only and does not persist them
45+ /// between application restarts. This is the default behavior when no
46+ /// custom credential store is provided.
47+ #[ derive( Debug , Default , Clone ) ]
48+ pub struct InMemoryCredentialStore {
49+ credentials : Arc < RwLock < Option < StoredCredentials > > > ,
50+ }
51+
52+ impl InMemoryCredentialStore {
53+ pub fn new ( ) -> Self {
54+ Self {
55+ credentials : Arc :: new ( RwLock :: new ( None ) ) ,
56+ }
57+ }
58+ }
59+
60+ #[ async_trait:: async_trait]
61+ impl CredentialStore for InMemoryCredentialStore {
62+ async fn load ( & self ) -> Result < Option < StoredCredentials > , AuthError > {
63+ Ok ( self . credentials . read ( ) . await . clone ( ) )
64+ }
65+
66+ async fn save ( & self , credentials : StoredCredentials ) -> Result < ( ) , AuthError > {
67+ * self . credentials . write ( ) . await = Some ( credentials) ;
68+ Ok ( ( ) )
69+ }
70+
71+ async fn clear ( & self ) -> Result < ( ) , AuthError > {
72+ * self . credentials . write ( ) . await = None ;
73+ Ok ( ( ) )
74+ }
75+ }
76+
2077/// sse client with oauth2 authorization
2178#[ derive( Clone ) ]
2279pub struct AuthClient < C > {
@@ -151,7 +208,7 @@ pub struct AuthorizationManager {
151208 http_client : HttpClient ,
152209 metadata : Option < AuthorizationMetadata > ,
153210 oauth_client : Option < OAuthClient > ,
154- credentials : RwLock < Option < OAuthTokenResponse > > ,
211+ credential_store : Arc < dyn CredentialStore > ,
155212 state : RwLock < Option < AuthorizationState > > ,
156213 base_url : Url ,
157214}
@@ -222,14 +279,42 @@ impl AuthorizationManager {
222279 http_client,
223280 metadata : None ,
224281 oauth_client : None ,
225- credentials : RwLock :: new ( None ) ,
282+ credential_store : Arc :: new ( InMemoryCredentialStore :: new ( ) ) ,
226283 state : RwLock :: new ( None ) ,
227284 base_url,
228285 } ;
229286
230287 Ok ( manager)
231288 }
232289
290+ /// Set a custom credential store
291+ ///
292+ /// This allows you to provide your own implementation of credential storage,
293+ /// such as file-based storage, keychain integration, or database storage.
294+ /// This should be called before any operations that need credentials.
295+ pub fn set_credential_store < S : CredentialStore + ' static > ( & mut self , store : S ) {
296+ self . credential_store = Arc :: new ( store) ;
297+ }
298+
299+ /// Initialize from stored credentials if available
300+ ///
301+ /// This will load credentials from the credential store and configure
302+ /// the client if credentials are found.
303+ pub async fn initialize_from_store ( & mut self ) -> Result < bool , AuthError > {
304+ if let Some ( stored) = self . credential_store . load ( ) . await ? {
305+ if stored. token_response . is_some ( ) {
306+ if self . metadata . is_none ( ) {
307+ let metadata = self . discover_metadata ( ) . await ?;
308+ self . metadata = Some ( metadata) ;
309+ }
310+
311+ self . configure_client_id ( & stored. client_id ) ?;
312+ return Ok ( true ) ;
313+ }
314+ }
315+ Ok ( false )
316+ }
317+
233318 pub fn with_client ( & mut self , http_client : HttpClient ) -> Result < ( ) , AuthError > {
234319 self . http_client = http_client;
235320 Ok ( ( ) )
@@ -252,13 +337,16 @@ impl AuthorizationManager {
252337
253338 /// get client id and credentials
254339 pub async fn get_credentials ( & self ) -> Result < Credentials , AuthError > {
255- let credentials = self . credentials . read ( ) . await ;
256340 let client_id = self
257341 . oauth_client
258342 . as_ref ( )
259343 . ok_or_else ( || AuthError :: InternalError ( "OAuth client not configured" . to_string ( ) ) ) ?
260344 . client_id ( ) ;
261- Ok ( ( client_id. to_string ( ) , credentials. clone ( ) ) )
345+
346+ let stored = self . credential_store . load ( ) . await ?;
347+ let token_response = stored. and_then ( |s| s. token_response ) ;
348+
349+ Ok ( ( client_id. to_string ( ) , token_response) )
262350 }
263351
264352 /// configure oauth2 client with client credentials
@@ -309,7 +397,6 @@ impl AuthorizationManager {
309397 ) ) ;
310398 } ;
311399
312- // prepare registration request
313400 let registration_request = ClientRegistrationRequest {
314401 client_name : name. to_string ( ) ,
315402 redirect_uris : vec ! [ redirect_uri. to_string( ) ] ,
@@ -479,23 +566,28 @@ impl AuthorizationManager {
479566 } ;
480567
481568 debug ! ( "exchange token result: {:?}" , token_result) ;
482- // store credentials
483- * self . credentials . write ( ) . await = Some ( token_result. clone ( ) ) ;
569+
570+ // Store credentials in the credential store
571+ let client_id = oauth_client. client_id ( ) . to_string ( ) ;
572+ let stored = StoredCredentials {
573+ client_id,
574+ token_response : Some ( token_result. clone ( ) ) ,
575+ } ;
576+ self . credential_store . save ( stored) . await ?;
484577
485578 Ok ( token_result)
486579 }
487580
488581 /// get access token, if expired, refresh it automatically
489582 pub async fn get_access_token ( & self ) -> Result < String , AuthError > {
490- let credentials = self . credentials . read ( ) . await ;
583+ // Load credentials from store
584+ let stored = self . credential_store . load ( ) . await ?;
585+ let credentials = stored. and_then ( |s| s. token_response ) ;
491586
492587 if let Some ( creds) = credentials. as_ref ( ) {
493- // check if the token is expire
494588 let expires_in = creds. expires_in ( ) . unwrap_or ( Duration :: from_secs ( 0 ) ) ;
495589 if expires_in <= Duration :: from_secs ( 0 ) {
496590 tracing:: info!( "Access token expired, refreshing." ) ;
497- // token expired, try to refresh , release the lock
498- drop ( credentials) ;
499591
500592 let new_creds = self . refresh_token ( ) . await ?;
501593 tracing:: info!( "Refreshed access token." ) ;
@@ -517,26 +609,28 @@ impl AuthorizationManager {
517609 . as_ref ( )
518610 . ok_or_else ( || AuthError :: InternalError ( "OAuth client not configured" . to_string ( ) ) ) ?;
519611
520- let current_credentials = self
521- . credentials
522- . read ( )
523- . await
524- . clone ( )
612+ let stored = self . credential_store . load ( ) . await ?;
613+ let current_credentials = stored
614+ . and_then ( |s| s. token_response )
525615 . ok_or_else ( || AuthError :: AuthorizationRequired ) ?;
526616
527617 let refresh_token = current_credentials. refresh_token ( ) . ok_or_else ( || {
528618 AuthError :: TokenRefreshFailed ( "No refresh token available" . to_string ( ) )
529619 } ) ?;
530620 debug ! ( "refresh token: {:?}" , refresh_token) ;
531- // refresh token
621+
532622 let token_result = oauth_client
533623 . exchange_refresh_token ( & RefreshToken :: new ( refresh_token. secret ( ) . to_string ( ) ) )
534624 . request_async ( & self . http_client )
535625 . await
536626 . map_err ( |e| AuthError :: TokenRefreshFailed ( e. to_string ( ) ) ) ?;
537627
538- // store new credentials
539- * self . credentials . write ( ) . await = Some ( token_result. clone ( ) ) ;
628+ let client_id = oauth_client. client_id ( ) . to_string ( ) ;
629+ let stored = StoredCredentials {
630+ client_id,
631+ token_response : Some ( token_result. clone ( ) ) ,
632+ } ;
633+ self . credential_store . save ( stored) . await ?;
540634
541635 Ok ( token_result)
542636 }
@@ -1003,14 +1097,15 @@ impl OAuthState {
10031097 AuthorizationManager :: new ( DEFAULT_EXCHANGE_URL ) . await ?,
10041098 ) ;
10051099
1006- // write credentials
1007- * manager. credentials . write ( ) . await = Some ( credentials) ;
1100+ let stored = StoredCredentials {
1101+ client_id : client_id. to_string ( ) ,
1102+ token_response : Some ( credentials) ,
1103+ } ;
1104+ manager. credential_store . save ( stored) . await ?;
10081105
1009- // discover metadata
10101106 let metadata = manager. discover_metadata ( ) . await ?;
10111107 manager. metadata = Some ( metadata) ;
10121108
1013- // set client id and secret
10141109 manager. configure_client_id ( client_id) ?;
10151110
10161111 * self = OAuthState :: Authorized ( manager) ;
0 commit comments