Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
5556199
initial rust initialization
virajmehta Dec 1, 2025
c1279f4
added pre-commit
virajmehta Dec 2, 2025
b470f98
set up sqlx and rust
virajmehta Dec 2, 2025
27879a0
added license files and initial migration;
virajmehta Dec 2, 2025
c04415c
added initial impl of client
virajmehta Dec 2, 2025
4464259
tests compile
virajmehta Dec 2, 2025
3372b8a
fixed bugs with sqlx types
virajmehta Dec 2, 2025
78b1c1f
added documentation
virajmehta Dec 3, 2025
510d09c
added tests
virajmehta Dec 3, 2025
e5dff40
updated tasks
virajmehta Dec 3, 2025
2c771d8
added a test that mocks the example in README
virajmehta Dec 3, 2025
3d623e7
removed todos
virajmehta Dec 3, 2025
ae77c98
added convenience methods for uuid, rand, now
virajmehta Dec 6, 2025
9ea4009
Merge pull request #9 from tensorzero/viraj/add-convenience-functions
virajmehta Dec 6, 2025
b26b98a
added handling for spawning and joining subtasks from workflows
virajmehta Dec 7, 2025
f6ab2bb
added missing sql and test files
virajmehta Dec 7, 2025
e37aac7
merged migrations
virajmehta Dec 7, 2025
55a11d2
cleaned up bad code
virajmehta Dec 7, 2025
57f809c
Merge branch 'viraj/draft-client' of github.com:tensorzero/durable in…
virajmehta Dec 7, 2025
2e18b25
cleaned up json handling
virajmehta Dec 7, 2025
6147df1
made process exit optional on too-long tasks
virajmehta Dec 7, 2025
fc2c0ab
fixed semaphore ordering
virajmehta Dec 7, 2025
b8d4fd5
Merge branch 'viraj/draft-client' of github.com:tensorzero/durable in…
virajmehta Dec 7, 2025
9c43681
Merge pull request #10 from tensorzero/viraj/fanout
virajmehta Dec 7, 2025
d2076ea
fixed issues with clock skew
virajmehta Dec 7, 2025
08adcf8
Merge branch 'viraj/draft-client' of github.com:tensorzero/durable in…
virajmehta Dec 7, 2025
9d11dc9
improved handling of leases
virajmehta Dec 7, 2025
c5d5a5a
added comments on sql schema
virajmehta Dec 8, 2025
435db57
enforced that claim timeouts must be set
virajmehta Dec 8, 2025
e7d51ca
cleaned up and documented sql
virajmehta Dec 8, 2025
1741a18
added support for transactions that enqueue tasks
virajmehta Dec 8, 2025
693c527
added a bunch of tests
virajmehta Dec 9, 2025
1c563cb
documented and tested event semantics
virajmehta Dec 9, 2025
3c1af99
added benchmarks
virajmehta Dec 9, 2025
ab05d1b
initial implementation of telemetry
virajmehta Dec 9, 2025
23ba9e4
Merge branch 'main' of github.com:tensorzero/durable into viraj/obser…
virajmehta Dec 9, 2025
9bc14de
removed exporter setup from crate
virajmehta Dec 9, 2025
f71f1a5
telemetry tests pass
virajmehta Dec 10, 2025
c817921
removed extra license file
virajmehta Dec 11, 2025
42d8f9a
Merge branch 'main' of github.com:tensorzero/durable into viraj/obser…
virajmehta Dec 11, 2025
c8c7572
inject otel context as a string for key durable::otel_context
virajmehta Dec 11, 2025
4d8e2b2
protect durable:: headers for internal use
virajmehta Dec 11, 2025
51ed409
Merge branch 'main' of github.com:tensorzero/durable into viraj/obser…
virajmehta Dec 11, 2025
1eece0f
addressed PR comment
virajmehta Dec 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 318 additions & 21 deletions Cargo.lock

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,27 @@ hostname = "0.4"
rand = "0.9"
futures = "0.3.31"

# Optional telemetry dependencies
tracing-opentelemetry = { version = "0.28", optional = true }
opentelemetry = { version = "0.27", optional = true }
opentelemetry_sdk = { version = "0.27", optional = true }
metrics = { version = "0.24", optional = true }

[features]
default = []
telemetry = [
"dep:tracing-opentelemetry",
"dep:opentelemetry",
"dep:opentelemetry_sdk",
"dep:metrics",
]

[dev-dependencies]
criterion = { version = "0.5", features = ["async_tokio", "html_reports"] }
tracing-fluent-assertions = "0.3"
metrics-util = { version = "0.18", features = ["debugging"] }
tracing-subscriber = { version = "0.3", features = ["registry"] }
ordered-float = "4"

[[bench]]
name = "throughput"
Expand Down
57 changes: 56 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ impl CancellationPolicyDb {

use crate::worker::Worker;

/// Validates that user-provided headers don't use reserved prefixes.
fn validate_headers(headers: &Option<HashMap<String, JsonValue>>) -> anyhow::Result<()> {
if let Some(headers) = headers {
for key in headers.keys() {
if key.starts_with("durable::") {
anyhow::bail!(
"Header key '{}' uses reserved prefix 'durable::'. User headers cannot start with 'durable::'.",
key
);
}
}
}
Ok(())
}

/// The main client for interacting with durable workflows.
///
/// Use this client to:
Expand Down Expand Up @@ -355,6 +370,14 @@ where
/// Spawn a task by name using a custom executor.
///
/// The task must be registered before spawning.
#[cfg_attr(
feature = "telemetry",
tracing::instrument(
name = "durable.client.spawn",
skip(self, executor, params, options),
fields(queue, task_name = %task_name)
)
)]
pub async fn spawn_by_name_with<'e, E>(
&self,
executor: E,
Expand All @@ -380,16 +403,30 @@ where
}

/// Internal spawn implementation without registry validation.
#[allow(unused_mut)] // mut is needed when telemetry feature is enabled
async fn spawn_by_name_internal<'e, E>(
&self,
executor: E,
task_name: &str,
params: JsonValue,
options: SpawnOptions,
mut options: SpawnOptions,
) -> anyhow::Result<SpawnResult>
where
E: Executor<'e, Database = Postgres>,
{
// Validate user headers don't use reserved prefix
validate_headers(&options.headers)?;

// Inject trace context into headers for distributed tracing
#[cfg(feature = "telemetry")]
{
let headers = options.headers.get_or_insert_with(HashMap::new);
crate::telemetry::inject_trace_context(headers);
}

#[cfg(feature = "telemetry")]
tracing::Span::current().record("queue", &self.queue_name);

let max_attempts = options.max_attempts.unwrap_or(self.default_max_attempts);

let db_options = Self::serialize_spawn_options(&options, max_attempts)?;
Expand All @@ -405,6 +442,9 @@ where
.fetch_one(executor)
.await?;

#[cfg(feature = "telemetry")]
crate::telemetry::record_task_spawned(&self.queue_name, task_name);

Ok(SpawnResult {
task_id: row.task_id,
run_id: row.run_id,
Expand Down Expand Up @@ -452,6 +492,14 @@ where
}

/// Emit an event to a queue (defaults to this client's queue)
#[cfg_attr(
feature = "telemetry",
tracing::instrument(
name = "durable.client.emit_event",
skip(self, payload),
fields(queue, event_name = %event_name)
)
)]
pub async fn emit_event<T: Serialize>(
&self,
event_name: &str,
Expand All @@ -461,6 +509,10 @@ where
anyhow::ensure!(!event_name.is_empty(), "event_name must be non-empty");

let queue = queue_name.unwrap_or(&self.queue_name);

#[cfg(feature = "telemetry")]
tracing::Span::current().record("queue", queue);

let payload_json = serde_json::to_value(payload)?;

let query = "SELECT durable.emit_event($1, $2, $3)";
Expand All @@ -471,6 +523,9 @@ where
.execute(&self.pool)
.await?;

#[cfg(feature = "telemetry")]
crate::telemetry::record_event_emitted(queue, event_name);

Ok(())
}

Expand Down
85 changes: 83 additions & 2 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ where
/// stripe::charge(amount, &idempotency_key).await
/// }).await?;
/// ```
#[cfg_attr(
feature = "telemetry",
tracing::instrument(
name = "durable.task.step",
skip(self, f),
fields(task_id = %self.task_id, step_name = %name)
)
)]
pub async fn step<T, F, Fut>(&mut self, name: &str, f: F) -> TaskResult<T>
where
T: Serialize + DeserializeOwned + Send,
Expand All @@ -174,8 +182,21 @@ where
let result = f().await?;

// Persist checkpoint (also extends claim lease)
#[cfg(feature = "telemetry")]
let checkpoint_start = std::time::Instant::now();

self.persist_checkpoint(&checkpoint_name, &result).await?;

#[cfg(feature = "telemetry")]
{
let duration = checkpoint_start.elapsed().as_secs_f64();
crate::telemetry::record_checkpoint_duration(
&self.queue_name,
&self.task.task_name,
duration,
);
}

Ok(result)
}

Expand Down Expand Up @@ -225,6 +246,14 @@ where
///
/// Wake time is computed using the database clock to ensure consistency
/// with the scheduler and enable deterministic testing via `durable.fake_now`.
#[cfg_attr(
feature = "telemetry",
tracing::instrument(
name = "durable.task.sleep_for",
skip(self),
fields(task_id = %self.task_id, duration_ms = duration.as_millis() as u64)
)
)]
pub async fn sleep_for(&mut self, name: &str, duration: std::time::Duration) -> TaskResult<()> {
validate_user_name(name)?;
let checkpoint_name = self.get_checkpoint_name(name);
Expand Down Expand Up @@ -268,6 +297,14 @@ where
/// Some(Duration::from_secs(7 * 24 * 3600)),
/// ).await?;
/// ```
#[cfg_attr(
feature = "telemetry",
tracing::instrument(
name = "durable.task.await_event",
skip(self, timeout),
fields(task_id = %self.task_id, event_name = %event_name)
)
)]
pub async fn await_event<T: DeserializeOwned>(
&mut self,
event_name: &str,
Expand Down Expand Up @@ -323,6 +360,14 @@ where
/// updates the payload (last write wins). Tasks waiting for this event
/// are woken with the payload at the time of the write that woke them;
/// subsequent writes do not propagate to already-woken tasks.
#[cfg_attr(
feature = "telemetry",
tracing::instrument(
name = "durable.task.emit_event",
skip(self, payload),
fields(task_id = %self.task_id, event_name = %event_name)
)
)]
pub async fn emit_event<T: Serialize>(&self, event_name: &str, payload: &T) -> TaskResult<()> {
if event_name.is_empty() {
return Err(TaskError::Failed(anyhow::anyhow!(
Expand Down Expand Up @@ -352,6 +397,14 @@ where
///
/// # Errors
/// Returns `TaskError::Control(Cancelled)` if the task was cancelled.
#[cfg_attr(
feature = "telemetry",
tracing::instrument(
name = "durable.task.heartbeat",
skip(self),
fields(task_id = %self.task_id)
)
)]
pub async fn heartbeat(&self, duration: Option<std::time::Duration>) -> TaskResult<()> {
let extend_by = duration
.map(|d| d.as_secs() as i32)
Expand Down Expand Up @@ -454,6 +507,14 @@ where
/// let r1: ItemResult = ctx.join("item-1", h1).await?;
/// let r2: ItemResult = ctx.join("item-2", h2).await?;
/// ```
#[cfg_attr(
feature = "telemetry",
tracing::instrument(
name = "durable.task.spawn",
skip(self, params, options),
fields(task_id = %self.task_id, subtask_name = T::NAME)
)
)]
pub async fn spawn<T>(
&mut self,
name: &str,
Expand Down Expand Up @@ -515,6 +576,19 @@ where
options: SpawnOptions,
) -> TaskResult<TaskHandle<T>> {
validate_user_name(name)?;

// Validate headers don't use reserved prefix
if let Some(ref headers) = options.headers {
for key in headers.keys() {
if key.starts_with("durable::") {
return Err(TaskError::Failed(anyhow::anyhow!(
"Header key '{}' uses reserved prefix 'durable::'. User headers cannot start with 'durable::'.",
key
)));
}
}
}

let checkpoint_name = self.get_checkpoint_name(&format!("$spawn:{name}"));

// Return cached task_id if already spawned
Expand Down Expand Up @@ -574,8 +648,7 @@ where
///
/// # Arguments
///
/// * `name` - Unique name for this join operation (used for checkpointing).
/// Uniqueness is constrained just within this task, not globally or for child tasks.
/// * `name` - Unique name for this join operation (used for checkpointing)
/// * `handle` - The [`TaskHandle`] returned by [`spawn`](Self::spawn)
///
/// # Errors
Expand All @@ -590,6 +663,14 @@ where
/// // ... do other work ...
/// let result: ComputeResult = ctx.join("compute", handle).await?;
/// ```
#[cfg_attr(
feature = "telemetry",
tracing::instrument(
name = "durable.task.join",
skip(self, handle),
fields(task_id = %self.task_id, child_task_id = %handle.task_id)
)
)]
pub async fn join<T: DeserializeOwned>(
&mut self,
name: &str,
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ mod client;
mod context;
mod error;
mod task;
#[cfg(feature = "telemetry")]
pub mod telemetry;
mod types;
mod worker;

Expand Down
Loading