Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/auth/src/credentials/external_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1959,7 +1959,7 @@ mod tests {
.unwrap();

let err = creds.headers(Extensions::new()).await.unwrap_err();
assert!(!err.is_transient());
assert!(err.is_transient());
sts_server.verify_and_clear();
subject_token_server.verify_and_clear();
}
Expand Down Expand Up @@ -2084,7 +2084,7 @@ mod tests {
.unwrap();

let err = creds.headers(Extensions::new()).await.unwrap_err();
assert!(!err.is_transient());
assert!(err.is_transient());
sts_server.verify_and_clear();
}

Expand Down
2 changes: 1 addition & 1 deletion src/auth/src/credentials/idtoken/user_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ mod tests {
.build()?;

let err = credentials.id_token().await.unwrap_err();
assert!(!err.is_transient());
assert!(err.is_transient());

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/auth/src/credentials/mds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ mod tests {
.build_token_provider();

let err = provider.token().await.unwrap_err();
assert!(!err.is_transient());
assert!(err.is_transient());
server.verify_and_clear();
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/auth/src/credentials/user_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ mod tests {
.build()?;

let err = credentials.headers(Extensions::new()).await.unwrap_err();
assert!(!err.is_transient());
assert!(err.is_transient());
server.verify_and_clear();
Ok(())
}
Expand Down
15 changes: 8 additions & 7 deletions src/auth/src/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,15 @@ where
return CredentialsError::from_source(false, e);
}

let msg = match e
match e
.source()
.and_then(|s| s.downcast_ref::<CredentialsError>())
{
Some(cred_error) if cred_error.is_transient() => constants::RETRY_EXHAUSTED_ERROR,
_ => constants::TOKEN_FETCH_FAILED_ERROR,
};
CredentialsError::new(false, msg, e)
Some(cred_error) if cred_error.is_transient() => {
CredentialsError::new(true, constants::RETRY_EXHAUSTED_ERROR, e)
}
_ => CredentialsError::new(false, constants::TOKEN_FETCH_FAILED_ERROR, e),
}
}
}

Expand Down Expand Up @@ -286,7 +287,7 @@ mod tests {
.build(mock_provider);

let error = provider.token().await.unwrap_err();
assert!(!error.is_transient());
assert!(error.is_transient());
let original_error = find_source_error::<CredentialsError>(&error).unwrap();
assert!(original_error.is_transient());
assert!(error.to_string().contains(constants::RETRY_EXHAUSTED_ERROR));
Expand Down Expand Up @@ -350,7 +351,7 @@ mod tests {
let provider = Builder::default().build(mock_provider);

let error = provider.token().await.unwrap_err();
assert!(!error.is_transient());
assert!(error.is_transient());
let original_error = find_source_error::<CredentialsError>(&error).unwrap();
assert!(original_error.is_transient());
}
Expand Down
77 changes: 70 additions & 7 deletions src/auth/src/token_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ async fn refresh_task<T>(
{
loop {
let token_result = token_provider.token().await;
let expiry = token_result.as_ref().ok().map(|t| t.expires_at);
let tagged = token_result.map(|token| {
let expiry = token_result.as_ref().map(|t| t.expires_at);
let tagged = token_result.clone().map(|token| {
let entity_tag = EntityTag::new();
(token, entity_tag)
});

let _ = tx_token.send(Some(tagged));

match expiry {
Some(Some(expiry)) => {
Ok(Some(expiry)) => {
let time_until_expiry = expiry.checked_duration_since(Instant::now());

match time_until_expiry {
Expand All @@ -129,15 +129,20 @@ async fn refresh_task<T>(
}
}
}
Some(None) => {
Ok(None) => {
// If there is no expiry, the token is valid forever, so no need to refresh
// TODO(#1553): Validate that all auth backends provide expiry and make expiry not optional.
break;
}
None => {
Err(err) => {
// The retry policy has been used already by the inner token provider.
// If it ended in an error, just quit the background task.
break;
// If it ended in an error, the background task will wait for a while
// and try again. This allows the task to eventually recover if the
// error was transient but exhausted all the retries.
if !err.is_transient() {
break;
}
sleep(SHORT_REFRESH_SLACK).await;
}
}
}
Expand Down Expand Up @@ -596,6 +601,64 @@ mod tests {
assert_eq!(actual, token2);
}

#[tokio::test(start_paused = true)]
async fn refresh_task_sleeps_on_transient_error_and_recovers_on_next_loop() -> TestResult {
let now = Instant::now();

let token = Token {
token: "token-1".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + TOKEN_VALID_DURATION),
metadata: None,
};

let mut mock = MockTokenProvider::new();
// 1st request succeeds
mock.expect_token()
.times(1)
.return_once(move || Ok(token.clone()));

// 2nd request (triggered by refresh loop) fails with transient error
mock.expect_token()
.times(1)
.return_once(|| Err(CredentialsError::from_msg(true, "transient error")));

let token = Token {
token: "token-2".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + 2 * TOKEN_VALID_DURATION),
metadata: None,
};

// 3rd request (triggered by next loop) succeeds
mock.expect_token()
.times(1)
.return_once(move || Ok(token.clone()));

let cache = TokenCache::new(mock);

// fetch an initial token
let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?;
assert_eq!(actual.token, "token-1");

// advance time to force expiration, which wakes up the background task.
let sleep = TOKEN_VALID_DURATION.add(Duration::from_secs(10));
tokio::time::advance(sleep).await;

let result = cache.token(Extensions::new()).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("transient error"));

// Wait another SHORT_REFRESH_SLACK + buffer for the background loop to try again and recover
tokio::time::advance(SHORT_REFRESH_SLACK.add(Duration::from_secs(10))).await;
tokio::task::yield_now().await;

let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?;
assert_eq!(actual.token, "token-2");

Ok(())
}

#[derive(Clone, Debug)]
struct FakeTokenProvider {
result: Result<Token>,
Expand Down
77 changes: 77 additions & 0 deletions src/auth/tests/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -965,4 +965,81 @@ mod tests {

Ok(())
}

#[tokio::test]
#[serial]
async fn test_credentials_refresh_recovers_after_outage() -> TestResult {
let server = Server::run();
let addr = server.addr().to_string();
let _e = ScopedEnv::set("GCE_METADATA_HOST", &addr);

// initial token request: success (200 OK)
server.expect(
Expectation::matching(request::path(
"/computeMetadata/v1/instance/service-accounts/default/token",
))
.times(1)
.respond_with(json_encoded(json!({
"access_token": "token-1",
"expires_in": 1, // short lived to trigger refresh soon
"token_type": "Bearer"
}))),
);

let creds = MdsBuilder::default().build_access_token_credentials()?;

// get initial token, this starts the refresh_task
let access_token = creds.access_token().await?;
assert!(
access_token.token.contains("token-1"),
"Expected token-1, got {}",
access_token.token
);

// set up outage: fail (503 Service Unavailable)
server.expect(
Expectation::matching(request::path(
"/computeMetadata/v1/instance/service-accounts/default/token",
))
.times(1..) // called at least once
.respond_with(status_code(503)),
);

// wait for the token to be refreshed
// it should exhaust retries, fail, and (with the fix) wait for SHORT_REFRESH_SLACK (10s)
tokio::time::sleep(std::time::Duration::from_millis(3000)).await;

// trying to get a token now should fail because retry was exhausted
let result = creds.headers(Extensions::new()).await;
assert!(
result.is_err(),
"expected error due to exhausted retries during outage, but got: {:?}",
result
);

// set up recovery: success (200 OK)
server.expect(
Expectation::matching(request::path(
"/computeMetadata/v1/instance/service-accounts/default/token",
))
.respond_with(json_encoded(json!({
"access_token": "token-2",
"expires_in": 3600,
"token_type": "Bearer"
}))),
);

// advance time long enough to pass through SHORT_REFRESH_SLACK (10s + some buffer)
tokio::time::sleep(std::time::Duration::from_secs(12)).await;
let result = creds.access_token().await;
// if it recovered, we should get token-2
let access_token = result.expect("the credential should have recovered from the outage!");
assert!(
access_token.token.contains("token-2"),
"Expected token-2 after recovery, but got: {}",
access_token.token
);

Ok(())
}
}
Loading