diff --git a/examples/cor_request_fallback.py b/examples/cor_request_fallback.py new file mode 100644 index 0000000..0a193cc --- /dev/null +++ b/examples/cor_request_fallback.py @@ -0,0 +1,247 @@ +""" +Example: Chain of Responsibility with Request Handler Fallback + +This example shows how to combine a COR (Chain of Responsibility) handler +with RequestHandlerFallback. The primary handler is a RequestHandler that +delegates to a COR chain; when the chain raises (e.g. downstream failure), +the fallback handler is invoked. + +Use case: A request is first tried through a chain of handlers (e.g. try +cache, then DB, then external API). If the whole chain fails (e.g. connection +error), a fallback handler returns a default/cached response. + +================================================================================ +HOW TO RUN THIS EXAMPLE +================================================================================ + +Run the example: + python examples/cor_request_fallback.py + +The example will: +- Send a command that is handled by a COR chain (primary path) +- For source="error", the chain raises and fallback handler runs +- For source="a" or "b", the chain handles the request successfully + +================================================================================ +WHAT THIS EXAMPLE DEMONSTRATES +================================================================================ + +1. RequestHandlerFallback with COR as primary: + - Primary is a RequestHandler that delegates to a COR chain (injected via DI). + - Fallback is a simple RequestHandler used when the chain raises. + +2. Building the chain: + - Create COR handler instances, build_chain(), then bind the chain entry + (first handler) in the container so the wrapper can receive it. + +3. Flow: + - mediator.send(request) dispatches to primary (CORChainWrapperHandler). + - Wrapper calls the chain; if the chain raises, dispatcher catches and + invokes fallback. + +4. Optional failure_exceptions: + - Restrict fallback to specific exception types (e.g. ConnectionError). + +================================================================================ +REQUIREMENTS +================================================================================ + +Make sure you have installed: + - cqrs (this package) + - di (dependency injection) + +================================================================================ +""" + +import asyncio +import logging + +import di +from di import dependent + +import cqrs +from cqrs.requests import bootstrap +from cqrs.requests.cor_request_handler import ( + CORRequestHandler, + build_chain, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +HANDLER_SOURCE: list[str] = [] # "chain" or "fallback" + + +# ----------------------------------------------------------------------------- +# Command and response +# ----------------------------------------------------------------------------- + + +class FetchDataCommand(cqrs.Request): + source: str # "a" | "b" | "error" + + +class FetchDataResult(cqrs.Response): + data: str + source: str # "chain" or "fallback" + + +# ----------------------------------------------------------------------------- +# COR handlers (chain) +# ----------------------------------------------------------------------------- + + +class SourceAHandler(CORRequestHandler[FetchDataCommand, FetchDataResult]): + @property + def events(self) -> list[cqrs.Event]: + return [] + + async def handle(self, request: FetchDataCommand) -> FetchDataResult | None: + if request.source == "a": + logger.info("COR chain: SourceAHandler handled source=a") + HANDLER_SOURCE.append("chain") + return FetchDataResult(data="data_from_a", source="chain") + return await self.next(request) + + +class SourceBHandler(CORRequestHandler[FetchDataCommand, FetchDataResult]): + @property + def events(self) -> list[cqrs.Event]: + return [] + + async def handle(self, request: FetchDataCommand) -> FetchDataResult | None: + if request.source == "b": + logger.info("COR chain: SourceBHandler handled source=b") + HANDLER_SOURCE.append("chain") + return FetchDataResult(data="data_from_b", source="chain") + return await self.next(request) + + +class DefaultChainHandler(CORRequestHandler[FetchDataCommand, FetchDataResult]): + """Last in chain: handles unknown or raises for source='error'.""" + + @property + def events(self) -> list[cqrs.Event]: + return [] + + async def handle(self, request: FetchDataCommand) -> FetchDataResult | None: + if request.source == "error": + logger.info("COR chain: DefaultChainHandler raising ConnectionError for source=error") + raise ConnectionError("Downstream service unavailable") + logger.info("COR chain: DefaultChainHandler handled (unknown source)") + HANDLER_SOURCE.append("chain") + return FetchDataResult(data="default_data", source="chain") + + +# ----------------------------------------------------------------------------- +# Wrapper: RequestHandler that delegates to the COR chain +# ----------------------------------------------------------------------------- + + +class CORChainWrapperHandler( + cqrs.RequestHandler[FetchDataCommand, FetchDataResult], +): + """Primary 'handler' that runs the COR chain; chain is injected as the first link.""" + + def __init__(self, chain_entry: SourceAHandler) -> None: + self._chain_entry = chain_entry + + @property + def events(self) -> list[cqrs.Event]: + return [] + + async def handle(self, request: FetchDataCommand) -> FetchDataResult: + result = await self._chain_entry.handle(request) + if result is None: + raise ValueError("COR chain did not handle the request") + return result + + +# ----------------------------------------------------------------------------- +# Fallback handler (used when the chain raises) +# ----------------------------------------------------------------------------- + + +class FallbackFetchDataHandler( + cqrs.RequestHandler[FetchDataCommand, FetchDataResult], +): + @property + def events(self) -> list[cqrs.Event]: + return [] + + async def handle(self, request: FetchDataCommand) -> FetchDataResult: + logger.info("Fallback handler: returning cached/default for source=%s", request.source) + HANDLER_SOURCE.append("fallback") + return FetchDataResult( + data="cached_or_default", + source="fallback", + ) + + +# ----------------------------------------------------------------------------- +# Mappers and bootstrap +# ----------------------------------------------------------------------------- + + +def commands_mapper(mapper: cqrs.RequestMap) -> None: + mapper.bind( + FetchDataCommand, + cqrs.RequestHandlerFallback( + primary=CORChainWrapperHandler, + fallback=FallbackFetchDataHandler, + failure_exceptions=(ConnectionError, TimeoutError), + ), + ) + + +async def main() -> None: + HANDLER_SOURCE.clear() + + # Build COR chain and inject the chain entry so CORChainWrapperHandler gets it + source_a = SourceAHandler() + source_b = SourceBHandler() + default = DefaultChainHandler() + build_chain([source_a, source_b, default]) + + di_container = di.Container() + di_container.bind( + di.bind_by_type( + dependent.Dependent(lambda: source_a, scope="request"), + SourceAHandler, + ), + ) + + mediator = bootstrap.bootstrap( + di_container=di_container, + commands_mapper=commands_mapper, + ) + + print("\n" + "=" * 60) + print("COR REQUEST HANDLER FALLBACK EXAMPLE") + print("=" * 60) + + # Case 1: chain handles (source=a) + print("\n1. Send FetchDataCommand(source='a') — chain handles") + result1: FetchDataResult = await mediator.send(FetchDataCommand(source="a")) + print(f" Result: data={result1.data}, source={result1.source}") + assert result1.source == "chain" and result1.data == "data_from_a" + + # Case 2: chain handles (source=b) + print("\n2. Send FetchDataCommand(source='b') — chain handles") + result2: FetchDataResult = await mediator.send(FetchDataCommand(source="b")) + print(f" Result: data={result2.data}, source={result2.source}") + assert result2.source == "chain" and result2.data == "data_from_b" + + # Case 3: chain raises (source=error) -> fallback runs + print("\n3. Send FetchDataCommand(source='error') — chain raises, fallback runs") + result3: FetchDataResult = await mediator.send(FetchDataCommand(source="error")) + print(f" Result: data={result3.data}, source={result3.source}") + assert result3.source == "fallback" and result3.data == "cached_or_default" + + print("\n Handlers that ran (in order): " + str(HANDLER_SOURCE)) + assert "chain" in HANDLER_SOURCE and "fallback" in HANDLER_SOURCE + print("\n" + "=" * 60 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/event_fallback.py b/examples/event_fallback.py new file mode 100644 index 0000000..95ef7cb --- /dev/null +++ b/examples/event_fallback.py @@ -0,0 +1,222 @@ +""" +Example: Event Handler Fallback with Optional Circuit Breaker + +This example demonstrates the EventHandlerFallback pattern for domain event +handlers. When the primary event handler fails (or the circuit breaker is open), +the fallback handler is invoked. This is useful for resilient side effects +such as sending notifications or updating read models when the primary path +(e.g. external API) is unavailable. + +================================================================================ +HOW TO RUN THIS EXAMPLE +================================================================================ + +Run the example (without circuit breaker): + python examples/event_fallback.py + +With circuit breaker (optional dependency): + pip install aiobreaker + python examples/event_fallback.py + +The example will: +- Execute a command that emits a domain event +- Primary event handler fails (simulated external service failure) +- Fallback event handler runs and completes successfully +- With circuit breaker: after N failures the circuit opens and fallback is + used without calling the primary handler + +================================================================================ +WHAT THIS EXAMPLE DEMONSTRATES +================================================================================ + +1. EventHandlerFallback Registration: + - Bind event type to EventHandlerFallback(primary, fallback, ...) + - Optional failure_exceptions to trigger fallback only for specific errors + - Optional circuit_breaker (e.g. AioBreakerAdapter) per domain + +2. Primary and Fallback Handlers: + - Primary handler implements EventHandler[EventType]; can raise + - Fallback handler implements same event type; runs when primary fails + +3. Flow: + - Command handler emits domain event + - EventEmitter runs handlers; for EventHandlerFallback runs primary first + - On primary exception (or circuit open): fallback handler is invoked + - Events from the handler that actually ran are collected and returned + +4. Circuit Breaker (optional): + - Use one AioBreakerAdapter instance per domain (e.g. events) + - After fail_max failures, circuit opens; primary is not called, fallback runs + +================================================================================ +REQUIREMENTS +================================================================================ + +Make sure you have installed: + - cqrs (this package) + - di (dependency injection) + +Optional for circuit breaker: + pip install aiobreaker + or: pip install python-cqrs[aiobreaker] + +================================================================================ +""" + +import asyncio +import logging +import di + +import cqrs +from cqrs.adapters.circuit_breaker import AioBreakerAdapter +from cqrs.requests import bootstrap + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Track which handler ran for demo output +EVENTS_HANDLED_BY: list[str] = [] + + +# ----------------------------------------------------------------------------- +# Command and domain event +# ----------------------------------------------------------------------------- + + +class SendNotificationCommand(cqrs.Request): + user_id: str + message: str + + +class NotificationSent(cqrs.DomainEvent, frozen=True): + user_id: str + message: str + + +# ----------------------------------------------------------------------------- +# Command handler (emits domain event) +# ----------------------------------------------------------------------------- + + +class SendNotificationCommandHandler(cqrs.RequestHandler[SendNotificationCommand, None]): + @property + def events(self) -> list[cqrs.Event]: + return self._events + + def __init__(self) -> None: + self._events: list[cqrs.Event] = [] + + async def handle(self, request: SendNotificationCommand) -> None: + self._events.append( + NotificationSent(user_id=request.user_id, message=request.message), + ) + logger.info("Command: emitted NotificationSent for user %s", request.user_id) + + +# ----------------------------------------------------------------------------- +# Primary event handler (simulates failure – e.g. external notification API down) +# ----------------------------------------------------------------------------- + + +class PrimaryNotificationSentHandler(cqrs.EventHandler[NotificationSent]): + async def handle(self, event: NotificationSent) -> None: + logger.info( + "Primary handler: would send notification to user %s: %s", + event.user_id, + event.message, + ) + EVENTS_HANDLED_BY.append("primary") + raise RuntimeError("External notification service unavailable") + + +# ----------------------------------------------------------------------------- +# Fallback event handler (e.g. write to local queue or log) +# ----------------------------------------------------------------------------- + + +class FallbackNotificationSentHandler(cqrs.EventHandler[NotificationSent]): + async def handle(self, event: NotificationSent) -> None: + logger.info( + "Fallback handler: enqueue notification for user %s (primary failed): %s", + event.user_id, + event.message, + ) + EVENTS_HANDLED_BY.append("fallback") + + +# ----------------------------------------------------------------------------- +# Mappers and bootstrap +# ----------------------------------------------------------------------------- + + +def command_mapper(mapper: cqrs.RequestMap) -> None: + mapper.bind(SendNotificationCommand, SendNotificationCommandHandler) + + +def events_mapper(mapper: cqrs.EventMap) -> None: + # Without circuit breaker: any exception from primary triggers fallback + mapper.bind( + NotificationSent, + cqrs.EventHandlerFallback( + primary=PrimaryNotificationSentHandler, + fallback=FallbackNotificationSentHandler, + ), + ) + + +def events_mapper_with_circuit_breaker(mapper: cqrs.EventMap) -> None: + try: + event_cb = AioBreakerAdapter(fail_max=2, timeout_duration=60) + except ImportError: + # No aiobreaker: use same as without circuit breaker + events_mapper(mapper) + return + mapper.bind( + NotificationSent, + cqrs.EventHandlerFallback( + primary=PrimaryNotificationSentHandler, + fallback=FallbackNotificationSentHandler, + circuit_breaker=event_cb, + ), + ) + + +async def main() -> None: + EVENTS_HANDLED_BY.clear() + + use_circuit_breaker = False + try: + import aiobreaker # noqa: F401 + + use_circuit_breaker = True + except ImportError: + pass + + events_mapper_fn = events_mapper_with_circuit_breaker if use_circuit_breaker else events_mapper + + mediator = bootstrap.bootstrap( + di_container=di.Container(), + commands_mapper=command_mapper, + domain_events_mapper=events_mapper_fn, + ) + + print("\n" + "=" * 60) + print("EVENT HANDLER FALLBACK EXAMPLE") + print("=" * 60) + print("\nSending command that emits NotificationSent...") + print("Primary handler will fail; fallback handler will run.\n") + + await mediator.send( + SendNotificationCommand(user_id="user_1", message="Hello from CQRS"), + ) + + print("\nResult:") + print(f" Handlers that ran (in order): {EVENTS_HANDLED_BY}") + assert "primary" in EVENTS_HANDLED_BY and "fallback" in EVENTS_HANDLED_BY + assert EVENTS_HANDLED_BY[-1] == "fallback" + print(" ✓ Primary ran and failed; fallback ran and completed.") + print("\n" + "=" * 60 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/request_fallback.py b/examples/request_fallback.py new file mode 100644 index 0000000..2f5a63c --- /dev/null +++ b/examples/request_fallback.py @@ -0,0 +1,213 @@ +""" +Example: Request Handler Fallback with Optional Circuit Breaker + +This example demonstrates the RequestHandlerFallback pattern for command/query +handlers. When the primary request handler fails (or the circuit breaker is +open), the fallback handler is invoked. This is useful for resilient reads +or writes when the primary path (e.g. database or external API) is unavailable. + +================================================================================ +HOW TO RUN THIS EXAMPLE +================================================================================ + +Run the example (without circuit breaker): + python examples/request_fallback.py + +With circuit breaker (optional dependency): + pip install aiobreaker + python examples/request_fallback.py + +The example will: +- Send a command that is handled by a primary handler (simulated to fail) +- Fallback handler runs and returns a valid response +- With circuit breaker: after N failures the circuit opens and requests are + dispatched to fallback without calling the primary handler + +================================================================================ +WHAT THIS EXAMPLE DEMONSTRATES +================================================================================ + +1. RequestHandlerFallback Registration: + - Bind request type to RequestHandlerFallback(primary, fallback, ...) + - Optional failure_exceptions to trigger fallback only for specific errors + - Optional circuit_breaker (e.g. AioBreakerAdapter) per domain + +2. Primary and Fallback Handlers: + - Both implement RequestHandler[Request, Response] + - Primary can raise; fallback provides alternative implementation (e.g. cache) + +3. Flow: + - mediator.send(request) dispatches to primary handler + - On primary exception (or circuit open): fallback handler is invoked + - Response and events from the handler that ran are returned + +4. Circuit Breaker (optional): + - Use one AioBreakerAdapter instance per domain (e.g. commands) + - After fail_max failures, circuit opens; primary is not called, fallback runs + +================================================================================ +REQUIREMENTS +================================================================================ + +Make sure you have installed: + - cqrs (this package) + - di (dependency injection) + +Optional for circuit breaker: + pip install aiobreaker + or: pip install python-cqrs[aiobreaker] + +================================================================================ +""" + +import asyncio +import logging +import di + +import cqrs +from cqrs.adapters.circuit_breaker import AioBreakerAdapter +from cqrs.requests import bootstrap + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +HANDLER_USED: list[str] = [] + + +# ----------------------------------------------------------------------------- +# Command and response +# ----------------------------------------------------------------------------- + + +class GetUserProfileCommand(cqrs.Request): + user_id: str + + +class UserProfileResult(cqrs.Response): + user_id: str + name: str + source: str # "primary" or "fallback" + + +# ----------------------------------------------------------------------------- +# Primary handler (simulates failure – e.g. database unavailable) +# ----------------------------------------------------------------------------- + + +class PrimaryGetUserProfileHandler( + cqrs.RequestHandler[GetUserProfileCommand, UserProfileResult], +): + @property + def events(self) -> list[cqrs.Event]: + return [] + + async def handle( + self, + request: GetUserProfileCommand, + ) -> UserProfileResult: + logger.info("Primary handler: fetching profile for user %s", request.user_id) + HANDLER_USED.append("primary") + raise ConnectionError("Database unavailable") + + +# ----------------------------------------------------------------------------- +# Fallback handler (e.g. return cached or default data) +# ----------------------------------------------------------------------------- + + +class FallbackGetUserProfileHandler( + cqrs.RequestHandler[GetUserProfileCommand, UserProfileResult], +): + @property + def events(self) -> list[cqrs.Event]: + return [] + + async def handle( + self, + request: GetUserProfileCommand, + ) -> UserProfileResult: + logger.info( + "Fallback handler: returning cached/default profile for user %s", + request.user_id, + ) + HANDLER_USED.append("fallback") + return UserProfileResult( + user_id=request.user_id, + name="Unknown User", + source="fallback", + ) + + +# ----------------------------------------------------------------------------- +# Mappers and bootstrap +# ----------------------------------------------------------------------------- + + +def commands_mapper(mapper: cqrs.RequestMap) -> None: + # Without circuit breaker: fallback on any exception (or restrict with failure_exceptions) + mapper.bind( + GetUserProfileCommand, + cqrs.RequestHandlerFallback( + primary=PrimaryGetUserProfileHandler, + fallback=FallbackGetUserProfileHandler, + failure_exceptions=(ConnectionError, TimeoutError), + ), + ) + + +def commands_mapper_with_circuit_breaker(mapper: cqrs.RequestMap) -> None: + try: + request_cb = AioBreakerAdapter(fail_max=2, timeout_duration=60) + except ImportError: + commands_mapper(mapper) + return + mapper.bind( + GetUserProfileCommand, + cqrs.RequestHandlerFallback( + primary=PrimaryGetUserProfileHandler, + fallback=FallbackGetUserProfileHandler, + failure_exceptions=(ConnectionError, TimeoutError), + circuit_breaker=request_cb, + ), + ) + + +async def main() -> None: + HANDLER_USED.clear() + + use_circuit_breaker = False + try: + import aiobreaker # noqa: F401 + + use_circuit_breaker = True + except ImportError: + pass + + commands_mapper_fn = commands_mapper_with_circuit_breaker if use_circuit_breaker else commands_mapper + + mediator = bootstrap.bootstrap( + di_container=di.Container(), + commands_mapper=commands_mapper_fn, + ) + + print("\n" + "=" * 60) + print("REQUEST HANDLER FALLBACK EXAMPLE") + print("=" * 60) + print("\nSending GetUserProfileCommand (primary will fail)...\n") + + result: UserProfileResult = await mediator.send( + GetUserProfileCommand(user_id="user_42"), + ) + + print("\nResult:") + print(f" Handlers that ran (in order): {HANDLER_USED}") + print(f" Response: user_id={result.user_id}, name={result.name}, source={result.source}") + assert result.source == "fallback" + assert "primary" in HANDLER_USED and "fallback" in HANDLER_USED + assert HANDLER_USED[-1] == "fallback" + print(" ✓ Primary ran and failed; fallback ran and returned response.") + print("\n" + "=" * 60 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/streaming_handler_fallback.py b/examples/streaming_handler_fallback.py new file mode 100644 index 0000000..0820319 --- /dev/null +++ b/examples/streaming_handler_fallback.py @@ -0,0 +1,193 @@ +""" +Example: Streaming Request Handler Fallback + +This example demonstrates RequestHandlerFallback with StreamingRequestHandler. +When the primary streaming handler fails (e.g. raises after yielding some items), +the fallback streaming handler is used and its stream is consumed. + +Use case: Stream results from a primary source (e.g. live API); if the stream +fails mid-way, switch to a fallback stream (e.g. cached or degraded results). + +================================================================================ +HOW TO RUN THIS EXAMPLE +================================================================================ + +Run the example: + python examples/streaming_handler_fallback.py + +The example will: +- Start streaming from the primary handler (yields a few items then raises) +- After the exception, the fallback streaming handler runs and yields items +- Collect and print all results from both handlers + +================================================================================ +WHAT THIS EXAMPLE DEMONSTRATES +================================================================================ + +1. RequestHandlerFallback with streaming handlers: + - primary and fallback are both StreamingRequestHandler (async generators). + - Dispatcher runs primary.handle(request); if it raises, runs fallback.handle(request). + +2. Flow: + - mediator.stream(request) yields results from the primary handler. + - When the primary raises, the dispatcher catches and continues with the + fallback handler's stream. + +3. Optional failure_exceptions and circuit_breaker: + - Same as for non-streaming RequestHandlerFallback. + +================================================================================ +REQUIREMENTS +================================================================================ + +Make sure you have installed: + - cqrs (this package) + - di (dependency injection) + +================================================================================ +""" + +from collections.abc import AsyncIterator + +import asyncio +import logging +import typing + +import di + +import cqrs +from cqrs.requests import bootstrap + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +STREAM_SOURCE: list[str] = [] # "primary" or "fallback" per yield + + +# ----------------------------------------------------------------------------- +# Request and response (streaming) +# ----------------------------------------------------------------------------- + + +class StreamItemsCommand(cqrs.Request): + item_ids: list[str] + + +class StreamItemResult(cqrs.Response): + item_id: str + status: str + source: str # "primary" or "fallback" + + +# ----------------------------------------------------------------------------- +# Primary streaming handler (yields twice then raises) +# ----------------------------------------------------------------------------- + + +class PrimaryStreamItemsHandler( + cqrs.StreamingRequestHandler[StreamItemsCommand, StreamItemResult], +): + def __init__(self) -> None: + self._events: list[cqrs.Event] = [] + + @property + def events(self) -> list[cqrs.Event]: + return self._events.copy() + + def clear_events(self) -> None: + self._events.clear() + + async def handle( + self, + request: StreamItemsCommand, + ) -> AsyncIterator[StreamItemResult]: + for i, item_id in enumerate(request.item_ids): + if i >= 2: + logger.info("Primary streaming handler raising after 2 items") + raise ConnectionError("Stream connection lost") + STREAM_SOURCE.append("primary") + yield StreamItemResult( + item_id=item_id, + status="processed", + source="primary", + ) + + +# ----------------------------------------------------------------------------- +# Fallback streaming handler (yields all items) +# ----------------------------------------------------------------------------- + + +class FallbackStreamItemsHandler( + cqrs.StreamingRequestHandler[StreamItemsCommand, StreamItemResult], +): + def __init__(self) -> None: + self._events: list[cqrs.Event] = [] + + @property + def events(self) -> list[cqrs.Event]: + return self._events.copy() + + def clear_events(self) -> None: + self._events.clear() + + async def handle( + self, + request: StreamItemsCommand, + ) -> AsyncIterator[StreamItemResult]: + for item_id in request.item_ids: + STREAM_SOURCE.append("fallback") + yield StreamItemResult( + item_id=item_id, + status="from_fallback", + source="fallback", + ) + + +# ----------------------------------------------------------------------------- +# Mapper and bootstrap +# ----------------------------------------------------------------------------- + + +def commands_mapper(mapper: cqrs.RequestMap) -> None: + mapper.bind( + StreamItemsCommand, + cqrs.RequestHandlerFallback( + primary=PrimaryStreamItemsHandler, + fallback=FallbackStreamItemsHandler, + failure_exceptions=(ConnectionError, TimeoutError), + ), + ) + + +async def main() -> None: + STREAM_SOURCE.clear() + + mediator = bootstrap.bootstrap_streaming( + di_container=di.Container(), + commands_mapper=commands_mapper, + ) + + print("\n" + "=" * 60) + print("STREAMING HANDLER FALLBACK EXAMPLE") + print("=" * 60) + print("\nStreaming items (primary will fail after 2 items, then fallback runs)...\n") + + request = StreamItemsCommand(item_ids=["id1", "id2", "id3", "id4"]) + results: list[StreamItemResult] = [] + async for response in mediator.stream(request): + if response is not None: + r = typing.cast(StreamItemResult, response) + results.append(r) + print(f" Yield: item_id={r.item_id}, status={r.status}, source={r.source}") + + print("\n Handlers that yielded (in order): " + str(STREAM_SOURCE)) + assert "primary" in STREAM_SOURCE and "fallback" in STREAM_SOURCE + assert results[0].source == "primary" and results[1].source == "primary" + assert any(r.source == "fallback" for r in results) + print("\n ✓ Primary stream failed; fallback stream completed.") + print("=" * 60 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index fc0491a..5f74f6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ maintainers = [{name = "Vadim Kozyrevskiy", email = "vadikko2@mail.ru"}] name = "python-cqrs" readme = "README.md" requires-python = ">=3.10" -version = "4.9.0" +version = "4.10.0" [project.optional-dependencies] aiobreaker = ["aiobreaker>=0.3.0"] diff --git a/src/cqrs/__init__.py b/src/cqrs/__init__.py index 319c7a1..396f670 100644 --- a/src/cqrs/__init__.py +++ b/src/cqrs/__init__.py @@ -1,7 +1,9 @@ from cqrs.compressors import Compressor, ZlibCompressor from cqrs.container.di import DIContainer from cqrs.container.protocol import Container +from cqrs.circuit_breaker import ICircuitBreaker from cqrs.events import EventMap +from cqrs.events.fallback import EventHandlerFallback from cqrs.events.event import ( DCEvent, DCDomainEvent, @@ -35,6 +37,7 @@ SqlAlchemyOutboxedEventRepository, ) from cqrs.producer import EventProducer +from cqrs.requests.fallback import RequestHandlerFallback from cqrs.requests.map import RequestMap, SagaMap from cqrs.requests.mermaid import CoRMermaid from cqrs.requests.request import DCRequest, IRequest, PydanticRequest, Request @@ -53,6 +56,9 @@ ) __all__ = ( + "ICircuitBreaker", + "EventHandlerFallback", + "RequestHandlerFallback", "RequestMediator", "SagaMediator", "StreamingRequestMediator", diff --git a/src/cqrs/adapters/circuit_breaker.py b/src/cqrs/adapters/circuit_breaker.py index 843edb7..46c5218 100644 --- a/src/cqrs/adapters/circuit_breaker.py +++ b/src/cqrs/adapters/circuit_breaker.py @@ -5,8 +5,8 @@ from datetime import timedelta from typing import TYPE_CHECKING, Callable +from cqrs.circuit_breaker import ICircuitBreaker from cqrs.saga.circuit_breaker import ISagaStepCircuitBreaker -from cqrs.saga.step import SagaStepHandler logger = logging.getLogger("cqrs.adapters.circuit_breaker") @@ -105,12 +105,23 @@ def default_memory_storage_factory(name: str) -> _CircuitBreakerStorage: return CircuitMemoryStorage(state=aiobreaker.CircuitBreakerState.CLOSED) -class AioBreakerAdapter(ISagaStepCircuitBreaker): +def _identifier_to_name(identifier: type | str) -> str: + """Build circuit breaker namespace from type or string.""" + if isinstance(identifier, str): + return identifier + module = getattr(identifier, "__module__", "") + name = getattr(identifier, "__name__", str(identifier)) + return f"{module}.{name}" if module else name + + +class AioBreakerAdapter(ICircuitBreaker, ISagaStepCircuitBreaker): """ - Adapter for aiobreaker circuit breaker. + Unified adapter for aiobreaker circuit breaker. - Manages circuit breaker instances per step type. Each step type gets its own - isolated circuit breaker with a namespace based on the step class name. + Implements ICircuitBreaker (and ISagaStepCircuitBreaker for backward + compatibility). Each fallback type (Saga, Request, Event) typically + uses its own adapter instance; identifier can be a step/handler type + or a string for namespace. Attributes: fail_max: Maximum number of failures before opening the circuit. @@ -121,11 +132,10 @@ class AioBreakerAdapter(ISagaStepCircuitBreaker): Defaults to in-memory storage. Example:: - adapter = AioBreakerAdapter( - fail_max=3, - timeout_duration=60, - exclude=[InventoryOutOfStockError] - ) + # One adapter per fallback domain (each has its own breaker namespaces) + saga_cb = AioBreakerAdapter(fail_max=3, timeout_duration=60) + request_cb = AioBreakerAdapter(fail_max=5, timeout_duration=30) + event_cb = AioBreakerAdapter(fail_max=5, timeout_duration=30) """ def __init__( @@ -145,45 +155,23 @@ def __init__( self._exclude = exclude or [] self._storage_factory = storage_factory or default_memory_storage_factory - # Dictionary to store circuit breakers per step type + # Dictionary to store circuit breakers per identifier (type or str) self._breakers: dict[str, typing.Any] = {} # type: ignore[type-arg] - def _get_step_name(self, step_type: type[SagaStepHandler]) -> str: - """ - Get name for a step type. - - Uses the fully qualified name of the step class as namespace. - - Args: - step_type: The step handler class type. - - Returns: - Name string for the circuit breaker (module.class_name). - """ - module = getattr(step_type, "__module__", "") - name = step_type.__name__ - return f"{module}.{name}" if module else name - - def _get_breaker(self, step_type: type[SagaStepHandler]) -> typing.Any: # type: ignore[return-type] + def _get_breaker(self, identifier: type | str) -> typing.Any: # type: ignore[return-type] """ - Get or create circuit breaker for a step type. - - Each step type gets its own isolated circuit breaker with a unique name - based on the step's fully qualified name. + Get or create circuit breaker for an identifier (type or string). Args: - step_type: The step handler class type. + identifier: Step/handler type or string for namespace. Returns: - CircuitBreaker instance for this step type. + CircuitBreaker instance for this identifier. """ - step_name = self._get_step_name(step_type) - - if step_name not in self._breakers: - breaker = self._create_breaker(step_name) - self._breakers[step_name] = breaker - - return self._breakers[step_name] + name = _identifier_to_name(identifier) + if name not in self._breakers: + self._breakers[name] = self._create_breaker(name) + return self._breakers[name] def _create_breaker(self, name: str) -> typing.Any: # type: ignore[return-type] """ @@ -220,7 +208,7 @@ def _create_breaker(self, name: str) -> typing.Any: # type: ignore[return-type] async def call( self, - step_type: type[SagaStepHandler], + identifier: type | str, func: typing.Callable[..., typing.Awaitable[typing.Any]], *args: typing.Any, **kwargs: typing.Any, @@ -229,7 +217,7 @@ async def call( Execute the function with circuit breaker protection. Args: - step_type: The step handler class type. Used to determine breaker name. + identifier: Step/handler type or string for breaker namespace. func: The async function to execute. *args: Positional arguments to pass to func. **kwargs: Keyword arguments to pass to func. @@ -241,7 +229,7 @@ async def call( CircuitBreakerError: If the circuit breaker is open. Exception: Any exception raised by func (if circuit is closed). """ - breaker = self._get_breaker(step_type) + breaker = self._get_breaker(identifier) return await breaker.call_async(func, *args, **kwargs) def is_circuit_breaker_error(self, exc: Exception) -> bool: diff --git a/src/cqrs/circuit_breaker.py b/src/cqrs/circuit_breaker.py new file mode 100644 index 0000000..172eab6 --- /dev/null +++ b/src/cqrs/circuit_breaker.py @@ -0,0 +1,77 @@ +"""Unified circuit breaker protocol for Saga, Request and Event fallbacks.""" + +import typing + + +class ICircuitBreaker(typing.Protocol): + """ + Unified interface for circuit breaker implementations. + + Used by Saga step fallbacks, Request handler fallbacks and Event handler + fallbacks. The same adapter class works for all with identifier-based + namespacing. + + Note: + Implementors should use a dedicated adapter instance per domain (events, + requests, saga) to keep circuit breaker state isolated. + """ + + async def call( + self, + identifier: type | str, + func: typing.Callable[..., typing.Awaitable[typing.Any]], + *args: typing.Any, + **kwargs: typing.Any, + ) -> typing.Any: + """ + Execute the function with circuit breaker protection. + + Args: + identifier: Handler/step type or string used as circuit breaker namespace. + func: The async function to execute. + *args: Positional arguments to pass to func. + **kwargs: Keyword arguments to pass to func. + + Returns: + The result of func execution. + + Raises: + CircuitBreakerError: If the circuit breaker is open. + Exception: Any exception raised by func (if circuit is closed). + """ + ... + + def is_circuit_breaker_error(self, exc: Exception) -> bool: + """ + Check if the given exception is a circuit breaker error. + + Args: + exc: The exception to check. + + Returns: + True if the exception is a circuit breaker error, False otherwise. + """ + ... + + +def should_use_fallback( + primary_error: Exception, + circuit_breaker: ICircuitBreaker | None, + failure_exceptions: tuple[type[Exception], ...], +) -> bool: + """ + Determine whether to invoke the fallback after primary handler failure. + + Returns True if the circuit breaker reports a breaker error, or the + exception matches failure_exceptions, or failure_exceptions is empty + (any exception triggers fallback). + """ + if circuit_breaker is not None and circuit_breaker.is_circuit_breaker_error( + primary_error, + ): + return True + if failure_exceptions and isinstance(primary_error, failure_exceptions): + return True + if not failure_exceptions: + return True + return False diff --git a/src/cqrs/dispatcher/event.py b/src/cqrs/dispatcher/event.py index e38fe34..a6b0169 100644 --- a/src/cqrs/dispatcher/event.py +++ b/src/cqrs/dispatcher/event.py @@ -4,6 +4,7 @@ from cqrs.container.protocol import Container from cqrs.events.event import IEvent from cqrs.events.event_handler import EventHandler +from cqrs.events.fallback import EventHandlerFallback from cqrs.events.map import EventMap from cqrs.middlewares.base import MiddlewareChain @@ -33,6 +34,53 @@ async def _handle_event( for follow_up in handler.events: await self.dispatch(follow_up) + async def _handle_event_fallback( + self, + event: IEvent, + fallback_config: EventHandlerFallback, + ) -> None: + """Run primary handler with fallback on failure; dispatch follow-up events from the handler that ran.""" + primary: _EventHandler = await self._container.resolve(fallback_config.primary) + try: + if fallback_config.circuit_breaker is not None: + await fallback_config.circuit_breaker.call( + fallback_config.primary, + primary.handle, + event, + ) + else: + await primary.handle(event) + for follow_up in primary.events: + await self.dispatch(follow_up) + except Exception as primary_error: + should_fallback = False + if fallback_config.circuit_breaker is not None and fallback_config.circuit_breaker.is_circuit_breaker_error( + primary_error, + ): + should_fallback = True + elif fallback_config.failure_exceptions and isinstance( + primary_error, + fallback_config.failure_exceptions, + ): + should_fallback = True + elif not fallback_config.failure_exceptions: + should_fallback = True + if should_fallback: + logger.warning( + "Primary event handler %s failed: %s. Switching to fallback %s.", + fallback_config.primary.__name__, + primary_error, + fallback_config.fallback.__name__, + ) + fallback_handler: _EventHandler = await self._container.resolve( + fallback_config.fallback, + ) + await fallback_handler.handle(event) + for follow_up in fallback_handler.events: + await self.dispatch(follow_up) + else: + raise primary_error + async def dispatch(self, event: IEvent) -> None: handler_types = self._event_map.get(type(event), []) if not handler_types: @@ -42,4 +90,7 @@ async def dispatch(self, event: IEvent) -> None: ) return for h_type in handler_types: - await self._handle_event(event, h_type) + if isinstance(h_type, EventHandlerFallback): + await self._handle_event_fallback(event, h_type) + else: + await self._handle_event(event, h_type) diff --git a/src/cqrs/dispatcher/request.py b/src/cqrs/dispatcher/request.py index 3f94f8f..8e9b5cd 100644 --- a/src/cqrs/dispatcher/request.py +++ b/src/cqrs/dispatcher/request.py @@ -3,6 +3,7 @@ import typing from collections import abc +from cqrs.circuit_breaker import should_use_fallback from cqrs.container.protocol import Container from cqrs.dispatcher.exceptions import ( RequestHandlerDoesNotExist, @@ -15,6 +16,7 @@ build_chain, CORRequestHandlerT as CORRequestHandlerType, ) +from cqrs.requests.fallback import RequestHandlerFallback from cqrs.requests.map import RequestMap, HandlerType from cqrs.requests.request import IRequest from cqrs.requests.request_handler import RequestHandler @@ -44,7 +46,12 @@ async def _resolve_handler( For single handlers, resolves them using the DI container. For lists of handlers, validates they are COR handlers and builds a chain. + RequestHandlerFallback is not resolved here; use dispatch fallback path. """ + if isinstance(handler_type, RequestHandlerFallback): + raise RequestHandlerTypeError( + "RequestHandlerFallback must be handled in dispatch, not _resolve_handler", + ) if isinstance(handler_type, abc.Iterable): if not all( issubclass( @@ -65,12 +72,66 @@ async def _resolve_handler( return typing.cast(_RequestHandler, await self._container.resolve(handler_type)) + async def _dispatch_fallback( + self, + request: IRequest, + fallback_config: RequestHandlerFallback, + ) -> RequestDispatchResult: + """Dispatch using primary handler with fallback on failure.""" + primary = await self._container.resolve(fallback_config.primary) + try: + wrapped_primary = self._middleware_chain.wrap(primary.handle) + if fallback_config.circuit_breaker is not None: + response = await fallback_config.circuit_breaker.call( + fallback_config.primary, + wrapped_primary, + request, + ) + else: + response = await wrapped_primary(request) + return RequestDispatchResult(response=response, events=primary.events) + except Exception as primary_error: + should_fallback = should_use_fallback( + primary_error, + fallback_config.circuit_breaker, + fallback_config.failure_exceptions, + ) + if should_fallback: + if ( + fallback_config.circuit_breaker is not None + and fallback_config.circuit_breaker.is_circuit_breaker_error( + primary_error, + ) + ): + logger.warning( + "Circuit breaker open for request handler %s, switching to fallback %s", + fallback_config.primary.__name__, + fallback_config.fallback.__name__, + ) + else: + logger.warning( + "Primary handler %s failed: %s. Switching to fallback %s.", + fallback_config.primary.__name__, + primary_error, + fallback_config.fallback.__name__, + ) + fallback_handler = await self._container.resolve(fallback_config.fallback) + wrapped_fallback = self._middleware_chain.wrap(fallback_handler.handle) + response = await wrapped_fallback(request) + return RequestDispatchResult( + response=response, + events=fallback_handler.events, + ) + raise primary_error + async def dispatch(self, request: IRequest) -> RequestDispatchResult: handler_type = self._request_map.get(type(request), None) if not handler_type: raise RequestHandlerDoesNotExist( f"RequestHandler not found matching Request type {type(request)}", ) + if isinstance(handler_type, RequestHandlerFallback): + return await self._dispatch_fallback(request, handler_type) handler: _RequestHandler = await self._resolve_handler(handler_type) wrapped_handle = self._middleware_chain.wrap(handler.handle) response = await wrapped_handle(request) diff --git a/src/cqrs/dispatcher/streaming.py b/src/cqrs/dispatcher/streaming.py index 4069d94..2caabec 100644 --- a/src/cqrs/dispatcher/streaming.py +++ b/src/cqrs/dispatcher/streaming.py @@ -1,14 +1,19 @@ import inspect +import logging import typing +from cqrs.circuit_breaker import should_use_fallback from cqrs.container.protocol import Container from cqrs.dispatcher.exceptions import RequestHandlerDoesNotExist from cqrs.dispatcher.models import RequestDispatchResult from cqrs.middlewares.base import MiddlewareChain +from cqrs.requests.fallback import RequestHandlerFallback from cqrs.requests.map import RequestMap from cqrs.requests.request import IRequest from cqrs.requests.request_handler import StreamingRequestHandler +logger = logging.getLogger("cqrs") + class StreamingRequestDispatcher: """ @@ -16,6 +21,13 @@ class StreamingRequestDispatcher: This dispatcher handles requests using handlers that yield responses as generators. After each yield, events are collected and can be emitted. + + When a primary streaming handler (used via RequestHandlerFallback) fails + mid-stream, already-yielded RequestDispatchResult items are not rolled + back and the fallback handler streams from its start. Results may + therefore be duplicated; callers can de-duplicate if needed. The + fallback path is driven by _stream_from_handler and handler_type + (primary vs fallback). """ def __init__( @@ -41,31 +53,86 @@ def dispatch( """ return self._dispatch_impl(request) + @staticmethod + async def _stream_from_handler( + request: IRequest, + handler: StreamingRequestHandler, + ) -> typing.AsyncIterator[RequestDispatchResult]: + async for response in handler.handle(request): + events = list(handler.events) + handler.clear_events() + yield RequestDispatchResult(response=response, events=events) + async def _dispatch_impl( self, request: IRequest, ) -> typing.AsyncIterator[RequestDispatchResult]: + """ + Dispatch to the mapped handler. For RequestHandlerFallback, on primary + failure the fallback streams from scratch (see class docstring). + """ handler_type = self._request_map.get(type(request), None) if handler_type is None: - raise RequestHandlerDoesNotExist( - f"StreamingRequestHandler not found matching Request type {type(request)}", - ) + raise RequestHandlerDoesNotExist(f"StreamingRequestHandler not found matching Request type {type(request)}") - # Streaming dispatcher only works with streaming handlers, not lists if isinstance(handler_type, list): raise TypeError( "StreamingRequestDispatcher does not support COR handler chains. " "Use RequestDispatcher for chain of responsibility pattern.", ) - # Type narrowing: handler_type is now a single handler type - handler_type_typed = typing.cast( - typing.Type[StreamingRequestHandler], - handler_type, - ) - handler: StreamingRequestHandler = await self._container.resolve( - handler_type_typed, - ) + if isinstance(handler_type, RequestHandlerFallback): + primary = await self._container.resolve(handler_type.primary) + fallback_handler = await self._container.resolve(handler_type.fallback) + if not inspect.isasyncgenfunction(primary.handle) or not inspect.isasyncgenfunction( + fallback_handler.handle, + ): + raise TypeError( + "RequestHandlerFallback with StreamingRequestDispatcher requires " + "both primary and fallback to be async generator handlers", + ) + try: + async for result in self._stream_from_handler( + request, + typing.cast(StreamingRequestHandler, primary), + ): + yield result + except Exception as primary_error: + should_fallback = should_use_fallback( + primary_error, + handler_type.circuit_breaker, + handler_type.failure_exceptions, + ) + if should_fallback: + if ( + handler_type.circuit_breaker is not None + and handler_type.circuit_breaker.is_circuit_breaker_error( + primary_error, + ) + ): + logger.warning( + "Circuit breaker open for streaming handler %s, switching to fallback %s", + handler_type.primary.__name__, + handler_type.fallback.__name__, + ) + else: + logger.warning( + "Primary streaming handler %s failed: %s. Switching to fallback %s.", + handler_type.primary.__name__, + primary_error, + handler_type.fallback.__name__, + ) + async for result in self._stream_from_handler( + request, + typing.cast(StreamingRequestHandler, fallback_handler), + ): + yield result + else: + raise primary_error + return + + handler_type_typed = typing.cast(typing.Type[StreamingRequestHandler], handler_type) + handler: StreamingRequestHandler = await self._container.resolve(handler_type_typed) if not inspect.isasyncgenfunction(handler.handle): handler_name = ( @@ -75,11 +142,5 @@ async def _dispatch_impl( f"Handler {handler_name}.handle must be an async generator function", ) - async_gen = handler.handle(request) - async for response in async_gen: - events = list(handler.events) - handler.clear_events() - yield RequestDispatchResult( - response=response, - events=events, - ) + async for result in self._stream_from_handler(request, handler): + yield result diff --git a/src/cqrs/events/__init__.py b/src/cqrs/events/__init__.py index f7c9782..b08c93d 100644 --- a/src/cqrs/events/__init__.py +++ b/src/cqrs/events/__init__.py @@ -8,6 +8,7 @@ - :class:`EventEmitter` — sends domain events to handlers and notification events to a message broker. - :class:`EventMap` — registry of event type -> handler types; use :meth:`EventMap.bind`. +- :class:`EventHandlerFallback` — fallback wrapper for event handlers with optional circuit breaker. """ from cqrs.events.event import ( @@ -26,6 +27,7 @@ ) from cqrs.events.event_emitter import EventEmitter from cqrs.events.event_handler import EventHandler +from cqrs.events.fallback import EventHandlerFallback from cqrs.events.map import EventMap __all__ = ( @@ -43,5 +45,6 @@ "PydanticNotificationEvent", "EventEmitter", "EventHandler", + "EventHandlerFallback", "EventMap", ) diff --git a/src/cqrs/events/event_emitter.py b/src/cqrs/events/event_emitter.py index 49d7088..e21af7a 100644 --- a/src/cqrs/events/event_emitter.py +++ b/src/cqrs/events/event_emitter.py @@ -3,8 +3,10 @@ import typing from cqrs import container as di_container, message_brokers +from cqrs.circuit_breaker import should_use_fallback from cqrs.events.event import IDomainEvent, IEvent, INotificationEvent from cqrs.events import event_handler, map +from cqrs.events.fallback import EventHandlerFallback logger = logging.getLogger("cqrs") @@ -110,6 +112,53 @@ async def _send_to_broker( await self._message_broker.send_message(message) + async def _handle_with_fallback( + self, + event: IDomainEvent, + fallback_config: EventHandlerFallback, + ) -> typing.Sequence[IEvent]: + """Run primary handler with fallback on failure; return events from the handler that ran.""" + primary: _H = await self._container.resolve(fallback_config.primary) + try: + if fallback_config.circuit_breaker is not None: + await fallback_config.circuit_breaker.call( + fallback_config.primary, + primary.handle, + event, + ) + else: + await primary.handle(event) + return list(primary.events) + except Exception as primary_error: + should_fallback = should_use_fallback( + primary_error, + fallback_config.circuit_breaker, + fallback_config.failure_exceptions, + ) + if should_fallback: + if ( + fallback_config.circuit_breaker is not None + and fallback_config.circuit_breaker.is_circuit_breaker_error( + primary_error, + ) + ): + logger.warning( + "Circuit breaker open for event handler %s, switching to fallback %s", + fallback_config.primary.__name__, + fallback_config.fallback.__name__, + ) + else: + logger.warning( + "Primary event handler %s failed: %s. Switching to fallback %s.", + fallback_config.primary.__name__, + primary_error, + fallback_config.fallback.__name__, + ) + fallback_handler: _H = await self._container.resolve(fallback_config.fallback) + await fallback_handler.handle(event) + return list(fallback_handler.events) + raise primary_error + @emit.register(IDomainEvent) async def _(self, event: IDomainEvent) -> typing.Sequence[IEvent]: """Emit domain event: run all registered handlers and return their follow-up events.""" @@ -121,17 +170,20 @@ async def _(self, event: IDomainEvent) -> typing.Sequence[IEvent]: ) return () follow_ups: list[IEvent] = [] - for handler_type in handlers_types: - handler: _H = await self._container.resolve( - handler_type, - ) + for handler_item in handlers_types: + if isinstance(handler_item, EventHandlerFallback): + follow_ups.extend( + await self._handle_with_fallback(event, handler_item), + ) + continue + handler_type = handler_item + handler: _H = await self._container.resolve(handler_type) logger.debug( "Handling Event(%s) via event handler(%s)", type(event).__name__, handler_type.__name__, ) await handler.handle(event) - # Snapshot follow-ups so shared handlers don't expose stale state follow_ups.extend(list(handler.events)) return follow_ups diff --git a/src/cqrs/events/fallback.py b/src/cqrs/events/fallback.py new file mode 100644 index 0000000..6c5ba19 --- /dev/null +++ b/src/cqrs/events/fallback.py @@ -0,0 +1,92 @@ +"""Fallback wrapper for event handlers with optional circuit breaker.""" + +import dataclasses +import typing + +from cqrs.circuit_breaker import ICircuitBreaker +from cqrs.events import event_handler +from cqrs.generic_utils import get_generic_args_for_origin + +EventHandlerT = typing.Type[event_handler.EventHandler] + +_EVENT_HANDLER_ORIGINS: tuple[type, ...] = (event_handler.EventHandler,) + + +def _event_type_name(t: type) -> str: + return getattr(t, "__name__", str(t)) + + +@dataclasses.dataclass(frozen=True) +class EventHandlerFallback: + """ + Fallback wrapper for event handlers. + + When the primary handler fails (or circuit breaker is open), the fallback + handler is invoked. Use a separate circuit breaker instance per domain + (e.g. one for events) that uses the same adapter class. + + Attributes: + primary: The primary event handler class. + fallback: The fallback handler class to execute if primary fails. + failure_exceptions: Exception types that trigger fallback; if empty, any exception. + circuit_breaker: Optional circuit breaker instance (e.g. AioBreakerAdapter). + + Example:: + event_cb = AioBreakerAdapter(fail_max=5, timeout_duration=60) + event_map.bind( + OrderCreatedEvent, + EventHandlerFallback( + SendEmailHandler, + SendEmailFallbackHandler, + circuit_breaker=event_cb, + ), + ) + """ + + primary: EventHandlerT + fallback: EventHandlerT + failure_exceptions: tuple[type[Exception], ...] = () + circuit_breaker: ICircuitBreaker | None = None + + def __post_init__(self) -> None: + if not isinstance(self.primary, type) or not isinstance(self.fallback, type): + raise TypeError( + "EventHandlerFallback primary and fallback must be handler classes", + ) + if not issubclass(self.primary, event_handler.EventHandler): + raise TypeError( + f"EventHandlerFallback primary ({self.primary.__name__}) " "must be a subclass of EventHandler", + ) + if not issubclass(self.fallback, event_handler.EventHandler): + raise TypeError( + f"EventHandlerFallback fallback ({self.fallback.__name__}) " "must be a subclass of EventHandler", + ) + # Validate that primary and fallback handle the same event type + primary_args = get_generic_args_for_origin( + self.primary, + _EVENT_HANDLER_ORIGINS, + min_args=1, + ) + fallback_args = get_generic_args_for_origin( + self.fallback, + _EVENT_HANDLER_ORIGINS, + min_args=1, + ) + if primary_args is not None and fallback_args is not None: + # Reject TypeVar (unparameterized) so we only allow concrete types + if any(isinstance(a, typing.TypeVar) for a in primary_args + fallback_args): + raise TypeError( + "EventHandlerFallback primary and fallback must be parameterized with a concrete event type " + "(e.g. EventHandler[MyEvent])", + ) + if primary_args[0] != fallback_args[0]: + raise TypeError( + "EventHandlerFallback primary and fallback must handle the same event type: " + f"primary {self.primary.__name__} handles {_event_type_name(primary_args[0])}, " + f"fallback {self.fallback.__name__} handles {_event_type_name(fallback_args[0])}", + ) + elif primary_args is None or fallback_args is None: + raise TypeError( + "EventHandlerFallback primary and fallback must be parameterized with a concrete event type " + "(e.g. EventHandler[MyEvent])", + ) diff --git a/src/cqrs/events/map.py b/src/cqrs/events/map.py index 687d91b..dbd5b22 100644 --- a/src/cqrs/events/map.py +++ b/src/cqrs/events/map.py @@ -2,17 +2,20 @@ from cqrs.events.event import IEvent from cqrs.events import event_handler +from cqrs.events.fallback import EventHandlerFallback _KT = typing.TypeVar("_KT", bound=typing.Type[IEvent]) -_VT: typing.TypeAlias = typing.List[typing.Type[event_handler.EventHandler]] +_HandlerItem = typing.Type[event_handler.EventHandler] | EventHandlerFallback +_VT: typing.TypeAlias = typing.List[_HandlerItem] class EventMap(typing.Dict[_KT, _VT]): """ - Registry mapping event types to one or more handler types. + Registry mapping event types to one or more handler types or fallbacks. Use :meth:`bind` to register handlers for an event type. Multiple handlers can be bound to the same event; all will be invoked when the event is emitted. + Handlers can be plain types or :class:`~cqrs.events.fallback.EventHandlerFallback`. Keys cannot be overwritten or deleted. Example:: @@ -20,30 +23,26 @@ class EventMap(typing.Dict[_KT, _VT]): event_map = EventMap() event_map.bind(OrderCreatedEvent, OrderCreatedEventHandler) event_map.bind(OrderCreatedEvent, SendEmailHandler) # second handler for same event - # event_map[OrderCreatedEvent] -> [OrderCreatedEventHandler, SendEmailHandler] + event_map.bind(OrderCreatedEvent, EventHandlerFallback(PrimaryHandler, FallbackHandler, circuit_breaker=cb)) """ def bind( self, event_type: _KT, - handler_type: typing.Type[event_handler.EventHandler], + handler_type: _HandlerItem, ) -> None: """ - Register a handler type for an event type. + Register a handler type or EventHandlerFallback for an event type. If the event type is new, creates a list with this handler. If the event type already exists, appends the handler (duplicates are rejected). Args: event_type: Event class (e.g. :class:`OrderCreatedEvent`). - handler_type: Handler class implementing :class:`~cqrs.events.event_handler.EventHandler`. + handler_type: Handler class or :class:`~cqrs.events.fallback.EventHandlerFallback`. Raises: - KeyError: If the same handler type is already bound to this event type. - - Example:: - - event_map.bind(OrderCreatedEvent, OrderCreatedEventHandler) + KeyError: If the same handler type or fallback is already bound to this event type. """ if event_type not in self: self[event_type] = [handler_type] diff --git a/src/cqrs/generic_utils.py b/src/cqrs/generic_utils.py new file mode 100644 index 0000000..438359b --- /dev/null +++ b/src/cqrs/generic_utils.py @@ -0,0 +1,43 @@ +"""Shared utilities for extracting generic type parameters from handler classes.""" + +import typing + + +def get_generic_args_for_origin( + klass: type, + origin_classes: tuple[type, ...], + min_args: int = 1, +) -> tuple[type, ...] | None: + """ + Extract generic type arguments from a class that inherits from a Generic base. + + Walks __orig_bases__ and __bases__ to find the first base whose origin is + one of the given origin_classes, then returns typing.get_args(base). + + Args: + klass: The handler class (e.g. a subclass of RequestHandler[Req, Res]). + origin_classes: Tuple of possible origin classes (e.g. (RequestHandler, StreamingRequestHandler)). + min_args: Minimum number of type arguments required to consider the result valid. + + Returns: + Tuple of type arguments (e.g. (ReqT, ResT) or (ET,)), or None if not found + or if the base has fewer than min_args concrete arguments. + """ + # Prefer __orig_bases__ (Python 3.12+ / generic subclass) + orig_bases = getattr(klass, "__orig_bases__", ()) + for base in orig_bases: + origin = typing.get_origin(base) + if origin in origin_classes: + args = typing.get_args(base) + if len(args) >= min_args: + return args + + # Fallback: __bases__ may contain the parameterized base + for base in klass.__bases__: + origin = typing.get_origin(base) + if origin in origin_classes: + args = typing.get_args(base) + if len(args) >= min_args: + return args + + return None diff --git a/src/cqrs/middlewares/base.py b/src/cqrs/middlewares/base.py index 1e4295b..60856f5 100644 --- a/src/cqrs/middlewares/base.py +++ b/src/cqrs/middlewares/base.py @@ -2,7 +2,7 @@ import typing from cqrs.saga.models import SagaContext -from cqrs.types import ReqT, ResT +from cqrs.requests.request import ReqT, ResT HandleType = typing.Callable[[ReqT], typing.Awaitable[ResT] | ResT] diff --git a/src/cqrs/requests/cor_request_handler.py b/src/cqrs/requests/cor_request_handler.py index 442fd1e..c102fe1 100644 --- a/src/cqrs/requests/cor_request_handler.py +++ b/src/cqrs/requests/cor_request_handler.py @@ -5,7 +5,7 @@ import typing from cqrs.events.event import IEvent -from cqrs.types import ReqT, ResT +from cqrs.requests.request import ReqT, ResT class CORRequestHandler(abc.ABC, typing.Generic[ReqT, ResT]): diff --git a/src/cqrs/requests/fallback.py b/src/cqrs/requests/fallback.py new file mode 100644 index 0000000..93f4677 --- /dev/null +++ b/src/cqrs/requests/fallback.py @@ -0,0 +1,98 @@ +"""Fallback wrapper for request handlers with optional circuit breaker.""" + +import dataclasses +import typing + +from cqrs.circuit_breaker import ICircuitBreaker +from cqrs.generic_utils import get_generic_args_for_origin +from cqrs.requests.request_handler import RequestHandler, StreamingRequestHandler + +RequestHandlerT = type[RequestHandler] | type[StreamingRequestHandler] + +_REQUEST_HANDLER_ORIGINS: tuple[type, ...] = (RequestHandler, StreamingRequestHandler) + + +def _type_name(t: type) -> str: + return getattr(t, "__name__", str(t)) + + +@dataclasses.dataclass(frozen=True) +class RequestHandlerFallback: + """ + Fallback wrapper for request handlers. + + When the primary handler fails (or circuit breaker is open), the fallback + handler is invoked. Use a separate circuit breaker instance per domain + (e.g. one for requests) that uses the same adapter class. + + Attributes: + primary: The primary request handler class. + fallback: The fallback handler class to execute if primary fails. + failure_exceptions: Exception types that trigger fallback; if empty, any exception. + circuit_breaker: Optional circuit breaker instance (e.g. AioBreakerAdapter). + + Example:: + request_cb = AioBreakerAdapter(fail_max=5, timeout_duration=60) + request_map.bind( + MyCommand, + RequestHandlerFallback( + MyCommandHandler, + MyCommandHandlerFallback, + failure_exceptions=(ConnectionError, TimeoutError), + circuit_breaker=request_cb, + ), + ) + """ + + primary: RequestHandlerT + fallback: RequestHandlerT + failure_exceptions: tuple[type[Exception], ...] = () + circuit_breaker: ICircuitBreaker | None = None + + def __post_init__(self) -> None: + if not isinstance(self.primary, type) or not isinstance(self.fallback, type): + raise TypeError( + "RequestHandlerFallback primary and fallback must be handler classes", + ) + primary_streaming = issubclass(self.primary, StreamingRequestHandler) + fallback_streaming = issubclass(self.fallback, StreamingRequestHandler) + if primary_streaming != fallback_streaming: + raise TypeError( + "RequestHandlerFallback primary and fallback must be the same handler base type: " + "both RequestHandler or both StreamingRequestHandler", + ) + # Validate that primary and fallback handle the same request and response types + primary_args = get_generic_args_for_origin( + self.primary, + _REQUEST_HANDLER_ORIGINS, + min_args=2, + ) + fallback_args = get_generic_args_for_origin( + self.fallback, + _REQUEST_HANDLER_ORIGINS, + min_args=2, + ) + if primary_args is not None and fallback_args is not None: + # Reject TypeVar (unparameterized) so we only allow concrete types + if any(isinstance(a, typing.TypeVar) for a in primary_args + fallback_args): + raise TypeError( + "RequestHandlerFallback primary and fallback must be parameterized with concrete types " + "(e.g. RequestHandler[MyCommand, MyResult] or StreamingRequestHandler[MyCommand, MyResult])", + ) + if primary_args[0] != fallback_args[0]: + raise TypeError( + "RequestHandlerFallback primary and fallback must handle the same request type: " + f"primary {self.primary.__name__} handles {_type_name(primary_args[0])}, " + f"fallback {self.fallback.__name__} handles {_type_name(fallback_args[0])}", + ) + if primary_args[1] != fallback_args[1]: + raise TypeError( + "RequestHandlerFallback primary and fallback must have the same response type: " + f"primary {self.primary.__name__} returns {_type_name(primary_args[1])}, " + f"fallback {self.fallback.__name__} returns {_type_name(fallback_args[1])}", + ) + elif primary_args is None or fallback_args is None: + raise TypeError( + "RequestHandlerFallback primary and fallback must be parameterized with concrete types " + "(e.g. RequestHandler[MyCommand, MyResult] or StreamingRequestHandler[MyCommand, MyResult])", + ) diff --git a/src/cqrs/requests/map.py b/src/cqrs/requests/map.py index 8a368c1..e4eabcd 100644 --- a/src/cqrs/requests/map.py +++ b/src/cqrs/requests/map.py @@ -1,6 +1,7 @@ import typing from cqrs.requests.cor_request_handler import CORRequestHandler +from cqrs.requests.fallback import RequestHandlerFallback from cqrs.requests.request import IRequest from cqrs.requests.request_handler import ( RequestHandler, @@ -12,7 +13,11 @@ _KT = typing.TypeVar("_KT", bound=typing.Type[IRequest]) # Type alias for handler types that can be bound to requests -HandlerType = typing.Type[RequestHandler | StreamingRequestHandler] | typing.List[typing.Type[CORRequestHandler]] +HandlerType = ( + typing.Type[RequestHandler | StreamingRequestHandler] + | typing.List[typing.Type[CORRequestHandler]] + | RequestHandlerFallback +) class RequestMap(typing.Dict[_KT, HandlerType]): diff --git a/src/cqrs/requests/request.py b/src/cqrs/requests/request.py index b220d3e..2a6d70f 100644 --- a/src/cqrs/requests/request.py +++ b/src/cqrs/requests/request.py @@ -1,9 +1,12 @@ import abc import dataclasses import sys +import typing import pydantic +from cqrs.response import IResponse + if sys.version_info >= (3, 11): from typing import Self # novm else: @@ -48,6 +51,12 @@ def from_dict(cls, **kwargs) -> Self: raise NotImplementedError +# Type variables for request/response (defined here to avoid circular import with +# cqrs.types <-> cqrs.requests.request_handler). Re-exported from cqrs.types for compatibility. +ReqT = typing.TypeVar("ReqT", bound=IRequest, contravariant=True) +ResT = typing.TypeVar("ResT", bound=IResponse | None, covariant=True) + + @dataclasses.dataclass class DCRequest(IRequest): """ @@ -140,4 +149,4 @@ def to_dict(self) -> dict: Request = PydanticRequest -__all__ = ("Request", "IRequest", "DCRequest", "PydanticRequest") +__all__ = ("Request", "IRequest", "DCRequest", "PydanticRequest", "ReqT", "ResT") diff --git a/src/cqrs/requests/request_handler.py b/src/cqrs/requests/request_handler.py index bfdef7c..751f47c 100644 --- a/src/cqrs/requests/request_handler.py +++ b/src/cqrs/requests/request_handler.py @@ -2,7 +2,7 @@ import typing from cqrs.events.event import IEvent -from cqrs.types import ReqT, ResT +from cqrs.requests.request import ReqT, ResT class RequestHandler(abc.ABC, typing.Generic[ReqT, ResT]): diff --git a/src/cqrs/saga/execution.py b/src/cqrs/saga/execution.py index 8b6e1d0..1c0b611 100644 --- a/src/cqrs/saga/execution.py +++ b/src/cqrs/saga/execution.py @@ -277,9 +277,9 @@ async def execute_fallback_step( # Execute primary step with circuit breaker if present if fallback_wrapper.circuit_breaker is not None: step_result = await fallback_wrapper.circuit_breaker.call( - step_type=fallback_wrapper.step, - func=primary_step.act, - context=self._context, + fallback_wrapper.step, + primary_step.act, + self._context, ) else: step_result = await primary_step.act(self._context) diff --git a/src/cqrs/saga/fallback.py b/src/cqrs/saga/fallback.py index d6233ae..03ab78d 100644 --- a/src/cqrs/saga/fallback.py +++ b/src/cqrs/saga/fallback.py @@ -2,11 +2,11 @@ import dataclasses -from cqrs.saga.circuit_breaker import ISagaStepCircuitBreaker +from cqrs.circuit_breaker import ICircuitBreaker from cqrs.saga.step import SagaStepHandler -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Fallback: """ Fallback wrapper for Saga steps. @@ -32,4 +32,4 @@ class Fallback: step: type[SagaStepHandler] fallback: type[SagaStepHandler] failure_exceptions: tuple[type[Exception], ...] = () - circuit_breaker: ISagaStepCircuitBreaker | None = None + circuit_breaker: ICircuitBreaker | None = None diff --git a/src/cqrs/types.py b/src/cqrs/types.py index 6fdbc20..c4cfbb2 100644 --- a/src/cqrs/types.py +++ b/src/cqrs/types.py @@ -1,18 +1,11 @@ """ Type definitions for CQRS framework. -This module contains common type variables used throughout the framework. -It is placed at the bottom of the dependency hierarchy to avoid circular imports. +This module re-exports common type variables (ReqT, ResT) from +cqrs.requests.request for backward compatibility. Defining ReqT/ResT in +request.py avoids circular import with request_handler. """ -import typing +from cqrs.requests.request import ReqT, ResT -from cqrs.requests.request import IRequest -from cqrs.response import IResponse - -# Type variable for request types (contravariant - can accept subtypes) -ReqT = typing.TypeVar("ReqT", bound=IRequest, contravariant=True) - -# Type variable for response types (covariant - can return subtypes) -# Can be IResponse or None -ResT = typing.TypeVar("ResT", bound=IResponse | None, covariant=True) +__all__ = ("ReqT", "ResT") diff --git a/tests/integration/test_pybreaker_adapter.py b/tests/integration/test_pybreaker_adapter.py index 91dcf9b..0746784 100644 --- a/tests/integration/test_pybreaker_adapter.py +++ b/tests/integration/test_pybreaker_adapter.py @@ -84,7 +84,7 @@ async def test_successful_execution(self, adapter): # Act result = await adapter.call( - step_type=step_type, + identifier=step_type, func=successful_function, value=5, ) @@ -103,14 +103,14 @@ async def test_namespace_isolation(self, adapter): for _ in range(2): with pytest.raises(RuntimeError): await adapter.call( - step_type=step_type_1, + identifier=step_type_1, func=failing_function, error_type=RuntimeError, ) # Act - Step2 should still work (different namespace) result = await adapter.call( - step_type=step_type_2, + identifier=step_type_2, func=successful_function, value=3, ) @@ -121,7 +121,7 @@ async def test_namespace_isolation(self, adapter): # Act - Step1 circuit should still be closed (only 2 failures, need 3 to open) with pytest.raises(CircuitBreakerError): await adapter.call( - step_type=step_type_1, + identifier=step_type_1, func=failing_function, error_type=RuntimeError, ) @@ -129,7 +129,7 @@ async def test_namespace_isolation(self, adapter): # Now circuit should be open (3 failures reached) with pytest.raises(CircuitBreakerError): await adapter.call( - step_type=step_type_1, + identifier=step_type_1, func=failing_function, error_type=RuntimeError, ) @@ -162,7 +162,7 @@ async def test_circuit_reset_after_timeout(self, adapter): for _ in range(3): try: await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -172,7 +172,7 @@ async def test_circuit_reset_after_timeout(self, adapter): # Assert - Circuit should be open now with pytest.raises(CircuitBreakerError): await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -183,7 +183,7 @@ async def test_circuit_reset_after_timeout(self, adapter): # Assert - Circuit should be half-open, trial call fails and circuit opens again with pytest.raises(CircuitBreakerError): await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -198,7 +198,7 @@ async def test_is_circuit_breaker_error(self, adapter): for _ in range(3): try: await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -208,7 +208,7 @@ async def test_is_circuit_breaker_error(self, adapter): # Act - Try to call when circuit is open try: await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -227,7 +227,7 @@ async def test_circuit_opens_after_failures(self, adapter): for _ in range(2): with pytest.raises(RuntimeError): await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -246,7 +246,7 @@ async def test_circuit_opens_after_failures(self, adapter): # Call 3 raises CircuitBreakerError (not RuntimeError) with pytest.raises(CircuitBreakerError): # The 3rd failure await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -254,7 +254,7 @@ async def test_circuit_opens_after_failures(self, adapter): # Call 4 raises CircuitBreakerError with pytest.raises(CircuitBreakerError): await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -289,7 +289,7 @@ async def test_custom_configuration( for _ in range(3): with pytest.raises(BusinessException): await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=BusinessException, ) @@ -298,7 +298,7 @@ async def test_custom_configuration( # 1st failure with pytest.raises(NetworkException): await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=NetworkException, ) @@ -306,7 +306,7 @@ async def test_custom_configuration( # 2nd failure -> Open. Should raise CircuitBreakerError immediately with pytest.raises(CircuitBreakerError): await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=NetworkException, ) @@ -314,7 +314,7 @@ async def test_custom_configuration( # 3rd call -> CircuitBreakerError with pytest.raises(CircuitBreakerError): await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=NetworkException, ) @@ -329,7 +329,7 @@ async def test_concurrent_calls(self, adapter): for _ in range(3): try: await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -339,7 +339,7 @@ async def test_concurrent_calls(self, adapter): # Verify open with pytest.raises(CircuitBreakerError): await adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) @@ -347,7 +347,7 @@ async def test_concurrent_calls(self, adapter): # Concurrent calls should all fail fast tasks = [ adapter.call( - step_type=step_type, + identifier=step_type, func=failing_function, error_type=RuntimeError, ) diff --git a/tests/unit/test_event_fallback.py b/tests/unit/test_event_fallback.py new file mode 100644 index 0000000..2a18f0d --- /dev/null +++ b/tests/unit/test_event_fallback.py @@ -0,0 +1,181 @@ +"""Tests for EventHandlerFallback (without circuit breaker).""" + +from collections.abc import Sequence +from typing import Any, TypeVar + +import pytest + +from cqrs import EventHandlerFallback +from cqrs.container.protocol import Container +from cqrs.events.event import DomainEvent, IEvent +from cqrs.events.event_emitter import EventEmitter +from cqrs.events.event_handler import EventHandler +from cqrs.events.map import EventMap + +T = TypeVar("T") + + +class SampleEvent(DomainEvent, frozen=True): + """Event type for fallback tests (name avoids pytest collecting it as a test class).""" + + id: str + + +class PrimaryEventHandler(EventHandler[SampleEvent]): + def __init__(self) -> None: + self._evs: list[IEvent] = [] + self.called = False + + @property + def events(self) -> Sequence[IEvent]: + return self._evs.copy() + + async def handle(self, event: SampleEvent) -> None: + self.called = True + raise RuntimeError("Primary failed") + + +class FallbackEventHandler(EventHandler[SampleEvent]): + def __init__(self) -> None: + self._evs: list[IEvent] = [] + self.called = False + + @property + def events(self) -> Sequence[IEvent]: + return self._evs.copy() + + async def handle(self, event: SampleEvent) -> None: + self.called = True + + +class _TestEventContainer: + """Minimal container for event fallback tests; implements Container protocol.""" + + def __init__(self) -> None: + self._primary = PrimaryEventHandler() + self._fallback = FallbackEventHandler() + self._external_container: Any = None + + @property + def external_container(self) -> Any: + return self._external_container + + def attach_external_container(self, container: Any) -> None: + self._external_container = container + + async def resolve(self, type_: type[T]) -> T: + if type_ is PrimaryEventHandler: + return self._primary # type: ignore[return-value] + if type_ is FallbackEventHandler: + return self._fallback # type: ignore[return-value] + raise KeyError(type_) + + +@pytest.mark.asyncio +async def test_event_fallback_no_cb_primary_fails_uses_fallback() -> None: + event_map: EventMap = EventMap() + event_map.bind( + SampleEvent, + EventHandlerFallback(PrimaryEventHandler, FallbackEventHandler), + ) + container: Container[Any] = _TestEventContainer() + emitter = EventEmitter(event_map=event_map, container=container) + + follow_ups = await emitter.emit(SampleEvent(id="e1")) + + assert container._primary.called + assert container._fallback.called + assert follow_ups == [] + + +@pytest.mark.asyncio +async def test_event_fallback_failure_exceptions_only_matching_triggers_fallback() -> None: + event_map = EventMap() + event_map.bind( + SampleEvent, + EventHandlerFallback( + PrimaryEventHandler, + FallbackEventHandler, + failure_exceptions=(ValueError,), + ), + ) + container: Container[Any] = _TestEventContainer() + emitter = EventEmitter(event_map=event_map, container=container) + + with pytest.raises(RuntimeError, match="Primary failed"): + await emitter.emit(SampleEvent(id="e1")) + + assert container._primary.called + assert not container._fallback.called + + +@pytest.mark.asyncio +async def test_event_fallback_matching_filter_triggers_fallback() -> None: + """When failure_exceptions matches the primary error, fallback is invoked.""" + event_map: EventMap = EventMap() + event_map.bind( + SampleEvent, + EventHandlerFallback( + PrimaryEventHandler, + FallbackEventHandler, + failure_exceptions=(RuntimeError,), + ), + ) + container: Container[Any] = _TestEventContainer() + emitter = EventEmitter(event_map=event_map, container=container) + + follow_ups = await emitter.emit(SampleEvent(id="e1")) + + assert container._primary.called + assert container._fallback.called + assert follow_ups == [] + + +# --- Validation tests --- + + +def test_event_fallback_validation_same_event_type_accepts() -> None: + """Same event type is accepted.""" + EventHandlerFallback(PrimaryEventHandler, FallbackEventHandler) + + +def test_event_fallback_validation_different_event_type_raises() -> None: + """Different event types raise TypeError.""" + from cqrs.events.event import DomainEvent + + class OtherEvent(DomainEvent, frozen=True): + num: int + + class HandlerOther(EventHandler[OtherEvent]): + async def handle(self, event: OtherEvent) -> None: + pass + + with pytest.raises(TypeError, match="same event type"): + EventHandlerFallback(PrimaryEventHandler, HandlerOther) + + +def test_event_fallback_validation_not_classes_raises() -> None: + """Passing non-classes raises TypeError.""" + with pytest.raises(TypeError, match="must be handler classes"): + EventHandlerFallback(PrimaryEventHandler, FallbackEventHandler()) # type: ignore[arg-type] + with pytest.raises(TypeError, match="must be handler classes"): + EventHandlerFallback(PrimaryEventHandler(), FallbackEventHandler) # type: ignore[arg-type] + + +def test_event_fallback_validation_primary_not_event_handler_raises() -> None: + """Primary that is not EventHandler subclass raises TypeError.""" + from cqrs.requests.request import Request + from cqrs.requests.request_handler import RequestHandler + from cqrs.response import Response + + class NotAnEventHandler: + pass + + class SomeHandler(RequestHandler[Request, Response]): + async def handle(self, request: Request) -> Response: + raise NotImplementedError + + with pytest.raises(TypeError, match="primary.*must be a subclass of EventHandler"): + EventHandlerFallback(NotAnEventHandler, FallbackEventHandler) # type: ignore[arg-type] + with pytest.raises(TypeError, match="primary.*must be a subclass of EventHandler"): + EventHandlerFallback(SomeHandler, FallbackEventHandler) # pyright: ignore[reportArgumentType] diff --git a/tests/unit/test_request_fallback.py b/tests/unit/test_request_fallback.py new file mode 100644 index 0000000..03b19d2 --- /dev/null +++ b/tests/unit/test_request_fallback.py @@ -0,0 +1,261 @@ +"""Tests for RequestHandlerFallback (without circuit breaker).""" + +from typing import Any, TypeVar + +import pytest + +from cqrs import RequestHandlerFallback +from cqrs.container.protocol import Container +from cqrs.dispatcher import RequestDispatcher +from cqrs.events.event import IEvent +from cqrs.requests.map import RequestMap +from cqrs.requests.request import Request +from cqrs.requests.request_handler import RequestHandler +from cqrs.response import Response + +T = TypeVar("T") + + +class SimpleCommand(Request): + value: str + + +class SimpleResult(Response): + value: str + + +class PrimaryHandler(RequestHandler[SimpleCommand, SimpleResult]): + def __init__(self) -> None: + self._events: list[IEvent] = [] + self.called = False + + @property + def events(self) -> list[IEvent]: + return self._events.copy() + + async def handle(self, request: SimpleCommand) -> SimpleResult: + self.called = True + raise RuntimeError("Primary failed") + + +class FallbackHandler(RequestHandler[SimpleCommand, SimpleResult]): + def __init__(self) -> None: + self._events: list[IEvent] = [] + self.called = False + + @property + def events(self) -> list[IEvent]: + return self._events.copy() + + async def handle(self, request: SimpleCommand) -> SimpleResult: + self.called = True + return SimpleResult(value=f"fallback:{request.value}") + + +class _TestRequestContainer(Container[Any]): + """Minimal container for request fallback tests.""" + + def __init__(self) -> None: + self._primary = PrimaryHandler() + self._fallback = FallbackHandler() + self._external_container: Any = None + + @property + def external_container(self) -> Any: + return self._external_container + + def attach_external_container(self, container: Any) -> None: + self._external_container = container + + async def resolve(self, type_: type[T]) -> T: + if type_ is PrimaryHandler: + return self._primary # type: ignore[return-value] + if type_ is FallbackHandler: + return self._fallback # type: ignore[return-value] + raise KeyError(type_) + + +@pytest.mark.asyncio +async def test_request_fallback_no_cb_primary_fails_uses_fallback() -> None: + request_map: RequestMap = RequestMap() + request_map.bind( + SimpleCommand, + RequestHandlerFallback(PrimaryHandler, FallbackHandler), + ) + container: Container[Any] = _TestRequestContainer() + dispatcher = RequestDispatcher(request_map=request_map, container=container) + + result = await dispatcher.dispatch(SimpleCommand(value="x")) + + assert result.response.value == "fallback:x" + assert container._primary.called + assert container._fallback.called + + +@pytest.mark.asyncio +async def test_request_fallback_failure_exceptions_only_matching_triggers_fallback() -> None: + request_map = RequestMap() + request_map.bind( + SimpleCommand, + RequestHandlerFallback( + PrimaryHandler, + FallbackHandler, + failure_exceptions=(ValueError,), + ), + ) + container: Container[Any] = _TestRequestContainer() + dispatcher = RequestDispatcher(request_map=request_map, container=container) + + with pytest.raises(RuntimeError, match="Primary failed"): + await dispatcher.dispatch(SimpleCommand(value="x")) + + assert container._primary.called + assert not container._fallback.called + + +@pytest.mark.asyncio +async def test_request_fallback_primary_succeeds_fallback_not_invoked() -> None: + """When the primary handler succeeds, the fallback is not invoked.""" + + class SuccessPrimaryHandler(RequestHandler[SimpleCommand, SimpleResult]): + def __init__(self) -> None: + self._events: list[IEvent] = [] + self.called = False + + @property + def events(self) -> list[IEvent]: + return self._events.copy() + + async def handle(self, request: SimpleCommand) -> SimpleResult: + self.called = True + return SimpleResult(value=f"primary:{request.value}") + + class UnusedFallbackHandler(RequestHandler[SimpleCommand, SimpleResult]): + def __init__(self) -> None: + self._events: list[IEvent] = [] + self.called = False + + @property + def events(self) -> list[IEvent]: + return self._events.copy() + + async def handle(self, request: SimpleCommand) -> SimpleResult: + self.called = True + return SimpleResult(value="unused") + + class SuccessContainer(Container[Any]): + def __init__(self) -> None: + self._primary = SuccessPrimaryHandler() + self._fallback = UnusedFallbackHandler() + self._external_container: Any = None + + @property + def external_container(self) -> Any: + return self._external_container + + def attach_external_container(self, container: Any) -> None: + self._external_container = container + + async def resolve(self, type_: type[T]) -> T: + if type_ is SuccessPrimaryHandler: + return self._primary # type: ignore[return-value] + if type_ is UnusedFallbackHandler: + return self._fallback # type: ignore[return-value] + raise KeyError(type_) + + request_map = RequestMap() + request_map.bind( + SimpleCommand, + RequestHandlerFallback(SuccessPrimaryHandler, UnusedFallbackHandler), + ) + container = SuccessContainer() + dispatcher = RequestDispatcher(request_map=request_map, container=container) + + result = await dispatcher.dispatch(SimpleCommand(value="ok")) + + assert result.response.value == "primary:ok" + assert container._primary.called + assert not container._fallback.called + + +# --- Validation tests --- + + +def test_request_fallback_validation_same_request_and_response_types_accepts() -> None: + """Same request and response types (including None) are accepted.""" + RequestHandlerFallback(PrimaryHandler, FallbackHandler) + + +def test_request_fallback_validation_different_request_type_raises() -> None: + """Different request types raise TypeError.""" + from cqrs.requests.request import Request + from cqrs.response import Response + + class OtherCommand(Request): + value: int + + class OtherResult(Response): + value: int + + class FallbackOther(RequestHandler[OtherCommand, OtherResult]): + async def handle(self, request: OtherCommand) -> OtherResult: + return OtherResult(value=0) + + with pytest.raises(TypeError, match="same request type"): + RequestHandlerFallback(PrimaryHandler, FallbackOther) + + +def test_request_fallback_validation_different_response_type_raises() -> None: + """Different response types raise TypeError.""" + from cqrs.response import Response + + class OtherResult(Response): + value: int + + class FallbackOtherResult(RequestHandler[SimpleCommand, OtherResult]): + async def handle(self, request: SimpleCommand) -> OtherResult: + return OtherResult(value=0) + + with pytest.raises(TypeError, match="same response type"): + RequestHandlerFallback(PrimaryHandler, FallbackOtherResult) + + +def test_request_fallback_validation_same_types_with_none_response_accepts() -> None: + """Both request and response (None) matching is accepted.""" + from cqrs.requests.request import Request + + class NoResultCommand(Request): + x: str + + class PrimaryNoRes(RequestHandler[NoResultCommand, None]): + async def handle(self, request: NoResultCommand) -> None: + return None + + class FallbackNoRes(RequestHandler[NoResultCommand, None]): + async def handle(self, request: NoResultCommand) -> None: + return None + + RequestHandlerFallback(PrimaryNoRes, FallbackNoRes) + + +def test_request_fallback_validation_not_classes_raises() -> None: + """Passing non-classes raises TypeError.""" + with pytest.raises(TypeError, match="must be handler classes"): + RequestHandlerFallback(PrimaryHandler, FallbackHandler()) # type: ignore[arg-type] + with pytest.raises(TypeError, match="must be handler classes"): + RequestHandlerFallback(PrimaryHandler(), FallbackHandler) # type: ignore[arg-type] + + +def test_request_fallback_validation_mixed_handler_base_raises() -> None: + """Mixing RequestHandler and StreamingRequestHandler raises TypeError.""" + from cqrs.requests.request_handler import StreamingRequestHandler + + class StreamingPrimary(StreamingRequestHandler[SimpleCommand, SimpleResult]): + async def handle(self, request: SimpleCommand): + yield SimpleResult(value=request.value) + + def clear_events(self) -> None: + pass + + with pytest.raises(TypeError, match="same handler base type"): + RequestHandlerFallback(PrimaryHandler, StreamingPrimary)