Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,13 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda
session.service_session_id = conv_id
return update

def _stream_finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
# Read response_format from merged chat_options (includes default_options)
# so that response_format set via the constructor is respected.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This defensive if ctx else None check is not exercised by any test. Consider whether ctx can realistically be None here; if not, dropping the guard (or adding a test that triggers it) would improve clarity.

ctx = ctx_holder["ctx"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor robustness note: if ctx is truthy but chat_options is unexpectedly absent, ctx["chat_options"] would raise KeyError. Using ctx.get("chat_options", {}).get("response_format") would be safer, though this is likely fine if the context schema guarantees the key.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (non-blocking): ctx["chat_options"] will raise KeyError if the key is absent. Using ctx.get("chat_options", {}).get("response_format") would be more defensive, though the key is expected to always be present in practice.

response_format = ctx["chat_options"].get("response_format") if ctx else None
return self._finalize_response_updates(updates, response_format=response_format)

return (
ResponseStream
.from_awaitable(_get_stream())
Expand All @@ -943,9 +950,7 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda
map_chat_to_agent_update,
agent_name=self.name,
),
finalizer=partial(
self._finalize_response_updates, response_format=options.get("response_format") if options else None
),
finalizer=_stream_finalizer,
)
.with_transform_hook(_propagate_conversation_id)
.with_result_hook(_post_hook)
Expand Down
26 changes: 26 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,32 @@ async def test_chat_client_agent_run_streaming(client: SupportsChatGetResponse)
assert result.text == "test streaming response another update"


async def test_chat_client_agent_run_streaming_response_format_via_default_options(
client: SupportsChatGetResponse,
) -> None:
"""response_format set in default_options must be available on the streamed AgentResponse.value (issue #3970)."""
from pydantic import BaseModel

class Greeting(BaseModel):
greeting: str

json_text = '{"greeting":"hello"}'
client.streaming_responses = [ # type: ignore[attr-defined]
[ChatResponseUpdate(contents=[Content.from_text(json_text)], role="assistant")],
]

agent = Agent(client=client, default_options={"response_format": Greeting})
stream = agent.run("Hello", stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good: the test covers the exact bug scenario (default_options response_format in streaming). Consider adding a sibling test that passes response_format directly in run() options to ensure the refactored finalizer handles both paths.


assert result.text == json_text
assert result.value is not None
assert isinstance(result.value, Greeting)
assert result.value.greeting == "hello"


async def test_chat_client_agent_create_session(client: SupportsChatGetResponse) -> None:
agent = Agent(client=client)
session = agent.create_session()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -937,8 +937,7 @@ def ai_func(arg1: str) -> str:

orphaned_calls = all_call_ids - all_result_ids
assert not orphaned_calls, (
f"Response contains orphaned FunctionCallContent without matching "
f"FunctionResultContent: {orphaned_calls}."
f"Response contains orphaned FunctionCallContent without matching FunctionResultContent: {orphaned_calls}."
)


Expand Down Expand Up @@ -1123,8 +1122,7 @@ def browser_snapshot(url: str) -> str:

orphaned_calls = all_call_ids - all_result_ids
assert not orphaned_calls, (
f"Response contains orphaned function calls {orphaned_calls}. "
f"This would cause API errors on the next call."
f"Response contains orphaned function calls {orphaned_calls}. This would cause API errors on the next call."
)


Expand Down
Loading