diff --git a/src/agents/agent.py b/src/agents/agent.py index bf9760e79..f28df1a14 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -22,8 +22,10 @@ ) from .agent_tool_state import ( consume_agent_tool_run_result, + get_agent_tool_state_scope, peek_agent_tool_run_result, record_agent_tool_run_result, + set_agent_tool_state_scope, ) from .exceptions import ModelBehaviorError, UserError from .guardrail import InputGuardrail, OutputGuardrail @@ -593,6 +595,7 @@ async def _run_agent_impl(context: ToolContext, input_json: str) -> Any: resolved_run_config = run_config if resolved_run_config is None and isinstance(context, ToolContext): resolved_run_config = context.run_config + tool_state_scope_id = get_agent_tool_state_scope(context) if isinstance(context, ToolContext): # Use a fresh ToolContext to avoid sharing approval state with parent runs. nested_context = ToolContext( @@ -605,17 +608,20 @@ async def _run_agent_impl(context: ToolContext, input_json: str) -> Any: agent=context.agent, run_config=resolved_run_config, ) + set_agent_tool_state_scope(nested_context, tool_state_scope_id) if should_capture_tool_input: nested_context.tool_input = params_data elif isinstance(context, RunContextWrapper): if should_capture_tool_input: nested_context = RunContextWrapper(context=context.context) + set_agent_tool_state_scope(nested_context, tool_state_scope_id) nested_context.tool_input = params_data else: nested_context = context.context else: if should_capture_tool_input: nested_context = RunContextWrapper(context=context) + set_agent_tool_state_scope(nested_context, tool_state_scope_id) nested_context.tool_input = params_data else: nested_context = context @@ -678,7 +684,10 @@ def _apply_nested_approvals( ) if isinstance(context, ToolContext) and context.tool_call is not None: - pending_run_result = peek_agent_tool_run_result(context.tool_call) + pending_run_result = peek_agent_tool_run_result( + context.tool_call, + scope_id=tool_state_scope_id, + ) if pending_run_result and getattr(pending_run_result, "interruptions", None): status = _nested_approvals_status(pending_run_result.interruptions) if status == "pending": @@ -693,7 +702,10 @@ def _apply_nested_approvals( context, pending_run_result.interruptions, ) - consume_agent_tool_run_result(context.tool_call) + consume_agent_tool_run_result( + context.tool_call, + scope_id=tool_state_scope_id, + ) if run_result is None: if on_stream is not None: @@ -780,7 +792,11 @@ async def dispatch_stream_events() -> None: interruptions = getattr(run_result, "interruptions", None) if isinstance(context, ToolContext) and context.tool_call is not None and interruptions: if should_record_run_result: - record_agent_tool_run_result(context.tool_call, run_result) + record_agent_tool_run_result( + context.tool_call, + run_result, + scope_id=tool_state_scope_id, + ) if custom_output_extractor: return await custom_output_extractor(run_result) diff --git a/src/agents/agent_tool_state.py b/src/agents/agent_tool_state.py index 28995ca82..2ddb2c988 100644 --- a/src/agents/agent_tool_state.py +++ b/src/agents/agent_tool_state.py @@ -1,30 +1,57 @@ from __future__ import annotations import weakref -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall from .result import RunResult, RunResultStreaming +ToolCallSignature = tuple[str, str, str, str, str | None, str | None] +ScopedToolCallSignature = tuple[str | None, ToolCallSignature] + +_AGENT_TOOL_STATE_SCOPE_ATTR = "_agent_tool_state_scope_id" + # Ephemeral maps linking tool call objects to nested agent results within the same run. # Store by object identity, and index by a stable signature to avoid call ID collisions. _agent_tool_run_results_by_obj: dict[int, RunResult | RunResultStreaming] = {} _agent_tool_run_results_by_signature: dict[ - tuple[str, str, str, str, str | None, str | None], + ScopedToolCallSignature, set[int], ] = {} _agent_tool_run_result_signature_by_obj: dict[ int, - tuple[str, str, str, str, str | None, str | None], + ScopedToolCallSignature, ] = {} _agent_tool_call_refs_by_obj: dict[int, weakref.ReferenceType[ResponseFunctionToolCall]] = {} +def get_agent_tool_state_scope(context: Any) -> str | None: + """Read the private agent-tool cache scope id from a context wrapper.""" + scope_id = getattr(context, _AGENT_TOOL_STATE_SCOPE_ATTR, None) + return scope_id if isinstance(scope_id, str) else None + + +def set_agent_tool_state_scope(context: Any, scope_id: str | None) -> None: + """Attach or clear the private agent-tool cache scope id on a context wrapper.""" + if context is None: + return + if scope_id is None: + try: + delattr(context, _AGENT_TOOL_STATE_SCOPE_ATTR) + except Exception: + return + return + try: + setattr(context, _AGENT_TOOL_STATE_SCOPE_ATTR, scope_id) + except Exception: + return + + def _tool_call_signature( tool_call: ResponseFunctionToolCall, -) -> tuple[str, str, str, str, str | None, str | None]: +) -> ToolCallSignature: """Build a stable signature for fallback lookup across tool call instances.""" return ( tool_call.call_id, @@ -36,11 +63,21 @@ def _tool_call_signature( ) +def _scoped_tool_call_signature( + tool_call: ResponseFunctionToolCall, *, scope_id: str | None +) -> ScopedToolCallSignature: + """Build a scope-qualified signature so independently restored states do not collide.""" + return (scope_id, _tool_call_signature(tool_call)) + + def _index_agent_tool_run_result( - tool_call: ResponseFunctionToolCall, tool_call_obj_id: int + tool_call: ResponseFunctionToolCall, + tool_call_obj_id: int, + *, + scope_id: str | None, ) -> None: """Track tool call objects by signature for fallback lookup.""" - signature = _tool_call_signature(tool_call) + signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id) _agent_tool_run_result_signature_by_obj[tool_call_obj_id] = signature _agent_tool_run_results_by_signature.setdefault(signature, set()).add(tool_call_obj_id) @@ -80,26 +117,40 @@ def _on_tool_call_gc(_ref: weakref.ReferenceType[ResponseFunctionToolCall]) -> N def record_agent_tool_run_result( - tool_call: ResponseFunctionToolCall, run_result: RunResult | RunResultStreaming + tool_call: ResponseFunctionToolCall, + run_result: RunResult | RunResultStreaming, + *, + scope_id: str | None = None, ) -> None: """Store the nested agent run result by tool call identity.""" tool_call_obj_id = id(tool_call) _agent_tool_run_results_by_obj[tool_call_obj_id] = run_result - _index_agent_tool_run_result(tool_call, tool_call_obj_id) + _index_agent_tool_run_result(tool_call, tool_call_obj_id, scope_id=scope_id) _register_tool_call_ref(tool_call, tool_call_obj_id) +def _tool_call_obj_matches_scope(tool_call_obj_id: int, *, scope_id: str | None) -> bool: + scoped_signature = _agent_tool_run_result_signature_by_obj.get(tool_call_obj_id) + if scoped_signature is None: + # Fallback for unindexed entries. + return scope_id is None + return scoped_signature[0] == scope_id + + def consume_agent_tool_run_result( tool_call: ResponseFunctionToolCall, + *, + scope_id: str | None = None, ) -> RunResult | RunResultStreaming | None: """Return and drop the stored nested agent run result for the given tool call.""" obj_id = id(tool_call) - run_result = _agent_tool_run_results_by_obj.pop(obj_id, None) - if run_result is not None: - _drop_agent_tool_run_result(obj_id) - return run_result + if _tool_call_obj_matches_scope(obj_id, scope_id=scope_id): + run_result = _agent_tool_run_results_by_obj.pop(obj_id, None) + if run_result is not None: + _drop_agent_tool_run_result(obj_id) + return run_result - signature = _tool_call_signature(tool_call) + signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id) candidate_ids = _agent_tool_run_results_by_signature.get(signature) if not candidate_ids: return None @@ -115,14 +166,17 @@ def consume_agent_tool_run_result( def peek_agent_tool_run_result( tool_call: ResponseFunctionToolCall, + *, + scope_id: str | None = None, ) -> RunResult | RunResultStreaming | None: """Return the stored nested agent run result without removing it.""" obj_id = id(tool_call) - run_result = _agent_tool_run_results_by_obj.get(obj_id) - if run_result is not None: - return run_result + if _tool_call_obj_matches_scope(obj_id, scope_id=scope_id): + run_result = _agent_tool_run_results_by_obj.get(obj_id) + if run_result is not None: + return run_result - signature = _tool_call_signature(tool_call) + signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id) candidate_ids = _agent_tool_run_results_by_signature.get(signature) if not candidate_ids: return None @@ -133,15 +187,20 @@ def peek_agent_tool_run_result( return _agent_tool_run_results_by_obj.get(candidate_id) -def drop_agent_tool_run_result(tool_call: ResponseFunctionToolCall) -> None: +def drop_agent_tool_run_result( + tool_call: ResponseFunctionToolCall, + *, + scope_id: str | None = None, +) -> None: """Drop the stored nested agent run result, if present.""" obj_id = id(tool_call) - run_result = _agent_tool_run_results_by_obj.pop(obj_id, None) - if run_result is not None: - _drop_agent_tool_run_result(obj_id) - return + if _tool_call_obj_matches_scope(obj_id, scope_id=scope_id): + run_result = _agent_tool_run_results_by_obj.pop(obj_id, None) + if run_result is not None: + _drop_agent_tool_run_result(obj_id) + return - signature = _tool_call_signature(tool_call) + signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id) candidate_ids = _agent_tool_run_results_by_signature.get(signature) if not candidate_ids: return diff --git a/src/agents/run.py b/src/agents/run.py index f239a1ef9..e7c1d421b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -9,6 +9,7 @@ from . import _debug from .agent import Agent +from .agent_tool_state import set_agent_tool_state_scope from .exceptions import ( AgentsException, InputGuardrailTripwireTriggered, @@ -555,6 +556,7 @@ async def run( session_items = [] model_responses = [] context_wrapper = ensure_context_wrapper(context) + set_agent_tool_state_scope(context_wrapper, None) run_state = RunState( context=context_wrapper, original_input=original_input, @@ -1458,6 +1460,7 @@ def run_streamed( auto_previous_response_id=auto_previous_response_id, ) context_wrapper = ensure_context_wrapper(context) + set_agent_tool_state_scope(context_wrapper, None) # input_for_state is the same as input_for_result here input_for_state = input_for_result run_state = RunState( diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index 3498b4572..9523e80a0 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -5,6 +5,7 @@ from typing import Any, cast from ..agent import Agent +from ..agent_tool_state import set_agent_tool_state_scope from ..exceptions import UserError from ..guardrail import InputGuardrailResult from ..items import ModelResponse, RunItem, ToolApprovalItem, TResponseInputItem @@ -141,10 +142,12 @@ def resolve_resumed_context( """Return the context wrapper for a resumed run, overriding when provided.""" if context is not None: context_wrapper = ensure_context_wrapper(context) + set_agent_tool_state_scope(context_wrapper, run_state._agent_tool_state_scope_id) run_state._context = context_wrapper return context_wrapper if run_state._context is None: run_state._context = ensure_context_wrapper(context) + set_agent_tool_state_scope(run_state._context, run_state._agent_tool_state_scope_id) return run_state._context diff --git a/src/agents/run_internal/items.py b/src/agents/run_internal/items.py index 57b5b067b..5883b58d1 100644 --- a/src/agents/run_internal/items.py +++ b/src/agents/run_internal/items.py @@ -194,10 +194,11 @@ def function_rejection_item( tool_call: Any, *, rejection_message: str = REJECTION_MESSAGE, + scope_id: str | None = None, ) -> ToolCallOutputItem: """Build a ToolCallOutputItem representing a rejected function tool call.""" if isinstance(tool_call, ResponseFunctionToolCall): - drop_agent_tool_run_result(tool_call) + drop_agent_tool_run_result(tool_call, scope_id=scope_id) return ToolCallOutputItem( output=rejection_message, raw_item=ItemHelpers.tool_call_output_item(tool_call, rejection_message), diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index cdac898a3..45b32f418 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -20,7 +20,11 @@ from openai.types.responses.response_output_item import McpApprovalRequest from ..agent import Agent -from ..agent_tool_state import consume_agent_tool_run_result, peek_agent_tool_run_result +from ..agent_tool_state import ( + consume_agent_tool_run_result, + get_agent_tool_state_scope, + peek_agent_tool_run_result, +) from ..editor import ApplyPatchOperation, ApplyPatchResult from ..exceptions import ( AgentsException, @@ -840,6 +844,7 @@ async def execute_function_tool_calls( """Execute function tool calls with approvals, guardrails, and hooks.""" tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] tool_output_guardrail_results: list[ToolOutputGuardrailResult] = [] + tool_state_scope_id = get_agent_tool_state_scope(context_wrapper) async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionToolCall) -> Any: with function_span(func_tool.name) as span_fn: @@ -903,6 +908,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo agent, tool_call, rejection_message=rejection_message, + scope_id=tool_state_scope_id, ), ) @@ -972,7 +978,10 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo function_tool_results = [] for tool_run, result in zip(tool_runs, results): if isinstance(result, FunctionToolResult): - nested_run_result = consume_agent_tool_run_result(tool_run.tool_call) + nested_run_result = consume_agent_tool_run_result( + tool_run.tool_call, + scope_id=tool_state_scope_id, + ) if nested_run_result: result.agent_run_result = nested_run_result nested_interruptions_from_result: list[ToolApprovalItem] = ( @@ -985,7 +994,10 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo function_tool_results.append(result) else: - nested_run_result = peek_agent_tool_run_result(tool_run.tool_call) + nested_run_result = peek_agent_tool_run_result( + tool_run.tool_call, + scope_id=tool_state_scope_id, + ) nested_interruptions: list[ToolApprovalItem] = [] if nested_run_result: nested_interruptions = ( @@ -994,9 +1006,15 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo else [] ) if nested_run_result and not nested_interruptions: - nested_run_result = consume_agent_tool_run_result(tool_run.tool_call) + nested_run_result = consume_agent_tool_run_result( + tool_run.tool_call, + scope_id=tool_state_scope_id, + ) elif nested_run_result is None: - nested_run_result = consume_agent_tool_run_result(tool_run.tool_call) + nested_run_result = consume_agent_tool_run_result( + tool_run.tool_call, + scope_id=tool_state_scope_id, + ) if nested_run_result: nested_interruptions = ( nested_run_result.interruptions diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index e51fa801c..59996c6e0 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -29,7 +29,7 @@ from ..agent import Agent, ToolsToFinalOutputResult from ..agent_output import AgentOutputSchemaBase -from ..agent_tool_state import peek_agent_tool_run_result +from ..agent_tool_state import get_agent_tool_state_scope, peek_agent_tool_run_result from ..exceptions import ModelBehaviorError, UserError from ..handoffs import Handoff, HandoffInputData, nest_handoff_history from ..items import ( @@ -697,7 +697,12 @@ async def _record_function_rejection( call_id=call_id, ) rejected_function_outputs.append( - function_rejection_item(agent, tool_call, rejection_message=rejection_message) + function_rejection_item( + agent, + tool_call, + rejection_message=rejection_message, + scope_id=tool_state_scope_id, + ) ) if isinstance(call_id, str): rejected_function_call_ids.add(call_id) @@ -725,6 +730,7 @@ async def _function_requires_approval(run: ToolRunFunction) -> bool: pending_approval_items = _pending_approvals_from_state() approval_items_by_call_id = index_approval_items_by_call_id(pending_approval_items) + tool_state_scope_id = get_agent_tool_state_scope(context_wrapper) rejected_function_outputs: list[RunItem] = [] rejected_function_call_ids: set[str] = set() @@ -840,7 +846,10 @@ def _function_output_exists(run: ToolRunFunction) -> bool: if not call_id: return False - pending_run_result = peek_agent_tool_run_result(run.tool_call) + pending_run_result = peek_agent_tool_run_result( + run.tool_call, + scope_id=tool_state_scope_id, + ) if pending_run_result and getattr(pending_run_result, "interruptions", None): status = _nested_interruptions_status(pending_run_result.interruptions) if status in ("approved", "rejected"): diff --git a/src/agents/run_state.py b/src/agents/run_state.py index 4e4ab175b..1f5fb2f9a 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -8,6 +8,7 @@ from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast +from uuid import uuid4 from openai.types.responses import ( ResponseComputerToolCall, @@ -172,6 +173,9 @@ class RunState(Generic[TContext, TAgent]): _trace_state: TraceState | None = field(default=None, repr=False) """Serialized trace metadata for resuming tracing context.""" + _agent_tool_state_scope_id: str | None = field(default=None, repr=False) + """Private scope id used to isolate agent-tool pending state per RunState instance.""" + def __init__( self, context: RunContextWrapper[TContext], @@ -204,6 +208,9 @@ def __init__( self._current_turn_persisted_item_count = 0 self._tool_use_tracker_snapshot = {} self._trace_state = None + from .agent_tool_state import get_agent_tool_state_scope + + self._agent_tool_state_scope_id = get_agent_tool_state_scope(context) def get_interruptions(self) -> list[ToolApprovalItem]: """Return pending interruptions if the current step is an interruption.""" @@ -582,6 +589,7 @@ def _serialize_processed_response( parent_state=self, function_entries=action_groups.get("functions", []), function_runs=processed_response.functions, + scope_id=self._agent_tool_state_scope_id, context_serializer=context_serializer, strict_context=strict_context, include_tracing_api_key=include_tracing_api_key, @@ -1161,6 +1169,7 @@ def _serialize_pending_nested_agent_tool_runs( parent_state: RunState[Any, Any], function_entries: Sequence[dict[str, Any]], function_runs: Sequence[Any], + scope_id: str | None = None, context_serializer: ContextSerializer | None = None, strict_context: bool = False, include_tracing_api_key: bool = False, @@ -1176,7 +1185,7 @@ def _serialize_pending_nested_agent_tool_runs( if not isinstance(tool_call, ResponseFunctionToolCall): continue - pending_run_result = peek_agent_tool_run_result(tool_call) + pending_run_result = peek_agent_tool_run_result(tool_call, scope_id=scope_id) if pending_run_result is None: continue @@ -1314,6 +1323,7 @@ async def _restore_pending_nested_agent_tool_runs( current_agent: Agent[Any], function_entries: Sequence[Any], function_runs: Sequence[Any], + scope_id: str | None = None, context_deserializer: ContextDeserializer | None = None, strict_context: bool = False, ) -> None: @@ -1356,8 +1366,8 @@ async def _restore_pending_nested_agent_tool_runs( # Replace any stale cache entry with the same signature so resumed runs do not read # older pending interruptions after consuming this restored entry. - drop_agent_tool_run_result(tool_call) - record_agent_tool_run_result(tool_call, cast(Any, pending_result)) + drop_agent_tool_run_result(tool_call, scope_id=scope_id) + record_agent_tool_run_result(tool_call, cast(Any, pending_result), scope_id=scope_id) async def _deserialize_processed_response( @@ -1366,6 +1376,7 @@ async def _deserialize_processed_response( context: RunContextWrapper[Any], agent_map: dict[str, Agent[Any]], *, + scope_id: str | None = None, context_deserializer: ContextDeserializer | None = None, strict_context: bool = False, ) -> ProcessedResponse: @@ -1555,6 +1566,7 @@ def _deserialize_action_groups() -> dict[str, list[Any]]: current_agent=current_agent, function_entries=processed_response_data.get("functions", []), function_runs=functions, + scope_id=scope_id, context_deserializer=context_deserializer, strict_context=strict_context, ) @@ -1972,6 +1984,10 @@ async def _build_run_state_from_json( previous_response_id=state_json.get("previous_response_id"), auto_previous_response_id=bool(state_json.get("auto_previous_response_id", False)), ) + from .agent_tool_state import set_agent_tool_state_scope + + state._agent_tool_state_scope_id = uuid4().hex + set_agent_tool_state_scope(context, state._agent_tool_state_scope_id) state._current_turn = state_json["current_turn"] state._model_responses = _deserialize_model_responses(state_json.get("model_responses", [])) @@ -1984,6 +2000,7 @@ async def _build_run_state_from_json( current_agent, state._context, agent_map, + scope_id=state._agent_tool_state_scope_id, context_deserializer=context_deserializer, strict_context=strict_context, ) diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index fbfbf4f60..4156b32ad 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -5,6 +5,7 @@ from openai.types.responses import ResponseFunctionToolCall +from .agent_tool_state import get_agent_tool_state_scope, set_agent_tool_state_scope from .run_context import RunContextWrapper, TContext from .usage import Usage @@ -130,4 +131,5 @@ def from_agent_context( run_config=tool_run_config, **base_values, ) + set_agent_tool_state_scope(tool_context, get_agent_tool_state_scope(context)) return tool_context diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index d2cdafd15..d6c1edb82 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -25,7 +25,11 @@ TResponseInputItem, ) from agents.agent_tool_input import StructuredToolInputBuilderOptions -from agents.agent_tool_state import record_agent_tool_run_result +from agents.agent_tool_state import ( + get_agent_tool_state_scope, + record_agent_tool_run_result, + set_agent_tool_state_scope, +) from agents.stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent from agents.tool_context import ToolContext from tests.utils.hitl import make_function_tool_call @@ -1000,6 +1004,80 @@ async def extractor(result: Any) -> str: assert run_inputs == [resume_state] +@pytest.mark.asyncio +async def test_agent_as_tool_preserves_scope_for_nested_tool_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Nested ToolContext instances should inherit the parent tool-state scope.""" + + class DummyResult: + def __init__(self) -> None: + self.final_output = "ok" + self.interruptions: list[ToolApprovalItem] = [] + + scope_id = "resume-scope" + agent = Agent(name="scope-agent") + tool = agent.as_tool(tool_name="scope_tool", tool_description="Scope tool") + + async def fake_run(cls, /, starting_agent, input, **kwargs) -> DummyResult: + del cls, starting_agent, input + nested_context = kwargs.get("context") + assert isinstance(nested_context, ToolContext) + assert get_agent_tool_state_scope(nested_context) == scope_id + return DummyResult() + + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) + + tool_context = ToolContext( + context=None, + tool_name="scope_tool", + tool_call_id="scope-call", + tool_arguments='{"input":"hello"}', + ) + set_agent_tool_state_scope(tool_context, scope_id) + + output = await tool.on_invoke_tool(tool_context, '{"input":"hello"}') + assert output == "ok" + + +@pytest.mark.asyncio +async def test_agent_as_tool_preserves_scope_for_nested_run_context_wrapper( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Nested RunContextWrapper instances should inherit the parent tool-state scope.""" + + class Params(BaseModel): + text: str + + class DummyResult: + def __init__(self) -> None: + self.final_output = "ok" + self.interruptions: list[ToolApprovalItem] = [] + + scope_id = "resume-scope-wrapper" + agent = Agent(name="scope-agent-wrapper") + tool = agent.as_tool( + tool_name="scope_tool_wrapper", + tool_description="Scope tool wrapper", + parameters=Params, + ) + + async def fake_run(cls, /, starting_agent, input, **kwargs) -> DummyResult: + del cls, starting_agent, input + nested_context = kwargs.get("context") + assert isinstance(nested_context, RunContextWrapper) + assert get_agent_tool_state_scope(nested_context) == scope_id + return DummyResult() + + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) + + parent_context = RunContextWrapper(context={"key": "value"}) + set_agent_tool_state_scope(parent_context, scope_id) + + output = await tool.on_invoke_tool(cast(Any, parent_context), '{"text":"hello"}') + assert output == "ok" + + @pytest.mark.asyncio async def test_agent_as_tool_streams_events_with_on_stream( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 9f3fe48a4..ef5e9091f 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -1840,16 +1840,24 @@ async def inner_sensitive_tool(text: str) -> str: del first_result gc.collect() - restored_state = await RunState.from_json(outer_agent, state_json) - restored_interruptions = restored_state.get_interruptions() - assert len(restored_interruptions) == 1 - restored_state.approve(restored_interruptions[0]) - - resumed_result = await Runner.run(outer_agent, restored_state) - - assert resumed_result.final_output == "outer-complete" - assert resumed_result.interruptions == [] - assert tool_calls == ["hello"] + restored_state_one = await RunState.from_json(outer_agent, state_json) + restored_state_two = await RunState.from_json(outer_agent, state_json) + + restored_interruptions_one = restored_state_one.get_interruptions() + restored_interruptions_two = restored_state_two.get_interruptions() + assert len(restored_interruptions_one) == 1 + assert len(restored_interruptions_two) == 1 + restored_state_one.approve(restored_interruptions_one[0]) + restored_state_two.approve(restored_interruptions_two[0]) + + resumed_result_one = await Runner.run(outer_agent, restored_state_one) + resumed_result_two = await Runner.run(outer_agent, restored_state_two) + + assert resumed_result_one.final_output == "outer-complete" + assert resumed_result_one.interruptions == [] + assert resumed_result_two.final_output == "outer-complete" + assert resumed_result_two.interruptions == [] + assert tool_calls == ["hello", "hello"] async def test_json_decode_error_handling(self): """Test that invalid JSON raises appropriate error."""