diff --git a/src/context.rs b/src/context.rs index effb3e5..cd4bb40 100644 --- a/src/context.rs +++ b/src/context.rs @@ -7,6 +7,9 @@ use uuid::Uuid; use crate::Durable; use crate::error::{ControlFlow, TaskError, TaskResult}; +use std::sync::Arc; + +use crate::heartbeat::{HeartbeatHandle, Heartbeater, StepState}; use crate::task::Task; use crate::types::DurableEventPayload; use crate::types::{ @@ -63,6 +66,9 @@ where /// Notifies the worker when the lease is extended via step() or heartbeat(). lease_extender: LeaseExtender, + + /// Cloneable heartbeat handle for use in step closures. + heartbeat_handle: HeartbeatHandle, } /// Validate that a user-provided step name doesn't use reserved prefix. @@ -103,6 +109,14 @@ where cache.insert(row.checkpoint_name, row.state); } + let heartbeat_handle = HeartbeatHandle::new( + durable.pool().clone(), + durable.queue_name().to_string(), + task.run_id, + claim_timeout, + lease_extender.clone(), + ); + Ok(Self { task_id: task.task_id, run_id: task.run_id, @@ -113,6 +127,7 @@ where checkpoint_cache: cache, step_counters: HashMap::new(), lease_extender, + heartbeat_handle, }) } @@ -152,9 +167,9 @@ where /// # Example /// /// ```ignore - /// let payment_id = ctx.step("charge-payment", ctx.task_id, |task_id, _state| async { + /// let payment_id = ctx.step("charge-payment", ctx.task_id, |task_id, step_state| async { /// let idempotency_key = format!("{}:charge", task_id); - /// stripe::charge(amount, &idempotency_key).await + /// stripe::charge(amount, &idempotency_key, &step_state.state).await /// }).await?; /// ``` #[cfg_attr( @@ -169,7 +184,7 @@ where &mut self, base_name: &str, params: P, - f: fn(P, State) -> Fut, + f: fn(P, StepState) -> Fut, ) -> TaskResult where P: Serialize, @@ -193,13 +208,14 @@ where span.record("cached", false); // Execute the step - let result = - f(params, self.durable.state().clone()) - .await - .map_err(|e| TaskError::Step { - base_name: base_name.to_string(), - error: e, - })?; + let step_state = StepState { + state: self.durable.state().clone(), + heartbeater: Arc::new(self.heartbeat_handle.clone()), + }; + let result = f(params, step_state).await.map_err(|e| TaskError::Step { + base_name: base_name.to_string(), + error: e, + })?; // Persist checkpoint (also extends claim lease) #[cfg(feature = "telemetry")] @@ -461,6 +477,14 @@ where }) } + /// Get a cloneable heartbeat handle for use in step closures or `SimpleTool`s. + /// + /// The returned [`HeartbeatHandle`] can be passed into contexts that need to + /// extend the task lease without access to the full `TaskContext`. + pub fn heartbeat_handle(&self) -> HeartbeatHandle { + self.heartbeat_handle.clone() + } + /// Extend the task's lease to prevent timeout. /// /// Use this for long-running operations that don't naturally checkpoint. @@ -482,27 +506,7 @@ where ) )] pub async fn heartbeat(&self, duration: Option) -> TaskResult<()> { - let extend_by = duration.unwrap_or(self.claim_timeout); - - if extend_by < std::time::Duration::from_secs(1) { - return Err(TaskError::Validation { - message: "heartbeat duration must be at least 1 second".to_string(), - }); - } - - let query = "SELECT durable.extend_claim($1, $2, $3)"; - sqlx::query(query) - .bind(self.durable.queue_name()) - .bind(self.run_id) - .bind(extend_by.as_secs() as i32) - .execute(self.durable.pool()) - .await - .map_err(TaskError::from_sqlx_error)?; - - // Notify worker that lease was extended so it can reset timers - self.lease_extender.notify(extend_by); - - Ok(()) + self.heartbeat_handle.heartbeat(duration).await } /// Generate a durable random value in [0, 1). diff --git a/src/heartbeat.rs b/src/heartbeat.rs new file mode 100644 index 0000000..b0543c5 --- /dev/null +++ b/src/heartbeat.rs @@ -0,0 +1,133 @@ +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use uuid::Uuid; + +use crate::error::{TaskError, TaskResult}; +use crate::worker::LeaseExtender; + +/// Trait for extending task leases during long-running operations. +/// +/// Implementations allow code running inside a `step()` closure to send heartbeats +/// back to the durable framework, preventing the worker from considering the task +/// dead during long-running operations. +/// +/// Two implementations are provided: +/// - [`HeartbeatHandle`] — extends leases via the database (used in durable workers) +/// - [`NoopHeartbeater`] — does nothing (used in tests and non-durable contexts) +#[async_trait] +pub trait Heartbeater: Send + Sync { + /// Extend the task's lease. + /// + /// # Arguments + /// * `duration` - Extension duration. If `None`, uses the original claim timeout. + /// Must be at least 1 second when `Some`. + async fn heartbeat(&self, duration: Option) -> TaskResult<()>; +} + +/// Real heartbeat handle that extends leases via the database. +/// +/// Created from a [`TaskContext`](crate::TaskContext) via +/// [`heartbeat_handle()`](crate::TaskContext::heartbeat_handle) and can be +/// passed into step closures or other contexts that need to extend the task lease. +#[derive(Clone)] +pub struct HeartbeatHandle { + pool: sqlx::PgPool, + queue_name: String, + run_id: Uuid, + claim_timeout: Duration, + lease_extender: LeaseExtender, +} + +impl HeartbeatHandle { + pub(crate) fn new( + pool: sqlx::PgPool, + queue_name: String, + run_id: Uuid, + claim_timeout: Duration, + lease_extender: LeaseExtender, + ) -> Self { + Self { + pool, + queue_name, + run_id, + claim_timeout, + lease_extender, + } + } +} + +#[async_trait] +impl Heartbeater for HeartbeatHandle { + async fn heartbeat(&self, duration: Option) -> TaskResult<()> { + let extend_by = duration.unwrap_or(self.claim_timeout); + + if extend_by < Duration::from_secs(1) { + return Err(TaskError::Validation { + message: "heartbeat duration must be at least 1 second".to_string(), + }); + } + + let query = "SELECT durable.extend_claim($1, $2, $3)"; + sqlx::query(query) + .bind(&self.queue_name) + .bind(self.run_id) + .bind(extend_by.as_secs() as i32) + .execute(&self.pool) + .await + .map_err(TaskError::from_sqlx_error)?; + + // Notify worker that lease was extended so it can reset timers + self.lease_extender.notify(extend_by); + + Ok(()) + } +} + +/// No-op heartbeater for testing and non-durable contexts. +/// +/// All heartbeat calls succeed immediately without any side effects. +#[derive(Clone, Default)] +pub struct NoopHeartbeater; + +#[async_trait] +impl Heartbeater for NoopHeartbeater { + async fn heartbeat(&self, _duration: Option) -> TaskResult<()> { + Ok(()) + } +} + +/// State provided to `step()` closures, wrapping the user's application state +/// alongside a heartbeater for extending the task lease. +/// +/// This is passed as the second argument to every `step()` closure, making +/// heartbeating available without the consumer needing to thread it manually. +/// +/// # Example +/// +/// ```ignore +/// ctx.step("long-operation", params, |params, step_state| async move { +/// for item in ¶ms.items { +/// process(item, &step_state.state).await?; +/// // Extend lease during long-running work +/// let _ = step_state.heartbeater.heartbeat(None).await; +/// } +/// Ok(result) +/// }).await?; +/// ``` +/// +/// For testing step closures in isolation, construct with [`NoopHeartbeater`]: +/// +/// ```ignore +/// let step_state = StepState { +/// state: my_test_state, +/// heartbeater: Arc::new(NoopHeartbeater), +/// }; +/// ``` +pub struct StepState { + /// The user's application state. + pub state: State, + /// Handle for extending the task lease during long-running operations. + pub heartbeater: Arc, +} diff --git a/src/lib.rs b/src/lib.rs index fed9b2e..fca9476 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,6 +98,7 @@ mod client; mod context; mod cron; mod error; +mod heartbeat; mod task; #[cfg(feature = "telemetry")] pub mod telemetry; @@ -109,6 +110,7 @@ pub use client::{Durable, DurableBuilder}; pub use context::TaskContext; pub use cron::{ScheduleFilter, ScheduleInfo, ScheduleOptions, setup_pgcron}; pub use error::{ControlFlow, DurableError, DurableResult, TaskError, TaskResult}; +pub use heartbeat::{HeartbeatHandle, Heartbeater, NoopHeartbeater, StepState}; pub use task::{ErasedTask, Task, TaskWrapper}; pub use types::{ CancellationPolicy, ClaimedTask, DurableEventPayload, RetryStrategy, SpawnDefaults, diff --git a/tests/execution_test.rs b/tests/execution_test.rs index da4beb7..f0d0884 100644 --- a/tests/execution_test.rs +++ b/tests/execution_test.rs @@ -745,13 +745,13 @@ impl durable::Task for WriteToDbTask { ) -> durable::TaskResult { // Use the app state's db pool to write to a table let row_id: i64 = ctx - .step("insert", params, |params, state| async move { + .step("insert", params, |params, step_state| async move { let (id,): (i64,) = sqlx::query_as( "INSERT INTO test_state_table (key, value) VALUES ($1, $2) RETURNING id", ) .bind(¶ms.key) .bind(¶ms.value) - .fetch_one(&state.db_pool) + .fetch_one(&step_state.state.db_pool) .await .map_err(|e| anyhow::anyhow!("DB error: {}", e))?; Ok(id)