diff --git a/pyproject.toml b/pyproject.toml index 547fc9cf9..2627f762d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "mcp[cli]>=1.4.1", "scale-gp>=0.1.0a59", "openai-agents==0.14.1", + "pydantic-ai-slim>=1.0,<2", "tzlocal>=5.3.1", "tzdata>=2025.2", "pytest>=8.4.0", diff --git a/requirements-dev.lock b/requirements-dev.lock index 62167cd44..a76d199b7 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -112,10 +112,13 @@ frozenlist==1.8.0 # via aiosignal fsspec==2026.3.0 # via huggingface-hub +genai-prices==0.0.59 + # via pydantic-ai-slim google-auth==2.49.1 # via kubernetes griffelib==2.0.2 # via openai-agents + # via pydantic-ai-slim h11==0.16.0 # via httpcore # via uvicorn @@ -126,12 +129,15 @@ httpcore==1.0.9 httpx==0.28.1 # via agentex-sdk # via anthropic + # via genai-prices # via httpx-aiohttp # via huggingface-hub # via langsmith # via litellm # via mcp # via openai + # via pydantic-ai-slim + # via pydantic-graph # via respx # via scale-gp # via scale-gp-beta @@ -196,6 +202,8 @@ langsmith==0.7.22 # via langchain-core litellm==1.83.7 # via agentex-sdk +logfire-api==4.32.1 + # via pydantic-graph markdown-it-py==3.0.0 # via rich markupsafe==3.0.3 @@ -236,6 +244,7 @@ opentelemetry-api==1.40.0 # via ddtrace # via opentelemetry-sdk # via opentelemetry-semantic-conventions + # via pydantic-ai-slim opentelemetry-sdk==1.40.0 # via agentex-sdk opentelemetry-semantic-conventions==0.61b0 @@ -287,18 +296,25 @@ pydantic==2.12.5 # via agentex-sdk # via anthropic # via fastapi + # via genai-prices # via langchain-core # via langsmith # via litellm # via mcp # via openai # via openai-agents + # via pydantic-ai-slim + # via pydantic-graph # via pydantic-settings # via python-on-whales # via scale-gp # via scale-gp-beta +pydantic-ai-slim==1.92.0 + # via agentex-sdk pydantic-core==2.41.5 # via pydantic +pydantic-graph==1.92.0 + # via pydantic-ai-slim pydantic-settings==2.13.1 # via mcp pygments==2.19.2 @@ -459,6 +475,8 @@ typing-inspection==0.4.2 # via fastapi # via mcp # via pydantic + # via pydantic-ai-slim + # via pydantic-graph # via pydantic-settings tzdata==2025.3 # via agentex-sdk diff --git a/requirements.lock b/requirements.lock index 414afb203..bd8318a93 100644 --- a/requirements.lock +++ b/requirements.lock @@ -99,10 +99,13 @@ frozenlist==1.8.0 # via aiosignal fsspec==2026.3.0 # via huggingface-hub +genai-prices==0.0.59 + # via pydantic-ai-slim google-auth==2.49.1 # via kubernetes griffelib==2.0.2 # via openai-agents + # via pydantic-ai-slim h11==0.16.0 # via httpcore # via uvicorn @@ -113,12 +116,15 @@ httpcore==1.0.9 httpx==0.28.1 # via agentex-sdk # via anthropic + # via genai-prices # via httpx-aiohttp # via huggingface-hub # via langsmith # via litellm # via mcp # via openai + # via pydantic-ai-slim + # via pydantic-graph # via scale-gp # via scale-gp-beta httpx-aiohttp==0.1.12 @@ -180,6 +186,8 @@ langsmith==0.7.22 # via langchain-core litellm==1.83.7 # via agentex-sdk +logfire-api==4.32.1 + # via pydantic-graph markdown-it-py==4.0.0 # via rich markupsafe==3.0.3 @@ -214,6 +222,7 @@ opentelemetry-api==1.40.0 # via ddtrace # via opentelemetry-sdk # via opentelemetry-semantic-conventions + # via pydantic-ai-slim opentelemetry-sdk==1.40.0 # via agentex-sdk opentelemetry-semantic-conventions==0.61b0 @@ -260,18 +269,25 @@ pydantic==2.12.5 # via agentex-sdk # via anthropic # via fastapi + # via genai-prices # via langchain-core # via langsmith # via litellm # via mcp # via openai # via openai-agents + # via pydantic-ai-slim + # via pydantic-graph # via pydantic-settings # via python-on-whales # via scale-gp # via scale-gp-beta +pydantic-ai-slim==1.92.0 + # via agentex-sdk pydantic-core==2.41.5 # via pydantic +pydantic-graph==1.92.0 + # via pydantic-ai-slim pydantic-settings==2.13.1 # via mcp pygments==2.20.0 @@ -424,6 +440,8 @@ typing-inspection==0.4.2 # via fastapi # via mcp # via pydantic + # via pydantic-ai-slim + # via pydantic-graph # via pydantic-settings tzdata==2025.3 # via agentex-sdk diff --git a/src/agentex/lib/adk/providers/_modules/pydantic_ai.py b/src/agentex/lib/adk/providers/_modules/pydantic_ai.py new file mode 100644 index 000000000..805cc773f --- /dev/null +++ b/src/agentex/lib/adk/providers/_modules/pydantic_ai.py @@ -0,0 +1,289 @@ +"""Pydantic AI streaming integration for Agentex. + +Converts a Pydantic AI ``AgentStreamEvent`` stream (as yielded by +``agent.run_stream_events(...)`` or via an ``event_stream_handler``) into the +Agentex ``StreamTaskMessage*`` events that the Agentex server understands. + +Typical sync usage: + + from pydantic_ai import Agent + from agentex.lib.adk.providers._modules.pydantic_ai import ( + convert_pydantic_ai_to_agentex_events, + ) + + agent = Agent("openai:gpt-4o", system_prompt="...") + + @acp.on_message_send + async def handle_message_send(params): + async with agent.run_stream_events(params.content.content) as stream: + async for event in convert_pydantic_ai_to_agentex_events(stream): + yield event +""" + +from __future__ import annotations + +import json +from typing import Any, AsyncIterator + +from pydantic_ai.messages import ( + FinalResultEvent, + FunctionToolCallEvent, + FunctionToolResultEvent, + PartDeltaEvent, + PartEndEvent, + PartStartEvent, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, + ToolCallPart, + ToolCallPartDelta, + ToolReturnPart, +) +from pydantic_ai.run import AgentRunResultEvent + +from agentex.lib.utils.logging import make_logger +from agentex.types.reasoning_content_delta import ReasoningContentDelta +from agentex.types.task_message_content import TextContent +from agentex.types.task_message_delta import TextDelta +from agentex.types.task_message_update import ( + StreamTaskMessageDelta, + StreamTaskMessageDone, + StreamTaskMessageFull, + StreamTaskMessageStart, +) +from agentex.types.tool_request_content import ToolRequestContent +from agentex.types.tool_request_delta import ToolRequestDelta +from agentex.types.tool_response_content import ToolResponseContent + +logger = make_logger(__name__) + + +def _args_delta_to_str(args_delta: str | dict[str, Any] | None) -> str: + """Normalize a Pydantic AI ``ToolCallPartDelta.args_delta`` to a string fragment. + + Pydantic AI emits string fragments for providers that stream JSON tokens + (OpenAI, Anthropic) and dicts for providers that emit one-shot tool calls. + Agentex's ``ToolRequestDelta.arguments_delta`` is concatenated server-side + and parsed as a single JSON object on completion, so we always produce a + string. For dict deltas this is a one-shot dump; subsequent dict deltas + will not compose correctly, but in practice dict deltas arrive as a single + final fragment. + """ + if args_delta is None: + return "" + if isinstance(args_delta, str): + return args_delta + return json.dumps(args_delta) + + +def _tool_return_content(result: ToolReturnPart | Any) -> Any: + """Best-effort extraction of the user-visible content from a tool result. + + ``FunctionToolResultEvent.result`` is ``ToolReturnPart | RetryPromptPart``. + For ``ToolReturnPart`` we surface ``.content`` directly; for ``RetryPromptPart`` + (a retry signal back to the model) we surface a string description so the + UI sees the failure reason. + """ + content = getattr(result, "content", None) + if content is None: + return str(result) + if isinstance(content, (str, int, float, bool, list, dict)): + return content + if hasattr(content, "model_dump"): + try: + return content.model_dump() + except Exception: + return str(content) + return str(content) + + +async def convert_pydantic_ai_to_agentex_events( + stream_response: AsyncIterator[Any], +) -> AsyncIterator[ + StreamTaskMessageStart | StreamTaskMessageDelta | StreamTaskMessageFull | StreamTaskMessageDone +]: + """Convert a Pydantic AI agent event stream into Agentex stream events. + + Mapping: + PartStartEvent(TextPart) -> StreamTaskMessageStart(TextContent) + PartStartEvent(ThinkingPart) -> StreamTaskMessageStart(TextContent) [reasoning channel] + PartStartEvent(ToolCallPart) -> StreamTaskMessageStart(ToolRequestContent) + PartDeltaEvent(TextPartDelta) -> StreamTaskMessageDelta(TextDelta) + PartDeltaEvent(ThinkingPart..) -> StreamTaskMessageDelta(ReasoningContentDelta) + PartDeltaEvent(ToolCallPart..) -> StreamTaskMessageDelta(ToolRequestDelta) + PartEndEvent -> StreamTaskMessageDone + FunctionToolResultEvent -> StreamTaskMessageFull(ToolResponseContent) + FunctionToolCallEvent -> (ignored — already covered by Start/Delta/End) + FinalResultEvent -> (ignored — informational; the run-level + AgentRunResultEvent terminates the stream) + AgentRunResultEvent -> (ignored — Agentex closes the per-message + stream via PartEndEvent already) + + Args: + stream_response: The async iterator yielded by Pydantic AI's + ``agent.run_stream_events(...)`` context manager (or a stream of + ``AgentStreamEvent`` items received in an ``event_stream_handler``). + + Yields: + Agentex ``StreamTaskMessage*`` events suitable for forwarding back over + the ACP streaming response. + """ + next_message_index = 0 + # Maps Pydantic AI's per-response part index to our absolute message index. + # Part indices restart at 0 on each new model response in a multi-step run, + # so we always overwrite the entry on PartStartEvent. + part_to_message_index: dict[int, int] = {} + # Tool-call metadata indexed by Pydantic AI part index (so deltas can + # surface the tool_call_id even when ToolCallPartDelta.tool_call_id is None). + tool_call_meta: dict[int, tuple[str, str]] = {} + + async for event in stream_response: + if isinstance(event, PartStartEvent): + message_index = next_message_index + next_message_index += 1 + part_to_message_index[event.index] = message_index + + if isinstance(event.part, TextPart): + yield StreamTaskMessageStart( + type="start", + index=message_index, + content=TextContent( + type="text", + author="agent", + content=event.part.content or "", + ), + ) + elif isinstance(event.part, ThinkingPart): + yield StreamTaskMessageStart( + type="start", + index=message_index, + content=TextContent( + type="text", + author="agent", + content="", + ), + ) + if event.part.content: + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ReasoningContentDelta( + type="reasoning_content", + content_index=0, + content_delta=event.part.content, + ), + ) + elif isinstance(event.part, ToolCallPart): + tool_call_meta[event.index] = (event.part.tool_call_id, event.part.tool_name) + # Pydantic AI may already have a fully-formed args dict at start + # when the provider returns the tool call in one shot; surface it + # directly so clients see the complete arguments without waiting + # for deltas. + initial_args: dict[str, Any] = {} + if isinstance(event.part.args, dict): + initial_args = event.part.args + yield StreamTaskMessageStart( + type="start", + index=message_index, + content=ToolRequestContent( + type="tool_request", + author="agent", + tool_call_id=event.part.tool_call_id, + name=event.part.tool_name, + arguments=initial_args, + ), + ) + if isinstance(event.part.args, str) and event.part.args: + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ToolRequestDelta( + type="tool_request", + tool_call_id=event.part.tool_call_id, + name=event.part.tool_name, + arguments_delta=event.part.args, + ), + ) + else: + logger.debug("Unhandled PartStartEvent part type: %r", type(event.part).__name__) + + elif isinstance(event, PartDeltaEvent): + message_index = part_to_message_index.get(event.index) + if message_index is None: + logger.debug("PartDeltaEvent for unknown part index %s; skipping", event.index) + continue + + if isinstance(event.delta, TextPartDelta): + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=TextDelta(type="text", text_delta=event.delta.content_delta), + ) + elif isinstance(event.delta, ThinkingPartDelta): + if event.delta.content_delta: + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ReasoningContentDelta( + type="reasoning_content", + content_index=0, + content_delta=event.delta.content_delta, + ), + ) + elif isinstance(event.delta, ToolCallPartDelta): + meta = tool_call_meta.get(event.index) + if meta is None: + # First time we've seen this part; the provider didn't emit + # a PartStartEvent first. Synthesize one from the delta if + # we have enough information. + tool_call_id = event.delta.tool_call_id or "" + tool_name = event.delta.tool_name_delta or "" + tool_call_meta[event.index] = (tool_call_id, tool_name) + else: + tool_call_id, tool_name = meta + yield StreamTaskMessageDelta( + type="delta", + index=message_index, + delta=ToolRequestDelta( + type="tool_request", + tool_call_id=tool_call_id, + name=tool_name, + arguments_delta=_args_delta_to_str(event.delta.args_delta), + ), + ) + else: + logger.debug("Unhandled PartDeltaEvent delta type: %r", type(event.delta).__name__) + + elif isinstance(event, PartEndEvent): + message_index = part_to_message_index.get(event.index) + if message_index is None: + continue + yield StreamTaskMessageDone(type="done", index=message_index) + + elif isinstance(event, FunctionToolResultEvent): + result = event.result + tool_call_id = result.tool_call_id + tool_name = getattr(result, "tool_name", "") or "" + message_index = next_message_index + next_message_index += 1 + yield StreamTaskMessageFull( + type="full", + index=message_index, + content=ToolResponseContent( + type="tool_response", + author="agent", + tool_call_id=tool_call_id, + name=tool_name, + content=_tool_return_content(result), + ), + ) + + elif isinstance(event, (FunctionToolCallEvent, FinalResultEvent, AgentRunResultEvent)): + # Already covered by PartStart/PartDelta/PartEnd events above, or + # informational only (FinalResultEvent / AgentRunResultEvent signal + # run-level state, not new content to surface). + continue + + else: + logger.debug("Unhandled Pydantic AI event type: %r", type(event).__name__) diff --git a/tests/lib/adk/providers/test_pydantic_ai.py b/tests/lib/adk/providers/test_pydantic_ai.py new file mode 100644 index 000000000..083420cda --- /dev/null +++ b/tests/lib/adk/providers/test_pydantic_ai.py @@ -0,0 +1,397 @@ +"""Tests for the Pydantic AI -> Agentex stream event converter.""" + +from __future__ import annotations + +import json +from typing import Any, AsyncIterator + +import pytest +from pydantic_ai.messages import ( + FinalResultEvent, + FunctionToolCallEvent, + FunctionToolResultEvent, + PartDeltaEvent, + PartEndEvent, + PartStartEvent, + RetryPromptPart, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, + ToolCallPart, + ToolCallPartDelta, + ToolReturnPart, +) + +from agentex.lib.adk.providers._modules.pydantic_ai import ( + _args_delta_to_str, + convert_pydantic_ai_to_agentex_events, +) +from agentex.types.reasoning_content_delta import ReasoningContentDelta +from agentex.types.task_message_content import TextContent +from agentex.types.task_message_delta import TextDelta +from agentex.types.task_message_update import ( + StreamTaskMessageDelta, + StreamTaskMessageDone, + StreamTaskMessageFull, + StreamTaskMessageStart, +) +from agentex.types.tool_request_content import ToolRequestContent +from agentex.types.tool_request_delta import ToolRequestDelta +from agentex.types.tool_response_content import ToolResponseContent + + +async def _aiter(events: list[Any]) -> AsyncIterator[Any]: + for e in events: + yield e + + +async def _collect(stream: AsyncIterator[Any]) -> list[Any]: + return [e async for e in stream] + + +class TestArgsDeltaToStr: + def test_none(self): + assert _args_delta_to_str(None) == "" + + def test_string_passthrough(self): + assert _args_delta_to_str('{"k":') == '{"k":' + + def test_dict_dumps_json(self): + assert json.loads(_args_delta_to_str({"city": "Paris"})) == {"city": "Paris"} + + +class TestTextStreaming: + async def test_plain_text_emits_start_deltas_done(self): + events = [ + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="Hello")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=", ")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="world!")), + PartEndEvent(index=0, part=TextPart(content="Hello, world!")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + assert len(out) == 5 + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[0].content, TextContent) + assert out[0].content.content == "" + assert out[0].index == 0 + + for i, expected in enumerate(["Hello", ", ", "world!"], start=1): + assert isinstance(out[i], StreamTaskMessageDelta) + assert isinstance(out[i].delta, TextDelta) + assert out[i].delta.text_delta == expected + assert out[i].index == 0 + + assert isinstance(out[4], StreamTaskMessageDone) + assert out[4].index == 0 + + async def test_text_with_initial_content_preserved(self): + events = [ + PartStartEvent(index=0, part=TextPart(content="Already there")), + PartEndEvent(index=0, part=TextPart(content="Already there")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageStart) + assert out[0].content.content == "Already there" + + +class TestThinkingStreaming: + async def test_thinking_emits_reasoning_deltas(self): + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta="step 1...")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=" step 2.")), + PartEndEvent(index=0, part=ThinkingPart(content="step 1... step 2.")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[1], StreamTaskMessageDelta) + assert isinstance(out[1].delta, ReasoningContentDelta) + assert out[1].delta.content_delta == "step 1..." + assert out[1].delta.content_index == 0 + assert isinstance(out[2].delta, ReasoningContentDelta) + assert out[2].delta.content_delta == " step 2." + assert isinstance(out[3], StreamTaskMessageDone) + + async def test_thinking_with_initial_content_emits_delta(self): + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="seed reasoning")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[1], StreamTaskMessageDelta) + assert out[1].delta.content_delta == "seed reasoning" + + async def test_thinking_delta_skipped_when_empty(self): + events = [ + PartStartEvent(index=0, part=ThinkingPart(content="")), + PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=None)), + PartEndEvent(index=0, part=ThinkingPart(content="")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert len(out) == 2 # Start + Done; no delta for None content + + +class TestToolCallStreaming: + async def test_tool_call_streamed_token_by_token(self): + """The headline use case: tool-call argument tokens streaming through to the client.""" + events = [ + PartStartEvent( + index=1, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="call_abc"), + ), + PartDeltaEvent( + index=1, + delta=ToolCallPartDelta(args_delta='{"city":', tool_call_id="call_abc"), + ), + PartDeltaEvent(index=1, delta=ToolCallPartDelta(args_delta='"Paris"}')), + PartEndEvent( + index=1, + part=ToolCallPart( + tool_name="get_weather", args='{"city":"Paris"}', tool_call_id="call_abc" + ), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + assert len(out) == 4 + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[0].content, ToolRequestContent) + assert out[0].content.tool_call_id == "call_abc" + assert out[0].content.name == "get_weather" + assert out[0].content.arguments == {} + + assert isinstance(out[1].delta, ToolRequestDelta) + assert out[1].delta.tool_call_id == "call_abc" + assert out[1].delta.name == "get_weather" + assert out[1].delta.arguments_delta == '{"city":' + + assert isinstance(out[2].delta, ToolRequestDelta) + assert out[2].delta.arguments_delta == '"Paris"}' + # tool_call_id is carried forward from the start even when the delta omits it + assert out[2].delta.tool_call_id == "call_abc" + + assert isinstance(out[3], StreamTaskMessageDone) + + async def test_tool_call_with_full_args_at_start(self): + """Some providers return a tool call in one shot — args dict is set at start.""" + events = [ + PartStartEvent( + index=0, + part=ToolCallPart( + tool_name="search", args={"query": "weather"}, tool_call_id="call_xyz" + ), + ), + PartEndEvent( + index=0, + part=ToolCallPart( + tool_name="search", args={"query": "weather"}, tool_call_id="call_xyz" + ), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageStart) + assert out[0].content.arguments == {"query": "weather"} + # No deltas emitted — args were already complete. + assert len(out) == 2 + assert isinstance(out[1], StreamTaskMessageDone) + + async def test_tool_call_with_full_args_string_at_start(self): + """When args is a complete JSON string at start, surface it as a single delta.""" + events = [ + PartStartEvent( + index=0, + part=ToolCallPart( + tool_name="search", args='{"query":"weather"}', tool_call_id="call_z" + ), + ), + PartEndEvent( + index=0, + part=ToolCallPart( + tool_name="search", args='{"query":"weather"}', tool_call_id="call_z" + ), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageStart) + assert out[0].content.arguments == {} + assert isinstance(out[1], StreamTaskMessageDelta) + assert out[1].delta.arguments_delta == '{"query":"weather"}' + + async def test_tool_call_dict_args_delta_serialized(self): + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="cid"), + ), + PartDeltaEvent( + index=0, + delta=ToolCallPartDelta(args_delta={"k": "v"}, tool_call_id="cid"), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert json.loads(out[1].delta.arguments_delta) == {"k": "v"} + + async def test_tool_result_emits_full(self): + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="call_abc"), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name="get_weather", args="{}", tool_call_id="call_abc"), + ), + FunctionToolResultEvent( + result=ToolReturnPart( + tool_name="get_weather", content="Sunny, 72F", tool_call_id="call_abc" + ), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + # Last event is the tool result -> Full ToolResponseContent + assert isinstance(out[-1], StreamTaskMessageFull) + assert isinstance(out[-1].content, ToolResponseContent) + assert out[-1].content.tool_call_id == "call_abc" + assert out[-1].content.name == "get_weather" + assert out[-1].content.content == "Sunny, 72F" + + async def test_tool_retry_prompt_surfaces_as_response(self): + events = [ + FunctionToolResultEvent( + result=RetryPromptPart( + content="bad arguments", + tool_name="get_weather", + tool_call_id="call_abc", + ), + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert isinstance(out[0], StreamTaskMessageFull) + assert isinstance(out[0].content, ToolResponseContent) + assert out[0].content.tool_call_id == "call_abc" + assert out[0].content.name == "get_weather" + # RetryPromptPart's content is the error message + assert out[0].content.content == "bad arguments" + + +class TestMultiStepRun: + async def test_text_then_tool_then_text_assigns_distinct_indices(self): + """A multi-step run: model emits text + tool call → tool runs → model emits more text. + + Pydantic AI restarts part indices at 0 for each new model response, so + the converter must assign fresh Agentex message indices. + """ + events = [ + # First model response: text at index 0, tool call at index 1 + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="Looking up...")), + PartEndEvent(index=0, part=TextPart(content="Looking up...")), + PartStartEvent( + index=1, + part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="c1"), + ), + PartDeltaEvent(index=1, delta=ToolCallPartDelta(args_delta="{}")), + PartEndEvent( + index=1, part=ToolCallPart(tool_name="get_weather", args="{}", tool_call_id="c1") + ), + FunctionToolResultEvent( + result=ToolReturnPart(tool_name="get_weather", content="Sunny", tool_call_id="c1"), + ), + # Second model response: text restarts at index 0 + PartStartEvent(index=0, part=TextPart(content="")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="It's sunny.")), + PartEndEvent(index=0, part=TextPart(content="It's sunny.")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + + # Pull every Start/Full event and check their assigned message indices + anchors = [ + e for e in out if isinstance(e, (StreamTaskMessageStart, StreamTaskMessageFull)) + ] + indices = [e.index for e in anchors] + assert indices == [0, 1, 2, 3], ( + f"Expected 4 distinct, monotonic message indices for: text1, tool_call, " + f"tool_result, text2 — got {indices}" + ) + + # And the second text's deltas should target the second text's message index. + text2_start = anchors[3] + text2_deltas = [ + e + for e in out + if isinstance(e, StreamTaskMessageDelta) + and isinstance(e.delta, TextDelta) + and e.index == text2_start.index + ] + assert len(text2_deltas) == 1 + assert text2_deltas[0].delta.text_delta == "It's sunny." + + +class TestIgnoredEvents: + async def test_function_tool_call_event_is_ignored(self): + """FunctionToolCallEvent is redundant with PartStart+Delta+End and should be skipped.""" + events = [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="c"), + ), + FunctionToolCallEvent( + part=ToolCallPart(tool_name="t", args="{}", tool_call_id="c"), + ), + PartEndEvent( + index=0, part=ToolCallPart(tool_name="t", args="{}", tool_call_id="c") + ), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + # Start + Done only — no event from FunctionToolCallEvent + assert len(out) == 2 + assert isinstance(out[0], StreamTaskMessageStart) + assert isinstance(out[1], StreamTaskMessageDone) + + async def test_final_result_event_ignored(self): + events = [ + FinalResultEvent(tool_name=None, tool_call_id=None), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert out == [] + + async def test_unknown_part_index_delta_skipped(self): + events = [ + PartDeltaEvent(index=99, delta=TextPartDelta(content_delta="orphan")), + ] + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + assert out == [] + + +class TestStartingTextMatchesAuthor: + """Sanity check that all emitted content is authored by the agent.""" + + @pytest.mark.parametrize( + "events", + [ + [PartStartEvent(index=0, part=TextPart(content=""))], + [PartStartEvent(index=0, part=ThinkingPart(content=""))], + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name="t", args=None, tool_call_id="c"), + ) + ], + [ + FunctionToolResultEvent( + result=ToolReturnPart(tool_name="t", content="ok", tool_call_id="c"), + ) + ], + ], + ) + async def test_author_is_agent(self, events: list[Any]): + out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(events))) + for e in out: + content = getattr(e, "content", None) + if content is not None and hasattr(content, "author"): + assert content.author == "agent"