diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 41122efc5..05b0eff50 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -403,7 +403,9 @@ async def _handle_model_execution( # Add the response message to the conversation agent.messages.append(message) - await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=message)) + await agent.hooks.invoke_callbacks_async( + MessageAddedEvent(agent=agent, message=message, usage=usage, metrics=metrics) + ) # Update metrics agent.event_loop_metrics.update_usage(usage) diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index ad40dfd7f..7a519e993 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -13,6 +13,7 @@ from ..agent.agent_result import AgentResult from ..types.content import Message, Messages +from ..types.event_loop import Metrics, Usage from ..types.interrupt import _Interruptible from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse @@ -105,6 +106,8 @@ class MessageAddedEvent(HookEvent): """ message: Message + usage: Usage | None = None + metrics: Metrics | None = None @dataclass diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index e8b7e5077..d7f45918d 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -178,7 +178,12 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u exception=None, ) - assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) + assert next(events) == MessageAddedEvent( + agent=agent, + message=agent.messages[1], + usage={"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + metrics={"latencyMs": 0, "timeToFirstByteMs": 0}, + ) assert next(events) == BeforeToolCallEvent( agent=agent, selected_tool=agent_tool, @@ -202,7 +207,12 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u ), exception=None, ) - assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) + assert next(events) == MessageAddedEvent( + agent=agent, + message=agent.messages[3], + usage={"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + metrics={"latencyMs": 0, "timeToFirstByteMs": 0}, + ) assert next(events) == AfterInvocationEvent(agent=agent, result=result) @@ -248,7 +258,12 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m exception=None, ) - assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) + assert next(events) == MessageAddedEvent( + agent=agent, + message=agent.messages[1], + usage={"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + metrics={"latencyMs": 0, "timeToFirstByteMs": 0}, + ) assert next(events) == BeforeToolCallEvent( agent=agent, selected_tool=agent_tool, @@ -272,7 +287,12 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m ), exception=None, ) - assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) + assert next(events) == MessageAddedEvent( + agent=agent, + message=agent.messages[3], + usage={"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + metrics={"latencyMs": 0, "timeToFirstByteMs": 0}, + ) assert next(events) == AfterInvocationEvent(agent=agent, result=result) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index a76a5b6b5..9070200c8 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -884,7 +884,10 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, # Final message assert next(events) == MessageAddedEvent( - agent=agent, message={"content": [{"text": "test text"}], "role": "assistant"} + agent=agent, + message={"content": [{"text": "test text"}], "role": "assistant"}, + usage={"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + metrics={"latencyMs": 0, "timeToFirstByteMs": 0}, )