Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions internal/consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@ package consumer

import (
"context"
"fmt"
"sync"
"time"

"github.com/hookdeck/outpost/internal/logging"
"github.com/hookdeck/outpost/internal/mqs"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
)

const (
defaultMaxConsecutiveErrors = 5
defaultInitialBackoff = 200 * time.Millisecond
defaultMaxBackoff = 5 * time.Second
)

type Consumer interface {
Run(context.Context) error
}
Expand All @@ -19,9 +27,12 @@ type MessageHandler interface {
}

type consumerImplOptions struct {
name string
concurrency int
logger *logging.Logger
name string
concurrency int
logger *logging.Logger
maxConsecutiveErrors int
initialBackoff time.Duration
maxBackoff time.Duration
}

func WithName(name string) func(*consumerImplOptions) {
Expand All @@ -42,10 +53,31 @@ func WithLogger(logger *logging.Logger) func(*consumerImplOptions) {
}
}

func WithMaxConsecutiveErrors(n int) func(*consumerImplOptions) {
return func(c *consumerImplOptions) {
c.maxConsecutiveErrors = n
}
}

func WithInitialBackoff(d time.Duration) func(*consumerImplOptions) {
return func(c *consumerImplOptions) {
c.initialBackoff = d
}
}

func WithMaxBackoff(d time.Duration) func(*consumerImplOptions) {
return func(c *consumerImplOptions) {
c.maxBackoff = d
}
}

func New(subscription mqs.Subscription, handler MessageHandler, opts ...func(*consumerImplOptions)) Consumer {
options := &consumerImplOptions{
name: "",
concurrency: 1,
name: "",
concurrency: 1,
maxConsecutiveErrors: defaultMaxConsecutiveErrors,
initialBackoff: defaultInitialBackoff,
maxBackoff: defaultMaxBackoff,
}
for _, opt := range opts {
opt(options)
Expand Down Expand Up @@ -76,17 +108,54 @@ func (c *consumerImpl) Run(ctx context.Context) error {
return c.runWithSemaphore(ctx)
}

// receiveWithRetry wraps subscription.Receive with exponential backoff on errors.
// Returns (nil, err) only after maxConsecutiveErrors consecutive failures.
func (c *consumerImpl) receiveWithRetry(ctx context.Context, consecutiveErrors *int) (*mqs.Message, error) {
for {
msg, err := c.subscription.Receive(ctx)
if err == nil {
*consecutiveErrors = 0
return msg, nil
}

*consecutiveErrors++
if *consecutiveErrors >= c.maxConsecutiveErrors {
return nil, fmt.Errorf("max consecutive receive errors reached (%d): %w", c.maxConsecutiveErrors, err)
}

backoff := c.initialBackoff * time.Duration(1<<(*consecutiveErrors-1))
if backoff > c.maxBackoff {
backoff = c.maxBackoff
}

if c.logger != nil {
c.logger.Ctx(ctx).Warn("consumer receive error, retrying",
zap.String("name", c.name),
zap.Error(err),
zap.Int("attempt", *consecutiveErrors),
zap.Duration("backoff", backoff))
}

select {
case <-time.After(backoff):
case <-ctx.Done():
return nil, ctx.Err()
}
}
}

// runConcurrent is used when the subscription manages flow control internally.
// A WaitGroup tracks in-flight handlers for graceful shutdown.
func (c *consumerImpl) runConcurrent(ctx context.Context) error {
tracer := otel.GetTracerProvider().Tracer("github.com/hookdeck/outpost/internal/consumer")

var wg sync.WaitGroup
var subscriptionErr error
consecutiveErrors := 0

recvLoop:
for {
msg, err := c.subscription.Receive(ctx)
msg, err := c.receiveWithRetry(ctx, &consecutiveErrors)
if err != nil {
subscriptionErr = err
break recvLoop
Expand Down Expand Up @@ -117,11 +186,12 @@ func (c *consumerImpl) runWithSemaphore(ctx context.Context) error {
tracer := otel.GetTracerProvider().Tracer("github.com/hookdeck/outpost/internal/consumer")

var subscriptionErr error
consecutiveErrors := 0

sem := make(chan struct{}, c.concurrency)
recvLoop:
for {
msg, err := c.subscription.Receive(ctx)
msg, err := c.receiveWithRetry(ctx, &consecutiveErrors)
if err != nil {
subscriptionErr = err
break recvLoop
Expand Down
91 changes: 91 additions & 0 deletions internal/consumer/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,85 @@ func TestConsumer_ConcurrentHandler(t *testing.T) {
test.run(t)
}

func TestConsumer_RetryTransientReceiveError(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Create a subscription that fails twice then succeeds.
errorCount := 0
messagesDelivered := 0
sub := &fakeSubscription{
receive: func(ctx context.Context) (*mqs.Message, error) {
if errorCount < 2 {
errorCount++
return nil, assert.AnError
}
if messagesDelivered >= 1 {
// Block until context is cancelled (no more messages).
<-ctx.Done()
return nil, ctx.Err()
}
messagesDelivered++
return &mqs.Message{Body: []byte("ok")}, nil
},
}

handled := make(chan string, 1)
handler := &handlerImpl{
handle: func(ctx context.Context, msg *mqs.Message) error {
handled <- string(msg.Body)
return nil
},
}

csm := consumer.New(sub, handler,
consumer.WithConcurrency(1),
consumer.WithMaxConsecutiveErrors(5),
consumer.WithInitialBackoff(10*time.Millisecond),
consumer.WithMaxBackoff(50*time.Millisecond),
)

go csm.Run(ctx)

select {
case body := <-handled:
assert.Equal(t, "ok", body)
assert.Equal(t, 2, errorCount, "should have retried through 2 transient errors")
case <-ctx.Done():
t.Fatal("timed out waiting for message to be handled")
}
}

func TestConsumer_ExhaustsRetriesOnPersistentError(t *testing.T) {
t.Parallel()

sub := &fakeSubscription{
receive: func(ctx context.Context) (*mqs.Message, error) {
return nil, assert.AnError
},
}

handler := &handlerImpl{
handle: func(ctx context.Context, msg *mqs.Message) error {
t.Fatal("handler should not be called")
return nil
},
}

csm := consumer.New(sub, handler,
consumer.WithConcurrency(1),
consumer.WithMaxConsecutiveErrors(3),
consumer.WithInitialBackoff(10*time.Millisecond),
consumer.WithMaxBackoff(50*time.Millisecond),
)

err := csm.Run(context.Background())
require.Error(t, err)
assert.Contains(t, err.Error(), "max consecutive receive errors reached (3)")
}

// ==================================== Mock ====================================

type Message struct {
Expand All @@ -178,6 +257,18 @@ func (m *Message) FromMessage(msg *mqs.Message) error {
return nil
}

type fakeSubscription struct {
receive func(context.Context) (*mqs.Message, error)
}

func (f *fakeSubscription) Receive(ctx context.Context) (*mqs.Message, error) {
return f.receive(ctx)
}

func (f *fakeSubscription) Shutdown(ctx context.Context) error {
return nil
}

type handlerImpl struct {
handle func(context.Context, *mqs.Message) error
}
Expand Down
Loading