From a16476039bacb1933644a3ed6c891d4dcba03b03 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 30 Jan 2026 10:24:16 -0800 Subject: [PATCH] refactor: refactors parts of the ADK codebase to improve null safety and consistency This CL refactors parts of the ADK codebase to improve null safety and consistency. The main changes include: 1. **`BaseAgent`**: * `beforeAgentCallback` and `afterAgentCallback` fields and their accessors now use `ImmutableList` (defaulting to empty) instead of `Optional`. * `findAgent` and `findSubAgent` now return `Optional`, with `findSubAgent` being reimplemented using Java Streams. 2. **`BaseAgentConfig`**: Getters for `subAgents`, `beforeAgentCallbacks`, and `afterAgentCallbacks` now return an empty list if the underlying field is null. 3. **`CallbackUtil`**: `getBeforeAgentCallbacks` and `getAfterAgentCallbacks` return `ImmutableList.of()` instead of `null` for null inputs. 4. **`LlmAgent`**: The `codeExecutor()` method now returns `Optional`. These changes necessitate updates in `BaseLlmFlow`, `CodeExecution`, and `Runner` to handle the new `Optional` return types. PiperOrigin-RevId: 863294916 --- .../java/com/google/adk/agents/BaseAgent.java | 67 ++++++++++--------- .../google/adk/agents/BaseAgentConfig.java | 7 +- .../com/google/adk/agents/CallbackUtil.java | 45 ++++++------- .../java/com/google/adk/agents/LlmAgent.java | 5 +- .../adk/flows/llmflows/BaseLlmFlow.java | 12 ++-- .../adk/flows/llmflows/CodeExecution.java | 22 +++--- .../java/com/google/adk/runner/Runner.java | 8 +-- .../com/google/adk/agents/BaseAgentTest.java | 8 +-- .../adk/agents/ConfigAgentUtilsTest.java | 6 +- 9 files changed, 91 insertions(+), 89 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 948d5eba..1c235d18 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -57,10 +57,10 @@ public abstract class BaseAgent { */ private BaseAgent parentAgent; - private final List subAgents; + private final ImmutableList subAgents; - private final Optional> beforeAgentCallback; - private final Optional> afterAgentCallback; + private final ImmutableList beforeAgentCallback; + private final ImmutableList afterAgentCallback; /** * Creates a new BaseAgent. @@ -82,9 +82,13 @@ public BaseAgent( this.name = name; this.description = description; this.parentAgent = null; - this.subAgents = subAgents != null ? subAgents : ImmutableList.of(); - this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback); - this.afterAgentCallback = Optional.ofNullable(afterAgentCallback); + this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents); + this.beforeAgentCallback = + beforeAgentCallback == null + ? ImmutableList.of() + : ImmutableList.copyOf(beforeAgentCallback); + this.afterAgentCallback = + afterAgentCallback == null ? ImmutableList.of() : ImmutableList.copyOf(afterAgentCallback); // Establish parent relationships for all sub-agents if needed. for (BaseAgent subAgent : this.subAgents) { @@ -144,38 +148,40 @@ public BaseAgent rootAgent() { /** * Finds an agent (this or descendant) by name. * - * @return the agent or descendant with the given name, or {@code null} if not found. + * @return an {@link Optional} containing the agent or descendant with the given name, or {@link + * Optional#empty()} if not found. */ - public BaseAgent findAgent(String name) { + public Optional findAgent(String name) { if (this.name().equals(name)) { - return this; + return Optional.of(this); } return findSubAgent(name); } - /** Recursively search sub agent by name. */ - public @Nullable BaseAgent findSubAgent(String name) { - for (BaseAgent subAgent : subAgents) { - if (subAgent.name().equals(name)) { - return subAgent; - } - BaseAgent result = subAgent.findSubAgent(name); - if (result != null) { - return result; - } - } - return null; + /** + * Recursively search sub agent by name. + * + * @return an {@link Optional} containing the sub agent with the given name, or {@link + * Optional#empty()} if not found. + */ + public Optional findSubAgent(String name) { + return subAgents.stream() + .map( + subAgent -> + subAgent.name().equals(name) ? Optional.of(subAgent) : subAgent.findSubAgent(name)) + .flatMap(Optional::stream) + .findFirst(); } - public List subAgents() { + public ImmutableList subAgents() { return subAgents; } - public Optional> beforeAgentCallback() { + public ImmutableList beforeAgentCallback() { return beforeAgentCallback; } - public Optional> afterAgentCallback() { + public ImmutableList afterAgentCallback() { return afterAgentCallback; } @@ -184,8 +190,8 @@ public Optional> afterAgentCallback() { * *

This method is only for use by Agent Development Kit. */ - public List canonicalBeforeAgentCallbacks() { - return beforeAgentCallback.orElse(ImmutableList.of()); + public ImmutableList canonicalBeforeAgentCallbacks() { + return beforeAgentCallback; } /** @@ -193,8 +199,8 @@ public List canonicalBeforeAgentCallbacks() { * *

This method is only for use by Agent Development Kit. */ - public List canonicalAfterAgentCallbacks() { - return afterAgentCallback.orElse(ImmutableList.of()); + public ImmutableList canonicalAfterAgentCallbacks() { + return afterAgentCallback; } /** @@ -239,8 +245,7 @@ public Flowable runAsync(InvocationContext parentContext) { () -> callCallback( beforeCallbacksToFunctions( - invocationContext.pluginManager(), - beforeAgentCallback.orElse(ImmutableList.of())), + invocationContext.pluginManager(), beforeAgentCallback), invocationContext) .flatMapPublisher( beforeEventOpt -> { @@ -257,7 +262,7 @@ public Flowable runAsync(InvocationContext parentContext) { callCallback( afterCallbacksToFunctions( invocationContext.pluginManager(), - afterAgentCallback.orElse(ImmutableList.of())), + afterAgentCallback), invocationContext) .flatMapPublisher(Flowable::fromOptional)); diff --git a/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java b/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java index e38895af..67b6ac40 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java @@ -16,6 +16,7 @@ package com.google.adk.agents; +import java.util.Collections; import java.util.List; /** @@ -132,7 +133,7 @@ public String agentClass() { } public List subAgents() { - return subAgents; + return subAgents == null ? Collections.emptyList() : subAgents; } public void setSubAgents(List subAgents) { @@ -140,7 +141,7 @@ public void setSubAgents(List subAgents) { } public List beforeAgentCallbacks() { - return beforeAgentCallbacks; + return beforeAgentCallbacks == null ? Collections.emptyList() : beforeAgentCallbacks; } public void setBeforeAgentCallbacks(List beforeAgentCallbacks) { @@ -148,7 +149,7 @@ public void setBeforeAgentCallbacks(List beforeAgentCallbacks) { } public List afterAgentCallbacks() { - return afterAgentCallbacks; + return afterAgentCallbacks == null ? Collections.emptyList() : afterAgentCallbacks; } public void setAfterAgentCallbacks(List afterAgentCallbacks) { diff --git a/core/src/main/java/com/google/adk/agents/CallbackUtil.java b/core/src/main/java/com/google/adk/agents/CallbackUtil.java index 728fd1d5..b22c699f 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackUtil.java +++ b/core/src/main/java/com/google/adk/agents/CallbackUtil.java @@ -26,7 +26,6 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Maybe; import java.util.List; -import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,45 +37,43 @@ public final class CallbackUtil { * Normalizes before-agent callbacks. * * @param beforeAgentCallback Callback list (sync or async). - * @return normalized async callbacks, or null if input is null. + * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static @Nullable ImmutableList getBeforeAgentCallbacks( + public static ImmutableList getBeforeAgentCallbacks( List beforeAgentCallback) { - if (beforeAgentCallback == null) { - return null; - } else if (beforeAgentCallback.isEmpty()) { + if (beforeAgentCallback == null || beforeAgentCallback.isEmpty()) { return ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (BeforeAgentCallbackBase callback : beforeAgentCallback) { - if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) { - builder.add(beforeAgentCallbackInstance); - } else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) { - builder.add( - (callbackContext) -> - Maybe.fromOptional(beforeAgentCallbackSyncInstance.call(callbackContext))); - } else { - logger.warn( - "Invalid beforeAgentCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } + } + + ImmutableList.Builder builder = ImmutableList.builder(); + for (BeforeAgentCallbackBase callback : beforeAgentCallback) { + if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) { + builder.add(beforeAgentCallbackInstance); + } else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) { + builder.add( + (callbackContext) -> + Maybe.fromOptional(beforeAgentCallbackSyncInstance.call(callbackContext))); + } else { + logger.warn( + "Invalid beforeAgentCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); } - return builder.build(); } + return builder.build(); } /** * Normalizes after-agent callbacks. * * @param afterAgentCallback Callback list (sync or async). - * @return normalized async callbacks, or null if input is null. + * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static @Nullable ImmutableList getAfterAgentCallbacks( + public static ImmutableList getAfterAgentCallbacks( List afterAgentCallback) { if (afterAgentCallback == null) { - return null; + return ImmutableList.of(); } else if (afterAgentCallback.isEmpty()) { return ImmutableList.of(); } else { diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 1f16d7c0..87967bb6 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -935,9 +935,8 @@ public Optional outputKey() { return outputKey; } - @Nullable - public BaseCodeExecutor codeExecutor() { - return codeExecutor.orElse(null); + public Optional codeExecutor() { + return codeExecutor; } public Model resolvedModel() { diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 46b3f195..cfbadb9f 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -388,15 +388,15 @@ private Flowable runOneStep(InvocationContext context) { String agentToTransfer = event.actions().transferToAgent().get(); logger.debug("Transferring to agent: {}", agentToTransfer); BaseAgent rootAgent = context.agent().rootAgent(); - BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer); - if (nextAgent == null) { + Optional nextAgent = rootAgent.findAgent(agentToTransfer); + if (nextAgent.isEmpty()) { String errorMsg = "Agent not found for transfer: " + agentToTransfer; logger.error(errorMsg); return postProcessedEvents.concatWith( Flowable.error(new IllegalStateException(errorMsg))); } return postProcessedEvents.concatWith( - Flowable.defer(() -> nextAgent.runAsync(context))); + Flowable.defer(() -> nextAgent.get().runAsync(context))); } return postProcessedEvents; }); @@ -574,14 +574,14 @@ public void onError(Throwable e) { Flowable events = Flowable.just(event); if (event.actions().transferToAgent().isPresent()) { BaseAgent rootAgent = invocationContext.agent().rootAgent(); - BaseAgent nextAgent = + Optional nextAgent = rootAgent.findAgent(event.actions().transferToAgent().get()); - if (nextAgent == null) { + if (nextAgent.isEmpty()) { throw new IllegalStateException( "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - nextAgent.runLive(invocationContext); + nextAgent.get().runLive(invocationContext); events = Flowable.concat(events, nextAgentEvents); } return events; diff --git a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java index 64d95cef..61b89e96 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java @@ -108,12 +108,12 @@ private static class CodeExecutionRequestProcessor implements RequestProcessor { public Single processRequest( InvocationContext invocationContext, LlmRequest llmRequest) { if (!(invocationContext.agent() instanceof LlmAgent llmAgent) - || llmAgent.codeExecutor() == null) { + || llmAgent.codeExecutor().isEmpty()) { return Single.just( RequestProcessor.RequestProcessingResult.create(llmRequest, ImmutableList.of())); } - if (llmAgent.codeExecutor() instanceof BuiltInCodeExecutor builtInCodeExecutor) { + if (llmAgent.codeExecutor().get() instanceof BuiltInCodeExecutor builtInCodeExecutor) { var llmRequestBuilder = llmRequest.toBuilder(); builtInCodeExecutor.processLlmRequest(llmRequestBuilder); LlmRequest updatedLlmRequest = llmRequestBuilder.build(); @@ -124,8 +124,8 @@ public Single processRequest( Flowable preprocessorEvents = runPreProcessor(invocationContext, llmRequest); // Convert the code execution parts to text parts. - if (llmAgent.codeExecutor() != null) { - BaseCodeExecutor baseCodeExecutor = llmAgent.codeExecutor(); + if (llmAgent.codeExecutor().isPresent()) { + BaseCodeExecutor baseCodeExecutor = llmAgent.codeExecutor().get(); List updatedContents = new ArrayList<>(); for (Content content : llmRequest.contents()) { List delimiters = @@ -173,10 +173,11 @@ private static Flowable runPreProcessor( return Flowable.empty(); } - var codeExecutor = llmAgent.codeExecutor(); - if (codeExecutor == null) { + var codeExecutorOptional = llmAgent.codeExecutor(); + if (codeExecutorOptional.isEmpty()) { return Flowable.empty(); } + var codeExecutor = codeExecutorOptional.get(); if (codeExecutor instanceof BuiltInCodeExecutor) { return Flowable.empty(); @@ -268,10 +269,11 @@ private static Flowable runPostProcessor( if (!(invocationContext.agent() instanceof LlmAgent llmAgent)) { return Flowable.empty(); } - var codeExecutor = llmAgent.codeExecutor(); - if (codeExecutor == null) { + var codeExecutorOptional = llmAgent.codeExecutor(); + if (codeExecutorOptional.isEmpty()) { return Flowable.empty(); } + var codeExecutor = codeExecutorOptional.get(); if (llmResponse.content().isEmpty()) { return Flowable.empty(); } @@ -387,8 +389,8 @@ private static List extractAndReplaceInlineFiles( private static Optional getOrSetExecutionId( InvocationContext invocationContext, CodeExecutorContext codeExecutorContext) { if (!(invocationContext.agent() instanceof LlmAgent llmAgent) - || llmAgent.codeExecutor() == null - || !llmAgent.codeExecutor().stateful()) { + || llmAgent.codeExecutor().isEmpty() + || !llmAgent.codeExecutor().get().stateful()) { return Optional.empty(); } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 66bb5860..4d21d615 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -768,14 +768,14 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { return rootAgent; } - BaseAgent agent = rootAgent.findSubAgent(author); + Optional agent = rootAgent.findSubAgent(author); - if (agent == null) { + if (agent.isEmpty()) { continue; } - if (this.isTransferableAcrossAgentTree(agent)) { - return agent; + if (this.isTransferableAcrossAgentTree(agent.get())) { + return agent.get(); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 34543682..2ae53d0e 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -59,10 +59,10 @@ public void findAgent_returnsCorrectAgent() { TestBaseAgent agent = new TestBaseAgent( TEST_AGENT_NAME, TEST_AGENT_DESCRIPTION, null, ImmutableList.of(subAgent), null, null); - assertThat(agent.findAgent("subSubAgent")).isEqualTo(subSubAgent); - assertThat(agent.findAgent("subAgent")).isEqualTo(subAgent); - assertThat(agent.findAgent(TEST_AGENT_NAME)).isEqualTo(agent); - assertThat(agent.findAgent("nonExistent")).isNull(); + assertThat(agent.findAgent("subSubAgent")).hasValue(subSubAgent); + assertThat(agent.findAgent("subAgent")).hasValue(subAgent); + assertThat(agent.findAgent(TEST_AGENT_NAME)).hasValue(agent); + assertThat(agent.findAgent("nonExistent")).isEmpty(); } @Test diff --git a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java index 11e07a09..4f6ea610 100644 --- a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java +++ b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java @@ -1209,10 +1209,8 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() assertThat(agent).isInstanceOf(LlmAgent.class); LlmAgent llm = (LlmAgent) agent; - assertThat(agent.beforeAgentCallback()).isPresent(); - assertThat(agent.beforeAgentCallback().get()).hasSize(2); - assertThat(agent.afterAgentCallback()).isPresent(); - assertThat(agent.afterAgentCallback().get()).hasSize(1); + assertThat(agent.beforeAgentCallback()).hasSize(2); + assertThat(agent.afterAgentCallback()).hasSize(1); assertThat(llm.beforeModelCallback()).hasSize(1); assertThat(llm.afterModelCallback()).hasSize(1);