diff --git a/internal/consumer/consumer.go b/internal/consumer/consumer.go index 94c64c4a..b734e241 100644 --- a/internal/consumer/consumer.go +++ b/internal/consumer/consumer.go @@ -2,7 +2,9 @@ package consumer import ( "context" + "fmt" "sync" + "time" "github.com/hookdeck/outpost/internal/logging" "github.com/hookdeck/outpost/internal/mqs" @@ -10,6 +12,12 @@ import ( "go.uber.org/zap" ) +const ( + defaultMaxConsecutiveErrors = 5 + defaultInitialBackoff = 200 * time.Millisecond + defaultMaxBackoff = 5 * time.Second +) + type Consumer interface { Run(context.Context) error } @@ -19,9 +27,12 @@ type MessageHandler interface { } type consumerImplOptions struct { - name string - concurrency int - logger *logging.Logger + name string + concurrency int + logger *logging.Logger + maxConsecutiveErrors int + initialBackoff time.Duration + maxBackoff time.Duration } func WithName(name string) func(*consumerImplOptions) { @@ -42,10 +53,31 @@ func WithLogger(logger *logging.Logger) func(*consumerImplOptions) { } } +func WithMaxConsecutiveErrors(n int) func(*consumerImplOptions) { + return func(c *consumerImplOptions) { + c.maxConsecutiveErrors = n + } +} + +func WithInitialBackoff(d time.Duration) func(*consumerImplOptions) { + return func(c *consumerImplOptions) { + c.initialBackoff = d + } +} + +func WithMaxBackoff(d time.Duration) func(*consumerImplOptions) { + return func(c *consumerImplOptions) { + c.maxBackoff = d + } +} + func New(subscription mqs.Subscription, handler MessageHandler, opts ...func(*consumerImplOptions)) Consumer { options := &consumerImplOptions{ - name: "", - concurrency: 1, + name: "", + concurrency: 1, + maxConsecutiveErrors: defaultMaxConsecutiveErrors, + initialBackoff: defaultInitialBackoff, + maxBackoff: defaultMaxBackoff, } for _, opt := range opts { opt(options) @@ -76,6 +108,42 @@ func (c *consumerImpl) Run(ctx context.Context) error { return c.runWithSemaphore(ctx) } +// receiveWithRetry wraps subscription.Receive with exponential backoff on errors. +// Returns (nil, err) only after maxConsecutiveErrors consecutive failures. +func (c *consumerImpl) receiveWithRetry(ctx context.Context, consecutiveErrors *int) (*mqs.Message, error) { + for { + msg, err := c.subscription.Receive(ctx) + if err == nil { + *consecutiveErrors = 0 + return msg, nil + } + + *consecutiveErrors++ + if *consecutiveErrors >= c.maxConsecutiveErrors { + return nil, fmt.Errorf("max consecutive receive errors reached (%d): %w", c.maxConsecutiveErrors, err) + } + + backoff := c.initialBackoff * time.Duration(1<<(*consecutiveErrors-1)) + if backoff > c.maxBackoff { + backoff = c.maxBackoff + } + + if c.logger != nil { + c.logger.Ctx(ctx).Warn("consumer receive error, retrying", + zap.String("name", c.name), + zap.Error(err), + zap.Int("attempt", *consecutiveErrors), + zap.Duration("backoff", backoff)) + } + + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + // runConcurrent is used when the subscription manages flow control internally. // A WaitGroup tracks in-flight handlers for graceful shutdown. func (c *consumerImpl) runConcurrent(ctx context.Context) error { @@ -83,10 +151,11 @@ func (c *consumerImpl) runConcurrent(ctx context.Context) error { var wg sync.WaitGroup var subscriptionErr error + consecutiveErrors := 0 recvLoop: for { - msg, err := c.subscription.Receive(ctx) + msg, err := c.receiveWithRetry(ctx, &consecutiveErrors) if err != nil { subscriptionErr = err break recvLoop @@ -117,11 +186,12 @@ func (c *consumerImpl) runWithSemaphore(ctx context.Context) error { tracer := otel.GetTracerProvider().Tracer("github.com/hookdeck/outpost/internal/consumer") var subscriptionErr error + consecutiveErrors := 0 sem := make(chan struct{}, c.concurrency) recvLoop: for { - msg, err := c.subscription.Receive(ctx) + msg, err := c.receiveWithRetry(ctx, &consecutiveErrors) if err != nil { subscriptionErr = err break recvLoop diff --git a/internal/consumer/consumer_test.go b/internal/consumer/consumer_test.go index 4688a087..c6773d92 100644 --- a/internal/consumer/consumer_test.go +++ b/internal/consumer/consumer_test.go @@ -161,6 +161,85 @@ func TestConsumer_ConcurrentHandler(t *testing.T) { test.run(t) } +func TestConsumer_RetryTransientReceiveError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Create a subscription that fails twice then succeeds. + errorCount := 0 + messagesDelivered := 0 + sub := &fakeSubscription{ + receive: func(ctx context.Context) (*mqs.Message, error) { + if errorCount < 2 { + errorCount++ + return nil, assert.AnError + } + if messagesDelivered >= 1 { + // Block until context is cancelled (no more messages). + <-ctx.Done() + return nil, ctx.Err() + } + messagesDelivered++ + return &mqs.Message{Body: []byte("ok")}, nil + }, + } + + handled := make(chan string, 1) + handler := &handlerImpl{ + handle: func(ctx context.Context, msg *mqs.Message) error { + handled <- string(msg.Body) + return nil + }, + } + + csm := consumer.New(sub, handler, + consumer.WithConcurrency(1), + consumer.WithMaxConsecutiveErrors(5), + consumer.WithInitialBackoff(10*time.Millisecond), + consumer.WithMaxBackoff(50*time.Millisecond), + ) + + go csm.Run(ctx) + + select { + case body := <-handled: + assert.Equal(t, "ok", body) + assert.Equal(t, 2, errorCount, "should have retried through 2 transient errors") + case <-ctx.Done(): + t.Fatal("timed out waiting for message to be handled") + } +} + +func TestConsumer_ExhaustsRetriesOnPersistentError(t *testing.T) { + t.Parallel() + + sub := &fakeSubscription{ + receive: func(ctx context.Context) (*mqs.Message, error) { + return nil, assert.AnError + }, + } + + handler := &handlerImpl{ + handle: func(ctx context.Context, msg *mqs.Message) error { + t.Fatal("handler should not be called") + return nil + }, + } + + csm := consumer.New(sub, handler, + consumer.WithConcurrency(1), + consumer.WithMaxConsecutiveErrors(3), + consumer.WithInitialBackoff(10*time.Millisecond), + consumer.WithMaxBackoff(50*time.Millisecond), + ) + + err := csm.Run(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "max consecutive receive errors reached (3)") +} + // ==================================== Mock ==================================== type Message struct { @@ -178,6 +257,18 @@ func (m *Message) FromMessage(msg *mqs.Message) error { return nil } +type fakeSubscription struct { + receive func(context.Context) (*mqs.Message, error) +} + +func (f *fakeSubscription) Receive(ctx context.Context) (*mqs.Message, error) { + return f.receive(ctx) +} + +func (f *fakeSubscription) Shutdown(ctx context.Context) error { + return nil +} + type handlerImpl struct { handle func(context.Context, *mqs.Message) error }