diff --git a/reqwest-retry/CHANGELOG.md b/reqwest-retry/CHANGELOG.md index 23c1030..edb925e 100644 --- a/reqwest-retry/CHANGELOG.md +++ b/reqwest-retry/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Report retry count on `Ok` results that underwent retries through a `RetryCount` response extension. + ## [0.7.0] - 2024-11-08 ### Breaking changes diff --git a/reqwest-retry/src/lib.rs b/reqwest-retry/src/lib.rs index 807ae9a..5921061 100644 --- a/reqwest-retry/src/lib.rs +++ b/reqwest-retry/src/lib.rs @@ -32,7 +32,7 @@ mod retryable_strategy; pub use retry_policies::{policies, Jitter, RetryDecision, RetryPolicy}; use thiserror::Error; -pub use middleware::RetryTransientMiddleware; +pub use middleware::{RetryCount, RetryTransientMiddleware}; pub use retryable::Retryable; pub use retryable_strategy::{ default_on_request_failure, default_on_request_success, DefaultRetryableStrategy, diff --git a/reqwest-retry/src/middleware.rs b/reqwest-retry/src/middleware.rs index a8339c1..f2b9201 100644 --- a/reqwest-retry/src/middleware.rs +++ b/reqwest-retry/src/middleware.rs @@ -184,20 +184,51 @@ where } }; - // Report whether we failed with or without retries. break if n_past_retries > 0 { - result.map_err(|err| { - Error::Middleware( + // Both `Ok` results (e.g. status code errors) and `Err` results (e.g. an + // `io::Error` for from connection reset), and we want to inform the user about the + // retries in both cases. + match result { + Ok(mut response) => { + response + .extensions_mut() + .insert(RetryCount::new(n_past_retries)); + Ok(response) + } + Err(err) => Err(Error::Middleware( RetryError::WithRetries { retries: n_past_retries, err, } .into(), - ) - }) + )), + } } else { result.map_err(|err| Error::Middleware(RetryError::Error(err).into())) }; } } } + +/// Extension type to store retry count in a response. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RetryCount(u32); + +impl RetryCount { + /// Create a new retry count. + pub fn new(count: u32) -> Self { + Self(count) + } + + pub fn value(self) -> u32 { + self.0 + } +} + +impl std::ops::Deref for RetryCount { + type Target = u32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/reqwest-retry/tests/all/helpers/simple_server.rs b/reqwest-retry/tests/all/helpers/simple_server.rs index 205aaff..da6d249 100644 --- a/reqwest-retry/tests/all/helpers/simple_server.rs +++ b/reqwest-retry/tests/all/helpers/simple_server.rs @@ -126,7 +126,7 @@ impl SimpleServer { /// Parses the request line and checks that it contains the method, uri and http_version parts. /// It does not check if the content of the checked parts is correct. It just checks the format (it contains enough parts) of the request. - fn parse_request_line(request: &str) -> Result> { + fn parse_request_line(request: &str) -> Result, Box> { let mut parts = request.split_whitespace(); let method = parts.next().ok_or("Method not specified")?;