Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ async def _run_impl(
output_events.append(event)

result = self._convert_workflow_events_to_agent_response(response_id, output_events)

# Set the response on the context so after_run providers (e.g. InMemoryHistoryProvider)
# can persist the response messages alongside input messages.
session_context._response = result # type: ignore[assignment]

await self._run_after_providers(session=provider_session, context=session_context)
return result

Expand Down Expand Up @@ -322,12 +327,20 @@ async def _run_stream_impl(
# combine the messages

session_messages: list[Message] = session_context.get_messages(include_input=True)
all_updates: list[AgentResponseUpdate] = []
async for event in self._run_core(
session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs
):
updates = self._convert_workflow_event_to_agent_response_updates(response_id, event)
for update in updates:
all_updates.append(update)
yield update

# Build the final response from collected updates so after_run providers
# (e.g. InMemoryHistoryProvider) can persist the response messages.
if all_updates:
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]

await self._run_after_providers(session=provider_session, context=session_context)

async def _run_core(
Expand Down
86 changes: 86 additions & 0 deletions python/packages/core/tests/workflow/test_workflow_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,92 @@ async def test_workflow_as_agent_adds_default_history_provider(self) -> None:
assert "first message" in texts
assert "second message" in texts

async def test_multi_turn_session_stores_responses(self) -> None:
"""Test that WorkflowAgent stores response messages in session history (issue #1694).

Previously, session_context._response was not set before running after_run
providers, so InMemoryHistoryProvider never persisted response messages.
On subsequent runs the workflow only received prior user inputs, not prior
assistant responses, breaking multi-turn conversations.
"""
capturing_executor = ConversationHistoryCapturingExecutor(id="multi_turn_test", streaming=False)
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
agent = workflow.as_agent(name="Multi Turn Agent")
session = AgentSession()

# First turn
await agent.run("My name is Bob", session=session)

# Second turn — the executor should see prior user+assistant messages plus new input
await agent.run("What is my name?", session=session)

received = capturing_executor.received_messages
roles = [m.role for m in received]
texts = [m.text for m in received]

# History should include: user("My name is Bob"), assistant(response), user("What is my name?")
assert len(received) == 3, f"Expected 3 messages (user, assistant, user), got {len(received)}: {roles}"
assert roles[0] == "user"
assert "My name is Bob" in (texts[0] or "")
assert roles[1] == "assistant"
assert roles[2] == "user"
assert "What is my name?" in (texts[2] or "")

async def test_multi_turn_session_stores_responses_streaming(self) -> None:
"""Streaming variant: WorkflowAgent stores response messages in session history."""
capturing_executor = ConversationHistoryCapturingExecutor(id="multi_turn_stream_test", streaming=True)
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
agent = workflow.as_agent(name="Multi Turn Stream Agent")
session = AgentSession()

# First turn (streaming)
stream = agent.run("Hello", stream=True, session=session)
async for _ in stream:
pass
await stream.get_final_response()

# Second turn — should include prior history
stream2 = agent.run("Follow up", stream=True, session=session)
async for _ in stream2:
pass
await stream2.get_final_response()

received = capturing_executor.received_messages
roles = [m.role for m in received]

assert len(received) == 3, f"Expected 3 messages, got {len(received)}: {roles}"
assert roles[0] == "user"
assert roles[1] == "assistant"
assert roles[2] == "user"

async def test_multi_turn_session_roundtrip_serialization(self) -> None:
"""Test that session can be serialized/deserialized and multi-turn still works."""
capturing_executor = ConversationHistoryCapturingExecutor(id="roundtrip_test", streaming=False)
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
agent = workflow.as_agent(name="Roundtrip Agent")
session = AgentSession()

# First turn
await agent.run("My name is Bob", session=session)

# Serialize and deserialize the session
serialized = session.to_dict()
restored_session = AgentSession.from_dict(serialized)

# Second turn with restored session
await agent.run("What is my name?", session=restored_session)

received = capturing_executor.received_messages
roles = [m.role for m in received]
texts = [m.text for m in received]

assert len(received) == 3, f"Expected 3 messages, got {len(received)}: {roles}"
assert roles[0] == "user"
assert "My name is Bob" in (texts[0] or "")
assert roles[1] == "assistant"
assert roles[2] == "user"
assert "What is my name?" in (texts[2] or "")

async def test_workflow_agent_keeps_explicit_context_providers(self) -> None:
"""Test that WorkflowAgent does not append defaults when context providers are explicitly provided."""
workflow = WorkflowBuilder(
Expand Down