diff --git a/src/strands/__init__.py b/src/strands/__init__.py index bc17497a0..93fffcfa5 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -3,6 +3,7 @@ from . import agent, models, telemetry, types from .agent.agent import Agent from .agent.base import AgentBase +from .hooks.decorator import hook from .tools.decorator import tool from .types.tools import ToolContext @@ -10,6 +11,7 @@ "Agent", "AgentBase", "agent", + "hook", "models", "tool", "ToolContext", diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 30163f207..2d6cff7d4 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -5,7 +5,7 @@ built-in SDK components and user code to react to or modify agent behavior through strongly-typed event callbacks. -Example Usage: +Example Usage with Class-Based Hooks: ```python from strands.hooks import HookProvider, HookRegistry from strands.hooks.events import BeforeInvocationEvent, AfterInvocationEvent @@ -25,10 +25,24 @@ def log_end(self, event: AfterInvocationEvent) -> None: agent = Agent(hooks=[LoggingHooks()]) ``` -This replaces the older callback_handler approach with a more composable, -type-safe system that supports multiple subscribers per event type. +Example Usage with Decorator-Based Hooks: + ```python + from strands import Agent, hook + from strands.hooks import BeforeToolCallEvent + + @hook + def log_tool_calls(event: BeforeToolCallEvent) -> None: + '''Log all tool calls before execution.''' + print(f"Tool: {event.tool_use}") + + agent = Agent(hooks=[log_tool_calls]) + ``` + +This module supports both the class-based HookProvider approach and the newer +decorator-based @hook approach for maximum flexibility. """ +from .decorator import DecoratedFunctionHook, FunctionHookMetadata, HookMetadata, hook from .events import ( AfterInvocationEvent, AfterModelCallEvent, @@ -42,6 +56,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry __all__ = [ + # Events "AgentInitializedEvent", "BeforeInvocationEvent", "BeforeToolCallEvent", @@ -50,10 +65,15 @@ def log_end(self, event: AfterInvocationEvent) -> None: "AfterModelCallEvent", "AfterInvocationEvent", "MessageAddedEvent", + # Registry "HookEvent", "HookProvider", "HookCallback", "HookRegistry", - "HookEvent", "BaseHookEvent", + # Decorator + "hook", + "DecoratedFunctionHook", + "FunctionHookMetadata", + "HookMetadata", ] diff --git a/src/strands/hooks/decorator.py b/src/strands/hooks/decorator.py new file mode 100644 index 000000000..0f4f1be15 --- /dev/null +++ b/src/strands/hooks/decorator.py @@ -0,0 +1,548 @@ +"""Hook decorator for simplified hook definitions. + +This module provides the @hook decorator that transforms Python functions into +HookProvider implementations with automatic event type detection from type hints. + +The @hook decorator mirrors the ergonomics of the existing @tool decorator, +making hooks as easy to define and share via PyPI packages as tools are today. + +Example: + ```python + from strands import Agent, hook + from strands.hooks import BeforeToolCallEvent + + @hook + def log_tool_calls(event: BeforeToolCallEvent) -> None: + '''Log all tool calls before execution.''' + print(f"Tool: {event.tool_use}") + + # With automatic agent injection: + @hook + def log_with_agent(event: BeforeToolCallEvent, agent: Agent) -> None: + '''Log tool calls with agent context.''' + print(f"Agent {agent.name} calling tool: {event.tool_use}") + + agent = Agent(hooks=[log_tool_calls, log_with_agent]) + ``` +""" + +import functools +import inspect +import logging +import sys +import types +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, + get_args, + get_origin, + get_type_hints, + overload, +) + +from .registry import BaseHookEvent, HookCallback, HookProvider, HookRegistry + +if TYPE_CHECKING: + from ..agent import Agent + +logger = logging.getLogger(__name__) + + +TEvent = TypeVar("TEvent", bound=BaseHookEvent) + + +@dataclass +class HookMetadata: + """Metadata extracted from a decorated hook function. + + Attributes: + name: The name of the hook function. + description: Description extracted from the function's docstring. + event_types: List of event types this hook handles. + is_async: Whether the hook function is async. + has_agent_param: Whether the function has an 'agent' parameter for injection. + """ + + name: str + description: str + event_types: list[Type[BaseHookEvent]] + is_async: bool + has_agent_param: bool = False + + +class FunctionHookMetadata: + """Helper class to extract and manage function metadata for hook decoration. + + This class handles the extraction of metadata from Python functions including: + - Function name and description from docstrings + - Event types from type hints + - Async detection + - Agent parameter detection for automatic injection + """ + + def __init__( + self, + func: Callable[..., Any], + event_types: Optional[Sequence[Type[BaseHookEvent]]] = None, + ) -> None: + """Initialize with the function to process. + + Args: + func: The function to extract metadata from. + event_types: Optional explicit event types. If not provided, + will be extracted from type hints. + """ + self.func = func + self.signature = inspect.signature(func) + self._explicit_event_types = list(event_types) if event_types else None + + # Validate and extract event types + self._event_types = self._resolve_event_types() + self._validate_event_types() + + # Check for agent parameter + self._has_agent_param = self._check_agent_parameter() + + def _check_agent_parameter(self) -> bool: + """Check if the function has an 'agent' parameter for injection. + + Returns: + True if the function has an 'agent' parameter. + """ + return "agent" in self.signature.parameters + + def _resolve_event_types(self) -> list[Type[BaseHookEvent]]: + """Resolve event types from explicit parameter or type hints. + + Returns: + List of event types this hook handles. + + Raises: + ValueError: If no event type can be determined. + """ + # Use explicit event types if provided + if self._explicit_event_types: + return self._explicit_event_types + + # Try to extract from type hints + try: + type_hints = get_type_hints(self.func) + except Exception: + # get_type_hints can fail for various reasons (forward refs, etc.) + type_hints = {} + + # Find the first parameter's type hint (should be the event) + params = list(self.signature.parameters.values()) + if not params: + raise ValueError( + f"Hook function '{self.func.__name__}' must have at least one parameter " + "for the event. Use @hook(event=EventType) if type hints are unavailable." + ) + + first_param = params[0] + event_type = type_hints.get(first_param.name) + + if event_type is None: + # Check annotation directly (for cases where get_type_hints fails) + if first_param.annotation is not inspect.Parameter.empty: + event_type = first_param.annotation + else: + raise ValueError( + f"Hook function '{self.func.__name__}' must have a type hint for the event parameter, " + "or use @hook(event=EventType) to specify the event type explicitly." + ) + + # Handle Union types (e.g., BeforeToolCallEvent | AfterToolCallEvent) + return self._extract_event_types_from_annotation(event_type) + + def _is_union_type(self, annotation: Any) -> bool: + """Check if annotation is a Union type (typing.Union or types.UnionType). + + Args: + annotation: The type annotation to check. + + Returns: + True if the annotation is a Union type. + """ + origin = get_origin(annotation) + if origin is Union: + return True + + # Python 3.10+ uses types.UnionType for `A | B` syntax + if sys.version_info >= (3, 10): + if isinstance(annotation, types.UnionType): + return True + + return False + + def _extract_event_types_from_annotation(self, annotation: Any) -> list[Type[BaseHookEvent]]: + """Extract event types from a type annotation. + + Handles Union types and single types. + + Args: + annotation: The type annotation to extract from. + + Returns: + List of event types. + """ + # Handle Union types (Union[A, B] or A | B) + if self._is_union_type(annotation): + args = get_args(annotation) + event_types = [] + for arg in args: + # Skip NoneType in Optional[X] + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, BaseHookEvent): + event_types.append(arg) + else: + raise ValueError(f"All types in Union must be subclasses of BaseHookEvent, got {arg}") + return event_types + + # Single type + if isinstance(annotation, type) and issubclass(annotation, BaseHookEvent): + return [annotation] + + raise ValueError(f"Event type must be a subclass of BaseHookEvent, got {annotation}") + + def _validate_event_types(self) -> None: + """Validate that all event types are valid. + + Raises: + ValueError: If any event type is invalid. + """ + if not self._event_types: + raise ValueError(f"Hook function '{self.func.__name__}' must handle at least one event type.") + + for event_type in self._event_types: + if not isinstance(event_type, type) or not issubclass(event_type, BaseHookEvent): + raise ValueError(f"Event type must be a subclass of BaseHookEvent, got {event_type}") + + def extract_metadata(self) -> HookMetadata: + """Extract metadata from the function to create hook specification. + + Returns: + HookMetadata containing the function's hook specification. + """ + func_name = self.func.__name__ + + # Extract description from docstring + description = inspect.getdoc(self.func) or func_name + + # Check if async + is_async = inspect.iscoroutinefunction(self.func) + + return HookMetadata( + name=func_name, + description=description, + event_types=self._event_types, + is_async=is_async, + has_agent_param=self._has_agent_param, + ) + + @property + def event_types(self) -> list[Type[BaseHookEvent]]: + """Get the event types this hook handles.""" + return self._event_types + + @property + def has_agent_param(self) -> bool: + """Check if the function has an 'agent' parameter.""" + return self._has_agent_param + + +class DecoratedFunctionHook(HookProvider, Generic[TEvent]): + """A HookProvider that wraps a function decorated with @hook. + + This class adapts Python functions decorated with @hook to the HookProvider + interface, enabling them to be used with Agent's hooks parameter. + + The class is generic over the event type to maintain type safety. + + Features: + - Automatic agent injection: If the hook function has an 'agent' parameter, + it will be automatically injected from event.agent when the hook is called. + """ + + _func: Callable[[TEvent], Any] + _metadata: FunctionHookMetadata + _hook_metadata: HookMetadata + + def __init__( + self, + func: Callable[[TEvent], Any], + metadata: FunctionHookMetadata, + ): + """Initialize the decorated function hook. + + Args: + func: The original function being decorated. + metadata: The FunctionHookMetadata object with extracted function information. + """ + self._func = func + self._metadata = metadata + self._hook_metadata = metadata.extract_metadata() + + # Preserve function metadata + functools.update_wrapper(wrapper=self, wrapped=self._func) + + def _create_callback_with_injection(self) -> HookCallback[BaseHookEvent]: + """Create a callback that handles agent injection. + + Returns: + A callback that wraps the original function with agent injection. + """ + func = self._func + has_agent_param = self._hook_metadata.has_agent_param + + if has_agent_param: + # Create wrapper that injects agent + if self._hook_metadata.is_async: + + async def async_callback_with_agent(event: BaseHookEvent) -> None: + await func(event, agent=event.agent) # type: ignore[arg-type] + + return cast(HookCallback[BaseHookEvent], async_callback_with_agent) + else: + + def sync_callback_with_agent(event: BaseHookEvent) -> None: + func(event, agent=event.agent) # type: ignore[arg-type] + + return cast(HookCallback[BaseHookEvent], sync_callback_with_agent) + else: + # No injection needed, use function directly + return cast(HookCallback[BaseHookEvent], func) + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register callback functions for specific event types. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments (unused, for protocol compatibility). + """ + callback = self._create_callback_with_injection() + for event_type in self._metadata.event_types: + registry.add_callback(event_type, callback) + + def __call__(self, event: TEvent, agent: Optional["Agent"] = None) -> Any: + """Allow direct invocation for testing. + + Args: + event: The event to process. + agent: Optional agent instance. If not provided and the hook + expects an agent parameter, it will be extracted from event.agent. + + Returns: + The result of the hook function. + """ + if self._hook_metadata.has_agent_param: + # Use provided agent or fall back to event.agent + actual_agent = agent if agent is not None else event.agent + return self._func(event, agent=actual_agent) # type: ignore[arg-type] + return self._func(event) + + @property + def name(self) -> str: + """Get the name of the hook. + + Returns: + The hook name as a string. + """ + return self._hook_metadata.name + + @property + def description(self) -> str: + """Get the description of the hook. + + Returns: + The hook description as a string. + """ + return self._hook_metadata.description + + @property + def event_types(self) -> list[Type[BaseHookEvent]]: + """Get the event types this hook handles. + + Returns: + List of event types. + """ + return self._hook_metadata.event_types + + @property + def is_async(self) -> bool: + """Check if this hook is async. + + Returns: + True if the hook function is async. + """ + return self._hook_metadata.is_async + + @property + def has_agent_param(self) -> bool: + """Check if this hook has an agent parameter. + + Returns: + True if the hook function expects an agent parameter. + """ + return self._hook_metadata.has_agent_param + + def __repr__(self) -> str: + """Return a string representation of the hook.""" + event_names = [e.__name__ for e in self._hook_metadata.event_types] + agent_info = ", agent_injection=True" if self._hook_metadata.has_agent_param else "" + return f"DecoratedFunctionHook({self._hook_metadata.name}, events={event_names}{agent_info})" + + +# Type variable for the decorated function +F = TypeVar("F", bound=Callable[..., Any]) + + +# Handle @hook +@overload +def hook(__func: F) -> DecoratedFunctionHook[Any]: ... + + +# Handle @hook(event=...) +@overload +def hook( + *, + event: Optional[Type[BaseHookEvent]] = None, + events: Optional[Sequence[Type[BaseHookEvent]]] = None, +) -> Callable[[F], DecoratedFunctionHook[Any]]: ... + + +def hook( + func: Optional[F] = None, + event: Optional[Type[BaseHookEvent]] = None, + events: Optional[Sequence[Type[BaseHookEvent]]] = None, +) -> Union[DecoratedFunctionHook[Any], Callable[[F], DecoratedFunctionHook[Any]]]: + """Decorator that transforms a Python function into a Strands hook. + + This decorator enables simple, function-based hook definitions - mirroring + the ergonomics of the existing @tool decorator. It extracts the event type + from the function's type hints or from explicit parameters. + + When decorated, a function: + 1. Implements the HookProvider protocol automatically + 2. Can be passed directly to Agent(hooks=[...]) + 3. Still works as a normal function when called directly + 4. Supports both sync and async hook functions + 5. Supports automatic agent injection via 'agent' parameter + + The decorator can be used in several ways: + + 1. Simple decorator with type hints: + ```python + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + print(f"Tool: {event.tool_use}") + ``` + + 2. With automatic agent injection: + ```python + @hook + def my_hook(event: BeforeToolCallEvent, agent: Agent) -> None: + print(f"Agent: {agent.name}") + print(f"Tool: {event.tool_use}") + ``` + + 3. With explicit event type: + ```python + @hook(event=BeforeToolCallEvent) + def my_hook(event) -> None: + print(f"Tool: {event.tool_use}") + ``` + + 4. For multiple event types: + ```python + @hook(events=[BeforeToolCallEvent, AfterToolCallEvent]) + def my_hook(event: BeforeToolCallEvent | AfterToolCallEvent) -> None: + print(f"Event: {event}") + ``` + + 5. With Union type hint: + ```python + @hook + def my_hook(event: BeforeToolCallEvent | AfterToolCallEvent) -> None: + print(f"Event: {event}") + ``` + + Args: + func: The function to decorate. When used as a simple decorator, + this is the function being decorated. When used with parameters, + this will be None. + event: Optional single event type to handle. Takes precedence over + type hint detection. + events: Optional list of event types to handle. Takes precedence over + both `event` parameter and type hint detection. + + Returns: + A DecoratedFunctionHook that implements HookProvider and can be used + directly with Agent(hooks=[...]). + + Raises: + ValueError: If no event type can be determined from type hints or parameters. + ValueError: If event types are not subclasses of BaseHookEvent. + + Example: + ```python + from strands import Agent, hook + from strands.hooks import BeforeToolCallEvent, AfterToolCallEvent + + @hook + def log_tool_calls(event: BeforeToolCallEvent) -> None: + '''Log all tool calls before execution.''' + print(f"Calling tool: {event.tool_use['name']}") + + @hook + def log_with_agent(event: BeforeToolCallEvent, agent: Agent) -> None: + '''Log with direct agent access.''' + print(f"Agent {agent.name} calling tool: {event.tool_use['name']}") + + @hook + async def async_audit(event: AfterToolCallEvent) -> None: + '''Async hook for auditing tool results.''' + await send_to_audit_service(event.result) + + @hook(events=[BeforeToolCallEvent, AfterToolCallEvent]) + def tool_lifecycle(event: BeforeToolCallEvent | AfterToolCallEvent) -> None: + '''Track the complete tool lifecycle.''' + if isinstance(event, BeforeToolCallEvent): + print("Starting tool...") + else: + print("Tool complete!") + + agent = Agent(hooks=[log_tool_calls, log_with_agent, async_audit, tool_lifecycle]) + ``` + """ + + def decorator(f: F) -> DecoratedFunctionHook[Any]: + # Determine event types from parameters or type hints + event_types: Optional[list[Type[BaseHookEvent]]] = None + + if events is not None: + event_types = list(events) + elif event is not None: + event_types = [event] + # Otherwise, let FunctionHookMetadata extract from type hints + + # Create function hook metadata + hook_meta = FunctionHookMetadata(f, event_types) + + return DecoratedFunctionHook(f, hook_meta) + + # Handle both @hook and @hook() syntax + if func is None: + return decorator + + return decorator(func) diff --git a/tests/strands/hooks/test_decorator.py b/tests/strands/hooks/test_decorator.py new file mode 100644 index 000000000..584997d68 --- /dev/null +++ b/tests/strands/hooks/test_decorator.py @@ -0,0 +1,715 @@ +"""Tests for the @hook decorator.""" + +from typing import Union +from unittest.mock import MagicMock + +import pytest + +from strands.hooks import ( + AfterToolCallEvent, + BeforeInvocationEvent, + BeforeToolCallEvent, + DecoratedFunctionHook, + FunctionHookMetadata, + HookMetadata, + HookRegistry, + hook, +) + + +class TestHookDecorator: + """Tests for the @hook decorator function.""" + + def test_basic_decorator_with_type_hint(self): + """Test @hook with type hints extracts event type correctly.""" + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + pass + + assert isinstance(my_hook, DecoratedFunctionHook) + assert my_hook.name == "my_hook" + assert my_hook.event_types == [BeforeToolCallEvent] + assert not my_hook.is_async + + def test_decorator_with_explicit_event(self): + """Test @hook(event=...) syntax.""" + + @hook(event=BeforeToolCallEvent) + def my_hook(event) -> None: + pass + + assert isinstance(my_hook, DecoratedFunctionHook) + assert my_hook.event_types == [BeforeToolCallEvent] + + def test_decorator_with_multiple_events(self): + """Test @hook(events=[...]) syntax for multiple event types.""" + + @hook(events=[BeforeToolCallEvent, AfterToolCallEvent]) + def my_hook(event: Union[BeforeToolCallEvent, AfterToolCallEvent]) -> None: + pass + + assert isinstance(my_hook, DecoratedFunctionHook) + assert set(my_hook.event_types) == {BeforeToolCallEvent, AfterToolCallEvent} + + def test_decorator_with_union_type_hint(self): + """Test @hook with Union type hint extracts multiple event types.""" + + @hook + def my_hook(event: BeforeToolCallEvent | AfterToolCallEvent) -> None: + pass + + assert isinstance(my_hook, DecoratedFunctionHook) + assert set(my_hook.event_types) == {BeforeToolCallEvent, AfterToolCallEvent} + + def test_async_hook_detection(self): + """Test that async hooks are detected correctly.""" + + @hook + async def async_hook(event: BeforeToolCallEvent) -> None: + pass + + assert async_hook.is_async + + @hook + def sync_hook(event: BeforeToolCallEvent) -> None: + pass + + assert not sync_hook.is_async + + def test_docstring_extraction(self): + """Test that docstring is extracted as description.""" + + @hook + def documented_hook(event: BeforeToolCallEvent) -> None: + """This is a documented hook for testing.""" + pass + + assert documented_hook.description == "This is a documented hook for testing." + + def test_default_description(self): + """Test that function name is used when no docstring.""" + + @hook + def undocumented_hook(event: BeforeToolCallEvent) -> None: + pass + + assert undocumented_hook.description == "undocumented_hook" + + def test_direct_invocation(self): + """Test that decorated hooks can be called directly.""" + mock_callback = MagicMock() + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + mock_callback(event) + + # Create a mock event + mock_event = MagicMock(spec=BeforeToolCallEvent) + + # Direct invocation + my_hook(mock_event) + + mock_callback.assert_called_once_with(mock_event) + + def test_hook_registration(self): + """Test that hooks register correctly with HookRegistry.""" + callback_called = [] + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + callback_called.append(event) + + registry = HookRegistry() + my_hook.register_hooks(registry) + + # Verify callback is registered + mock_agent = MagicMock() + mock_tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {}} + event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use=mock_tool_use, + invocation_state={}, + ) + + registry.invoke_callbacks(event) + + assert len(callback_called) == 1 + assert callback_called[0] is event + + def test_multi_event_registration(self): + """Test that multi-event hooks register for all event types.""" + events_received = [] + + @hook(events=[BeforeToolCallEvent, AfterToolCallEvent]) + def multi_hook(event: Union[BeforeToolCallEvent, AfterToolCallEvent]) -> None: + events_received.append(type(event).__name__) + + registry = HookRegistry() + multi_hook.register_hooks(registry) + + # Create mock events + mock_agent = MagicMock() + mock_tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {}} + mock_result = {"toolUseId": "test-123", "status": "success", "content": []} + + before_event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use=mock_tool_use, + invocation_state={}, + ) + after_event = AfterToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use=mock_tool_use, + invocation_state={}, + result=mock_result, + ) + + registry.invoke_callbacks(before_event) + registry.invoke_callbacks(after_event) + + assert "BeforeToolCallEvent" in events_received + assert "AfterToolCallEvent" in events_received + + def test_repr(self): + """Test string representation of decorated hook.""" + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + pass + + repr_str = repr(my_hook) + assert "DecoratedFunctionHook" in repr_str + assert "my_hook" in repr_str + assert "BeforeToolCallEvent" in repr_str + + +class TestHookDecoratorErrors: + """Tests for error handling in @hook decorator.""" + + def test_no_parameters_error(self): + """Test error when function has no parameters.""" + with pytest.raises(ValueError, match="must have at least one parameter"): + + @hook + def no_params() -> None: + pass + + def test_no_type_hint_error(self): + """Test error when no type hint and no explicit event type.""" + with pytest.raises(ValueError, match="must have a type hint"): + + @hook + def no_hint(event) -> None: + pass + + def test_invalid_event_type_error(self): + """Test error when event type is not a BaseHookEvent subclass.""" + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + + @hook(event=str) # type: ignore + def invalid_event(event) -> None: + pass + + def test_invalid_union_type_error(self): + """Test error when Union contains non-event types.""" + with pytest.raises(ValueError, match="must be subclasses of BaseHookEvent"): + + @hook + def invalid_union(event: BeforeToolCallEvent | str) -> None: # type: ignore + pass + + +class TestFunctionHookMetadata: + """Tests for FunctionHookMetadata class.""" + + def test_metadata_extraction(self): + """Test metadata extraction from function.""" + + def my_func(event: BeforeToolCallEvent) -> None: + """A test hook function.""" + pass + + metadata = FunctionHookMetadata(my_func) + hook_meta = metadata.extract_metadata() + + assert isinstance(hook_meta, HookMetadata) + assert hook_meta.name == "my_func" + assert hook_meta.description == "A test hook function." + assert hook_meta.event_types == [BeforeToolCallEvent] + assert not hook_meta.is_async + + def test_explicit_event_types_override(self): + """Test that explicit event types override type hints.""" + + def my_func(event: BeforeToolCallEvent) -> None: + pass + + # Explicitly specify different event type + metadata = FunctionHookMetadata(my_func, event_types=[AfterToolCallEvent]) + + assert metadata.event_types == [AfterToolCallEvent] + + +class TestDecoratedFunctionHook: + """Tests for DecoratedFunctionHook class.""" + + def test_hook_provider_protocol(self): + """Test that DecoratedFunctionHook implements HookProvider.""" + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + pass + + # Should have register_hooks method + assert hasattr(my_hook, "register_hooks") + assert callable(my_hook.register_hooks) + + def test_function_wrapper_preserves_metadata(self): + """Test that functools.wraps preserves function metadata.""" + + @hook + def original_function(event: BeforeToolCallEvent) -> None: + """Original docstring.""" + pass + + assert original_function.__name__ == "original_function" + assert original_function.__doc__ == "Original docstring." + + +class TestMixedHooksUsage: + """Tests for using decorated hooks alongside class-based hooks.""" + + def test_mixed_hooks_in_registry(self): + """Test using both decorator and class-based hooks together.""" + from strands.hooks import HookProvider, HookRegistry + + decorator_called = [] + class_called = [] + + @hook + def decorator_hook(event: BeforeInvocationEvent) -> None: + decorator_called.append(event) + + class ClassHook(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(BeforeInvocationEvent, self.on_event) + + def on_event(self, event: BeforeInvocationEvent) -> None: + class_called.append(event) + + registry = HookRegistry() + registry.add_hook(decorator_hook) + registry.add_hook(ClassHook()) + + # Create mock event + mock_agent = MagicMock() + event = BeforeInvocationEvent(agent=mock_agent) + + registry.invoke_callbacks(event) + + assert len(decorator_called) == 1 + assert len(class_called) == 1 + + +class TestAgentInjection: + """Tests for automatic agent injection in @hook decorated functions.""" + + def test_agent_param_detection(self): + """Test that agent parameter is correctly detected.""" + from strands.agent import Agent + + @hook + def with_agent(event: BeforeToolCallEvent, agent: Agent) -> None: + pass + + @hook + def without_agent(event: BeforeToolCallEvent) -> None: + pass + + assert with_agent.has_agent_param is True + assert without_agent.has_agent_param is False + + def test_agent_injection_in_repr(self): + """Test that agent injection is shown in repr.""" + from strands.agent import Agent + + @hook + def with_agent(event: BeforeToolCallEvent, agent: Agent) -> None: + pass + + assert "agent_injection=True" in repr(with_agent) + + def test_hook_without_agent_param_not_injected(self): + """Test that hooks without agent param work normally.""" + received_events = [] + + @hook + def simple_hook(event: BeforeToolCallEvent) -> None: + received_events.append(event) + + # Create a mock event + mock_agent = MagicMock() + mock_event = MagicMock(spec=BeforeToolCallEvent) + mock_event.agent = mock_agent + + # Call directly + simple_hook(mock_event) + + assert len(received_events) == 1 + assert received_events[0] is mock_event + + def test_hook_with_agent_param_receives_agent(self): + """Test that hooks with agent param receive agent via injection.""" + received_data = [] + + @hook + def hook_with_agent(event: BeforeToolCallEvent, agent) -> None: + received_data.append({"event": event, "agent": agent}) + + # Create mock event with agent + mock_agent = MagicMock() + mock_agent.name = "test_agent" + mock_event = MagicMock(spec=BeforeToolCallEvent) + mock_event.agent = mock_agent + + # Call directly - agent should be extracted from event.agent + hook_with_agent(mock_event) + + assert len(received_data) == 1 + assert received_data[0]["event"] is mock_event + assert received_data[0]["agent"] is mock_agent + + def test_direct_call_with_explicit_agent(self): + """Test direct invocation with explicit agent parameter.""" + received_data = [] + + @hook + def hook_with_agent(event: BeforeToolCallEvent, agent) -> None: + received_data.append({"event": event, "agent": agent}) + + # Create mocks + mock_event = MagicMock(spec=BeforeToolCallEvent) + mock_event.agent = MagicMock(name="event_agent") + explicit_agent = MagicMock(name="explicit_agent") + + # Call with explicit agent - should use explicit over event.agent + hook_with_agent(mock_event, agent=explicit_agent) + + assert len(received_data) == 1 + assert received_data[0]["agent"] is explicit_agent + + def test_agent_injection_with_registry(self): + """Test agent injection when registered with HookRegistry.""" + received_data = [] + + @hook + def hook_with_agent(event: BeforeToolCallEvent, agent) -> None: + received_data.append({"event": event, "agent": agent}) + + # Create registry and register hook + registry = HookRegistry() + hook_with_agent.register_hooks(registry) + + # Create a real BeforeToolCallEvent (not mock) since registry uses type() + mock_agent = MagicMock() + mock_agent.name = "registry_test_agent" + + # Create actual event instance + mock_tool = MagicMock() + event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=mock_tool, + tool_use={"toolUseId": "test-123", "name": "test_tool", "input": {}}, + invocation_state={}, + ) + + # Invoke callbacks through registry + for callback in registry.get_callbacks_for(event): + callback(event) + + assert len(received_data) == 1 + assert received_data[0]["agent"] is mock_agent + + def test_async_hook_with_agent_injection(self): + """Test async hooks with agent injection.""" + import asyncio + + received_data = [] + + @hook + async def async_hook_with_agent(event: BeforeToolCallEvent, agent) -> None: + received_data.append({"event": event, "agent": agent}) + + assert async_hook_with_agent.has_agent_param is True + assert async_hook_with_agent.is_async is True + + # Create mock event + mock_agent = MagicMock() + mock_event = MagicMock(spec=BeforeToolCallEvent) + mock_event.agent = mock_agent + + # Run async hook + asyncio.run(async_hook_with_agent(mock_event)) + + assert len(received_data) == 1 + assert received_data[0]["agent"] is mock_agent + + def test_hook_metadata_includes_agent_param(self): + """Test that HookMetadata correctly reflects agent parameter.""" + + @hook + def with_agent(event: BeforeToolCallEvent, agent) -> None: + pass + + # Access internal metadata + metadata = with_agent._hook_metadata + + assert metadata.has_agent_param is True + assert metadata.name == "with_agent" + + def test_mixed_hooks_with_and_without_agent(self): + """Test that hooks with and without agent params work together.""" + results = {"with_agent": [], "without_agent": []} + + @hook + def without_agent_hook(event: BeforeToolCallEvent) -> None: + results["without_agent"].append(event) + + @hook + def with_agent_hook(event: BeforeToolCallEvent, agent) -> None: + results["with_agent"].append({"event": event, "agent": agent}) + + # Create mock event + mock_agent = MagicMock() + mock_event = MagicMock(spec=BeforeToolCallEvent) + mock_event.agent = mock_agent + + # Call both hooks + without_agent_hook(mock_event) + with_agent_hook(mock_event) + + assert len(results["without_agent"]) == 1 + assert len(results["with_agent"]) == 1 + assert results["with_agent"][0]["agent"] is mock_agent + + +class TestCoverageGaps: + """Additional tests to cover edge cases and improve coverage.""" + + def test_optional_type_hint_extracts_event_type(self): + """Test that Optional[EventType] correctly extracts the event type (skips NoneType).""" + from typing import Optional + + @hook + def optional_hook(event: Optional[BeforeToolCallEvent]) -> None: + pass + + assert isinstance(optional_hook, DecoratedFunctionHook) + assert optional_hook.event_types == [BeforeToolCallEvent] + + def test_async_hook_with_agent_via_registry(self): + """Test async hook with agent injection when invoked via registry.""" + import asyncio + + received_data = [] + + @hook + async def async_hook_with_agent(event: BeforeToolCallEvent, agent) -> None: + received_data.append({"event": event, "agent": agent}) + + # Register with registry + registry = HookRegistry() + async_hook_with_agent.register_hooks(registry) + + # Create event + mock_agent = MagicMock() + mock_agent.name = "async_registry_agent" + event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use={"toolUseId": "test-123", "name": "test_tool", "input": {}}, + invocation_state={}, + ) + + # Get callbacks and invoke them (async) + async def run_callbacks(): + for callback in registry.get_callbacks_for(event): + result = callback(event) + if asyncio.iscoroutine(result): + await result + + asyncio.run(run_callbacks()) + + assert len(received_data) == 1 + assert received_data[0]["agent"] is mock_agent + + def test_sync_hook_with_agent_via_registry(self): + """Test sync hook with agent injection when invoked via registry.""" + received_data = [] + + @hook + def sync_hook_with_agent(event: BeforeToolCallEvent, agent) -> None: + received_data.append({"event": event, "agent": agent}) + + # Register with registry + registry = HookRegistry() + sync_hook_with_agent.register_hooks(registry) + + # Create event + mock_agent = MagicMock() + mock_agent.name = "sync_registry_agent" + event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use={"toolUseId": "test-123", "name": "test_tool", "input": {}}, + invocation_state={}, + ) + + # Get callbacks and invoke them + for callback in registry.get_callbacks_for(event): + callback(event) + + assert len(received_data) == 1 + assert received_data[0]["agent"] is mock_agent + + def test_direct_call_without_agent_param_ignores_explicit_agent(self): + """Test that hooks without agent param work even if explicit agent is passed.""" + received_events = [] + + @hook + def no_agent_hook(event: BeforeToolCallEvent) -> None: + received_events.append(event) + + # Create mock event + mock_event = MagicMock(spec=BeforeToolCallEvent) + explicit_agent = MagicMock(name="explicit_agent") + + # Call with explicit agent - should be ignored since hook doesn't take agent + no_agent_hook(mock_event, agent=explicit_agent) + + assert len(received_events) == 1 + assert received_events[0] is mock_event + + def test_get_type_hints_failure_fallback(self): + """Test that annotation is used when get_type_hints fails.""" + # Create a function with a forward reference that might cause get_type_hints to fail + # by directly testing FunctionHookMetadata with annotation + + def func_with_annotation(event: BeforeToolCallEvent) -> None: + pass + + # This should work normally + metadata = FunctionHookMetadata(func_with_annotation) + assert metadata.event_types == [BeforeToolCallEvent] + + def test_hook_parentheses_no_args(self): + """Test @hook() syntax with empty parentheses.""" + + @hook() + def my_hook(event: BeforeToolCallEvent) -> None: + pass + + assert isinstance(my_hook, DecoratedFunctionHook) + assert my_hook.event_types == [BeforeToolCallEvent] + + def test_union_with_typing_union(self): + """Test Union from typing module explicitly.""" + from typing import Union + + @hook + def union_hook(event: Union[BeforeToolCallEvent, AfterToolCallEvent]) -> None: + pass + + assert isinstance(union_hook, DecoratedFunctionHook) + assert set(union_hook.event_types) == {BeforeToolCallEvent, AfterToolCallEvent} + + def test_function_hook_metadata_event_types_property(self): + """Test FunctionHookMetadata.event_types property.""" + + def my_func(event: BeforeToolCallEvent) -> None: + pass + + metadata = FunctionHookMetadata(my_func) + # Access via property + assert metadata.event_types == [BeforeToolCallEvent] + + def test_function_hook_metadata_has_agent_param_property(self): + """Test FunctionHookMetadata.has_agent_param property.""" + + def with_agent(event: BeforeToolCallEvent, agent) -> None: + pass + + def without_agent(event: BeforeToolCallEvent) -> None: + pass + + meta_with = FunctionHookMetadata(with_agent) + meta_without = FunctionHookMetadata(without_agent) + + # Access via property + assert meta_with.has_agent_param is True + assert meta_without.has_agent_param is False + + +class TestAdditionalErrorCases: + """Additional error case tests for complete coverage.""" + + def test_invalid_annotation_not_event_type(self): + """Test error when annotation is a non-event class type.""" + # This should trigger the error at line 216: "Event type must be a subclass of BaseHookEvent" + + class NotAnEvent: + pass + + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + + @hook + def invalid_hook(event: NotAnEvent) -> None: + pass + + def test_invalid_single_event_type_in_explicit_list(self): + """Test error when explicit event list contains invalid type.""" + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + + @hook(events=[str]) # type: ignore + def invalid_events_hook(event) -> None: + pass + + +class TestEdgeCases: + """Edge case tests for remaining coverage gaps.""" + + def test_get_type_hints_exception_fallback(self): + """Test fallback when get_type_hints raises an exception. + + This can happen with certain forward references or complex type annotations. + """ + # Create a function with annotation that get_type_hints might struggle with + # but that still has a valid annotation + + def func_with_annotation(event: BeforeToolCallEvent) -> None: + pass + + # Manually test by mocking get_type_hints to raise + import unittest.mock as mock + + with mock.patch("strands.hooks.decorator.get_type_hints", side_effect=Exception("Type hint error")): + metadata = FunctionHookMetadata(func_with_annotation) + # Should fall back to annotation + assert metadata.event_types == [BeforeToolCallEvent] + + def test_annotation_fallback_when_type_hints_empty(self): + """Test annotation is used when get_type_hints returns empty dict for param.""" + import unittest.mock as mock + + def func_with_annotation(event: BeforeToolCallEvent) -> None: + pass + + # Mock get_type_hints to return empty dict (param not in hints) + with mock.patch("strands.hooks.decorator.get_type_hints", return_value={}): + metadata = FunctionHookMetadata(func_with_annotation) + # Should fall back to first_param.annotation + assert metadata.event_types == [BeforeToolCallEvent]