diff --git a/src/agents/run_state.py b/src/agents/run_state.py index 6df92ebb5..4e4ab175b 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -544,7 +544,12 @@ def to_json( result["current_step"] = self._serialize_current_step() result["last_model_response"] = _serialize_last_model_response(model_responses) result["last_processed_response"] = ( - self._serialize_processed_response(self._last_processed_response) + self._serialize_processed_response( + self._last_processed_response, + context_serializer=context_serializer, + strict_context=strict_context, + include_tracing_api_key=include_tracing_api_key, + ) if self._last_processed_response else None ) @@ -556,7 +561,12 @@ def to_json( return result def _serialize_processed_response( - self, processed_response: ProcessedResponse + self, + processed_response: ProcessedResponse, + *, + context_serializer: ContextSerializer | None = None, + strict_context: bool = False, + include_tracing_api_key: bool = False, ) -> dict[str, Any]: """Serialize a ProcessedResponse to JSON format. @@ -568,6 +578,14 @@ def _serialize_processed_response( """ action_groups = _serialize_tool_action_groups(processed_response) + _serialize_pending_nested_agent_tool_runs( + parent_state=self, + function_entries=action_groups.get("functions", []), + function_runs=processed_response.functions, + context_serializer=context_serializer, + strict_context=strict_context, + include_tracing_api_key=include_tracing_api_key, + ) interruptions_data = [ _serialize_tool_approval_interruption(interruption, include_tool_name=True) @@ -1138,6 +1156,82 @@ def _serialize_tool_action_groups( return serialized +def _serialize_pending_nested_agent_tool_runs( + *, + parent_state: RunState[Any, Any], + function_entries: Sequence[dict[str, Any]], + function_runs: Sequence[Any], + context_serializer: ContextSerializer | None = None, + strict_context: bool = False, + include_tracing_api_key: bool = False, +) -> None: + """Attach serialized nested run state for pending agent-as-tool interruptions.""" + if not function_entries or not function_runs: + return + + from .agent_tool_state import peek_agent_tool_run_result + + for entry, function_run in zip(function_entries, function_runs): + tool_call = getattr(function_run, "tool_call", None) + if not isinstance(tool_call, ResponseFunctionToolCall): + continue + + pending_run_result = peek_agent_tool_run_result(tool_call) + if pending_run_result is None: + continue + + interruptions = getattr(pending_run_result, "interruptions", None) + if not isinstance(interruptions, list) or not interruptions: + continue + + to_state = getattr(pending_run_result, "to_state", None) + if not callable(to_state): + continue + + try: + nested_state = to_state() + except Exception: + if strict_context: + raise + logger.warning( + "Failed to capture nested agent run state for tool call %s.", + tool_call.call_id, + ) + continue + + if not isinstance(nested_state, RunState): + continue + if nested_state is parent_state: + # Defensive guard against accidental self-referential serialization loops. + continue + + try: + entry["agent_run_state"] = nested_state.to_json( + context_serializer=context_serializer, + strict_context=strict_context, + include_tracing_api_key=include_tracing_api_key, + ) + except Exception: + if strict_context: + raise + logger.warning( + "Failed to serialize nested agent run state for tool call %s.", + tool_call.call_id, + ) + + +class _SerializedAgentToolRunResult: + """Minimal run-result wrapper used to restore nested agent-as-tool resumptions.""" + + def __init__(self, state: RunState[Any, Agent[Any]]) -> None: + self._state = state + self.interruptions = list(state.get_interruptions()) + self.final_output = None + + def to_state(self) -> RunState[Any, Agent[Any]]: + return self._state + + def _serialize_guardrail_results( results: Sequence[InputGuardrailResult | OutputGuardrailResult], ) -> list[dict[str, Any]]: @@ -1215,11 +1309,65 @@ def _build_handoffs_map(current_agent: Agent[Any]) -> dict[str, Handoff[Any, Age return handoffs_map +async def _restore_pending_nested_agent_tool_runs( + *, + current_agent: Agent[Any], + function_entries: Sequence[Any], + function_runs: Sequence[Any], + context_deserializer: ContextDeserializer | None = None, + strict_context: bool = False, +) -> None: + """Rehydrate nested agent-as-tool run state into the ephemeral tool-call cache.""" + if not function_entries or not function_runs: + return + + from .agent_tool_state import drop_agent_tool_run_result, record_agent_tool_run_result + + for entry, function_run in zip(function_entries, function_runs): + if not isinstance(entry, Mapping): + continue + nested_state_data = entry.get("agent_run_state") + if not isinstance(nested_state_data, Mapping): + continue + + tool_call = getattr(function_run, "tool_call", None) + if not isinstance(tool_call, ResponseFunctionToolCall): + continue + + try: + nested_state = await _build_run_state_from_json( + initial_agent=current_agent, + state_json=dict(nested_state_data), + context_deserializer=context_deserializer, + strict_context=strict_context, + ) + except Exception: + if strict_context: + raise + logger.warning( + "Failed to deserialize nested agent run state for tool call %s.", + tool_call.call_id, + ) + continue + + pending_result = _SerializedAgentToolRunResult(nested_state) + if not pending_result.interruptions: + continue + + # Replace any stale cache entry with the same signature so resumed runs do not read + # older pending interruptions after consuming this restored entry. + drop_agent_tool_run_result(tool_call) + record_agent_tool_run_result(tool_call, cast(Any, pending_result)) + + async def _deserialize_processed_response( processed_response_data: dict[str, Any], current_agent: Agent[Any], context: RunContextWrapper[Any], agent_map: dict[str, Agent[Any]], + *, + context_deserializer: ContextDeserializer | None = None, + strict_context: bool = False, ) -> ProcessedResponse: """Deserialize a ProcessedResponse from JSON data. @@ -1403,6 +1551,14 @@ def _deserialize_action_groups() -> dict[str, list[Any]]: shell_actions = action_groups["shell_actions"] apply_patch_actions = action_groups["apply_patch_actions"] + await _restore_pending_nested_agent_tool_runs( + current_agent=current_agent, + function_entries=processed_response_data.get("functions", []), + function_runs=functions, + context_deserializer=context_deserializer, + strict_context=strict_context, + ) + mcp_approval_requests: list[ToolRunMCPApprovalRequest] = [] for request_data in processed_response_data.get("mcp_approval_requests", []): request_item_data = request_data.get("request_item", {}) @@ -1824,7 +1980,12 @@ async def _build_run_state_from_json( last_processed_response_data = state_json.get("last_processed_response") if last_processed_response_data and state._context is not None: state._last_processed_response = await _deserialize_processed_response( - last_processed_response_data, current_agent, state._context, agent_map + last_processed_response_data, + current_agent, + state._context, + agent_map, + context_deserializer=context_deserializer, + strict_context=strict_context, ) else: state._last_processed_response = None diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 2fd1dabd3..9f3fe48a4 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -2,9 +2,10 @@ from __future__ import annotations +import gc import json import logging -from collections.abc import Mapping +from collections.abc import AsyncIterator, Mapping from dataclasses import dataclass from datetime import datetime from typing import Any, Callable, TypeVar, cast @@ -23,7 +24,7 @@ from openai.types.responses.tool_param import Mcp from pydantic import BaseModel -from agents import Agent, Runner, handoff, trace +from agents import Agent, Model, ModelSettings, Runner, handoff, trace from agents.computer import Computer from agents.exceptions import UserError from agents.guardrail import ( @@ -44,6 +45,7 @@ ToolCallItem, ToolCallOutputItem, TResponseInputItem, + TResponseStreamEvent, ) from agents.run_context import RunContextWrapper from agents.run_internal.run_loop import ( @@ -1701,6 +1703,154 @@ async def test_nested_agent_tool_interruptions_roundtrip(self): assert interruptions[0].agent.name == "InnerAgent" assert interruptions[0].raw_item.name == "sensitive_tool" # type: ignore[union-attr] + @pytest.mark.asyncio + async def test_nested_agent_tool_hitl_resume_survives_json_round_trip_after_gc(self) -> None: + """Nested agent-tool resumptions should survive RunState JSON round-trips.""" + + def _has_function_call_output(input_data: str | list[TResponseInputItem]) -> bool: + if not isinstance(input_data, list): + return False + for item in input_data: + if isinstance(item, dict): + if item.get("type") == "function_call_output": + return True + continue + if getattr(item, "type", None) == "function_call_output": + return True + return False + + class ResumeAwareToolModel(Model): + def __init__( + self, *, tool_name: str, tool_arguments: str, final_text: str, call_prefix: str + ) -> None: + self.tool_name = tool_name + self.tool_arguments = tool_arguments + self.final_text = final_text + self.call_prefix = call_prefix + self.call_count = 0 + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Any], + output_schema: Any, + handoffs: list[Any], + tracing: Any, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: Any | None, + ) -> ModelResponse: + del ( + system_instructions, + model_settings, + tools, + output_schema, + handoffs, + tracing, + previous_response_id, + conversation_id, + prompt, + ) + if _has_function_call_output(input): + return ModelResponse( + output=[get_text_message(self.final_text)], + usage=Usage(), + response_id=f"{self.call_prefix}-done", + ) + + self.call_count += 1 + return ModelResponse( + output=[ + ResponseFunctionToolCall( + type="function_call", + name=self.tool_name, + call_id=f"{self.call_prefix}-{id(self)}-{self.call_count}", + arguments=self.tool_arguments, + ) + ], + usage=Usage(), + response_id=f"{self.call_prefix}-call-{self.call_count}", + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Any], + output_schema: Any, + handoffs: list[Any], + tracing: Any, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: Any | None, + ) -> AsyncIterator[TResponseStreamEvent]: + del ( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + tracing, + previous_response_id, + conversation_id, + prompt, + ) + if False: + yield cast(TResponseStreamEvent, {}) + raise RuntimeError("Streaming is not supported in this test.") + + tool_calls: list[str] = [] + + @function_tool(name_override="inner_sensitive_tool", needs_approval=True) + async def inner_sensitive_tool(text: str) -> str: + tool_calls.append(text) + return f"approved:{text}" + + inner_model = ResumeAwareToolModel( + tool_name="inner_sensitive_tool", + tool_arguments=json.dumps({"text": "hello"}), + final_text="inner-complete", + call_prefix="inner", + ) + inner_agent = Agent(name="InnerAgent", model=inner_model, tools=[inner_sensitive_tool]) + + outer_tool = inner_agent.as_tool( + tool_name="inner_agent_tool", + tool_description="Inner agent tool", + ) + outer_model = ResumeAwareToolModel( + tool_name="inner_agent_tool", + tool_arguments=json.dumps({"input": "hello"}), + final_text="outer-complete", + call_prefix="outer", + ) + outer_agent = Agent(name="OuterAgent", model=outer_model, tools=[outer_tool]) + + first_result = await Runner.run(outer_agent, "start") + assert first_result.final_output is None + assert first_result.interruptions + + state_json = first_result.to_state().to_json() + del first_result + gc.collect() + + restored_state = await RunState.from_json(outer_agent, state_json) + restored_interruptions = restored_state.get_interruptions() + assert len(restored_interruptions) == 1 + restored_state.approve(restored_interruptions[0]) + + resumed_result = await Runner.run(outer_agent, restored_state) + + assert resumed_result.final_output == "outer-complete" + assert resumed_result.interruptions == [] + assert tool_calls == ["hello"] + async def test_json_decode_error_handling(self): """Test that invalid JSON raises appropriate error.""" agent = Agent(name="TestAgent")