diff --git a/contributing/samples/human_tool_confirmation/__init__.py b/contributing/samples/human_tool_confirmation/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/human_tool_confirmation/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/human_tool_confirmation/agent.py b/contributing/samples/human_tool_confirmation/agent.py new file mode 100644 index 0000000000..b49078dae4 --- /dev/null +++ b/contributing/samples/human_tool_confirmation/agent.py @@ -0,0 +1,80 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk import Agent +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_confirmation import ToolConfirmation +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +def reimburse(amount: int, tool_context: ToolContext) -> str: + """Reimburse the employee for the given amount.""" + return {'status': 'ok'} + + +def request_time_off(days: int, tool_context: ToolContext): + """Request day off for the employee.""" + if days <= 0: + return {'status': 'Invalid days to request.'} + + if days <= 2: + return { + 'status': 'ok', + 'approved_days': days, + } + + tool_confirmation = tool_context.tool_confirmation + if not tool_confirmation: + tool_context.request_confirmation( + hint=( + 'Please approve or reject the tool call request_time_off() by' + ' responding with a FunctionResponse with an expected' + ' ToolConfirmation payload.' + ), + payload={ + 'approved_days': 0, + }, + ) + return {'status': 'Manager approval is required.'} + + approved_days = tool_confirmation.payload['approved_days'] + approved_days = min(approved_days, days) + if approved_days == 0: + return {'status': 'The time off request is rejected.', 'approved_days': 0} + return { + 'status': 'ok', + 'approved_days': approved_days, + } + + +root_agent = Agent( + model='gemini-2.5-flash', + name='time_off_agent', + instruction=""" + You are a helpful assistant that can help employees with reimbursement and time off requests. + - Use the `reimburse` tool for reimbursement requests. + - Use the `request_time_off` tool for time off requests. + - Prioritize using tools to fulfill the user's request. + - Always respond to the user with the tool results. + """, + tools=[ + # Set require_confirmation to True to require user confirmation for the + # tool call. This is an easier way to get user confirmation if the tool + # just need a boolean confirmation. + FunctionTool(reimburse, require_confirmation=True), + request_time_off, + ], + generate_content_config=types.GenerateContentConfig(temperature=0.1), +) diff --git a/src/google/adk/events/event_actions.py b/src/google/adk/events/event_actions.py index 994a7900b9..a46ad16e98 100644 --- a/src/google/adk/events/event_actions.py +++ b/src/google/adk/events/event_actions.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Any from typing import Optional from pydantic import alias_generators @@ -22,6 +23,7 @@ from pydantic import Field from ..auth.auth_tool import AuthConfig +from ..tools.tool_confirmation import ToolConfirmation class EventActions(BaseModel): @@ -64,3 +66,9 @@ class EventActions(BaseModel): identify the function call. - Values: The requested auth config. """ + + requested_tool_confirmations: dict[str, ToolConfirmation] = Field( + default_factory=dict + ) + """A dict of tool confirmation requested by this event, keyed by + function call id.""" diff --git a/src/google/adk/flows/llm_flows/__init__.py b/src/google/adk/flows/llm_flows/__init__.py index 6dbd22f5d0..4a916554d0 100644 --- a/src/google/adk/flows/llm_flows/__init__.py +++ b/src/google/adk/flows/llm_flows/__init__.py @@ -18,3 +18,4 @@ from . import functions from . import identity from . import instructions +from . import request_confirmation diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 2d8fd15920..2785abea28 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -638,6 +638,12 @@ async def _postprocess_handle_function_calls_async( if auth_event: yield auth_event + tool_confirmation_event = functions.generate_request_confirmation_event( + invocation_context, function_call_event, function_response_event + ) + if tool_confirmation_event: + yield tool_confirmation_event + # Always yield the function response event first yield function_response_event diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index ae1bd44ad9..f05a821d94 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -27,6 +27,7 @@ from ...models.llm_request import LlmRequest from ._base_llm_processor import BaseLlmRequestProcessor from .functions import remove_client_function_call_id +from .functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME from .functions import REQUEST_EUC_FUNCTION_CALL_NAME @@ -238,6 +239,9 @@ def _get_contents( if _is_auth_event(event): # Skip auth events. continue + if _is_request_confirmation_event(event): + # Skip request confirmation events. + continue filtered_events.append( _convert_foreign_event(event) if _is_other_agent_reply(agent_name, event) @@ -431,18 +435,23 @@ def _is_event_belongs_to_branch( return invocation_branch.startswith(event.branch) -def _is_auth_event(event: Event) -> bool: - if not event.content.parts: +def _is_function_call_event(event: Event, function_name: str) -> bool: + """Checks if an event is a function call/response for a given function name.""" + if not event.content or not event.content.parts: return False for part in event.content.parts: - if ( - part.function_call - and part.function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME - ): + if part.function_call and part.function_call.name == function_name: return True - if ( - part.function_response - and part.function_response.name == REQUEST_EUC_FUNCTION_CALL_NAME - ): + if part.function_response and part.function_response.name == function_name: return True return False + + +def _is_auth_event(event: Event) -> bool: + """Checks if the event is an authentication event.""" + return _is_function_call_event(event, REQUEST_EUC_FUNCTION_CALL_NAME) + + +def _is_request_confirmation_event(event: Event) -> bool: + """Checks if the event is a request confirmation event.""" + return _is_function_call_event(event, REQUEST_CONFIRMATION_FUNCTION_CALL_NAME) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index b0700270f1..72d5211a1b 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -39,6 +39,7 @@ from ...telemetry import trace_tool_call from ...telemetry import tracer from ...tools.base_tool import BaseTool +from ...tools.tool_confirmation import ToolConfirmation from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing @@ -47,6 +48,7 @@ AF_FUNCTION_CALL_ID_PREFIX = 'adk-' REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential' +REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation' logger = logging.getLogger('google_adk.' + __name__) @@ -130,11 +132,76 @@ def generate_auth_event( ) +def generate_request_confirmation_event( + invocation_context: InvocationContext, + function_call_event: Event, + function_response_event: Event, +) -> Optional[Event]: + """Generates a request confirmation event from a function response event.""" + if not function_response_event.actions.requested_tool_confirmations: + return None + parts = [] + long_running_tool_ids = set() + function_calls = function_call_event.get_function_calls() + for ( + function_call_id, + tool_confirmation, + ) in function_response_event.actions.requested_tool_confirmations.items(): + original_function_call = next( + (fc for fc in function_calls if fc.id == function_call_id), None + ) + if not original_function_call: + continue + request_confirmation_function_call = types.FunctionCall( + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args={ + 'originalFunctionCall': original_function_call.model_dump( + exclude_none=True, by_alias=True + ), + 'toolConfirmation': tool_confirmation.model_dump( + by_alias=True, exclude_none=True + ), + }, + ) + request_confirmation_function_call.id = generate_client_function_call_id() + long_running_tool_ids.add(request_confirmation_function_call.id) + parts.append(types.Part(function_call=request_confirmation_function_call)) + + return Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + branch=invocation_context.branch, + content=types.Content( + parts=parts, role=function_response_event.content.role + ), + long_running_tool_ids=long_running_tool_ids, + ) + + async def handle_function_calls_async( invocation_context: InvocationContext, function_call_event: Event, tools_dict: dict[str, BaseTool], filters: Optional[set[str]] = None, + tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None, +) -> Optional[Event]: + """Calls the functions and returns the function response event.""" + function_calls = function_call_event.get_function_calls() + return await handle_function_call_list_async( + invocation_context, + function_calls, + tools_dict, + filters, + tool_confirmation_dict, + ) + + +async def handle_function_call_list_async( + invocation_context: InvocationContext, + function_calls: list[types.FunctionCall], + tools_dict: dict[str, BaseTool], + filters: Optional[set[str]] = None, + tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None, ) -> Optional[Event]: """Calls the functions and returns the function response event.""" from ...agents.llm_agent import LlmAgent @@ -143,8 +210,6 @@ async def handle_function_calls_async( if not isinstance(agent, LlmAgent): return None - function_calls = function_call_event.get_function_calls() - # Filter function calls filtered_calls = [ fc for fc in function_calls if not filters or fc.id in filters @@ -161,6 +226,9 @@ async def handle_function_calls_async( function_call, tools_dict, agent, + tool_confirmation_dict[function_call.id] + if tool_confirmation_dict + else None, ) ) for function_call in filtered_calls @@ -198,12 +266,14 @@ async def _execute_single_function_call_async( function_call: types.FunctionCall, tools_dict: dict[str, BaseTool], agent: LlmAgent, + tool_confirmation: Optional[ToolConfirmation] = None, ) -> Optional[Event]: """Execute a single function call with thread safety for state modifications.""" tool, tool_context = _get_tool_and_context( invocation_context, function_call, tools_dict, + tool_confirmation, ) with tracer.start_as_current_span(f'execute_tool {tool.name}'): @@ -567,6 +637,7 @@ def _get_tool_and_context( invocation_context: InvocationContext, function_call: types.FunctionCall, tools_dict: dict[str, BaseTool], + tool_confirmation: Optional[ToolConfirmation] = None, ): if function_call.name not in tools_dict: raise ValueError( @@ -576,6 +647,7 @@ def _get_tool_and_context( tool_context = ToolContext( invocation_context=invocation_context, function_call_id=function_call.id, + tool_confirmation=tool_confirmation, ) tool = tools_dict[function_call.name] diff --git a/src/google/adk/flows/llm_flows/request_confirmation.py b/src/google/adk/flows/llm_flows/request_confirmation.py new file mode 100644 index 0000000000..7bf9759563 --- /dev/null +++ b/src/google/adk/flows/llm_flows/request_confirmation.py @@ -0,0 +1,169 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import logging +from typing import AsyncGenerator +from typing import TYPE_CHECKING + +from google.genai import types +from typing_extensions import override + +from . import functions +from ...agents.invocation_context import InvocationContext +from ...agents.readonly_context import ReadonlyContext +from ...events.event import Event +from ...models.llm_request import LlmRequest +from ...tools.tool_confirmation import ToolConfirmation +from ._base_llm_processor import BaseLlmRequestProcessor +from .functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + +if TYPE_CHECKING: + from ...agents.llm_agent import LlmAgent + + +logger = logging.getLogger('google_adk.' + __name__) + + +class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): + """Handles tool confirmation information to build the LLM request.""" + + @override + async def run_async( + self, invocation_context: InvocationContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + from ...agents.llm_agent import LlmAgent + + agent = invocation_context.agent + if not isinstance(agent, LlmAgent): + return + events = invocation_context.session.events + if not events: + return + + request_confirmation_function_responses = ( + dict() + ) # {function call id, tool confirmation} + + confirmation_event_index = -1 + for k in range(len(events) - 1, -1, -1): + event = events[k] + # Find the first event authored by user + if not event.author or event.author != 'user': + continue + responses = event.get_function_responses() + if not responses: + return + + for function_response in responses: + if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME: + continue + + # Find the FunctionResponse event that contains the user provided tool + # confirmation + if ( + function_response.response + and len(function_response.response.values()) == 1 + and 'response' in function_response.response.keys() + ): + # ADK web client will send a request that is always encapted in a + # 'response' key. + tool_confirmation = ToolConfirmation.model_validate( + json.loads(function_response.response['response']) + ) + else: + tool_confirmation = ToolConfirmation.model_validate( + function_response.response + ) + request_confirmation_function_responses[function_response.id] = ( + tool_confirmation + ) + confirmation_event_index = k + break + + if not request_confirmation_function_responses: + return + + for i in range(len(events) - 2, -1, -1): + event = events[i] + # Find the system generated FunctionCall event requesting the tool + # confirmation + function_calls = event.get_function_calls() + if not function_calls: + continue + + tools_to_resume_with_confirmation = ( + dict() + ) # {Function call id, tool confirmation} + tools_to_resume_with_args = dict() # {Function call id, function calls} + + for function_call in function_calls: + if ( + function_call.id + not in request_confirmation_function_responses.keys() + ): + continue + + args = function_call.args + if 'originalFunctionCall' not in args: + continue + original_function_call = types.FunctionCall( + **args['originalFunctionCall'] + ) + tools_to_resume_with_confirmation[original_function_call.id] = ( + request_confirmation_function_responses[function_call.id] + ) + tools_to_resume_with_args[original_function_call.id] = ( + original_function_call + ) + if not tools_to_resume_with_confirmation: + continue + + # Remove the tools that have already been confirmed. + for i in range(len(events) - 1, confirmation_event_index, -1): + event = events[i] + function_response = event.get_function_responses() + if not function_response: + continue + + for function_response in event.get_function_responses(): + if function_response.id in tools_to_resume_with_confirmation: + tools_to_resume_with_confirmation.pop(function_response.id) + tools_to_resume_with_args.pop(function_response.id) + if not tools_to_resume_with_confirmation: + break + + if not tools_to_resume_with_confirmation: + continue + + if function_response_event := await functions.handle_function_call_list_async( + invocation_context, + tools_to_resume_with_args.values(), + { + tool.name: tool + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ) + }, + # There could be parallel function calls that require input + # response would be a dict keyed by function call id + tools_to_resume_with_confirmation.keys(), + tools_to_resume_with_confirmation, + ): + yield function_response_event + return + + +request_processor = _RequestConfirmationLlmRequestProcessor() diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index 5b398b52b7..2644ebc046 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -25,6 +25,7 @@ from . import contents from . import identity from . import instructions +from . import request_confirmation from ...auth import auth_preprocessor from .base_llm_flow import BaseLlmFlow @@ -43,6 +44,7 @@ def __init__(self): self.request_processors += [ basic.request_processor, auth_preprocessor.request_processor, + request_confirmation.request_processor, instructions.request_processor, identity.request_processor, contents.request_processor, diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 69f5934b2c..a3a580662a 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -15,6 +15,7 @@ from __future__ import annotations import inspect +import logging from typing import Any from typing import Callable from typing import Optional @@ -25,8 +26,11 @@ from ..utils.context_utils import Aclosing from ._automatic_function_calling_util import build_function_declaration from .base_tool import BaseTool +from .tool_confirmation import ToolConfirmation from .tool_context import ToolContext +logger = logging.getLogger('google_adk.' + __name__) + class FunctionTool(BaseTool): """A tool that wraps a user-defined Python function. @@ -35,8 +39,15 @@ class FunctionTool(BaseTool): func: The function to wrap. """ - def __init__(self, func: Callable[..., Any]): - """Extract metadata from a callable object.""" + def __init__( + self, func: Callable[..., Any], *, require_confirmation: bool = False + ): + """Initializes the FunctionTool. Extracts metadata from a callable object. + + Args: + func: The function to wrap. + require_confirmation: Whether the tool call requires user confirmation. + """ name = '' doc = '' # Handle different types of callables @@ -61,6 +72,7 @@ def __init__(self, func: Callable[..., Any]): super().__init__(name=name, description=doc) self.func = func self._ignore_params = ['tool_context', 'input_stream'] + self._require_confirmation = require_confirmation @override def _get_declaration(self) -> Optional[types.FunctionDeclaration]: @@ -106,6 +118,29 @@ async def run_async( You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" return {'error': error_str} + if self._require_confirmation: + if not tool_context.tool_confirmation: + args_to_show = args_to_call.copy() + if 'tool_context' in args_to_show: + args_to_show.pop('tool_context') + + tool_context.request_confirmation( + hint=( + f'Please approve or reject the tool call {self.name}() by' + ' responding with a FunctionResponse with an expected' + ' ToolConfirmation payload.' + ), + ) + return { + 'error': ( + 'This tool call requires confirmation, please approve or' + ' reject.' + ) + } + else: + if not tool_context.tool_confirmation.confirmed: + return {'error': 'This tool call is rejected.'} + # Functions are callable objects, but not all callable objects are functions # checking coroutine function is not enough. We also need to check whether # Callable's __call__ function is a coroutine funciton @@ -137,6 +172,8 @@ async def _call_live( ].stream if 'tool_context' in signature.parameters: args_to_call['tool_context'] = tool_context + + # TODO: support tool confirmation for live mode. async with Aclosing(self.func(**args_to_call)) as agen: async for item in agen: yield item diff --git a/src/google/adk/tools/tool_confirmation.py b/src/google/adk/tools/tool_confirmation.py new file mode 100644 index 0000000000..df14ff5026 --- /dev/null +++ b/src/google/adk/tools/tool_confirmation.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any +from typing import Optional + +from pydantic import alias_generators +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + +from ..utils.feature_decorator import experimental + + +@experimental +class ToolConfirmation(BaseModel): + """Represents a tool confirmation configuration.""" + + model_config = ConfigDict( + extra="forbid", + alias_generator=alias_generators.to_camel, + populate_by_name=True, + ) + """The pydantic model config.""" + + hint: str = "" + """The hint text for why the input is needed.""" + confirmed: bool = False + """Whether the tool excution is confirmed.""" + payload: Optional[Any] = None + """The custom data payload needed from the user to continue the flow. + It should be JSON serializable.""" diff --git a/src/google/adk/tools/tool_context.py b/src/google/adk/tools/tool_context.py index 3a1c2a8ddf..91d6116631 100644 --- a/src/google/adk/tools/tool_context.py +++ b/src/google/adk/tools/tool_context.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Any from typing import Optional from typing import TYPE_CHECKING @@ -21,6 +22,7 @@ from ..auth.auth_credential import AuthCredential from ..auth.auth_handler import AuthHandler from ..auth.auth_tool import AuthConfig +from .tool_confirmation import ToolConfirmation if TYPE_CHECKING: from ..agents.invocation_context import InvocationContext @@ -43,6 +45,7 @@ class ToolContext(CallbackContext): If LLM didn't return this id, ADK will assign one to it. This id is used to map function call response to the original function call. event_actions: The event actions of the current tool call. + tool_confirmation: The tool confirmation of the current tool call. """ def __init__( @@ -51,9 +54,11 @@ def __init__( *, function_call_id: Optional[str] = None, event_actions: Optional[EventActions] = None, + tool_confirmation: Optional[ToolConfirmation] = None, ): super().__init__(invocation_context, event_actions=event_actions) self.function_call_id = function_call_id + self.tool_confirmation = tool_confirmation @property def actions(self) -> EventActions: @@ -69,6 +74,27 @@ def request_credential(self, auth_config: AuthConfig) -> None: def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential: return AuthHandler(auth_config).get_auth_response(self.state) + def request_confirmation( + self, + *, + hint: Optional[str] = None, + payload: Optional[Any] = None, + ) -> None: + """Requests confirmation for the given function call. + + Args: + hint: A hint to the user on how to confirm the tool call. + payload: The payload used to confirm the tool call. + """ + if not self.function_call_id: + raise ValueError('function_call_id is not set.') + self._event_actions.requested_tool_confirmations[self.function_call_id] = ( + ToolConfirmation( + hint=hint, + payload=payload, + ) + ) + async def search_memory(self, query: str) -> SearchMemoryResponse: """Searches the memory of the current user.""" if self._invocation_context.memory_service is None: diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py index fae62d3534..8d9464b07b 100644 --- a/tests/unittests/flows/llm_flows/test_contents.py +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -162,6 +162,60 @@ def test_get_contents_filters_empty_events(): assert contents_result[0].parts[0].text == "Hello" +def test_get_contents_filters_auth_and_confirmation_events(): + """Test _get_contents filters out auth and request confirmation events.""" + auth_event = Event( + invocation_id="test_inv", + author="agent", + content=types.Content( + role="model", + parts=[ + types.Part( + function_call=types.FunctionCall( + id="auth_func", + name=contents.REQUEST_EUC_FUNCTION_CALL_NAME, + args={}, + ) + ) + ], + ), + ) + + confirmation_event = Event( + invocation_id="test_inv", + author="agent", + content=types.Content( + role="model", + parts=[ + types.Part( + function_call=types.FunctionResponse( + id="confirm_func", + name=contents.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={ + "confirmed": True, + }, + ) + ) + ], + ), + ) + + valid_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + + contents_result = _get_contents( + None, [auth_event, confirmation_event, valid_event], "test_agent" + ) + assert len(contents_result) == 1 + assert contents_result[0].role == "user" + assert contents_result[0].parts[0].text == "Hello" + + def test_convert_foreign_event(): """Test _convert_foreign_event function.""" agent_event = Event( diff --git a/tests/unittests/flows/llm_flows/test_request_confirmation.py b/tests/unittests/flows/llm_flows/test_request_confirmation.py new file mode 100644 index 0000000000..bd36e83c79 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_request_confirmation.py @@ -0,0 +1,302 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import patch + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.events.event import Event +from google.adk.flows.llm_flows import functions +from google.adk.flows.llm_flows.request_confirmation import request_processor +from google.adk.models.llm_request import LlmRequest +from google.adk.tools.tool_confirmation import ToolConfirmation +from google.genai import types +import pytest + +from ... import testing_utils + +MOCK_TOOL_NAME = "mock_tool" +MOCK_FUNCTION_CALL_ID = "mock_function_call_id" +MOCK_CONFIRMATION_FUNCTION_CALL_ID = "mock_confirmation_function_call_id" + + +def mock_tool(param1: str): + """Mock tool function.""" + return f"Mock tool result with {param1}" + + +@pytest.mark.asyncio +async def test_request_confirmation_processor_no_events(): + """Test that the processor returns None when there are no events.""" + agent = LlmAgent(name="test_agent", tools=[mock_tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + llm_request = LlmRequest() + + events = [] + async for event in request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + assert not events + + +@pytest.mark.asyncio +async def test_request_confirmation_processor_no_function_responses(): + """Test that the processor returns None when the user event has no function responses.""" + agent = LlmAgent(name="test_agent", tools=[mock_tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + llm_request = LlmRequest() + + invocation_context.session.events.append( + Event(author="user", content=types.Content()) + ) + + events = [] + async for event in request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + assert not events + + +@pytest.mark.asyncio +async def test_request_confirmation_processor_no_confirmation_function_response(): + """Test that the processor returns None when no confirmation function response is present.""" + agent = LlmAgent(name="test_agent", tools=[mock_tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + llm_request = LlmRequest() + + invocation_context.session.events.append( + Event( + author="user", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name="other_function", response={} + ) + ) + ] + ), + ) + ) + + events = [] + async for event in request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + assert not events + + +@pytest.mark.asyncio +async def test_request_confirmation_processor_success(): + """Test the successful processing of a tool confirmation.""" + agent = LlmAgent(name="test_agent", tools=[mock_tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + llm_request = LlmRequest() + + original_function_call = types.FunctionCall( + name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID + ) + + tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint") + tool_confirmation_args = { + "originalFunctionCall": original_function_call.model_dump( + exclude_none=True, by_alias=True + ), + "toolConfirmation": tool_confirmation.model_dump( + by_alias=True, exclude_none=True + ), + } + + # Event with the request for confirmation + invocation_context.session.events.append( + Event( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args=tool_confirmation_args, + id=MOCK_CONFIRMATION_FUNCTION_CALL_ID, + ) + ) + ] + ), + ) + ) + + # Event with the user's confirmation + user_confirmation = ToolConfirmation(confirmed=True) + invocation_context.session.events.append( + Event( + author="user", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + id=MOCK_CONFIRMATION_FUNCTION_CALL_ID, + response={ + "response": user_confirmation.model_dump_json() + }, + ) + ) + ] + ), + ) + ) + + expected_event = Event( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=MOCK_TOOL_NAME, + id=MOCK_FUNCTION_CALL_ID, + response={"result": "Mock tool result with test"}, + ) + ) + ] + ), + ) + + with patch( + "google.adk.flows.llm_flows.functions.handle_function_call_list_async" + ) as mock_handle_function_call_list_async: + mock_handle_function_call_list_async.return_value = expected_event + + events = [] + async for event in request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + assert len(events) == 1 + assert events[0] == expected_event + + mock_handle_function_call_list_async.assert_called_once() + args, _ = mock_handle_function_call_list_async.call_args + + assert list(args[1]) == [original_function_call] # function_calls + assert args[3] == {MOCK_FUNCTION_CALL_ID} # tools_to_confirm + assert ( + args[4][MOCK_FUNCTION_CALL_ID] == user_confirmation + ) # tool_confirmation_dict + + +@pytest.mark.asyncio +async def test_request_confirmation_processor_tool_not_confirmed(): + """Test when the tool execution is not confirmed by the user.""" + agent = LlmAgent(name="test_agent", tools=[mock_tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + llm_request = LlmRequest() + + original_function_call = types.FunctionCall( + name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID + ) + + tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint") + tool_confirmation_args = { + "originalFunctionCall": original_function_call.model_dump( + exclude_none=True, by_alias=True + ), + "toolConfirmation": tool_confirmation.model_dump( + by_alias=True, exclude_none=True + ), + } + + invocation_context.session.events.append( + Event( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args=tool_confirmation_args, + id=MOCK_CONFIRMATION_FUNCTION_CALL_ID, + ) + ) + ] + ), + ) + ) + + user_confirmation = ToolConfirmation(confirmed=False) + invocation_context.session.events.append( + Event( + author="user", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + id=MOCK_CONFIRMATION_FUNCTION_CALL_ID, + response={ + "response": user_confirmation.model_dump_json() + }, + ) + ) + ] + ), + ) + ) + + with patch( + "google.adk.flows.llm_flows.functions.handle_function_call_list_async" + ) as mock_handle_function_call_list_async: + mock_handle_function_call_list_async.return_value = Event( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=MOCK_TOOL_NAME, + id=MOCK_FUNCTION_CALL_ID, + response={"error": "Tool execution not confirmed"}, + ) + ) + ] + ), + ) + + events = [] + async for event in request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + assert len(events) == 1 + mock_handle_function_call_list_async.assert_called_once() + args, _ = mock_handle_function_call_list_async.call_args + assert ( + args[4][MOCK_FUNCTION_CALL_ID] == user_confirmation + ) # tool_confirmation_dict diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index 871f58dcb8..e7854a2c87 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -17,6 +17,7 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.sessions.session import Session from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_confirmation import ToolConfirmation from google.adk.tools.tool_context import ToolContext import pytest @@ -345,3 +346,51 @@ def sample_func_with_context(expected_arg: str, tool_context: ToolContext): "received_arg": "world", "context_present": True, } + + +@pytest.mark.asyncio +async def test_run_async_with_require_confirmation(): + """Test that run_async handles require_confirmation flag.""" + + def sample_func(arg1: str): + return {"received_arg": arg1} + + tool = FunctionTool(sample_func, require_confirmation=True) + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = MagicMock(spec=Session) + mock_invocation_context.session.state = MagicMock() + mock_invocation_context.agent = MagicMock() + mock_invocation_context.agent.name = "test_agent" + tool_context_mock = ToolContext(invocation_context=mock_invocation_context) + tool_context_mock.function_call_id = "test_function_call_id" + + # First call, should request confirmation + result = await tool.run_async( + args={"arg1": "hello"}, + tool_context=tool_context_mock, + ) + assert result == { + "error": "This tool call requires confirmation, please approve or reject." + } + assert tool_context_mock._event_actions.requested_tool_confirmations[ + "test_function_call_id" + ].hint == ( + "Please approve or reject the tool call sample_func() by responding with" + " a FunctionResponse with an expected ToolConfirmation payload." + ) + + # Second call, user rejects + tool_context_mock.tool_confirmation = ToolConfirmation(confirmed=False) + result = await tool.run_async( + args={"arg1": "hello"}, + tool_context=tool_context_mock, + ) + assert result == {"error": "This tool call is rejected."} + + # Third call, user approves + tool_context_mock.tool_confirmation = ToolConfirmation(confirmed=True) + result = await tool.run_async( + args={"arg1": "hello"}, + tool_context=tool_context_mock, + ) + assert result == {"received_arg": "hello"}