Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions a2a/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@
<version>${truth.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
215 changes: 178 additions & 37 deletions a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import static com.google.common.base.Strings.nullToEmpty;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.google.adk.a2a.common.A2AClientError;
import com.google.adk.a2a.common.A2AMetadata;
import com.google.adk.a2a.converters.EventConverter;
import com.google.adk.a2a.converters.ResponseConverter;
import com.google.adk.agents.BaseAgent;
Expand All @@ -11,6 +15,9 @@
import com.google.adk.events.Event;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.genai.types.Content;
import com.google.genai.types.CustomMetadata;
import com.google.genai.types.Part;
import io.a2a.client.Client;
import io.a2a.client.ClientEvent;
import io.a2a.client.TaskEvent;
Expand All @@ -22,8 +29,11 @@
import io.reactivex.rxjava3.core.BackpressureStrategy;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.FlowableEmitter;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import org.slf4j.Logger;
Expand Down Expand Up @@ -54,6 +64,8 @@
public class RemoteA2AAgent extends BaseAgent {

private static final Logger logger = LoggerFactory.getLogger(RemoteA2AAgent.class);
private static final ObjectMapper objectMapper =
new ObjectMapper().registerModule(new JavaTimeModule());

private final AgentCard agentCard;
private final Client a2aClient;
Expand Down Expand Up @@ -173,60 +185,189 @@ protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
}

Message originalMessage = a2aMessageOpt.get();
String requestJson;
try {
requestJson = objectMapper.writeValueAsString(originalMessage);
} catch (JsonProcessingException e) {
logger.warn("Failed to serialize request", e);
requestJson = null;
}
String finalRequestJson = requestJson;

return Flowable.create(
emitter -> {
FlowableEmitter<Event> flowableEmitter = emitter.serialize();
AtomicBoolean done = new AtomicBoolean(false);
StreamHandler handler =
new StreamHandler(emitter.serialize(), invocationContext, finalRequestJson);
ImmutableList<BiConsumer<ClientEvent, AgentCard>> consumers =
ImmutableList.of(
(event, unused) ->
handleClientEvent(event, flowableEmitter, invocationContext, done));
a2aClient.sendMessage(
originalMessage, consumers, e -> handleClientError(e, flowableEmitter, done), null);
ImmutableList.of(handler::handleEvent);
a2aClient.sendMessage(originalMessage, consumers, handler::handleError, null);
},
BackpressureStrategy.BUFFER);
}

private void handleClientError(Throwable e, FlowableEmitter<Event> emitter, AtomicBoolean done) {
// Mark the flow as done if it is already cancelled.
done.compareAndSet(false, emitter.isCancelled());
private class StreamHandler {
private final FlowableEmitter<Event> emitter;
private final InvocationContext invocationContext;
private final String requestJson;
private final AtomicBoolean done = new AtomicBoolean(false);
private final StringBuilder textBuffer = new StringBuilder();
private final StringBuilder thoughtsBuffer = new StringBuilder();

StreamHandler(
FlowableEmitter<Event> emitter, InvocationContext invocationContext, String requestJson) {
this.emitter = emitter;
this.invocationContext = invocationContext;
this.requestJson = requestJson;
}

void handleError(Throwable e) {
// Mark the flow as done if it is already cancelled.
done.compareAndSet(false, emitter.isCancelled());

// If the flow is already done, stop processing and exit the consumer.
if (done.get()) {
return;
// If the flow is already done, stop processing.
if (done.get()) {
return;
}
// If the error is raised, complete the flow with an error.
if (!done.getAndSet(true)) {
emitter.tryOnError(new A2AClientError("Failed to communicate with the remote agent", e));
}
}
// If the error is raised, complete the flow with an error.
if (!done.getAndSet(true)) {
emitter.tryOnError(new A2AClientError("Failed to communicate with the remote agent", e));

void handleEvent(ClientEvent clientEvent, AgentCard unused) {
// Mark the flow as done if it is already cancelled.
done.compareAndSet(false, emitter.isCancelled());

// If the flow is already done, stop processing.
if (done.get()) {
return;
}

Optional<Event> eventOpt =
ResponseConverter.clientEventToEvent(clientEvent, invocationContext);
if (eventOpt.isPresent()) {
Event event = eventOpt.get();
enrichWithMetadata(event, clientEvent);
boolean consumed = processContent(event);
if (!consumed) {
emitEvents(event);
}
}

// For non-streaming communication, complete the flow; for streaming, wait until the client
// marks the completion.
if (isCompleted(clientEvent) || !streaming) {
// Only complete the flow once.
if (!done.getAndSet(true)) {
emitter.onComplete();
}
}
}
}

private void handleClientEvent(
ClientEvent clientEvent,
FlowableEmitter<Event> emitter,
InvocationContext invocationContext,
AtomicBoolean done) {
// Mark the flow as done if it is already cancelled.
done.compareAndSet(false, emitter.isCancelled());

// If the flow is already done, stop processing and exit the consumer.
if (done.get()) {
return;
private void enrichWithMetadata(Event event, ClientEvent clientEvent) {
List<CustomMetadata> eventMetadata =
new ArrayList<>(event.customMetadata().orElse(ImmutableList.of()));
if (requestJson != null) {
eventMetadata.add(
CustomMetadata.builder()
.key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST))
.stringValue(requestJson)
.build());
}
try {
if (clientEvent != null) {
eventMetadata.add(
CustomMetadata.builder()
.key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.RESPONSE))
.stringValue(objectMapper.writeValueAsString(clientEvent))
.build());
}
} catch (JsonProcessingException e) {
logger.warn("Failed to serialize response metadata", e);
}
event.setCustomMetadata(Optional.of(ImmutableList.copyOf(eventMetadata)));
}

Optional<Event> event = ResponseConverter.clientEventToEvent(clientEvent, invocationContext);
if (event.isPresent()) {
emitter.onNext(event.get());
private boolean processContent(Event event) {
if (!event.partial().orElse(false)) {
return false;
}

List<Part> nonTextParts = new ArrayList<>();
for (Part part : event.content().flatMap(Content::parts).orElse(ImmutableList.of())) {
if (part.text().isPresent()) {
String t = part.text().get();
if (part.thought().orElse(false)) {
thoughtsBuffer.append(t);
} else {
textBuffer.append(t);
}
} else {
nonTextParts.add(part);
}
}

if (nonTextParts.isEmpty()) {
return true;
}

Content nonTextContent = Content.builder().role("model").parts(nonTextParts).build();
event.setContent(Optional.of(nonTextContent));
return false;
}

// For non-streaming communication, complete the flow; for streaming, wait until the client
// marks the completion.
if (isCompleted(clientEvent) || !streaming) {
// Only complete the flow once.
if (!done.getAndSet(true)) {
emitter.onComplete();
private void emitEvents(Event event) {
List<Part> parts = new ArrayList<>();
if (thoughtsBuffer.length() > 0) {
parts.add(Part.builder().thought(true).text(thoughtsBuffer.toString()).build());
thoughtsBuffer.setLength(0);
}
if (textBuffer.length() > 0) {
parts.add(Part.builder().text(textBuffer.toString()).build());
textBuffer.setLength(0);
}

if (!parts.isEmpty()) {
Content aggregatedContent = Content.builder().role("model").parts(parts).build();

if (event.content().flatMap(Content::parts).orElse(ImmutableList.of()).isEmpty()) {
// Reuse empty event for aggregated content.
event.setContent(Optional.of(aggregatedContent));
emitter.onNext(event);
} else {
// Emit separate aggregated event first.
Event aggEvent = createAggregatedEvent(aggregatedContent);
emitter.onNext(aggEvent);
emitter.onNext(event);
}
} else {
emitter.onNext(event);
}
}

private Event createAggregatedEvent(Content content) {
List<CustomMetadata> aggMetadata = new ArrayList<>();
aggMetadata.add(
CustomMetadata.builder()
.key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.AGGREGATED))
.stringValue("true")
.build());
if (requestJson != null) {
aggMetadata.add(
CustomMetadata.builder()
.key(A2AMetadata.toA2AMetaKey(A2AMetadata.Key.REQUEST))
.stringValue(requestJson)
.build());
}

return Event.builder()
.id(UUID.randomUUID().toString())
.invocationId(invocationContext.invocationId())
.author("agent")
.content(Optional.of(content))
.timestamp(Instant.now().toEpochMilli())
.customMetadata(Optional.of(ImmutableList.copyOf(aggMetadata)))
.build();
}
}

Expand Down
28 changes: 28 additions & 0 deletions a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.google.adk.a2a.common;

/** Constants and utilities for A2A metadata keys. */
public final class A2AMetadata {

/** Enum for A2A custom metadata keys. */
public enum Key {
REQUEST("request"),
RESPONSE("response"),
AGGREGATED("aggregated");

private final String value;

Key(String value) {
this.value = value;
}

public String getValue() {
return value;
}
}

public static String toA2AMetaKey(Key key) {
return "a2a:" + key.value;
}

private A2AMetadata() {}
}
Loading
Loading