From 2bfc95d7c602e00935066b0d8d0dbde0fa7597be Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 4 Nov 2025 06:55:56 -0800 Subject: [PATCH] feat: HITL/Wire up tool confirmation support This is a port of the python implementation and part of the "human in the loop" workflow. PiperOrigin-RevId: 827953462 --- .../google/adk/flows/llmflows/Contents.java | 20 ++ .../google/adk/flows/llmflows/Functions.java | 110 +++++++++-- ...equestConfirmationLlmRequestProcessor.java | 185 ++++++++++++++++++ .../com/google/adk/tools/FunctionTool.java | 48 ++++- .../com/google/adk/tools/LoadMemoryTool.java | 6 +- .../adk/tools/LongRunningFunctionTool.java | 4 +- ...stConfirmationLlmRequestProcessorTest.java | 154 +++++++++++++++ .../google/adk/tools/FunctionToolTest.java | 160 +++++++++++++++ 8 files changed, 662 insertions(+), 25 deletions(-) create mode 100644 core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java create mode 100644 core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index e4452f4de..0289389b9 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -109,6 +109,9 @@ private ImmutableList getContents( if (!isEventBelongsToBranch(currentBranch, event)) { continue; } + if (isRequestConfirmationEvent(event)) { + continue; + } // TODO: Skip auth events. @@ -511,4 +514,21 @@ private static boolean hasContentWithNonEmptyParts(Event event) { .map(list -> !list.isEmpty()) // Optional .orElse(false); } + + /** Checks if the event is a request confirmation event. */ + private static boolean isRequestConfirmationEvent(Event event) { + return event.content().flatMap(Content::parts).stream() + .flatMap(List::stream) + // return event.content().flatMap(Content::parts).orElse(ImmutableList.of()).stream() + .anyMatch( + part -> + part.functionCall() + .flatMap(FunctionCall::name) + .map(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME::equals) + .orElse(false) + || part.functionResponse() + .flatMap(FunctionResponse::name) + .map(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME::equals) + .orElse(false)); + } } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 12d06b4cd..3c713ba1f 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -17,6 +17,8 @@ package com.google.adk.flows.llmflows; +import static com.google.common.collect.ImmutableMap.toImmutableMap; + import com.google.adk.Telemetry; import com.google.adk.agents.ActiveStreamingTool; import com.google.adk.agents.Callbacks.AfterToolCallback; @@ -27,6 +29,7 @@ import com.google.adk.events.EventActions; import com.google.adk.tools.BaseTool; import com.google.adk.tools.FunctionTool; +import com.google.adk.tools.ToolConfirmation; import com.google.adk.tools.ToolContext; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; @@ -52,7 +55,6 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; -import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -60,6 +62,7 @@ public final class Functions { private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-"; + static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation"; private static final Logger logger = LoggerFactory.getLogger(Functions.class); /** Generates a unique ID for a function call. */ @@ -122,6 +125,15 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) { /** Handles standard, non-streaming function calls. */ public static Maybe handleFunctionCalls( InvocationContext invocationContext, Event functionCallEvent, Map tools) { + return handleFunctionCalls(invocationContext, functionCallEvent, tools, ImmutableMap.of()); + } + + /** Handles standard, non-streaming function calls with tool confirmations. */ + public static Maybe handleFunctionCalls( + InvocationContext invocationContext, + Event functionCallEvent, + Map tools, + Map toolConfirmations) { ImmutableList functionCalls = functionCallEvent.functionCalls(); List> functionResponseEvents = new ArrayList<>(); @@ -134,9 +146,10 @@ public static Maybe handleFunctionCalls( ToolContext toolContext = ToolContext.builder(invocationContext) .functionCallId(functionCall.id().orElse("")) + .toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null))) .build(); - Map functionArgs = functionCall.args().orElse(new HashMap<>()); + Map functionArgs = functionCall.args().orElse(ImmutableMap.of()); Maybe> maybeFunctionResult = maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) @@ -192,10 +205,12 @@ public static Maybe handleFunctionCalls( if (events.isEmpty()) { return Maybe.empty(); } - Event mergedEvent = Functions.mergeParallelFunctionResponseEvents(events); - if (mergedEvent == null) { + Optional maybeMergedEvent = + Functions.mergeParallelFunctionResponseEvents(events); + if (maybeMergedEvent.isEmpty()) { return Maybe.empty(); } + var mergedEvent = maybeMergedEvent.get(); if (events.size() > 1) { Tracer tracer = Telemetry.getTracer(); @@ -288,7 +303,7 @@ public static Maybe handleFunctionCallsLive( if (events.isEmpty()) { return Maybe.empty(); } - return Maybe.just(Functions.mergeParallelFunctionResponseEvents(events)); + return Maybe.just(Functions.mergeParallelFunctionResponseEvents(events).orElse(null)); }); } @@ -387,13 +402,13 @@ public static Set getLongRunningFunctionCalls( return longRunningFunctionCalls; } - private static @Nullable Event mergeParallelFunctionResponseEvents( + private static Optional mergeParallelFunctionResponseEvents( List functionResponseEvents) { if (functionResponseEvents.isEmpty()) { - return null; + return Optional.empty(); } if (functionResponseEvents.size() == 1) { - return functionResponseEvents.get(0); + return Optional.of(functionResponseEvents.get(0)); } // Use the first event as the base for common attributes Event baseEvent = functionResponseEvents.get(0); @@ -410,15 +425,16 @@ public static Set getLongRunningFunctionCalls( mergedActionsBuilder.merge(event.actions()); } - return Event.builder() - .id(Event.generateEventId()) - .invocationId(baseEvent.invocationId()) - .author(baseEvent.author()) - .branch(baseEvent.branch()) - .content(Optional.of(Content.builder().role("user").parts(mergedParts).build())) - .actions(mergedActionsBuilder.build()) - .timestamp(baseEvent.timestamp()) - .build(); + return Optional.of( + Event.builder() + .id(Event.generateEventId()) + .invocationId(baseEvent.invocationId()) + .author(baseEvent.author()) + .branch(baseEvent.branch()) + .content(Optional.of(Content.builder().role("user").parts(mergedParts).build())) + .actions(mergedActionsBuilder.build()) + .timestamp(baseEvent.timestamp()) + .build()); } private static Maybe> maybeInvokeBeforeToolCall( @@ -563,5 +579,65 @@ private static Event buildResponseEvent( } } + /** + * Generates a request confirmation event from a function response event. + * + * @param invocationContext The invocation context. + * @param functionCallEvent The event containing the original function call. + * @param functionResponseEvent The event containing the function response. + * @return An optional event containing the request confirmation function call. + */ + public static Optional generateRequestConfirmationEvent( + InvocationContext invocationContext, Event functionCallEvent, Event functionResponseEvent) { + if (functionResponseEvent.actions().requestedToolConfirmations().isEmpty()) { + return Optional.empty(); + } + + List parts = new ArrayList<>(); + Set longRunningToolIds = new HashSet<>(); + ImmutableMap functionCallsById = + functionCallEvent.functionCalls().stream() + .filter(fc -> fc.id().isPresent()) + .collect(toImmutableMap(fc -> fc.id().get(), fc -> fc)); + + for (Map.Entry entry : + functionResponseEvent.actions().requestedToolConfirmations().entrySet().stream() + .filter(fc -> functionCallsById.containsKey(fc.getKey())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)) + .entrySet()) { + + FunctionCall requestConfirmationFunctionCall = + FunctionCall.builder() + .name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME) + .args( + ImmutableMap.of( + "originalFunctionCall", + functionCallsById.get(entry.getKey()), + "toolConfirmation", + entry.getValue())) + .id(generateClientFunctionCallId()) + .build(); + + longRunningToolIds.add(requestConfirmationFunctionCall.id().get()); + parts.add(Part.builder().functionCall(requestConfirmationFunctionCall).build()); + } + + if (parts.isEmpty()) { + return Optional.empty(); + } + + var contentBuilder = Content.builder().parts(parts); + functionResponseEvent.content().flatMap(Content::role).ifPresent(contentBuilder::role); + + return Optional.of( + Event.builder() + .invocationId(invocationContext.invocationId()) + .author(invocationContext.agent().name()) + .branch(invocationContext.branch()) + .content(contentBuilder.build()) + .longRunningToolIds(longRunningToolIds) + .build()); + } + private Functions() {} } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java new file mode 100644 index 000000000..4eb45eb77 --- /dev/null +++ b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java @@ -0,0 +1,185 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import static com.google.adk.flows.llmflows.Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolConfirmation; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Handles tool confirmation information to build the LLM request. */ +public class RequestConfirmationLlmRequestProcessor implements RequestProcessor { + private static final Logger logger = + LoggerFactory.getLogger(RequestConfirmationLlmRequestProcessor.class); + private final ObjectMapper objectMapper; + + public RequestConfirmationLlmRequestProcessor() { + objectMapper = new ObjectMapper().registerModule(new Jdk8Module()); + } + + @Override + public Single processRequest( + InvocationContext invocationContext, LlmRequest llmRequest) { + List events = invocationContext.session().events(); + if (events.isEmpty()) { + logger.info( + "No events are present in the session. Skipping request confirmation processing."); + return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of())); + } + + ImmutableMap requestConfirmationFunctionResponses = + filterRequestConfirmationFunctionResponses(events); + if (requestConfirmationFunctionResponses.isEmpty()) { + logger.info("No request confirmation function responses found."); + return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of())); + } + + for (ImmutableList functionCalls : + events.stream() + .map(Event::functionCalls) + .filter(fc -> !fc.isEmpty()) + .collect(toImmutableList())) { + + ImmutableMap toolsToResumeWithArgs = + filterToolsToResumeWithArgs(functionCalls, requestConfirmationFunctionResponses); + ImmutableMap toolsToResumeWithConfirmation = + toolsToResumeWithArgs.keySet().stream() + .filter( + id -> + events.stream() + .flatMap(e -> e.functionResponses().stream()) + .anyMatch(fr -> Objects.equals(fr.id().orElse(null), id))) + .collect(toImmutableMap(k -> k, requestConfirmationFunctionResponses::get)); + if (toolsToResumeWithConfirmation.isEmpty()) { + logger.info("No tools to resume with confirmation."); + continue; + } + + return assembleEvent( + invocationContext, toolsToResumeWithArgs.values(), toolsToResumeWithConfirmation) + .map(event -> RequestProcessingResult.create(llmRequest, ImmutableList.of(event))) + .toSingle(); + } + + return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of())); + } + + private Maybe assembleEvent( + InvocationContext invocationContext, + Collection functionCalls, + Map toolConfirmations) { + ImmutableMap.Builder toolsBuilder = ImmutableMap.builder(); + if (invocationContext.agent() instanceof LlmAgent llmAgent) { + for (BaseTool tool : llmAgent.tools()) { + toolsBuilder.put(tool.name(), tool); + } + } + + var functionCallEvent = + Event.builder() + .content( + Content.builder() + .parts( + functionCalls.stream() + .map(fc -> Part.builder().functionCall(fc).build()) + .collect(toImmutableList())) + .build()) + .build(); + + return Functions.handleFunctionCalls( + invocationContext, functionCallEvent, toolsBuilder.buildOrThrow(), toolConfirmations); + } + + private ImmutableMap filterRequestConfirmationFunctionResponses( + List events) { + return events.stream() + .filter(event -> Objects.equals(event.author(), "user")) + .flatMap(event -> event.functionResponses().stream()) + .filter(functionResponse -> functionResponse.id().isPresent()) + .filter( + functionResponse -> + Objects.equals( + functionResponse.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)) + .map(this::maybeCreateToolConfirmationEntry) + .flatMap(Optional::stream) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + private Optional> maybeCreateToolConfirmationEntry( + FunctionResponse functionResponse) { + Map responseMap = functionResponse.response().orElse(ImmutableMap.of()); + if (responseMap.size() != 1 || !responseMap.containsKey("response")) { + return Optional.of( + Map.entry( + functionResponse.id().get(), + objectMapper.convertValue(responseMap, ToolConfirmation.class))); + } + + try { + return Optional.of( + Map.entry( + functionResponse.id().get(), + objectMapper.readValue( + (String) responseMap.get("response"), ToolConfirmation.class))); + } catch (JsonProcessingException e) { + logger.error("Failed to parse tool confirmation response", e); + } + + return Optional.empty(); + } + + private ImmutableMap filterToolsToResumeWithArgs( + ImmutableList functionCalls, + Map requestConfirmationFunctionResponses) { + return functionCalls.stream() + .filter(fc -> fc.id().isPresent()) + .filter(fc -> requestConfirmationFunctionResponses.containsKey(fc.id().get())) + .filter( + fc -> Objects.equals(fc.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)) + .filter(fc -> fc.args().orElse(ImmutableMap.of()).containsKey("originalFunctionCall")) + .collect( + toImmutableMap( + fc -> fc.id().get(), + fc -> + objectMapper.convertValue( + fc.args().get().get("originalFunctionCall"), FunctionCall.class))); + } +} diff --git a/core/src/main/java/com/google/adk/tools/FunctionTool.java b/core/src/main/java/com/google/adk/tools/FunctionTool.java index aea543617..498ee8370 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/FunctionTool.java @@ -49,8 +49,13 @@ public class FunctionTool extends BaseTool { private final @Nullable Object instance; private final Method func; private final FunctionDeclaration funcDeclaration; + private final boolean requireConfirmation; public static FunctionTool create(Object instance, Method func) { + return create(instance, func, /* requireConfirmation= */ false); + } + + public static FunctionTool create(Object instance, Method func, boolean requireConfirmation) { if (!areParametersAnnotatedWithSchema(func) && wasCompiledWithDefaultParameterNames(func)) { logger.error( """ @@ -66,10 +71,15 @@ public static FunctionTool create(Object instance, Method func) { + " Expected: %s, Actual: %s", func.getDeclaringClass().getName(), instance.getClass().getName())); } - return new FunctionTool(instance, func, /* isLongRunning= */ false); + return new FunctionTool( + instance, func, /* isLongRunning= */ false, /* requireConfirmation= */ requireConfirmation); } public static FunctionTool create(Method func) { + return create(func, /* requireConfirmation= */ false); + } + + public static FunctionTool create(Method func, boolean requireConfirmation) { if (!areParametersAnnotatedWithSchema(func) && wasCompiledWithDefaultParameterNames(func)) { logger.error( """ @@ -81,13 +91,17 @@ public static FunctionTool create(Method func) { if (!Modifier.isStatic(func.getModifiers())) { throw new IllegalArgumentException("The method provided must be static."); } - return new FunctionTool(null, func, /* isLongRunning= */ false); + return new FunctionTool(null, func, /* isLongRunning= */ false, requireConfirmation); } public static FunctionTool create(Class cls, String methodName) { + return create(cls, methodName, /* requireConfirmation= */ false); + } + + public static FunctionTool create(Class cls, String methodName, boolean requireConfirmation) { for (Method method : cls.getMethods()) { if (method.getName().equals(methodName) && Modifier.isStatic(method.getModifiers())) { - return create(null, method); + return create(null, method, requireConfirmation); } } throw new IllegalArgumentException( @@ -95,10 +109,15 @@ public static FunctionTool create(Class cls, String methodName) { } public static FunctionTool create(Object instance, String methodName) { + return create(instance, methodName, /* requireConfirmation= */ false); + } + + public static FunctionTool create( + Object instance, String methodName, boolean requireConfirmation) { Class cls = instance.getClass(); for (Method method : cls.getMethods()) { if (method.getName().equals(methodName) && !Modifier.isStatic(method.getModifiers())) { - return create(instance, method); + return create(instance, method, requireConfirmation); } } throw new IllegalArgumentException( @@ -127,6 +146,11 @@ private static boolean wasCompiledWithDefaultParameterNames(Method func) { } protected FunctionTool(@Nullable Object instance, Method func, boolean isLongRunning) { + this(instance, func, isLongRunning, /* requireConfirmation= */ false); + } + + protected FunctionTool( + @Nullable Object instance, Method func, boolean isLongRunning, boolean requireConfirmation) { super( func.isAnnotationPresent(Annotations.Schema.class) && !func.getAnnotation(Annotations.Schema.class).name().isEmpty() @@ -148,6 +172,7 @@ protected FunctionTool(@Nullable Object instance, Method func, boolean isLongRun this.funcDeclaration = FunctionCallingUtils.buildFunctionDeclaration( this.func, ImmutableList.of("toolContext", "inputStream")); + this.requireConfirmation = requireConfirmation; } @Override @@ -174,6 +199,20 @@ public boolean isStreaming() { @Override public Single> runAsync(Map args, ToolContext toolContext) { try { + if (requireConfirmation) { + if (toolContext.toolConfirmation().isEmpty()) { + toolContext.requestConfirmation( + String.format( + "Please approve or reject the tool call %s() by responding with a" + + " FunctionResponse with an expected ToolConfirmation payload.", + name())); + return Single.just( + ImmutableMap.of( + "error", "This tool call requires confirmation, please approve or reject.")); + } else if (!toolContext.toolConfirmation().get().confirmed()) { + return Single.just(ImmutableMap.of("error", "This tool call is rejected.")); + } + } return this.call(args, toolContext).defaultIfEmpty(ImmutableMap.of()); } catch (Exception e) { logger.error("Exception occurred while calling function tool: " + func.getName(), e); @@ -182,7 +221,6 @@ public Single> runAsync(Map args, ToolContex } } - @SuppressWarnings("unchecked") // For tool parameter type casting. private Maybe> call(Map args, ToolContext toolContext) throws IllegalAccessException, InvocationTargetException { Object[] arguments = buildArguments(args, toolContext, null); diff --git a/core/src/main/java/com/google/adk/tools/LoadMemoryTool.java b/core/src/main/java/com/google/adk/tools/LoadMemoryTool.java index 638cd67cb..d597bea38 100644 --- a/core/src/main/java/com/google/adk/tools/LoadMemoryTool.java +++ b/core/src/main/java/com/google/adk/tools/LoadMemoryTool.java @@ -22,7 +22,11 @@ private static Method getLoadMemoryMethod() { } public LoadMemoryTool() { - super(null, getLoadMemoryMethod(), false); + super( + /* instance= */ null, + getLoadMemoryMethod(), + /* isLongRunning= */ false, + /* requireConfirmation= */ false); } /** diff --git a/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java b/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java index b645fe589..328be1968 100644 --- a/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java @@ -47,10 +47,10 @@ public static LongRunningFunctionTool create(Object instance, String methodName) } private LongRunningFunctionTool(Method func) { - super(null, func, /* isLongRunning= */ true); + super(null, func, /* isLongRunning= */ true, /* requireConfirmation= */ false); } private LongRunningFunctionTool(Object instance, Method func) { - super(instance, func, /* isLongRunning= */ true); + super(instance, func, /* isLongRunning= */ true, /* requireConfirmation= */ false); } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java new file mode 100644 index 000000000..5fec336f2 --- /dev/null +++ b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import static com.google.adk.flows.llmflows.Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME; +import static com.google.adk.testing.TestUtils.createLlmResponse; +import static com.google.adk.testing.TestUtils.createTestAgentBuilder; +import static com.google.adk.testing.TestUtils.createTestLlm; +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.RunConfig; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.plugins.PluginManager; +import com.google.adk.sessions.Session; +import com.google.adk.testing.TestLlm; +import com.google.adk.testing.TestUtils.EchoTool; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class RequestConfirmationLlmRequestProcessorTest { + private static final String REQUEST_CONFIRMATION_FUNCTION_CALL_ID = "fc1"; + private static final String ECHO_TOOL_NAME = "echo_tool"; + + private static final FunctionCall ORIGINAL_FUNCTION_CALL = + FunctionCall.builder() + .id("fc0") + .name(ECHO_TOOL_NAME) + .args(ImmutableMap.of("say", "hello")) + .build(); + + private static final Event REQUEST_CONFIRMATION_EVENT = + Event.builder() + .author("model") + .content( + Content.fromParts( + Part.builder() + .functionCall( + FunctionCall.builder() + .id(REQUEST_CONFIRMATION_FUNCTION_CALL_ID) + .name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME) + .args(ImmutableMap.of("originalFunctionCall", ORIGINAL_FUNCTION_CALL)) + .build()) + .build())) + .build(); + + private static final Event USER_CONFIRMATION_EVENT = + Event.builder() + .author("user") + .content( + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(REQUEST_CONFIRMATION_FUNCTION_CALL_ID) + .name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME) + .response(ImmutableMap.of("confirmed", true)) + .build()) + .build())) + .build(); + + private static final RequestConfirmationLlmRequestProcessor processor = + new RequestConfirmationLlmRequestProcessor(); + + @Test + public void runAsync_withConfirmation_callsOriginalFunction() { + LlmAgent agent = createAgentWithEchoTool(); + Session session = + Session.builder("session_id") + .events(ImmutableList.of(REQUEST_CONFIRMATION_EVENT, USER_CONFIRMATION_EVENT)) + .build(); + + InvocationContext context = createInvocationContext(agent, session); + + RequestProcessor.RequestProcessingResult result = + processor.processRequest(context, LlmRequest.builder().build()).blockingGet(); + + assertThat(result).isNotNull(); + assertThat(result.events()).hasSize(1); + Event event = result.events().iterator().next(); + assertThat(event.functionResponses()).hasSize(1); + FunctionResponse fr = event.functionResponses().get(0); + assertThat(fr.id()).hasValue("fc0"); + assertThat(fr.name()).hasValue(ECHO_TOOL_NAME); + assertThat(fr.response()).hasValue(ImmutableMap.of("result", ImmutableMap.of("say", "hello"))); + } + + @Test + public void runAsync_noEvents_empty() { + LlmAgent agent = createAgentWithEchoTool(); + Session session = Session.builder("session_id").events(ImmutableList.of()).build(); + + assertThat( + processor + .processRequest( + createInvocationContext(agent, session), LlmRequest.builder().build()) + .blockingGet() + .events()) + .isEmpty(); + } + + private static InvocationContext createInvocationContext(LlmAgent agent, Session session) { + return new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* pluginManager= */ new PluginManager(), + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ InvocationContext.newInvocationContextId(), + /* agent= */ agent, + /* session= */ session, + /* userContent= */ Optional.empty(), + /* runConfig= */ RunConfig.builder().build(), + /* endInvocation= */ false); + } + + private static LlmAgent createAgentWithEchoTool() { + Content contentWithFunctionCall = + Content.fromParts( + Part.fromText("text"), + Part.fromFunctionCall(ECHO_TOOL_NAME, ImmutableMap.of("arg", "value"))); + Content unreachableContent = Content.fromParts(Part.fromText("This should never be returned.")); + TestLlm testLlm = + createTestLlm( + createLlmResponse(contentWithFunctionCall), createLlmResponse(unreachableContent)); + return createTestAgentBuilder(testLlm).tools(new EchoTool()).maxSteps(2).build(); + } +} diff --git a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java index 6576431a9..1e717cf57 100644 --- a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java @@ -636,6 +636,166 @@ public void call_nonStaticWithAllSupportedParameterTypes() throws Exception { .buildOrThrow()); } + @Test + public void runAsync_withRequireConfirmation() throws Exception { + Method method = Functions.class.getMethod("returnsMap"); + FunctionTool tool = + new FunctionTool(null, method, /* isLongRunning= */ false, /* requireConfirmation= */ true); + ToolContext toolContext = + ToolContext.builder( + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + /* session= */ Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) + .functionCallId("functionCallId") + .build(); + + // First call, should request confirmation + Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result) + .containsExactly( + "error", "This tool call requires confirmation, please approve or reject."); + assertThat(toolContext.actions().requestedToolConfirmations()).containsKey("functionCallId"); + assertThat(toolContext.actions().requestedToolConfirmations().get("functionCallId").hint()) + .isEqualTo( + "Please approve or reject the tool call returnsMap() by responding with a" + + " FunctionResponse with an expected ToolConfirmation payload."); + + // Second call, user rejects + toolContext.toolConfirmation(ToolConfirmation.builder().confirmed(false).build()); + result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result).containsExactly("error", "This tool call is rejected."); + + // Third call, user approves + toolContext.toolConfirmation(ToolConfirmation.builder().confirmed(true).build()); + result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result).containsExactly("key", "value"); + } + + @Test + public void create_instanceMethodWithConfirmation_requestsConfirmation() throws Exception { + Functions functions = new Functions(); + Method method = Functions.class.getMethod("nonStaticVoidReturnWithoutSchema"); + FunctionTool tool = FunctionTool.create(functions, method, /* requireConfirmation= */ true); + ToolContext toolContext = + ToolContext.builder( + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + /* session= */ Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) + .functionCallId("functionCallId") + .build(); + + Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result) + .containsExactly( + "error", "This tool call requires confirmation, please approve or reject."); + assertThat(toolContext.actions().requestedToolConfirmations()).containsKey("functionCallId"); + } + + @Test + public void create_staticMethodWithConfirmation_requestsConfirmation() throws Exception { + Method method = Functions.class.getMethod("voidReturnWithoutSchema"); + FunctionTool tool = FunctionTool.create(method, /* requireConfirmation= */ true); + ToolContext toolContext = + ToolContext.builder( + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + /* session= */ Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) + .functionCallId("functionCallId") + .build(); + + Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result) + .containsExactly( + "error", "This tool call requires confirmation, please approve or reject."); + assertThat(toolContext.actions().requestedToolConfirmations()).containsKey("functionCallId"); + } + + @Test + public void create_classMethodNameWithConfirmation_requestsConfirmation() throws Exception { + FunctionTool tool = + FunctionTool.create( + Functions.class, "voidReturnWithoutSchema", /* requireConfirmation= */ true); + ToolContext toolContext = + ToolContext.builder( + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + /* session= */ Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) + .functionCallId("functionCallId") + .build(); + + Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result) + .containsExactly( + "error", "This tool call requires confirmation, please approve or reject."); + assertThat(toolContext.actions().requestedToolConfirmations()).containsKey("functionCallId"); + } + + @Test + public void create_instanceMethodNameWithConfirmation_requestsConfirmation() throws Exception { + Functions functions = new Functions(); + FunctionTool tool = + FunctionTool.create( + functions, "nonStaticVoidReturnWithoutSchema", /* requireConfirmation= */ true); + ToolContext toolContext = + ToolContext.builder( + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + /* session= */ Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) + .functionCallId("functionCallId") + .build(); + + Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result) + .containsExactly( + "error", "This tool call requires confirmation, please approve or reject."); + assertThat(toolContext.actions().requestedToolConfirmations()).containsKey("functionCallId"); + } + static class Functions { @Annotations.Schema(