diff --git a/utils/py-utils/dl_utils/event_publisher.py b/utils/py-utils/dl_utils/event_publisher.py index b533ccba..c6bdee43 100644 --- a/utils/py-utils/dl_utils/event_publisher.py +++ b/utils/py-utils/dl_utils/event_publisher.py @@ -6,15 +6,23 @@ import json import logging +import time from typing import List, Dict, Any, Optional, Literal, Callable from uuid import uuid4 import boto3 +from botocore.config import Config from botocore.exceptions import ClientError from pydantic import ValidationError DlqReason = Literal['INVALID_EVENT', 'EVENTBRIDGE_FAILURE'] MAX_BATCH_SIZE = 10 +MAX_PUBLISHER_RETRIES = 3 +TRANSIENT_ERROR_CODES = { + 'ThrottlingException', + 'InternalFailure', + 'ServiceUnavailable' +} class EventPublisher: @@ -44,7 +52,10 @@ def __init__( self.event_bus_arn = event_bus_arn self.dlq_url = dlq_url self.logger = logger or logging.getLogger(__name__) - self.events_client = events_client or boto3.client('events') + self.events_client = events_client or boto3.client( + 'events', + config=Config(retries={'max_attempts': 3, 'mode': 'standard'}) + ) self.sqs_client = sqs_client or boto3.client('sqs') def _validate_cloud_event(self, event: Dict[str, Any], validator: Callable[..., Any]) -> tuple[bool, Optional[str]]: @@ -54,9 +65,70 @@ def _validate_cloud_event(self, event: Dict[str, Any], validator: Callable[..., try: validator(**event) return (True, None) - except Exception as e: + except ValidationError as e: return (False, str(e)) + def _classify_failed_entries( + self, + response: Dict[str, Any], + events: List[Dict[str, Any]] + ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + transient = [] + permanent = [] + + for entry, event in zip(response.get("Entries", []), events): + error_code = entry.get("ErrorCode") + if not error_code: + continue + + if error_code in TRANSIENT_ERROR_CODES: + transient.append(event) + else: + permanent.append(event) + + return transient, permanent + + def _send_batch_with_retry( + self, batch: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Send a single batch to EventBridge with retries for transient errors. + Returns a list of events that permanently failed. + """ + events_to_retry = batch + + for attempt in range(MAX_PUBLISHER_RETRIES): + entries = [ + { + "Source": event["source"], + "DetailType": event["type"], + "Detail": json.dumps(event), + "EventBusName": self.event_bus_arn, + } + for event in events_to_retry + ] + + try: + response = self.events_client.put_events(Entries=entries) + + transient, permanent = self._classify_failed_entries( + response, events_to_retry + ) + + if not transient: + return permanent + + if attempt == MAX_PUBLISHER_RETRIES - 1: + return transient + permanent + + events_to_retry = transient + time.sleep(2 ** attempt) + + except ClientError: + return events_to_retry + + return events_to_retry + def _send_to_event_bridge(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Send events to EventBridge in batches. @@ -82,56 +154,62 @@ def _send_to_event_bridge(self, events: List[Dict[str, Any]]) -> List[Dict[str, } ) - try: - entries = [ - { - 'Source': event['source'], - 'DetailType': event['type'], - 'Detail': json.dumps(event), - 'EventBusName': self.event_bus_arn - } - for event in batch - ] + batch_failures = self._send_batch_with_retry(batch) - response = self.events_client.put_events(Entries=entries) + if batch_failures: + for event in batch_failures: + self.logger.warning( + 'Event failed to send to EventBridge', + extra={'event_id': event.get('id')} + ) + failed_events.extend(batch_failures) - failed_count = response.get('FailedEntryCount', 0) - success_count = len(batch) - failed_count + return failed_events - self.logger.info( - 'EventBridge batch sent', - extra={ - 'batch_size': len(batch), - 'failed_entry_count': failed_count, - 'successful_count': success_count + def _build_dlq_entries( + self, + events: List[Dict[str, Any]], + reason: DlqReason + ) -> tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Build SQS batch entries for the DLQ and a mapping of entry IDs to events""" + id_to_event_map = {} + entries = [] + for event in events: + entry_id = str(uuid4()) + id_to_event_map[entry_id] = event + entries.append({ + 'Id': entry_id, + 'MessageBody': json.dumps(event), + 'MessageAttributes': { + 'DlqReason': { + 'DataType': 'String', + 'StringValue': reason } - ) - - # Track failed entries - if failed_count > 0 and 'Entries' in response: - for idx, entry in enumerate(response['Entries']): - if 'ErrorCode' in entry: - self.logger.warning( - 'Event failed to send to EventBridge', - extra={ - 'error_code': entry.get('ErrorCode'), - 'error_message': entry.get('ErrorMessage'), - 'event_id': batch[idx].get('id') - } - ) - failed_events.append(batch[idx]) + } + }) + return entries, id_to_event_map - except ClientError as error: + def _extract_failed_dlq_events( + self, + response: Dict[str, Any], + id_to_event_map: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """Extract events that failed to send to the DLQ from a send_message_batch response.""" + failed = [] + for failed_entry in response.get('Failed', []): + entry_id = failed_entry.get('Id') + if entry_id and entry_id in id_to_event_map: + failed_event = id_to_event_map[entry_id] self.logger.warning( - 'EventBridge send error', + 'Event failed to send to DLQ', extra={ - 'error': str(error), - 'batch_size': len(batch) + 'error_code': failed_entry.get('Code'), + 'error_message': failed_entry.get('Message'), + 'event_id': failed_event.get('id') } ) - failed_events.extend(batch) - - return failed_events + failed.append(failed_event) + return failed def _send_to_dlq( self, @@ -154,44 +232,14 @@ def _send_to_dlq( for i in range(0, len(events), MAX_BATCH_SIZE): batch = events[i:i + MAX_BATCH_SIZE] - id_to_event_map = {} - - entries = [] - for event in batch: - entry_id = str(uuid4()) - id_to_event_map[entry_id] = event - entries.append({ - 'Id': entry_id, - 'MessageBody': json.dumps(event), - 'MessageAttributes': { - 'DlqReason': { - 'DataType': 'String', - 'StringValue': reason - } - } - }) + entries, id_to_event_map = self._build_dlq_entries(batch, reason) try: response = self.sqs_client.send_message_batch( QueueUrl=self.dlq_url, Entries=entries ) - - # Track failed DLQ sends - if 'Failed' in response: - for failed_entry in response['Failed']: - entry_id = failed_entry.get('Id') - if entry_id and entry_id in id_to_event_map: - failed_event = id_to_event_map[entry_id] - self.logger.warning( - 'Event failed to send to DLQ', - extra={ - 'error_code': failed_entry.get('Code'), - 'error_message': failed_entry.get('Message'), - 'event_id': failed_event.get('id') - } - ) - failed_dlqs.append(failed_event) + failed_dlqs.extend(self._extract_failed_dlq_events(response, id_to_event_map)) except ClientError as error: self.logger.warning(