Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 35 additions & 31 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use uuid::Uuid;

use crate::Durable;
use crate::error::{ControlFlow, TaskError, TaskResult};
use std::sync::Arc;

use crate::heartbeat::{HeartbeatHandle, Heartbeater, StepState};
use crate::task::Task;
use crate::types::DurableEventPayload;
use crate::types::{
Expand Down Expand Up @@ -63,6 +66,9 @@ where

/// Notifies the worker when the lease is extended via step() or heartbeat().
lease_extender: LeaseExtender,

/// Cloneable heartbeat handle for use in step closures.
heartbeat_handle: HeartbeatHandle,
}

/// Validate that a user-provided step name doesn't use reserved prefix.
Expand Down Expand Up @@ -103,6 +109,14 @@ where
cache.insert(row.checkpoint_name, row.state);
}

let heartbeat_handle = HeartbeatHandle::new(
durable.pool().clone(),
durable.queue_name().to_string(),
task.run_id,
claim_timeout,
lease_extender.clone(),
);

Ok(Self {
task_id: task.task_id,
run_id: task.run_id,
Expand All @@ -113,6 +127,7 @@ where
checkpoint_cache: cache,
step_counters: HashMap::new(),
lease_extender,
heartbeat_handle,
})
}

Expand Down Expand Up @@ -152,9 +167,9 @@ where
/// # Example
///
/// ```ignore
/// let payment_id = ctx.step("charge-payment", ctx.task_id, |task_id, _state| async {
/// let payment_id = ctx.step("charge-payment", ctx.task_id, |task_id, step_state| async {
/// let idempotency_key = format!("{}:charge", task_id);
/// stripe::charge(amount, &idempotency_key).await
/// stripe::charge(amount, &idempotency_key, &step_state.state).await
/// }).await?;
/// ```
#[cfg_attr(
Expand All @@ -169,7 +184,7 @@ where
&mut self,
base_name: &str,
params: P,
f: fn(P, State) -> Fut,
f: fn(P, StepState<State>) -> Fut,
) -> TaskResult<T>
where
P: Serialize,
Expand All @@ -193,13 +208,14 @@ where
span.record("cached", false);

// Execute the step
let result =
f(params, self.durable.state().clone())
.await
.map_err(|e| TaskError::Step {
base_name: base_name.to_string(),
error: e,
})?;
let step_state = StepState {
state: self.durable.state().clone(),
heartbeater: Arc::new(self.heartbeat_handle.clone()),
};
let result = f(params, step_state).await.map_err(|e| TaskError::Step {
base_name: base_name.to_string(),
error: e,
})?;

// Persist checkpoint (also extends claim lease)
#[cfg(feature = "telemetry")]
Expand Down Expand Up @@ -461,6 +477,14 @@ where
})
}

/// Get a cloneable heartbeat handle for use in step closures or `SimpleTool`s.
///
/// The returned [`HeartbeatHandle`] can be passed into contexts that need to
/// extend the task lease without access to the full `TaskContext`.
pub fn heartbeat_handle(&self) -> HeartbeatHandle {
self.heartbeat_handle.clone()
}

/// Extend the task's lease to prevent timeout.
///
/// Use this for long-running operations that don't naturally checkpoint.
Expand All @@ -482,27 +506,7 @@ where
)
)]
pub async fn heartbeat(&self, duration: Option<std::time::Duration>) -> TaskResult<()> {
let extend_by = duration.unwrap_or(self.claim_timeout);

if extend_by < std::time::Duration::from_secs(1) {
return Err(TaskError::Validation {
message: "heartbeat duration must be at least 1 second".to_string(),
});
}

let query = "SELECT durable.extend_claim($1, $2, $3)";
sqlx::query(query)
.bind(self.durable.queue_name())
.bind(self.run_id)
.bind(extend_by.as_secs() as i32)
.execute(self.durable.pool())
.await
.map_err(TaskError::from_sqlx_error)?;

// Notify worker that lease was extended so it can reset timers
self.lease_extender.notify(extend_by);

Ok(())
self.heartbeat_handle.heartbeat(duration).await
}

/// Generate a durable random value in [0, 1).
Expand Down
133 changes: 133 additions & 0 deletions src/heartbeat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use uuid::Uuid;

use crate::error::{TaskError, TaskResult};
use crate::worker::LeaseExtender;

/// Trait for extending task leases during long-running operations.
///
/// Implementations allow code running inside a `step()` closure to send heartbeats
/// back to the durable framework, preventing the worker from considering the task
/// dead during long-running operations.
///
/// Two implementations are provided:
/// - [`HeartbeatHandle`] — extends leases via the database (used in durable workers)
/// - [`NoopHeartbeater`] — does nothing (used in tests and non-durable contexts)
#[async_trait]
pub trait Heartbeater: Send + Sync {
/// Extend the task's lease.
///
/// # Arguments
/// * `duration` - Extension duration. If `None`, uses the original claim timeout.
/// Must be at least 1 second when `Some`.
async fn heartbeat(&self, duration: Option<Duration>) -> TaskResult<()>;
}

/// Real heartbeat handle that extends leases via the database.
///
/// Created from a [`TaskContext`](crate::TaskContext) via
/// [`heartbeat_handle()`](crate::TaskContext::heartbeat_handle) and can be
/// passed into step closures or other contexts that need to extend the task lease.
#[derive(Clone)]
pub struct HeartbeatHandle {
pool: sqlx::PgPool,
queue_name: String,
run_id: Uuid,
claim_timeout: Duration,
lease_extender: LeaseExtender,
}

impl HeartbeatHandle {
pub(crate) fn new(
pool: sqlx::PgPool,
queue_name: String,
run_id: Uuid,
claim_timeout: Duration,
lease_extender: LeaseExtender,
) -> Self {
Self {
pool,
queue_name,
run_id,
claim_timeout,
lease_extender,
}
}
}

#[async_trait]
impl Heartbeater for HeartbeatHandle {
async fn heartbeat(&self, duration: Option<Duration>) -> TaskResult<()> {
let extend_by = duration.unwrap_or(self.claim_timeout);

if extend_by < Duration::from_secs(1) {
return Err(TaskError::Validation {
message: "heartbeat duration must be at least 1 second".to_string(),
});
}

let query = "SELECT durable.extend_claim($1, $2, $3)";
sqlx::query(query)
.bind(&self.queue_name)
.bind(self.run_id)
.bind(extend_by.as_secs() as i32)
.execute(&self.pool)
.await
.map_err(TaskError::from_sqlx_error)?;

// Notify worker that lease was extended so it can reset timers
self.lease_extender.notify(extend_by);

Ok(())
}
}

/// No-op heartbeater for testing and non-durable contexts.
///
/// All heartbeat calls succeed immediately without any side effects.
#[derive(Clone, Default)]
pub struct NoopHeartbeater;

#[async_trait]
impl Heartbeater for NoopHeartbeater {
async fn heartbeat(&self, _duration: Option<Duration>) -> TaskResult<()> {
Ok(())
}
}

/// State provided to `step()` closures, wrapping the user's application state
/// alongside a heartbeater for extending the task lease.
///
/// This is passed as the second argument to every `step()` closure, making
/// heartbeating available without the consumer needing to thread it manually.
///
/// # Example
///
/// ```ignore
/// ctx.step("long-operation", params, |params, step_state| async move {
/// for item in &params.items {
/// process(item, &step_state.state).await?;
/// // Extend lease during long-running work
/// let _ = step_state.heartbeater.heartbeat(None).await;
/// }
/// Ok(result)
/// }).await?;
/// ```
///
/// For testing step closures in isolation, construct with [`NoopHeartbeater`]:
///
/// ```ignore
/// let step_state = StepState {
/// state: my_test_state,
/// heartbeater: Arc::new(NoopHeartbeater),
/// };
/// ```
pub struct StepState<State> {
/// The user's application state.
pub state: State,
/// Handle for extending the task lease during long-running operations.
pub heartbeater: Arc<dyn Heartbeater>,
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ mod client;
mod context;
mod cron;
mod error;
mod heartbeat;
mod task;
#[cfg(feature = "telemetry")]
pub mod telemetry;
Expand All @@ -109,6 +110,7 @@ pub use client::{Durable, DurableBuilder};
pub use context::TaskContext;
pub use cron::{ScheduleFilter, ScheduleInfo, ScheduleOptions, setup_pgcron};
pub use error::{ControlFlow, DurableError, DurableResult, TaskError, TaskResult};
pub use heartbeat::{HeartbeatHandle, Heartbeater, NoopHeartbeater, StepState};
pub use task::{ErasedTask, Task, TaskWrapper};
pub use types::{
CancellationPolicy, ClaimedTask, DurableEventPayload, RetryStrategy, SpawnDefaults,
Expand Down
4 changes: 2 additions & 2 deletions tests/execution_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,13 +745,13 @@ impl durable::Task<AppState> for WriteToDbTask {
) -> durable::TaskResult<Self::Output> {
// Use the app state's db pool to write to a table
let row_id: i64 = ctx
.step("insert", params, |params, state| async move {
.step("insert", params, |params, step_state| async move {
let (id,): (i64,) = sqlx::query_as(
"INSERT INTO test_state_table (key, value) VALUES ($1, $2) RETURNING id",
)
.bind(&params.key)
.bind(&params.value)
.fetch_one(&state.db_pool)
.fetch_one(&step_state.state.db_pool)
.await
.map_err(|e| anyhow::anyhow!("DB error: {}", e))?;
Ok(id)
Expand Down