From b6a982ce3802e256f9ae33c112e42d3f2cb48dad Mon Sep 17 00:00:00 2001 From: Brandur Date: Wed, 22 Apr 2026 21:40:57 -0500 Subject: [PATCH] Implement resumable jobs Here, implement "resumable" jobs, which are jobs that can checkpoint their progress so that in case they have to stop early, they're picked up from a point that lets them skip work that's already been done. This is especially useful for long running jobs that are at risk of being interrupted from something like a deploy. Here's roughly the shape of the API, with the same normal `Work` function that all jobs implement, and with a series of `ResumableStep` calls within, each of which take a name for the step and function representing it: func (w *ResumableWorker) Work(ctx context.Context, job *river.Job[ResumableArgs]) error { river.ResumableStep(ctx, "step1", func(ctx context.Context) error { fmt.Println("Step 1") return nil }) river.ResumableStep(ctx, "step2", func(ctx context.Context) error { fmt.Println("Step 2") return nil }) river.ResumableStep(ctx, "step3", func(ctx context.Context) error { fmt.Println("Step 3") return nil }) return nil } We also provide a cursor API for more granularity. This lets a step set an arbitrary cursor value periodically as it's doing something like looping over records in a set: river.ResumableStepCursor(ctx, "process_ids", func(ctx context.Context, cursor ResumableCursor) error { for _, id := range job.Args.IDs { if id <= cursor.LastProcessedID { continue } fmt.Printf("Processed %d\n", id) if err := river.ResumableSetCursor(ctx, ResumableCursor{LastProcessedID: id}); err != nil { return err } } return nil }) The function is `ResumableStepCursor[TCursor any]` where `TCursor` can be defined arbitrarily by the user. This could be a simple scalar value representing an ID, or a more complex `struct` value containing multiple IDs, enabling nested loops that set inner and outer IDs at the same time. `ResumableStep` and `ResumableStepCursor` steps can be freely intermingled, and multiple `ResumableStepCursor` steps with different cursor types are supported. Cursors must be JSON marshable because they're stored to a job's metadata. Lastly, we provide `ResumableSetStepTx` and `ResumableSetStepCursorTx` for cases where a transaction guarantee is necessary. Normally, resumable step and cursor are set as a job's being completed, but there's a chance this is never called in case of sudden failure. `ResumableSetStepTx` (and its cursor version) is available to durably persist a step at the cost of an extra database operation similar to how `JobCompleteTx` does the same for job completion. One neat aspect the implementation here is that I was able to make it entirely middleware-only. So all the resumable job logic goes in an internal `resumableMiddleware` that's included in all clients by default. This is kind of nice because it keeps its code highly modular and will hopefully act as a template for future features. --- CHANGELOG.md | 4 + client.go | 4 +- client_test.go | 114 +++++++++- example_resumable_cursor_job_test.go | 105 +++++++++ example_resumable_job_test.go | 98 ++++++++ example_resumable_set_step_tx_test.go | 126 ++++++++++ internal/maintenance/job_scheduler_test.go | 3 +- internal/rivercommon/river_common.go | 8 + internal/rivermiddleware/middleware.go | 84 +++++++ job_list_params.go | 5 +- resumable.go | 156 +++++++++++++ resumable_step_tx.go | 122 ++++++++++ resumable_step_tx_test.go | 169 ++++++++++++++ resumable_test.go | 253 +++++++++++++++++++++ rivertest/resumable.go | 70 ++++++ rivertest/resumable_test.go | 201 ++++++++++++++++ rivertest/worker.go | 3 +- 17 files changed, 1516 insertions(+), 9 deletions(-) create mode 100644 example_resumable_cursor_job_test.go create mode 100644 example_resumable_job_test.go create mode 100644 example_resumable_set_step_tx_test.go create mode 100644 internal/rivermiddleware/middleware.go create mode 100644 resumable.go create mode 100644 resumable_step_tx.go create mode 100644 resumable_step_tx_test.go create mode 100644 resumable_test.go create mode 100644 rivertest/resumable.go create mode 100644 rivertest/resumable_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index e8ca7545..1932d545 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added "resumable jobs" that can be broken down into multiple steps and with a step persisted after it finishes that lets them skip work that's already been done. This is particularly useful for long running jobs that may experience a cancellation (like in the event of a deploy) during the span of their run. [PR #1226](https://github.com/riverqueue/river/pull/1226). + ## [0.36.0] - 2026-05-09 ### Added diff --git a/client.go b/client.go index d6877f92..d7d46d53 100644 --- a/client.go +++ b/client.go @@ -24,6 +24,7 @@ import ( "github.com/riverqueue/river/internal/notifier" "github.com/riverqueue/river/internal/notifylimiter" "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/internal/rivermiddleware" "github.com/riverqueue/river/internal/workunit" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/baseservice" @@ -782,7 +783,8 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client // the more abstract config.Middleware for middleware are set, but not both, // so in practice we never append all three of these to each other. { - middleware := config.Middleware + middleware := rivermiddleware.DefaultMiddleware() + middleware = append(middleware, config.Middleware...) for _, jobInsertMiddleware := range config.JobInsertMiddleware { middleware = append(middleware, jobInsertMiddleware) } diff --git a/client_test.go b/client_test.go index 6148b584..9c1613c8 100644 --- a/client_test.go +++ b/client_test.go @@ -82,6 +82,62 @@ func (w *periodicJobWorker) Work(ctx context.Context, job *Job[periodicJobArgs]) return nil } +type resumableClientTestArgs struct{} + +func (resumableClientTestArgs) Kind() string { return "resumable_client_test" } + +type resumableClientTestWorker struct { + WorkerDefaults[resumableClientTestArgs] + + calls []string + callsMu sync.Mutex + failedOnce atomic.Bool +} + +func (w *resumableClientTestWorker) Calls() []string { + w.callsMu.Lock() + defer w.callsMu.Unlock() + + return append([]string(nil), w.calls...) +} + +func (w *resumableClientTestWorker) Work(ctx context.Context, job *Job[resumableClientTestArgs]) error { + appendCall := func(call string) { + w.callsMu.Lock() + defer w.callsMu.Unlock() + + w.calls = append(w.calls, call) + } + + ResumableStep(ctx, "step1", func(ctx context.Context) error { + appendCall("step1") + return nil + }) + + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor int) error { + appendCall("step2:" + strconv.Itoa(cursor)) + + for itemID := cursor + 1; itemID <= 2; itemID++ { + appendCall("item:" + strconv.Itoa(itemID)) + if err := ResumableSetCursor(ctx, itemID); err != nil { + return err + } + if !w.failedOnce.Swap(true) { + return errors.New("retry me") + } + } + + return nil + }) + + ResumableStep(ctx, "step3", func(ctx context.Context) error { + appendCall("step3") + return nil + }) + + return nil +} + func makeAwaitWorker[T JobArgs](startedCh chan<- int64, doneCh chan struct{}) Worker[T] { return WorkFunc(func(ctx context.Context, job *Job[T]) error { client := ClientFromContext[pgx.Tx](ctx) @@ -7308,6 +7364,58 @@ func Test_Client_JobCompletion(t *testing.T) { require.Nil(t, reloadedJob.FinalizedAt) }) + t.Run("ResumableJobRetriesAndResumes", func(t *testing.T) { + t.Parallel() + + config := newTestConfig(t, "") + config.RetryPolicy = &retrypolicytest.RetryPolicyNoJitter{} + + worker := &resumableClientTestWorker{} + AddWorker(config.Workers, worker) + + client, bundle := setup(t, config) + + insertRes, err := client.Insert(ctx, resumableClientTestArgs{}, nil) + require.NoError(t, err) + + // Wait for the first attempt to fail after step2 checkpoints cursor + // progress and intentionally returns "retry me", leaving the job queued + // for retry. + eventFailed := riversharedtest.WaitOrTimeout(t, bundle.subscribeChan) + require.Equal(t, EventKindJobFailed, eventFailed.Kind) + require.Equal(t, insertRes.Job.ID, eventFailed.Job.ID) + + var retryableMetadata map[string]any + require.Contains(t, []rivertype.JobState{rivertype.JobStateAvailable, rivertype.JobStateRetryable}, eventFailed.Job.State) + require.NoError(t, json.Unmarshal(eventFailed.Job.Metadata, &retryableMetadata)) + require.Equal(t, "step1", retryableMetadata["river:resumable_step"]) + require.Equal(t, map[string]any{"step2": float64(1)}, retryableMetadata["river:resumable_cursor"]) + + // Wait for the retried attempt to resume and then complete successfully. + eventCompleted := riversharedtest.WaitOrTimeout(t, bundle.subscribeChan) + require.Equal(t, EventKindJobCompleted, eventCompleted.Kind) + require.Equal(t, insertRes.Job.ID, eventCompleted.Job.ID) + + reloadedJob, err := client.JobGet(ctx, insertRes.Job.ID) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateCompleted, reloadedJob.State) + require.Len(t, reloadedJob.Errors, 1) + + var metadata map[string]any + require.NoError(t, json.Unmarshal(reloadedJob.Metadata, &metadata)) + require.Equal(t, "step1", metadata["river:resumable_step"]) + require.Equal(t, map[string]any{"step2": float64(1)}, metadata["river:resumable_cursor"]) + + require.Equal(t, []string{ + "step1", + "step2:0", + "item:1", + "step2:1", + "item:2", + "step3", + }, worker.Calls()) + }) + t.Run("JobThatReturnsJobCancelErrorIsImmediatelyCancelled", func(t *testing.T) { t.Parallel() @@ -7974,7 +8082,7 @@ func Test_NewClient_Validations(t *testing.T) { }, validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindJobInsert), 1) - require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 1) + require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 2) }, }, { @@ -7985,7 +8093,7 @@ func Test_NewClient_Validations(t *testing.T) { }, validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindJobInsert), 2) - require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 2) + require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 3) }, }, { @@ -7997,7 +8105,7 @@ func Test_NewClient_Validations(t *testing.T) { }, validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindJobInsert), 1) - require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 1) + require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 2) }, }, { diff --git a/example_resumable_cursor_job_test.go b/example_resumable_cursor_job_test.go new file mode 100644 index 00000000..b03dac7c --- /dev/null +++ b/example_resumable_cursor_job_test.go @@ -0,0 +1,105 @@ +package river_test + +import ( + "context" + "fmt" + "log/slog" + "os" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/util/slogutil" + "github.com/riverqueue/river/rivershared/util/testutil" +) + +type ResumableCursorArgs struct { + IDs []int `json:"ids"` +} + +func (ResumableCursorArgs) Kind() string { return "resumable_cursor" } + +type ResumableCursor struct { + LastProcessedID int `json:"last_processed_id"` +} + +type ResumableCursorWorker struct { + river.WorkerDefaults[ResumableCursorArgs] +} + +func (w *ResumableCursorWorker) Work(ctx context.Context, job *river.Job[ResumableCursorArgs]) error { + river.ResumableStepCursor(ctx, "process_ids", func(ctx context.Context, cursor ResumableCursor) error { + for _, id := range job.Args.IDs { + if id <= cursor.LastProcessedID { + continue + } + + fmt.Printf("Processed %d\n", id) + + if err := river.ResumableSetCursor(ctx, ResumableCursor{LastProcessedID: id}); err != nil { + return err + } + } + + return nil + }) + + return nil +} + +// Example_resumableCursor demonstrates the use of a resumable cursor step, a +// step that can store arbitrary JSON state to resume a partially completed loop. +func Example_resumableCursor() { + ctx := context.Background() + + dbPool, err := pgxpool.New(ctx, riversharedtest.TestDatabaseURL()) + if err != nil { + panic(err) + } + defer dbPool.Close() + + workers := river.NewWorkers() + river.AddWorker(workers, &ResumableCursorWorker{}) + + riverClient, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 100}, + }, + Schema: riverdbtest.TestSchema(ctx, testutil.PanicTB(), riverpgxv5.New(dbPool), nil), // only necessary for the example test + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Out of example scope, but used to wait until a job is worked. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobCompleted) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + + if _, err = riverClient.Insert(ctx, ResumableCursorArgs{ + IDs: []int{1, 2, 3}, + }, nil); err != nil { + panic(err) + } + + // Wait for jobs to complete. Only needed for purposes of the example test. + riversharedtest.WaitOrTimeoutN(testutil.PanicTB(), subscribeChan, 1) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // Processed 1 + // Processed 2 + // Processed 3 +} diff --git a/example_resumable_job_test.go b/example_resumable_job_test.go new file mode 100644 index 00000000..73b41fb5 --- /dev/null +++ b/example_resumable_job_test.go @@ -0,0 +1,98 @@ +package river_test + +import ( + "context" + "fmt" + "log/slog" + "os" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/util/slogutil" + "github.com/riverqueue/river/rivershared/util/testutil" +) + +type ResumableArgs struct{} + +func (ResumableArgs) Kind() string { return "resumable" } + +type ResumableWorker struct { + river.WorkerDefaults[ResumableArgs] +} + +func (w *ResumableWorker) Work(ctx context.Context, job *river.Job[ResumableArgs]) error { + river.ResumableStep(ctx, "step1", func(ctx context.Context) error { + fmt.Println("Step 1") + return nil + }) + + // If River was forced to stop between step1 and step2 or midway into step2, + // the next run of this job skips step1 and picks up from here. + river.ResumableStep(ctx, "step2", func(ctx context.Context) error { + fmt.Println("Step 2") + return nil + }) + + river.ResumableStep(ctx, "step3", func(ctx context.Context) error { + fmt.Println("Step 3") + return nil + }) + + return nil +} + +// Example_resumable demonstrates the use of a "resumable job", a job that has +// multiple steps, and which can be resumed after each one. +func Example_resumable() { + ctx := context.Background() + + dbPool, err := pgxpool.New(ctx, riversharedtest.TestDatabaseURL()) + if err != nil { + panic(err) + } + defer dbPool.Close() + + workers := river.NewWorkers() + river.AddWorker(workers, &ResumableWorker{}) + + riverClient, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 100}, + }, + Schema: riverdbtest.TestSchema(ctx, testutil.PanicTB(), riverpgxv5.New(dbPool), nil), // only necessary for the example test + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Out of example scope, but used to wait until a job is worked. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobCompleted) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + + if _, err = riverClient.Insert(ctx, ResumableArgs{}, nil); err != nil { + panic(err) + } + + // Wait for jobs to complete. Only needed for purposes of the example test. + riversharedtest.WaitOrTimeoutN(testutil.PanicTB(), subscribeChan, 1) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // Step 1 + // Step 2 + // Step 3 +} diff --git a/example_resumable_set_step_tx_test.go b/example_resumable_set_step_tx_test.go new file mode 100644 index 00000000..767d996d --- /dev/null +++ b/example_resumable_set_step_tx_test.go @@ -0,0 +1,126 @@ +package river_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/util/slogutil" + "github.com/riverqueue/river/rivershared/util/testutil" +) + +type ResumableStepTxArgs struct{} + +func (ResumableStepTxArgs) Kind() string { return "resumable_step_tx" } + +// ResumableStepTxWorker persists resumable step progress transactionally before +// failing the job. +type ResumableStepTxWorker struct { + river.WorkerDefaults[ResumableStepTxArgs] + + dbPool *pgxpool.Pool +} + +func (w *ResumableStepTxWorker) Work(ctx context.Context, job *river.Job[ResumableStepTxArgs]) error { + const durableStep = "durable_step" + + river.ResumableStep(ctx, durableStep, func(ctx context.Context) error { + tx, err := w.dbPool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + // Perform some kind database work in a transaction. + var result int + if err := tx.QueryRow(ctx, "SELECT 1").Scan(&result); err != nil { + return err + } + + // Then, record the step as completed in the same transaction. + if _, err := river.ResumableSetStepTx[*riverpgxv5.Driver](ctx, tx, job); err != nil { + return err + } + + if err := tx.Commit(ctx); err != nil { + return err + } + + return errors.New("simulated failure after persisting step") + }) + + return nil +} + +// Example_resumableSetStepTx demonstrates how to transactionally persist a +// resumable step so it survives a failed attempt. +func Example_resumableSetStepTx() { + ctx := context.Background() + + dbPool, err := pgxpool.New(ctx, riversharedtest.TestDatabaseURL()) + if err != nil { + panic(err) + } + defer dbPool.Close() + + workers := river.NewWorkers() + river.AddWorker(workers, &ResumableStepTxWorker{dbPool: dbPool}) + + riverClient, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn, ReplaceAttr: slogutil.NoLevelTime})), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 100}, + }, + Schema: riverdbtest.TestSchema(ctx, testutil.PanicTB(), riverpgxv5.New(dbPool), nil), // only necessary for the example test + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Used only to help the example test wait for the failed attempt. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobFailed) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + + insertRes, err := riverClient.Insert(ctx, ResumableStepTxArgs{}, nil) + if err != nil { + panic(err) + } + + // Wait for the failed attempt so the persisted step can be inspected. + riversharedtest.WaitOrTimeoutN(testutil.PanicTB(), subscribeChan, 1) + + jobAfter, err := riverClient.JobGet(ctx, insertRes.Job.ID) + if err != nil { + panic(err) + } + + var metadata map[string]any + if err := json.Unmarshal(jobAfter.Metadata, &metadata); err != nil { + panic(err) + } + + fmt.Printf("Persisted resumable step: %s\n", metadata[rivercommon.MetadataKeyResumableStep]) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // Persisted resumable step: durable_step +} diff --git a/internal/maintenance/job_scheduler_test.go b/internal/maintenance/job_scheduler_test.go index 25388042..ecd9009e 100644 --- a/internal/maintenance/job_scheduler_test.go +++ b/internal/maintenance/job_scheduler_test.go @@ -330,7 +330,8 @@ func TestJobScheduler(t *testing.T) { addJob := func(queue string, fromNow time.Duration, state rivertype.JobState) { t.Helper() var finalizedAt *time.Time - switch state { //nolint:exhaustive + switch state { + case rivertype.JobStateAvailable, rivertype.JobStatePending, rivertype.JobStateRetryable, rivertype.JobStateRunning, rivertype.JobStateScheduled: case rivertype.JobStateCompleted, rivertype.JobStateCancelled, rivertype.JobStateDiscarded: finalizedAt = ptrutil.Ptr(now.Add(fromNow)) } diff --git a/internal/rivercommon/river_common.go b/internal/rivercommon/river_common.go index d409ccf1..986583ed 100644 --- a/internal/rivercommon/river_common.go +++ b/internal/rivercommon/river_common.go @@ -24,6 +24,14 @@ const ( // them. MetadataKeyPeriodicJobID = "river:periodic_job_id" + // MetadataKeyResumableStep records the last successfully completed step for + // a resumable job so later attempts can skip ahead. + MetadataKeyResumableStep = "river:resumable_step" + + // MetadataKeyResumableCursor records a resumable step cursor so a later + // attempt can resume a partially completed step. + MetadataKeyResumableCursor = "river:resumable_cursor" + // MetadataKeyRescueCount records how many times the job has been rescued. MetadataKeyRescueCount = "river:rescue_count" diff --git a/internal/rivermiddleware/middleware.go b/internal/rivermiddleware/middleware.go new file mode 100644 index 00000000..01f4a517 --- /dev/null +++ b/internal/rivermiddleware/middleware.go @@ -0,0 +1,84 @@ +package rivermiddleware + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/tidwall/gjson" + + "github.com/riverqueue/river/internal/jobexecutor" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/rivertype" +) + +// DefaultMiddleware returns the default middleware that River applies to all +// jobs. This includes internal middleware like the resumable step middleware. +func DefaultMiddleware() []rivertype.Middleware { + return []rivertype.Middleware{&ResumableMiddleware{}} +} + +// ResumableMiddleware is internal middleware that enables resumable step +// functionality. It reads the last completed step and cursor data from job +// metadata, injects them into the context, and persists updated step/cursor +// state back to metadata when a job errors after making progress. +type ResumableMiddleware struct{} + +func (*ResumableMiddleware) IsMiddleware() bool { return true } + +func (*ResumableMiddleware) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { + metadataUpdates, hasMetadataUpdates := jobexecutor.MetadataUpdatesFromWorkContext(ctx) + if !hasMetadataUpdates { + return errors.New("expected to find metadata updates in context, but didn't") + } + + state := &ResumableState{ + Cursors: make(map[string]json.RawMessage), + ResumeMatched: true, + ResumeStep: gjson.GetBytes(job.Metadata, rivercommon.MetadataKeyResumableStep).Str, + } + if state.ResumeStep != "" { + state.ResumeMatched = false + } + if cursorJSON := gjson.GetBytes(job.Metadata, rivercommon.MetadataKeyResumableCursor); cursorJSON.Exists() && cursorJSON.Type == gjson.JSON { + if err := json.Unmarshal([]byte(cursorJSON.Raw), &state.Cursors); err != nil { + return fmt.Errorf("river: unmarshal resumable cursors: %w", err) + } + } + + ctx = context.WithValue(ctx, ResumableContextKey{}, state) + + err := doInner(ctx) + if err == nil { + switch { + case state.Err != nil: + err = state.Err + case state.ResumeStep != "" && !state.ResumeMatched: + err = fmt.Errorf("river: resumable step %q not found in Worker", state.ResumeStep) + } + } + + if err != nil && state.CompletedStep != "" { + if len(state.Cursors) > 0 { + metadataUpdates[rivercommon.MetadataKeyResumableCursor] = state.Cursors + } + metadataUpdates[rivercommon.MetadataKeyResumableStep] = state.CompletedStep + } + + return err +} + +// ResumableState holds the state for a resumable job execution. It is stored in +// the context and accessed by ResumableStep and ResumableStepCursor. +type ResumableState struct { + CompletedStep string + Cursors map[string]json.RawMessage + Err error + ResumeMatched bool + ResumeStep string + StepName string +} + +// ResumableContextKey is the context key for ResumableState. +type ResumableContextKey struct{} diff --git a/job_list_params.go b/job_list_params.go index 116a5c76..cebdf57b 100644 --- a/job_list_params.go +++ b/job_list_params.go @@ -234,11 +234,10 @@ func (p *JobListParams) toDBParams() (*dblist.JobListParams, error) { if p.sortField == JobListOrderByFinalizedAt { currentNonFinalizedStates := make([]rivertype.JobState, 0, len(p.states)) for _, state := range p.states { - //nolint:exhaustive switch state { - case rivertype.JobStateCancelled, rivertype.JobStateCompleted, rivertype.JobStateDiscarded: - default: + case rivertype.JobStateAvailable, rivertype.JobStatePending, rivertype.JobStateRetryable, rivertype.JobStateRunning, rivertype.JobStateScheduled: currentNonFinalizedStates = append(currentNonFinalizedStates, state) + case rivertype.JobStateCancelled, rivertype.JobStateCompleted, rivertype.JobStateDiscarded: } } // This indicates the user overrode the States list with only non-finalized diff --git a/resumable.go b/resumable.go new file mode 100644 index 00000000..a305eaa1 --- /dev/null +++ b/resumable.go @@ -0,0 +1,156 @@ +package river + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/riverqueue/river/internal/rivermiddleware" +) + +var ( + errResumableStepNotInWorker = errors.New("river: resumable step can only be used within a Worker") + errResumableCursorNotInStep = errors.New("river: resumable cursor can only be used within ResumableStepCursor") +) + +// ResumableSetCursor records a cursor for the current resumable cursor step. +// The cursor is stored only if the job attempt ends in an error, allowing a +// later retry to resume the same step from the recorded position. +// +// Alternatively, ResumableSetStepCursorTx is available to persist a step and +// cursor immediately as part of a transaction, guaranteeing that it's stored +// durably. +func ResumableSetCursor[TCursor any](ctx context.Context, cursor TCursor) error { + state := mustResumableState(ctx) + if state.StepName == "" { + return errResumableCursorNotInStep + } + + cursorBytes, err := json.Marshal(cursor) + if err != nil { + return err + } + + if state.Cursors == nil { + state.Cursors = make(map[string]json.RawMessage) + } + state.Cursors[state.StepName] = json.RawMessage(cursorBytes) + return nil +} + +// ResumableStep runs a resumable step, skipping the step on a later retry if +// an earlier attempt already completed it successfully. +// +// After a step returns an error, no subsequent steps will be run and the +// overall job will be marked as failed with that error. Be careful to put all +// executable code in steps, because any code outside of them will be run, even +// if a step returned an error. +func ResumableStep(ctx context.Context, name string, stepFunc func(ctx context.Context) error) { + state := mustResumableState(ctx) + if state.Err != nil { + return + } + + if !state.ResumeMatched { + if name == state.ResumeStep { + state.CompletedStep = name + state.ResumeMatched = true + } + return + } + + previousStepName := state.StepName + state.StepName = name + defer func() { state.StepName = previousStepName }() + + if err := stepFunc(ctx); err != nil { + state.Err = err + return + } + + state.CompletedStep = name +} + +// ResumableStepCursor runs a resumable step that also receives a persisted +// cursor value from an earlier failed attempt, if one was recorded with +// ResumableSetCursor. +// +// The cursor type T is user-specified. It may be a primitive value like an +// integer ID, or a more complex type like a struct with multiple fields. It's +// stored in a job's metadata, so it needs to be marshable and unmarshable to +// and from JSON. +// +// Notably, it's the responsibility of the step function to call +// ResumableSetCursor with an updated cursor value as progress is made, and to +// check the cursor value before running to determine where to resume from. +// +// After a step returns an error, no subsequent steps will be run and the +// overall job will be marked as failed with that error. Be careful to put all +// executable code in steps, because any code outside of them will be run, even +// if a step returned an error. +func ResumableStepCursor[TCursor any](ctx context.Context, name string, stepFunc func(ctx context.Context, cursor TCursor) error) { + state := mustResumableState(ctx) + if state.Err != nil { + return + } + + if !state.ResumeMatched { + if name == state.ResumeStep { + state.CompletedStep = name + state.ResumeMatched = true + + // If cursor data exists for this step, it was only partially + // completed on the previous attempt. Fall through to re-execute + // it with the cursor rather than skipping it. + if _, hasCursor := state.Cursors[name]; !hasCursor { + return + } + } else { + return + } + } + + var cursor TCursor + if cursorBytes, ok := state.Cursors[name]; ok && len(cursorBytes) > 0 { + if err := json.Unmarshal(cursorBytes, &cursor); err != nil { + state.Err = fmt.Errorf("river: unmarshal resumable cursor for step %q: %w", name, err) + return + } + } + + previousStepName := state.StepName + state.StepName = name + defer func() { state.StepName = previousStepName }() + + if err := stepFunc(ctx, cursor); err != nil { + state.Err = err + return + } + + state.CompletedStep = name + delete(state.Cursors, name) +} + +func mustResumableState(ctx context.Context) *rivermiddleware.ResumableState { + state, ok := resumableStateFromContext(ctx) + if !ok { + panic(errResumableStepNotInWorker) + } + + return state +} + +func resumableStateFromContext(ctx context.Context) (*rivermiddleware.ResumableState, bool) { + state := ctx.Value(rivermiddleware.ResumableContextKey{}) + if state == nil { + return nil, false + } + + typedState, ok := state.(*rivermiddleware.ResumableState) + if !ok || typedState == nil { + return nil, false + } + + return typedState, true +} diff --git a/resumable_step_tx.go b/resumable_step_tx.go new file mode 100644 index 00000000..6d212733 --- /dev/null +++ b/resumable_step_tx.go @@ -0,0 +1,122 @@ +package river + +import ( + "context" + "encoding/json" + "errors" + + "github.com/riverqueue/river/internal/execution" + "github.com/riverqueue/river/internal/jobexecutor" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/rivertype" +) + +// ResumableSetStepTx immediately persists the current resumable step as +// part of transaction tx. If tx is rolled back, the step update will be as +// well. +// +// Normally, a resumable job's step progress is recorded after it runs along +// with its result status. This is normally sufficient, but because it happens +// out-of-transaction, there's a chance that it doesn't happen in case of panic +// or other abrupt termination. This function is useful in cases where a +// resumable worker needs a guarantee of a checkpoint being recorded durably, at +// the cost of an extra database operation. +// +// Must be called from within a ResumableStep or ResumableStepCursor callback. +func ResumableSetStepTx[TDriver riverdriver.Driver[TTx], TTx any, TArgs JobArgs](ctx context.Context, tx TTx, job *Job[TArgs]) (*Job[TArgs], error) { + return resumableSetStepTx(ctx, tx, job, nil) +} + +// ResumableSetStepCursorTx immediately persists the current resumable step and +// cursor as part of transaction tx. If tx is rolled back, the step and cursor +// update will be as well. +// +// Normally, a resumable job's step progress is recorded after it runs along +// with its result status. This is normally sufficient, but because it happens +// out-of-transaction, there's a chance that it doesn't happen in case of panic +// or other abrupt termination. This function is useful in cases where a +// resumable worker needs a guarantee of a checkpoint being recorded durably, at +// the cost of an extra database operation. +// +// Must be called from within a ResumableStepCursor callback. +func ResumableSetStepCursorTx[TDriver riverdriver.Driver[TTx], TTx any, TArgs JobArgs, TCursor any](ctx context.Context, tx TTx, job *Job[TArgs], cursor TCursor) (*Job[TArgs], error) { + cursorBytes, err := json.Marshal(cursor) + if err != nil { + return nil, err + } + + return resumableSetStepTx(ctx, tx, job, json.RawMessage(cursorBytes)) +} + +func resumableSetStepTx[TTx any, TArgs JobArgs](ctx context.Context, tx TTx, job *Job[TArgs], cursor json.RawMessage) (*Job[TArgs], error) { + if job.State != rivertype.JobStateRunning { + return nil, errors.New("job must be running") + } + + state, ok := resumableStateFromContext(ctx) + if !ok { + return nil, errors.New("not inside a resumable step; must be called from within ResumableStep or ResumableStepCursor") + } + if state.StepName == "" { + return nil, errors.New("not inside a resumable step; must be called from within ResumableStep or ResumableStepCursor") + } + + step := state.StepName + + client := ClientFromContext[TTx](ctx) + if client == nil { + return nil, errors.New("client not found in context, can only work within a River worker") + } + + metadataUpdates := map[string]any{ + rivercommon.MetadataKeyResumableStep: step, + } + + state.CompletedStep = step + if cursor != nil { + if state.Cursors == nil { + state.Cursors = make(map[string]json.RawMessage) + } + state.Cursors[step] = cursor + } + if len(state.Cursors) > 0 { + metadataUpdates[rivercommon.MetadataKeyResumableCursor] = state.Cursors + } + + workMetadataUpdates, hasWorkMetadataUpdates := jobexecutor.MetadataUpdatesFromWorkContext(ctx) + if hasWorkMetadataUpdates { + workMetadataUpdates[rivercommon.MetadataKeyResumableStep] = step + if resumableCursorMetadata, ok := metadataUpdates[rivercommon.MetadataKeyResumableCursor]; ok { + workMetadataUpdates[rivercommon.MetadataKeyResumableCursor] = resumableCursorMetadata + } + } + + metadataUpdatesBytes, err := json.Marshal(metadataUpdates) + if err != nil { + return nil, err + } + + updatedJob, err := client.Driver().UnwrapExecutor(tx).JobUpdate(ctx, &riverdriver.JobUpdateParams{ + ID: job.ID, + MetadataDoMerge: true, + Metadata: metadataUpdatesBytes, + Schema: client.config.Schema, + }) + if err != nil { + if errors.Is(err, rivertype.ErrNotFound) { + if _, isInsideTestWorker := ctx.Value(execution.ContextKeyInsideTestWorker{}).(bool); isInsideTestWorker { + panic("to use ResumableSetStepTx or ResumableSetStepCursorTx in a rivertest.Worker, the job must be inserted into the database first") + } + } + + return nil, err + } + + result := &Job[TArgs]{JobRow: updatedJob} + if err := json.Unmarshal(result.EncodedArgs, &result.Args); err != nil { + return nil, err + } + + return result, nil +} diff --git a/resumable_step_tx_test.go b/resumable_step_tx_test.go new file mode 100644 index 00000000..3202759f --- /dev/null +++ b/resumable_step_tx_test.go @@ -0,0 +1,169 @@ +package river + +import ( + "context" + "encoding/json" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/internal/execution" + "github.com/riverqueue/river/internal/jobexecutor" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/internal/rivermiddleware" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/testfactory" + "github.com/riverqueue/river/rivershared/util/ptrutil" + "github.com/riverqueue/river/rivershared/util/testutil" + "github.com/riverqueue/river/rivertype" +) + +func TestResumableSetStepTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type JobArgs struct { + testutil.JobArgsReflectKind[JobArgs] + } + + type testBundle struct { + client *Client[pgx.Tx] + exec riverdriver.Executor + tx pgx.Tx + } + + setup := func(ctx context.Context, t *testing.T, stepName string) (context.Context, *testBundle) { + t.Helper() + + tx := riverdbtest.TestTxPgx(ctx, t) + client, err := NewClient(riverpgxv5.New(nil), &Config{ + Logger: riversharedtest.Logger(t), + }) + require.NoError(t, err) + ctx = context.WithValue(ctx, rivercommon.ContextKeyClient{}, client) + ctx = context.WithValue(ctx, jobexecutor.ContextKeyMetadataUpdates, make(map[string]any)) + ctx = context.WithValue(ctx, rivermiddleware.ResumableContextKey{}, &rivermiddleware.ResumableState{ + Cursors: make(map[string]json.RawMessage), + StepName: stepName, + }) + + return ctx, &testBundle{ + client: client, + exec: riverpgxv5.New(nil).UnwrapExecutor(tx), + tx: tx, + } + } + + t.Run("SetsStep", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t, "step1") + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ + State: ptrutil.Ptr(rivertype.JobStateRunning), + }) + + updatedJob, err := ResumableSetStepTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateRunning, updatedJob.State) + + reloadedJob, err := bundle.exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ID: job.ID}) + require.NoError(t, err) + + var metadata map[string]any + require.NoError(t, json.Unmarshal(reloadedJob.Metadata, &metadata)) + require.Equal(t, "step1", metadata[rivercommon.MetadataKeyResumableStep]) + }) + + t.Run("SetsStepAndCursor", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t, "step2") + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ + State: ptrutil.Ptr(rivertype.JobStateRunning), + }) + + type Cursor struct { + ID int `json:"id"` + } + + updatedJob, err := ResumableSetStepCursorTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}, Cursor{ID: 123}) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateRunning, updatedJob.State) + + reloadedJob, err := bundle.exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ID: job.ID}) + require.NoError(t, err) + + var metadata map[string]any + require.NoError(t, json.Unmarshal(reloadedJob.Metadata, &metadata)) + require.Equal(t, "step2", metadata[rivercommon.MetadataKeyResumableStep]) + require.Equal(t, map[string]any{"step2": map[string]any{"id": float64(123)}}, metadata[rivercommon.MetadataKeyResumableCursor]) + + metadataUpdates, ok := jobexecutor.MetadataUpdatesFromWorkContext(ctx) + require.True(t, ok) + require.Equal(t, "step2", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + }) + + t.Run("ErrorIfNotRunning", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t, "step1") + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) + + _, err := ResumableSetStepTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}) + require.EqualError(t, err, "job must be running") + }) + + t.Run("ErrorIfNotInStep", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t, "") // empty step name + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ + State: ptrutil.Ptr(rivertype.JobStateRunning), + }) + + _, err := ResumableSetStepTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}) + require.EqualError(t, err, "not inside a resumable step; must be called from within ResumableStep or ResumableStepCursor") + }) + + t.Run("ErrorIfJobDoesntExist", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t, "step1") + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ + State: ptrutil.Ptr(rivertype.JobStateAvailable), + }) + _, err := bundle.exec.JobDelete(ctx, &riverdriver.JobDeleteParams{ID: job.ID}) + require.NoError(t, err) + + job.State = rivertype.JobStateRunning + _, err = ResumableSetStepTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}) + require.ErrorIs(t, err, rivertype.ErrNotFound) + }) + + t.Run("PanicsIfCalledInTestWorkerWithoutInsertingJob", func(t *testing.T) { + t.Parallel() + + ctx, bundle := setup(ctx, t, "step1") + ctx = context.WithValue(ctx, execution.ContextKeyInsideTestWorker{}, true) + + job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateAvailable)}) + _, err := bundle.client.JobDeleteTx(ctx, bundle.tx, job.ID) + require.NoError(t, err) + job.State = rivertype.JobStateRunning + + require.PanicsWithValue(t, "to use ResumableSetStepTx or ResumableSetStepCursorTx in a rivertest.Worker, the job must be inserted into the database first", func() { + _, err := ResumableSetStepTx[*riverpgxv5.Driver](ctx, bundle.tx, &Job[JobArgs]{JobRow: job}) + require.NoError(t, err) + }) + }) +} diff --git a/resumable_test.go b/resumable_test.go new file mode 100644 index 00000000..49438f66 --- /dev/null +++ b/resumable_test.go @@ -0,0 +1,253 @@ +package river + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/internal/jobexecutor" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/internal/rivermiddleware" + "github.com/riverqueue/river/rivertype" +) + +func TestResumableStep(t *testing.T) { + t.Parallel() + + setup := func(t *testing.T, metadata string) (context.Context, map[string]any, *rivertype.JobRow) { + t.Helper() + + metadataUpdates := make(map[string]any) + ctx := context.WithValue(context.Background(), jobexecutor.ContextKeyMetadataUpdates, metadataUpdates) + + return ctx, metadataUpdates, &rivertype.JobRow{Metadata: []byte(metadata)} + } + + t.Run("PanicsOutsideWorker", func(t *testing.T) { + t.Parallel() + + require.PanicsWithError(t, errResumableStepNotInWorker.Error(), func() { + ResumableStep(context.Background(), "step1", func(ctx context.Context) error { return nil }) + }) + }) + + t.Run("ResumesFromLastCompletedStep", func(t *testing.T) { + t.Parallel() + + ctx, metadataUpdates, job := setup(t, `{}`) + + var ran []string + err := (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, "step1") + return nil + }) + ResumableStep(ctx, "step2", func(ctx context.Context) error { + ran = append(ran, "step2") + return errors.New("step2 failed") + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, "step3") + return nil + }) + + return nil + }) + require.EqualError(t, err, "step2 failed") + require.Equal(t, []string{"step1", "step2"}, ran) + require.Equal(t, "step1", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + + ctx, metadataUpdates, job = setup(t, `{"river:resumable_step":"step1"}`) + ran = nil + + err = (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, "step1") + return nil + }) + ResumableStep(ctx, "step2", func(ctx context.Context) error { + ran = append(ran, "step2") + return nil + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, "step3") + return nil + }) + + return nil + }) + require.NoError(t, err) + require.Equal(t, []string{"step2", "step3"}, ran) + require.Empty(t, metadataUpdates) + }) + + t.Run("SavesLastCompletedStepOnContextCancellation", func(t *testing.T) { + t.Parallel() + + baseCtx, metadataUpdates, job := setup(t, `{}`) + ctx, cancel := context.WithCancel(baseCtx) + defer cancel() + + var ran []string + err := (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, "step1") + cancel() + return nil + }) + ResumableStep(ctx, "step2", func(ctx context.Context) error { + ran = append(ran, "step2") + return ctx.Err() + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, "step3") + return nil + }) + + return nil + }) + require.ErrorIs(t, err, context.Canceled) + require.Equal(t, []string{"step1", "step2"}, ran) + require.Equal(t, "step1", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + }) +} + +func TestResumableStepCursor(t *testing.T) { + t.Parallel() + + type resumableCursor struct { + ID int `json:"id"` + } + + setup := func(t *testing.T, metadata string) (context.Context, map[string]any, *rivertype.JobRow) { + t.Helper() + + metadataUpdates := make(map[string]any) + ctx := context.WithValue(context.Background(), jobexecutor.ContextKeyMetadataUpdates, metadataUpdates) + + return ctx, metadataUpdates, &rivertype.JobRow{Metadata: []byte(metadata)} + } + + t.Run("ResumesCursor", func(t *testing.T) { + t.Parallel() + + ctx, metadataUpdates, job := setup(t, `{}`) + + var ( + cursorResult resumableCursor + ran []int + ) + err := (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, 1) + return nil + }) + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor resumableCursor) error { + cursorResult = cursor + ran = append(ran, 2) + require.NoError(t, ResumableSetCursor(ctx, resumableCursor{ID: 42})) + return errors.New("step2 failed") + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, 3) + return nil + }) + + return nil + }) + require.EqualError(t, err, "step2 failed") + require.Equal(t, resumableCursor{}, cursorResult) + require.Equal(t, []int{1, 2}, ran) + require.Equal(t, "step1", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + cursorMetadata, err := json.Marshal(metadataUpdates[rivercommon.MetadataKeyResumableCursor]) + require.NoError(t, err) + require.JSONEq(t, `{"step2":{"id":42}}`, string(cursorMetadata)) + + ctx, metadataUpdates, job = setup(t, `{"river:resumable_cursor":{"step2":{"id":42}},"river:resumable_step":"step1"}`) + cursorResult = resumableCursor{} + ran = nil + + err = (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, 1) + return nil + }) + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor resumableCursor) error { + cursorResult = cursor + ran = append(ran, 2) + return nil + }) + ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, 3) + return nil + }) + + return nil + }) + require.NoError(t, err) + require.Equal(t, resumableCursor{ID: 42}, cursorResult) + require.Equal(t, []int{2, 3}, ran) + require.Empty(t, metadataUpdates) + }) + + t.Run("SetCursorOutsideStep", func(t *testing.T) { + t.Parallel() + + ctx, _, _ := setup(t, `{}`) + + err := (&rivermiddleware.ResumableMiddleware{}).Work(ctx, &rivertype.JobRow{Metadata: []byte(`{}`)}, func(ctx context.Context) error { + return ResumableSetCursor(ctx, 1) + }) + require.ErrorIs(t, err, errResumableCursorNotInStep) + }) + + t.Run("SupportsMultipleCursorStepsWithDifferentTypes", func(t *testing.T) { + t.Parallel() + + type secondCursor struct { + ID string `json:"id"` + } + + ctx, metadataUpdates, job := setup(t, `{}`) + + err := (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStepCursor(ctx, "step1", func(ctx context.Context, cursor int) error { + require.Zero(t, cursor) + require.NoError(t, ResumableSetCursor(ctx, 123)) + return nil + }) + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor secondCursor) error { + require.Equal(t, secondCursor{}, cursor) + require.NoError(t, ResumableSetCursor(ctx, secondCursor{ID: "abc"})) + return errors.New("step2 failed") + }) + + return nil + }) + require.EqualError(t, err, "step2 failed") + require.Equal(t, "step1", metadataUpdates[rivercommon.MetadataKeyResumableStep]) + cursorMetadata, err := json.Marshal(metadataUpdates[rivercommon.MetadataKeyResumableCursor]) + require.NoError(t, err) + require.JSONEq(t, `{"step2":{"id":"abc"}}`, string(cursorMetadata)) + + ctx, metadataUpdates, job = setup(t, `{"river:resumable_cursor":{"step1":123,"step2":{"id":"abc"}},"river:resumable_step":"step1"}`) + + err = (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error { + ResumableStepCursor(ctx, "step1", func(ctx context.Context, cursor int) error { + require.Equal(t, 123, cursor) + return nil + }) + ResumableStepCursor(ctx, "step2", func(ctx context.Context, cursor secondCursor) error { + require.Equal(t, secondCursor{ID: "abc"}, cursor) + return nil + }) + + return nil + }) + require.NoError(t, err) + require.Empty(t, metadataUpdates) + }) +} diff --git a/rivertest/resumable.go b/rivertest/resumable.go new file mode 100644 index 00000000..220f000d --- /dev/null +++ b/rivertest/resumable.go @@ -0,0 +1,70 @@ +package rivertest + +import ( + "encoding/json" + "fmt" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/internal/rivercommon" +) + +// ResumeAfterStep configures insert options so that the job resumes after the +// named step. When the job is worked, all steps up to and including the named +// step will be skipped, and execution will begin at the next step. +// +// Used with [Worker.Work], simulates a job that previously completed the named +// step, then failed on a subsequent step and is now being retried. +// +// result, err := testWorker.Work(ctx, t, tx, args, +// rivertest.ResumeAfterStep(&river.InsertOpts{}, "step1")) +func ResumeAfterStep(opts *river.InsertOpts, stepName string) *river.InsertOpts { + mergeResumableMetadata(opts, stepName, nil) + return opts +} + +// ResumeAtStepWithCursor configures insert options so that the job resumes at +// the named cursor step with the provided cursor data. The named step and all +// subsequent steps will execute; all steps before it are skipped. +// +// Used with [Worker.Work], simulates a job that was partway through a +// [river.ResumableStepCursor] step when it failed. On retry, the step runs +// again with the cursor so it can pick up where it left off. +// +// result, err := testWorker.Work(ctx, t, tx, args, +// rivertest.ResumeAtStepWithCursor(&river.InsertOpts{}, "process_ids", MyCursor{LastID: 42})) +func ResumeAtStepWithCursor[TCursor any](opts *river.InsertOpts, stepName string, cursor TCursor) *river.InsertOpts { + cursorBytes, err := json.Marshal(cursor) + if err != nil { + panic(fmt.Sprintf("rivertest: marshal resumable cursor: %s", err)) + } + + mergeResumableMetadata(opts, stepName, map[string]json.RawMessage{ + stepName: json.RawMessage(cursorBytes), + }) + return opts +} + +func mergeResumableMetadata(opts *river.InsertOpts, step string, cursors map[string]json.RawMessage) { + var existing map[string]any + if len(opts.Metadata) > 0 { + if err := json.Unmarshal(opts.Metadata, &existing); err != nil { + panic(fmt.Sprintf("rivertest: unmarshal existing metadata: %s", err)) + } + } else { + existing = make(map[string]any) + } + + if step != "" { + existing[rivercommon.MetadataKeyResumableStep] = step + } + if len(cursors) > 0 { + existing[rivercommon.MetadataKeyResumableCursor] = cursors + } + + metadataBytes, err := json.Marshal(existing) + if err != nil { + panic(fmt.Sprintf("rivertest: marshal resumable metadata: %s", err)) + } + + opts.Metadata = metadataBytes +} diff --git a/rivertest/resumable_test.go b/rivertest/resumable_test.go new file mode 100644 index 00000000..93fb8685 --- /dev/null +++ b/rivertest/resumable_test.go @@ -0,0 +1,201 @@ +package rivertest + +import ( + "context" + "encoding/json" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" +) + +type resumableTestArgs struct{} + +func (resumableTestArgs) Kind() string { return "resumable_test" } + +func TestResumeAfterStep(t *testing.T) { + t.Parallel() + + t.Run("SetsStepInMetadata", func(t *testing.T) { + t.Parallel() + + opts := &river.InsertOpts{} + ResumeAfterStep(opts, "step2") + + var metadata map[string]any + require.NoError(t, json.Unmarshal(opts.Metadata, &metadata)) + require.Equal(t, "step2", metadata[rivercommon.MetadataKeyResumableStep]) + require.NotContains(t, metadata, rivercommon.MetadataKeyResumableCursor) + }) + + t.Run("PreservesExistingMetadata", func(t *testing.T) { + t.Parallel() + + opts := &river.InsertOpts{ + Metadata: []byte(`{"custom_key":"custom_value"}`), + } + ResumeAfterStep(opts, "step1") + + var metadata map[string]any + require.NoError(t, json.Unmarshal(opts.Metadata, &metadata)) + require.Equal(t, "step1", metadata[rivercommon.MetadataKeyResumableStep]) + require.Equal(t, "custom_value", metadata["custom_key"]) + }) +} + +func TestResumeAtStepWithCursor(t *testing.T) { + t.Parallel() + + type testCursor struct { + LastProcessedID int `json:"last_processed_id"` + } + + t.Run("SetsStepAndCursorInMetadata", func(t *testing.T) { + t.Parallel() + + opts := &river.InsertOpts{} + ResumeAtStepWithCursor(opts, "process_ids", testCursor{LastProcessedID: 42}) + + var metadata map[string]json.RawMessage + require.NoError(t, json.Unmarshal(opts.Metadata, &metadata)) + + var step string + require.NoError(t, json.Unmarshal(metadata[rivercommon.MetadataKeyResumableStep], &step)) + require.Equal(t, "process_ids", step) + + var cursors map[string]json.RawMessage + require.NoError(t, json.Unmarshal(metadata[rivercommon.MetadataKeyResumableCursor], &cursors)) + + var cursor testCursor + require.NoError(t, json.Unmarshal(cursors["process_ids"], &cursor)) + require.Equal(t, 42, cursor.LastProcessedID) + }) +} + +func TestResumableOptsIntegration(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + driver *riverpgxv5.Driver + tx pgx.Tx + } + + setup := func(t *testing.T) *testBundle { + t.Helper() + + return &testBundle{ + driver: riverpgxv5.New(nil), + tx: riverdbtest.TestTxPgx(ctx, t), + } + } + + t.Run("ResumeAfterStepSkipsCompletedSteps", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + + var ran []string + worker := river.WorkFunc(func(ctx context.Context, job *river.Job[resumableTestArgs]) error { + river.ResumableStep(ctx, "step1", func(ctx context.Context) error { + ran = append(ran, "step1") + return nil + }) + river.ResumableStep(ctx, "step2", func(ctx context.Context) error { + ran = append(ran, "step2") + return nil + }) + river.ResumableStep(ctx, "step3", func(ctx context.Context) error { + ran = append(ran, "step3") + return nil + }) + return nil + }) + + config := &river.Config{ID: "rivertest-resumable"} + tw := NewWorker(t, bundle.driver, config, worker) + + opts := &river.InsertOpts{} + ResumeAfterStep(opts, "step1") + + result, err := tw.Work(ctx, t, bundle.tx, resumableTestArgs{}, opts) + require.NoError(t, err) + require.Equal(t, river.EventKindJobCompleted, result.EventKind) + require.Equal(t, []string{"step2", "step3"}, ran) + }) + + t.Run("ResumeAtCursorStepPassesCursor", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + + type testCursor struct { + LastProcessedID int `json:"last_processed_id"` + } + + var ( + ran []string + receivedCursor testCursor + ) + worker := river.WorkFunc(func(ctx context.Context, job *river.Job[resumableTestArgs]) error { + river.ResumableStep(ctx, "validate", func(ctx context.Context) error { + ran = append(ran, "validate") + return nil + }) + river.ResumableStepCursor(ctx, "process_ids", func(ctx context.Context, cursor testCursor) error { + ran = append(ran, "process_ids") + receivedCursor = cursor + return nil + }) + return nil + }) + + config := &river.Config{ID: "rivertest-resumable"} + tw := NewWorker(t, bundle.driver, config, worker) + + opts := &river.InsertOpts{} + ResumeAtStepWithCursor(opts, "process_ids", testCursor{LastProcessedID: 42}) + + result, err := tw.Work(ctx, t, bundle.tx, resumableTestArgs{}, opts) + require.NoError(t, err) + require.Equal(t, river.EventKindJobCompleted, result.EventKind) + require.Equal(t, []string{"process_ids"}, ran) + require.Equal(t, 42, receivedCursor.LastProcessedID) + }) + + t.Run("ResumeAtFirstCursorStep", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + + type testCursor struct { + Offset int `json:"offset"` + } + + var receivedCursor testCursor + worker := river.WorkFunc(func(ctx context.Context, job *river.Job[resumableTestArgs]) error { + river.ResumableStepCursor(ctx, "process", func(ctx context.Context, cursor testCursor) error { + receivedCursor = cursor + return nil + }) + return nil + }) + + config := &river.Config{ID: "rivertest-resumable"} + tw := NewWorker(t, bundle.driver, config, worker) + + opts := &river.InsertOpts{} + ResumeAtStepWithCursor(opts, "process", testCursor{Offset: 100}) + + result, err := tw.Work(ctx, t, bundle.tx, resumableTestArgs{}, opts) + require.NoError(t, err) + require.Equal(t, river.EventKindJobCompleted, result.EventKind) + require.Equal(t, 100, receivedCursor.Offset) + }) +} diff --git a/rivertest/worker.go b/rivertest/worker.go index 8ef06837..45c5fb8b 100644 --- a/rivertest/worker.go +++ b/rivertest/worker.go @@ -13,6 +13,7 @@ import ( "github.com/riverqueue/river/internal/jobexecutor" "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/middlewarelookup" + "github.com/riverqueue/river/internal/rivermiddleware" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/riversharedtest" @@ -206,7 +207,7 @@ func (w *Worker[T, TTx]) workJob(ctx context.Context, tb testing.TB, tx TTx, job HookLookupGlobal: hooklookup.NewHookLookup(w.config.Hooks), HookLookupByJob: hooklookup.NewJobHookLookup(), JobRow: job, - MiddlewareLookupGlobal: middlewarelookup.NewMiddlewareLookup(w.config.Middleware), + MiddlewareLookupGlobal: middlewarelookup.NewMiddlewareLookup(append(rivermiddleware.DefaultMiddleware(), w.config.Middleware...)), ProducerCallbacks: struct { JobDone func(jobRow *rivertype.JobRow) Stuck func()