diff --git a/a2a/pom.xml b/a2a/pom.xml
index dc606afa..88014840 100644
--- a/a2a/pom.xml
+++ b/a2a/pom.xml
@@ -106,6 +106,11 @@
${truth.version}
test
+
+ org.mockito
+ mockito-core
+ test
+
diff --git a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java
index 5e6e341d..bc6b9065 100644
--- a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java
+++ b/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java
@@ -2,7 +2,11 @@
import static com.google.common.base.Strings.nullToEmpty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.google.adk.a2a.common.A2AClientError;
+import com.google.adk.a2a.common.A2AMetadata;
import com.google.adk.a2a.converters.EventConverter;
import com.google.adk.a2a.converters.ResponseConverter;
import com.google.adk.agents.BaseAgent;
@@ -11,6 +15,9 @@
import com.google.adk.events.Event;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
+import com.google.genai.types.Content;
+import com.google.genai.types.CustomMetadata;
+import com.google.genai.types.Part;
import io.a2a.client.Client;
import io.a2a.client.ClientEvent;
import io.a2a.client.TaskEvent;
@@ -22,8 +29,11 @@
import io.reactivex.rxjava3.core.BackpressureStrategy;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.FlowableEmitter;
+import java.time.Instant;
+import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import org.slf4j.Logger;
@@ -54,6 +64,8 @@
public class RemoteA2AAgent extends BaseAgent {
private static final Logger logger = LoggerFactory.getLogger(RemoteA2AAgent.class);
+ private static final ObjectMapper objectMapper =
+ new ObjectMapper().registerModule(new JavaTimeModule());
private final AgentCard agentCard;
private final Client a2aClient;
@@ -173,60 +185,189 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) {
}
Message originalMessage = a2aMessageOpt.get();
+ String requestJson;
+ try {
+ requestJson = objectMapper.writeValueAsString(originalMessage);
+ } catch (JsonProcessingException e) {
+ logger.warn("Failed to serialize request", e);
+ requestJson = null;
+ }
+ String finalRequestJson = requestJson;
return Flowable.create(
emitter -> {
- FlowableEmitter flowableEmitter = emitter.serialize();
- AtomicBoolean done = new AtomicBoolean(false);
+ StreamHandler handler =
+ new StreamHandler(emitter.serialize(), invocationContext, finalRequestJson);
ImmutableList> consumers =
- ImmutableList.of(
- (event, unused) ->
- handleClientEvent(event, flowableEmitter, invocationContext, done));
- a2aClient.sendMessage(
- originalMessage, consumers, e -> handleClientError(e, flowableEmitter, done), null);
+ ImmutableList.of(handler::handleEvent);
+ a2aClient.sendMessage(originalMessage, consumers, handler::handleError, null);
},
BackpressureStrategy.BUFFER);
}
- private void handleClientError(Throwable e, FlowableEmitter emitter, AtomicBoolean done) {
- // Mark the flow as done if it is already cancelled.
- done.compareAndSet(false, emitter.isCancelled());
+ private class StreamHandler {
+ private final FlowableEmitter emitter;
+ private final InvocationContext invocationContext;
+ private final String requestJson;
+ private final AtomicBoolean done = new AtomicBoolean(false);
+ private final StringBuilder textBuffer = new StringBuilder();
+ private final StringBuilder thoughtsBuffer = new StringBuilder();
+
+ StreamHandler(
+ FlowableEmitter emitter, InvocationContext invocationContext, String requestJson) {
+ this.emitter = emitter;
+ this.invocationContext = invocationContext;
+ this.requestJson = requestJson;
+ }
+
+ void handleError(Throwable e) {
+ // Mark the flow as done if it is already cancelled.
+ done.compareAndSet(false, emitter.isCancelled());
- // If the flow is already done, stop processing and exit the consumer.
- if (done.get()) {
- return;
+ // If the flow is already done, stop processing.
+ if (done.get()) {
+ return;
+ }
+ // If the error is raised, complete the flow with an error.
+ if (!done.getAndSet(true)) {
+ emitter.tryOnError(new A2AClientError("Failed to communicate with the remote agent", e));
+ }
}
- // If the error is raised, complete the flow with an error.
- if (!done.getAndSet(true)) {
- emitter.tryOnError(new A2AClientError("Failed to communicate with the remote agent", e));
+
+ void handleEvent(ClientEvent clientEvent, AgentCard unused) {
+ // Mark the flow as done if it is already cancelled.
+ done.compareAndSet(false, emitter.isCancelled());
+
+ // If the flow is already done, stop processing.
+ if (done.get()) {
+ return;
+ }
+
+ Optional eventOpt =
+ ResponseConverter.clientEventToEvent(clientEvent, invocationContext);
+ if (eventOpt.isPresent()) {
+ Event event = eventOpt.get();
+ enrichWithMetadata(event, clientEvent);
+ boolean consumed = processContent(event);
+ if (!consumed) {
+ emitEvents(event);
+ }
+ }
+
+ // For non-streaming communication, complete the flow; for streaming, wait until the client
+ // marks the completion.
+ if (isCompleted(clientEvent) || !streaming) {
+ // Only complete the flow once.
+ if (!done.getAndSet(true)) {
+ emitter.onComplete();
+ }
+ }
}
- }
- private void handleClientEvent(
- ClientEvent clientEvent,
- FlowableEmitter emitter,
- InvocationContext invocationContext,
- AtomicBoolean done) {
- // Mark the flow as done if it is already cancelled.
- done.compareAndSet(false, emitter.isCancelled());
-
- // If the flow is already done, stop processing and exit the consumer.
- if (done.get()) {
- return;
+ private void enrichWithMetadata(Event event, ClientEvent clientEvent) {
+ List eventMetadata =
+ new ArrayList<>(event.customMetadata().orElse(ImmutableList.of()));
+ if (requestJson != null) {
+ eventMetadata.add(
+ CustomMetadata.builder()
+ .key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST))
+ .stringValue(requestJson)
+ .build());
+ }
+ try {
+ if (clientEvent != null) {
+ eventMetadata.add(
+ CustomMetadata.builder()
+ .key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.RESPONSE))
+ .stringValue(objectMapper.writeValueAsString(clientEvent))
+ .build());
+ }
+ } catch (JsonProcessingException e) {
+ logger.warn("Failed to serialize response metadata", e);
+ }
+ event.setCustomMetadata(Optional.of(ImmutableList.copyOf(eventMetadata)));
}
- Optional event = ResponseConverter.clientEventToEvent(clientEvent, invocationContext);
- if (event.isPresent()) {
- emitter.onNext(event.get());
+ private boolean processContent(Event event) {
+ if (!event.partial().orElse(false)) {
+ return false;
+ }
+
+ List nonTextParts = new ArrayList<>();
+ for (Part part : event.content().flatMap(Content::parts).orElse(ImmutableList.of())) {
+ if (part.text().isPresent()) {
+ String t = part.text().get();
+ if (part.thought().orElse(false)) {
+ thoughtsBuffer.append(t);
+ } else {
+ textBuffer.append(t);
+ }
+ } else {
+ nonTextParts.add(part);
+ }
+ }
+
+ if (nonTextParts.isEmpty()) {
+ return true;
+ }
+
+ Content nonTextContent = Content.builder().role("model").parts(nonTextParts).build();
+ event.setContent(Optional.of(nonTextContent));
+ return false;
}
- // For non-streaming communication, complete the flow; for streaming, wait until the client
- // marks the completion.
- if (isCompleted(clientEvent) || !streaming) {
- // Only complete the flow once.
- if (!done.getAndSet(true)) {
- emitter.onComplete();
+ private void emitEvents(Event event) {
+ List parts = new ArrayList<>();
+ if (thoughtsBuffer.length() > 0) {
+ parts.add(Part.builder().thought(true).text(thoughtsBuffer.toString()).build());
+ thoughtsBuffer.setLength(0);
}
+ if (textBuffer.length() > 0) {
+ parts.add(Part.builder().text(textBuffer.toString()).build());
+ textBuffer.setLength(0);
+ }
+
+ if (!parts.isEmpty()) {
+ Content aggregatedContent = Content.builder().role("model").parts(parts).build();
+
+ if (event.content().flatMap(Content::parts).orElse(ImmutableList.of()).isEmpty()) {
+ // Reuse empty event for aggregated content.
+ event.setContent(Optional.of(aggregatedContent));
+ emitter.onNext(event);
+ } else {
+ // Emit separate aggregated event first.
+ Event aggEvent = createAggregatedEvent(aggregatedContent);
+ emitter.onNext(aggEvent);
+ emitter.onNext(event);
+ }
+ } else {
+ emitter.onNext(event);
+ }
+ }
+
+ private Event createAggregatedEvent(Content content) {
+ List aggMetadata = new ArrayList<>();
+ aggMetadata.add(
+ CustomMetadata.builder()
+ .key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.AGGREGATED))
+ .stringValue("true")
+ .build());
+ if (requestJson != null) {
+ aggMetadata.add(
+ CustomMetadata.builder()
+ .key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST))
+ .stringValue(requestJson)
+ .build());
+ }
+
+ return Event.builder()
+ .id(UUID.randomUUID().toString())
+ .invocationId(invocationContext.invocationId())
+ .author("agent")
+ .content(Optional.of(content))
+ .timestamp(Instant.now().toEpochMilli())
+ .customMetadata(Optional.of(ImmutableList.copyOf(aggMetadata)))
+ .build();
}
}
diff --git a/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java b/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java
new file mode 100644
index 00000000..ff32fcaf
--- /dev/null
+++ b/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java
@@ -0,0 +1,28 @@
+package com.google.adk.a2a.common;
+
+/** Constants and utilities for A2A metadata keys. */
+public final class A2AMetadata {
+
+ /** Enum for A2A custom metadata keys. */
+ public enum Key {
+ REQUEST("request"),
+ RESPONSE("response"),
+ AGGREGATED("aggregated");
+
+ private final String value;
+
+ Key(String value) {
+ this.value = value;
+ }
+
+ public String getValue() {
+ return value;
+ }
+ }
+
+ public static String toA2AMetaKey(Key key) {
+ return "a2a:" + key.value;
+ }
+
+ private A2AMetadata() {}
+}
diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java
index c6ef0640..5513bff0 100644
--- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java
+++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java
@@ -7,10 +7,14 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Blob;
+import com.google.genai.types.CodeExecutionResult;
+import com.google.genai.types.ExecutableCode;
import com.google.genai.types.Content;
import com.google.genai.types.FileData;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.FunctionResponse;
+import com.google.genai.types.Language;
+import com.google.genai.types.Outcome;
import com.google.genai.types.Part;
import io.a2a.spec.DataPart;
import io.a2a.spec.FileContent;
@@ -22,6 +26,7 @@
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
+import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
@@ -95,6 +100,10 @@ public static Optional convertGenaiPartToA2aPart(Part part) {
return createDataPartFromFunctionCall(part.functionCall().get());
} else if (part.functionResponse().isPresent()) {
return createDataPartFromFunctionResponse(part.functionResponse().get());
+ } else if (part.executableCode().isPresent()) {
+ return createDataPartFromExecutableCode(part.executableCode().get());
+ } else if (part.codeExecutionResult().isPresent()) {
+ return createDataPartFromCodeExecutionResult(part.codeExecutionResult().get());
}
logger.warn("Cannot convert unsupported part for Google GenAI part: " + part);
@@ -174,6 +183,33 @@ private static Optional convertDataPartToGenAiPart(
.build());
}
+ if ((data.containsKey("code") && data.containsKey("language"))
+ || metadataType.equals(A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE)) {
+ String code = String.valueOf(data.getOrDefault("code", ""));
+ String language =
+ String.valueOf(data.getOrDefault("language", "PYTHON")).toUpperCase(Locale.getDefault());
+ return Optional.of(
+ com.google.genai.types.Part.builder()
+ .executableCode(
+ ExecutableCode.builder().code(code).language(new Language(language)).build())
+ .build());
+ }
+
+ if ((data.containsKey("outcome") && data.containsKey("output"))
+ || metadataType.equals(A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT)) {
+ String outcome =
+ String.valueOf(data.getOrDefault("outcome", "OK")).toUpperCase(Locale.getDefault());
+ String output = String.valueOf(data.getOrDefault("output", ""));
+ return Optional.of(
+ com.google.genai.types.Part.builder()
+ .codeExecutionResult(
+ CodeExecutionResult.builder()
+ .outcome(new Outcome(outcome))
+ .output(output)
+ .build())
+ .build());
+ }
+
try {
String json = objectMapper.writeValueAsString(data);
return Optional.of(com.google.genai.types.Part.builder().text(json).build());
@@ -231,6 +267,32 @@ private static Optional createDataPartFromFunctionResponse(
return Optional.of(new DataPart(data, metadata));
}
+ private static Optional createDataPartFromExecutableCode(
+ ExecutableCode executableCode) {
+ Map data = new HashMap<>();
+ data.put("code", executableCode.code().orElse(""));
+ data.put("language", executableCode.language().map(Language::toString).orElse("PYTHON"));
+
+ ImmutableMap metadata =
+ ImmutableMap.of(
+ A2A_DATA_PART_METADATA_TYPE_KEY, A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE);
+
+ return Optional.of(new DataPart(data, metadata));
+ }
+
+ private static Optional createDataPartFromCodeExecutionResult(
+ CodeExecutionResult result) {
+ Map data = new HashMap<>();
+ data.put("outcome", result.outcome().map(Outcome::toString).orElse("OK"));
+ data.put("output", result.output().orElse(""));
+
+ ImmutableMap metadata =
+ ImmutableMap.of(
+ A2A_DATA_PART_METADATA_TYPE_KEY, A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT);
+
+ return Optional.of(new DataPart(data, metadata));
+ }
+
private PartConverter() {}
/** Convert a GenAI part into the A2A JSON representation. */
@@ -260,7 +322,10 @@ public static Optional> fromGenaiPart(Part part) {
return Optional.of(new FilePart(new FileWithBytes(mime, name, encoded), new HashMap<>()));
}
- if (part.functionCall().isPresent() || part.functionResponse().isPresent()) {
+ if (part.functionCall().isPresent()
+ || part.functionResponse().isPresent()
+ || part.executableCode().isPresent()
+ || part.codeExecutionResult().isPresent()) {
return convertGenaiPartToA2aPart(part).map(data -> data);
}
diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java
index 785ce6f3..73e53cea 100644
--- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java
+++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java
@@ -13,6 +13,7 @@
import io.a2a.client.MessageEvent;
import io.a2a.client.TaskEvent;
import io.a2a.client.TaskUpdateEvent;
+import io.a2a.spec.Artifact;
import io.a2a.spec.EventKind;
import io.a2a.spec.JSONRPCError;
import io.a2a.spec.Message;
@@ -189,8 +190,11 @@ private static Optional handleTaskUpdate(
var updateEvent = event.getUpdateEvent();
if (updateEvent instanceof TaskArtifactUpdateEvent artifactEvent) {
- if (Objects.equals(artifactEvent.isAppend(), false)
- || Objects.equals(artifactEvent.isLastChunk(), true)) {
+ if (Objects.equals(artifactEvent.isAppend(), true)) {
+ Event eventPart = artifactToEvent(artifactEvent.getArtifact(), context);
+ eventPart.setPartial(Optional.of(true));
+ return Optional.of(eventPart);
+ } else if (Objects.equals(artifactEvent.isLastChunk(), true)) {
return Optional.of(taskToEvent(event.getTask(), context));
}
return Optional.empty();
@@ -209,6 +213,13 @@ private static Optional handleTaskUpdate(
"Unsupported TaskUpdateEvent type: " + updateEvent.getClass());
}
+ /** Converts an artifact to an ADK event. */
+ public static Event artifactToEvent(Artifact artifact, InvocationContext invocationContext) {
+ Message message =
+ new Message.Builder().role(Message.Role.AGENT).parts(artifact.parts()).build();
+ return messageToEvent(message, invocationContext);
+ }
+
/** Converts an A2A message back to ADK events. */
public static Event messageToEvent(Message message, InvocationContext invocationContext) {
return remoteAgentEventBuilder(invocationContext)
diff --git a/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java
new file mode 100644
index 00000000..7d6af9bf
--- /dev/null
+++ b/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java
@@ -0,0 +1,768 @@
+package com.google.adk.a2a;
+
+import static com.google.common.truth.Truth.assertThat;
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
+
+import com.google.adk.a2a.common.A2AMetadata;
+import com.google.adk.agents.BaseAgent;
+import com.google.adk.agents.CallbackContext;
+import com.google.adk.agents.Callbacks;
+import com.google.adk.agents.InvocationContext;
+import com.google.adk.agents.RunConfig;
+import com.google.adk.artifacts.InMemoryArtifactService;
+import com.google.adk.events.Event;
+import com.google.adk.plugins.PluginManager;
+import com.google.adk.sessions.InMemorySessionService;
+import com.google.adk.sessions.Session;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.genai.types.Content;
+import com.google.genai.types.CustomMetadata;
+import com.google.genai.types.FunctionResponse;
+import com.google.genai.types.Part;
+import io.a2a.client.Client;
+import io.a2a.client.ClientEvent;
+import io.a2a.client.TaskEvent;
+import io.a2a.client.TaskUpdateEvent;
+import io.a2a.spec.AgentCapabilities;
+import io.a2a.spec.AgentCard;
+import io.a2a.spec.Artifact;
+import io.a2a.spec.DataPart;
+import io.a2a.spec.FilePart;
+import io.a2a.spec.FileWithUri;
+import io.a2a.spec.Message;
+import io.a2a.spec.Task;
+import io.a2a.spec.TaskArtifactUpdateEvent;
+import io.a2a.spec.TaskState;
+import io.a2a.spec.TaskStatus;
+import io.a2a.spec.TaskStatusUpdateEvent;
+import io.a2a.spec.TextPart;
+import io.reactivex.rxjava3.core.Flowable;
+import io.reactivex.rxjava3.core.Maybe;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+
+@RunWith(JUnit4.class)
+public final class RemoteA2AAgentTest {
+
+ private Client mockClient;
+ private AgentCard agentCard;
+ private InvocationContext invocationContext;
+ private Session session;
+
+ @Before
+ public void setUp() {
+ mockClient = mock(Client.class);
+ agentCard =
+ new AgentCard.Builder()
+ .name("remote-agent")
+ .description("Remote Agent")
+ .version("1.0.0")
+ .url("http://example.com")
+ .capabilities(new AgentCapabilities.Builder().streaming(true).build())
+ .defaultInputModes(ImmutableList.of("text"))
+ .defaultOutputModes(ImmutableList.of("text"))
+ .skills(ImmutableList.of())
+ .build();
+
+ when(mockClient.getAgentCard()).thenReturn(agentCard);
+
+ session =
+ Session.builder("session-1")
+ .appName("demo")
+ .userId("user")
+ .events(
+ ImmutableList.of(
+ Event.builder()
+ .id("event-1")
+ .author("user")
+ .content(
+ Content.builder()
+ .role("user")
+ .parts(ImmutableList.of(Part.builder().text("Hello").build()))
+ .build())
+ .build()))
+ .build();
+
+ invocationContext =
+ InvocationContext.builder()
+ .sessionService(new InMemorySessionService())
+ .artifactService(new InMemoryArtifactService())
+ .pluginManager(new PluginManager())
+ .invocationId("invocation-1")
+ .agent(new TestAgent())
+ .session(session)
+ .runConfig(RunConfig.builder().build())
+ .endInvocation(false)
+ .build();
+ }
+
+ @Test
+ public void runAsync_aggregatesPartialEvents() {
+ RemoteA2AAgent agent = createAgent();
+
+ mockStreamResponse(
+ consumer -> {
+ consumer.accept(createPartialEvent("Hello ", true, false), agentCard);
+ consumer.accept(createPartialEvent("World!", true, false), agentCard);
+ consumer.accept(createFinalEvent("Final artifact content"), agentCard);
+ });
+
+ List events = agent.runAsync(invocationContext).toList().blockingGet();
+
+ assertThat(events).hasSize(2);
+
+ Event aggregatedEvent = events.get(0);
+ assertThat(aggregatedEvent.content().get().parts().get()).hasSize(1);
+ assertThought(aggregatedEvent, false);
+ assertText(aggregatedEvent, "Hello World!");
+ assertAggregated(aggregatedEvent);
+
+ Event finalEvent = events.get(1);
+ assertText(finalEvent, "Final artifact content");
+ assertRequestMetadata(finalEvent);
+ assertResponseMetadata(finalEvent);
+ }
+
+ @Test
+ public void runAsync_aggregatesInterleavedFunctionCalls() {
+ RemoteA2AAgent agent = createAgent();
+
+ mockStreamResponse(
+ consumer -> {
+ consumer.accept(createPartialEvent("Hello ", true, false), agentCard);
+ consumer.accept(createPartialFunctionCallEvent("get_weather", "call_1"), agentCard);
+ consumer.accept(createPartialEvent("World!", true, false), agentCard);
+ consumer.accept(createFinalEvent("Final"), agentCard);
+ });
+
+ List events = agent.runAsync(invocationContext).toList().blockingGet();
+
+ assertThat(events).hasSize(4);
+
+ assertText(events.get(0), "Hello ");
+ assertAggregated(events.get(0));
+
+ assertThat(events.get(1).content().get().parts().get().get(0).functionCall()).isPresent();
+ assertThat(
+ events
+ .get(1)
+ .content()
+ .get()
+ .parts()
+ .get()
+ .get(0)
+ .functionCall()
+ .get()
+ .name()
+ .orElse(""))
+ .isEqualTo("get_weather");
+
+ assertText(events.get(2), "World!");
+ assertAggregated(events.get(2));
+ assertText(events.get(3), "Final");
+ assertRequestMetadata(events.get(3));
+ assertResponseMetadata(events.get(3));
+ }
+
+ @Test
+ public void runAsync_aggregatesFiles() {
+ RemoteA2AAgent agent = createAgent();
+
+ mockStreamResponse(
+ consumer -> {
+ consumer.accept(createPartialEvent("Here is a file: ", true, false), agentCard);
+ consumer.accept(
+ createPartialFileEvent("http://example.com/file.txt", "text/plain"), agentCard);
+ consumer.accept(createFinalEvent("Done"), agentCard);
+ });
+
+ List events = agent.runAsync(invocationContext).toList().blockingGet();
+
+ assertThat(events).hasSize(3);
+ assertText(events.get(0), "Here is a file: ");
+
+ Part filePart = events.get(1).content().get().parts().get().get(0);
+ assertThat(filePart.fileData()).isPresent();
+ assertThat(filePart.fileData().get().fileUri().orElse(""))
+ .isEqualTo("http://example.com/file.txt");
+ assertRequestMetadata(events.get(1));
+ assertResponseMetadata(events.get(1));
+
+ assertText(events.get(2), "Done");
+ assertRequestMetadata(events.get(2));
+ assertResponseMetadata(events.get(2));
+ }
+
+ @Test
+ public void runAsync_handlesTasksWithStatusMessage() {
+ RemoteA2AAgent agent = createAgent();
+
+ mockStreamResponse(
+ consumer -> {
+ Task task =
+ new Task.Builder()
+ .id("task-1")
+ .contextId("context-1")
+ .status(
+ new TaskStatus(
+ TaskState.COMPLETED,
+ new Message.Builder()
+ .role(Message.Role.AGENT)
+ .parts(ImmutableList.of(new TextPart("hello")))
+ .build(),
+ null))
+ .build();
+
+ consumer.accept(new TaskEvent(task), agentCard);
+ });
+
+ List events = agent.runAsync(invocationContext).toList().blockingGet();
+
+ assertThat(events).hasSize(1);
+ assertText(events.get(0), "hello");
+ assertRequestMetadata(events.get(0));
+ assertResponseMetadata(events.get(0));
+ }
+
+ @Test
+ public void runAsync_handlesTasksWithMultipartArtifact() {
+ RemoteA2AAgent agent = createAgent();
+
+ mockStreamResponse(
+ consumer -> {
+ Artifact artifact =
+ new Artifact.Builder()
+ .artifactId("artifact-1")
+ .parts(ImmutableList.of(new TextPart("hello"), new TextPart("world")))
+ .build();
+ Task task =
+ new Task.Builder()
+ .id("task-1")
+ .contextId("context-1")
+ .status(new TaskStatus(TaskState.COMPLETED))
+ .artifacts(ImmutableList.of(artifact))
+ .build();
+
+ consumer.accept(new TaskEvent(task), agentCard);
+ });
+
+ List events = agent.runAsync(invocationContext).toList().blockingGet();
+
+ assertThat(events).hasSize(1);
+ assertThat(events.get(0).content().get().parts().get()).hasSize(2);
+ assertText(events.get(0), 0, "hello");
+ assertText(events.get(0), 1, "world");
+ assertRequestMetadata(events.get(0));
+ assertResponseMetadata(events.get(0));
+ }
+
+ @Test
+ public void runAsync_handlesNonFinalStatusUpdatesAsThoughts() {
+ RemoteA2AAgent agent = createAgent();
+
+ mockStreamResponse(
+ consumer -> {
+ Task task1 =
+ new Task.Builder()
+ .id("task-1")
+ .contextId("context-1")
+ .status(new TaskStatus(TaskState.SUBMITTED))
+ .build();
+ consumer.accept(
+ new TaskUpdateEvent(
+ task1,
+ new TaskStatusUpdateEvent.Builder()
+ .taskId("task-1")
+ .contextId("context-1")
+ .status(
+ new TaskStatus(
+ TaskState.SUBMITTED,
+ new Message.Builder()
+ .role(Message.Role.AGENT)
+ .parts(ImmutableList.of(new TextPart("submitted...")))
+ .build(),
+ null))
+ .build()),
+ agentCard);
+
+ Task task2 =
+ new Task.Builder()
+ .id("task-1")
+ .contextId("context-1")
+ .status(new TaskStatus(TaskState.WORKING))
+ .build();
+ consumer.accept(
+ new TaskUpdateEvent(
+ task2,
+ new TaskStatusUpdateEvent.Builder()
+ .taskId("task-1")
+ .contextId("context-1")
+ .status(
+ new TaskStatus(
+ TaskState.WORKING,
+ new Message.Builder()
+ .role(Message.Role.AGENT)
+ .parts(ImmutableList.of(new TextPart("working...")))
+ .build(),
+ null))
+ .build()),
+ agentCard);
+
+ Task task3 =
+ new Task.Builder()
+ .id("task-1")
+ .contextId("context-1")
+ .status(new TaskStatus(TaskState.COMPLETED))
+ .artifacts(
+ ImmutableList.of(
+ new Artifact.Builder()
+ .artifactId("a1")
+ .parts(ImmutableList.of(new TextPart("done")))
+ .build()))
+ .build();
+ consumer.accept(new TaskEvent(task3), agentCard);
+ });
+
+ List events = agent.runAsync(invocationContext).toList().blockingGet();
+
+ assertThat(events).hasSize(3);
+
+ assertText(events.get(0), "submitted...");
+ assertThought(events.get(0), true);
+ assertRequestMetadata(events.get(0));
+ assertResponseMetadata(events.get(0));
+
+ assertText(events.get(1), "working...");
+ assertThought(events.get(1), true);
+ assertRequestMetadata(events.get(1));
+ assertResponseMetadata(events.get(1));
+
+ assertText(events.get(2), "done");
+ assertThought(events.get(2), false);
+ assertRequestMetadata(events.get(2));
+ assertResponseMetadata(events.get(2));
+ }
+
+ @Test
+ @SuppressWarnings("unchecked") // cast for Mockito
+ public void runAsync_constructsRequestWithHistory() {
+ RemoteA2AAgent agent = createAgent();
+
+ Session historySession =
+ Session.builder("session-2")
+ .appName("demo")
+ .userId("user")
+ .events(
+ ImmutableList.of(
+ Event.builder()
+ .id("e1")
+ .author("user")
+ .content(
+ Content.builder()
+ .role("user")
+ .parts(ImmutableList.of(Part.builder().text("hello").build()))
+ .build())
+ .build(),
+ Event.builder()
+ .id("e2")
+ .author("model")
+ .content(
+ Content.builder()
+ .role("model")
+ .parts(ImmutableList.of(Part.builder().text("hi").build()))
+ .build())
+ .build(),
+ Event.builder()
+ .id("e3")
+ .author("user")
+ .content(
+ Content.builder()
+ .role("user")
+ .parts(
+ ImmutableList.of(Part.builder().text("how are you?").build()))
+ .build())
+ .build()))
+ .build();
+
+ InvocationContext context =
+ InvocationContext.builder()
+ .sessionService(new InMemorySessionService())
+ .artifactService(new InMemoryArtifactService())
+ .pluginManager(new PluginManager())
+ .invocationId("invocation-2")
+ .agent(new TestAgent())
+ .session(historySession)
+ .runConfig(RunConfig.builder().build())
+ .build();
+
+ mockStreamResponse(
+ consumer -> {
+ consumer.accept(createFinalEvent("fine"), agentCard);
+ });
+
+ var unused = agent.runAsync(context).toList().blockingGet();
+
+ ArgumentCaptor messageCaptor = ArgumentCaptor.forClass(Message.class);
+ verify(mockClient)
+ .sendMessage(messageCaptor.capture(), any(List.class), any(Consumer.class), any());
+ Message message = messageCaptor.getValue();
+
+ assertThat(message.getRole()).isEqualTo(Message.Role.USER);
+
+ assertThat(message.getParts()).hasSize(3);
+ assertThat(((TextPart) message.getParts().get(0)).getText()).isEqualTo("hello");
+ assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("hi");
+ assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("how are you?");
+ }
+
+ @Test
+ @SuppressWarnings("unchecked") // cast for Mockito
+ public void runAsync_constructsRequestWithFunctionResponse() {
+ RemoteA2AAgent agent = createAgent();
+
+ Session session =
+ Session.builder("session-3")
+ .appName("demo")
+ .userId("user")
+ .events(
+ ImmutableList.of(
+ Event.builder()
+ .id("e1")
+ .author("user")
+ .content(
+ Content.builder()
+ .role("user")
+ .parts(
+ ImmutableList.of(
+ Part.builder()
+ .functionResponse(
+ FunctionResponse.builder()
+ .name("fn")
+ .id("call-1")
+ .response(ImmutableMap.of("status", "ok"))
+ .build())
+ .build()))
+ .build())
+ .build()))
+ .build();
+
+ InvocationContext context =
+ InvocationContext.builder()
+ .sessionService(new InMemorySessionService())
+ .artifactService(new InMemoryArtifactService())
+ .pluginManager(new PluginManager())
+ .invocationId("invocation-3")
+ .agent(new TestAgent())
+ .session(session)
+ .runConfig(RunConfig.builder().build())
+ .build();
+
+ mockStreamResponse(
+ consumer -> {
+ consumer.accept(createFinalEvent("ok"), agentCard);
+ });
+
+ var unused = agent.runAsync(context).toList().blockingGet();
+
+ ArgumentCaptor messageCaptor = ArgumentCaptor.forClass(Message.class);
+ verify(mockClient)
+ .sendMessage(messageCaptor.capture(), any(List.class), any(Consumer.class), any());
+ Message message = messageCaptor.getValue();
+
+ assertThat(message.getParts()).hasSize(1);
+ io.a2a.spec.Part> part = message.getParts().get(0);
+ assertThat(part).isInstanceOf(DataPart.class);
+ DataPart dataPart = (DataPart) part;
+ assertThat(dataPart.getData().get("name")).isEqualTo("fn");
+ assertThat(dataPart.getData().get("id")).isEqualTo("call-1");
+ assertThat(dataPart.getMetadata().get("adk_type")).isEqualTo("function_response");
+ }
+
+ @Test
+ public void runAsync_invokesBeforeAndAfterCallbacks() {
+ AtomicBoolean beforeCalled = new AtomicBoolean(false);
+ AtomicBoolean afterCalled = new AtomicBoolean(false);
+
+ RemoteA2AAgent agent =
+ getAgentBuilder()
+ .beforeAgentCallback(
+ ImmutableList.of(
+ (CallbackContext ctx) -> {
+ beforeCalled.set(true);
+ return Maybe.empty();
+ }))
+ .afterAgentCallback(
+ ImmutableList.of(
+ (CallbackContext ctx) -> {
+ afterCalled.set(true);
+ return Maybe.empty();
+ }))
+ .build();
+
+ mockStreamResponse(
+ consumer -> {
+ consumer.accept(createFinalEvent("done"), agentCard);
+ });
+
+ var unused = agent.runAsync(invocationContext).toList().blockingGet();
+
+ assertThat(beforeCalled.get()).isTrue();
+ assertThat(afterCalled.get()).isTrue();
+ }
+
+ @Test
+ public void runAsync_aggregatesCodeExecution() {
+ RemoteA2AAgent agent = createAgent();
+
+ mockStreamResponse(
+ consumer -> {
+ consumer.accept(createPartialCodeEvent("print('hello')", "python"), agentCard);
+ consumer.accept(createPartialCodeResultEvent("hello\n", "ok"), agentCard);
+ consumer.accept(createFinalEvent("Done"), agentCard);
+ });
+
+ List events = agent.runAsync(invocationContext).toList().blockingGet();
+
+ assertThat(events).hasSize(3);
+
+ Part codePart = events.get(0).content().get().parts().get().get(0);
+ assertThat(codePart.executableCode()).isPresent();
+ assertThat(codePart.executableCode().get().code()).hasValue("print('hello')");
+ assertThat(codePart.executableCode().get().language().get().toString()).isEqualTo("PYTHON");
+
+ Part resultPart = events.get(1).content().get().parts().get().get(0);
+ assertThat(resultPart.codeExecutionResult()).isPresent();
+ assertThat(resultPart.codeExecutionResult().get().output()).hasValue("hello\n");
+
+ assertText(events.get(2), "Done");
+ assertRequestMetadata(events.get(2));
+ assertResponseMetadata(events.get(2));
+ }
+
+ @Test
+ public void runAsync_beforeCallbackCanShortCircuit() {
+ Content shortCircuitContent =
+ Content.builder()
+ .role("model")
+ .parts(ImmutableList.of(Part.builder().text("short circuit").build()))
+ .build();
+
+ RemoteA2AAgent agent =
+ getAgentBuilder()
+ .beforeAgentCallback(
+ ImmutableList.of(
+ (CallbackContext ctx) -> Maybe.just(shortCircuitContent)))
+ .build();
+
+ List events = agent.runAsync(invocationContext).toList().blockingGet();
+
+ assertThat(events).hasSize(1);
+ assertText(events.get(0), "short circuit");
+ verifyNoInteractions(mockClient);
+ }
+
+ @Test
+ public void runAsync_handlesClientError() {
+ RemoteA2AAgent agent = createAgent();
+
+ mockStreamError(new RuntimeException("Connection failed"));
+
+ agent
+ .runAsync(invocationContext)
+ .test()
+ .awaitDone(5, SECONDS)
+ .assertError(RuntimeException.class)
+ .assertError(
+ e -> e.getCause() != null && e.getCause().getMessage().contains("Connection failed"));
+ }
+
+ private ClientEvent createPartialEvent(String text, boolean append, boolean lastChunk) {
+ return createTestEvent(new TextPart(text), TaskState.WORKING, append, lastChunk);
+ }
+
+ private ClientEvent createPartialFunctionCallEvent(String name, String id) {
+ Map data = new HashMap<>();
+ data.put("name", name);
+ data.put("id", id);
+ data.put("args", new HashMap<>());
+ Map metadata = new HashMap<>();
+ metadata.put("adk_type", "function_call");
+
+ return createTestEvent(new DataPart(data, metadata), TaskState.WORKING, true, false);
+ }
+
+ private ClientEvent createPartialCodeEvent(String code, String language) {
+ Map data = new HashMap<>();
+ data.put("code", code);
+ data.put("language", language);
+ Map metadata = new HashMap<>();
+ metadata.put("adk_type", "executable_code");
+
+ return createTestEvent(new DataPart(data, metadata), TaskState.WORKING, true, false);
+ }
+
+ private ClientEvent createPartialCodeResultEvent(String output, String outcome) {
+ Map data = new HashMap<>();
+ data.put("output", output);
+ data.put("outcome", outcome);
+ Map metadata = new HashMap<>();
+ metadata.put("adk_type", "code_execution_result");
+
+ return createTestEvent(new DataPart(data, metadata), TaskState.WORKING, true, false);
+ }
+
+ private ClientEvent createPartialFileEvent(String uri, String mimeType) {
+ return createTestEvent(
+ new FilePart(new FileWithUri(mimeType, "file", uri)), TaskState.WORKING, true, false);
+ }
+
+ private ClientEvent createFinalEvent(String text) {
+ return createTestEvent(new TextPart(text), TaskState.COMPLETED, false, false);
+ }
+
+ private ClientEvent createTestEvent(
+ io.a2a.spec.Part> part, TaskState state, boolean append, boolean lastChunk) {
+ Artifact artifact =
+ new Artifact.Builder().artifactId("artifact-1").parts(ImmutableList.of(part)).build();
+ Task task =
+ new Task.Builder()
+ .id("task-1")
+ .contextId("context-1")
+ .status(new TaskStatus(state))
+ .artifacts(ImmutableList.of(artifact))
+ .build();
+
+ if (state == TaskState.COMPLETED && !append && !lastChunk) {
+ return new TaskEvent(task);
+ }
+
+ TaskArtifactUpdateEvent updateEvent =
+ new TaskArtifactUpdateEvent.Builder()
+ .lastChunk(lastChunk)
+ .append(append)
+ .contextId("context-1")
+ .artifact(artifact)
+ .taskId("task-id-1")
+ .build();
+ return new TaskUpdateEvent(task, updateEvent);
+ }
+
+ private RemoteA2AAgent.Builder getAgentBuilder() {
+ return RemoteA2AAgent.builder().name("remote-agent").a2aClient(mockClient).agentCard(agentCard);
+ }
+
+ private RemoteA2AAgent createAgent() {
+ return getAgentBuilder().build();
+ }
+
+ @SuppressWarnings("unchecked") // cast for Mockito
+ private void mockStreamResponse(Consumer> responseProducer) {
+ doAnswer(
+ invocation -> {
+ List> consumers = invocation.getArgument(1);
+ BiConsumer consumer = consumers.get(0);
+ responseProducer.accept(consumer);
+ return null;
+ })
+ .when(mockClient)
+ .sendMessage(any(Message.class), any(List.class), any(Consumer.class), any());
+ }
+
+ @SuppressWarnings("unchecked") // cast for Mockito
+ private void mockStreamError(Throwable error) {
+ doAnswer(
+ invocation -> {
+ Consumer errorConsumer = invocation.getArgument(2);
+ errorConsumer.accept(error);
+ return null;
+ })
+ .when(mockClient)
+ .sendMessage(any(Message.class), any(List.class), any(Consumer.class), any());
+ }
+
+ private void assertText(Event event, String expectedText) {
+ assertText(event, 0, expectedText);
+ }
+
+ private void assertText(Event event, int partIndex, String expectedText) {
+ assertThat(event.content().get().parts().get().get(partIndex).text().orElse(""))
+ .isEqualTo(expectedText);
+ }
+
+ private void assertThought(Event event, boolean expected) {
+ assertThat(event.content().get().parts().get().get(0).thought().orElse(false))
+ .isEqualTo(expected);
+ }
+
+ private void assertAggregated(Event event) {
+ assertThat(event.customMetadata()).isPresent();
+ List metadata = event.customMetadata().get();
+
+ boolean hasAggregated =
+ metadata.stream()
+ .anyMatch(
+ m ->
+ A2AMetadata.toA2AMetaKey(A2AMetadata.Key.AGGREGATED).equals(m.key().orElse(""))
+ && Objects.equals(m.stringValue().orElse(""), "true"));
+ boolean hasRequest =
+ metadata.stream()
+ .anyMatch(
+ m -> A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST).equals(m.key().orElse("")));
+
+ assertThat(hasAggregated).isTrue();
+ assertThat(hasRequest).isTrue();
+ }
+
+ private void assertRequestMetadata(Event event) {
+ assertThat(event.customMetadata()).isPresent();
+ List metadata = event.customMetadata().get();
+ boolean hasRequest =
+ metadata.stream()
+ .anyMatch(
+ m -> A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST).equals(m.key().orElse("")));
+ assertThat(hasRequest).isTrue();
+ }
+
+ private void assertResponseMetadata(Event event) {
+ assertThat(event.customMetadata()).isPresent();
+ List metadata = event.customMetadata().get();
+ boolean hasResponse =
+ metadata.stream()
+ .anyMatch(
+ m -> A2AMetadata.toA2AMetaKey(A2AMetadata.Key.RESPONSE).equals(m.key().orElse("")));
+ assertThat(hasResponse).isTrue();
+ }
+
+ private static final class TestAgent extends BaseAgent {
+ TestAgent() {
+ super("test_agent", "test", ImmutableList.of(), null, null);
+ }
+
+ @Override
+ protected Flowable runAsyncImpl(InvocationContext invocationContext) {
+ return Flowable.empty();
+ }
+
+ @Override
+ protected Flowable runLiveImpl(InvocationContext invocationContext) {
+ return Flowable.empty();
+ }
+ }
+}
diff --git a/core/src/main/java/com/google/adk/events/Event.java b/core/src/main/java/com/google/adk/events/Event.java
index 9e05918b..9cbaeddb 100644
--- a/core/src/main/java/com/google/adk/events/Event.java
+++ b/core/src/main/java/com/google/adk/events/Event.java
@@ -27,6 +27,7 @@
import com.google.common.collect.Iterables;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.genai.types.Content;
+import com.google.genai.types.CustomMetadata;
import com.google.genai.types.FinishReason;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.FunctionResponse;
@@ -61,6 +62,7 @@ public class Event extends JsonBaseModel {
private Optional interrupted = Optional.empty();
private Optional branch = Optional.empty();
private Optional groundingMetadata = Optional.empty();
+ private Optional> customMetadata = Optional.empty();
private Optional modelVersion = Optional.empty();
private long timestamp;
@@ -242,6 +244,16 @@ public void setGroundingMetadata(Optional groundingMetadata)
this.groundingMetadata = groundingMetadata;
}
+ /** The custom metadata of the event. */
+ @JsonProperty("customMetadata")
+ public Optional> customMetadata() {
+ return customMetadata;
+ }
+
+ public void setCustomMetadata(Optional> customMetadata) {
+ this.customMetadata = customMetadata;
+ }
+
/** The model version used to generate the response. */
@JsonProperty("modelVersion")
public Optional modelVersion() {
@@ -347,6 +359,7 @@ public static class Builder {
private Optional interrupted = Optional.empty();
private Optional branch = Optional.empty();
private Optional groundingMetadata = Optional.empty();
+ private Optional> customMetadata = Optional.empty();
private Optional modelVersion = Optional.empty();
private Optional timestamp = Optional.empty();
@@ -570,6 +583,23 @@ Optional groundingMetadata() {
return groundingMetadata;
}
+ @CanIgnoreReturnValue
+ @JsonProperty("customMetadata")
+ public Builder customMetadata(@Nullable List value) {
+ this.customMetadata = Optional.ofNullable(value);
+ return this;
+ }
+
+ @CanIgnoreReturnValue
+ public Builder customMetadata(Optional> value) {
+ this.customMetadata = value;
+ return this;
+ }
+
+ Optional> customMetadata() {
+ return customMetadata;
+ }
+
@CanIgnoreReturnValue
@JsonProperty("modelVersion")
public Builder modelVersion(@Nullable String value) {
@@ -604,6 +634,7 @@ public Event build() {
event.setInterrupted(interrupted);
event.branch(branch);
event.setGroundingMetadata(groundingMetadata);
+ event.setCustomMetadata(customMetadata);
event.setModelVersion(modelVersion);
event.setActions(actions().orElseGet(() -> EventActions.builder().build()));
event.setTimestamp(timestamp().orElseGet(() -> Instant.now().toEpochMilli()));
@@ -640,6 +671,7 @@ public Builder toBuilder() {
.interrupted(this.interrupted)
.branch(this.branch)
.groundingMetadata(this.groundingMetadata)
+ .customMetadata(this.customMetadata)
.modelVersion(this.modelVersion);
if (this.timestamp != 0) {
builder.timestamp(this.timestamp);
@@ -672,6 +704,7 @@ public boolean equals(Object obj) {
&& Objects.equals(interrupted, other.interrupted)
&& Objects.equals(branch, other.branch)
&& Objects.equals(groundingMetadata, other.groundingMetadata)
+ && Objects.equals(customMetadata, other.customMetadata)
&& Objects.equals(modelVersion, other.modelVersion);
}
@@ -699,6 +732,7 @@ public int hashCode() {
interrupted,
branch,
groundingMetadata,
+ customMetadata,
modelVersion,
timestamp);
}
diff --git a/core/src/main/java/com/google/adk/models/LlmResponse.java b/core/src/main/java/com/google/adk/models/LlmResponse.java
index 6f8f3d78..5fe885e4 100644
--- a/core/src/main/java/com/google/adk/models/LlmResponse.java
+++ b/core/src/main/java/com/google/adk/models/LlmResponse.java
@@ -25,6 +25,7 @@
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.genai.types.Candidate;
import com.google.genai.types.Content;
+import com.google.genai.types.CustomMetadata;
import com.google.genai.types.FinishReason;
import com.google.genai.types.GenerateContentResponse;
import com.google.genai.types.GenerateContentResponsePromptFeedback;
@@ -59,6 +60,14 @@ public abstract class LlmResponse extends JsonBaseModel {
@JsonProperty("groundingMetadata")
public abstract Optional groundingMetadata();
+ /**
+ * Returns the custom metadata of the response, if available.
+ *
+ * @return An {@link Optional} containing a list of {@link CustomMetadata} or empty.
+ */
+ @JsonProperty("customMetadata")
+ public abstract Optional> customMetadata();
+
/**
* Indicates whether the text content is part of a unfinished text stream.
*
@@ -133,6 +142,11 @@ static LlmResponse.Builder jacksonBuilder() {
public abstract Builder groundingMetadata(Optional groundingMetadata);
+ @JsonProperty("customMetadata")
+ public abstract Builder customMetadata(List customMetadata);
+
+ public abstract Builder customMetadata(Optional> customMetadata);
+
@JsonProperty("partial")
public abstract Builder partial(@Nullable Boolean partial);