diff --git a/aerospike-core/src/client.rs b/aerospike-core/src/client.rs index d7dc10c3..56b38cfe 100644 --- a/aerospike-core/src/client.rs +++ b/aerospike-core/src/client.rs @@ -36,9 +36,9 @@ use crate::{ }; use aerospike_rt::fs::File; #[cfg(all(any(feature = "rt-tokio"), not(feature = "rt-async-std")))] -use aerospike_rt::io::AsyncReadExt; +use aerospike_rt::{io::AsyncReadExt, Semaphore}; #[cfg(all(any(feature = "rt-async-std"), not(feature = "rt-tokio")))] -use futures::AsyncReadExt; +use futures::{AsyncReadExt, Semaphore}; /// Instantiate a Client instance to access an Aerospike database cluster and perform database /// operations. @@ -682,7 +682,19 @@ impl Client { let bins = bins.into(); let nodes = self.cluster.nodes().await; let recordset = Arc::new(Recordset::new(policy.record_queue_size, nodes.len())); + + let mut handles = Vec::new(); + let semaphore = Arc::new(Semaphore::new( + if policy.max_concurrent_nodes == 0 { + Semaphore::MAX_PERMITS + } else { + policy.max_concurrent_nodes + } + )); + for node in nodes { + let semaphore = semaphore.clone(); + let partitions = self.cluster.node_partitions(node.as_ref(), namespace).await; let node = node.clone(); let recordset = recordset.clone(); @@ -691,14 +703,21 @@ impl Client { let set_name = set_name.to_owned(); let bins = bins.clone(); - aerospike_rt::spawn(async move { + let handle = aerospike_rt::spawn(async move { + let permit = semaphore.acquire().await + .map_err(|e| Error::ClientError(format!("Failed to acquire semaphore: {}", e)))?; let mut command = ScanCommand::new( &policy, node, &namespace, &set_name, bins, recordset, partitions, ); - command.execute().await.unwrap(); - }) - .await?; + let result = command.execute().await; + drop(permit); + result + }); + + handles.push(handle); } + + futures::future::try_join_all(handles).await?; Ok(recordset) } @@ -729,7 +748,7 @@ impl Client { let namespace = namespace.to_owned(); let set_name = set_name.to_owned(); - aerospike_rt::spawn(async move { + let _ = aerospike_rt::spawn(async move { let mut command = ScanCommand::new( &policy, node, @@ -739,7 +758,8 @@ impl Client { t_recordset, partitions, ); - command.execute().await.unwrap(); + command.execute().await?; + Ok::<(), Error>(()) }) .await?; @@ -781,7 +801,19 @@ impl Client { let nodes = self.cluster.nodes().await; let recordset = Arc::new(Recordset::new(policy.record_queue_size, nodes.len())); + + let mut handles = Vec::new(); + let semaphore = Arc::new(Semaphore::new( + if policy.max_concurrent_nodes == 0 { + Semaphore::MAX_PERMITS + } else { + policy.max_concurrent_nodes + } + )); + for node in nodes { + let semaphore = semaphore.clone(); + let partitions = self .cluster .node_partitions(node.as_ref(), &statement.namespace) @@ -790,13 +822,20 @@ impl Client { let t_recordset = recordset.clone(); let policy = policy.clone(); let statement = statement.clone(); - aerospike_rt::spawn(async move { - let mut command = - QueryCommand::new(&policy, node, statement, t_recordset, partitions); - command.execute().await.unwrap(); - }) - .await?; + + let handle = aerospike_rt::spawn(async move { + let permit = semaphore.acquire().await + .map_err(|e| Error::ClientError(format!("Failed to acquire semaphore: {}", e)))?; + let mut command = QueryCommand::new(&policy, node, statement, t_recordset, partitions); + let result = command.execute().await; + drop(permit); + result + }); + + handles.push(handle); } + + futures::future::try_join_all(handles).await?; Ok(recordset) } @@ -823,9 +862,10 @@ impl Client { .node_partitions(node.as_ref(), &statement.namespace) .await; - aerospike_rt::spawn(async move { + let _ = aerospike_rt::spawn(async move { let mut command = QueryCommand::new(&policy, node, statement, t_recordset, partitions); - command.execute().await.unwrap(); + command.execute().await?; + Ok::<(), Error>(()) }) .await?; diff --git a/aerospike-rt/src/lib.rs b/aerospike-rt/src/lib.rs index 2d0e73f1..93f3c2de 100644 --- a/aerospike-rt/src/lib.rs +++ b/aerospike-rt/src/lib.rs @@ -6,10 +6,10 @@ compile_error!("Please select only one runtime"); #[cfg(all(any(feature = "rt-async-std"), not(feature = "rt-tokio")))] pub use async_std::{ - self, fs, future::timeout, io, net, sync::RwLock, task, task::sleep, task::spawn, + self, fs, future::timeout, io, net, sync::RwLock, sync::Semaphore, task, task::sleep, task::spawn, }; #[cfg(all(any(feature = "rt-tokio"), not(feature = "rt-async-std")))] -pub use tokio::{self, fs, io, net, spawn, sync::RwLock, task, time, time::sleep, time::timeout}; +pub use tokio::{self, fs, io, net, spawn, sync::RwLock, sync::Semaphore, task, time, time::sleep, time::timeout}; #[cfg(all(any(feature = "rt-async-std"), not(feature = "rt-tokio")))] pub use std::time;