diff --git a/src/robust_provider.rs b/src/robust_provider.rs index be5c193..eb277e8 100644 --- a/src/robust_provider.rs +++ b/src/robust_provider.rs @@ -85,8 +85,8 @@ impl RobustProvider { /// Set the base delay for exponential backoff retries. #[must_use] - pub fn min_delay(mut self, retry_interval: Duration) -> Self { - self.min_delay = retry_interval; + pub fn min_delay(mut self, min_delay: Duration) -> Self { + self.min_delay = min_delay; self } @@ -105,8 +105,8 @@ impl RobustProvider { /// /// Fallback providers are used when the primary provider times out or fails. #[must_use] - pub fn fallback(mut self, provider: RootProvider) -> Self { - self.providers.push(provider); + pub fn fallback(mut self, provider: impl Provider) -> Self { + self.providers.push(provider.root().to_owned()); self } @@ -122,9 +122,10 @@ impl RobustProvider { ) -> Result { info!("eth_getBlockByNumber called"); let result = self - .retry_with_total_timeout(move |provider| async move { - provider.get_block_by_number(number).await - }) + .retry_with_total_timeout( + move |provider| async move { provider.get_block_by_number(number).await }, + false, + ) .await; if let Err(e) = &result { error!(error = %e, "eth_getByBlockNumber failed"); @@ -144,6 +145,7 @@ impl RobustProvider { let result = self .retry_with_total_timeout( move |provider| async move { provider.get_block_number().await }, + false, ) .await; if let Err(e) = &result { @@ -164,9 +166,10 @@ impl RobustProvider { ) -> Result { info!("eth_getBlockByHash called"); let result = self - .retry_with_total_timeout(move |provider| async move { - provider.get_block_by_hash(hash).await - }) + .retry_with_total_timeout( + move |provider| async move { provider.get_block_by_hash(hash).await }, + false, + ) .await; if let Err(e) = &result { error!(error = %e, "eth_getBlockByHash failed"); @@ -186,6 +189,7 @@ impl RobustProvider { let result = self .retry_with_total_timeout( move |provider| async move { provider.get_logs(filter).await }, + false, ) .await; if let Err(e) = &result { @@ -202,11 +206,12 @@ impl RobustProvider { /// after exhausting retries or if the call times out. pub async fn subscribe_blocks(&self) -> Result, Error> { info!("eth_subscribe called"); - // We need this otherwise error is not clear + // immediately fail if primary does not support pubsub self.root().client().expect_pubsub_frontend(); let result = self .retry_with_total_timeout( move |provider| async move { provider.subscribe_blocks().await }, + true, ) .await; if let Err(e) = &result { @@ -224,17 +229,27 @@ impl RobustProvider { /// If the timeout is exceeded and fallback providers are available, it will /// attempt to use each fallback provider in sequence. /// + /// If `require_pubsub` is true, providers that don't support pubsub will be skipped. + /// /// # Errors /// /// - Returns [`RpcError`] with message "total operation timeout exceeded /// and all fallback providers failed" if the overall timeout elapses and no fallback /// providers succeed. + /// - Returns [`RpcError::Transport(TransportErrorKind::PubsubUnavailable)`] if `require_pubsub` + /// is true and all providers don't support pubsub. /// - Propagates any [`RpcError`] from the underlying retries. - async fn retry_with_total_timeout(&self, operation: F) -> Result + async fn retry_with_total_timeout( + &self, + operation: F, + require_pubsub: bool, + ) -> Result where F: Fn(RootProvider) -> Fut, Fut: Future>>, { + let mut skipped_count = 0; + let mut providers = self.providers.iter(); let primary = providers.next().expect("should have primary provider"); @@ -253,6 +268,11 @@ impl RobustProvider { // This loop starts at index 1 automatically for (idx, provider) in providers.enumerate() { let fallback_num = idx + 1; + if require_pubsub && !Self::supports_pubsub(provider) { + info!("Fallback provider {} doesn't support pubsub, skipping", fallback_num); + skipped_count += 1; + continue; + } info!("Attempting fallback provider {}/{}", fallback_num, self.providers.len() - 1); match self.try_provider_with_timeout(provider, &operation).await { @@ -267,6 +287,13 @@ impl RobustProvider { } } + // If all providers were skipped due to pubsub requirement + if skipped_count == self.providers.len() { + error!("All providers skipped - none support pubsub"); + return Err(RpcError::Transport(TransportErrorKind::PubsubUnavailable).into()); + } + + // Return the last error encountered error!("All providers failed or timed out"); Err(last_error) } @@ -298,25 +325,30 @@ impl RobustProvider { .map_err(Error::from)? .map_err(Error::from) } + + /// Check if a provider supports pubsub + fn supports_pubsub(provider: &RootProvider) -> bool { + provider.client().pubsub_frontend().is_some() + } } #[cfg(test)] mod tests { use super::*; - use alloy::network::Ethereum; + use alloy::{ + network::Ethereum, + providers::{ProviderBuilder, WsConnect}, + }; + use alloy_node_bindings::Anvil; use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::time::sleep; - fn test_provider( - timeout: u64, - max_retries: usize, - retry_interval: u64, - ) -> RobustProvider { + fn test_provider(timeout: u64, max_retries: usize, min_delay: u64) -> RobustProvider { RobustProvider { providers: vec![RootProvider::new_http("http://localhost:8545".parse().unwrap())], max_timeout: Duration::from_millis(timeout), max_retries, - min_delay: Duration::from_millis(retry_interval), + min_delay: Duration::from_millis(min_delay), } } @@ -327,11 +359,14 @@ mod tests { let call_count = AtomicUsize::new(0); let result = provider - .retry_with_total_timeout(|_| async { - call_count.fetch_add(1, Ordering::SeqCst); - let count = call_count.load(Ordering::SeqCst); - Ok(count) - }) + .retry_with_total_timeout( + |_| async { + call_count.fetch_add(1, Ordering::SeqCst); + let count = call_count.load(Ordering::SeqCst); + Ok(count) + }, + false, + ) .await; assert!(matches!(result, Ok(1))); @@ -344,14 +379,17 @@ mod tests { let call_count = AtomicUsize::new(0); let result = provider - .retry_with_total_timeout(|_| async { - call_count.fetch_add(1, Ordering::SeqCst); - let count = call_count.load(Ordering::SeqCst); - match count { - 3 => Ok(count), - _ => Err(TransportErrorKind::BackendGone.into()), - } - }) + .retry_with_total_timeout( + |_| async { + call_count.fetch_add(1, Ordering::SeqCst); + let count = call_count.load(Ordering::SeqCst); + match count { + 3 => Ok(count), + _ => Err(TransportErrorKind::BackendGone.into()), + } + }, + false, + ) .await; assert!(matches!(result, Ok(3))); @@ -364,10 +402,13 @@ mod tests { let call_count = AtomicUsize::new(0); let result: Result<(), Error> = provider - .retry_with_total_timeout(|_| async { - call_count.fetch_add(1, Ordering::SeqCst); - Err(TransportErrorKind::BackendGone.into()) - }) + .retry_with_total_timeout( + |_| async { + call_count.fetch_add(1, Ordering::SeqCst); + Err(TransportErrorKind::BackendGone.into()) + }, + false, + ) .await; assert!(matches!(result, Err(Error::RpcError(_)))); @@ -380,12 +421,98 @@ mod tests { let provider = test_provider(max_timeout, 10, 1); let result = provider - .retry_with_total_timeout(move |_provider| async move { - sleep(Duration::from_millis(max_timeout + 10)).await; - Ok(42) - }) + .retry_with_total_timeout( + move |_provider| async move { + sleep(Duration::from_millis(max_timeout + 10)).await; + Ok(42) + }, + false, + ) .await; assert!(matches!(result, Err(Error::Timeout))); } + + #[tokio::test] + async fn test_subscribe_fails_causes_backup_to_be_used() { + let anvil_1 = Anvil::new().port(2222_u16).try_spawn().expect("Failed to start anvil"); + + let ws_provider_1 = ProviderBuilder::new() + .connect_ws(WsConnect::new(anvil_1.ws_endpoint_url().as_str())) + .await + .expect("Failed to connect to WS"); + + let anvil_2 = Anvil::new().port(1111_u16).try_spawn().expect("Failed to start anvil"); + + let ws_provider_2 = ProviderBuilder::new() + .connect_ws(WsConnect::new(anvil_2.ws_endpoint_url().as_str())) + .await + .expect("Failed to connect to WS"); + + let robust = RobustProvider::new(ws_provider_1) + .fallback(ws_provider_2) + .max_timeout(Duration::from_secs(5)) + .max_retries(10) + .min_delay(Duration::from_millis(100)); + + drop(anvil_1); + + let result = robust.subscribe_blocks().await; + + assert!(result.is_ok(), "Expected subscribe blocks to work"); + } + + #[tokio::test] + #[should_panic(expected = "called pubsub_frontend on a non-pubsub transport")] + async fn test_subscribe_fails_if_primary_provider_lacks_pubsub() { + let anvil = Anvil::new().try_spawn().expect("Failed to start anvil"); + + let http_provider = ProviderBuilder::new().connect_http(anvil.endpoint_url()); + let ws_provider = ProviderBuilder::new() + .connect_ws(WsConnect::new(anvil.ws_endpoint_url().as_str())) + .await + .expect("Failed to connect to WS"); + + let robust = RobustProvider::new(http_provider) + .fallback(ws_provider) + .max_timeout(Duration::from_secs(5)) + .max_retries(10) + .min_delay(Duration::from_millis(100)); + + let _ = robust.subscribe_blocks().await; + } + + #[tokio::test] + async fn test_ws_fails_http_fallback_returns_primary_error() { + let anvil_1 = Anvil::new().try_spawn().expect("Failed to start anvil"); + + let ws_provider = ProviderBuilder::new() + .connect_ws(WsConnect::new(anvil_1.ws_endpoint_url().as_str())) + .await + .expect("Failed to connect to WS"); + + let anvil_2 = Anvil::new().port(8222_u16).try_spawn().expect("Failed to start anvil"); + let http_provider = ProviderBuilder::new().connect_http(anvil_2.endpoint_url()); + + let robust = RobustProvider::new(ws_provider.clone()) + .fallback(http_provider) + .max_timeout(Duration::from_millis(500)) + .max_retries(0) + .min_delay(Duration::from_millis(10)); + + // force ws_provider to fail and return BackendGone + drop(anvil_1); + + let err = robust.subscribe_blocks().await.unwrap_err(); + + // The error should be either a Timeout or BackendGone from the primary WS provider, + // NOT a PubsubUnavailable error (which would indicate HTTP fallback was attempted) + match err { + Error::Timeout => {} + Error::RpcError(e) => { + assert!(matches!(e.as_ref(), RpcError::Transport(TransportErrorKind::BackendGone))); + } + Error::BlockNotFound(id) => panic!("Unexpected error type: BlockNotFound({id})"), + } + } } diff --git a/tests/block_range_scanner.rs b/tests/block_range_scanner.rs index 1e30c5a..5e8c393 100644 --- a/tests/block_range_scanner.rs +++ b/tests/block_range_scanner.rs @@ -32,7 +32,7 @@ async fn live_mode_processes_all_blocks_respecting_block_confirmations() -> anyh robust_provider.root().anvil_mine(Some(1), None).await?; - assert_next!(stream, 6..=6); + assert_next!(stream, 6..=6, timeout = 10); assert_empty!(stream); // --- 1 block confirmation --- @@ -50,7 +50,7 @@ async fn live_mode_processes_all_blocks_respecting_block_confirmations() -> anyh robust_provider.root().anvil_mine(Some(1), None).await?; - assert_next!(stream, 11..=11); + assert_next!(stream, 11..=11, timeout = 10); assert_empty!(stream); Ok(())