From dc8c2aacd06b2f8d553893b9ac64b879545105cd Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Tue, 24 Feb 2026 12:28:27 -0500 Subject: [PATCH 1/2] feat(hooks): add BeforeStreamChunkEvent for true stream chunk interception Add BeforeStreamChunkEvent hook that fires BEFORE each stream chunk is processed, enabling true interception capabilities: - Monitor streaming progress in real-time - Modify chunk content before processing (affects final message) - Skip chunks entirely by setting skip=True (excluded from final message) - Implement content transformation (e.g., redaction, translation) Implementation details: - Added ChunkInterceptor callback type to streaming.py - Modified process_stream() to invoke interceptor BEFORE processing - Modified stream_messages() to accept chunk_interceptor parameter - event_loop.py creates interceptor that invokes BeforeStreamChunkEvent When skip=True: - The chunk is not processed at all - No events (ModelStreamChunkEvent, TextStreamEvent) are yielded - The chunk does not contribute to the final message When chunk is modified: - The modified chunk is used for all downstream processing - TextStreamEvent will contain the modified text - The final message will contain the modified content --- src/strands/event_loop/event_loop.py | 16 +- src/strands/event_loop/streaming.py | 22 ++- src/strands/hooks/__init__.py | 2 + src/strands/hooks/events.py | 44 ++++- tests/strands/agent/hooks/test_events.py | 56 ++++++ tests/strands/agent/test_agent_hooks.py | 175 +++++++++++++++++++ tests_integ/hooks/test_stream_chunk_event.py | 122 +++++++++++++ 7 files changed, 432 insertions(+), 5 deletions(-) create mode 100644 tests_integ/hooks/test_stream_chunk_event.py diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b4af16058..b1d3bb788 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,7 +15,7 @@ from opentelemetry import trace as trace_api -from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent +from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, BeforeStreamChunkEvent, MessageAddedEvent from ..telemetry.metrics import Trace from ..telemetry.tracer import Tracer, get_tracer from ..tools._validator import validate_and_prepare_tools @@ -39,7 +39,7 @@ MaxTokensReachedException, StructuredOutputException, ) -from ..types.streaming import StopReason +from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from ._retry import ModelRetryStrategy @@ -338,6 +338,17 @@ async def _handle_model_execution( else: tool_specs = agent.tool_registry.get_all_tool_specs() + # Create chunk interceptor that invokes BeforeStreamChunkEvent hook + async def chunk_interceptor(chunk: StreamEvent) -> tuple[StreamEvent, bool]: + """Intercept chunks and invoke BeforeStreamChunkEvent hook.""" + stream_chunk_event = BeforeStreamChunkEvent( + agent=agent, + chunk=chunk, + invocation_state=invocation_state, + ) + await agent.hooks.invoke_callbacks_async(stream_chunk_event) + return stream_chunk_event.chunk, stream_chunk_event.skip + async for event in stream_messages( agent.model, agent.system_prompt, @@ -348,6 +359,7 @@ async def _handle_model_execution( invocation_state=invocation_state, model_state=agent._model_state, cancel_signal=agent._cancel_signal, + chunk_interceptor=chunk_interceptor, ): yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 0a1161135..a1dc79e65 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -5,7 +5,7 @@ import threading import time import warnings -from collections.abc import AsyncGenerator, AsyncIterable +from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable from typing import Any from ..models.model import Model @@ -42,6 +42,10 @@ logger = logging.getLogger(__name__) +# Type for chunk interceptor callback +# Takes a chunk and returns (modified_chunk, skip) where skip=True means don't process this chunk +ChunkInterceptor = Callable[[StreamEvent], Awaitable[tuple[StreamEvent, bool]]] + def _normalize_messages(messages: Messages) -> Messages: """Remove or replace blank text in message content. @@ -387,6 +391,7 @@ async def process_stream( chunks: AsyncIterable[StreamEvent], start_time: float | None = None, cancel_signal: threading.Event | None = None, + chunk_interceptor: ChunkInterceptor | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. @@ -394,6 +399,9 @@ async def process_stream( chunks: The chunks of the response stream from the model. start_time: Time when the model request is initiated cancel_signal: Optional threading.Event to check for cancellation during streaming. + chunk_interceptor: Optional callback to intercept and modify chunks before processing. + The callback receives a chunk and returns (modified_chunk, skip). If skip is True, + the chunk is not processed or yielded. Yields: The reason for stopping, the constructed message, and the usage metrics. @@ -427,6 +435,12 @@ async def process_stream( ) return + # Invoke chunk interceptor BEFORE processing if provided + if chunk_interceptor is not None: + chunk, skip = await chunk_interceptor(chunk) + if skip: + continue + # Track first byte time when we get first content if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk): first_byte_time = time.time() @@ -465,6 +479,7 @@ async def stream_messages( invocation_state: dict[str, Any] | None = None, model_state: dict[str, Any] | None = None, cancel_signal: threading.Event | None = None, + chunk_interceptor: ChunkInterceptor | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -480,6 +495,9 @@ async def stream_messages( invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. model_state: Runtime state for model providers (e.g., server-side response ids). cancel_signal: Optional threading.Event to check for cancellation during streaming. + chunk_interceptor: Optional callback to intercept and modify chunks before processing. + The callback receives a chunk and returns (modified_chunk, skip). If skip is True, + the chunk is not processed or yielded. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -500,5 +518,5 @@ async def stream_messages( model_state=model_state, ) - async for event in process_stream(chunks, start_time, cancel_signal): + async for event in process_stream(chunks, start_time, cancel_signal, chunk_interceptor): yield event diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 96c7f577b..6b21390aa 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -41,6 +41,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: BeforeModelCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, + BeforeStreamChunkEvent, BeforeToolCallEvent, MessageAddedEvent, MultiAgentInitializedEvent, @@ -50,6 +51,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: __all__ = [ "AgentInitializedEvent", "BeforeInvocationEvent", + "BeforeStreamChunkEvent", "BeforeToolCallEvent", "AfterToolCallEvent", "BeforeModelCallEvent", diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 9186e0e70..7088a91a5 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -15,7 +15,7 @@ from ..types.agent import AgentInput from ..types.content import Message, Messages from ..types.interrupt import _Interruptible -from ..types.streaming import StopReason +from ..types.streaming import StopReason, StreamEvent from ..types.tools import AgentTool, ToolResult, ToolUse from .registry import BaseHookEvent, HookEvent @@ -303,6 +303,48 @@ def should_reverse_callbacks(self) -> bool: return True +@dataclass +class BeforeStreamChunkEvent(HookEvent): + """Event triggered before each stream chunk is processed. + + This event is fired for each chunk received from the model BEFORE the chunk + is processed for message building or yielded as stream events. Hook providers + can use this event to: + + - Monitor streaming progress in real-time + - Modify chunk content before processing (affects final message and all events) + - Filter/skip chunks entirely by setting skip=True + - Implement content transformation (e.g., redaction, translation) + + When skip=True: + - The chunk is not processed at all + - No events (ModelStreamChunkEvent, TextStreamEvent, etc.) are yielded + - The chunk does not contribute to the final message + + When chunk is modified: + - The modified chunk is used for all downstream processing + - TextStreamEvent will contain the modified text + - The final message will contain the modified content + + Performance Note: + This event fires for every stream chunk, so callbacks should execute + quickly to avoid impacting streaming latency. + + Attributes: + chunk: The raw stream event from the model. Can be modified by hooks + to transform content before processing. + skip: When True, the chunk is skipped entirely (not processed or yielded). + invocation_state: State passed through agent invocation. + """ + + chunk: StreamEvent + invocation_state: dict[str, Any] = field(default_factory=dict) + skip: bool = False + + def _can_write(self, name: str) -> bool: + return name in ["chunk", "skip"] + + # Multiagent hook events start here @dataclass class MultiAgentInitializedEvent(BaseHookEvent): diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 0e03fbbcd..55d72e2a9 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -10,6 +10,7 @@ AgentInitializedEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeStreamChunkEvent, BeforeToolCallEvent, MessageAddedEvent, ) @@ -260,3 +261,58 @@ def test_after_invocation_event_resume_accepts_various_input_types(agent): # None to stop event.resume = None assert event.resume is None + + +@pytest.fixture +def before_stream_chunk_event(agent): + chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} + return BeforeStreamChunkEvent( + agent=agent, + chunk=chunk, + invocation_state={"test": "state"}, + ) + + +def test_before_stream_chunk_event_should_not_reverse_callbacks(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent does not reverse callbacks.""" + assert before_stream_chunk_event.should_reverse_callbacks is False + + +def test_before_stream_chunk_event_can_write_chunk(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent.chunk is writable.""" + new_chunk = {"contentBlockDelta": {"delta": {"text": "Modified"}}} + before_stream_chunk_event.chunk = new_chunk + assert before_stream_chunk_event.chunk == new_chunk + + +def test_before_stream_chunk_event_can_write_skip(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent.skip is writable.""" + assert before_stream_chunk_event.skip is False + before_stream_chunk_event.skip = True + assert before_stream_chunk_event.skip is True + + +def test_before_stream_chunk_event_cannot_write_agent(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent.agent is not writable.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + before_stream_chunk_event.agent = Mock() + + +def test_before_stream_chunk_event_cannot_write_invocation_state(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent.invocation_state is not writable.""" + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + before_stream_chunk_event.invocation_state = {} + + +def test_before_stream_chunk_event_skip_defaults_to_false(agent): + """Test that BeforeStreamChunkEvent.skip defaults to False.""" + chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} + event = BeforeStreamChunkEvent(agent=agent, chunk=chunk) + assert event.skip is False + + +def test_before_stream_chunk_event_invocation_state_defaults_to_empty(agent): + """Test that BeforeStreamChunkEvent.invocation_state defaults to empty dict.""" + chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} + event = BeforeStreamChunkEvent(agent=agent, chunk=chunk) + assert event.invocation_state == {} diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 3a40d69a8..d05f6df6b 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -12,6 +12,7 @@ AgentInitializedEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeStreamChunkEvent, BeforeToolCallEvent, MessageAddedEvent, ) @@ -1073,3 +1074,177 @@ def on_before(event: BeforeInvocationEvent) -> None: assert len(before_events) == 1 assert isinstance(before_events[0], BeforeInvocationEvent) + + +def test_before_stream_chunk_event_fires_for_each_chunk(): + """Test that BeforeStreamChunkEvent fires for each stream chunk.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Hello world"}], + }, + ] + ) + + chunk_events = [] + + async def capture_stream_chunks(event: BeforeStreamChunkEvent): + chunk_events.append(event.chunk) + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, capture_stream_chunks) + + agent("Test message") + + # Should have received multiple chunk events (messageStart, contentBlockStart, delta, stop, etc.) + assert len(chunk_events) > 0 + # Should include content block delta with text + text_deltas = [c for c in chunk_events if "contentBlockDelta" in c] + assert len(text_deltas) == 1 + assert text_deltas[0]["contentBlockDelta"]["delta"]["text"] == "Hello world" + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_can_skip_chunks(): + """Test that setting skip=True prevents chunks from being processed entirely. + + When skip=True, the chunk is not processed at all: + - No ModelStreamChunkEvent is yielded + - No TextStreamEvent is yielded + - The chunk does not contribute to the final message + """ + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Hello world"}], + }, + ] + ) + + async def skip_text_chunks(event: BeforeStreamChunkEvent): + # Skip only contentBlockDelta chunks (text content) + if "contentBlockDelta" in event.chunk: + event.skip = True + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, skip_text_chunks) + + # Collect all yielded events + text_events = [] + result = None + async for event in agent.stream_async("Test message"): + if "data" in event: # TextStreamEvent + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + # Verify no text events were yielded (because we skipped content deltas) + assert len(text_events) == 0 + + # Verify final message has no text content (skipped chunks don't contribute) + # The message will have empty content since we skipped the text blocks + assert result.message["content"] == [] + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_can_modify_chunks(): + """Test that chunk modifications affect both stream events and final message. + + The BeforeStreamChunkEvent hook intercepts chunks BEFORE processing, so + modifications affect: + - The yielded ModelStreamChunkEvent (raw chunk) + - The yielded TextStreamEvent (processed text) + - The final message content + """ + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "secret"}], + }, + ] + ) + + async def redact_text_chunks(event: BeforeStreamChunkEvent): + # Modify text delta chunks + if "contentBlockDelta" in event.chunk: + delta = event.chunk["contentBlockDelta"]["delta"] + if "text" in delta: + # Create modified chunk with redacted text + event.chunk = {"contentBlockDelta": {"delta": {"text": "[REDACTED]"}}} + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, redact_text_chunks) + + # Collect yielded events + text_events = [] + result = None + async for event in agent.stream_async("Test message"): + if "data" in event: # TextStreamEvent + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + # Verify TextStreamEvent contains modified text + assert len(text_events) == 1 + assert text_events[0] == "[REDACTED]" + + # Verify final message contains modified text + assert result.message["content"][0]["text"] == "[REDACTED]" + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_with_stream_async(): + """Test that BeforeStreamChunkEvent works with stream_async.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Hello"}], + }, + ] + ) + + chunk_events = [] + + async def capture_stream_chunks(event: BeforeStreamChunkEvent): + chunk_events.append(event.chunk) + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, capture_stream_chunks) + + async for _ in agent.stream_async("Test message"): + pass + + # Should have received chunk events + assert len(chunk_events) > 0 + + +def test_before_stream_chunk_event_has_invocation_state(): + """Test that BeforeStreamChunkEvent includes invocation_state.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Hello"}], + }, + ] + ) + + received_states = [] + + async def capture_invocation_state(event: BeforeStreamChunkEvent): + received_states.append(event.invocation_state.copy()) + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, capture_invocation_state) + + agent("Test message", invocation_state={"custom_key": "custom_value"}) + + # All captured states should have the custom key + assert len(received_states) > 0 + for state in received_states: + assert "custom_key" in state + assert state["custom_key"] == "custom_value" diff --git a/tests_integ/hooks/test_stream_chunk_event.py b/tests_integ/hooks/test_stream_chunk_event.py new file mode 100644 index 000000000..60e485e98 --- /dev/null +++ b/tests_integ/hooks/test_stream_chunk_event.py @@ -0,0 +1,122 @@ +"""Integration tests for BeforeStreamChunkEvent hook.""" + +import pytest + +from strands import Agent +from strands.hooks import BeforeStreamChunkEvent + + +@pytest.fixture +def chunks_intercepted(): + return [] + + +@pytest.fixture +def agent_with_stream_hook(chunks_intercepted): + """Create an agent with BeforeStreamChunkEvent hook registered.""" + + async def intercept_chunks(event: BeforeStreamChunkEvent): + chunks_intercepted.append(event.chunk.copy()) + + agent = Agent(system_prompt="Be very brief. Reply with one word only.") + agent.hooks.add_callback(BeforeStreamChunkEvent, intercept_chunks) + return agent + + +def test_before_stream_chunk_event_fires(agent_with_stream_hook, chunks_intercepted): + """Test that BeforeStreamChunkEvent fires for each stream chunk.""" + agent_with_stream_hook("Say hello") + + # Should have intercepted multiple chunks + assert len(chunks_intercepted) > 0 + + # Should have message start, content blocks, and message stop + chunk_types = set() + for chunk in chunks_intercepted: + chunk_types.update(chunk.keys()) + + assert "messageStart" in chunk_types + assert "messageStop" in chunk_types + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_modification(): + """Test that chunk modifications affect both stream events and final message.""" + modified_text = "[REDACTED]" + + async def redact_chunks(event: BeforeStreamChunkEvent): + if "contentBlockDelta" in event.chunk: + delta = event.chunk.get("contentBlockDelta", {}).get("delta", {}) + if "text" in delta: + event.chunk = {"contentBlockDelta": {"delta": {"text": modified_text}}} + + agent = Agent(system_prompt="Say exactly: secret123") + agent.hooks.add_callback(BeforeStreamChunkEvent, redact_chunks) + + text_events = [] + result = None + + async for event in agent.stream_async("go"): + if "data" in event: + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + # All text events should be the modified text + assert all(text == modified_text for text in text_events) + + # Final message should only contain modified text + final_text = result.message["content"][0].get("text", "") + assert modified_text in final_text + assert "secret" not in final_text.lower() + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_skip(): + """Test that skip=True excludes chunks from processing and final message.""" + + async def skip_content_deltas(event: BeforeStreamChunkEvent): + # Skip all content block deltas (text content) + if "contentBlockDelta" in event.chunk: + event.skip = True + + agent = Agent(system_prompt="Say hello") + agent.hooks.add_callback(BeforeStreamChunkEvent, skip_content_deltas) + + text_events = [] + result = None + + async for event in agent.stream_async("go"): + if "data" in event: + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + # No text events should be yielded + assert len(text_events) == 0 + + # Final message should have no content (all text was skipped) + assert result.message["content"] == [] + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_has_invocation_state(): + """Test that invocation_state is accessible in BeforeStreamChunkEvent.""" + received_states = [] + + async def capture_state(event: BeforeStreamChunkEvent): + received_states.append(event.invocation_state.copy()) + + agent = Agent(system_prompt="Be brief") + agent.hooks.add_callback(BeforeStreamChunkEvent, capture_state) + + custom_state = {"session_id": "test-123", "user_id": "user-456"} + + async for _ in agent.stream_async("hi", invocation_state=custom_state): + pass + + # All captured states should have our custom keys + assert len(received_states) > 0 + for state in received_states: + assert state.get("session_id") == "test-123" + assert state.get("user_id") == "user-456" From 4ebf6e33452cfe58ccfbe23aaa0457f87cdc91f7 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sat, 11 Apr 2026 15:40:01 -0400 Subject: [PATCH 2/2] test: add real-model validation for BeforeStreamChunkEvent Exercises all three hook capabilities (capture, modify, skip) against OpenAI gpt-4o-mini to confirm the feature works end-to-end with a live streaming provider. --- examples/test_stream_chunk_hook.py | 141 +++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 examples/test_stream_chunk_hook.py diff --git a/examples/test_stream_chunk_hook.py b/examples/test_stream_chunk_hook.py new file mode 100644 index 000000000..77a132fb8 --- /dev/null +++ b/examples/test_stream_chunk_hook.py @@ -0,0 +1,141 @@ +"""Test BeforeStreamChunkEvent with a real OpenAI model. + +Tests three scenarios: +1. Chunk capture - intercept and log all stream chunks +2. Chunk modification - redact text in-flight +3. Chunk skipping - suppress content deltas entirely +""" + +import asyncio +import os + +from strands import Agent +from strands.hooks import BeforeStreamChunkEvent +from strands.models.openai import OpenAIModel + + +def make_model(): + return OpenAIModel( + model_id="gpt-4o-mini", + client_args={"api_key": os.environ["OPENAI_API_KEY"]}, + ) + + +# ---------- Test 1: capture chunks ---------- +def test_capture(): + print("=" * 60) + print("TEST 1: Capture every stream chunk") + print("=" * 60) + + chunks = [] + + async def on_chunk(event: BeforeStreamChunkEvent): + chunks.append(event.chunk.copy()) + + agent = Agent( + model=make_model(), + system_prompt="Reply in exactly 5 words.", + ) + agent.hooks.add_callback(BeforeStreamChunkEvent, on_chunk) + + result = agent("Say hello") + + print(f"\nFinal text : {result.message['content'][0].get('text', '')}") + print(f"Chunks seen: {len(chunks)}") + + chunk_types = set() + for c in chunks: + chunk_types.update(c.keys()) + print(f"Chunk types: {sorted(chunk_types)}") + + assert len(chunks) > 0, "Expected at least one chunk" + assert "messageStart" in chunk_types, "Missing messageStart" + assert "contentBlockDelta" in chunk_types, "Missing contentBlockDelta" + print("PASSED\n") + + +# ---------- Test 2: modify (redact) chunks ---------- +def test_modify(): + print("=" * 60) + print("TEST 2: Redact text in-flight") + print("=" * 60) + + async def redact(event: BeforeStreamChunkEvent): + if "contentBlockDelta" in event.chunk: + delta = event.chunk.get("contentBlockDelta", {}).get("delta", {}) + if "text" in delta: + event.chunk = {"contentBlockDelta": {"delta": {"text": "[REDACTED]"}}} + + agent = Agent( + model=make_model(), + system_prompt="Say exactly: the secret code is 12345", + ) + agent.hooks.add_callback(BeforeStreamChunkEvent, redact) + + text_events = [] + result = None + + async def run(): + nonlocal result + async for event in agent.stream_async("go"): + if "data" in event: + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + asyncio.run(run()) + + final_text = result.message["content"][0].get("text", "") + print(f"\nStreamed text events: {text_events}") + print(f"Final message text : {final_text}") + + assert all(t == "[REDACTED]" for t in text_events), "Some text events were not redacted" + assert "secret" not in final_text.lower(), "Final message still contains secret" + assert "[REDACTED]" in final_text, "Final message missing redaction marker" + print("PASSED\n") + + +# ---------- Test 3: skip content deltas ---------- +def test_skip(): + print("=" * 60) + print("TEST 3: Skip content deltas entirely") + print("=" * 60) + + async def skip_deltas(event: BeforeStreamChunkEvent): + if "contentBlockDelta" in event.chunk: + event.skip = True + + agent = Agent( + model=make_model(), + system_prompt="Reply with a greeting", + ) + agent.hooks.add_callback(BeforeStreamChunkEvent, skip_deltas) + + text_events = [] + result = None + + async def run(): + nonlocal result + async for event in agent.stream_async("hi"): + if "data" in event: + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + asyncio.run(run()) + + print(f"\nText events received: {len(text_events)}") + print(f"Final content : {result.message['content']}") + + assert len(text_events) == 0, f"Expected 0 text events, got {len(text_events)}" + assert result.message["content"] == [], "Expected empty content" + print("PASSED\n") + + +if __name__ == "__main__": + test_capture() + test_modify() + test_skip() + print("=" * 60) + print("ALL TESTS PASSED") + print("=" * 60)