Skip to content

Commit 4eae5b8

Browse files
authored
Add the ability to spawn tasks by name (#22)
* wip * made durable generic over a context type * renamed app context to state * made TaskContext generic over app state * give callers the ability to spawn pre-registered tasks by name * added a test that covers spawning by name from context * addressed PR comment
1 parent 1eb20f0 commit 4eae5b8

File tree

7 files changed

+238
-32
lines changed

7 files changed

+238
-32
lines changed

src/client.rs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,26 @@ where
282282
.await
283283
}
284284

285-
/// Spawn a task by name (dynamic version for unregistered tasks)
285+
/// Spawn a task by name (dynamic version).
286+
///
287+
/// The task must be registered before spawning.
286288
pub async fn spawn_by_name(
287289
&self,
288290
task_name: &str,
289291
params: JsonValue,
290292
options: SpawnOptions,
291293
) -> anyhow::Result<SpawnResult> {
292-
self.spawn_by_name_with(&self.pool, task_name, params, options)
294+
// Validate that the task is registered
295+
{
296+
let registry = self.registry.read().await;
297+
anyhow::ensure!(
298+
registry.contains_key(task_name),
299+
"Unknown task: {}. Task must be registered before spawning.",
300+
task_name
301+
);
302+
}
303+
304+
self.spawn_by_name_internal(&self.pool, task_name, params, options)
293305
.await
294306
}
295307

@@ -335,18 +347,46 @@ where
335347
T: Task<State>,
336348
E: Executor<'e, Database = Postgres>,
337349
{
338-
self.spawn_by_name_with(executor, T::NAME, serde_json::to_value(&params)?, options)
350+
// Type-safe spawn uses T::NAME which is already registered
351+
self.spawn_by_name_internal(executor, T::NAME, serde_json::to_value(&params)?, options)
339352
.await
340353
}
341354

342355
/// Spawn a task by name using a custom executor.
356+
///
357+
/// The task must be registered before spawning.
343358
pub async fn spawn_by_name_with<'e, E>(
344359
&self,
345360
executor: E,
346361
task_name: &str,
347362
params: JsonValue,
348363
options: SpawnOptions,
349364
) -> anyhow::Result<SpawnResult>
365+
where
366+
E: Executor<'e, Database = Postgres>,
367+
{
368+
// Validate that the task is registered
369+
{
370+
let registry = self.registry.read().await;
371+
anyhow::ensure!(
372+
registry.contains_key(task_name),
373+
"Unknown task: {}. Task must be registered before spawning.",
374+
task_name
375+
);
376+
}
377+
378+
self.spawn_by_name_internal(executor, task_name, params, options)
379+
.await
380+
}
381+
382+
/// Internal spawn implementation without registry validation.
383+
async fn spawn_by_name_internal<'e, E>(
384+
&self,
385+
executor: E,
386+
task_name: &str,
387+
params: JsonValue,
388+
options: SpawnOptions,
389+
) -> anyhow::Result<SpawnResult>
350390
where
351391
E: Executor<'e, Database = Postgres>,
352392
{

src/context.rs

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ use serde_json::Value as JsonValue;
44
use sqlx::PgPool;
55
use std::collections::HashMap;
66
use std::marker::PhantomData;
7+
use std::sync::Arc;
78
use std::time::Duration;
9+
use tokio::sync::RwLock;
810
use uuid::Uuid;
911

1012
use crate::error::{ControlFlow, TaskError, TaskResult};
11-
use crate::task::Task;
13+
use crate::task::{Task, TaskRegistry};
1214
use crate::types::{
1315
AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnOptions,
1416
SpawnResultRow, TaskHandle,
@@ -67,6 +69,9 @@ where
6769
/// Notifies the worker when the lease is extended via step() or heartbeat().
6870
lease_extender: LeaseExtender,
6971

72+
/// Task registry for validating spawn_by_name calls.
73+
registry: Arc<RwLock<TaskRegistry<State>>>,
74+
7075
/// Phantom data to carry the State type parameter.
7176
_state: PhantomData<State>,
7277
}
@@ -93,6 +98,7 @@ where
9398
task: ClaimedTask,
9499
claim_timeout: u64,
95100
lease_extender: LeaseExtender,
101+
registry: Arc<RwLock<TaskRegistry<State>>>,
96102
) -> Result<Self, sqlx::Error> {
97103
// Load all checkpoints for this task into cache
98104
let checkpoints: Vec<CheckpointRow> = sqlx::query_as(
@@ -120,6 +126,7 @@ where
120126
checkpoint_cache: cache,
121127
step_counters: HashMap::new(),
122128
lease_extender,
129+
registry,
123130
_state: PhantomData,
124131
})
125132
}
@@ -456,6 +463,57 @@ where
456463
where
457464
T: Task<State>,
458465
{
466+
let params_json = serde_json::to_value(&params)?;
467+
self.spawn_by_name(name, T::NAME, params_json, options)
468+
.await
469+
}
470+
471+
/// Spawn a subtask by task name (dynamic version).
472+
///
473+
/// This is similar to [`spawn`](Self::spawn) but works with task names
474+
/// instead of requiring a concrete type. Useful for dynamic task invocation
475+
/// where the task type isn't known at compile time.
476+
///
477+
/// The spawn is checkpointed - if this task retries after spawning, the
478+
/// same subtask ID is returned without spawning a duplicate.
479+
///
480+
/// # Arguments
481+
///
482+
/// * `name` - Unique name for this spawn operation (used for checkpointing)
483+
/// * `task_name` - The registered name of the task to spawn
484+
/// * `params` - JSON parameters to pass to the task
485+
/// * `options` - Spawn options (max_attempts, priority, etc.)
486+
///
487+
/// # Returns
488+
///
489+
/// A [`TaskHandle`] that can be passed to [`join`](Self::join) to wait for
490+
/// the result. The output type `T` must match the actual task's output type.
491+
///
492+
/// # Errors
493+
///
494+
/// * `TaskError::Failed` - If the task name is not registered in the registry
495+
///
496+
/// # Example
497+
///
498+
/// ```ignore
499+
/// // Spawn a task by name
500+
/// let handle: TaskHandle<ProcessResult> = ctx.spawn_by_name(
501+
/// "process-item",
502+
/// "process-item-task",
503+
/// serde_json::json!({ "item_id": 123 }),
504+
/// Default::default(),
505+
/// ).await?;
506+
///
507+
/// // Wait for result
508+
/// let result: ProcessResult = ctx.join("process-item", handle).await?;
509+
/// ```
510+
pub async fn spawn_by_name<T: DeserializeOwned>(
511+
&mut self,
512+
name: &str,
513+
task_name: &str,
514+
params: JsonValue,
515+
options: SpawnOptions,
516+
) -> TaskResult<TaskHandle<T>> {
459517
validate_user_name(name)?;
460518
let checkpoint_name = self.get_checkpoint_name(&format!("$spawn:{name}"));
461519

@@ -465,8 +523,18 @@ where
465523
return Ok(TaskHandle::new(task_id));
466524
}
467525

526+
// Validate that the task is registered
527+
{
528+
let registry = self.registry.read().await;
529+
if !registry.contains_key(task_name) {
530+
return Err(TaskError::Failed(anyhow::anyhow!(
531+
"Unknown task: {}. Task must be registered before spawning.",
532+
task_name
533+
)));
534+
}
535+
}
536+
468537
// Build options JSON, merging user options with parent_task_id
469-
let params_json = serde_json::to_value(&params)?;
470538
#[derive(Serialize)]
471539
struct SubtaskOptions<'a> {
472540
parent_task_id: Uuid,
@@ -482,8 +550,8 @@ where
482550
"SELECT task_id, run_id, attempt FROM durable.spawn_task($1, $2, $3, $4)",
483551
)
484552
.bind(&self.queue_name)
485-
.bind(T::NAME)
486-
.bind(&params_json)
553+
.bind(task_name)
554+
.bind(&params)
487555
.bind(&options_json)
488556
.fetch_one(&self.pool)
489557
.await?;

src/worker.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ impl Worker {
247247
task.clone(),
248248
claim_timeout,
249249
lease_extender,
250+
registry.clone(),
250251
)
251252
.await
252253
{

tests/common/tasks.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,3 +1241,51 @@ impl Task<()> for SpawnThenFailTask {
12411241
}))
12421242
}
12431243
}
1244+
1245+
// ============================================================================
1246+
// SpawnByNameTask - Tests spawn_by_name on TaskContext
1247+
// ============================================================================
1248+
1249+
/// Parent task that spawns a child using spawn_by_name (dynamic version)
1250+
#[allow(dead_code)]
1251+
pub struct SpawnByNameTask;
1252+
1253+
#[allow(dead_code)]
1254+
#[derive(Debug, Clone, Serialize, Deserialize)]
1255+
pub struct SpawnByNameParams {
1256+
pub child_value: i32,
1257+
}
1258+
1259+
#[allow(dead_code)]
1260+
#[derive(Debug, Clone, Serialize, Deserialize)]
1261+
pub struct SpawnByNameOutput {
1262+
pub child_result: i32,
1263+
}
1264+
1265+
#[async_trait]
1266+
impl Task<()> for SpawnByNameTask {
1267+
const NAME: &'static str = "spawn-by-name";
1268+
type Params = SpawnByNameParams;
1269+
type Output = SpawnByNameOutput;
1270+
1271+
async fn run(
1272+
params: Self::Params,
1273+
mut ctx: TaskContext,
1274+
_state: (),
1275+
) -> TaskResult<Self::Output> {
1276+
// Spawn child task using spawn_by_name (dynamic version)
1277+
let handle: TaskHandle<i32> = ctx
1278+
.spawn_by_name(
1279+
"child",
1280+
"double", // task name string instead of type
1281+
serde_json::json!({ "value": params.child_value }),
1282+
Default::default(),
1283+
)
1284+
.await?;
1285+
1286+
// Join and get result
1287+
let child_result: i32 = ctx.join("child", handle).await?;
1288+
1289+
Ok(SpawnByNameOutput { child_result })
1290+
}
1291+
}

tests/execution_test.rs

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -349,32 +349,24 @@ async fn test_unregistered_task_fails(pool: PgPool) -> sqlx::Result<()> {
349349
client.create_queue(None).await.unwrap();
350350
// Note: We don't register any task handler
351351

352-
// Spawn a task by name
352+
// Spawn a task by name - should fail at spawn time because task is not registered
353353
let spawn_result = client
354354
.spawn_by_name(
355355
"unregistered-task",
356356
serde_json::json!({}),
357357
Default::default(),
358358
)
359-
.await
360-
.expect("Failed to spawn task");
361-
362-
let worker = client
363-
.start_worker(WorkerOptions {
364-
poll_interval: 0.05,
365-
claim_timeout: 30,
366-
..Default::default()
367-
})
368359
.await;
369360

370-
tokio::time::sleep(Duration::from_millis(500)).await;
371-
worker.shutdown().await;
372-
373-
// Task should have failed because handler is not registered
374-
let state = get_task_state(&pool, "exec_unreg", spawn_result.task_id).await;
375-
assert_eq!(
376-
state, "failed",
377-
"Task with unregistered handler should fail"
361+
assert!(
362+
spawn_result.is_err(),
363+
"Spawning an unregistered task should fail"
364+
);
365+
let err = spawn_result.unwrap_err();
366+
assert!(
367+
err.to_string().contains("Unknown task"),
368+
"Error should mention 'Unknown task': {}",
369+
err
378370
);
379371

380372
Ok(())

tests/fanout_test.rs

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ mod common;
44

55
use common::tasks::{
66
DoubleTask, FailingChildTask, MultiSpawnOutput, MultiSpawnParams, MultiSpawnTask,
7-
SingleSpawnOutput, SingleSpawnParams, SingleSpawnTask, SlowChildTask, SpawnFailingChildTask,
8-
SpawnSlowChildParams, SpawnSlowChildTask,
7+
SingleSpawnOutput, SingleSpawnParams, SingleSpawnTask, SlowChildTask, SpawnByNameOutput,
8+
SpawnByNameParams, SpawnByNameTask, SpawnFailingChildTask, SpawnSlowChildParams,
9+
SpawnSlowChildTask,
910
};
1011
use durable::{Durable, MIGRATOR, WorkerOptions};
1112
use sqlx::{AssertSqlSafe, PgPool};
@@ -348,3 +349,53 @@ async fn test_cascade_cancel_when_parent_cancelled(pool: PgPool) -> sqlx::Result
348349

349350
Ok(())
350351
}
352+
353+
// ============================================================================
354+
// spawn_by_name Tests
355+
// ============================================================================
356+
357+
#[sqlx::test(migrator = "MIGRATOR")]
358+
async fn test_spawn_by_name_from_task_context(pool: PgPool) -> sqlx::Result<()> {
359+
let client = create_client(pool.clone(), "fanout_by_name").await;
360+
client.create_queue(None).await.unwrap();
361+
client.register::<SpawnByNameTask>().await;
362+
client.register::<DoubleTask>().await;
363+
364+
// Spawn parent task that will use spawn_by_name internally
365+
let spawn_result = client
366+
.spawn::<SpawnByNameTask>(SpawnByNameParams { child_value: 21 })
367+
.await
368+
.expect("Failed to spawn task");
369+
370+
// Start worker with concurrency to handle both parent and child
371+
let worker = client
372+
.start_worker(WorkerOptions {
373+
poll_interval: 0.05,
374+
claim_timeout: 30,
375+
concurrency: 2,
376+
..Default::default()
377+
})
378+
.await;
379+
380+
// Wait for tasks to complete
381+
tokio::time::sleep(Duration::from_millis(2000)).await;
382+
worker.shutdown().await;
383+
384+
// Verify parent task completed
385+
let state = get_task_state(&pool, "fanout_by_name", spawn_result.task_id).await;
386+
assert_eq!(state, "completed", "Parent task should be completed");
387+
388+
// Verify result
389+
let result = get_task_result(&pool, "fanout_by_name", spawn_result.task_id)
390+
.await
391+
.expect("Task should have a result");
392+
393+
let output: SpawnByNameOutput =
394+
serde_json::from_value(result).expect("Failed to deserialize result");
395+
assert_eq!(
396+
output.child_result, 42,
397+
"Child should have doubled 21 to 42 (spawned via spawn_by_name)"
398+
);
399+
400+
Ok(())
401+
}

0 commit comments

Comments
 (0)