-
Notifications
You must be signed in to change notification settings - Fork 3.2k
fix(agent-tool): preserve nested run cache across tool-call GC #2491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,5 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import weakref | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -19,7 +18,6 @@ | |
| int, | ||
| tuple[str, str, str, str, str | None, str | None], | ||
| ] = {} | ||
| _agent_tool_call_refs_by_obj: dict[int, weakref.ReferenceType[ResponseFunctionToolCall]] = {} | ||
|
|
||
|
|
||
| def _tool_call_signature( | ||
|
|
@@ -47,9 +45,6 @@ def _index_agent_tool_run_result( | |
|
|
||
| def _drop_agent_tool_run_result(tool_call_obj_id: int) -> None: | ||
| """Remove a tool call object from the fallback index.""" | ||
| tool_call_refs = _agent_tool_call_refs_by_obj | ||
| if isinstance(tool_call_refs, dict): | ||
| tool_call_refs.pop(tool_call_obj_id, None) | ||
| signature_by_obj = _agent_tool_run_result_signature_by_obj | ||
| if not isinstance(signature_by_obj, dict): | ||
| return | ||
|
|
@@ -67,26 +62,15 @@ def _drop_agent_tool_run_result(tool_call_obj_id: int) -> None: | |
| results_by_signature.pop(signature, None) | ||
|
|
||
|
|
||
| def _register_tool_call_ref(tool_call: ResponseFunctionToolCall, tool_call_obj_id: int) -> None: | ||
| """Tie cached nested run results to the tool call lifetime to avoid leaks.""" | ||
|
|
||
| def _on_tool_call_gc(_ref: weakref.ReferenceType[ResponseFunctionToolCall]) -> None: | ||
| run_results = _agent_tool_run_results_by_obj | ||
| if isinstance(run_results, dict): | ||
| run_results.pop(tool_call_obj_id, None) | ||
| _drop_agent_tool_run_result(tool_call_obj_id) | ||
|
|
||
| _agent_tool_call_refs_by_obj[tool_call_obj_id] = weakref.ref(tool_call, _on_tool_call_gc) | ||
|
|
||
|
|
||
| def record_agent_tool_run_result( | ||
| tool_call: ResponseFunctionToolCall, run_result: RunResult | RunResultStreaming | ||
| ) -> None: | ||
| """Store the nested agent run result by tool call identity.""" | ||
| tool_call_obj_id = id(tool_call) | ||
| # Clear any stale signature index if this object ID gets reused. | ||
| _drop_agent_tool_run_result(tool_call_obj_id) | ||
| _agent_tool_run_results_by_obj[tool_call_obj_id] = run_result | ||
| _index_agent_tool_run_result(tool_call, tool_call_obj_id) | ||
|
Comment on lines
72
to
73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
| _register_tool_call_ref(tool_call, tool_call_obj_id) | ||
|
|
||
|
|
||
| def consume_agent_tool_run_result( | ||
|
|
@@ -109,7 +93,6 @@ def consume_agent_tool_run_result( | |
| candidate_id = next(iter(candidate_ids)) | ||
| _agent_tool_run_results_by_signature.pop(signature, None) | ||
| _agent_tool_run_result_signature_by_obj.pop(candidate_id, None) | ||
| _agent_tool_call_refs_by_obj.pop(candidate_id, None) | ||
| return _agent_tool_run_results_by_obj.pop(candidate_id, None) | ||
|
|
||
|
|
||
|
|
@@ -151,5 +134,4 @@ def drop_agent_tool_run_result(tool_call: ResponseFunctionToolCall) -> None: | |
| candidate_id = next(iter(candidate_ids)) | ||
| _agent_tool_run_results_by_signature.pop(signature, None) | ||
| _agent_tool_run_result_signature_by_obj.pop(candidate_id, None) | ||
| _agent_tool_call_refs_by_obj.pop(candidate_id, None) | ||
| _agent_tool_run_results_by_obj.pop(candidate_id, None) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,180 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import gc | ||
| from typing import Any | ||
|
|
||
| import pytest | ||
|
|
||
| import agents.agent_tool_state as tool_state | ||
| from tests.utils.hitl import make_function_tool_call | ||
|
|
||
|
|
||
| def test_drop_agent_tool_run_result_handles_cleared_globals( | ||
| monkeypatch: pytest.MonkeyPatch, | ||
| ) -> None: | ||
| monkeypatch.setattr(tool_state, "_agent_tool_call_refs_by_obj", None) | ||
| monkeypatch.setattr(tool_state, "_agent_tool_run_result_signature_by_obj", None) | ||
| monkeypatch.setattr(tool_state, "_agent_tool_run_results_by_signature", None) | ||
|
|
||
| # Should not raise even if globals are cleared during interpreter shutdown. | ||
| tool_state._drop_agent_tool_run_result(123) | ||
|
|
||
|
|
||
| def test_agent_tool_result_survives_tool_call_gc_until_consumed() -> None: | ||
| tool_state._agent_tool_run_results_by_obj.clear() | ||
| tool_state._agent_tool_run_results_by_signature.clear() | ||
| tool_state._agent_tool_run_result_signature_by_obj.clear() | ||
|
|
||
| nested_result: Any = object() | ||
| tool_call = make_function_tool_call( | ||
| "inner_tool", | ||
| call_id="inner-1", | ||
| arguments='{"input":"hello"}', | ||
| ) | ||
| tool_state.record_agent_tool_run_result(tool_call, nested_result) | ||
|
|
||
| # Resume uses a reconstructed tool-call object with the same signature. | ||
| resume_tool_call = make_function_tool_call( | ||
| "inner_tool", | ||
| call_id="inner-1", | ||
| arguments='{"input":"hello"}', | ||
| ) | ||
|
|
||
| del tool_call | ||
| gc.collect() | ||
|
|
||
| assert tool_state.peek_agent_tool_run_result(resume_tool_call) is nested_result | ||
| assert tool_state.consume_agent_tool_run_result(resume_tool_call) is nested_result | ||
| assert tool_state.peek_agent_tool_run_result(resume_tool_call) is None | ||
|
|
||
|
|
||
| def test_record_clears_stale_signature_when_obj_id_is_reused() -> None: | ||
| tool_state._agent_tool_run_results_by_obj.clear() | ||
| tool_state._agent_tool_run_results_by_signature.clear() | ||
| tool_state._agent_tool_run_result_signature_by_obj.clear() | ||
|
|
||
| tool_call = make_function_tool_call( | ||
| "inner_tool", | ||
| call_id="new-call", | ||
| arguments='{"input":"hello"}', | ||
| ) | ||
| obj_id = id(tool_call) | ||
|
|
||
| stale_signature = ( | ||
| "old-call", | ||
| "inner_tool", | ||
| '{"input":"old"}', | ||
| "function_call", | ||
| "old-id", | ||
| "completed", | ||
| ) | ||
| stale_result: Any = object() | ||
| new_result: Any = object() | ||
|
|
||
| tool_state._agent_tool_run_results_by_obj[obj_id] = stale_result | ||
| tool_state._agent_tool_run_result_signature_by_obj[obj_id] = stale_signature | ||
| tool_state._agent_tool_run_results_by_signature[stale_signature] = {obj_id} | ||
|
|
||
| tool_state.record_agent_tool_run_result(tool_call, new_result) | ||
|
|
||
| assert obj_id in tool_state._agent_tool_run_results_by_obj | ||
| assert tool_state._agent_tool_run_results_by_obj[obj_id] is new_result | ||
| assert stale_signature not in tool_state._agent_tool_run_results_by_signature | ||
|
|
||
|
|
||
| def test_consume_peek_and_drop_direct_object_path() -> None: | ||
| tool_state._agent_tool_run_results_by_obj.clear() | ||
| tool_state._agent_tool_run_results_by_signature.clear() | ||
| tool_state._agent_tool_run_result_signature_by_obj.clear() | ||
|
|
||
| tool_call = make_function_tool_call( | ||
| "inner_tool", | ||
| call_id="direct-1", | ||
| arguments='{"input":"hello"}', | ||
| ) | ||
| nested_result: Any = object() | ||
|
|
||
| tool_state.record_agent_tool_run_result(tool_call, nested_result) | ||
| assert tool_state.peek_agent_tool_run_result(tool_call) is nested_result | ||
| assert tool_state.consume_agent_tool_run_result(tool_call) is nested_result | ||
| assert tool_state.consume_agent_tool_run_result(tool_call) is None | ||
|
|
||
| tool_state.record_agent_tool_run_result(tool_call, nested_result) | ||
| tool_state.drop_agent_tool_run_result(tool_call) | ||
| assert tool_state.peek_agent_tool_run_result(tool_call) is None | ||
|
|
||
|
|
||
| def test_signature_fallback_none_and_ambiguous_paths() -> None: | ||
| tool_state._agent_tool_run_results_by_obj.clear() | ||
| tool_state._agent_tool_run_results_by_signature.clear() | ||
| tool_state._agent_tool_run_result_signature_by_obj.clear() | ||
|
|
||
| tool_call = make_function_tool_call( | ||
| "inner_tool", | ||
| call_id="fallback-1", | ||
| arguments='{"input":"hello"}', | ||
| ) | ||
| signature = tool_state._tool_call_signature(tool_call) | ||
|
|
||
| # No candidate IDs -> None paths. | ||
| assert tool_state.consume_agent_tool_run_result(tool_call) is None | ||
| assert tool_state.peek_agent_tool_run_result(tool_call) is None | ||
| tool_state.drop_agent_tool_run_result(tool_call) | ||
|
|
||
| # Multiple candidate IDs -> ambiguous, should return/perform no-op. | ||
| tool_state._agent_tool_run_results_by_signature[signature] = {101, 202} | ||
| fake_result_1: Any = object() | ||
| fake_result_2: Any = object() | ||
| tool_state._agent_tool_run_results_by_obj[101] = fake_result_1 | ||
| tool_state._agent_tool_run_results_by_obj[202] = fake_result_2 | ||
| tool_state._agent_tool_run_result_signature_by_obj[101] = signature | ||
| tool_state._agent_tool_run_result_signature_by_obj[202] = signature | ||
|
|
||
| assert tool_state.consume_agent_tool_run_result(tool_call) is None | ||
| assert tool_state.peek_agent_tool_run_result(tool_call) is None | ||
| tool_state.drop_agent_tool_run_result(tool_call) | ||
|
|
||
|
|
||
| def test_drop_index_handles_missing_candidate_collection( | ||
| monkeypatch: pytest.MonkeyPatch, | ||
| ) -> None: | ||
| signature = ("call", "name", "{}", "function_call", "id", "completed") | ||
|
|
||
| signature_by_obj = {7: signature} | ||
| monkeypatch.setattr(tool_state, "_agent_tool_run_result_signature_by_obj", signature_by_obj) | ||
| monkeypatch.setattr(tool_state, "_agent_tool_run_results_by_signature", None) | ||
| tool_state._drop_agent_tool_run_result(7) | ||
|
|
||
| signature_by_obj = {9: signature} | ||
| monkeypatch.setattr(tool_state, "_agent_tool_run_result_signature_by_obj", signature_by_obj) | ||
| monkeypatch.setattr(tool_state, "_agent_tool_run_results_by_signature", {}) | ||
| tool_state._drop_agent_tool_run_result(9) | ||
|
|
||
|
|
||
| def test_drop_removes_single_fallback_candidate() -> None: | ||
| tool_state._agent_tool_run_results_by_obj.clear() | ||
| tool_state._agent_tool_run_results_by_signature.clear() | ||
| tool_state._agent_tool_run_result_signature_by_obj.clear() | ||
|
|
||
| stored_call = make_function_tool_call( | ||
| "inner_tool", | ||
| call_id="drop-fallback", | ||
| arguments='{"input":"hello"}', | ||
| ) | ||
| probe_call = make_function_tool_call( | ||
| "inner_tool", | ||
| call_id="drop-fallback", | ||
| arguments='{"input":"hello"}', | ||
| ) | ||
|
|
||
| stored_id = id(stored_call) | ||
| signature = tool_state._tool_call_signature(stored_call) | ||
| nested_result: Any = object() | ||
| tool_state._agent_tool_run_results_by_obj[stored_id] = nested_result | ||
| tool_state._agent_tool_run_result_signature_by_obj[stored_id] = signature | ||
| tool_state._agent_tool_run_results_by_signature[signature] = {stored_id} | ||
|
|
||
| tool_state.drop_agent_tool_run_result(probe_call) | ||
|
|
||
| assert signature not in tool_state._agent_tool_run_results_by_signature | ||
| assert stored_id not in tool_state._agent_tool_run_result_signature_by_obj | ||
| assert stored_id not in tool_state._agent_tool_run_results_by_obj |
Uh oh!
There was an error while loading. Please reload this page.