diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index a3f4570b6e..580b6e2c6d 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -935,6 +935,11 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda session.service_session_id = conv_id return update + def _finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]: + ctx = ctx_holder["ctx"] + rf = ctx.get("chat_options", {}).get("response_format") if ctx else (options.get("response_format") if options else None) + return self._finalize_response_updates(updates, response_format=rf) + return ( ResponseStream .from_awaitable(_get_stream()) @@ -943,9 +948,7 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda map_chat_to_agent_update, agent_name=self.name, ), - finalizer=partial( - self._finalize_response_updates, response_format=options.get("response_format") if options else None - ), + finalizer=_finalizer, ) .with_transform_hook(_propagate_conversation_id) .with_result_hook(_post_hook) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 627987a1f2..b6f84dc970 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -97,6 +97,58 @@ async def test_chat_client_agent_run_streaming(client: SupportsChatGetResponse) assert result.text == "test streaming response another update" +async def test_chat_client_agent_streaming_response_format_from_default_options( + client: SupportsChatGetResponse, +) -> None: + """AgentResponse.value must be parsed when response_format is set in default_options and streaming.""" + from pydantic import BaseModel + + class Greeting(BaseModel): + greeting: str + + json_text = '{"greeting": "Hello"}' + client.streaming_responses.append( # type: ignore[attr-defined] + [ChatResponseUpdate(contents=[Content.from_text(json_text)], role="assistant", finish_reason="stop")] + ) + + agent = Agent(client=client, default_options={"response_format": Greeting}) + stream = agent.run("Hello", stream=True) + async for _ in stream: + pass + result = await stream.get_final_response() + + assert result.text == json_text + assert result.value is not None + assert isinstance(result.value, Greeting) + assert result.value.greeting == "Hello" + + +async def test_chat_client_agent_streaming_response_format_from_run_options( + client: SupportsChatGetResponse, +) -> None: + """AgentResponse.value must be parsed when response_format is passed via run() options kwarg.""" + from pydantic import BaseModel + + class Greeting(BaseModel): + greeting: str + + json_text = '{"greeting": "Hi"}' + client.streaming_responses.append( # type: ignore[attr-defined] + [ChatResponseUpdate(contents=[Content.from_text(json_text)], role="assistant", finish_reason="stop")] + ) + + agent = Agent(client=client) + stream = agent.run("Hello", stream=True, options={"response_format": Greeting}) + async for _ in stream: + pass + result = await stream.get_final_response() + + assert result.text == json_text + assert result.value is not None + assert isinstance(result.value, Greeting) + assert result.value.greeting == "Hi" + + async def test_chat_client_agent_create_session(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) session = agent.create_session()