From aad17d1c4f57d38533a1a31b4dee7acdd280ac80 Mon Sep 17 00:00:00 2001 From: Nikolay Denev Date: Mon, 9 Mar 2026 07:49:55 -0500 Subject: [PATCH] Fix client creation race in K8s client pool --- Cargo.lock | 4 + Cargo.toml | 7 +- src/kubernetes/client.rs | 177 +++++++++++++++++++++++++++++++++++---- 3 files changed, 170 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8c8d2c2..215cd9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2871,6 +2871,7 @@ dependencies = [ "anyhow", "async-stream", "async-trait", + "bytes", "chrono", "clap", "comfy-table", @@ -2881,6 +2882,8 @@ dependencies = [ "dirs", "fastrand", "futures", + "http", + "http-body-util", "indicatif", "k8s-metrics", "k8s-openapi", @@ -2895,6 +2898,7 @@ dependencies = [ "serde_yaml", "tempfile", "tokio", + "tower", "tracing", "tracing-rolling-file", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index e53585c..9dade65 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,9 +88,14 @@ lazy-regex = "3.5.1" # Atomic file operations tempfile = "3.24.0" +[dev-dependencies] +bytes = "1.11.0" +http = "1.3.1" +http-body-util = "0.1.3" +tower = { version = "0.5.2", features = ["util"] } + [profile.release] strip = true lto = "thin" codegen-units = 1 opt-level = "z" # Optimize for size - diff --git a/src/kubernetes/client.rs b/src/kubernetes/client.rs index 15debb4..3d70a92 100644 --- a/src/kubernetes/client.rs +++ b/src/kubernetes/client.rs @@ -6,10 +6,10 @@ use kube::config::{KubeConfigOptions, Kubeconfig}; use kube::{Api, Client, Config, api::ListParams}; use std::collections::HashMap; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Mutex as StdMutex}; use std::time::{Duration, Instant}; -use tokio::sync::RwLock; +use tokio::sync::{Mutex, RwLock}; use tracing::{debug, info, trace, warn}; use super::ApiFilters; @@ -145,6 +145,7 @@ impl CachedRegistry { pub struct K8sClientPool { kubeconfig: Kubeconfig, clients: Arc>>, + client_creation_locks: Arc>>>>, registries: Arc>>, /// Current active contexts (supports multi-USE) current_contexts: Arc>>, @@ -152,8 +153,15 @@ pub struct K8sClientPool { progress: ProgressHandle, /// Local disk cache for CRD discovery results resource_cache: ResourceCache, + #[cfg(test)] + client_factory: Arc>>>, } +#[cfg(test)] +type TestClientFactory = dyn Fn(String) -> Pin> + Send>> + + Send + + Sync; + /// Extract (group, version, kind) tuples from CRD list /// Only includes the storage version (the canonical/preferred version) /// This matches kubectl's behavior of showing one version per CRD @@ -181,13 +189,28 @@ impl K8sClientPool { /// Create a minimal client pool for unit tests (no kubeconfig required) #[cfg(test)] pub fn new_for_test(progress: crate::progress::ProgressHandle) -> Self { + let test_context = "test-context".to_string(); Self { - kubeconfig: Kubeconfig::default(), + kubeconfig: Kubeconfig { + contexts: vec![kube::config::NamedContext { + name: test_context.clone(), + context: Some(kube::config::Context { + cluster: "test-cluster".to_string(), + user: None, + namespace: None, + extensions: None, + }), + }], + current_context: Some(test_context.clone()), + ..Kubeconfig::default() + }, clients: Arc::new(RwLock::new(HashMap::new())), + client_creation_locks: Arc::new(StdMutex::new(HashMap::new())), registries: Arc::new(RwLock::new(HashMap::new())), - current_contexts: Arc::new(RwLock::new(vec!["test-context".to_string()])), + current_contexts: Arc::new(RwLock::new(vec![test_context])), progress, resource_cache: ResourceCache::new().expect("Failed to create test cache"), + client_factory: Arc::new(RwLock::new(None)), } } @@ -212,10 +235,13 @@ impl K8sClientPool { Ok(Self { kubeconfig, clients: Arc::new(RwLock::new(HashMap::new())), + client_creation_locks: Arc::new(StdMutex::new(HashMap::new())), registries: Arc::new(RwLock::new(HashMap::new())), current_contexts: Arc::new(RwLock::new(vec![context_name])), progress: crate::progress::create_progress_handle(), resource_cache: ResourceCache::new()?, + #[cfg(test)] + client_factory: Arc::new(RwLock::new(None)), }) } @@ -627,11 +653,56 @@ impl K8sClientPool { } } + let client_creation_lock = self.client_creation_lock(context); + let _client_creation_guard = client_creation_lock.lock().await; + + { + let clients = self.clients.read().await; + if let Some(client) = clients.get(context) { + self.progress + .connected(context, start.elapsed().as_millis() as u64); + return Ok(client.clone()); + } + } + // Verify context exists if !self.kubeconfig.contexts.iter().any(|c| c.name == context) { return Err(anyhow!("Context '{}' not found in kubeconfig", context)); } + let client = self.build_client(context).await?; + + // Report connected + self.progress + .connected(context, start.elapsed().as_millis() as u64); + + // Cache it + { + let mut clients = self.clients.write().await; + clients.insert(context.to_string(), client.clone()); + } + + Ok(client) + } + + fn client_creation_lock(&self, context: &str) -> Arc> { + let mut locks = self + .client_creation_locks + .lock() + .expect("client creation locks mutex poisoned"); + + locks + .entry(context.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() + } + + async fn build_client(&self, context: &str) -> Result { + #[cfg(test)] + if let Some(factory) = self.client_factory.read().await.clone() { + return factory(context.to_string()).await; + } + // Create new client with timeouts let mut config = Config::from_custom_kubeconfig( self.kubeconfig.clone(), @@ -647,20 +718,13 @@ impl K8sClientPool { config.connect_timeout = Some(CONNECT_TIMEOUT); config.read_timeout = Some(READ_TIMEOUT); - let client = Client::try_from(config) - .with_context(|| format!("Failed to create client for context '{}'", context))?; - - // Report connected - self.progress - .connected(context, start.elapsed().as_millis() as u64); - - // Cache it - { - let mut clients = self.clients.write().await; - clients.insert(context.to_string(), client.clone()); - } + Client::try_from(config) + .with_context(|| format!("Failed to create client for context '{}'", context)) + } - Ok(client) + #[cfg(test)] + async fn set_client_factory_for_test(&self, factory: Arc) { + *self.client_factory.write().await = Some(factory); } /// Get client for a specific context, or current context if None @@ -1046,6 +1110,19 @@ impl K8sClientPool { #[cfg(test)] mod tests { + use super::K8sClientPool; + use crate::progress::create_progress_handle; + use bytes::Bytes; + use futures::future::join_all; + use http::{Request, Response, StatusCode}; + use http_body_util::Full; + use std::convert::Infallible; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; + use tokio::sync::Barrier; + use tower::service_fn; + /// Helper function to test alias building logic in isolation /// This replicates the logic from process_discovered_crds fn build_aliases( @@ -1151,4 +1228,70 @@ mod tests { assert!(aliases.contains(&"cert".to_string())); assert!(aliases.contains(&"certificate".to_string())); } + + fn make_test_client() -> kube::Client { + kube::Client::new( + service_fn(|_request: Request| async move { + Ok::<_, Infallible>( + Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::from_static(b"{}"))) + .expect("test response"), + ) + }), + "default", + ) + } + + #[tokio::test] + async fn get_or_create_client_creates_once_per_context_under_concurrency() { + let pool = Arc::new(K8sClientPool::new_for_test(create_progress_handle())); + let create_count = Arc::new(AtomicUsize::new(0)); + + pool.set_client_factory_for_test(Arc::new({ + let create_count = create_count.clone(); + move |_context| { + let create_count = create_count.clone(); + Box::pin(async move { + create_count.fetch_add(1, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(50)).await; + Ok(make_test_client()) + }) + } + })) + .await; + + let callers = 16; + let start_barrier = Arc::new(Barrier::new(callers + 1)); + + let tasks = (0..callers) + .map(|_| { + let pool = pool.clone(); + let start_barrier = start_barrier.clone(); + tokio::spawn(async move { + start_barrier.wait().await; + pool.get_client(Some("test-context")).await + }) + }) + .collect::>(); + + start_barrier.wait().await; + + let results = join_all(tasks).await; + for result in results { + result + .expect("task join should succeed") + .expect("client creation should succeed"); + } + + assert_eq!( + create_count.load(Ordering::SeqCst), + 1, + "parallel callers should share a single client construction" + ); + + let clients = pool.clients.read().await; + assert_eq!(clients.len(), 1); + assert!(clients.contains_key("test-context")); + } }