Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 59 additions & 16 deletions python/packages/core/agent_framework/_workflows/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, cast

Expand Down Expand Up @@ -292,10 +293,10 @@ async def _run_agent_and_emit(
# Non-streaming mode: use run() and emit single event
response = await self._run_agent(cast(WorkflowContext[Never, AgentResponse], ctx))

# Always extend full conversation with cached messages plus agent outputs
# (agent_response.messages) after each run. This is to avoid losing context
# when agent did not complete and the cache is cleared when responses come back.
self._full_conversation.extend(list(self._cache) + (list(response.messages) if response else []))
# Snapshot current conversation as cache + latest agent outputs.
# Do not append to prior snapshots: callers may provide full-history messages
# in request.messages, and extending would duplicate prior turns.
self._full_conversation = list(self._cache) + (list(response.messages) if response else [])

if response is None:
# Agent did not complete (e.g., waiting for user input); do not emit response
Expand All @@ -315,12 +316,7 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR
Returns:
The complete AgentResponse, or None if waiting for user input.
"""
run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})

# Build options dict with additional_function_arguments for tool kwargs propagation
options: dict[str, Any] | None = None
if run_kwargs:
options = {"additional_function_arguments": run_kwargs}
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}))

response = await self._agent.run(
self._cache,
Expand Down Expand Up @@ -349,12 +345,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp
Returns:
The complete AgentResponse, or None if waiting for user input.
"""
run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {}

# Build options dict with additional_function_arguments for tool kwargs propagation
options: dict[str, Any] | None = None
if run_kwargs:
options = {"additional_function_arguments": run_kwargs}
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {})

updates: list[AgentResponseUpdate] = []
user_input_requests: list[Content] = []
Expand Down Expand Up @@ -389,3 +380,55 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp
return None

return response

@staticmethod
def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]:
"""Prepare kwargs and options for agent.run(), avoiding duplicate option passing.

Workflow-level kwargs are propagated to tool calls through
`options.additional_function_arguments`. If workflow kwargs include an
`options` key, merge it into the final options object and remove it from
kwargs before spreading `**run_kwargs`.
"""
run_kwargs = dict(raw_run_kwargs)
options_from_workflow = run_kwargs.pop("options", None)
workflow_additional_args = run_kwargs.pop("additional_function_arguments", None)

options: dict[str, Any] = {}
if options_from_workflow is not None:
if isinstance(options_from_workflow, Mapping):
for key, value in options_from_workflow.items():
if isinstance(key, str):
options[key] = value
else:
logger.warning(
"Ignoring non-mapping workflow 'options' kwarg of type %s for AgentExecutor %s.",
type(options_from_workflow).__name__,
AgentExecutor.__name__,
)

existing_additional_args = options.get("additional_function_arguments")
if isinstance(existing_additional_args, Mapping):
additional_args = {key: value for key, value in existing_additional_args.items() if isinstance(key, str)}
else:
additional_args = {}

if workflow_additional_args is not None:
if isinstance(workflow_additional_args, Mapping):
additional_args.update({
key: value for key, value in workflow_additional_args.items() if isinstance(key, str)
})
else:
logger.warning(
"Ignoring non-mapping workflow 'additional_function_arguments' kwarg of type %s for AgentExecutor %s.", # noqa: E501
type(workflow_additional_args).__name__,
AgentExecutor.__name__,
)

if run_kwargs:
additional_args.update(run_kwargs)

if additional_args:
options["additional_function_arguments"] = additional_args

return run_kwargs, options or None
4 changes: 4 additions & 0 deletions python/packages/core/agent_framework/_workflows/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: Checkpoint
# Save executor states into the shared state before creating the checkpoint,
# so that they are included in the checkpoint payload.
await self._save_executor_states()
# `on_checkpoint_save()` writes via State.set(), which stages values in the
# pending buffer. Checkpoints serialize committed state only, so commit here
# to ensure executor snapshots are captured in this checkpoint.
self._state.commit()

checkpoint_id = await self._ctx.create_checkpoint(
self._workflow_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -941,8 +941,16 @@ def _prepare_content_for_openai(
"""Prepare content for the OpenAI Responses API format."""
match content.type:
case "text":
if role == "assistant":
# Assistant history is represented as output text items; Azure validation
# requires `annotations` to be present for this type.
return {
"type": "output_text",
"text": content.text,
"annotations": [],
}
return {
"type": "output_text" if role == "assistant" else "input_text",
"type": "input_text",
"text": content.text,
}
case "text_reasoning":
Expand Down
34 changes: 34 additions & 0 deletions python/packages/core/tests/openai/test_openai_responses_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,40 @@ def test_prepare_content_for_openai_hosted_vector_store_content() -> None:
assert result == {}


def test_prepare_content_for_openai_text_uses_role_specific_type() -> None:
"""Text content should use input_text for user and output_text for assistant."""
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")

text_content = Content.from_text(text="hello")

user_result = client._prepare_content_for_openai("user", text_content, {})
assistant_result = client._prepare_content_for_openai("assistant", text_content, {})

assert user_result["type"] == "input_text"
assert assistant_result["type"] == "output_text"
assert assistant_result["annotations"] == []
assert user_result["text"] == "hello"
assert assistant_result["text"] == "hello"


def test_prepare_messages_for_openai_assistant_history_uses_output_text_with_annotations() -> None:
"""Assistant history should be output_text and include required annotations."""
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")

messages = [
Message(role="user", text="What is async/await?"),
Message(role="assistant", text="Async/await enables non-blocking concurrency."),
]

prepared = client._prepare_messages_for_openai(messages)

assert prepared[0]["role"] == "user"
assert prepared[0]["content"][0]["type"] == "input_text"
assert prepared[1]["role"] == "assistant"
assert prepared[1]["content"][0]["type"] == "output_text"
assert prepared[1]["content"][0]["annotations"] == []


def test_parse_response_from_openai_with_mcp_server_tool_result() -> None:
"""Test _parse_response_from_openai with MCP server tool result."""
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
Expand Down
62 changes: 62 additions & 0 deletions python/packages/core/tests/workflow/test_full_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from agent_framework import (
AgentExecutor,
AgentExecutorRequest,
AgentExecutorResponse,
AgentResponse,
AgentResponseUpdate,
Expand Down Expand Up @@ -150,3 +151,64 @@ async def test_sequential_adapter_uses_full_conversation() -> None:
assert len(seen) == 2
assert seen[0].role == "user" and "hello seq" in (seen[0].text or "")
assert seen[1].role == "assistant" and "A1 reply" in (seen[1].text or "")


class _RoundTripCoordinator(Executor):
"""Loops once back to the same agent with full conversation + feedback."""

def __init__(self, *, target_agent_id: str, id: str = "round_trip_coordinator") -> None:
super().__init__(id=id)
self._target_agent_id = target_agent_id
self._seen = 0

@handler
async def handle_response(
self,
response: AgentExecutorResponse,
ctx: WorkflowContext[Never, dict[str, Any]],
) -> None:
self._seen += 1
if self._seen == 1:
assert response.full_conversation is not None
await ctx.send_message(
AgentExecutorRequest(
messages=list(response.full_conversation) + [Message(role="user", text="apply feedback")],
should_respond=True,
),
target_id=self._target_agent_id,
)
return

assert response.full_conversation is not None
await ctx.yield_output({
"roles": [m.role for m in response.full_conversation],
"texts": [m.text for m in response.full_conversation],
})


async def test_agent_executor_full_conversation_round_trip_does_not_duplicate_history() -> None:
"""When full history is replayed, AgentExecutor should not duplicate prior turns."""
agent = _SimpleAgent(id="writer_agent", name="Writer", reply_text="draft reply")
agent_exec = AgentExecutor(agent, id="writer_agent")
coordinator = _RoundTripCoordinator(target_agent_id="writer_agent")

wf = (
WorkflowBuilder(start_executor=agent_exec, output_executors=[coordinator])
.add_edge(agent_exec, coordinator)
.add_edge(coordinator, agent_exec)
.build()
)

result = await wf.run("initial prompt")
outputs = result.get_outputs()
assert len(outputs) == 1
payload = outputs[0]
assert isinstance(payload, dict)

# Expected conversation after one loop:
# user(initial), assistant(first reply), user(feedback), assistant(second reply)
assert payload["roles"] == ["user", "assistant", "user", "assistant"]
assert payload["texts"][0] == "initial prompt"
assert payload["texts"][1] == "draft reply"
assert payload["texts"][2] == "apply feedback"
assert payload["texts"][3] == "draft reply"
135 changes: 135 additions & 0 deletions python/packages/core/tests/workflow/test_workflow_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,41 @@ async def _run() -> AgentResponse:
return _run()


class _OptionsAwareAgent(BaseAgent):
"""Test agent that captures explicit `options` and kwargs passed to run()."""

captured_options: list[dict[str, Any] | None]
captured_kwargs: list[dict[str, Any]]

def __init__(self, name: str = "options_agent") -> None:
super().__init__(name=name, description="Test agent for options capture")
self.captured_options = []
self.captured_kwargs = []

def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
*,
stream: bool = False,
thread: AgentThread | None = None,
options: dict[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
self.captured_options.append(dict(options) if options is not None else None)
self.captured_kwargs.append(dict(kwargs))
if stream:

async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} response")])

return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)

async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", [f"{self.name} response"])])

return _run()


# region Sequential Builder Tests


Expand Down Expand Up @@ -131,6 +166,106 @@ async def test_sequential_run_kwargs_flow() -> None:
assert agent.captured_kwargs[0].get("custom_data") == {"test": True}


async def test_sequential_run_options_does_not_conflict_with_agent_options() -> None:
"""Test workflow.run(options=...) does not conflict with Agent.run(options=...)."""
agent = _OptionsAwareAgent(name="options_agent")
workflow = SequentialBuilder(participants=[agent]).build()

custom_data = {"session_id": "abc123"}
user_token = {"user_name": "alice"}
provided_options = {
"store": False,
"additional_function_arguments": {"source": "workflow-options"},
}

async for event in workflow.run(
"test message",
stream=True,
options=provided_options,
custom_data=custom_data,
user_token=user_token,
):
if event.type == "status" and event.state == WorkflowRunState.IDLE:
break

assert len(agent.captured_options) >= 1
captured_options = agent.captured_options[0]
assert captured_options is not None
assert captured_options.get("store") is False

additional_args = captured_options.get("additional_function_arguments")
assert isinstance(additional_args, dict)
assert additional_args.get("source") == "workflow-options"
assert additional_args.get("custom_data") == custom_data
assert additional_args.get("user_token") == user_token

# "options" should be passed once via the dedicated options parameter,
# not duplicated in **kwargs.
assert len(agent.captured_kwargs) >= 1
captured_kwargs = agent.captured_kwargs[0]
assert "options" not in captured_kwargs
assert captured_kwargs.get("custom_data") == custom_data
assert captured_kwargs.get("user_token") == user_token


async def test_sequential_run_additional_function_arguments_flattened() -> None:
"""Test workflow.run(additional_function_arguments=...) maps directly to tool kwargs."""
agent = _OptionsAwareAgent(name="options_agent")
workflow = SequentialBuilder(participants=[agent]).build()

custom_data = {"session_id": "abc123"}
user_token = {"user_name": "alice"}

async for event in workflow.run(
"test message",
stream=True,
additional_function_arguments={"custom_data": custom_data, "user_token": user_token},
):
if event.type == "status" and event.state == WorkflowRunState.IDLE:
break

assert len(agent.captured_options) >= 1
captured_options = agent.captured_options[0]
assert captured_options is not None

additional_args = captured_options.get("additional_function_arguments")
assert isinstance(additional_args, dict)
assert additional_args.get("custom_data") == custom_data
assert additional_args.get("user_token") == user_token
assert "additional_function_arguments" not in additional_args

assert len(agent.captured_kwargs) >= 1
captured_kwargs = agent.captured_kwargs[0]
assert "additional_function_arguments" not in captured_kwargs


async def test_sequential_run_additional_function_arguments_merges_with_options() -> None:
"""Test workflow additional_function_arguments merges with workflow options."""
agent = _OptionsAwareAgent(name="options_agent")
workflow = SequentialBuilder(participants=[agent]).build()

async for event in workflow.run(
"test message",
stream=True,
options={"additional_function_arguments": {"source": "workflow-options"}},
additional_function_arguments={"custom_data": {"session_id": "abc123"}},
user_token={"user_name": "alice"},
):
if event.type == "status" and event.state == WorkflowRunState.IDLE:
break

assert len(agent.captured_options) >= 1
captured_options = agent.captured_options[0]
assert captured_options is not None

additional_args = captured_options.get("additional_function_arguments")
assert isinstance(additional_args, dict)
assert additional_args.get("source") == "workflow-options"
assert additional_args.get("custom_data") == {"session_id": "abc123"}
assert additional_args.get("user_token") == {"user_name": "alice"}
assert "additional_function_arguments" not in additional_args


# endregion


Expand Down
Loading
Loading