Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,34 @@
* Responsible for retrieving, saving, and updating the task based on
* events received from the agent.
*/
public class ClientTaskManager {
class ClientTaskManager {

private @Nullable Task currentTask;
private @Nullable String taskId;
private @Nullable String contextId;

public ClientTaskManager() {
ClientTaskManager() {
this.currentTask = null;
this.taskId = null;
this.contextId = null;
}

public synchronized Task getCurrentTask() throws A2AClientInvalidStateError {
synchronized Task getCurrentTask() throws A2AClientInvalidStateError {
if (currentTask == null) {
throw new A2AClientInvalidStateError("No current task");
}
return currentTask;
}

public synchronized Task saveTaskEvent(Task task) throws A2AClientInvalidArgsError {
synchronized Task saveTaskEvent(Task task) throws A2AClientInvalidArgsError {
if (currentTask != null) {
throw new A2AClientInvalidArgsError("Task is already set, create new manager for new tasks.");
}
saveTask(task);
return task;
}

public synchronized Task saveTaskEvent(TaskStatusUpdateEvent taskStatusUpdateEvent) throws A2AClientError {
synchronized Task saveTaskEvent(TaskStatusUpdateEvent taskStatusUpdateEvent) throws A2AClientError {
if (taskId == null) {
taskId = taskStatusUpdateEvent.taskId();
}
Expand Down Expand Up @@ -86,7 +86,7 @@ public synchronized Task saveTaskEvent(TaskStatusUpdateEvent taskStatusUpdateEve
return currentTask;
}

public synchronized Task saveTaskEvent(TaskArtifactUpdateEvent taskArtifactUpdateEvent) {
synchronized Task saveTaskEvent(TaskArtifactUpdateEvent taskArtifactUpdateEvent) {
if (taskId == null) {
taskId = taskArtifactUpdateEvent.taskId();
}
Expand All @@ -113,7 +113,7 @@ public synchronized Task saveTaskEvent(TaskArtifactUpdateEvent taskArtifactUpdat
* @param task the task to update
* @return the updated task
*/
public synchronized Task updateWithMessage(Message message, Task task) {
synchronized Task updateWithMessage(Message message, Task task) {
Task.Builder taskBuilder = Task.builder(task);
List<Message> history = new ArrayList<>(task.history());
if (task.status().message() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public SSEEventListener(Consumer<StreamingEventKind> eventHandler,
@Override
public void onMessage(ServerSentEvent event, @Nullable Future<Void> completableFuture) {
try {
log.fine("Streaming message received: " + event.data());
log.fine("REST SSE raw data: " + event.data());
org.a2aproject.sdk.grpc.StreamResponse.Builder builder = org.a2aproject.sdk.grpc.StreamResponse.newBuilder();
JsonFormat.parser().merge(event.data(), builder);
parseAndHandleMessage(builder.build(), completableFuture);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,46 @@
import org.a2aproject.sdk.compat03.spec.A2AClientInvalidArgsError_v0_3;
import org.a2aproject.sdk.compat03.spec.A2AClientInvalidStateError_v0_3;
import org.a2aproject.sdk.compat03.spec.Message_v0_3;
import org.a2aproject.sdk.compat03.spec.Task_v0_3;
import org.a2aproject.sdk.compat03.spec.TaskArtifactUpdateEvent_v0_3;
import org.a2aproject.sdk.compat03.spec.TaskState_v0_3;
import org.a2aproject.sdk.compat03.spec.TaskStatus_v0_3;
import org.a2aproject.sdk.compat03.spec.TaskStatusUpdateEvent_v0_3;
import org.a2aproject.sdk.compat03.spec.TaskStatus_v0_3;
import org.a2aproject.sdk.compat03.spec.Task_v0_3;
import org.jspecify.annotations.Nullable;

/**
* Helps manage a task's lifecycle during the execution of a request.
* Responsible for retrieving, saving, and updating the task based on
* events received from the agent.
*/
public class ClientTaskManager_v0_3 {
class ClientTaskManager_v0_3 {

private @Nullable Task_v0_3 currentTask;
private @Nullable String taskId;
private @Nullable String contextId;

public ClientTaskManager_v0_3() {
ClientTaskManager_v0_3() {
this.currentTask = null;
this.taskId = null;
this.contextId = null;
}

public Task_v0_3 getCurrentTask() throws A2AClientInvalidStateError_v0_3 {
Task_v0_3 getCurrentTask() throws A2AClientInvalidStateError_v0_3 {
if (currentTask == null) {
throw new A2AClientInvalidStateError_v0_3("No current task");
}
return currentTask;
}

public Task_v0_3 saveTaskEvent(Task_v0_3 task) throws A2AClientInvalidArgsError_v0_3 {
if (currentTask != null) {
Task_v0_3 saveTaskEvent(Task_v0_3 task) throws A2AClientInvalidArgsError_v0_3 {
if (currentTask != null && !currentTask.id().equals(task.id())) {
throw new A2AClientInvalidArgsError_v0_3("Task is already set, create new manager for new tasks.");
}
saveTask(task);
return task;
}

public Task_v0_3 saveTaskEvent(TaskStatusUpdateEvent_v0_3 taskStatusUpdateEvent) throws A2AClientError_v0_3 {
Task_v0_3 saveTaskEvent(TaskStatusUpdateEvent_v0_3 taskStatusUpdateEvent) throws A2AClientError_v0_3 {
if (taskId == null) {
taskId = taskStatusUpdateEvent.taskId();
}
Expand Down Expand Up @@ -82,7 +82,7 @@ public Task_v0_3 saveTaskEvent(TaskStatusUpdateEvent_v0_3 taskStatusUpdateEvent)
return currentTask;
}

public Task_v0_3 saveTaskEvent(TaskArtifactUpdateEvent_v0_3 taskArtifactUpdateEvent) {
Task_v0_3 saveTaskEvent(TaskArtifactUpdateEvent_v0_3 taskArtifactUpdateEvent) {
if (taskId == null) {
taskId = taskArtifactUpdateEvent.taskId();
}
Expand All @@ -109,7 +109,7 @@ public Task_v0_3 saveTaskEvent(TaskArtifactUpdateEvent_v0_3 taskArtifactUpdateEv
* @param task the task to update
* @return the updated task
*/
public Task_v0_3 updateWithMessage(Message_v0_3 message, Task_v0_3 task) {
Task_v0_3 updateWithMessage(Message_v0_3 message, Task_v0_3 task) {
Task_v0_3.Builder taskBuilder = new Task_v0_3.Builder(task);
List<Message_v0_3> history = new ArrayList<>(task.history());
if (task.status().message() != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.a2aproject.sdk.compat03.server.grpc.quarkus;

import java.util.concurrent.Executor;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

import org.a2aproject.sdk.server.util.async.Internal;
import io.grpc.Context;
import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;

/**
* v0.3 variant of {@code BlockingOffloadInterceptor} in the reference/grpc module.
*/
@ApplicationScoped
public class BlockingOffloadInterceptor_v0_3 implements ServerInterceptor {

private final Executor executor;

@Inject
public BlockingOffloadInterceptor_v0_3(@Internal Executor executor) {
this.executor = executor;
}

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {

MethodDescriptor.MethodType type = call.getMethodDescriptor().getType();
if (type == MethodDescriptor.MethodType.CLIENT_STREAMING
|| type == MethodDescriptor.MethodType.BIDI_STREAMING) {
return next.startCall(call, headers);
}

ServerCall.Listener<ReqT> delegate = next.startCall(call, headers);

return new SimpleForwardingServerCallListener<ReqT>(delegate) {
@Override
public void onHalfClose() {
Context grpcContext = Context.current().fork();
try {
executor.execute(() -> {
Context previous = grpcContext.attach();
try {
super.onHalfClose();
} catch (Exception e) {
call.close(Status.INTERNAL.withDescription("Error during execution: " + e.getMessage()), new Metadata());
} finally {
grpcContext.detach(previous);
}
});
Comment thread
kabir marked this conversation as resolved.
} catch (Exception e) {
call.close(Status.INTERNAL.withDescription("Failed to offload to worker thread: " + e.getMessage()), new Metadata());
}
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.quarkus.grpc.GrpcService;
import io.quarkus.grpc.RegisterInterceptor;
import io.quarkus.security.Authenticated;
import io.smallrye.common.annotation.Blocking;
import org.a2aproject.sdk.compat03.conversion.Convert_v0_3_To10RequestHandler;
import org.a2aproject.sdk.compat03.spec.AgentCard_v0_3;
import org.a2aproject.sdk.compat03.transport.grpc.handler.CallContextFactory_v0_3;
Expand All @@ -17,7 +18,9 @@

@GrpcService
@RegisterInterceptor(A2AExtensionsInterceptor_v0_3.class)
@RegisterInterceptor(BlockingOffloadInterceptor_v0_3.class)
@Authenticated
@Blocking
public class QuarkusGrpcHandler_v0_3 extends GrpcHandler_v0_3 {

private final AgentCard_v0_3 agentCard;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void setupRoutes(@Observes Router router) {
} catch (Exception e) {
VertxSecurityHelper.handleGenericError(ctx);
}
});
}, false);

// Only register v0.3 agent card if no real v1.0 agent card producer exists.
// DefaultProducers provides a @DefaultBean AgentCard fallback that is always
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,52 +75,52 @@ void setupRouter(@Observes @Priority(10) Router router) {
.handler(BodyHandler.create())
.blockingHandler(authenticated(ctx -> {
sendMessage(extractBody(ctx), ctx);
}));
}), false);

// POST /v1/message:stream
router.postWithRegex("^\\/v1\\/message:stream$")
.handler(BodyHandler.create())
.blockingHandler(authenticatedStreaming(ctx -> {
sendMessageStreaming(extractBody(ctx), ctx);
}));
}), false);

// GET /v1/tasks/:id
router.get("/v1/tasks/:id")
.order(1)
.blockingHandler(authenticated(this::getTask));
.blockingHandler(authenticated(this::getTask), false);

// POST /v1/tasks/{id}:cancel
router.postWithRegex("^\\/v1\\/tasks\\/([^/]+):cancel$")
.order(1)
.blockingHandler(authenticated(this::cancelTask));
.blockingHandler(authenticated(this::cancelTask), false);

// POST /v1/tasks/{id}:subscribe
router.postWithRegex("^\\/v1\\/tasks\\/([^/]+):subscribe$")
.order(1)
.blockingHandler(authenticatedStreaming(this::resubscribeTask));
.blockingHandler(authenticatedStreaming(this::resubscribeTask), false);

// POST /v1/tasks/:id/pushNotificationConfigs
router.post("/v1/tasks/:id/pushNotificationConfigs")
.order(1)
.handler(BodyHandler.create())
.blockingHandler(authenticated(ctx -> {
setTaskPushNotificationConfiguration(extractBody(ctx), ctx);
}));
}), false);

// GET /v1/tasks/:id/pushNotificationConfigs/:configId
router.get("/v1/tasks/:id/pushNotificationConfigs/:configId")
.order(1)
.blockingHandler(authenticated(this::getTaskPushNotificationConfiguration));
.blockingHandler(authenticated(this::getTaskPushNotificationConfiguration), false);

// GET /v1/tasks/:id/pushNotificationConfigs
router.get("/v1/tasks/:id/pushNotificationConfigs")
.order(2)
.blockingHandler(authenticated(this::listTaskPushNotificationConfigurations));
.blockingHandler(authenticated(this::listTaskPushNotificationConfigurations), false);

// DELETE /v1/tasks/:id/pushNotificationConfigs/:configId
router.delete("/v1/tasks/:id/pushNotificationConfigs/:configId")
.order(1)
.blockingHandler(authenticated(this::deleteTaskPushNotificationConfiguration));
.blockingHandler(authenticated(this::deleteTaskPushNotificationConfiguration), false);

// Only register v0.3 agent card if no real v1.0 agent card producer exists.
// DefaultProducers provides a @DefaultBean AgentCard fallback that is always
Expand All @@ -136,7 +136,7 @@ void setupRouter(@Observes @Priority(10) Router router) {
router.get("/v1/card")
.order(1)
.produces(APPLICATION_JSON)
.blockingHandler(authenticated(this::getAuthenticatedExtendedCard));
.blockingHandler(authenticated(this::getAuthenticatedExtendedCard), false);
}

private Handler<RoutingContext> authenticated(Consumer<RoutingContext> action) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ default FileContent toV10(FileContent_v0_3 v03) {
}

if (v03 instanceof FileWithBytes_v0_3 v03Bytes) {
return new FileWithBytes(v03Bytes.mimeType(), v03Bytes.name(), v03Bytes.bytes());
String name = v03Bytes.name() != null ? v03Bytes.name() : "";
return new FileWithBytes(v03Bytes.mimeType(), name, v03Bytes.bytes());
} else if (v03 instanceof FileWithUri_v0_3 v03Uri) {
return new FileWithUri(v03Uri.mimeType(), v03Uri.name(), v03Uri.uri());
String name = v03Uri.name() != null ? v03Uri.name() : "";
return new FileWithUri(v03Uri.mimeType(), name, v03Uri.uri());
}
Comment thread
kabir marked this conversation as resolved.

throw new InvalidRequestError(null, "Unrecognized FileContent type: " + v03.getClass().getName(), null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
import org.a2aproject.sdk.server.auth.UnauthenticatedUser;
import org.a2aproject.sdk.server.auth.User;
import org.a2aproject.sdk.spec.A2AError;
import io.grpc.Context;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.UnaryOperator;

/**
* Abstract gRPC handler for v0.3 protocol with translation layer to v1.0.
Expand Down Expand Up @@ -235,6 +237,7 @@ public void sendStreamingMessage(org.a2aproject.sdk.compat03.grpc.SendMessageReq

try {
ServerCallContext context = createCallContext(responseObserver);
installForkedContextWrapper(context);
MessageSendParams_v0_3 params = FromProto.messageSendParams(request);
Flow.Publisher<StreamingEventKind_v0_3> publisher = requestHandler.onMessageSendStream(params, context);
convertToStreamResponse(publisher, responseObserver);
Expand All @@ -259,6 +262,7 @@ public void taskSubscription(org.a2aproject.sdk.compat03.grpc.TaskSubscriptionRe

try {
ServerCallContext context = createCallContext(responseObserver);
installForkedContextWrapper(context);
TaskIdParams_v0_3 params = FromProto.taskIdParams(request);
Flow.Publisher<StreamingEventKind_v0_3> publisher = requestHandler.onResubscribeToTask(params, context);
convertToStreamResponse(publisher, responseObserver);
Expand All @@ -273,6 +277,19 @@ public void taskSubscription(org.a2aproject.sdk.compat03.grpc.TaskSubscriptionRe
}
}

private void installForkedContextWrapper(ServerCallContext context) {
Context forked = Context.current().fork();
context.getState().put(ServerCallContext.EXECUTION_WRAPPER_KEY,
(UnaryOperator<Runnable>) (runnable -> () -> {
Context prev = forked.attach();
try {
runnable.run();
} finally {
forked.detach(prev);
}
}));
}

private void convertToStreamResponse(Flow.Publisher<StreamingEventKind_v0_3> publisher,
StreamObserver<org.a2aproject.sdk.compat03.grpc.StreamResponse> responseObserver) {
CompletableFuture.runAsync(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonElement;
import com.google.gson.JsonParser;
import com.google.gson.JsonSyntaxException;
import com.google.gson.ToNumberPolicy;
Expand Down Expand Up @@ -474,7 +475,11 @@ public static Map<String, Object> readMetadata(@Nullable String json) throws Jso
return Collections.emptyMap();
}
try {
return readMetadata(JsonParser.parseString(json).getAsJsonObject());
JsonElement element = JsonParser.parseString(json);
if (!element.isJsonObject()) {
return Collections.emptyMap();
}
return readMetadata(element.getAsJsonObject());
} catch (JsonSyntaxException e) {
throw new JsonProcessingException("Failed to parse metadata JSON", e);
}
Expand Down
Loading
Loading