Skip to content

Commit 686f8f8

Browse files
committed
feat(hooks): add BeforeStreamChunkEvent for true stream chunk interception
Add BeforeStreamChunkEvent hook that fires BEFORE each stream chunk is processed, enabling true interception capabilities: - Monitor streaming progress in real-time - Modify chunk content before processing (affects final message) - Skip chunks entirely by setting skip=True (excluded from final message) - Implement content transformation (e.g., redaction, translation) Implementation details: - Added ChunkInterceptor callback type to streaming.py - Modified process_stream() to invoke interceptor BEFORE processing - Modified stream_messages() to accept chunk_interceptor parameter - event_loop.py creates interceptor that invokes BeforeStreamChunkEvent When skip=True: - The chunk is not processed at all - No events (ModelStreamChunkEvent, TextStreamEvent) are yielded - The chunk does not contribute to the final message When chunk is modified: - The modified chunk is used for all downstream processing - TextStreamEvent will contain the modified text - The final message will contain the modified content
1 parent a110149 commit 686f8f8

7 files changed

Lines changed: 433 additions & 5 deletions

File tree

src/strands/event_loop/event_loop.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from opentelemetry import trace as trace_api
1717

18-
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent
18+
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, BeforeStreamChunkEvent, MessageAddedEvent
1919
from ..telemetry.metrics import Trace
2020
from ..telemetry.tracer import Tracer, get_tracer
2121
from ..tools._validator import validate_and_prepare_tools
@@ -39,7 +39,7 @@
3939
MaxTokensReachedException,
4040
StructuredOutputException,
4141
)
42-
from ..types.streaming import StopReason
42+
from ..types.streaming import StopReason, StreamEvent
4343
from ..types.tools import ToolResult, ToolUse
4444
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
4545
from ._retry import ModelRetryStrategy
@@ -327,6 +327,18 @@ async def _handle_model_execution(
327327
tool_specs = [tool_spec] if tool_spec else []
328328
else:
329329
tool_specs = agent.tool_registry.get_all_tool_specs()
330+
331+
# Create chunk interceptor that invokes BeforeStreamChunkEvent hook
332+
async def chunk_interceptor(chunk: StreamEvent) -> tuple[StreamEvent, bool]:
333+
"""Intercept chunks and invoke BeforeStreamChunkEvent hook."""
334+
stream_chunk_event = BeforeStreamChunkEvent(
335+
agent=agent,
336+
chunk=chunk,
337+
invocation_state=invocation_state,
338+
)
339+
await agent.hooks.invoke_callbacks_async(stream_chunk_event)
340+
return stream_chunk_event.chunk, stream_chunk_event.skip
341+
330342
try:
331343
async for event in stream_messages(
332344
agent.model,
@@ -337,6 +349,7 @@ async def _handle_model_execution(
337349
tool_choice=structured_output_context.tool_choice,
338350
invocation_state=invocation_state,
339351
cancel_signal=agent._cancel_signal,
352+
chunk_interceptor=chunk_interceptor,
340353
):
341354
yield event
342355

src/strands/event_loop/streaming.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import threading
66
import time
77
import warnings
8-
from collections.abc import AsyncGenerator, AsyncIterable
8+
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable
99
from typing import Any
1010

1111
from ..models.model import Model
@@ -42,6 +42,10 @@
4242

4343
logger = logging.getLogger(__name__)
4444

45+
# Type for chunk interceptor callback
46+
# Takes a chunk and returns (modified_chunk, skip) where skip=True means don't process this chunk
47+
ChunkInterceptor = Callable[[StreamEvent], Awaitable[tuple[StreamEvent, bool]]]
48+
4549

4650
def _normalize_messages(messages: Messages) -> Messages:
4751
"""Remove or replace blank text in message content.
@@ -387,13 +391,17 @@ async def process_stream(
387391
chunks: AsyncIterable[StreamEvent],
388392
start_time: float | None = None,
389393
cancel_signal: threading.Event | None = None,
394+
chunk_interceptor: ChunkInterceptor | None = None,
390395
) -> AsyncGenerator[TypedEvent, None]:
391396
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.
392397
393398
Args:
394399
chunks: The chunks of the response stream from the model.
395400
start_time: Time when the model request is initiated
396401
cancel_signal: Optional threading.Event to check for cancellation during streaming.
402+
chunk_interceptor: Optional callback to intercept and modify chunks before processing.
403+
The callback receives a chunk and returns (modified_chunk, skip). If skip is True,
404+
the chunk is not processed or yielded.
397405
398406
Yields:
399407
The reason for stopping, the constructed message, and the usage metrics.
@@ -427,6 +435,12 @@ async def process_stream(
427435
)
428436
return
429437

438+
# Invoke chunk interceptor BEFORE processing if provided
439+
if chunk_interceptor is not None:
440+
chunk, skip = await chunk_interceptor(chunk)
441+
if skip:
442+
continue
443+
430444
# Track first byte time when we get first content
431445
if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk):
432446
first_byte_time = time.time()
@@ -464,6 +478,7 @@ async def stream_messages(
464478
system_prompt_content: list[SystemContentBlock] | None = None,
465479
invocation_state: dict[str, Any] | None = None,
466480
cancel_signal: threading.Event | None = None,
481+
chunk_interceptor: ChunkInterceptor | None = None,
467482
**kwargs: Any,
468483
) -> AsyncGenerator[TypedEvent, None]:
469484
"""Streams messages to the model and processes the response.
@@ -478,6 +493,9 @@ async def stream_messages(
478493
system prompt data.
479494
invocation_state: Caller-provided state/context that was passed to the agent when it was invoked.
480495
cancel_signal: Optional threading.Event to check for cancellation during streaming.
496+
chunk_interceptor: Optional callback to intercept and modify chunks before processing.
497+
The callback receives a chunk and returns (modified_chunk, skip). If skip is True,
498+
the chunk is not processed or yielded.
481499
**kwargs: Additional keyword arguments for future extensibility.
482500
483501
Yields:
@@ -497,5 +515,5 @@ async def stream_messages(
497515
invocation_state=invocation_state,
498516
)
499517

500-
async for event in process_stream(chunks, start_time, cancel_signal):
518+
async for event in process_stream(chunks, start_time, cancel_signal, chunk_interceptor):
501519
yield event

src/strands/hooks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
4141
BeforeModelCallEvent,
4242
BeforeMultiAgentInvocationEvent,
4343
BeforeNodeCallEvent,
44+
BeforeStreamChunkEvent,
4445
BeforeToolCallEvent,
4546
MessageAddedEvent,
4647
MultiAgentInitializedEvent,
@@ -50,6 +51,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
5051
__all__ = [
5152
"AgentInitializedEvent",
5253
"BeforeInvocationEvent",
54+
"BeforeStreamChunkEvent",
5355
"BeforeToolCallEvent",
5456
"AfterToolCallEvent",
5557
"BeforeModelCallEvent",

src/strands/hooks/events.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..types.agent import AgentInput
1616
from ..types.content import Message, Messages
1717
from ..types.interrupt import _Interruptible
18-
from ..types.streaming import StopReason
18+
from ..types.streaming import StopReason, StreamEvent
1919
from ..types.tools import AgentTool, ToolResult, ToolUse
2020
from .registry import BaseHookEvent, HookEvent
2121

@@ -303,6 +303,48 @@ def should_reverse_callbacks(self) -> bool:
303303
return True
304304

305305

306+
@dataclass
307+
class BeforeStreamChunkEvent(HookEvent):
308+
"""Event triggered before each stream chunk is processed.
309+
310+
This event is fired for each chunk received from the model BEFORE the chunk
311+
is processed for message building or yielded as stream events. Hook providers
312+
can use this event to:
313+
314+
- Monitor streaming progress in real-time
315+
- Modify chunk content before processing (affects final message and all events)
316+
- Filter/skip chunks entirely by setting skip=True
317+
- Implement content transformation (e.g., redaction, translation)
318+
319+
When skip=True:
320+
- The chunk is not processed at all
321+
- No events (ModelStreamChunkEvent, TextStreamEvent, etc.) are yielded
322+
- The chunk does not contribute to the final message
323+
324+
When chunk is modified:
325+
- The modified chunk is used for all downstream processing
326+
- TextStreamEvent will contain the modified text
327+
- The final message will contain the modified content
328+
329+
Performance Note:
330+
This event fires for every stream chunk, so callbacks should execute
331+
quickly to avoid impacting streaming latency.
332+
333+
Attributes:
334+
chunk: The raw stream event from the model. Can be modified by hooks
335+
to transform content before processing.
336+
skip: When True, the chunk is skipped entirely (not processed or yielded).
337+
invocation_state: State passed through agent invocation.
338+
"""
339+
340+
chunk: StreamEvent
341+
invocation_state: dict[str, Any] = field(default_factory=dict)
342+
skip: bool = False
343+
344+
def _can_write(self, name: str) -> bool:
345+
return name in ["chunk", "skip"]
346+
347+
306348
# Multiagent hook events start here
307349
@dataclass
308350
class MultiAgentInitializedEvent(BaseHookEvent):

tests/strands/agent/hooks/test_events.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
AgentInitializedEvent,
1111
BeforeInvocationEvent,
1212
BeforeModelCallEvent,
13+
BeforeStreamChunkEvent,
1314
BeforeToolCallEvent,
1415
MessageAddedEvent,
1516
)
@@ -260,3 +261,58 @@ def test_after_invocation_event_resume_accepts_various_input_types(agent):
260261
# None to stop
261262
event.resume = None
262263
assert event.resume is None
264+
265+
266+
@pytest.fixture
267+
def before_stream_chunk_event(agent):
268+
chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}}
269+
return BeforeStreamChunkEvent(
270+
agent=agent,
271+
chunk=chunk,
272+
invocation_state={"test": "state"},
273+
)
274+
275+
276+
def test_before_stream_chunk_event_should_not_reverse_callbacks(before_stream_chunk_event):
277+
"""Test that BeforeStreamChunkEvent does not reverse callbacks."""
278+
assert before_stream_chunk_event.should_reverse_callbacks is False
279+
280+
281+
def test_before_stream_chunk_event_can_write_chunk(before_stream_chunk_event):
282+
"""Test that BeforeStreamChunkEvent.chunk is writable."""
283+
new_chunk = {"contentBlockDelta": {"delta": {"text": "Modified"}}}
284+
before_stream_chunk_event.chunk = new_chunk
285+
assert before_stream_chunk_event.chunk == new_chunk
286+
287+
288+
def test_before_stream_chunk_event_can_write_skip(before_stream_chunk_event):
289+
"""Test that BeforeStreamChunkEvent.skip is writable."""
290+
assert before_stream_chunk_event.skip is False
291+
before_stream_chunk_event.skip = True
292+
assert before_stream_chunk_event.skip is True
293+
294+
295+
def test_before_stream_chunk_event_cannot_write_agent(before_stream_chunk_event):
296+
"""Test that BeforeStreamChunkEvent.agent is not writable."""
297+
with pytest.raises(AttributeError, match="Property agent is not writable"):
298+
before_stream_chunk_event.agent = Mock()
299+
300+
301+
def test_before_stream_chunk_event_cannot_write_invocation_state(before_stream_chunk_event):
302+
"""Test that BeforeStreamChunkEvent.invocation_state is not writable."""
303+
with pytest.raises(AttributeError, match="Property invocation_state is not writable"):
304+
before_stream_chunk_event.invocation_state = {}
305+
306+
307+
def test_before_stream_chunk_event_skip_defaults_to_false(agent):
308+
"""Test that BeforeStreamChunkEvent.skip defaults to False."""
309+
chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}}
310+
event = BeforeStreamChunkEvent(agent=agent, chunk=chunk)
311+
assert event.skip is False
312+
313+
314+
def test_before_stream_chunk_event_invocation_state_defaults_to_empty(agent):
315+
"""Test that BeforeStreamChunkEvent.invocation_state defaults to empty dict."""
316+
chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}}
317+
event = BeforeStreamChunkEvent(agent=agent, chunk=chunk)
318+
assert event.invocation_state == {}

0 commit comments

Comments
 (0)