Skip to content

Commit 161fee1

Browse files
committed
fix: finish up human-in-the-loop port
1 parent 9782265 commit 161fee1

File tree

11 files changed

+5062
-746
lines changed

11 files changed

+5062
-746
lines changed

src/agents/_run_impl.py

Lines changed: 417 additions & 5 deletions
Large diffs are not rendered by default.

src/agents/agent.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,35 @@
4646
from .run import RunConfig
4747
from .stream_events import StreamEvent
4848

49+
# Per-process, ephemeral map linking a tool call ID to its nested
50+
# Agent run result within the same run; entry is removed after consumption.
51+
_agent_tool_run_results: dict[str, RunResult | RunResultStreaming] = {}
52+
53+
54+
def save_agent_tool_run_result(
55+
tool_call: ResponseFunctionToolCall | None,
56+
run_result: RunResult | RunResultStreaming,
57+
) -> None:
58+
"""Save the nested agent run result for later consumption.
59+
60+
This is used when an agent is used as a tool. The run result is stored
61+
so that interruptions from the nested agent run can be collected.
62+
"""
63+
if tool_call:
64+
_agent_tool_run_results[tool_call.call_id] = run_result
65+
66+
67+
def consume_agent_tool_run_result(
68+
tool_call: ResponseFunctionToolCall,
69+
) -> RunResult | RunResultStreaming | None:
70+
"""Consume and return the nested agent run result for a tool call.
71+
72+
This retrieves and removes the stored run result. Returns None if
73+
no result was stored for this tool call.
74+
"""
75+
run_result = _agent_tool_run_results.pop(tool_call.call_id, None)
76+
return run_result
77+
4978

5079
@dataclass
5180
class ToolsToFinalOutputResult:
@@ -412,6 +441,8 @@ def as_tool(
412441
is_enabled: bool
413442
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
414443
on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None,
444+
needs_approval: bool
445+
| Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] = False,
415446
run_config: RunConfig | None = None,
416447
max_turns: int | None = None,
417448
hooks: RunHooks[TContext] | None = None,
@@ -441,6 +472,13 @@ def as_tool(
441472
agent run. The callback receives an `AgentToolStreamEvent` containing the nested
442473
agent, the originating tool call (when available), and each stream event. When
443474
provided, the nested agent is executed in streaming mode.
475+
needs_approval: Whether the tool needs approval before execution.
476+
If True, the run will be interrupted and the tool call will need
477+
to be approved using RunState.approve() or rejected using
478+
RunState.reject() before continuing. Can be a bool
479+
(always/never needs approval) or a function that takes
480+
(run_context, tool_parameters, call_id) and returns whether this
481+
specific call needs approval.
444482
failure_error_function: If provided, generate an error message when the tool (agent) run
445483
fails. The message is sent to the LLM. If None, the exception is raised instead.
446484
"""
@@ -449,10 +487,12 @@ def as_tool(
449487
name_override=tool_name or _transforms.transform_string_function_style(self.name),
450488
description_override=tool_description or "",
451489
is_enabled=is_enabled,
490+
needs_approval=needs_approval,
452491
failure_error_function=failure_error_function,
453492
)
454493
async def run_agent(context: ToolContext, input: str) -> Any:
455494
from .run import DEFAULT_MAX_TURNS, Runner
495+
from .tool_context import ToolContext
456496

457497
resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
458498
run_result: RunResult | RunResultStreaming
@@ -530,12 +570,24 @@ async def dispatch_stream_events() -> None:
530570
conversation_id=conversation_id,
531571
session=session,
532572
)
573+
574+
# Store the run result keyed by tool_call_id so it can be retrieved later
575+
# when the tool_call is available during result processing
576+
# At runtime, context is actually a ToolContext which has tool_call_id
577+
if isinstance(context, ToolContext):
578+
_agent_tool_run_results[context.tool_call_id] = run_result
579+
533580
if custom_output_extractor:
534581
return await custom_output_extractor(run_result)
535582

536583
return run_result.final_output
537584

538-
return run_agent
585+
# Mark the function tool as an agent tool
586+
run_agent_tool = run_agent
587+
run_agent_tool._is_agent_tool = True
588+
run_agent_tool._agent_instance = self
589+
590+
return run_agent_tool
539591

540592
async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
541593
if isinstance(self.instructions, str):

src/agents/memory/openai_conversations_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
6767

6868
async def add_items(self, items: list[TResponseInputItem]) -> None:
6969
session_id = await self._get_session_id()
70+
if not items:
71+
return
72+
7073
await self._openai_client.conversations.items.create(
7174
conversation_id=session_id,
7275
items=items,

src/agents/result.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,13 @@ class RunResult(RunResultBase):
155155
)
156156
_last_processed_response: ProcessedResponse | None = field(default=None, repr=False)
157157
"""The last processed model response. This is needed for resuming from interruptions."""
158+
_tool_use_tracker_snapshot: dict[str, list[str]] = field(default_factory=dict, repr=False)
159+
_current_turn_persisted_item_count: int = 0
160+
"""Number of items from new_items already persisted to session for the
161+
current turn."""
162+
_original_input: str | list[TResponseInputItem] | None = field(default=None, repr=False)
163+
"""The original input from the first turn. Unlike `input`, this is never updated during the run.
164+
Used by to_state() to preserve the correct originalInput when serializing state."""
158165

159166
def __post_init__(self) -> None:
160167
self._last_agent_ref = weakref.ref(self._last_agent)
@@ -204,9 +211,12 @@ def to_state(self) -> Any:
204211
```
205212
"""
206213
# Create a RunState from the current result
214+
original_input_for_state = getattr(self, "_original_input", None)
207215
state = RunState(
208216
context=self.context_wrapper,
209-
original_input=self.input,
217+
original_input=original_input_for_state
218+
if original_input_for_state is not None
219+
else self.input,
210220
starting_agent=self.last_agent,
211221
max_turns=10, # This will be overridden by the runner
212222
)
@@ -217,6 +227,8 @@ def to_state(self) -> Any:
217227
state._input_guardrail_results = self.input_guardrail_results
218228
state._output_guardrail_results = self.output_guardrail_results
219229
state._last_processed_response = self._last_processed_response
230+
state._current_turn_persisted_item_count = self._current_turn_persisted_item_count
231+
state.set_tool_use_tracker_snapshot(self._tool_use_tracker_snapshot)
220232

221233
# If there are interruptions, set the current step
222234
if self.interruptions:
@@ -279,11 +291,32 @@ class RunResultStreaming(RunResultBase):
279291
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
280292
_stored_exception: Exception | None = field(default=None, repr=False)
281293

294+
_current_turn_persisted_item_count: int = 0
295+
"""Number of items from new_items already persisted to session for the
296+
current turn."""
297+
298+
_stream_input_persisted: bool = False
299+
"""Whether the input has been persisted to the session. Prevents double-saving."""
300+
301+
_original_input_for_persistence: list[TResponseInputItem] = field(default_factory=list)
302+
"""Original turn input before session history was merged, used for
303+
persistence (matches JS sessionInputOriginalSnapshot)."""
304+
282305
# Soft cancel state
283306
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
284307

308+
_original_input: str | list[TResponseInputItem] | None = field(default=None, repr=False)
309+
"""The original input from the first turn. Unlike `input`, this is never updated during the run.
310+
Used by to_state() to preserve the correct originalInput when serializing state."""
311+
_tool_use_tracker_snapshot: dict[str, list[str]] = field(default_factory=dict, repr=False)
312+
_state: Any = field(default=None, repr=False)
313+
"""Internal reference to the RunState for streaming results."""
314+
285315
def __post_init__(self) -> None:
286316
self._current_agent_ref = weakref.ref(self.current_agent)
317+
# Store the original input at creation time (it will be set via input field)
318+
if self._original_input is None:
319+
self._original_input = self.input
287320

288321
@property
289322
def last_agent(self) -> Agent[Any]:
@@ -508,9 +541,11 @@ def to_state(self) -> Any:
508541
```
509542
"""
510543
# Create a RunState from the current result
544+
# Use _original_input (the input from the first turn) instead of input
545+
# (which may have been updated during the run)
511546
state = RunState(
512547
context=self.context_wrapper,
513-
original_input=self.input,
548+
original_input=self._original_input if self._original_input is not None else self.input,
514549
starting_agent=self.last_agent,
515550
max_turns=self.max_turns,
516551
)
@@ -522,6 +557,8 @@ def to_state(self) -> Any:
522557
state._output_guardrail_results = self.output_guardrail_results
523558
state._current_turn = self.current_turn
524559
state._last_processed_response = self._last_processed_response
560+
state._current_turn_persisted_item_count = self._current_turn_persisted_item_count
561+
state.set_tool_use_tracker_snapshot(self._tool_use_tracker_snapshot)
525562

526563
# If there are interruptions, set the current step
527564
if self.interruptions:

0 commit comments

Comments
 (0)