Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,18 +484,13 @@ async def _run_on_tool_error_callbacks(
tool = _get_tool(function_call, tools_dict)
except ValueError as tool_error:
tool = BaseTool(name=function_call.name, description='Tool not found')
error_response = await _run_on_tool_error_callbacks(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
if error_response is not None:
return __build_response_event(
tool, error_response, tool_context, invocation_context
)
else:
raise tool_error
# Fall through to _run_with_trace so that before_tool_callback and the
# OTel span are created *before* on_tool_error_callback fires. This
# keeps the callback lifecycle balanced (push/pop) and prevents plugins
# like BigQueryAgentAnalyticsPlugin from corrupting their span stacks.
_tool_lookup_error: Exception = tool_error
else:
_tool_lookup_error = None

async def _run_with_trace():
nonlocal function_args
Expand All @@ -520,6 +515,22 @@ async def _run_with_trace():
if function_response:
break

# Step 2.5: If the tool was not found (hallucinated), surface the error
# *after* before_tool_callback so the lifecycle stays balanced.
if _tool_lookup_error is not None:
error_response = await _run_on_tool_error_callbacks(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=_tool_lookup_error,
)
if error_response is not None:
return __build_response_event(
tool, error_response, tool_context, invocation_context
)
else:
raise _tool_lookup_error

# Step 3: Otherwise, proceed calling the tool normally.
if function_response is None:
try:
Expand Down Expand Up @@ -715,17 +726,9 @@ async def _run_on_tool_error_callbacks(
tool = _get_tool(function_call, tools_dict)
except ValueError as tool_error:
tool = BaseTool(name=function_call.name, description='Tool not found')
error_response = await _run_on_tool_error_callbacks(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
if error_response is not None:
return __build_response_event(
tool, error_response, tool_context, invocation_context
)
raise tool_error
_tool_lookup_error: Exception = tool_error
else:
_tool_lookup_error = None

async def _run_with_trace():
nonlocal function_args
Expand Down Expand Up @@ -755,6 +758,21 @@ async def _run_with_trace():
if function_response:
break

# Step 2.5: If the tool was not found (hallucinated), surface the error
# *after* before_tool_callback so the lifecycle stays balanced.
if _tool_lookup_error is not None:
error_response = await _run_on_tool_error_callbacks(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=_tool_lookup_error,
)
if error_response is not None:
return __build_response_event(
tool, error_response, tool_context, invocation_context
)
raise _tool_lookup_error

# Step 3: Otherwise, proceed calling the tool normally.
if function_response is None:
try:
Expand Down
108 changes: 108 additions & 0 deletions tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,5 +340,113 @@ def agent_after_cb(tool, args, tool_context, tool_response):
assert part.function_response.response == mock_plugin.after_tool_response


@pytest.mark.asyncio
async def test_hallucinated_tool_fires_before_and_error_callbacks(
mock_tool, mock_plugin
):
"""Regression test for https://github.com/google/adk-python/issues/4775.

When the LLM hallucinates a tool name, on_tool_error_callback used to fire
*before* before_tool_callback, corrupting plugin span stacks (e.g.
BigQueryAgentAnalyticsPlugin's TraceManager). After the fix, both
callbacks should fire in order: before_tool → on_tool_error.
"""
mock_plugin.enable_before_tool_callback = True
mock_plugin.enable_on_tool_error_callback = True

# Track callback invocation order
call_order = []
original_before = mock_plugin.before_tool_callback
original_error = mock_plugin.on_tool_error_callback

async def tracking_before(**kwargs):
call_order.append("before_tool")
return await original_before(**kwargs)

async def tracking_error(**kwargs):
call_order.append("on_tool_error")
return await original_error(**kwargs)

mock_plugin.before_tool_callback = tracking_before
mock_plugin.on_tool_error_callback = tracking_error

model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name="agent",
model=model,
tools=[mock_tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content="", plugins=[mock_plugin]
)

# Build function call for a non-existent tool (hallucinated name)
function_call = types.FunctionCall(
name="hallucinated_tool_xyz", args={"query": "test"}
)
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {mock_tool.name: mock_tool}

result_event = await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)

# on_tool_error_callback returned a response, so we should get an event
assert result_event is not None
part = result_event.content.parts[0]
assert part.function_response.response == mock_plugin.on_tool_error_response

# Verify that before_tool fired BEFORE on_tool_error
assert "before_tool" in call_order
assert "on_tool_error" in call_order
assert call_order.index("before_tool") < call_order.index("on_tool_error")


@pytest.mark.asyncio
async def test_hallucinated_tool_raises_when_no_error_callback(
mock_tool, mock_plugin
):
"""When a tool is hallucinated and no error callback handles it, ValueError
should propagate — but only after before_tool_callback has had a chance to
run (so plugin stacks remain balanced)."""
mock_plugin.enable_before_tool_callback = False
mock_plugin.enable_on_tool_error_callback = False

model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name="agent",
model=model,
tools=[mock_tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content="", plugins=[mock_plugin]
)

function_call = types.FunctionCall(
name="nonexistent_tool", args={}
)
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {mock_tool.name: mock_tool}

with pytest.raises(ValueError, match="nonexistent_tool"):
await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)


if __name__ == "__main__":
pytest.main([__file__])