diff --git a/src/auth/src/credentials/external_account.rs b/src/auth/src/credentials/external_account.rs index f1a74aa1bb..753e1be54b 100644 --- a/src/auth/src/credentials/external_account.rs +++ b/src/auth/src/credentials/external_account.rs @@ -1959,7 +1959,7 @@ mod tests { .unwrap(); let err = creds.headers(Extensions::new()).await.unwrap_err(); - assert!(!err.is_transient()); + assert!(err.is_transient()); sts_server.verify_and_clear(); subject_token_server.verify_and_clear(); } @@ -2084,7 +2084,7 @@ mod tests { .unwrap(); let err = creds.headers(Extensions::new()).await.unwrap_err(); - assert!(!err.is_transient()); + assert!(err.is_transient()); sts_server.verify_and_clear(); } diff --git a/src/auth/src/credentials/idtoken/user_account.rs b/src/auth/src/credentials/idtoken/user_account.rs index 6ff45142eb..75a559d508 100644 --- a/src/auth/src/credentials/idtoken/user_account.rs +++ b/src/auth/src/credentials/idtoken/user_account.rs @@ -422,7 +422,7 @@ mod tests { .build()?; let err = credentials.id_token().await.unwrap_err(); - assert!(!err.is_transient()); + assert!(err.is_transient()); Ok(()) } diff --git a/src/auth/src/credentials/mds.rs b/src/auth/src/credentials/mds.rs index 864f0e8c01..84c0e14918 100644 --- a/src/auth/src/credentials/mds.rs +++ b/src/auth/src/credentials/mds.rs @@ -483,7 +483,7 @@ mod tests { .build_token_provider(); let err = provider.token().await.unwrap_err(); - assert!(!err.is_transient()); + assert!(err.is_transient()); server.verify_and_clear(); Ok(()) } diff --git a/src/auth/src/credentials/user_account.rs b/src/auth/src/credentials/user_account.rs index df0858d309..e7d6d37f6d 100644 --- a/src/auth/src/credentials/user_account.rs +++ b/src/auth/src/credentials/user_account.rs @@ -611,7 +611,7 @@ mod tests { .build()?; let err = credentials.headers(Extensions::new()).await.unwrap_err(); - assert!(!err.is_transient()); + assert!(err.is_transient()); server.verify_and_clear(); Ok(()) } diff --git a/src/auth/src/retry.rs b/src/auth/src/retry.rs index dbc6534d12..85e57d4831 100644 --- a/src/auth/src/retry.rs +++ b/src/auth/src/retry.rs @@ -133,14 +133,15 @@ where return CredentialsError::from_source(false, e); } - let msg = match e + match e .source() .and_then(|s| s.downcast_ref::()) { - Some(cred_error) if cred_error.is_transient() => constants::RETRY_EXHAUSTED_ERROR, - _ => constants::TOKEN_FETCH_FAILED_ERROR, - }; - CredentialsError::new(false, msg, e) + Some(cred_error) if cred_error.is_transient() => { + CredentialsError::new(true, constants::RETRY_EXHAUSTED_ERROR, e) + } + _ => CredentialsError::new(false, constants::TOKEN_FETCH_FAILED_ERROR, e), + } } } @@ -286,7 +287,7 @@ mod tests { .build(mock_provider); let error = provider.token().await.unwrap_err(); - assert!(!error.is_transient()); + assert!(error.is_transient()); let original_error = find_source_error::(&error).unwrap(); assert!(original_error.is_transient()); assert!(error.to_string().contains(constants::RETRY_EXHAUSTED_ERROR)); @@ -350,7 +351,7 @@ mod tests { let provider = Builder::default().build(mock_provider); let error = provider.token().await.unwrap_err(); - assert!(!error.is_transient()); + assert!(error.is_transient()); let original_error = find_source_error::(&error).unwrap(); assert!(original_error.is_transient()); } @@ -493,6 +494,26 @@ mod tests { ); } + #[test_case(false, "invalid credentials"; "permanent auth error")] + #[test_case(true, "transient network error"; "transient auth error")] + fn test_map_retry_error_auth_error(transient: bool, message: &str) { + // 1. Create an authentication error. + let error = CredentialsError::from_msg(transient, message); + let error = gax::error::Error::authentication(error); + let error_string = error.to_string(); + + // 2. Call the function under test. + let credentials_error = + TokenProviderWithRetry::::map_retry_error(error); + + // 3. Assert that the resulting error is transient or not like the original error and wraps the original error. + assert_eq!(credentials_error.is_transient(), transient); + assert_eq!( + credentials_error.source().unwrap().to_string(), + error_string + ); + } + #[test] fn test_unwind_safe() { assert_impl_all!(Builder: std::panic::UnwindSafe, std::panic::RefUnwindSafe); diff --git a/src/auth/src/token_cache.rs b/src/auth/src/token_cache.rs index 1aa8af2c75..a665c2f6e0 100644 --- a/src/auth/src/token_cache.rs +++ b/src/auth/src/token_cache.rs @@ -101,8 +101,8 @@ async fn refresh_task( { loop { let token_result = token_provider.token().await; - let expiry = token_result.as_ref().ok().map(|t| t.expires_at); - let tagged = token_result.map(|token| { + let expiry = token_result.as_ref().map(|t| t.expires_at); + let tagged = token_result.clone().map(|token| { let entity_tag = EntityTag::new(); (token, entity_tag) }); @@ -110,7 +110,7 @@ async fn refresh_task( let _ = tx_token.send(Some(tagged)); match expiry { - Some(Some(expiry)) => { + Ok(Some(expiry)) => { let time_until_expiry = expiry.checked_duration_since(Instant::now()); match time_until_expiry { @@ -129,15 +129,20 @@ async fn refresh_task( } } } - Some(None) => { + Ok(None) => { // If there is no expiry, the token is valid forever, so no need to refresh // TODO(#1553): Validate that all auth backends provide expiry and make expiry not optional. break; } - None => { + Err(err) => { // The retry policy has been used already by the inner token provider. - // If it ended in an error, just quit the background task. - break; + // If it ended in an error, the background task will wait for a while + // and try again. This allows the task to eventually recover if the + // error was transient but exhausted all the retries. + if !err.is_transient() { + break; + } + sleep(SHORT_REFRESH_SLACK).await; } } } @@ -596,6 +601,64 @@ mod tests { assert_eq!(actual, token2); } + #[tokio::test(start_paused = true)] + async fn refresh_task_sleeps_on_transient_error_and_recovers_on_next_loop() -> TestResult { + let now = Instant::now(); + + let token = Token { + token: "token-1".to_string(), + token_type: "Bearer".to_string(), + expires_at: Some(now + TOKEN_VALID_DURATION), + metadata: None, + }; + + let mut mock = MockTokenProvider::new(); + // 1st request succeeds + mock.expect_token() + .times(1) + .return_once(move || Ok(token.clone())); + + // 2nd request (triggered by refresh loop) fails with transient error + mock.expect_token() + .times(1) + .return_once(|| Err(CredentialsError::from_msg(true, "transient error"))); + + let token = Token { + token: "token-2".to_string(), + token_type: "Bearer".to_string(), + expires_at: Some(now + 2 * TOKEN_VALID_DURATION), + metadata: None, + }; + + // 3rd request (triggered by next loop) succeeds + mock.expect_token() + .times(1) + .return_once(move || Ok(token.clone())); + + let cache = TokenCache::new(mock); + + // fetch an initial token + let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?; + assert_eq!(actual.token, "token-1"); + + // advance time to force expiration, which wakes up the background task. + let sleep = TOKEN_VALID_DURATION.add(Duration::from_secs(10)); + tokio::time::advance(sleep).await; + + let result = cache.token(Extensions::new()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("transient error")); + + // Wait another SHORT_REFRESH_SLACK + buffer for the background loop to try again and recover + tokio::time::advance(SHORT_REFRESH_SLACK.add(Duration::from_secs(10))).await; + tokio::task::yield_now().await; + + let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?; + assert_eq!(actual.token, "token-2"); + + Ok(()) + } + #[derive(Clone, Debug)] struct FakeTokenProvider { result: Result, diff --git a/src/auth/tests/credentials.rs b/src/auth/tests/credentials.rs index 2d02ccb70a..c827398b16 100644 --- a/src/auth/tests/credentials.rs +++ b/src/auth/tests/credentials.rs @@ -965,4 +965,92 @@ mod tests { Ok(()) } + + #[tokio::test(start_paused = true)] + #[serial] + async fn test_credentials_refresh_recovers_after_outage() -> TestResult { + let server = Server::run(); + let addr = server.addr().to_string(); + let _e = ScopedEnv::set("GCE_METADATA_HOST", &addr); + + // initial token request: success (200 OK) + server.expect( + Expectation::matching(request::path( + "/computeMetadata/v1/instance/service-accounts/default/token", + )) + .times(1) + .respond_with(json_encoded(json!({ + "access_token": "token-1", + "expires_in": 1, // short lived to trigger refresh soon + "token_type": "Bearer" + }))), + ); + + let creds = MdsBuilder::default().build_access_token_credentials()?; + + // get initial token, this starts the refresh_task + let access_token = creds.access_token().await?; + assert!( + access_token.token.contains("token-1"), + "Expected token-1, got {}", + access_token.token + ); + + // set up outage: fail (503 Service Unavailable) + server.expect( + Expectation::matching(request::path( + "/computeMetadata/v1/instance/service-accounts/default/token", + )) + .times(1..) // called at least once + .respond_with(status_code(503)), + ); + + // wait for the token to be refreshed + // it should exhaust retries, fail, and (with the fix) wait for SHORT_REFRESH_SLACK (10s) + tokio::time::advance(std::time::Duration::from_millis(3000)).await; + tokio::task::yield_now().await; + + // trying to get a token now should fail because retry was exhausted + let result = creds.headers(Extensions::new()).await; + assert!( + result.is_err(), + "expected error due to exhausted retries during outage, but got: {:?}", + result + ); + + // set up recovery: success (200 OK) + server.expect( + Expectation::matching(request::path( + "/computeMetadata/v1/instance/service-accounts/default/token", + )) + .respond_with(json_encoded(json!({ + "access_token": "token-2", + "expires_in": 3600, + "token_type": "Bearer" + }))), + ); + + // advance time long enough to pass through SHORT_REFRESH_SLACK (10s + some buffer) + tokio::time::advance(std::time::Duration::from_secs(12)).await; + + // yield tasks to let the refresh task run and http request layers work + for _ in 0..100 { + tokio::task::yield_now().await; + let result = creds.headers(Extensions::new()).await; + if result.is_ok() { + break; + } + } + + let result = creds.access_token().await; + // if it recovered, we should get token-2 + let access_token = result.expect("the credential should have recovered from the outage!"); + assert!( + access_token.token.contains("token-2"), + "Expected token-2 after recovery, but got: {}", + access_token.token + ); + + Ok(()) + } }