Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 164 additions & 3 deletions src/agents/run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,12 @@ def to_json(
result["current_step"] = self._serialize_current_step()
result["last_model_response"] = _serialize_last_model_response(model_responses)
result["last_processed_response"] = (
self._serialize_processed_response(self._last_processed_response)
self._serialize_processed_response(
self._last_processed_response,
context_serializer=context_serializer,
strict_context=strict_context,
include_tracing_api_key=include_tracing_api_key,
)
if self._last_processed_response
else None
)
Expand All @@ -556,7 +561,12 @@ def to_json(
return result

def _serialize_processed_response(
self, processed_response: ProcessedResponse
self,
processed_response: ProcessedResponse,
*,
context_serializer: ContextSerializer | None = None,
strict_context: bool = False,
include_tracing_api_key: bool = False,
) -> dict[str, Any]:
"""Serialize a ProcessedResponse to JSON format.

Expand All @@ -568,6 +578,14 @@ def _serialize_processed_response(
"""

action_groups = _serialize_tool_action_groups(processed_response)
_serialize_pending_nested_agent_tool_runs(
parent_state=self,
function_entries=action_groups.get("functions", []),
function_runs=processed_response.functions,
context_serializer=context_serializer,
strict_context=strict_context,
include_tracing_api_key=include_tracing_api_key,
)

interruptions_data = [
_serialize_tool_approval_interruption(interruption, include_tool_name=True)
Expand Down Expand Up @@ -1138,6 +1156,82 @@ def _serialize_tool_action_groups(
return serialized


def _serialize_pending_nested_agent_tool_runs(
*,
parent_state: RunState[Any, Any],
function_entries: Sequence[dict[str, Any]],
function_runs: Sequence[Any],
context_serializer: ContextSerializer | None = None,
strict_context: bool = False,
include_tracing_api_key: bool = False,
) -> None:
"""Attach serialized nested run state for pending agent-as-tool interruptions."""
if not function_entries or not function_runs:
return

from .agent_tool_state import peek_agent_tool_run_result

for entry, function_run in zip(function_entries, function_runs):
tool_call = getattr(function_run, "tool_call", None)
if not isinstance(tool_call, ResponseFunctionToolCall):
continue

pending_run_result = peek_agent_tool_run_result(tool_call)
if pending_run_result is None:
continue

interruptions = getattr(pending_run_result, "interruptions", None)
if not isinstance(interruptions, list) or not interruptions:
continue

to_state = getattr(pending_run_result, "to_state", None)
if not callable(to_state):
continue

try:
nested_state = to_state()
except Exception:
if strict_context:
raise
logger.warning(
"Failed to capture nested agent run state for tool call %s.",
tool_call.call_id,
)
continue

if not isinstance(nested_state, RunState):
continue
if nested_state is parent_state:
# Defensive guard against accidental self-referential serialization loops.
continue

try:
entry["agent_run_state"] = nested_state.to_json(
context_serializer=context_serializer,
strict_context=strict_context,
include_tracing_api_key=include_tracing_api_key,
)
except Exception:
if strict_context:
raise
logger.warning(
"Failed to serialize nested agent run state for tool call %s.",
tool_call.call_id,
)


class _SerializedAgentToolRunResult:
"""Minimal run-result wrapper used to restore nested agent-as-tool resumptions."""

def __init__(self, state: RunState[Any, Agent[Any]]) -> None:
self._state = state
self.interruptions = list(state.get_interruptions())
self.final_output = None

def to_state(self) -> RunState[Any, Agent[Any]]:
return self._state


def _serialize_guardrail_results(
results: Sequence[InputGuardrailResult | OutputGuardrailResult],
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -1215,11 +1309,65 @@ def _build_handoffs_map(current_agent: Agent[Any]) -> dict[str, Handoff[Any, Age
return handoffs_map


async def _restore_pending_nested_agent_tool_runs(
*,
current_agent: Agent[Any],
function_entries: Sequence[Any],
function_runs: Sequence[Any],
context_deserializer: ContextDeserializer | None = None,
strict_context: bool = False,
) -> None:
"""Rehydrate nested agent-as-tool run state into the ephemeral tool-call cache."""
if not function_entries or not function_runs:
return

from .agent_tool_state import drop_agent_tool_run_result, record_agent_tool_run_result

for entry, function_run in zip(function_entries, function_runs):
if not isinstance(entry, Mapping):
continue
nested_state_data = entry.get("agent_run_state")
if not isinstance(nested_state_data, Mapping):
continue

tool_call = getattr(function_run, "tool_call", None)
if not isinstance(tool_call, ResponseFunctionToolCall):
continue

try:
nested_state = await _build_run_state_from_json(
initial_agent=current_agent,
state_json=dict(nested_state_data),
context_deserializer=context_deserializer,
strict_context=strict_context,
)
except Exception:
if strict_context:
raise
logger.warning(
"Failed to deserialize nested agent run state for tool call %s.",
tool_call.call_id,
)
continue

pending_result = _SerializedAgentToolRunResult(nested_state)
if not pending_result.interruptions:
continue

# Replace any stale cache entry with the same signature so resumed runs do not read
# older pending interruptions after consuming this restored entry.
drop_agent_tool_run_result(tool_call)
record_agent_tool_run_result(tool_call, cast(Any, pending_result))


async def _deserialize_processed_response(
processed_response_data: dict[str, Any],
current_agent: Agent[Any],
context: RunContextWrapper[Any],
agent_map: dict[str, Agent[Any]],
*,
context_deserializer: ContextDeserializer | None = None,
strict_context: bool = False,
) -> ProcessedResponse:
"""Deserialize a ProcessedResponse from JSON data.

Expand Down Expand Up @@ -1403,6 +1551,14 @@ def _deserialize_action_groups() -> dict[str, list[Any]]:
shell_actions = action_groups["shell_actions"]
apply_patch_actions = action_groups["apply_patch_actions"]

await _restore_pending_nested_agent_tool_runs(
current_agent=current_agent,
function_entries=processed_response_data.get("functions", []),
function_runs=functions,
context_deserializer=context_deserializer,
strict_context=strict_context,
)

mcp_approval_requests: list[ToolRunMCPApprovalRequest] = []
for request_data in processed_response_data.get("mcp_approval_requests", []):
request_item_data = request_data.get("request_item", {})
Expand Down Expand Up @@ -1824,7 +1980,12 @@ async def _build_run_state_from_json(
last_processed_response_data = state_json.get("last_processed_response")
if last_processed_response_data and state._context is not None:
state._last_processed_response = await _deserialize_processed_response(
last_processed_response_data, current_agent, state._context, agent_map
last_processed_response_data,
current_agent,
state._context,
agent_map,
context_deserializer=context_deserializer,
strict_context=strict_context,
)
else:
state._last_processed_response = None
Expand Down
Loading