Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,11 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda
session.service_session_id = conv_id
return update

def _finalize_with_context(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
ctx = ctx_holder["ctx"]
response_format = ctx["chat_options"].get("response_format") if ctx else None
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Streaming finalizer pulls response_format from ctx["chat_options"] but doesn’t apply the same BaseModel subclass validation that the non-streaming path does (where non-BaseModel values are normalized to None). This can lead to inconsistent AgentResponse.response_format contents between streaming and non-streaming runs and can violate the annotated type. Consider reusing the non-streaming validation logic here (or inside _finalize_response_updates) so only type[BaseModel] is propagated for parsing.

Suggested change
response_format = ctx["chat_options"].get("response_format") if ctx else None
raw_response_format = ctx["chat_options"].get("response_format") if ctx else None
# Normalize response_format so only BaseModel subclasses are propagated, matching non-streaming behavior.
if isinstance(raw_response_format, type) and issubclass(raw_response_format, BaseModel):
response_format = raw_response_format
else:
response_format = None

Copilot uses AI. Check for mistakes.
return self._finalize_response_updates(updates, response_format=response_format)

return (
ResponseStream
.from_awaitable(_get_stream())
Expand All @@ -943,9 +948,7 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda
map_chat_to_agent_update,
agent_name=self.name,
),
finalizer=partial(
self._finalize_response_updates, response_format=options.get("response_format") if options else None
),
finalizer=_finalize_with_context,
)
.with_transform_hook(_propagate_conversation_id)
.with_result_hook(_post_hook)
Expand Down
33 changes: 33 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,39 @@ async def test_chat_client_agent_run_streaming(client: SupportsChatGetResponse)
assert result.text == "test streaming response another update"


async def test_chat_client_agent_run_streaming_response_format_from_default_options(
client: SupportsChatGetResponse,
) -> None:
"""Streaming with response_format in default_options should parse value correctly (#3970)."""
from pydantic import BaseModel

class Greeting(BaseModel):
greeting: str

json_text = '{"greeting": "Hello!"}'
client.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text(json_text)],
role="assistant",
finish_reason="stop",
),
]
]

agent = Agent(client=client, default_options={"response_format": Greeting})

stream = agent.run("Hello", stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()

assert result.text == json_text
assert result.value is not None
assert isinstance(result.value, Greeting)
assert result.value.greeting == "Hello!"


async def test_chat_client_agent_create_session(client: SupportsChatGetResponse) -> None:
agent = Agent(client=client)
session = agent.create_session()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -937,8 +937,7 @@ def ai_func(arg1: str) -> str:

orphaned_calls = all_call_ids - all_result_ids
assert not orphaned_calls, (
f"Response contains orphaned FunctionCallContent without matching "
f"FunctionResultContent: {orphaned_calls}."
f"Response contains orphaned FunctionCallContent without matching FunctionResultContent: {orphaned_calls}."
)


Expand Down Expand Up @@ -1123,8 +1122,7 @@ def browser_snapshot(url: str) -> str:

orphaned_calls = all_call_ids - all_result_ids
assert not orphaned_calls, (
f"Response contains orphaned function calls {orphaned_calls}. "
f"This would cause API errors on the next call."
f"Response contains orphaned function calls {orphaned_calls}. This would cause API errors on the next call."
)


Expand Down
Loading