Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,11 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda
session.service_session_id = conv_id
return update

def _finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]:
ctx = ctx_holder["ctx"]
rf = ctx.get("chat_options", {}).get("response_format") if ctx else (options.get("response_format") if options else None)
return self._finalize_response_updates(updates, response_format=rf)

return (
ResponseStream
.from_awaitable(_get_stream())
Expand All @@ -943,9 +948,7 @@ def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpda
map_chat_to_agent_update,
agent_name=self.name,
),
finalizer=partial(
self._finalize_response_updates, response_format=options.get("response_format") if options else None
),
finalizer=_finalizer,
)
.with_transform_hook(_propagate_conversation_id)
.with_result_hook(_post_hook)
Expand Down
52 changes: 52 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,58 @@ async def test_chat_client_agent_run_streaming(client: SupportsChatGetResponse)
assert result.text == "test streaming response another update"


async def test_chat_client_agent_streaming_response_format_from_default_options(
client: SupportsChatGetResponse,
) -> None:
"""AgentResponse.value must be parsed when response_format is set in default_options and streaming."""
from pydantic import BaseModel

class Greeting(BaseModel):
greeting: str

json_text = '{"greeting": "Hello"}'
client.streaming_responses.append( # type: ignore[attr-defined]
[ChatResponseUpdate(contents=[Content.from_text(json_text)], role="assistant", finish_reason="stop")]
)

agent = Agent(client=client, default_options={"response_format": Greeting})
stream = agent.run("Hello", stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()

assert result.text == json_text
assert result.value is not None
assert isinstance(result.value, Greeting)
assert result.value.greeting == "Hello"


async def test_chat_client_agent_streaming_response_format_from_run_options(
client: SupportsChatGetResponse,
) -> None:
"""AgentResponse.value must be parsed when response_format is passed via run() options kwarg."""
from pydantic import BaseModel

class Greeting(BaseModel):
greeting: str

json_text = '{"greeting": "Hi"}'
client.streaming_responses.append( # type: ignore[attr-defined]
[ChatResponseUpdate(contents=[Content.from_text(json_text)], role="assistant", finish_reason="stop")]
)

agent = Agent(client=client)
stream = agent.run("Hello", stream=True, options={"response_format": Greeting})
async for _ in stream:
pass
result = await stream.get_final_response()

assert result.text == json_text
assert result.value is not None
assert isinstance(result.value, Greeting)
assert result.value.greeting == "Hi"


async def test_chat_client_agent_create_session(client: SupportsChatGetResponse) -> None:
agent = Agent(client=client)
session = agent.create_session()
Expand Down