Skip to content

Commit b64aabd

Browse files
committed
fix: address issues around resuming run state with conversation history
1 parent e823904 commit b64aabd

File tree

1 file changed

+72
-3
lines changed

1 file changed

+72
-3
lines changed

src/agents/run.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -744,12 +744,17 @@ async def run(
744744
# Check if we're resuming from a RunState
745745
is_resumed_state = isinstance(input, RunState)
746746
run_state: RunState[TContext] | None = None
747+
prepared_input: str | list[TResponseInputItem]
747748

748749
if is_resumed_state:
749750
# Resuming from a saved state
750751
run_state = cast(RunState[TContext], input)
751752
original_user_input = run_state._original_input
752-
prepared_input = run_state._original_input
753+
754+
if isinstance(run_state._original_input, list):
755+
prepared_input = self._merge_provider_data_in_items(run_state._original_input)
756+
else:
757+
prepared_input = run_state._original_input
753758

754759
# Override context with the state's context if not provided
755760
if context is None and run_state._context is not None:
@@ -826,6 +831,9 @@ async def run(
826831
# If resuming from an interrupted state, execute approved tools first
827832
if is_resumed_state and run_state is not None and run_state._current_step is not None:
828833
if isinstance(run_state._current_step, NextStepInterruption):
834+
# Track items before executing approved tools
835+
items_before_execution = len(generated_items)
836+
829837
# We're resuming from an interruption - execute approved tools
830838
await self._execute_approved_tools(
831839
agent=current_agent,
@@ -835,6 +843,16 @@ async def run(
835843
run_config=run_config,
836844
hooks=hooks,
837845
)
846+
847+
# Save the newly executed tool outputs to the session
848+
new_tool_outputs: list[RunItem] = [
849+
item
850+
for item in generated_items[items_before_execution:]
851+
if item.type == "tool_call_output_item"
852+
]
853+
if new_tool_outputs and session is not None:
854+
await self._save_result_to_session(session, [], new_tool_outputs)
855+
838856
# Clear the current step since we've handled it
839857
run_state._current_step = None
840858

@@ -1168,7 +1186,14 @@ def run_streamed(
11681186

11691187
if is_resumed_state:
11701188
run_state = cast(RunState[TContext], input)
1171-
input_for_result = run_state._original_input
1189+
1190+
if isinstance(run_state._original_input, list):
1191+
input_for_result = AgentRunner._merge_provider_data_in_items(
1192+
run_state._original_input
1193+
)
1194+
else:
1195+
input_for_result = run_state._original_input
1196+
11721197
# Use context from RunState if not provided
11731198
if context is None and run_state._context is not None:
11741199
context = run_state._context.context
@@ -1387,6 +1412,9 @@ async def _start_streaming(
13871412
# If resuming from an interrupted state, execute approved tools first
13881413
if run_state is not None and run_state._current_step is not None:
13891414
if isinstance(run_state._current_step, NextStepInterruption):
1415+
# Track items before executing approved tools
1416+
items_before_execution = len(streamed_result.new_items)
1417+
13901418
# We're resuming from an interruption - execute approved tools
13911419
await cls._execute_approved_tools_static(
13921420
agent=current_agent,
@@ -1396,6 +1424,16 @@ async def _start_streaming(
13961424
run_config=run_config,
13971425
hooks=hooks,
13981426
)
1427+
1428+
# Save the newly executed tool outputs to the session
1429+
new_tool_outputs: list[RunItem] = [
1430+
item
1431+
for item in streamed_result.new_items[items_before_execution:]
1432+
if item.type == "tool_call_output_item"
1433+
]
1434+
if new_tool_outputs and session is not None:
1435+
await cls._save_result_to_session(session, [], new_tool_outputs)
1436+
13991437
# Clear the current step since we've handled it
14001438
run_state._current_step = None
14011439

@@ -1698,6 +1736,8 @@ async def _run_single_turn_streamed(
16981736
input_item = item.to_input_item()
16991737
input.append(input_item)
17001738

1739+
input = cls._merge_provider_data_in_items(input)
1740+
17011741
# THIS IS THE RESOLVED CONFLICT BLOCK
17021742
filtered = await cls._maybe_filter_model_input(
17031743
agent=agent,
@@ -2038,6 +2078,8 @@ async def _run_single_turn(
20382078
input_item = generated_item.to_input_item()
20392079
input.append(input_item)
20402080

2081+
input = cls._merge_provider_data_in_items(input)
2082+
20412083
new_response = await cls._get_new_response(
20422084
agent,
20432085
system_prompt,
@@ -2375,6 +2417,30 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
23752417

23762418
return run_config.model_provider.get_model(agent.model)
23772419

2420+
@classmethod
2421+
def _merge_provider_data_in_items(
2422+
cls, items: list[TResponseInputItem]
2423+
) -> list[TResponseInputItem]:
2424+
"""Remove providerData fields from items."""
2425+
result = []
2426+
for item in items:
2427+
if isinstance(item, dict):
2428+
merged_item = dict(item)
2429+
# Pop both possible keys (providerData and provider_data)
2430+
provider_data = merged_item.pop("providerData", None)
2431+
if provider_data is None:
2432+
provider_data = merged_item.pop("provider_data", None)
2433+
# Merge contents if providerData exists and is a dict
2434+
if isinstance(provider_data, dict):
2435+
# Merge provider_data contents, with existing fields taking precedence
2436+
for key, value in provider_data.items():
2437+
if key not in merged_item:
2438+
merged_item[key] = value
2439+
result.append(cast(TResponseInputItem, merged_item))
2440+
else:
2441+
result.append(item)
2442+
return result
2443+
23782444
@classmethod
23792445
async def _prepare_input_with_session(
23802446
cls,
@@ -2398,6 +2464,7 @@ async def _prepare_input_with_session(
23982464

23992465
# Get previous conversation history
24002466
history = await session.get_items()
2467+
history = cls._merge_provider_data_in_items(history)
24012468

24022469
# Convert input to list format
24032470
new_input_list = ItemHelpers.input_to_new_input_list(input)
@@ -2407,7 +2474,9 @@ async def _prepare_input_with_session(
24072474
elif callable(session_input_callback):
24082475
res = session_input_callback(history, new_input_list)
24092476
if inspect.isawaitable(res):
2410-
return await res
2477+
res = await res
2478+
if isinstance(res, list):
2479+
res = cls._merge_provider_data_in_items(res)
24112480
return res
24122481
else:
24132482
raise UserError(

0 commit comments

Comments
 (0)