Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions examples/test_stream_chunk_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Test BeforeStreamChunkEvent with a real OpenAI model.

Tests three scenarios:
1. Chunk capture - intercept and log all stream chunks
2. Chunk modification - redact text in-flight
3. Chunk skipping - suppress content deltas entirely
"""

import asyncio
import os

from strands import Agent
from strands.hooks import BeforeStreamChunkEvent
from strands.models.openai import OpenAIModel


def make_model():
return OpenAIModel(
model_id="gpt-4o-mini",
client_args={"api_key": os.environ["OPENAI_API_KEY"]},
)


# ---------- Test 1: capture chunks ----------
def test_capture():
print("=" * 60)
print("TEST 1: Capture every stream chunk")
print("=" * 60)

chunks = []

async def on_chunk(event: BeforeStreamChunkEvent):
chunks.append(event.chunk.copy())

agent = Agent(
model=make_model(),
system_prompt="Reply in exactly 5 words.",
)
agent.hooks.add_callback(BeforeStreamChunkEvent, on_chunk)

result = agent("Say hello")

print(f"\nFinal text : {result.message['content'][0].get('text', '')}")
print(f"Chunks seen: {len(chunks)}")

chunk_types = set()
for c in chunks:
chunk_types.update(c.keys())
print(f"Chunk types: {sorted(chunk_types)}")

assert len(chunks) > 0, "Expected at least one chunk"
assert "messageStart" in chunk_types, "Missing messageStart"
assert "contentBlockDelta" in chunk_types, "Missing contentBlockDelta"
print("PASSED\n")


# ---------- Test 2: modify (redact) chunks ----------
def test_modify():
print("=" * 60)
print("TEST 2: Redact text in-flight")
print("=" * 60)

async def redact(event: BeforeStreamChunkEvent):
if "contentBlockDelta" in event.chunk:
delta = event.chunk.get("contentBlockDelta", {}).get("delta", {})
if "text" in delta:
event.chunk = {"contentBlockDelta": {"delta": {"text": "[REDACTED]"}}}

agent = Agent(
model=make_model(),
system_prompt="Say exactly: the secret code is 12345",
)
agent.hooks.add_callback(BeforeStreamChunkEvent, redact)

text_events = []
result = None

async def run():
nonlocal result
async for event in agent.stream_async("go"):
if "data" in event:
text_events.append(event["data"])
if "result" in event:
result = event["result"]

asyncio.run(run())

final_text = result.message["content"][0].get("text", "")
print(f"\nStreamed text events: {text_events}")
print(f"Final message text : {final_text}")

assert all(t == "[REDACTED]" for t in text_events), "Some text events were not redacted"
assert "secret" not in final_text.lower(), "Final message still contains secret"
assert "[REDACTED]" in final_text, "Final message missing redaction marker"
print("PASSED\n")


# ---------- Test 3: skip content deltas ----------
def test_skip():
print("=" * 60)
print("TEST 3: Skip content deltas entirely")
print("=" * 60)

async def skip_deltas(event: BeforeStreamChunkEvent):
if "contentBlockDelta" in event.chunk:
event.skip = True

agent = Agent(
model=make_model(),
system_prompt="Reply with a greeting",
)
agent.hooks.add_callback(BeforeStreamChunkEvent, skip_deltas)

text_events = []
result = None

async def run():
nonlocal result
async for event in agent.stream_async("hi"):
if "data" in event:
text_events.append(event["data"])
if "result" in event:
result = event["result"]

asyncio.run(run())

print(f"\nText events received: {len(text_events)}")
print(f"Final content : {result.message['content']}")

assert len(text_events) == 0, f"Expected 0 text events, got {len(text_events)}"
assert result.message["content"] == [], "Expected empty content"
print("PASSED\n")


if __name__ == "__main__":
test_capture()
test_modify()
test_skip()
print("=" * 60)
print("ALL TESTS PASSED")
print("=" * 60)
16 changes: 14 additions & 2 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from opentelemetry import trace as trace_api

from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, BeforeStreamChunkEvent, MessageAddedEvent
from ..telemetry.metrics import Trace
from ..telemetry.tracer import Tracer, get_tracer
from ..tools._validator import validate_and_prepare_tools
Expand All @@ -39,7 +39,7 @@
MaxTokensReachedException,
StructuredOutputException,
)
from ..types.streaming import StopReason
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolResult, ToolUse
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
from ._retry import ModelRetryStrategy
Expand Down Expand Up @@ -338,6 +338,17 @@ async def _handle_model_execution(
else:
tool_specs = agent.tool_registry.get_all_tool_specs()

# Create chunk interceptor that invokes BeforeStreamChunkEvent hook
async def chunk_interceptor(chunk: StreamEvent) -> tuple[StreamEvent, bool]:
"""Intercept chunks and invoke BeforeStreamChunkEvent hook."""
stream_chunk_event = BeforeStreamChunkEvent(
agent=agent,
chunk=chunk,
invocation_state=invocation_state,
)
await agent.hooks.invoke_callbacks_async(stream_chunk_event)
return stream_chunk_event.chunk, stream_chunk_event.skip

async for event in stream_messages(
agent.model,
agent.system_prompt,
Expand All @@ -348,6 +359,7 @@ async def _handle_model_execution(
invocation_state=invocation_state,
model_state=agent._model_state,
cancel_signal=agent._cancel_signal,
chunk_interceptor=chunk_interceptor,
):
yield event

Expand Down
22 changes: 20 additions & 2 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import time
import warnings
from collections.abc import AsyncGenerator, AsyncIterable
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable
from typing import Any

from ..models.model import Model
Expand Down Expand Up @@ -42,6 +42,10 @@

logger = logging.getLogger(__name__)

# Type for chunk interceptor callback
# Takes a chunk and returns (modified_chunk, skip) where skip=True means don't process this chunk
ChunkInterceptor = Callable[[StreamEvent], Awaitable[tuple[StreamEvent, bool]]]


def _normalize_messages(messages: Messages) -> Messages:
"""Remove or replace blank text in message content.
Expand Down Expand Up @@ -387,13 +391,17 @@ async def process_stream(
chunks: AsyncIterable[StreamEvent],
start_time: float | None = None,
cancel_signal: threading.Event | None = None,
chunk_interceptor: ChunkInterceptor | None = None,
) -> AsyncGenerator[TypedEvent, None]:
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.

Args:
chunks: The chunks of the response stream from the model.
start_time: Time when the model request is initiated
cancel_signal: Optional threading.Event to check for cancellation during streaming.
chunk_interceptor: Optional callback to intercept and modify chunks before processing.
The callback receives a chunk and returns (modified_chunk, skip). If skip is True,
the chunk is not processed or yielded.

Yields:
The reason for stopping, the constructed message, and the usage metrics.
Expand Down Expand Up @@ -427,6 +435,12 @@ async def process_stream(
)
return

# Invoke chunk interceptor BEFORE processing if provided
if chunk_interceptor is not None:
chunk, skip = await chunk_interceptor(chunk)
if skip:
continue

# Track first byte time when we get first content
if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk):
first_byte_time = time.time()
Expand Down Expand Up @@ -465,6 +479,7 @@ async def stream_messages(
invocation_state: dict[str, Any] | None = None,
model_state: dict[str, Any] | None = None,
cancel_signal: threading.Event | None = None,
chunk_interceptor: ChunkInterceptor | None = None,
**kwargs: Any,
) -> AsyncGenerator[TypedEvent, None]:
"""Streams messages to the model and processes the response.
Expand All @@ -480,6 +495,9 @@ async def stream_messages(
invocation_state: Caller-provided state/context that was passed to the agent when it was invoked.
model_state: Runtime state for model providers (e.g., server-side response ids).
cancel_signal: Optional threading.Event to check for cancellation during streaming.
chunk_interceptor: Optional callback to intercept and modify chunks before processing.
The callback receives a chunk and returns (modified_chunk, skip). If skip is True,
the chunk is not processed or yielded.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand All @@ -500,5 +518,5 @@ async def stream_messages(
model_state=model_state,
)

async for event in process_stream(chunks, start_time, cancel_signal):
async for event in process_stream(chunks, start_time, cancel_signal, chunk_interceptor):
yield event
2 changes: 2 additions & 0 deletions src/strands/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
BeforeModelCallEvent,
BeforeMultiAgentInvocationEvent,
BeforeNodeCallEvent,
BeforeStreamChunkEvent,
BeforeToolCallEvent,
MessageAddedEvent,
MultiAgentInitializedEvent,
Expand All @@ -50,6 +51,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
__all__ = [
"AgentInitializedEvent",
"BeforeInvocationEvent",
"BeforeStreamChunkEvent",
"BeforeToolCallEvent",
"AfterToolCallEvent",
"BeforeModelCallEvent",
Expand Down
44 changes: 43 additions & 1 deletion src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..types.agent import AgentInput
from ..types.content import Message, Messages
from ..types.interrupt import _Interruptible
from ..types.streaming import StopReason
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import AgentTool, ToolResult, ToolUse
from .registry import BaseHookEvent, HookEvent

Expand Down Expand Up @@ -303,6 +303,48 @@ def should_reverse_callbacks(self) -> bool:
return True


@dataclass
class BeforeStreamChunkEvent(HookEvent):
"""Event triggered before each stream chunk is processed.

This event is fired for each chunk received from the model BEFORE the chunk
is processed for message building or yielded as stream events. Hook providers
can use this event to:

- Monitor streaming progress in real-time
- Modify chunk content before processing (affects final message and all events)
- Filter/skip chunks entirely by setting skip=True
- Implement content transformation (e.g., redaction, translation)

When skip=True:
- The chunk is not processed at all
- No events (ModelStreamChunkEvent, TextStreamEvent, etc.) are yielded
- The chunk does not contribute to the final message

When chunk is modified:
- The modified chunk is used for all downstream processing
- TextStreamEvent will contain the modified text
- The final message will contain the modified content

Performance Note:
This event fires for every stream chunk, so callbacks should execute
quickly to avoid impacting streaming latency.

Attributes:
chunk: The raw stream event from the model. Can be modified by hooks
to transform content before processing.
skip: When True, the chunk is skipped entirely (not processed or yielded).
invocation_state: State passed through agent invocation.
"""

chunk: StreamEvent
invocation_state: dict[str, Any] = field(default_factory=dict)
skip: bool = False

def _can_write(self, name: str) -> bool:
return name in ["chunk", "skip"]


# Multiagent hook events start here
@dataclass
class MultiAgentInitializedEvent(BaseHookEvent):
Expand Down
Loading