From 8a5fe46da32e8567500dd534712b6d3c1c8b92e0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 30 Jan 2026 00:59:32 -0800 Subject: [PATCH] feat: implement partial event aggregation in RemoteA2AAgent PiperOrigin-RevId: 863112193 --- a2a/pom.xml | 5 + .../com/google/adk/a2a/RemoteA2AAgent.java | 215 ++++- .../google/adk/a2a/common/A2AMetadata.java | 28 + .../adk/a2a/converters/PartConverter.java | 67 +- .../adk/a2a/converters/ResponseConverter.java | 15 +- .../google/adk/a2a/RemoteA2AAgentTest.java | 768 ++++++++++++++++++ .../java/com/google/adk/events/Event.java | 34 + .../com/google/adk/models/LlmResponse.java | 14 + 8 files changed, 1106 insertions(+), 40 deletions(-) create mode 100644 a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java create mode 100644 a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java diff --git a/a2a/pom.xml b/a2a/pom.xml index dc606afa..88014840 100644 --- a/a2a/pom.xml +++ b/a2a/pom.xml @@ -106,6 +106,11 @@ ${truth.version} test + + org.mockito + mockito-core + test + diff --git a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java index 5e6e341d..bc6b9065 100644 --- a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java @@ -2,7 +2,11 @@ import static com.google.common.base.Strings.nullToEmpty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.google.adk.a2a.common.A2AClientError; +import com.google.adk.a2a.common.A2AMetadata; import com.google.adk.a2a.converters.EventConverter; import com.google.adk.a2a.converters.ResponseConverter; import com.google.adk.agents.BaseAgent; @@ -11,6 +15,9 @@ import com.google.adk.events.Event; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.Part; import io.a2a.client.Client; import io.a2a.client.ClientEvent; import io.a2a.client.TaskEvent; @@ -22,8 +29,11 @@ import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.FlowableEmitter; +import java.time.Instant; +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import org.slf4j.Logger; @@ -54,6 +64,8 @@ public class RemoteA2AAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(RemoteA2AAgent.class); + private static final ObjectMapper objectMapper = + new ObjectMapper().registerModule(new JavaTimeModule()); private final AgentCard agentCard; private final Client a2aClient; @@ -173,60 +185,189 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { } Message originalMessage = a2aMessageOpt.get(); + String requestJson; + try { + requestJson = objectMapper.writeValueAsString(originalMessage); + } catch (JsonProcessingException e) { + logger.warn("Failed to serialize request", e); + requestJson = null; + } + String finalRequestJson = requestJson; return Flowable.create( emitter -> { - FlowableEmitter flowableEmitter = emitter.serialize(); - AtomicBoolean done = new AtomicBoolean(false); + StreamHandler handler = + new StreamHandler(emitter.serialize(), invocationContext, finalRequestJson); ImmutableList> consumers = - ImmutableList.of( - (event, unused) -> - handleClientEvent(event, flowableEmitter, invocationContext, done)); - a2aClient.sendMessage( - originalMessage, consumers, e -> handleClientError(e, flowableEmitter, done), null); + ImmutableList.of(handler::handleEvent); + a2aClient.sendMessage(originalMessage, consumers, handler::handleError, null); }, BackpressureStrategy.BUFFER); } - private void handleClientError(Throwable e, FlowableEmitter emitter, AtomicBoolean done) { - // Mark the flow as done if it is already cancelled. - done.compareAndSet(false, emitter.isCancelled()); + private class StreamHandler { + private final FlowableEmitter emitter; + private final InvocationContext invocationContext; + private final String requestJson; + private final AtomicBoolean done = new AtomicBoolean(false); + private final StringBuilder textBuffer = new StringBuilder(); + private final StringBuilder thoughtsBuffer = new StringBuilder(); + + StreamHandler( + FlowableEmitter emitter, InvocationContext invocationContext, String requestJson) { + this.emitter = emitter; + this.invocationContext = invocationContext; + this.requestJson = requestJson; + } + + void handleError(Throwable e) { + // Mark the flow as done if it is already cancelled. + done.compareAndSet(false, emitter.isCancelled()); - // If the flow is already done, stop processing and exit the consumer. - if (done.get()) { - return; + // If the flow is already done, stop processing. + if (done.get()) { + return; + } + // If the error is raised, complete the flow with an error. + if (!done.getAndSet(true)) { + emitter.tryOnError(new A2AClientError("Failed to communicate with the remote agent", e)); + } } - // If the error is raised, complete the flow with an error. - if (!done.getAndSet(true)) { - emitter.tryOnError(new A2AClientError("Failed to communicate with the remote agent", e)); + + void handleEvent(ClientEvent clientEvent, AgentCard unused) { + // Mark the flow as done if it is already cancelled. + done.compareAndSet(false, emitter.isCancelled()); + + // If the flow is already done, stop processing. + if (done.get()) { + return; + } + + Optional eventOpt = + ResponseConverter.clientEventToEvent(clientEvent, invocationContext); + if (eventOpt.isPresent()) { + Event event = eventOpt.get(); + enrichWithMetadata(event, clientEvent); + boolean consumed = processContent(event); + if (!consumed) { + emitEvents(event); + } + } + + // For non-streaming communication, complete the flow; for streaming, wait until the client + // marks the completion. + if (isCompleted(clientEvent) || !streaming) { + // Only complete the flow once. + if (!done.getAndSet(true)) { + emitter.onComplete(); + } + } } - } - private void handleClientEvent( - ClientEvent clientEvent, - FlowableEmitter emitter, - InvocationContext invocationContext, - AtomicBoolean done) { - // Mark the flow as done if it is already cancelled. - done.compareAndSet(false, emitter.isCancelled()); - - // If the flow is already done, stop processing and exit the consumer. - if (done.get()) { - return; + private void enrichWithMetadata(Event event, ClientEvent clientEvent) { + List eventMetadata = + new ArrayList<>(event.customMetadata().orElse(ImmutableList.of())); + if (requestJson != null) { + eventMetadata.add( + CustomMetadata.builder() + .key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST)) + .stringValue(requestJson) + .build()); + } + try { + if (clientEvent != null) { + eventMetadata.add( + CustomMetadata.builder() + .key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.RESPONSE)) + .stringValue(objectMapper.writeValueAsString(clientEvent)) + .build()); + } + } catch (JsonProcessingException e) { + logger.warn("Failed to serialize response metadata", e); + } + event.setCustomMetadata(Optional.of(ImmutableList.copyOf(eventMetadata))); } - Optional event = ResponseConverter.clientEventToEvent(clientEvent, invocationContext); - if (event.isPresent()) { - emitter.onNext(event.get()); + private boolean processContent(Event event) { + if (!event.partial().orElse(false)) { + return false; + } + + List nonTextParts = new ArrayList<>(); + for (Part part : event.content().flatMap(Content::parts).orElse(ImmutableList.of())) { + if (part.text().isPresent()) { + String t = part.text().get(); + if (part.thought().orElse(false)) { + thoughtsBuffer.append(t); + } else { + textBuffer.append(t); + } + } else { + nonTextParts.add(part); + } + } + + if (nonTextParts.isEmpty()) { + return true; + } + + Content nonTextContent = Content.builder().role("model").parts(nonTextParts).build(); + event.setContent(Optional.of(nonTextContent)); + return false; } - // For non-streaming communication, complete the flow; for streaming, wait until the client - // marks the completion. - if (isCompleted(clientEvent) || !streaming) { - // Only complete the flow once. - if (!done.getAndSet(true)) { - emitter.onComplete(); + private void emitEvents(Event event) { + List parts = new ArrayList<>(); + if (thoughtsBuffer.length() > 0) { + parts.add(Part.builder().thought(true).text(thoughtsBuffer.toString()).build()); + thoughtsBuffer.setLength(0); } + if (textBuffer.length() > 0) { + parts.add(Part.builder().text(textBuffer.toString()).build()); + textBuffer.setLength(0); + } + + if (!parts.isEmpty()) { + Content aggregatedContent = Content.builder().role("model").parts(parts).build(); + + if (event.content().flatMap(Content::parts).orElse(ImmutableList.of()).isEmpty()) { + // Reuse empty event for aggregated content. + event.setContent(Optional.of(aggregatedContent)); + emitter.onNext(event); + } else { + // Emit separate aggregated event first. + Event aggEvent = createAggregatedEvent(aggregatedContent); + emitter.onNext(aggEvent); + emitter.onNext(event); + } + } else { + emitter.onNext(event); + } + } + + private Event createAggregatedEvent(Content content) { + List aggMetadata = new ArrayList<>(); + aggMetadata.add( + CustomMetadata.builder() + .key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.AGGREGATED)) + .stringValue("true") + .build()); + if (requestJson != null) { + aggMetadata.add( + CustomMetadata.builder() + .key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST)) + .stringValue(requestJson) + .build()); + } + + return Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId(invocationContext.invocationId()) + .author("agent") + .content(Optional.of(content)) + .timestamp(Instant.now().toEpochMilli()) + .customMetadata(Optional.of(ImmutableList.copyOf(aggMetadata))) + .build(); } } diff --git a/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java b/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java new file mode 100644 index 00000000..ff32fcaf --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java @@ -0,0 +1,28 @@ +package com.google.adk.a2a.common; + +/** Constants and utilities for A2A metadata keys. */ +public final class A2AMetadata { + + /** Enum for A2A custom metadata keys. */ + public enum Key { + REQUEST("request"), + RESPONSE("response"), + AGGREGATED("aggregated"); + + private final String value; + + Key(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + public static String toA2AMetaKey(Key key) { + return "a2a:" + key.value; + } + + private A2AMetadata() {} +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index c6ef0640..5513bff0 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -7,10 +7,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Blob; +import com.google.genai.types.CodeExecutionResult; +import com.google.genai.types.ExecutableCode; import com.google.genai.types.Content; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Language; +import com.google.genai.types.Outcome; import com.google.genai.types.Part; import io.a2a.spec.DataPart; import io.a2a.spec.FileContent; @@ -22,6 +26,7 @@ import java.util.Base64; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import org.slf4j.Logger; @@ -95,6 +100,10 @@ public static Optional convertGenaiPartToA2aPart(Part part) { return createDataPartFromFunctionCall(part.functionCall().get()); } else if (part.functionResponse().isPresent()) { return createDataPartFromFunctionResponse(part.functionResponse().get()); + } else if (part.executableCode().isPresent()) { + return createDataPartFromExecutableCode(part.executableCode().get()); + } else if (part.codeExecutionResult().isPresent()) { + return createDataPartFromCodeExecutionResult(part.codeExecutionResult().get()); } logger.warn("Cannot convert unsupported part for Google GenAI part: " + part); @@ -174,6 +183,33 @@ private static Optional convertDataPartToGenAiPart( .build()); } + if ((data.containsKey("code") && data.containsKey("language")) + || metadataType.equals(A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE)) { + String code = String.valueOf(data.getOrDefault("code", "")); + String language = + String.valueOf(data.getOrDefault("language", "PYTHON")).toUpperCase(Locale.getDefault()); + return Optional.of( + com.google.genai.types.Part.builder() + .executableCode( + ExecutableCode.builder().code(code).language(new Language(language)).build()) + .build()); + } + + if ((data.containsKey("outcome") && data.containsKey("output")) + || metadataType.equals(A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT)) { + String outcome = + String.valueOf(data.getOrDefault("outcome", "OK")).toUpperCase(Locale.getDefault()); + String output = String.valueOf(data.getOrDefault("output", "")); + return Optional.of( + com.google.genai.types.Part.builder() + .codeExecutionResult( + CodeExecutionResult.builder() + .outcome(new Outcome(outcome)) + .output(output) + .build()) + .build()); + } + try { String json = objectMapper.writeValueAsString(data); return Optional.of(com.google.genai.types.Part.builder().text(json).build()); @@ -231,6 +267,32 @@ private static Optional createDataPartFromFunctionResponse( return Optional.of(new DataPart(data, metadata)); } + private static Optional createDataPartFromExecutableCode( + ExecutableCode executableCode) { + Map data = new HashMap<>(); + data.put("code", executableCode.code().orElse("")); + data.put("language", executableCode.language().map(Language::toString).orElse("PYTHON")); + + ImmutableMap metadata = + ImmutableMap.of( + A2A_DATA_PART_METADATA_TYPE_KEY, A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE); + + return Optional.of(new DataPart(data, metadata)); + } + + private static Optional createDataPartFromCodeExecutionResult( + CodeExecutionResult result) { + Map data = new HashMap<>(); + data.put("outcome", result.outcome().map(Outcome::toString).orElse("OK")); + data.put("output", result.output().orElse("")); + + ImmutableMap metadata = + ImmutableMap.of( + A2A_DATA_PART_METADATA_TYPE_KEY, A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT); + + return Optional.of(new DataPart(data, metadata)); + } + private PartConverter() {} /** Convert a GenAI part into the A2A JSON representation. */ @@ -260,7 +322,10 @@ public static Optional> fromGenaiPart(Part part) { return Optional.of(new FilePart(new FileWithBytes(mime, name, encoded), new HashMap<>())); } - if (part.functionCall().isPresent() || part.functionResponse().isPresent()) { + if (part.functionCall().isPresent() + || part.functionResponse().isPresent() + || part.executableCode().isPresent() + || part.codeExecutionResult().isPresent()) { return convertGenaiPartToA2aPart(part).map(data -> data); } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index 785ce6f3..73e53cea 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -13,6 +13,7 @@ import io.a2a.client.MessageEvent; import io.a2a.client.TaskEvent; import io.a2a.client.TaskUpdateEvent; +import io.a2a.spec.Artifact; import io.a2a.spec.EventKind; import io.a2a.spec.JSONRPCError; import io.a2a.spec.Message; @@ -189,8 +190,11 @@ private static Optional handleTaskUpdate( var updateEvent = event.getUpdateEvent(); if (updateEvent instanceof TaskArtifactUpdateEvent artifactEvent) { - if (Objects.equals(artifactEvent.isAppend(), false) - || Objects.equals(artifactEvent.isLastChunk(), true)) { + if (Objects.equals(artifactEvent.isAppend(), true)) { + Event eventPart = artifactToEvent(artifactEvent.getArtifact(), context); + eventPart.setPartial(Optional.of(true)); + return Optional.of(eventPart); + } else if (Objects.equals(artifactEvent.isLastChunk(), true)) { return Optional.of(taskToEvent(event.getTask(), context)); } return Optional.empty(); @@ -209,6 +213,13 @@ private static Optional handleTaskUpdate( "Unsupported TaskUpdateEvent type: " + updateEvent.getClass()); } + /** Converts an artifact to an ADK event. */ + public static Event artifactToEvent(Artifact artifact, InvocationContext invocationContext) { + Message message = + new Message.Builder().role(Message.Role.AGENT).parts(artifact.parts()).build(); + return messageToEvent(message, invocationContext); + } + /** Converts an A2A message back to ADK events. */ public static Event messageToEvent(Message message, InvocationContext invocationContext) { return remoteAgentEventBuilder(invocationContext) diff --git a/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java new file mode 100644 index 00000000..7d6af9bf --- /dev/null +++ b/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java @@ -0,0 +1,768 @@ +package com.google.adk.a2a; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import com.google.adk.a2a.common.A2AMetadata; +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.Callbacks; +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.RunConfig; +import com.google.adk.artifacts.InMemoryArtifactService; +import com.google.adk.events.Event; +import com.google.adk.plugins.PluginManager; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import io.a2a.client.Client; +import io.a2a.client.ClientEvent; +import io.a2a.client.TaskEvent; +import io.a2a.client.TaskUpdateEvent; +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.a2a.spec.Artifact; +import io.a2a.spec.DataPart; +import io.a2a.spec.FilePart; +import io.a2a.spec.FileWithUri; +import io.a2a.spec.Message; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; + +@RunWith(JUnit4.class) +public final class RemoteA2AAgentTest { + + private Client mockClient; + private AgentCard agentCard; + private InvocationContext invocationContext; + private Session session; + + @Before + public void setUp() { + mockClient = mock(Client.class); + agentCard = + new AgentCard.Builder() + .name("remote-agent") + .description("Remote Agent") + .version("1.0.0") + .url("http://example.com") + .capabilities(new AgentCapabilities.Builder().streaming(true).build()) + .defaultInputModes(ImmutableList.of("text")) + .defaultOutputModes(ImmutableList.of("text")) + .skills(ImmutableList.of()) + .build(); + + when(mockClient.getAgentCard()).thenReturn(agentCard); + + session = + Session.builder("session-1") + .appName("demo") + .userId("user") + .events( + ImmutableList.of( + Event.builder() + .id("event-1") + .author("user") + .content( + Content.builder() + .role("user") + .parts(ImmutableList.of(Part.builder().text("Hello").build())) + .build()) + .build())) + .build(); + + invocationContext = + InvocationContext.builder() + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .pluginManager(new PluginManager()) + .invocationId("invocation-1") + .agent(new TestAgent()) + .session(session) + .runConfig(RunConfig.builder().build()) + .endInvocation(false) + .build(); + } + + @Test + public void runAsync_aggregatesPartialEvents() { + RemoteA2AAgent agent = createAgent(); + + mockStreamResponse( + consumer -> { + consumer.accept(createPartialEvent("Hello ", true, false), agentCard); + consumer.accept(createPartialEvent("World!", true, false), agentCard); + consumer.accept(createFinalEvent("Final artifact content"), agentCard); + }); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(2); + + Event aggregatedEvent = events.get(0); + assertThat(aggregatedEvent.content().get().parts().get()).hasSize(1); + assertThought(aggregatedEvent, false); + assertText(aggregatedEvent, "Hello World!"); + assertAggregated(aggregatedEvent); + + Event finalEvent = events.get(1); + assertText(finalEvent, "Final artifact content"); + assertRequestMetadata(finalEvent); + assertResponseMetadata(finalEvent); + } + + @Test + public void runAsync_aggregatesInterleavedFunctionCalls() { + RemoteA2AAgent agent = createAgent(); + + mockStreamResponse( + consumer -> { + consumer.accept(createPartialEvent("Hello ", true, false), agentCard); + consumer.accept(createPartialFunctionCallEvent("get_weather", "call_1"), agentCard); + consumer.accept(createPartialEvent("World!", true, false), agentCard); + consumer.accept(createFinalEvent("Final"), agentCard); + }); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(4); + + assertText(events.get(0), "Hello "); + assertAggregated(events.get(0)); + + assertThat(events.get(1).content().get().parts().get().get(0).functionCall()).isPresent(); + assertThat( + events + .get(1) + .content() + .get() + .parts() + .get() + .get(0) + .functionCall() + .get() + .name() + .orElse("")) + .isEqualTo("get_weather"); + + assertText(events.get(2), "World!"); + assertAggregated(events.get(2)); + assertText(events.get(3), "Final"); + assertRequestMetadata(events.get(3)); + assertResponseMetadata(events.get(3)); + } + + @Test + public void runAsync_aggregatesFiles() { + RemoteA2AAgent agent = createAgent(); + + mockStreamResponse( + consumer -> { + consumer.accept(createPartialEvent("Here is a file: ", true, false), agentCard); + consumer.accept( + createPartialFileEvent("http://example.com/file.txt", "text/plain"), agentCard); + consumer.accept(createFinalEvent("Done"), agentCard); + }); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(3); + assertText(events.get(0), "Here is a file: "); + + Part filePart = events.get(1).content().get().parts().get().get(0); + assertThat(filePart.fileData()).isPresent(); + assertThat(filePart.fileData().get().fileUri().orElse("")) + .isEqualTo("http://example.com/file.txt"); + assertRequestMetadata(events.get(1)); + assertResponseMetadata(events.get(1)); + + assertText(events.get(2), "Done"); + assertRequestMetadata(events.get(2)); + assertResponseMetadata(events.get(2)); + } + + @Test + public void runAsync_handlesTasksWithStatusMessage() { + RemoteA2AAgent agent = createAgent(); + + mockStreamResponse( + consumer -> { + Task task = + new Task.Builder() + .id("task-1") + .contextId("context-1") + .status( + new TaskStatus( + TaskState.COMPLETED, + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("hello"))) + .build(), + null)) + .build(); + + consumer.accept(new TaskEvent(task), agentCard); + }); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertText(events.get(0), "hello"); + assertRequestMetadata(events.get(0)); + assertResponseMetadata(events.get(0)); + } + + @Test + public void runAsync_handlesTasksWithMultipartArtifact() { + RemoteA2AAgent agent = createAgent(); + + mockStreamResponse( + consumer -> { + Artifact artifact = + new Artifact.Builder() + .artifactId("artifact-1") + .parts(ImmutableList.of(new TextPart("hello"), new TextPart("world"))) + .build(); + Task task = + new Task.Builder() + .id("task-1") + .contextId("context-1") + .status(new TaskStatus(TaskState.COMPLETED)) + .artifacts(ImmutableList.of(artifact)) + .build(); + + consumer.accept(new TaskEvent(task), agentCard); + }); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).content().get().parts().get()).hasSize(2); + assertText(events.get(0), 0, "hello"); + assertText(events.get(0), 1, "world"); + assertRequestMetadata(events.get(0)); + assertResponseMetadata(events.get(0)); + } + + @Test + public void runAsync_handlesNonFinalStatusUpdatesAsThoughts() { + RemoteA2AAgent agent = createAgent(); + + mockStreamResponse( + consumer -> { + Task task1 = + new Task.Builder() + .id("task-1") + .contextId("context-1") + .status(new TaskStatus(TaskState.SUBMITTED)) + .build(); + consumer.accept( + new TaskUpdateEvent( + task1, + new TaskStatusUpdateEvent.Builder() + .taskId("task-1") + .contextId("context-1") + .status( + new TaskStatus( + TaskState.SUBMITTED, + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("submitted..."))) + .build(), + null)) + .build()), + agentCard); + + Task task2 = + new Task.Builder() + .id("task-1") + .contextId("context-1") + .status(new TaskStatus(TaskState.WORKING)) + .build(); + consumer.accept( + new TaskUpdateEvent( + task2, + new TaskStatusUpdateEvent.Builder() + .taskId("task-1") + .contextId("context-1") + .status( + new TaskStatus( + TaskState.WORKING, + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("working..."))) + .build(), + null)) + .build()), + agentCard); + + Task task3 = + new Task.Builder() + .id("task-1") + .contextId("context-1") + .status(new TaskStatus(TaskState.COMPLETED)) + .artifacts( + ImmutableList.of( + new Artifact.Builder() + .artifactId("a1") + .parts(ImmutableList.of(new TextPart("done"))) + .build())) + .build(); + consumer.accept(new TaskEvent(task3), agentCard); + }); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(3); + + assertText(events.get(0), "submitted..."); + assertThought(events.get(0), true); + assertRequestMetadata(events.get(0)); + assertResponseMetadata(events.get(0)); + + assertText(events.get(1), "working..."); + assertThought(events.get(1), true); + assertRequestMetadata(events.get(1)); + assertResponseMetadata(events.get(1)); + + assertText(events.get(2), "done"); + assertThought(events.get(2), false); + assertRequestMetadata(events.get(2)); + assertResponseMetadata(events.get(2)); + } + + @Test + @SuppressWarnings("unchecked") // cast for Mockito + public void runAsync_constructsRequestWithHistory() { + RemoteA2AAgent agent = createAgent(); + + Session historySession = + Session.builder("session-2") + .appName("demo") + .userId("user") + .events( + ImmutableList.of( + Event.builder() + .id("e1") + .author("user") + .content( + Content.builder() + .role("user") + .parts(ImmutableList.of(Part.builder().text("hello").build())) + .build()) + .build(), + Event.builder() + .id("e2") + .author("model") + .content( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.builder().text("hi").build())) + .build()) + .build(), + Event.builder() + .id("e3") + .author("user") + .content( + Content.builder() + .role("user") + .parts( + ImmutableList.of(Part.builder().text("how are you?").build())) + .build()) + .build())) + .build(); + + InvocationContext context = + InvocationContext.builder() + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .pluginManager(new PluginManager()) + .invocationId("invocation-2") + .agent(new TestAgent()) + .session(historySession) + .runConfig(RunConfig.builder().build()) + .build(); + + mockStreamResponse( + consumer -> { + consumer.accept(createFinalEvent("fine"), agentCard); + }); + + var unused = agent.runAsync(context).toList().blockingGet(); + + ArgumentCaptor messageCaptor = ArgumentCaptor.forClass(Message.class); + verify(mockClient) + .sendMessage(messageCaptor.capture(), any(List.class), any(Consumer.class), any()); + Message message = messageCaptor.getValue(); + + assertThat(message.getRole()).isEqualTo(Message.Role.USER); + + assertThat(message.getParts()).hasSize(3); + assertThat(((TextPart) message.getParts().get(0)).getText()).isEqualTo("hello"); + assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("hi"); + assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("how are you?"); + } + + @Test + @SuppressWarnings("unchecked") // cast for Mockito + public void runAsync_constructsRequestWithFunctionResponse() { + RemoteA2AAgent agent = createAgent(); + + Session session = + Session.builder("session-3") + .appName("demo") + .userId("user") + .events( + ImmutableList.of( + Event.builder() + .id("e1") + .author("user") + .content( + Content.builder() + .role("user") + .parts( + ImmutableList.of( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .name("fn") + .id("call-1") + .response(ImmutableMap.of("status", "ok")) + .build()) + .build())) + .build()) + .build())) + .build(); + + InvocationContext context = + InvocationContext.builder() + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .pluginManager(new PluginManager()) + .invocationId("invocation-3") + .agent(new TestAgent()) + .session(session) + .runConfig(RunConfig.builder().build()) + .build(); + + mockStreamResponse( + consumer -> { + consumer.accept(createFinalEvent("ok"), agentCard); + }); + + var unused = agent.runAsync(context).toList().blockingGet(); + + ArgumentCaptor messageCaptor = ArgumentCaptor.forClass(Message.class); + verify(mockClient) + .sendMessage(messageCaptor.capture(), any(List.class), any(Consumer.class), any()); + Message message = messageCaptor.getValue(); + + assertThat(message.getParts()).hasSize(1); + io.a2a.spec.Part part = message.getParts().get(0); + assertThat(part).isInstanceOf(DataPart.class); + DataPart dataPart = (DataPart) part; + assertThat(dataPart.getData().get("name")).isEqualTo("fn"); + assertThat(dataPart.getData().get("id")).isEqualTo("call-1"); + assertThat(dataPart.getMetadata().get("adk_type")).isEqualTo("function_response"); + } + + @Test + public void runAsync_invokesBeforeAndAfterCallbacks() { + AtomicBoolean beforeCalled = new AtomicBoolean(false); + AtomicBoolean afterCalled = new AtomicBoolean(false); + + RemoteA2AAgent agent = + getAgentBuilder() + .beforeAgentCallback( + ImmutableList.of( + (CallbackContext ctx) -> { + beforeCalled.set(true); + return Maybe.empty(); + })) + .afterAgentCallback( + ImmutableList.of( + (CallbackContext ctx) -> { + afterCalled.set(true); + return Maybe.empty(); + })) + .build(); + + mockStreamResponse( + consumer -> { + consumer.accept(createFinalEvent("done"), agentCard); + }); + + var unused = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(beforeCalled.get()).isTrue(); + assertThat(afterCalled.get()).isTrue(); + } + + @Test + public void runAsync_aggregatesCodeExecution() { + RemoteA2AAgent agent = createAgent(); + + mockStreamResponse( + consumer -> { + consumer.accept(createPartialCodeEvent("print('hello')", "python"), agentCard); + consumer.accept(createPartialCodeResultEvent("hello\n", "ok"), agentCard); + consumer.accept(createFinalEvent("Done"), agentCard); + }); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(3); + + Part codePart = events.get(0).content().get().parts().get().get(0); + assertThat(codePart.executableCode()).isPresent(); + assertThat(codePart.executableCode().get().code()).hasValue("print('hello')"); + assertThat(codePart.executableCode().get().language().get().toString()).isEqualTo("PYTHON"); + + Part resultPart = events.get(1).content().get().parts().get().get(0); + assertThat(resultPart.codeExecutionResult()).isPresent(); + assertThat(resultPart.codeExecutionResult().get().output()).hasValue("hello\n"); + + assertText(events.get(2), "Done"); + assertRequestMetadata(events.get(2)); + assertResponseMetadata(events.get(2)); + } + + @Test + public void runAsync_beforeCallbackCanShortCircuit() { + Content shortCircuitContent = + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.builder().text("short circuit").build())) + .build(); + + RemoteA2AAgent agent = + getAgentBuilder() + .beforeAgentCallback( + ImmutableList.of( + (CallbackContext ctx) -> Maybe.just(shortCircuitContent))) + .build(); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertText(events.get(0), "short circuit"); + verifyNoInteractions(mockClient); + } + + @Test + public void runAsync_handlesClientError() { + RemoteA2AAgent agent = createAgent(); + + mockStreamError(new RuntimeException("Connection failed")); + + agent + .runAsync(invocationContext) + .test() + .awaitDone(5, SECONDS) + .assertError(RuntimeException.class) + .assertError( + e -> e.getCause() != null && e.getCause().getMessage().contains("Connection failed")); + } + + private ClientEvent createPartialEvent(String text, boolean append, boolean lastChunk) { + return createTestEvent(new TextPart(text), TaskState.WORKING, append, lastChunk); + } + + private ClientEvent createPartialFunctionCallEvent(String name, String id) { + Map data = new HashMap<>(); + data.put("name", name); + data.put("id", id); + data.put("args", new HashMap<>()); + Map metadata = new HashMap<>(); + metadata.put("adk_type", "function_call"); + + return createTestEvent(new DataPart(data, metadata), TaskState.WORKING, true, false); + } + + private ClientEvent createPartialCodeEvent(String code, String language) { + Map data = new HashMap<>(); + data.put("code", code); + data.put("language", language); + Map metadata = new HashMap<>(); + metadata.put("adk_type", "executable_code"); + + return createTestEvent(new DataPart(data, metadata), TaskState.WORKING, true, false); + } + + private ClientEvent createPartialCodeResultEvent(String output, String outcome) { + Map data = new HashMap<>(); + data.put("output", output); + data.put("outcome", outcome); + Map metadata = new HashMap<>(); + metadata.put("adk_type", "code_execution_result"); + + return createTestEvent(new DataPart(data, metadata), TaskState.WORKING, true, false); + } + + private ClientEvent createPartialFileEvent(String uri, String mimeType) { + return createTestEvent( + new FilePart(new FileWithUri(mimeType, "file", uri)), TaskState.WORKING, true, false); + } + + private ClientEvent createFinalEvent(String text) { + return createTestEvent(new TextPart(text), TaskState.COMPLETED, false, false); + } + + private ClientEvent createTestEvent( + io.a2a.spec.Part part, TaskState state, boolean append, boolean lastChunk) { + Artifact artifact = + new Artifact.Builder().artifactId("artifact-1").parts(ImmutableList.of(part)).build(); + Task task = + new Task.Builder() + .id("task-1") + .contextId("context-1") + .status(new TaskStatus(state)) + .artifacts(ImmutableList.of(artifact)) + .build(); + + if (state == TaskState.COMPLETED && !append && !lastChunk) { + return new TaskEvent(task); + } + + TaskArtifactUpdateEvent updateEvent = + new TaskArtifactUpdateEvent.Builder() + .lastChunk(lastChunk) + .append(append) + .contextId("context-1") + .artifact(artifact) + .taskId("task-id-1") + .build(); + return new TaskUpdateEvent(task, updateEvent); + } + + private RemoteA2AAgent.Builder getAgentBuilder() { + return RemoteA2AAgent.builder().name("remote-agent").a2aClient(mockClient).agentCard(agentCard); + } + + private RemoteA2AAgent createAgent() { + return getAgentBuilder().build(); + } + + @SuppressWarnings("unchecked") // cast for Mockito + private void mockStreamResponse(Consumer> responseProducer) { + doAnswer( + invocation -> { + List> consumers = invocation.getArgument(1); + BiConsumer consumer = consumers.get(0); + responseProducer.accept(consumer); + return null; + }) + .when(mockClient) + .sendMessage(any(Message.class), any(List.class), any(Consumer.class), any()); + } + + @SuppressWarnings("unchecked") // cast for Mockito + private void mockStreamError(Throwable error) { + doAnswer( + invocation -> { + Consumer errorConsumer = invocation.getArgument(2); + errorConsumer.accept(error); + return null; + }) + .when(mockClient) + .sendMessage(any(Message.class), any(List.class), any(Consumer.class), any()); + } + + private void assertText(Event event, String expectedText) { + assertText(event, 0, expectedText); + } + + private void assertText(Event event, int partIndex, String expectedText) { + assertThat(event.content().get().parts().get().get(partIndex).text().orElse("")) + .isEqualTo(expectedText); + } + + private void assertThought(Event event, boolean expected) { + assertThat(event.content().get().parts().get().get(0).thought().orElse(false)) + .isEqualTo(expected); + } + + private void assertAggregated(Event event) { + assertThat(event.customMetadata()).isPresent(); + List metadata = event.customMetadata().get(); + + boolean hasAggregated = + metadata.stream() + .anyMatch( + m -> + A2AMetadata.toA2AMetaKey(A2AMetadata.Key.AGGREGATED).equals(m.key().orElse("")) + && Objects.equals(m.stringValue().orElse(""), "true")); + boolean hasRequest = + metadata.stream() + .anyMatch( + m -> A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST).equals(m.key().orElse(""))); + + assertThat(hasAggregated).isTrue(); + assertThat(hasRequest).isTrue(); + } + + private void assertRequestMetadata(Event event) { + assertThat(event.customMetadata()).isPresent(); + List metadata = event.customMetadata().get(); + boolean hasRequest = + metadata.stream() + .anyMatch( + m -> A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST).equals(m.key().orElse(""))); + assertThat(hasRequest).isTrue(); + } + + private void assertResponseMetadata(Event event) { + assertThat(event.customMetadata()).isPresent(); + List metadata = event.customMetadata().get(); + boolean hasResponse = + metadata.stream() + .anyMatch( + m -> A2AMetadata.toA2AMetaKey(A2AMetadata.Key.RESPONSE).equals(m.key().orElse(""))); + assertThat(hasResponse).isTrue(); + } + + private static final class TestAgent extends BaseAgent { + TestAgent() { + super("test_agent", "test", ImmutableList.of(), null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + } +} diff --git a/core/src/main/java/com/google/adk/events/Event.java b/core/src/main/java/com/google/adk/events/Event.java index 9e05918b..9cbaeddb 100644 --- a/core/src/main/java/com/google/adk/events/Event.java +++ b/core/src/main/java/com/google/adk/events/Event.java @@ -27,6 +27,7 @@ import com.google.common.collect.Iterables; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; import com.google.genai.types.FinishReason; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; @@ -61,6 +62,7 @@ public class Event extends JsonBaseModel { private Optional interrupted = Optional.empty(); private Optional branch = Optional.empty(); private Optional groundingMetadata = Optional.empty(); + private Optional> customMetadata = Optional.empty(); private Optional modelVersion = Optional.empty(); private long timestamp; @@ -242,6 +244,16 @@ public void setGroundingMetadata(Optional groundingMetadata) this.groundingMetadata = groundingMetadata; } + /** The custom metadata of the event. */ + @JsonProperty("customMetadata") + public Optional> customMetadata() { + return customMetadata; + } + + public void setCustomMetadata(Optional> customMetadata) { + this.customMetadata = customMetadata; + } + /** The model version used to generate the response. */ @JsonProperty("modelVersion") public Optional modelVersion() { @@ -347,6 +359,7 @@ public static class Builder { private Optional interrupted = Optional.empty(); private Optional branch = Optional.empty(); private Optional groundingMetadata = Optional.empty(); + private Optional> customMetadata = Optional.empty(); private Optional modelVersion = Optional.empty(); private Optional timestamp = Optional.empty(); @@ -570,6 +583,23 @@ Optional groundingMetadata() { return groundingMetadata; } + @CanIgnoreReturnValue + @JsonProperty("customMetadata") + public Builder customMetadata(@Nullable List value) { + this.customMetadata = Optional.ofNullable(value); + return this; + } + + @CanIgnoreReturnValue + public Builder customMetadata(Optional> value) { + this.customMetadata = value; + return this; + } + + Optional> customMetadata() { + return customMetadata; + } + @CanIgnoreReturnValue @JsonProperty("modelVersion") public Builder modelVersion(@Nullable String value) { @@ -604,6 +634,7 @@ public Event build() { event.setInterrupted(interrupted); event.branch(branch); event.setGroundingMetadata(groundingMetadata); + event.setCustomMetadata(customMetadata); event.setModelVersion(modelVersion); event.setActions(actions().orElseGet(() -> EventActions.builder().build())); event.setTimestamp(timestamp().orElseGet(() -> Instant.now().toEpochMilli())); @@ -640,6 +671,7 @@ public Builder toBuilder() { .interrupted(this.interrupted) .branch(this.branch) .groundingMetadata(this.groundingMetadata) + .customMetadata(this.customMetadata) .modelVersion(this.modelVersion); if (this.timestamp != 0) { builder.timestamp(this.timestamp); @@ -672,6 +704,7 @@ public boolean equals(Object obj) { && Objects.equals(interrupted, other.interrupted) && Objects.equals(branch, other.branch) && Objects.equals(groundingMetadata, other.groundingMetadata) + && Objects.equals(customMetadata, other.customMetadata) && Objects.equals(modelVersion, other.modelVersion); } @@ -699,6 +732,7 @@ public int hashCode() { interrupted, branch, groundingMetadata, + customMetadata, modelVersion, timestamp); } diff --git a/core/src/main/java/com/google/adk/models/LlmResponse.java b/core/src/main/java/com/google/adk/models/LlmResponse.java index 6f8f3d78..5fe885e4 100644 --- a/core/src/main/java/com/google/adk/models/LlmResponse.java +++ b/core/src/main/java/com/google/adk/models/LlmResponse.java @@ -25,6 +25,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Candidate; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; import com.google.genai.types.FinishReason; import com.google.genai.types.GenerateContentResponse; import com.google.genai.types.GenerateContentResponsePromptFeedback; @@ -59,6 +60,14 @@ public abstract class LlmResponse extends JsonBaseModel { @JsonProperty("groundingMetadata") public abstract Optional groundingMetadata(); + /** + * Returns the custom metadata of the response, if available. + * + * @return An {@link Optional} containing a list of {@link CustomMetadata} or empty. + */ + @JsonProperty("customMetadata") + public abstract Optional> customMetadata(); + /** * Indicates whether the text content is part of a unfinished text stream. * @@ -133,6 +142,11 @@ static LlmResponse.Builder jacksonBuilder() { public abstract Builder groundingMetadata(Optional groundingMetadata); + @JsonProperty("customMetadata") + public abstract Builder customMetadata(List customMetadata); + + public abstract Builder customMetadata(Optional> customMetadata); + @JsonProperty("partial") public abstract Builder partial(@Nullable Boolean partial);