Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 123 additions & 75 deletions utils/py-utils/dl_utils/event_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@

import json
import logging
import time
from typing import List, Dict, Any, Optional, Literal, Callable
from uuid import uuid4
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from pydantic import ValidationError


DlqReason = Literal['INVALID_EVENT', 'EVENTBRIDGE_FAILURE']
MAX_BATCH_SIZE = 10
MAX_PUBLISHER_RETRIES = 3
TRANSIENT_ERROR_CODES = {
'ThrottlingException',
'InternalFailure',
'ServiceUnavailable'
}


class EventPublisher:
Expand Down Expand Up @@ -44,7 +52,10 @@ def __init__(
self.event_bus_arn = event_bus_arn
self.dlq_url = dlq_url
self.logger = logger or logging.getLogger(__name__)
self.events_client = events_client or boto3.client('events')
self.events_client = events_client or boto3.client(
'events',
config=Config(retries={'max_attempts': 3, 'mode': 'standard'})
)
self.sqs_client = sqs_client or boto3.client('sqs')

def _validate_cloud_event(self, event: Dict[str, Any], validator: Callable[..., Any]) -> tuple[bool, Optional[str]]:
Expand All @@ -54,9 +65,70 @@ def _validate_cloud_event(self, event: Dict[str, Any], validator: Callable[...,
try:
validator(**event)
return (True, None)
except Exception as e:
except ValidationError as e:
return (False, str(e))

def _classify_failed_entries(
self,
response: Dict[str, Any],
events: List[Dict[str, Any]]
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
transient = []
permanent = []

for entry, event in zip(response.get("Entries", []), events):
error_code = entry.get("ErrorCode")
if not error_code:
continue

if error_code in TRANSIENT_ERROR_CODES:
transient.append(event)
else:
permanent.append(event)

return transient, permanent

def _send_batch_with_retry(
self, batch: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Send a single batch to EventBridge with retries for transient errors.
Returns a list of events that permanently failed.
"""
events_to_retry = batch

for attempt in range(MAX_PUBLISHER_RETRIES):
entries = [
{
"Source": event["source"],
"DetailType": event["type"],
"Detail": json.dumps(event),
"EventBusName": self.event_bus_arn,
}
for event in events_to_retry
]

try:
response = self.events_client.put_events(Entries=entries)

transient, permanent = self._classify_failed_entries(
response, events_to_retry
)

if not transient:
return permanent

if attempt == MAX_PUBLISHER_RETRIES - 1:
return transient + permanent

events_to_retry = transient
time.sleep(2 ** attempt)

except ClientError:
return events_to_retry

return events_to_retry

def _send_to_event_bridge(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Send events to EventBridge in batches.
Expand All @@ -82,56 +154,62 @@ def _send_to_event_bridge(self, events: List[Dict[str, Any]]) -> List[Dict[str,
}
)

try:
entries = [
{
'Source': event['source'],
'DetailType': event['type'],
'Detail': json.dumps(event),
'EventBusName': self.event_bus_arn
}
for event in batch
]
batch_failures = self._send_batch_with_retry(batch)

response = self.events_client.put_events(Entries=entries)
if batch_failures:
for event in batch_failures:
self.logger.warning(
'Event failed to send to EventBridge',
extra={'event_id': event.get('id')}
)
failed_events.extend(batch_failures)

failed_count = response.get('FailedEntryCount', 0)
success_count = len(batch) - failed_count
return failed_events

self.logger.info(
'EventBridge batch sent',
extra={
'batch_size': len(batch),
'failed_entry_count': failed_count,
'successful_count': success_count
def _build_dlq_entries(
self,
events: List[Dict[str, Any]],
reason: DlqReason
) -> tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Build SQS batch entries for the DLQ and a mapping of entry IDs to events"""
id_to_event_map = {}
entries = []
for event in events:
entry_id = str(uuid4())
id_to_event_map[entry_id] = event
entries.append({
'Id': entry_id,
'MessageBody': json.dumps(event),
'MessageAttributes': {
'DlqReason': {
'DataType': 'String',
'StringValue': reason
}
)

# Track failed entries
if failed_count > 0 and 'Entries' in response:
for idx, entry in enumerate(response['Entries']):
if 'ErrorCode' in entry:
self.logger.warning(
'Event failed to send to EventBridge',
extra={
'error_code': entry.get('ErrorCode'),
'error_message': entry.get('ErrorMessage'),
'event_id': batch[idx].get('id')
}
)
failed_events.append(batch[idx])
}
})
return entries, id_to_event_map

except ClientError as error:
def _extract_failed_dlq_events(
self,
response: Dict[str, Any],
id_to_event_map: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Extract events that failed to send to the DLQ from a send_message_batch response."""
failed = []
for failed_entry in response.get('Failed', []):
entry_id = failed_entry.get('Id')
if entry_id and entry_id in id_to_event_map:
failed_event = id_to_event_map[entry_id]
self.logger.warning(
'EventBridge send error',
'Event failed to send to DLQ',
extra={
'error': str(error),
'batch_size': len(batch)
'error_code': failed_entry.get('Code'),
'error_message': failed_entry.get('Message'),
'event_id': failed_event.get('id')
}
)
failed_events.extend(batch)

return failed_events
failed.append(failed_event)
return failed

def _send_to_dlq(
self,
Expand All @@ -154,44 +232,14 @@ def _send_to_dlq(

for i in range(0, len(events), MAX_BATCH_SIZE):
batch = events[i:i + MAX_BATCH_SIZE]
id_to_event_map = {}

entries = []
for event in batch:
entry_id = str(uuid4())
id_to_event_map[entry_id] = event
entries.append({
'Id': entry_id,
'MessageBody': json.dumps(event),
'MessageAttributes': {
'DlqReason': {
'DataType': 'String',
'StringValue': reason
}
}
})
entries, id_to_event_map = self._build_dlq_entries(batch, reason)

try:
response = self.sqs_client.send_message_batch(
QueueUrl=self.dlq_url,
Entries=entries
)

# Track failed DLQ sends
if 'Failed' in response:
for failed_entry in response['Failed']:
entry_id = failed_entry.get('Id')
if entry_id and entry_id in id_to_event_map:
failed_event = id_to_event_map[entry_id]
self.logger.warning(
'Event failed to send to DLQ',
extra={
'error_code': failed_entry.get('Code'),
'error_message': failed_entry.get('Message'),
'event_id': failed_event.get('id')
}
)
failed_dlqs.append(failed_event)
failed_dlqs.extend(self._extract_failed_dlq_events(response, id_to_event_map))

except ClientError as error:
self.logger.warning(
Expand Down
Loading