diff --git a/src/google/adk/flows/llm_flows/request_confirmation.py b/src/google/adk/flows/llm_flows/request_confirmation.py index 3cb92bf22b..a252d89db1 100644 --- a/src/google/adk/flows/llm_flows/request_confirmation.py +++ b/src/google/adk/flows/llm_flows/request_confirmation.py @@ -53,11 +53,27 @@ async def run_async( if not events: return - request_confirmation_function_responses = ( - dict() - ) # {function call id, tool confirmation} + request_confirmation_function_responses = dict() confirmation_event_index = -1 + + # Helper to unwrap redundant response envelopes and decode the innermost JSON. + def _parse_tool_confirmation_payload(payload: 'Any') -> 'Any': + while ( + isinstance(payload, dict) + and len(payload) == 1 + and 'response' in payload + ): + payload = payload['response'] + if isinstance(payload, str): + try: + payload = json.loads(payload) + except json.JSONDecodeError as exc: + raise ValueError( + 'Failed to decode tool confirmation payload.' + ) from exc + return payload + for k in range(len(events) - 1, -1, -1): event = events[k] # Find the first event authored by user @@ -71,22 +87,13 @@ async def run_async( if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME: continue - # Find the FunctionResponse event that contains the user provided tool - # confirmation - if ( + confirmation_payload = _parse_tool_confirmation_payload( function_response.response - and len(function_response.response.values()) == 1 - and 'response' in function_response.response.keys() - ): - # ADK web client will send a request that is always encapsulated in a - # 'response' key. - tool_confirmation = ToolConfirmation.model_validate( - json.loads(function_response.response['response']) - ) - else: - tool_confirmation = ToolConfirmation.model_validate( - function_response.response - ) + ) + + tool_confirmation = ToolConfirmation.model_validate( + confirmation_payload + ) request_confirmation_function_responses[function_response.id] = ( tool_confirmation ) @@ -104,10 +111,8 @@ async def run_async( if not function_calls: continue - tools_to_resume_with_confirmation = ( - dict() - ) # {Function call id, tool confirmation} - tools_to_resume_with_args = dict() # {Function call id, function calls} + tools_to_resume_with_confirmation = dict() + tools_to_resume_with_args = dict() for function_call in function_calls: if ( diff --git a/tests/unittests/flows/llm_flows/test_request_confirmation.py b/tests/unittests/flows/llm_flows/test_request_confirmation.py index bd36e83c79..3b8a78cad0 100644 --- a/tests/unittests/flows/llm_flows/test_request_confirmation.py +++ b/tests/unittests/flows/llm_flows/test_request_confirmation.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest.mock import patch from google.adk.agents.llm_agent import LlmAgent @@ -210,6 +209,103 @@ async def test_request_confirmation_processor_success(): ) # tool_confirmation_dict +@pytest.mark.asyncio +async def test_request_confirmation_processor_doubly_wrapped_response(): + """Test confirmation parsing when responses are nested under multiple keys.""" + agent = LlmAgent(name="test_agent", tools=[mock_tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + llm_request = LlmRequest() + + original_function_call = types.FunctionCall( + name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID + ) + + tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint") + tool_confirmation_args = { + "originalFunctionCall": original_function_call.model_dump( + exclude_none=True, by_alias=True + ), + "toolConfirmation": tool_confirmation.model_dump( + by_alias=True, exclude_none=True + ), + } + + invocation_context.session.events.append( + Event( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args=tool_confirmation_args, + id=MOCK_CONFIRMATION_FUNCTION_CALL_ID, + ) + ) + ] + ), + ) + ) + + user_confirmation = ToolConfirmation(confirmed=True) + invocation_context.session.events.append( + Event( + author="user", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + id=MOCK_CONFIRMATION_FUNCTION_CALL_ID, + response={ + "response": { + "response": user_confirmation.model_dump_json() + } + }, + ) + ) + ] + ), + ) + ) + + expected_event = Event( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=MOCK_TOOL_NAME, + id=MOCK_FUNCTION_CALL_ID, + response={"result": "Mock tool result with test"}, + ) + ) + ] + ), + ) + + with patch( + "google.adk.flows.llm_flows.functions.handle_function_call_list_async" + ) as mock_handle_function_call_list_async: + mock_handle_function_call_list_async.return_value = expected_event + + events = [] + async for event in request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + assert len(events) == 1 + assert events[0] == expected_event + + args, _ = mock_handle_function_call_list_async.call_args + assert ( + args[4][MOCK_FUNCTION_CALL_ID] == user_confirmation + ) # tool_confirmation_dict + + @pytest.mark.asyncio async def test_request_confirmation_processor_tool_not_confirmed(): """Test when the tool execution is not confirmed by the user."""