From f23ddbaf9ea5e8ce399aebcaf079b080f98df16a Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Mon, 17 Mar 2025 15:07:09 -0400 Subject: [PATCH] feat: Allow fetching multiple tasks at once Change the store to allow fetching multiple pending tasks at once. This is the first step towards being able to batch fetch tasks from the worker. This adds a `get_pending_activations` function, and points the existing function to call that function with a limit of 1. In the future new endpoints can leverage the pending activations function directly. That function accepts the namespaces that are being requested, and the number of pending tasks that can be returned. --- src/grpc/server.rs | 8 +--- src/grpc/server_tests.rs | 4 +- src/store/inflight_activation.rs | 65 ++++++++++++++++++++------ src/store/inflight_activation_tests.rs | 34 +++++++++++--- src/test_utils.rs | 2 +- 5 files changed, 85 insertions(+), 28 deletions(-) diff --git a/src/grpc/server.rs b/src/grpc/server.rs index fb522839..0fdd9cb6 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -25,11 +25,8 @@ impl ConsumerService for TaskbrokerServer { request: Request, ) -> Result, Status> { let start_time = Instant::now(); - let namespace = &request.get_ref().namespace; - let inflight = self - .store - .get_pending_activation(namespace.as_deref()) - .await; + let namespace = request.get_ref().namespace.as_deref(); + let inflight = self.store.get_pending_activation(namespace).await; match inflight { Ok(Some(inflight)) => { @@ -123,7 +120,6 @@ impl ConsumerService for TaskbrokerServer { }; let start_time = Instant::now(); - let res = match self .store .get_pending_activation(namespace.as_deref()) diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index b1c6ad1a..44e6f62c 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -92,7 +92,9 @@ async fn test_set_task_status_success() { let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete - fetch_next_task: Some(FetchNextTask { namespace: None }), + fetch_next_task: Some(FetchNextTask { + namespace: Some("namespace".to_string()), + }), }; let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index 0f2e9b3b..7180e9f6 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -326,11 +326,13 @@ impl InflightActivationStore { } #[instrument(skip_all)] - pub async fn get_pending_activation( + pub async fn get_pending_activations( &self, - namespace: Option<&str>, - ) -> Result, Error> { + namespaces: Option>, + limit: Option, + ) -> Result>, Error> { let now = Utc::now(); + let to_return = limit.unwrap_or(1); let mut query_builder = QueryBuilder::new( " @@ -344,7 +346,7 @@ impl InflightActivationStore { query_builder.push_bind(InflightActivationStatus::Processing); query_builder.push( " - WHERE id = ( + WHERE id IN ( SELECT id FROM inflight_taskactivations WHERE status = ", @@ -354,19 +356,54 @@ impl InflightActivationStore { query_builder.push_bind(now.timestamp()); query_builder.push(")"); - if let Some(namespace) = namespace { - query_builder.push(" AND namespace = "); - query_builder.push_bind(namespace); + let namespaces_vec: Vec; + if let Some(namespaces) = namespaces { + query_builder.push(" AND namespace IN ("); + let mut separated = query_builder.separated(", "); + namespaces_vec = namespaces.iter().map(|ns| ns.to_string()).collect(); + for namespace in namespaces_vec.iter() { + separated.push_bind(namespace); + } + separated.push_unseparated(")"); } - query_builder.push(" ORDER BY added_at LIMIT 1) RETURNING *"); + query_builder.push(" ORDER BY added_at LIMIT "); + query_builder.push_bind(to_return); + query_builder.push(") RETURNING *"); + + let rows: Vec = query_builder + .build_query_as() + .fetch_all(&self.write_pool) + .await? + .into_iter() + .map(|row: TableRow| row.into()) + .collect(); - let result: Option = query_builder - .build_query_as::() - .fetch_optional(&self.write_pool) - .await?; - let Some(row) = result else { return Ok(None) }; + if rows.is_empty() { + return Ok(None); + } - Ok(Some(row.into())) + Ok(Some(rows)) + } + + #[instrument(skip_all)] + pub async fn get_pending_activation( + &self, + namespace: Option<&str>, + ) -> Result, Error> { + if let Some(namespace) = namespace { + match self + .get_pending_activations(Some(vec![namespace]), Some(1)) + .await? + { + Some(rows) => Ok(Some(rows[0].clone())), + None => Ok(None), + } + } else { + match self.get_pending_activations(None, Some(1)).await? { + Some(rows) => Ok(Some(rows[0].clone())), + None => Ok(None), + } + } } #[instrument(skip_all)] diff --git a/src/store/inflight_activation_tests.rs b/src/store/inflight_activation_tests.rs index d6dbcc57..2d3d3d47 100644 --- a/src/store/inflight_activation_tests.rs +++ b/src/store/inflight_activation_tests.rs @@ -147,11 +147,7 @@ async fn test_get_pending_activation_with_race() { let store = store.clone(); join_set.spawn(async move { rx.recv().await.unwrap(); - store - .get_pending_activation(Some("namespace")) - .await - .unwrap() - .unwrap() + store.get_pending_activation(None).await.unwrap().unwrap() }); } @@ -218,6 +214,26 @@ async fn test_get_pending_activation_earliest() { ); } +#[tokio::test] +async fn test_get_pending_activations() { + let store = create_test_store().await; + + let mut batch = make_activations(5); + batch[1].namespace = "other_namespace".into(); + assert!(store.store(batch.clone()).await.is_ok()); + + let result = store + .get_pending_activations(Some(vec!["namespace", "other_namespace"]), Some(2)) + .await + .unwrap() + .unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result[0].activation.id, "id_0"); + assert_eq!(result[1].activation.id, "id_1"); + assert_count_by_status(&store, InflightActivationStatus::Pending, 3).await; +} + #[tokio::test] async fn test_count_pending_activations() { let store = create_test_store().await; @@ -266,7 +282,13 @@ async fn set_activation_status() { .is_ok() ); assert_eq!(store.count_pending_activations().await.unwrap(), 0); - assert!(store.get_pending_activation(None).await.unwrap().is_none()); + assert!( + store + .get_pending_activation(Some("namespace")) + .await + .unwrap() + .is_none() + ); } #[tokio::test] diff --git a/src/test_utils.rs b/src/test_utils.rs index 1a456804..7cfcf4c3 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -49,7 +49,7 @@ pub fn make_activations(count: u32) -> Vec { status: InflightActivationStatus::Pending, partition: 0, offset: i as i64, - added_at: Utc::now(), + added_at: Utc::now() + chrono::Duration::seconds(i as i64), processing_attempts: 0, expires_at: None, processing_deadline: None,