Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,19 +911,21 @@ async def _post_hook(response: AgentResponse) -> None:
if ctx is None:
return # No context available (shouldn't happen in normal flow)

# Update thread with conversation_id derived from streaming raw updates.
# Using response_id here can break function-call continuation for APIs
# where response IDs are not valid conversation handles.
conversation_id = self._extract_conversation_id_from_streaming_response(response)
# Ensure author names are set for all messages
for message in response.messages:
if message.author_name is None:
message.author_name = ctx["agent_name"]

# Propagate conversation_id back to session from streaming updates
# Propagate conversation_id back to session from streaming updates.
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
# so refresh when a newer value is returned.
sess = ctx["session"]
if sess and not sess.service_session_id and response.raw_representation:
raw_items = response.raw_representation if isinstance(response.raw_representation, list) else []
for item in raw_items:
if hasattr(item, "conversation_id") and item.conversation_id:
sess.service_session_id = item.conversation_id
break
if sess and conversation_id and sess.service_session_id != conversation_id:
sess.service_session_id = conversation_id

# Run after_run providers (reverse order)
session_context = ctx["session_context"]
Expand Down Expand Up @@ -974,6 +976,27 @@ def _finalize_response_updates(
output_format_type = response_format if isinstance(response_format, type) else None
return AgentResponse.from_updates(updates, output_format_type=output_format_type)

@staticmethod
def _extract_conversation_id_from_streaming_response(response: AgentResponse[Any]) -> str | None:
"""Extract conversation_id from streaming raw updates, if present."""
raw = response.raw_representation
if raw is None:
return None

raw_items: list[Any] = raw if isinstance(raw, list) else [raw]
for item in reversed(raw_items):
if isinstance(item, Mapping):
value = item.get("conversation_id")
if isinstance(value, str) and value:
return value
continue

value = getattr(item, "conversation_id", None)
if isinstance(value, str) and value:
return value

return None

async def _prepare_run_context(
self,
*,
Expand Down Expand Up @@ -1100,8 +1123,10 @@ async def _finalize_response(
if message.author_name is None:
message.author_name = agent_name

# Propagate conversation_id back to session (e.g. thread ID from Assistants API)
if session and response.conversation_id and not session.service_session_id:
# Propagate conversation_id back to session (e.g. thread ID from Assistants API).
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
# so refresh when a newer value is returned.
if session and response.conversation_id and session.service_session_id != response.conversation_id:
session.service_session_id = response.conversation_id

# Set the response on the context for after_run providers
Expand Down
11 changes: 10 additions & 1 deletion python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,16 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str:
k: v
for k, v in kwargs.items()
if k
not in {"chat_options", "tools", "tool_choice", "session", "thread", "conversation_id", "options", "response_format"}
not in {
"chat_options",
"tools",
"tool_choice",
"session",
"thread",
"conversation_id",
"options",
"response_format",
}
}

parser = self.parse_tool_results or _parse_tool_result_from_mcp
Expand Down
37 changes: 28 additions & 9 deletions python/packages/core/agent_framework/_workflows/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import sys
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from typing import Any, cast

Expand Down Expand Up @@ -358,22 +358,31 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {})

updates: list[AgentResponseUpdate] = []
user_input_requests: list[Content] = []
async for update in self._agent.run(
streamed_user_input_requests: list[Content] = []
stream = self._agent.run(
self._cache,
stream=True,
session=self._session,
options=options,
**run_kwargs,
):
)
async for update in stream:
updates.append(update)
await ctx.yield_output(update)

if update.user_input_requests:
user_input_requests.extend(update.user_input_requests)

# Build the final AgentResponse from the collected updates
if is_chat_agent(self._agent):
streamed_user_input_requests.extend(update.user_input_requests)

# Prefer stream finalization when available so result hooks run
# (e.g., thread conversation updates). Fall back to reconstructing from updates
# for legacy/custom agents that return a plain async iterable.
# TODO(evmattso): Integrate workflow agent run handling around ResponseStream so
# AgentExecutor does not need this conditional stream-finalization branch.
maybe_get_final_response = getattr(stream, "get_final_response", None)
get_final_response = maybe_get_final_response if callable(maybe_get_final_response) else None
response: AgentResponse[Any]
if get_final_response is not None:
response = await cast(Callable[[], Awaitable[AgentResponse[Any]]], get_final_response)()
elif is_chat_agent(self._agent):
response_format = self._agent.default_options.get("response_format")
response = AgentResponse.from_updates(
updates,
Expand All @@ -383,6 +392,16 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp
response = AgentResponse.from_updates(updates)

# Handle any user input requests after the streaming completes
user_input_requests: list[Content] = []
seen_request_ids: set[str] = set()
for user_input_request in [*streamed_user_input_requests, *response.user_input_requests]:
request_id = getattr(user_input_request, "id", None)
if isinstance(request_id, str) and request_id:
if request_id in seen_request_ids:
continue
seen_request_ids.add(request_id)
user_input_requests.append(user_input_request)

if user_input_requests:
for user_input_request in user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
Expand Down
106 changes: 106 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BaseContextProvider,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
Content,
FunctionTool,
Message,
Expand Down Expand Up @@ -154,6 +155,111 @@ async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChat
assert session.service_session_id == "123"


async def test_chat_client_agent_updates_existing_session_id_non_streaming(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.run_responses = [
ChatResponse(
messages=[Message(role="assistant", contents=[Content.from_text("test response")])],
conversation_id="resp_new_123",
)
]

agent = Agent(client=chat_client_base)
session = agent.get_session(service_session_id="resp_old_123")

await agent.run("Hello", session=session)
assert session.service_session_id == "resp_new_123"


async def test_chat_client_agent_update_session_id_streaming_uses_conversation_id(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("stream part 1")],
role="assistant",
response_id="resp_stream_123",
conversation_id="conv_stream_456",
),
ChatResponseUpdate(
contents=[Content.from_text(" stream part 2")],
role="assistant",
response_id="resp_stream_123",
conversation_id="conv_stream_456",
finish_reason="stop",
),
]
]

agent = Agent(client=chat_client_base)
session = agent.create_session()

stream = agent.run("Hello", session=session, stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()
assert result.text == "stream part 1 stream part 2"
assert session.service_session_id == "conv_stream_456"


async def test_chat_client_agent_updates_existing_session_id_streaming(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("stream part 1")],
role="assistant",
response_id="resp_stream_123",
conversation_id="resp_new_456",
),
ChatResponseUpdate(
contents=[Content.from_text(" stream part 2")],
role="assistant",
response_id="resp_stream_123",
conversation_id="resp_new_456",
finish_reason="stop",
),
]
]

agent = Agent(client=chat_client_base)
session = agent.get_session(service_session_id="resp_old_456")

stream = agent.run("Hello", session=session, stream=True)
async for _ in stream:
pass
await stream.get_final_response()
assert session.service_session_id == "resp_new_456"


async def test_chat_client_agent_update_session_id_streaming_does_not_use_response_id(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("stream response without conversation id")],
role="assistant",
response_id="resp_only_123",
finish_reason="stop",
),
]
]

agent = Agent(client=chat_client_base)
session = agent.create_session()

stream = agent.run("Hello", session=session, stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()
assert result.text == "stream response without conversation id"
assert session.service_session_id is None


async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
agent = Agent(client=client)
session = agent.create_session()
Expand Down
1 change: 1 addition & 0 deletions python/packages/core/tests/workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
51 changes: 51 additions & 0 deletions python/packages/core/tests/workflow/test_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,57 @@ async def _run() -> AgentResponse:
return _run()


class _StreamingHookAgent(BaseAgent):
"""Agent that exposes whether its streaming result hook was executed."""

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.result_hook_called = False

def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
*,
stream: bool = False,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
if stream:

async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(
contents=[Content.from_text(text="hook test")],
role="assistant",
)

async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse:
self.result_hook_called = True
return response

return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook(
_mark_result_hook_called
)

async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["hook test"])])

return _run()


async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None:
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
executor = AgentExecutor(agent, id="hook_exec")
workflow = SequentialBuilder(participants=[executor]).build()

output_events: list[Any] = []
async for event in workflow.run("run hook test", stream=True):
if event.type == "output":
output_events.append(event)

assert output_events
assert agent.result_hook_called


async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
"""Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly."""
storage = InMemoryCheckpointStorage()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
WorkflowRunState,
)
from agent_framework._workflows._checkpoint_encoding import (
_PICKLE_MARKER,
_PICKLE_MARKER, # type: ignore
encode_checkpoint_value,
)
from agent_framework._workflows._events import WorkflowEvent
Expand Down
Loading