diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index a3f4570b6e..47be14653d 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -935,6 +935,13 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda session.service_session_id = conv_id return update + def _stream_finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: + # Read response_format from merged chat_options (includes default_options) + # so that response_format set via the constructor is respected. + ctx = ctx_holder["ctx"] + response_format = ctx["chat_options"].get("response_format") if ctx else None + return self._finalize_response_updates(updates, response_format=response_format) + return ( ResponseStream .from_awaitable(_get_stream()) @@ -943,9 +950,7 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda map_chat_to_agent_update, agent_name=self.name, ), - finalizer=partial( - self._finalize_response_updates, response_format=options.get("response_format") if options else None - ), + finalizer=_stream_finalizer, ) .with_transform_hook(_propagate_conversation_id) .with_result_hook(_post_hook) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 627987a1f2..63628b206d 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -97,6 +97,32 @@ async def test_chat_client_agent_run_streaming(client: SupportsChatGetResponse) assert result.text == "test streaming response another update" +async def test_chat_client_agent_run_streaming_response_format_via_default_options( + client: SupportsChatGetResponse, +) -> None: + """response_format set in default_options must be available on the streamed AgentResponse.value (issue #3970).""" + from pydantic import BaseModel + + class Greeting(BaseModel): + greeting: str + + json_text = '{"greeting":"hello"}' + client.streaming_responses = [ # type: ignore[attr-defined] + [ChatResponseUpdate(contents=[Content.from_text(json_text)], role="assistant")], + ] + + agent = Agent(client=client, default_options={"response_format": Greeting}) + stream = agent.run("Hello", stream=True) + async for _ in stream: + pass + result = await stream.get_final_response() + + assert result.text == json_text + assert result.value is not None + assert isinstance(result.value, Greeting) + assert result.value.greeting == "hello" + + async def test_chat_client_agent_create_session(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) session = agent.create_session() diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index f278afaeac..319d35f152 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -937,8 +937,7 @@ def ai_func(arg1: str) -> str: orphaned_calls = all_call_ids - all_result_ids assert not orphaned_calls, ( - f"Response contains orphaned FunctionCallContent without matching " - f"FunctionResultContent: {orphaned_calls}." + f"Response contains orphaned FunctionCallContent without matching FunctionResultContent: {orphaned_calls}." ) @@ -1123,8 +1122,7 @@ def browser_snapshot(url: str) -> str: orphaned_calls = all_call_ids - all_result_ids assert not orphaned_calls, ( - f"Response contains orphaned function calls {orphaned_calls}. " - f"This would cause API errors on the next call." + f"Response contains orphaned function calls {orphaned_calls}. This would cause API errors on the next call." )