Skip to content

Commit 5963af6

Browse files
authored
Made durable generic over a context type (#20)
* wip * made durable generic over a context type * renamed app context to state * made TaskContext generic over app state
1 parent f7c57e1 commit 5963af6

File tree

8 files changed

+543
-112
lines changed

8 files changed

+543
-112
lines changed

benches/common/tasks.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ use serde::{Deserialize, Serialize};
99
pub struct NoOpTask;
1010

1111
#[async_trait]
12-
impl Task for NoOpTask {
12+
impl Task<()> for NoOpTask {
1313
const NAME: &'static str = "bench-noop";
1414
type Params = ();
1515
type Output = ();
1616

17-
async fn run(_params: Self::Params, _ctx: TaskContext) -> TaskResult<Self::Output> {
17+
async fn run(_params: Self::Params, _ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
1818
Ok(())
1919
}
2020
}
@@ -33,12 +33,12 @@ pub struct QuickParams {
3333
}
3434

3535
#[async_trait]
36-
impl Task for QuickTask {
36+
impl Task<()> for QuickTask {
3737
const NAME: &'static str = "bench-quick";
3838
type Params = QuickParams;
3939
type Output = u32;
4040

41-
async fn run(params: Self::Params, _ctx: TaskContext) -> TaskResult<Self::Output> {
41+
async fn run(params: Self::Params, _ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
4242
Ok(params.task_num)
4343
}
4444
}
@@ -57,12 +57,16 @@ pub struct MultiStepParams {
5757
}
5858

5959
#[async_trait]
60-
impl Task for MultiStepBenchTask {
60+
impl Task<()> for MultiStepBenchTask {
6161
const NAME: &'static str = "bench-multi-step";
6262
type Params = MultiStepParams;
6363
type Output = u32;
6464

65-
async fn run(params: Self::Params, mut ctx: TaskContext) -> TaskResult<Self::Output> {
65+
async fn run(
66+
params: Self::Params,
67+
mut ctx: TaskContext,
68+
_state: (),
69+
) -> TaskResult<Self::Output> {
6670
for i in 0..params.num_steps {
6771
let _: u32 = ctx
6872
.step(&format!("step-{}", i), || async move { Ok(i) })
@@ -86,12 +90,16 @@ pub struct LargePayloadParams {
8690
}
8791

8892
#[async_trait]
89-
impl Task for LargePayloadBenchTask {
93+
impl Task<()> for LargePayloadBenchTask {
9094
const NAME: &'static str = "bench-large-payload";
9195
type Params = LargePayloadParams;
9296
type Output = usize;
9397

94-
async fn run(params: Self::Params, mut ctx: TaskContext) -> TaskResult<Self::Output> {
98+
async fn run(
99+
params: Self::Params,
100+
mut ctx: TaskContext,
101+
_state: (),
102+
) -> TaskResult<Self::Output> {
95103
let payload = "x".repeat(params.payload_size);
96104
let _: String = ctx
97105
.step("large-step", || async move { Ok(payload) })

src/client.rs

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,37 +57,67 @@ use crate::worker::Worker;
5757
/// - Emit events with [`emit_event`](Self::emit_event)
5858
/// - Cancel tasks with [`cancel_task`](Self::cancel_task)
5959
///
60+
/// # Type Parameter
61+
///
62+
/// * `State` - Application state type passed to task handlers. Use `()` if you
63+
/// don't need any state. The state must implement `Clone + Send + Sync + 'static`.
64+
///
6065
/// # Example
6166
///
6267
/// ```ignore
68+
/// // Without state (default)
6369
/// let client = Durable::builder()
6470
/// .database_url("postgres://localhost/myapp")
6571
/// .queue_name("tasks")
6672
/// .build()
6773
/// .await?;
6874
///
75+
/// // With application state
76+
/// #[derive(Clone)]
77+
/// struct AppState {
78+
/// http_client: reqwest::Client,
79+
/// }
80+
///
81+
/// let app_state = AppState { http_client: reqwest::Client::new() };
82+
/// let client = Durable::builder()
83+
/// .database_url("postgres://localhost/myapp")
84+
/// .queue_name("tasks")
85+
/// .build_with_state(app_state)
86+
/// .await?;
87+
///
6988
/// client.register::<MyTask>().await;
7089
/// client.spawn::<MyTask>(params).await?;
7190
/// ```
72-
pub struct Durable {
91+
pub struct Durable<State = ()>
92+
where
93+
State: Clone + Send + Sync + 'static,
94+
{
7395
pool: PgPool,
7496
owns_pool: bool,
7597
queue_name: String,
7698
default_max_attempts: u32,
77-
registry: Arc<RwLock<TaskRegistry>>,
99+
registry: Arc<RwLock<TaskRegistry<State>>>,
100+
state: State,
78101
}
79102

80103
/// Builder for configuring a [`Durable`] client.
81104
///
82105
/// # Example
83106
///
84107
/// ```ignore
108+
/// // Without state
85109
/// let client = Durable::builder()
86110
/// .database_url("postgres://localhost/myapp")
87111
/// .queue_name("orders")
88112
/// .default_max_attempts(3)
89113
/// .build()
90114
/// .await?;
115+
///
116+
/// // With state
117+
/// let client = Durable::builder()
118+
/// .database_url("postgres://localhost/myapp")
119+
/// .build_with_state(my_app_state)
120+
/// .await?;
91121
/// ```
92122
pub struct DurableBuilder {
93123
database_url: Option<String>,
@@ -130,8 +160,43 @@ impl DurableBuilder {
130160
self
131161
}
132162

133-
/// Build the Durable client
134-
pub async fn build(self) -> anyhow::Result<Durable> {
163+
/// Build the Durable client without application state.
164+
///
165+
/// Use this when your tasks don't need access to shared resources
166+
/// like HTTP clients or database pools.
167+
pub async fn build(self) -> anyhow::Result<Durable<()>> {
168+
self.build_with_state(()).await
169+
}
170+
171+
/// Build the Durable client with application state.
172+
///
173+
/// The state will be cloned and passed to each task execution.
174+
/// Use this to provide shared resources like HTTP clients, database pools,
175+
/// or other application state to your tasks.
176+
///
177+
/// # Example
178+
///
179+
/// ```ignore
180+
/// #[derive(Clone)]
181+
/// struct AppState {
182+
/// http_client: reqwest::Client,
183+
/// db_pool: PgPool,
184+
/// }
185+
///
186+
/// let state = AppState {
187+
/// http_client: reqwest::Client::new(),
188+
/// db_pool: pool.clone(),
189+
/// };
190+
///
191+
/// let client = Durable::builder()
192+
/// .database_url("postgres://localhost/myapp")
193+
/// .build_with_state(state)
194+
/// .await?;
195+
/// ```
196+
pub async fn build_with_state<State>(self, state: State) -> anyhow::Result<Durable<State>>
197+
where
198+
State: Clone + Send + Sync + 'static,
199+
{
135200
let (pool, owns_pool) = if let Some(pool) = self.pool {
136201
(pool, false)
137202
} else {
@@ -148,6 +213,7 @@ impl DurableBuilder {
148213
queue_name: self.queue_name,
149214
default_max_attempts: self.default_max_attempts,
150215
registry: Arc::new(RwLock::new(HashMap::new())),
216+
state,
151217
})
152218
}
153219
}
@@ -158,8 +224,8 @@ impl Default for DurableBuilder {
158224
}
159225
}
160226

161-
impl Durable {
162-
/// Create a new client with default settings
227+
impl Durable<()> {
228+
/// Create a new client with default settings (no application state).
163229
pub async fn new(database_url: &str) -> anyhow::Result<Self> {
164230
DurableBuilder::new()
165231
.database_url(database_url)
@@ -171,7 +237,12 @@ impl Durable {
171237
pub fn builder() -> DurableBuilder {
172238
DurableBuilder::new()
173239
}
240+
}
174241

242+
impl<State> Durable<State>
243+
where
244+
State: Clone + Send + Sync + 'static,
245+
{
175246
/// Get a reference to the underlying connection pool
176247
pub fn pool(&self) -> &PgPool {
177248
&self.pool
@@ -182,21 +253,29 @@ impl Durable {
182253
&self.queue_name
183254
}
184255

256+
/// Get a reference to the application state
257+
pub fn state(&self) -> &State {
258+
&self.state
259+
}
260+
185261
/// Register a task type. Required before spawning or processing.
186-
pub async fn register<T: Task>(&self) -> &Self {
262+
pub async fn register<T: Task<State>>(&self) -> &Self {
187263
let mut registry = self.registry.write().await;
188-
registry.insert(T::NAME.to_string(), Arc::new(TaskWrapper::<T>::new()));
264+
registry.insert(
265+
T::NAME.to_string(),
266+
Arc::new(TaskWrapper::<T, State>::new()),
267+
);
189268
self
190269
}
191270

192271
/// Spawn a task (type-safe version)
193-
pub async fn spawn<T: Task>(&self, params: T::Params) -> anyhow::Result<SpawnResult> {
272+
pub async fn spawn<T: Task<State>>(&self, params: T::Params) -> anyhow::Result<SpawnResult> {
194273
self.spawn_with_options::<T>(params, SpawnOptions::default())
195274
.await
196275
}
197276

198277
/// Spawn a task with options (type-safe version)
199-
pub async fn spawn_with_options<T: Task>(
278+
pub async fn spawn_with_options<T: Task<State>>(
200279
&self,
201280
params: T::Params,
202281
options: SpawnOptions,
@@ -240,7 +319,7 @@ impl Durable {
240319
params: T::Params,
241320
) -> anyhow::Result<SpawnResult>
242321
where
243-
T: Task,
322+
T: Task<State>,
244323
E: Executor<'e, Database = Postgres>,
245324
{
246325
self.spawn_with_options_with::<T, E>(executor, params, SpawnOptions::default())
@@ -255,7 +334,7 @@ impl Durable {
255334
options: SpawnOptions,
256335
) -> anyhow::Result<SpawnResult>
257336
where
258-
T: Task,
337+
T: Task<State>,
259338
E: Executor<'e, Database = Postgres>,
260339
{
261340
self.spawn_by_name_with(executor, T::NAME, serde_json::to_value(&params)?, options)
@@ -377,6 +456,7 @@ impl Durable {
377456
self.queue_name.clone(),
378457
self.registry.clone(),
379458
options,
459+
self.state.clone(),
380460
)
381461
.await
382462
}

src/context.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use serde::{Serialize, de::DeserializeOwned};
33
use serde_json::Value as JsonValue;
44
use sqlx::PgPool;
55
use std::collections::HashMap;
6+
use std::marker::PhantomData;
67
use std::time::Duration;
78
use uuid::Uuid;
89

@@ -27,12 +28,20 @@ use crate::worker::LeaseExtender;
2728
/// - **Lease management** via [`heartbeat`](Self::heartbeat) - Extend the task lease
2829
/// for long-running operations
2930
///
31+
/// # Type Parameter
32+
///
33+
/// * `State` - The application state type. This allows [`spawn`](Self::spawn) to
34+
/// automatically infer the correct state type for child tasks.
35+
///
3036
/// # Public Fields
3137
///
3238
/// - `task_id` - Unique identifier for this task (use as idempotency key)
3339
/// - `run_id` - Identifier for the current execution attempt
3440
/// - `attempt` - Current attempt number (starts at 1)
35-
pub struct TaskContext {
41+
pub struct TaskContext<State = ()>
42+
where
43+
State: Clone + Send + Sync + 'static,
44+
{
3645
/// Unique identifier for this task. Use this as an idempotency key for
3746
/// external API calls to achieve "exactly-once" semantics.
3847
pub task_id: Uuid,
@@ -57,6 +66,9 @@ pub struct TaskContext {
5766

5867
/// Notifies the worker when the lease is extended via step() or heartbeat().
5968
lease_extender: LeaseExtender,
69+
70+
/// Phantom data to carry the State type parameter.
71+
_state: PhantomData<State>,
6072
}
6173

6274
/// Validate that a user-provided step name doesn't use reserved prefix.
@@ -69,7 +81,10 @@ fn validate_user_name(name: &str) -> TaskResult<()> {
6981
Ok(())
7082
}
7183

72-
impl TaskContext {
84+
impl<State> TaskContext<State>
85+
where
86+
State: Clone + Send + Sync + 'static,
87+
{
7388
/// Create a new TaskContext. Called by the worker before executing a task.
7489
/// Loads all existing checkpoints into the cache.
7590
pub(crate) async fn create(
@@ -105,6 +120,7 @@ impl TaskContext {
105120
checkpoint_cache: cache,
106121
step_counters: HashMap::new(),
107122
lease_extender,
123+
_state: PhantomData,
108124
})
109125
}
110126

@@ -431,12 +447,15 @@ impl TaskContext {
431447
/// let r1: ItemResult = ctx.join("item-1", h1).await?;
432448
/// let r2: ItemResult = ctx.join("item-2", h2).await?;
433449
/// ```
434-
pub async fn spawn<T: Task>(
450+
pub async fn spawn<T>(
435451
&mut self,
436452
name: &str,
437453
params: T::Params,
438454
options: crate::SpawnOptions,
439-
) -> TaskResult<TaskHandle<T::Output>> {
455+
) -> TaskResult<TaskHandle<T::Output>>
456+
where
457+
T: Task<State>,
458+
{
440459
validate_user_name(name)?;
441460
let checkpoint_name = self.get_checkpoint_name(&format!("$spawn:{name}"));
442461

0 commit comments

Comments
 (0)