Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
)
from .agent_tool_state import (
consume_agent_tool_run_result,
get_agent_tool_state_scope,
peek_agent_tool_run_result,
record_agent_tool_run_result,
set_agent_tool_state_scope,
)
from .exceptions import ModelBehaviorError, UserError
from .guardrail import InputGuardrail, OutputGuardrail
Expand Down Expand Up @@ -593,6 +595,7 @@ async def _run_agent_impl(context: ToolContext, input_json: str) -> Any:
resolved_run_config = run_config
if resolved_run_config is None and isinstance(context, ToolContext):
resolved_run_config = context.run_config
tool_state_scope_id = get_agent_tool_state_scope(context)
if isinstance(context, ToolContext):
# Use a fresh ToolContext to avoid sharing approval state with parent runs.
nested_context = ToolContext(
Expand All @@ -605,17 +608,20 @@ async def _run_agent_impl(context: ToolContext, input_json: str) -> Any:
agent=context.agent,
run_config=resolved_run_config,
)
set_agent_tool_state_scope(nested_context, tool_state_scope_id)
if should_capture_tool_input:
nested_context.tool_input = params_data
elif isinstance(context, RunContextWrapper):
if should_capture_tool_input:
nested_context = RunContextWrapper(context=context.context)
set_agent_tool_state_scope(nested_context, tool_state_scope_id)
nested_context.tool_input = params_data
else:
nested_context = context.context
else:
if should_capture_tool_input:
nested_context = RunContextWrapper(context=context)
set_agent_tool_state_scope(nested_context, tool_state_scope_id)
nested_context.tool_input = params_data
else:
nested_context = context
Expand Down Expand Up @@ -678,7 +684,10 @@ def _apply_nested_approvals(
)

if isinstance(context, ToolContext) and context.tool_call is not None:
pending_run_result = peek_agent_tool_run_result(context.tool_call)
pending_run_result = peek_agent_tool_run_result(
context.tool_call,
scope_id=tool_state_scope_id,
)
if pending_run_result and getattr(pending_run_result, "interruptions", None):
status = _nested_approvals_status(pending_run_result.interruptions)
if status == "pending":
Expand All @@ -693,7 +702,10 @@ def _apply_nested_approvals(
context,
pending_run_result.interruptions,
)
consume_agent_tool_run_result(context.tool_call)
consume_agent_tool_run_result(
context.tool_call,
scope_id=tool_state_scope_id,
)

if run_result is None:
if on_stream is not None:
Expand Down Expand Up @@ -780,7 +792,11 @@ async def dispatch_stream_events() -> None:
interruptions = getattr(run_result, "interruptions", None)
if isinstance(context, ToolContext) and context.tool_call is not None and interruptions:
if should_record_run_result:
record_agent_tool_run_result(context.tool_call, run_result)
record_agent_tool_run_result(
context.tool_call,
run_result,
scope_id=tool_state_scope_id,
)

if custom_output_extractor:
return await custom_output_extractor(run_result)
Expand Down
105 changes: 82 additions & 23 deletions src/agents/agent_tool_state.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,57 @@
from __future__ import annotations

import weakref
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall

from .result import RunResult, RunResultStreaming

ToolCallSignature = tuple[str, str, str, str, str | None, str | None]
ScopedToolCallSignature = tuple[str | None, ToolCallSignature]

_AGENT_TOOL_STATE_SCOPE_ATTR = "_agent_tool_state_scope_id"

# Ephemeral maps linking tool call objects to nested agent results within the same run.
# Store by object identity, and index by a stable signature to avoid call ID collisions.
_agent_tool_run_results_by_obj: dict[int, RunResult | RunResultStreaming] = {}
_agent_tool_run_results_by_signature: dict[
tuple[str, str, str, str, str | None, str | None],
ScopedToolCallSignature,
set[int],
] = {}
_agent_tool_run_result_signature_by_obj: dict[
int,
tuple[str, str, str, str, str | None, str | None],
ScopedToolCallSignature,
] = {}
_agent_tool_call_refs_by_obj: dict[int, weakref.ReferenceType[ResponseFunctionToolCall]] = {}


def get_agent_tool_state_scope(context: Any) -> str | None:
"""Read the private agent-tool cache scope id from a context wrapper."""
scope_id = getattr(context, _AGENT_TOOL_STATE_SCOPE_ATTR, None)
return scope_id if isinstance(scope_id, str) else None


def set_agent_tool_state_scope(context: Any, scope_id: str | None) -> None:
"""Attach or clear the private agent-tool cache scope id on a context wrapper."""
if context is None:
return
if scope_id is None:
try:
delattr(context, _AGENT_TOOL_STATE_SCOPE_ATTR)
except Exception:
return
return
try:
setattr(context, _AGENT_TOOL_STATE_SCOPE_ATTR, scope_id)
except Exception:
return


def _tool_call_signature(
tool_call: ResponseFunctionToolCall,
) -> tuple[str, str, str, str, str | None, str | None]:
) -> ToolCallSignature:
"""Build a stable signature for fallback lookup across tool call instances."""
return (
tool_call.call_id,
Expand All @@ -36,11 +63,21 @@ def _tool_call_signature(
)


def _scoped_tool_call_signature(
tool_call: ResponseFunctionToolCall, *, scope_id: str | None
) -> ScopedToolCallSignature:
"""Build a scope-qualified signature so independently restored states do not collide."""
return (scope_id, _tool_call_signature(tool_call))


def _index_agent_tool_run_result(
tool_call: ResponseFunctionToolCall, tool_call_obj_id: int
tool_call: ResponseFunctionToolCall,
tool_call_obj_id: int,
*,
scope_id: str | None,
) -> None:
"""Track tool call objects by signature for fallback lookup."""
signature = _tool_call_signature(tool_call)
signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id)
_agent_tool_run_result_signature_by_obj[tool_call_obj_id] = signature
_agent_tool_run_results_by_signature.setdefault(signature, set()).add(tool_call_obj_id)

Expand Down Expand Up @@ -80,26 +117,40 @@ def _on_tool_call_gc(_ref: weakref.ReferenceType[ResponseFunctionToolCall]) -> N


def record_agent_tool_run_result(
tool_call: ResponseFunctionToolCall, run_result: RunResult | RunResultStreaming
tool_call: ResponseFunctionToolCall,
run_result: RunResult | RunResultStreaming,
*,
scope_id: str | None = None,
) -> None:
"""Store the nested agent run result by tool call identity."""
tool_call_obj_id = id(tool_call)
_agent_tool_run_results_by_obj[tool_call_obj_id] = run_result
_index_agent_tool_run_result(tool_call, tool_call_obj_id)
_index_agent_tool_run_result(tool_call, tool_call_obj_id, scope_id=scope_id)
_register_tool_call_ref(tool_call, tool_call_obj_id)


def _tool_call_obj_matches_scope(tool_call_obj_id: int, *, scope_id: str | None) -> bool:
scoped_signature = _agent_tool_run_result_signature_by_obj.get(tool_call_obj_id)
if scoped_signature is None:
# Fallback for unindexed entries.
return scope_id is None
return scoped_signature[0] == scope_id


def consume_agent_tool_run_result(
tool_call: ResponseFunctionToolCall,
*,
scope_id: str | None = None,
) -> RunResult | RunResultStreaming | None:
"""Return and drop the stored nested agent run result for the given tool call."""
obj_id = id(tool_call)
run_result = _agent_tool_run_results_by_obj.pop(obj_id, None)
if run_result is not None:
_drop_agent_tool_run_result(obj_id)
return run_result
if _tool_call_obj_matches_scope(obj_id, scope_id=scope_id):
run_result = _agent_tool_run_results_by_obj.pop(obj_id, None)
if run_result is not None:
_drop_agent_tool_run_result(obj_id)
return run_result

signature = _tool_call_signature(tool_call)
signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id)
candidate_ids = _agent_tool_run_results_by_signature.get(signature)
if not candidate_ids:
return None
Expand All @@ -115,14 +166,17 @@ def consume_agent_tool_run_result(

def peek_agent_tool_run_result(
tool_call: ResponseFunctionToolCall,
*,
scope_id: str | None = None,
) -> RunResult | RunResultStreaming | None:
"""Return the stored nested agent run result without removing it."""
obj_id = id(tool_call)
run_result = _agent_tool_run_results_by_obj.get(obj_id)
if run_result is not None:
return run_result
if _tool_call_obj_matches_scope(obj_id, scope_id=scope_id):
run_result = _agent_tool_run_results_by_obj.get(obj_id)
if run_result is not None:
return run_result

signature = _tool_call_signature(tool_call)
signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id)
candidate_ids = _agent_tool_run_results_by_signature.get(signature)
if not candidate_ids:
return None
Expand All @@ -133,15 +187,20 @@ def peek_agent_tool_run_result(
return _agent_tool_run_results_by_obj.get(candidate_id)


def drop_agent_tool_run_result(tool_call: ResponseFunctionToolCall) -> None:
def drop_agent_tool_run_result(
tool_call: ResponseFunctionToolCall,
*,
scope_id: str | None = None,
) -> None:
"""Drop the stored nested agent run result, if present."""
obj_id = id(tool_call)
run_result = _agent_tool_run_results_by_obj.pop(obj_id, None)
if run_result is not None:
_drop_agent_tool_run_result(obj_id)
return
if _tool_call_obj_matches_scope(obj_id, scope_id=scope_id):
run_result = _agent_tool_run_results_by_obj.pop(obj_id, None)
if run_result is not None:
_drop_agent_tool_run_result(obj_id)
return

signature = _tool_call_signature(tool_call)
signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id)
candidate_ids = _agent_tool_run_results_by_signature.get(signature)
if not candidate_ids:
return
Expand Down
3 changes: 3 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from . import _debug
from .agent import Agent
from .agent_tool_state import set_agent_tool_state_scope
from .exceptions import (
AgentsException,
InputGuardrailTripwireTriggered,
Expand Down Expand Up @@ -555,6 +556,7 @@ async def run(
session_items = []
model_responses = []
context_wrapper = ensure_context_wrapper(context)
set_agent_tool_state_scope(context_wrapper, None)
run_state = RunState(
context=context_wrapper,
original_input=original_input,
Expand Down Expand Up @@ -1458,6 +1460,7 @@ def run_streamed(
auto_previous_response_id=auto_previous_response_id,
)
context_wrapper = ensure_context_wrapper(context)
set_agent_tool_state_scope(context_wrapper, None)
# input_for_state is the same as input_for_result here
input_for_state = input_for_result
run_state = RunState(
Expand Down
3 changes: 3 additions & 0 deletions src/agents/run_internal/agent_runner_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, cast

from ..agent import Agent
from ..agent_tool_state import set_agent_tool_state_scope
from ..exceptions import UserError
from ..guardrail import InputGuardrailResult
from ..items import ModelResponse, RunItem, ToolApprovalItem, TResponseInputItem
Expand Down Expand Up @@ -141,10 +142,12 @@ def resolve_resumed_context(
"""Return the context wrapper for a resumed run, overriding when provided."""
if context is not None:
context_wrapper = ensure_context_wrapper(context)
set_agent_tool_state_scope(context_wrapper, run_state._agent_tool_state_scope_id)
run_state._context = context_wrapper
return context_wrapper
if run_state._context is None:
run_state._context = ensure_context_wrapper(context)
set_agent_tool_state_scope(run_state._context, run_state._agent_tool_state_scope_id)
return run_state._context


Expand Down
3 changes: 2 additions & 1 deletion src/agents/run_internal/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,11 @@ def function_rejection_item(
tool_call: Any,
*,
rejection_message: str = REJECTION_MESSAGE,
scope_id: str | None = None,
) -> ToolCallOutputItem:
"""Build a ToolCallOutputItem representing a rejected function tool call."""
if isinstance(tool_call, ResponseFunctionToolCall):
drop_agent_tool_run_result(tool_call)
drop_agent_tool_run_result(tool_call, scope_id=scope_id)
return ToolCallOutputItem(
output=rejection_message,
raw_item=ItemHelpers.tool_call_output_item(tool_call, rejection_message),
Expand Down
28 changes: 23 additions & 5 deletions src/agents/run_internal/tool_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from openai.types.responses.response_output_item import McpApprovalRequest

from ..agent import Agent
from ..agent_tool_state import consume_agent_tool_run_result, peek_agent_tool_run_result
from ..agent_tool_state import (
consume_agent_tool_run_result,
get_agent_tool_state_scope,
peek_agent_tool_run_result,
)
from ..editor import ApplyPatchOperation, ApplyPatchResult
from ..exceptions import (
AgentsException,
Expand Down Expand Up @@ -840,6 +844,7 @@ async def execute_function_tool_calls(
"""Execute function tool calls with approvals, guardrails, and hooks."""
tool_input_guardrail_results: list[ToolInputGuardrailResult] = []
tool_output_guardrail_results: list[ToolOutputGuardrailResult] = []
tool_state_scope_id = get_agent_tool_state_scope(context_wrapper)

async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionToolCall) -> Any:
with function_span(func_tool.name) as span_fn:
Expand Down Expand Up @@ -903,6 +908,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
agent,
tool_call,
rejection_message=rejection_message,
scope_id=tool_state_scope_id,
),
)

Expand Down Expand Up @@ -972,7 +978,10 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
function_tool_results = []
for tool_run, result in zip(tool_runs, results):
if isinstance(result, FunctionToolResult):
nested_run_result = consume_agent_tool_run_result(tool_run.tool_call)
nested_run_result = consume_agent_tool_run_result(
tool_run.tool_call,
scope_id=tool_state_scope_id,
)
if nested_run_result:
result.agent_run_result = nested_run_result
nested_interruptions_from_result: list[ToolApprovalItem] = (
Expand All @@ -985,7 +994,10 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo

function_tool_results.append(result)
else:
nested_run_result = peek_agent_tool_run_result(tool_run.tool_call)
nested_run_result = peek_agent_tool_run_result(
tool_run.tool_call,
scope_id=tool_state_scope_id,
)
nested_interruptions: list[ToolApprovalItem] = []
if nested_run_result:
nested_interruptions = (
Expand All @@ -994,9 +1006,15 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
else []
)
if nested_run_result and not nested_interruptions:
nested_run_result = consume_agent_tool_run_result(tool_run.tool_call)
nested_run_result = consume_agent_tool_run_result(
tool_run.tool_call,
scope_id=tool_state_scope_id,
)
elif nested_run_result is None:
nested_run_result = consume_agent_tool_run_result(tool_run.tool_call)
nested_run_result = consume_agent_tool_run_result(
tool_run.tool_call,
scope_id=tool_state_scope_id,
)
if nested_run_result:
nested_interruptions = (
nested_run_result.interruptions
Expand Down
Loading