diff --git a/Cargo.lock b/Cargo.lock index ff3c643..40d1d4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -419,13 +419,15 @@ dependencies = [ ] [[package]] -name = "aws-smithy-mocks-experimental" -version = "0.2.4" +name = "aws-smithy-mocks" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ce8a35535906a8a9ceadbe7ff70ae8686a36f7df03b288b1256c084a5c45c69" +checksum = "178b1ad961028a58d48ce857f86ffbd5233a4b7e2c7b56d026fb1c1afe46696e" dependencies = [ + "aws-smithy-http-client", "aws-smithy-runtime-api", "aws-smithy-types", + "http 1.3.1", ] [[package]] @@ -590,7 +592,7 @@ version = "2.0.0" dependencies = [ "aws-config", "aws-sdk-secretsmanager", - "aws-smithy-mocks-experimental", + "aws-smithy-mocks", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", diff --git a/aws_secretsmanager_caching/Cargo.toml b/aws_secretsmanager_caching/Cargo.toml index 631ab7c..b2b4e8e 100644 --- a/aws_secretsmanager_caching/Cargo.toml +++ b/aws_secretsmanager_caching/Cargo.toml @@ -23,7 +23,7 @@ rustls = "0" log = "0.4.20" [dev-dependencies] -aws-smithy-mocks-experimental = "0" +aws-smithy-mocks = "0.1" aws-smithy-runtime = { version = "1", features = ["test-util", "wire-mock"] } aws-sdk-secretsmanager = { version = "1", features = ["test-util"] } tokio = { version = "1", features = ["macros", "rt", "sync", "test-util"] } diff --git a/aws_secretsmanager_caching/src/lib.rs b/aws_secretsmanager_caching/src/lib.rs index 0704f74..25070e2 100644 --- a/aws_secretsmanager_caching/src/lib.rs +++ b/aws_secretsmanager_caching/src/lib.rs @@ -430,37 +430,49 @@ impl SecretsManagerCachingClient { counter.load(Ordering::Relaxed) } } - #[cfg(test)] mod tests { + use aws_sdk_secretsmanager::{ + config::http::HttpResponse, + operation::{ + describe_secret::{DescribeSecretError, DescribeSecretOutput}, + get_secret_value::{GetSecretValueError, GetSecretValueOutput}, + }, + types::error::ResourceNotFoundException, + }; + use aws_smithy_mocks::{mock, mock_client, RuleMode}; + use aws_smithy_types::{body::SdkBody, error::ErrorMetadata}; use tokio::time::sleep; use super::*; - use aws_smithy_runtime_api::client::http::SharedHttpClient; - - fn fake_client( - ttl: Option, - ignore_transient_errors: bool, - http_client: Option, - endpoint_url: Option, - ) -> SecretsManagerCachingClient { - SecretsManagerCachingClient::new( - asm_mock::def_fake_client(http_client, endpoint_url), - NonZeroUsize::new(1000).unwrap(), - match ttl { - Some(ttl) => ttl, - None => Duration::from_secs(1000), - }, - ignore_transient_errors, - ) - .expect("client should create") - } + use aws_smithy_runtime_api::{client::result::SdkError, http::StatusCode}; #[tokio::test] async fn test_get_secret_value() { - let client = fake_client(None, false, None, None); let secret_id = "test_secret"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| req.secret_id() == Some(secret_id)) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_stages("AWSCURRENT") + .build() + }); + + let asm_mock = mock_client!(aws_sdk_secretsmanager, [&gsv]); + + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(1000), + true, + ) + .unwrap(); let response = client .get_secret_value(secret_id, None, None, false) @@ -469,53 +481,87 @@ mod tests { assert_eq!(response.name, Some(secret_id.to_string())); assert_eq!(response.secret_string, Some("hunter2".to_string())); - assert_eq!( - response.arn, - Some( - asm_mock::FAKE_ARN - .replace("{{name}}", secret_id) - .to_string() - ) - ); + assert_eq!(response.arn, Some(arn.into())); assert_eq!( response.version_stages, Some(vec!["AWSCURRENT".to_string()]) ); + assert_eq!(gsv.num_calls(), 1) } #[tokio::test] async fn test_get_secret_value_version_id() { - let client = fake_client(None, false, None, None); let secret_id = "test_secret"; let version_id = "test_version"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| { + req.secret_id() == Some(secret_id) && req.version_id() == Some(version_id) + }) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_id(version_id) + .version_stages("AWSCURRENT") + .build() + }); + + let asm_mock = mock_client!(aws_sdk_secretsmanager, [&gsv]); + + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(1000), + true, + ) + .unwrap(); let response = client .get_secret_value(secret_id, Some(version_id), None, false) .await .unwrap(); - assert_eq!(response.name, Some(secret_id.to_string())); assert_eq!(response.secret_string, Some("hunter2".to_string())); assert_eq!(response.version_id, Some(version_id.to_string())); - assert_eq!( - response.arn, - Some( - asm_mock::FAKE_ARN - .replace("{{name}}", secret_id) - .to_string() - ) - ); + assert_eq!(response.arn, Some(arn.into())); assert_eq!( response.version_stages, Some(vec!["AWSCURRENT".to_string()]) ); + assert_eq!(gsv.num_calls(), 1) } #[tokio::test] async fn test_get_secret_value_version_stage() { - let client = fake_client(None, false, None, None); let secret_id = "test_secret"; let stage_label = "STAGEHERE"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| { + req.secret_id() == Some(secret_id) && req.version_stage() == Some(stage_label) + }) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_stages(stage_label) + .build() + }); + + let asm_mock = mock_client!(aws_sdk_secretsmanager, [&gsv]); + + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(1000), + true, + ) + .unwrap(); let response = client .get_secret_value(secret_id, None, Some(stage_label), false) @@ -524,47 +570,97 @@ mod tests { assert_eq!(response.name, Some(secret_id.to_string())); assert_eq!(response.secret_string, Some("hunter2".to_string())); - assert_eq!( - response.arn, - Some( - asm_mock::FAKE_ARN - .replace("{{name}}", secret_id) - .to_string() - ) - ); + assert_eq!(response.arn, Some(arn.into())); assert_eq!(response.version_stages, Some(vec![stage_label.to_string()])); + assert_eq!(gsv.num_calls(), 1) } #[tokio::test] async fn test_get_secret_value_version_id_and_stage() { - let client = fake_client(None, false, None, None); let secret_id = "test_secret"; let version_id = "test_version"; let stage_label = "STAGEHERE"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| { + req.secret_id() == Some(secret_id) + && req.version_stage() == Some(stage_label) + && req.version_id() == Some(version_id) + }) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_stages(stage_label) + .version_id(version_id) + .build() + }); + + let asm_mock = mock_client!(aws_sdk_secretsmanager, [&gsv]); + + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(1000), + true, + ) + .unwrap(); let response = client .get_secret_value(secret_id, Some(version_id), Some(stage_label), false) .await .unwrap(); - assert_eq!(response.name, Some(secret_id.to_string())); assert_eq!(response.secret_string, Some("hunter2".to_string())); assert_eq!(response.version_id, Some(version_id.to_string())); - assert_eq!( - response.arn, - Some( - asm_mock::FAKE_ARN - .replace("{{name}}", secret_id) - .to_string() - ) - ); + assert_eq!(response.arn, Some(arn.into())); assert_eq!(response.version_stages, Some(vec![stage_label.to_string()])); + assert_eq!(gsv.num_calls(), 1) } #[tokio::test] async fn test_get_cache_expired() { - let client = fake_client(Some(Duration::from_secs(0)), false, None, None); let secret_id = "test_secret"; + let version_id = "version_id"; + let version_stage = "AWSCURRENT"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| req.secret_id() == Some(secret_id)) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_id(version_id) + .version_stages(version_stage) + .build() + }); + + let describe_secret = + mock!(aws_sdk_secretsmanager::Client::describe_secret).then_output(move || { + // Don't serve the same value + DescribeSecretOutput::builder() + .name(secret_id) + .version_ids_to_stages("different_version_id", vec![version_stage.into()]) + .build() + }); + + let asm_mock = mock_client!( + aws_sdk_secretsmanager, + RuleMode::MatchAny, + [&gsv, &describe_secret] + ); + + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(0), + true, + ) + .unwrap(); // Run through this twice to test the cache expiration for i in 0..2 { @@ -575,14 +671,7 @@ mod tests { assert_eq!(response.name, Some(secret_id.to_string())); assert_eq!(response.secret_string, Some("hunter2".to_string())); - assert_eq!( - response.arn, - Some( - asm_mock::FAKE_ARN - .replace("{{name}}", secret_id) - .to_string() - ) - ); + assert_eq!(response.arn, Some(arn.into())); assert_eq!( response.version_stages, Some(vec!["AWSCURRENT".to_string()]) @@ -592,229 +681,518 @@ mod tests { sleep(Duration::from_millis(50)).await; } } + + assert_eq!(gsv.num_calls(), 2) } #[tokio::test] - #[should_panic] async fn test_get_secret_value_kms_access_denied() { - let client = fake_client(None, false, None, None); + let gsv = + mock!(aws_sdk_secretsmanager::Client::get_secret_value).then_http_response(|| { + HttpResponse::new( + StatusCode::try_from(400).unwrap(), + SdkBody::from( + r##"{ + "__type":"AccessDeniedException", + "message":"Access to KMS is not allowed" + }"##, + ), + ) + }); + + let asm_mock = mock_client!(aws_sdk_secretsmanager, &[gsv]); + + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(1000), + true, + ) + .unwrap(); let secret_id = "KMSACCESSDENIEDabcdef"; - client - .get_secret_value(secret_id, None, None, false) - .await - .unwrap(); + match client.get_secret_value(secret_id, None, None, false).await { + Ok(_) => panic!(), + Err(e) => e.to_string().contains("Access to KMS is not allowed"), + }; } #[tokio::test] - #[should_panic] async fn test_get_secret_value_resource_not_found() { - let client = fake_client(None, false, None, None); + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value).then_error(|| { + GetSecretValueError::ResourceNotFoundException( + ResourceNotFoundException::builder() + .message("Secrets Manager can't find the specified secret.") + .build(), + ) + }); + + let asm_mock = mock_client!(aws_sdk_secretsmanager, &[gsv]); + + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(1000), + true, + ) + .unwrap(); + let secret_id = "NOTFOUNDfasefasef"; - client - .get_secret_value(secret_id, None, None, false) - .await - .unwrap(); + match client.get_secret_value(secret_id, None, None, false).await { + Ok(_) => panic!(), + Err(e) => assert!(e + .downcast::>() + .unwrap() + .into_service_error() + .is_resource_not_found_exception()), + }; } #[tokio::test] - async fn test_is_current_default_succeeds() { - let client = fake_client(Some(Duration::from_secs(0)), false, None, None); + async fn test_get_cache_is_current_fast_refreshes() { let secret_id = "test_secret"; + let version_id = "version_id"; + let version_stage = "AWSCURRENT"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| req.secret_id() == Some(secret_id)) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_id(version_id) + .version_stages(version_stage) + .build() + }); + + let describe_secret = + mock!(aws_sdk_secretsmanager::Client::describe_secret).then_output(move || { + // Cache is current. We fast-refresh + DescribeSecretOutput::builder() + .name(secret_id) + .version_ids_to_stages(version_id, vec![version_stage.into()]) + .build() + }); + + let asm_mock = mock_client!( + aws_sdk_secretsmanager, + RuleMode::MatchAny, + [&gsv, &describe_secret] + ); - let res1 = client - .get_secret_value(secret_id, None, None, false) - .await - .unwrap(); + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(0), + true, + ) + .unwrap(); - let res2 = client - .get_secret_value(secret_id, None, None, false) - .await - .unwrap(); + // Run through this twice to test the cache expiration + for i in 0..2 { + let response = client + .get_secret_value(secret_id, None, None, false) + .await + .unwrap(); - assert_eq!(res1, res2) + assert_eq!(response.name, Some(secret_id.to_string())); + assert_eq!(response.secret_string, Some("hunter2".to_string())); + assert_eq!(response.arn, Some(arn.into())); + assert_eq!( + response.version_stages, + Some(vec!["AWSCURRENT".to_string()]) + ); + // let the entry expire + if i == 0 { + sleep(Duration::from_millis(50)).await; + } + } + + assert_eq!(gsv.num_calls(), 1); + assert_eq!(describe_secret.num_calls(), 1); } #[tokio::test] async fn test_is_current_version_id_succeeds() { - let client = fake_client(Some(Duration::from_secs(0)), false, None, None); let secret_id = "test_secret"; - let version_id = Some("test_version"); + let version_id = "version_id"; + let version_stage = "AWSCURRENT"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| { + req.secret_id() == Some(secret_id) && req.version_id() == Some(version_id) + }) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_id(version_id) + .version_stages(version_stage) + .build() + }); + + let describe_secret = + mock!(aws_sdk_secretsmanager::Client::describe_secret).then_output(move || { + // Cache is current. We fast-refresh + DescribeSecretOutput::builder() + .name(secret_id) + .version_ids_to_stages(version_id, vec![version_stage.into()]) + .build() + }); + + let asm_mock = mock_client!( + aws_sdk_secretsmanager, + RuleMode::MatchAny, + [&gsv, &describe_secret] + ); - let res1 = client - .get_secret_value(secret_id, version_id, None, false) - .await - .unwrap(); + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(0), + true, + ) + .unwrap(); - let res2 = client - .get_secret_value(secret_id, version_id, None, false) - .await - .unwrap(); + // Run through this twice to test the cache expiration + for i in 0..2 { + let response = client + .get_secret_value(secret_id, Some(version_id), None, false) + .await + .unwrap(); - assert_eq!(res1, res2) + assert_eq!(response.name, Some(secret_id.to_string())); + assert_eq!(response.secret_string, Some("hunter2".to_string())); + assert_eq!(response.arn, Some(arn.into())); + assert_eq!( + response.version_stages, + Some(vec!["AWSCURRENT".to_string()]) + ); + // let the entry expire + if i == 0 { + sleep(Duration::from_millis(50)).await; + } + } + + assert_eq!(gsv.num_calls(), 1); + assert_eq!(describe_secret.num_calls(), 1); } #[tokio::test] async fn test_is_current_version_stage_succeeds() { - let client = fake_client(Some(Duration::from_secs(0)), false, None, None); let secret_id = "test_secret"; - let version_stage = Some("VERSIONSTAGE"); + let version_id = "version_id"; + let version_stage = "VERSIONSTAGE"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| { + req.secret_id() == Some(secret_id) && req.version_stage() == Some(version_stage) + }) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_id(version_id) + .version_stages(version_stage) + .build() + }); + + let describe_secret = + mock!(aws_sdk_secretsmanager::Client::describe_secret).then_output(move || { + // Cache is current. We fast-refresh + DescribeSecretOutput::builder() + .name(secret_id) + .version_ids_to_stages(version_id, vec![version_stage.into()]) + .build() + }); + + let asm_mock = mock_client!( + aws_sdk_secretsmanager, + RuleMode::MatchAny, + [&gsv, &describe_secret] + ); - let res1 = client - .get_secret_value(secret_id, None, version_stage, false) - .await - .unwrap(); + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(0), + true, + ) + .unwrap(); - let res2 = client - .get_secret_value(secret_id, None, version_stage, false) - .await - .unwrap(); + // Run through this twice to test the cache expiration + for i in 0..2 { + let response = client + .get_secret_value(secret_id, None, Some(version_stage), false) + .await + .unwrap(); - assert_eq!(res1, res2) + assert_eq!(response.name, Some(secret_id.to_string())); + assert_eq!(response.secret_string, Some("hunter2".to_string())); + assert_eq!(response.arn, Some(arn.into())); + assert_eq!( + response.version_stages, + Some(vec![version_stage.to_string()]) + ); + // let the entry expire + if i == 0 { + sleep(Duration::from_millis(50)).await; + } + } + + assert_eq!(gsv.num_calls(), 1); + assert_eq!(describe_secret.num_calls(), 1); } #[tokio::test] - async fn test_is_current_both_version_id_and_version_stage_succeeds() { - let client = fake_client(Some(Duration::from_secs(0)), false, None, None); + async fn test_is_current_both_version_id_and_version_stage_succeed() { let secret_id = "test_secret"; - let version_id = Some("test_version"); - let version_stage = Some("VERSIONSTAGE"); + let version_id = "version_id"; + let version_stage = "VERSIONSTAGE"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| { + req.secret_id() == Some(secret_id) + && req.version_stage() == Some(version_stage) + && req.version_id() == Some(version_id) + }) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_id(version_id) + .version_stages(version_stage) + .build() + }); + + let describe_secret = + mock!(aws_sdk_secretsmanager::Client::describe_secret).then_output(move || { + // Cache is current. We fast-refresh + DescribeSecretOutput::builder() + .name(secret_id) + .version_ids_to_stages(version_id, vec![version_stage.into()]) + .build() + }); + + let asm_mock = mock_client!( + aws_sdk_secretsmanager, + RuleMode::MatchAny, + [&gsv, &describe_secret] + ); - let res1 = client - .get_secret_value(secret_id, version_id, version_stage, false) - .await - .unwrap(); + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(0), + true, + ) + .unwrap(); - let res2 = client - .get_secret_value(secret_id, version_id, version_stage, false) - .await - .unwrap(); + // Run through this twice to test the cache expiration + for i in 0..2 { + let response = client + .get_secret_value(secret_id, Some(version_id), Some(version_stage), false) + .await + .unwrap(); - assert_eq!(res1, res2) + assert_eq!(response.name, Some(secret_id.to_string())); + assert_eq!(response.secret_string, Some("hunter2".to_string())); + assert_eq!(response.arn, Some(arn.into())); + assert_eq!( + response.version_stages, + Some(vec![version_stage.to_string()]) + ); + // let the entry expire + if i == 0 { + sleep(Duration::from_millis(50)).await; + } + } + + assert_eq!(gsv.num_calls(), 1); + assert_eq!(describe_secret.num_calls(), 1); } #[tokio::test] async fn test_is_current_describe_access_denied_fails() { - let client = fake_client(Some(Duration::from_secs(0)), false, None, None); - let secret_id = "DESCRIBEACCESSDENIED_test_secret"; - let version_id = Some("test_version"); - - client - .get_secret_value(secret_id, version_id, None, false) - .await - .unwrap(); + let secret_id = "test_secret"; + let version_id = "version_id"; + let version_stage = "VERSIONSTAGE"; + let arn = "arn"; + let secret_string = "hunter2"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| { + req.secret_id() == Some(secret_id) + && req.version_stage() == Some(version_stage) + && req.version_id() == Some(version_id) + }) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string(secret_string) + .version_id(version_id) + .version_stages(version_stage) + .build() + }); + + let describe_secret = + mock!(aws_sdk_secretsmanager::Client::describe_secret).then_error(|| { + // TODO: Figure out how to set __type + DescribeSecretError::generic( + ErrorMetadata::builder() + .code("400") + .message("is not authorized to perform: secretsmanager:DescribeSecret on resource: XXXXXXXX") + .build(), + ) + }); - if (client - .get_secret_value(secret_id, version_id, None, false) - .await) - .is_ok() - { - panic!("Expected failure") - } - } + let asm_mock = mock_client!( + aws_sdk_secretsmanager, + RuleMode::MatchAny, + [&gsv, &describe_secret] + ); - #[tokio::test] - async fn test_is_current_describe_timeout_error_succeeds() { - use asm_mock::GSV_BODY; - use aws_smithy_runtime::client::http::test_util::wire::{ReplayedEvent, WireMockServer}; - - let mock = WireMockServer::start(vec![ - ReplayedEvent::with_body(GSV_BODY), - ReplayedEvent::Timeout, - ]) - .await; - let client = fake_client( - Some(Duration::from_secs(0)), + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(0), true, - Some(mock.http_client()), - Some(mock.endpoint_url()), - ); - let secret_id = "DESCRIBETIMEOUT_test_secret"; - let version_id = Some("test_version"); + ) + .unwrap(); - let res1 = client - .get_secret_value(secret_id, version_id, None, false) + // Run through this twice to test the cache expiration + let response = client + .get_secret_value(secret_id, Some(version_id), Some(version_stage), false) .await .unwrap(); - let res2 = client - .get_secret_value(secret_id, version_id, None, false) - .await - .unwrap(); + assert_eq!(response.name, Some(secret_id.to_string())); + assert_eq!(response.secret_string, Some(secret_string.to_string())); + assert_eq!(response.arn, Some(arn.into())); + assert_eq!( + response.version_stages, + Some(vec![version_stage.to_string()]) + ); + // let the entry expire + sleep(Duration::from_millis(50)).await; - mock.shutdown(); + if client + .get_secret_value(secret_id, Some(version_id), Some(version_stage), false) + .await + .is_ok() + { + panic!("Expected failure") + } - assert_eq!(res1, res2) + assert_eq!(gsv.num_calls(), 1) } #[tokio::test] async fn test_is_current_describe_service_error_succeeds() { - let client = fake_client(Some(Duration::from_secs(0)), true, None, None); let secret_id = "DESCRIBESERVICEERROR_test_secret"; - let version_id = Some("test_version"); - let version_stage = Some("VERSIONSTAGE"); - - let res1 = client - .get_secret_value(secret_id, version_id, version_stage, false) - .await - .unwrap(); - - let res2 = client - .get_secret_value(secret_id, version_id, version_stage, false) - .await - .unwrap(); - - assert_eq!(res1, res2) - } + let version_id = "test_version"; + let version_stage = "VERSIONSTAGE"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| { + req.secret_id() == Some(secret_id) + && req.version_stage() == Some(version_stage) + && req.version_id() == Some(version_id) + }) + .then_output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_id(version_id) + .version_stages(version_stage) + .build() + }); + + let describe_secret = mock!(aws_sdk_secretsmanager::Client::describe_secret) + .then_http_response(|| { + HttpResponse::new( + StatusCode::try_from(500).unwrap(), + SdkBody::from( + r##"{ + "__type": "InternalServiceError", + "message": "Internal service error" + }"##, + ), + ) + }); - #[tokio::test] - async fn test_is_current_gsv_timeout_error_succeeds() { - use asm_mock::DESC_BODY; - use asm_mock::GSV_BODY; - use aws_smithy_runtime::client::http::test_util::wire::{ReplayedEvent, WireMockServer}; - - let mock = WireMockServer::start(vec![ - ReplayedEvent::with_body( - GSV_BODY - .replace("{{version}}", "old_version") - .replace("{{label}}", "AWSCURRENT"), - ), - ReplayedEvent::with_body( - DESC_BODY - .replace("{{version}}", "new_version") - .replace("{{label}}", "AWSCURRENT"), - ), - ReplayedEvent::Timeout, - ]) - .await; - let client = fake_client( - Some(Duration::from_secs(0)), - true, - Some(mock.http_client()), - Some(mock.endpoint_url()), + let asm_mock = mock_client!( + aws_sdk_secretsmanager, + RuleMode::MatchAny, + [&gsv, &describe_secret] ); - let secret_id = "GSVTIMEOUT_test_secret"; + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::ZERO, + true, + ) + .unwrap(); let res1 = client - .get_secret_value(secret_id, None, None, false) + .get_secret_value(secret_id, Some(version_id), Some(version_stage), false) .await .unwrap(); let res2 = client - .get_secret_value(secret_id, None, None, false) + .get_secret_value(secret_id, Some(version_id), Some(version_stage), false) .await .unwrap(); - mock.shutdown(); - assert_eq!(res1, res2) } #[tokio::test] async fn test_get_secret_value_refresh_now_true() { - let client = fake_client(Some(Duration::from_secs(30)), false, None, None); let secret_id = "REFRESHNOW_test_secret"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| req.secret_id() == Some(secret_id)) + .sequence() + .output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_stages("AWSCURRENT") + .build() + }) + .output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("some other string") + .version_stages("AWSCURRENT") + .build() + }) + .build(); + + let asm_mock = mock_client!(aws_sdk_secretsmanager, [&gsv]); + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(30), + true, + ) + .unwrap(); let response1 = client .get_secret_value(secret_id, None, None, false) @@ -822,35 +1200,57 @@ mod tests { .unwrap(); assert_eq!(response1.name, Some(secret_id.to_string())); - assert_eq!( - response1.arn, - Some( - asm_mock::FAKE_ARN - .replace("{{name}}", secret_id) - .to_string() - ) - ); + assert_eq!(response1.arn, Some(arn.into())); assert_eq!( response1.version_stages, Some(vec!["AWSCURRENT".to_string()]) ); - sleep(Duration::from_millis(1)).await; - let response2 = client .get_secret_value(secret_id, None, None, true) .await .unwrap(); - assert_ne!(response1.secret_string, response2.secret_string); assert_eq!(response1.arn, response2.arn); assert_eq!(response1.version_stages, response2.version_stages); + + assert_eq!(gsv.num_calls(), 2) } #[tokio::test] async fn test_get_secret_value_refresh_now_false() { - let client = fake_client(Some(Duration::from_secs(30)), false, None, None); let secret_id = "REFRESHNOW_test_secret"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| req.secret_id() == Some(secret_id)) + .sequence() + .output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_stages("AWSCURRENT") + .build() + }) + .output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("some other string") + .version_stages("AWSCURRENT") + .build() + }) + .build(); + + let asm_mock = mock_client!(aws_sdk_secretsmanager, [&gsv]); + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(30), + true, + ) + .unwrap(); let response1 = client .get_secret_value(secret_id, None, None, false) @@ -858,69 +1258,88 @@ mod tests { .unwrap(); assert_eq!(response1.name, Some(secret_id.to_string())); - assert_eq!( - response1.arn, - Some( - asm_mock::FAKE_ARN - .replace("{{name}}", secret_id) - .to_string() - ) - ); + assert_eq!(response1.arn, Some(arn.into())); assert_eq!( response1.version_stages, Some(vec!["AWSCURRENT".to_string()]) ); - sleep(Duration::from_millis(1)).await; - let response2 = client .get_secret_value(secret_id, None, None, false) .await .unwrap(); assert_eq!(response1, response2); + + assert_eq!(gsv.num_calls(), 1) } #[tokio::test] async fn test_get_secret_value_version_id_and_stage_refresh_now() { - let client = fake_client(Some(Duration::from_secs(30)), false, None, None); let secret_id = "REFRESHNOW_test_secret"; let version_id = "test_version"; let stage_label = "STAGEHERE"; + let arn = "arn"; + + let gsv = mock!(aws_sdk_secretsmanager::Client::get_secret_value) + .match_requests(|req| req.secret_id() == Some(secret_id)) + .sequence() + .output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("hunter2") + .version_stages("AWSCURRENT") + .build() + }) + .output(move || { + GetSecretValueOutput::builder() + .name(secret_id) + .arn(arn) + .secret_string("some other string") + .version_stages("AWSCURRENT") + .build() + }) + .build(); + + let asm_mock = mock_client!(aws_sdk_secretsmanager, [&gsv]); + let client = SecretsManagerCachingClient::new( + asm_mock, + NonZeroUsize::new(1000).unwrap(), + Duration::from_secs(30), + true, + ) + .unwrap(); let response1 = client .get_secret_value(secret_id, Some(version_id), Some(stage_label), false) .await .unwrap(); - sleep(Duration::from_millis(1)).await; - let response2 = client .get_secret_value(secret_id, Some(version_id), Some(stage_label), true) .await .unwrap(); - assert_ne!(response1.secret_string, response2.secret_string); assert_eq!(response1.arn, response2.arn); assert_eq!(response1.version_stages, response2.version_stages); + + assert_eq!(gsv.num_calls(), 2); } - mod asm_mock { + mod wire_tests { use aws_sdk_secretsmanager as secretsmanager; - use aws_smithy_runtime::client::http::test_util::infallible_client_fn; + + use aws_smithy_runtime::client::http::test_util::wire::WireMockServer; use aws_smithy_runtime_api::client::http::SharedHttpClient; - use aws_smithy_types::body::SdkBody; + use aws_smithy_types::timeout::TimeoutConfig; - use http::{Request, Response}; + use secretsmanager::config::BehaviorVersion; - use serde_json::Value; - use std::time::{Duration, SystemTime, UNIX_EPOCH}; - pub const FAKE_ARN: &str = - "arn:aws:secretsmanager:us-west-2:123456789012:secret:{{name}}-NhBWsc"; - pub const DEFAULT_VERSION: &str = "5767290c-d089-49ed-b97c-17086f8c9d79"; - pub const DEFAULT_LABEL: &str = "AWSCURRENT"; - pub const DEFAULT_SECRET_STRING: &str = "hunter2"; + use std::{num::NonZeroUsize, time::Duration}; + + use crate::SecretsManagerCachingClient; // Template GetSecretValue responses for testing pub const GSV_BODY: &str = r###"{ @@ -933,7 +1352,6 @@ mod tests { ], "CreatedDate": 1569534789.046 }"###; - // Template DescribeSecret responses for testing pub const DESC_BODY: &str = r###"{ "ARN": "{{arn}}", @@ -950,83 +1368,10 @@ mod tests { "CreatedDate": 1569534789.046 }"###; - // Template for access denied testing - const KMS_ACCESS_DENIED_BODY: &str = r###"{ - "__type":"AccessDeniedException", - "Message":"Access to KMS is not allowed" - }"###; - - // Template for testing resource not found with DescribeSecret - const NOT_FOUND_EXCEPTION_BODY: &str = r###"{ - "__type":"ResourceNotFoundException", - "message":"Secrets Manager can't find the specified secret." - }"###; - - const SECRETSMANAGER_ACCESS_DENIED_BODY: &str = r###"{ - "__type:"AccessDeniedException", - "Message": "is not authorized to perform: secretsmanager:DescribeSecret on resource: XXXXXXXX" - }"###; - - const SECRETSMANAGER_INTERNAL_SERVICE_ERROR_BODY: &str = r###"{ - "__type:"InternalServiceError", - "Message": "Internal service error" - }"###; - - // Private helper to look at the request and provide the correct response. - fn format_rsp(req: Request) -> (u16, String) { - let (parts, body) = req.into_parts(); - - let req_map: serde_json::Map = - serde_json::from_slice(body.bytes().unwrap()).unwrap(); - let version = req_map - .get("VersionId") - .map_or(DEFAULT_VERSION, |x| x.as_str().unwrap()); - let label = req_map - .get("VersionStage") - .map_or(DEFAULT_LABEL, |x| x.as_str().unwrap()); - let name = req_map.get("SecretId").unwrap().as_str().unwrap(); // Does not handle full ARN case. - - let secret_string = match name { - secret if secret.starts_with("REFRESHNOW") => SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis() - .to_string(), - _ => DEFAULT_SECRET_STRING.to_string(), - }; - - let (code, template) = match parts.headers["x-amz-target"].to_str().unwrap() { - "secretsmanager.GetSecretValue" if name.starts_with("KMSACCESSDENIED") => { - (400, KMS_ACCESS_DENIED_BODY) - } - "secretsmanager.GetSecretValue" if name.starts_with("NOTFOUND") => { - (400, NOT_FOUND_EXCEPTION_BODY) - } - "secretsmanager.GetSecretValue" => (200, GSV_BODY), - "secretsmanager.DescribeSecret" if name.contains("DESCRIBEACCESSDENIED") => { - (400, SECRETSMANAGER_ACCESS_DENIED_BODY) - } - "secretsmanager.DescribeSecret" if name.contains("DESCRIBESERVICEERROR") => { - (500, SECRETSMANAGER_INTERNAL_SERVICE_ERROR_BODY) - } - "secretsmanager.DescribeSecret" => (200, DESC_BODY), - _ => panic!("Unknown operation"), - }; - - // Fill in the template and return the response. - let rsp = template - .replace("{{arn}}", FAKE_ARN) - .replace("{{name}}", name) - .replace("{{version}}", version) - .replace("{{secret}}", &secret_string) - .replace("{{label}}", label); - (code, rsp) - } - // Test client that stubs off network call and provides a canned response. pub fn def_fake_client( - http_client: Option, - endpoint_url: Option, + http_client: SharedHttpClient, + endpoint_url: String, ) -> secretsmanager::Client { let fake_creds = secretsmanager::config::Credentials::new( "AKIDTESTKEY", @@ -1035,7 +1380,6 @@ mod tests { None, "", ); - let mut config_builder = secretsmanager::Config::builder() .behavior_version(BehaviorVersion::latest()) .credentials_provider(fake_creds) @@ -1045,22 +1389,98 @@ mod tests { .operation_attempt_timeout(Duration::from_millis(100)) .build(), ) - .http_client(match http_client { - Some(custom_client) => custom_client, - None => infallible_client_fn(|_req| { - let (code, rsp) = format_rsp(_req); - Response::builder() - .status(code) - .body(SdkBody::from(rsp)) - .unwrap() - }), - }); - config_builder = match endpoint_url { - Some(endpoint_url) => config_builder.endpoint_url(endpoint_url), - None => config_builder, - }; + .http_client(http_client); + config_builder = config_builder.endpoint_url(endpoint_url); secretsmanager::Client::from_conf(config_builder.build()) } + + fn fake_client( + ttl: Option, + ignore_transient_errors: bool, + wire_server: &WireMockServer, + ) -> SecretsManagerCachingClient { + SecretsManagerCachingClient::new( + def_fake_client(wire_server.http_client(), wire_server.endpoint_url()), + NonZeroUsize::new(1000).unwrap(), + match ttl { + Some(ttl) => ttl, + None => Duration::from_secs(1000), + }, + ignore_transient_errors, + ) + .expect("client should create") + } + + #[tokio::test] + async fn test_is_current_gsv_timeout_error_succeeds() { + use aws_smithy_runtime::client::http::test_util::wire::{ + ReplayedEvent, WireMockServer, + }; + + let mock = WireMockServer::start(vec![ + ReplayedEvent::with_body( + GSV_BODY + .replace("{{version}}", "old_version") + .replace("{{label}}", "AWSCURRENT"), + ), + ReplayedEvent::with_body( + DESC_BODY + .replace("{{version}}", "new_version") + .replace("{{label}}", "AWSCURRENT"), + ), + ReplayedEvent::Timeout, + ]) + .await; + + let client = fake_client(Some(Duration::from_secs(0)), true, &mock); + + let secret_id = "GSVTIMEOUT_test_secret"; + + let res1 = client + .get_secret_value(secret_id, None, None, false) + .await + .unwrap(); + + let res2 = client + .get_secret_value(secret_id, None, None, false) + .await + .unwrap(); + + mock.shutdown(); + + assert_eq!(res1, res2) + } + + #[tokio::test] + async fn test_is_current_describe_timeout_error_succeeds() { + // TODO: Figure out how to do this with mocks + use aws_smithy_runtime::client::http::test_util::wire::{ + ReplayedEvent, WireMockServer, + }; + + let mock = WireMockServer::start(vec![ + ReplayedEvent::with_body(GSV_BODY), + ReplayedEvent::Timeout, + ]) + .await; + let client = fake_client(Some(Duration::from_secs(0)), true, &mock); + let secret_id = "DESCRIBETIMEOUT_test_secret"; + let version_id = Some("test_version"); + + let res1 = client + .get_secret_value(secret_id, version_id, None, false) + .await + .unwrap(); + + let res2 = client + .get_secret_value(secret_id, version_id, None, false) + .await + .unwrap(); + + mock.shutdown(); + + assert_eq!(res1, res2) + } } }