diff --git a/client/base/src/main/java/org/a2aproject/sdk/client/ClientTaskManager.java b/client/base/src/main/java/org/a2aproject/sdk/client/ClientTaskManager.java index 94c9c21a5..0a962eab7 100644 --- a/client/base/src/main/java/org/a2aproject/sdk/client/ClientTaskManager.java +++ b/client/base/src/main/java/org/a2aproject/sdk/client/ClientTaskManager.java @@ -23,26 +23,26 @@ * Responsible for retrieving, saving, and updating the task based on * events received from the agent. */ -public class ClientTaskManager { +class ClientTaskManager { private @Nullable Task currentTask; private @Nullable String taskId; private @Nullable String contextId; - public ClientTaskManager() { + ClientTaskManager() { this.currentTask = null; this.taskId = null; this.contextId = null; } - public synchronized Task getCurrentTask() throws A2AClientInvalidStateError { + synchronized Task getCurrentTask() throws A2AClientInvalidStateError { if (currentTask == null) { throw new A2AClientInvalidStateError("No current task"); } return currentTask; } - public synchronized Task saveTaskEvent(Task task) throws A2AClientInvalidArgsError { + synchronized Task saveTaskEvent(Task task) throws A2AClientInvalidArgsError { if (currentTask != null) { throw new A2AClientInvalidArgsError("Task is already set, create new manager for new tasks."); } @@ -50,7 +50,7 @@ public synchronized Task saveTaskEvent(Task task) throws A2AClientInvalidArgsErr return task; } - public synchronized Task saveTaskEvent(TaskStatusUpdateEvent taskStatusUpdateEvent) throws A2AClientError { + synchronized Task saveTaskEvent(TaskStatusUpdateEvent taskStatusUpdateEvent) throws A2AClientError { if (taskId == null) { taskId = taskStatusUpdateEvent.taskId(); } @@ -86,7 +86,7 @@ public synchronized Task saveTaskEvent(TaskStatusUpdateEvent taskStatusUpdateEve return currentTask; } - public synchronized Task saveTaskEvent(TaskArtifactUpdateEvent taskArtifactUpdateEvent) { + synchronized Task saveTaskEvent(TaskArtifactUpdateEvent taskArtifactUpdateEvent) { if (taskId == null) { taskId = taskArtifactUpdateEvent.taskId(); } @@ -113,7 +113,7 @@ public synchronized Task saveTaskEvent(TaskArtifactUpdateEvent taskArtifactUpdat * @param task the task to update * @return the updated task */ - public synchronized Task updateWithMessage(Message message, Task task) { + synchronized Task updateWithMessage(Message message, Task task) { Task.Builder taskBuilder = Task.builder(task); List history = new ArrayList<>(task.history()); if (task.status().message() != null) { diff --git a/client/transport/rest/src/main/java/org/a2aproject/sdk/client/transport/rest/sse/SSEEventListener.java b/client/transport/rest/src/main/java/org/a2aproject/sdk/client/transport/rest/sse/SSEEventListener.java index a8265d5b4..8ee78d0a9 100644 --- a/client/transport/rest/src/main/java/org/a2aproject/sdk/client/transport/rest/sse/SSEEventListener.java +++ b/client/transport/rest/src/main/java/org/a2aproject/sdk/client/transport/rest/sse/SSEEventListener.java @@ -30,7 +30,7 @@ public SSEEventListener(Consumer eventHandler, @Override public void onMessage(ServerSentEvent event, @Nullable Future completableFuture) { try { - log.fine("Streaming message received: " + event.data()); + log.fine("REST SSE raw data: " + event.data()); org.a2aproject.sdk.grpc.StreamResponse.Builder builder = org.a2aproject.sdk.grpc.StreamResponse.newBuilder(); JsonFormat.parser().merge(event.data(), builder); parseAndHandleMessage(builder.build(), completableFuture); diff --git a/compat-0.3/client/base/src/main/java/org/a2aproject/sdk/compat03/client/ClientTaskManager_v0_3.java b/compat-0.3/client/base/src/main/java/org/a2aproject/sdk/compat03/client/ClientTaskManager_v0_3.java index 1d8bc1de7..5d240a63c 100644 --- a/compat-0.3/client/base/src/main/java/org/a2aproject/sdk/compat03/client/ClientTaskManager_v0_3.java +++ b/compat-0.3/client/base/src/main/java/org/a2aproject/sdk/compat03/client/ClientTaskManager_v0_3.java @@ -11,11 +11,11 @@ import org.a2aproject.sdk.compat03.spec.A2AClientInvalidArgsError_v0_3; import org.a2aproject.sdk.compat03.spec.A2AClientInvalidStateError_v0_3; import org.a2aproject.sdk.compat03.spec.Message_v0_3; -import org.a2aproject.sdk.compat03.spec.Task_v0_3; import org.a2aproject.sdk.compat03.spec.TaskArtifactUpdateEvent_v0_3; import org.a2aproject.sdk.compat03.spec.TaskState_v0_3; -import org.a2aproject.sdk.compat03.spec.TaskStatus_v0_3; import org.a2aproject.sdk.compat03.spec.TaskStatusUpdateEvent_v0_3; +import org.a2aproject.sdk.compat03.spec.TaskStatus_v0_3; +import org.a2aproject.sdk.compat03.spec.Task_v0_3; import org.jspecify.annotations.Nullable; /** @@ -23,34 +23,34 @@ * Responsible for retrieving, saving, and updating the task based on * events received from the agent. */ -public class ClientTaskManager_v0_3 { +class ClientTaskManager_v0_3 { private @Nullable Task_v0_3 currentTask; private @Nullable String taskId; private @Nullable String contextId; - public ClientTaskManager_v0_3() { + ClientTaskManager_v0_3() { this.currentTask = null; this.taskId = null; this.contextId = null; } - public Task_v0_3 getCurrentTask() throws A2AClientInvalidStateError_v0_3 { + Task_v0_3 getCurrentTask() throws A2AClientInvalidStateError_v0_3 { if (currentTask == null) { throw new A2AClientInvalidStateError_v0_3("No current task"); } return currentTask; } - public Task_v0_3 saveTaskEvent(Task_v0_3 task) throws A2AClientInvalidArgsError_v0_3 { - if (currentTask != null) { + Task_v0_3 saveTaskEvent(Task_v0_3 task) throws A2AClientInvalidArgsError_v0_3 { + if (currentTask != null && !currentTask.id().equals(task.id())) { throw new A2AClientInvalidArgsError_v0_3("Task is already set, create new manager for new tasks."); } saveTask(task); return task; } - public Task_v0_3 saveTaskEvent(TaskStatusUpdateEvent_v0_3 taskStatusUpdateEvent) throws A2AClientError_v0_3 { + Task_v0_3 saveTaskEvent(TaskStatusUpdateEvent_v0_3 taskStatusUpdateEvent) throws A2AClientError_v0_3 { if (taskId == null) { taskId = taskStatusUpdateEvent.taskId(); } @@ -82,7 +82,7 @@ public Task_v0_3 saveTaskEvent(TaskStatusUpdateEvent_v0_3 taskStatusUpdateEvent) return currentTask; } - public Task_v0_3 saveTaskEvent(TaskArtifactUpdateEvent_v0_3 taskArtifactUpdateEvent) { + Task_v0_3 saveTaskEvent(TaskArtifactUpdateEvent_v0_3 taskArtifactUpdateEvent) { if (taskId == null) { taskId = taskArtifactUpdateEvent.taskId(); } @@ -109,7 +109,7 @@ public Task_v0_3 saveTaskEvent(TaskArtifactUpdateEvent_v0_3 taskArtifactUpdateEv * @param task the task to update * @return the updated task */ - public Task_v0_3 updateWithMessage(Message_v0_3 message, Task_v0_3 task) { + Task_v0_3 updateWithMessage(Message_v0_3 message, Task_v0_3 task) { Task_v0_3.Builder taskBuilder = new Task_v0_3.Builder(task); List history = new ArrayList<>(task.history()); if (task.status().message() != null) { diff --git a/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/BlockingOffloadInterceptor_v0_3.java b/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/BlockingOffloadInterceptor_v0_3.java new file mode 100644 index 000000000..3c79f8474 --- /dev/null +++ b/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/BlockingOffloadInterceptor_v0_3.java @@ -0,0 +1,66 @@ +package org.a2aproject.sdk.compat03.server.grpc.quarkus; + +import java.util.concurrent.Executor; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import org.a2aproject.sdk.server.util.async.Internal; +import io.grpc.Context; +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; + +/** + * v0.3 variant of {@code BlockingOffloadInterceptor} in the reference/grpc module. + */ +@ApplicationScoped +public class BlockingOffloadInterceptor_v0_3 implements ServerInterceptor { + + private final Executor executor; + + @Inject + public BlockingOffloadInterceptor_v0_3(@Internal Executor executor) { + this.executor = executor; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata headers, + ServerCallHandler next) { + + MethodDescriptor.MethodType type = call.getMethodDescriptor().getType(); + if (type == MethodDescriptor.MethodType.CLIENT_STREAMING + || type == MethodDescriptor.MethodType.BIDI_STREAMING) { + return next.startCall(call, headers); + } + + ServerCall.Listener delegate = next.startCall(call, headers); + + return new SimpleForwardingServerCallListener(delegate) { + @Override + public void onHalfClose() { + Context grpcContext = Context.current().fork(); + try { + executor.execute(() -> { + Context previous = grpcContext.attach(); + try { + super.onHalfClose(); + } catch (Exception e) { + call.close(Status.INTERNAL.withDescription("Error during execution: " + e.getMessage()), new Metadata()); + } finally { + grpcContext.detach(previous); + } + }); + } catch (Exception e) { + call.close(Status.INTERNAL.withDescription("Failed to offload to worker thread: " + e.getMessage()), new Metadata()); + } + } + }; + } +} diff --git a/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusGrpcHandler_v0_3.java b/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusGrpcHandler_v0_3.java index 3364f72c9..d79871387 100644 --- a/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusGrpcHandler_v0_3.java +++ b/compat-0.3/reference/grpc/src/main/java/org/a2aproject/sdk/compat03/server/grpc/quarkus/QuarkusGrpcHandler_v0_3.java @@ -8,6 +8,7 @@ import io.quarkus.grpc.GrpcService; import io.quarkus.grpc.RegisterInterceptor; import io.quarkus.security.Authenticated; +import io.smallrye.common.annotation.Blocking; import org.a2aproject.sdk.compat03.conversion.Convert_v0_3_To10RequestHandler; import org.a2aproject.sdk.compat03.spec.AgentCard_v0_3; import org.a2aproject.sdk.compat03.transport.grpc.handler.CallContextFactory_v0_3; @@ -17,7 +18,9 @@ @GrpcService @RegisterInterceptor(A2AExtensionsInterceptor_v0_3.class) +@RegisterInterceptor(BlockingOffloadInterceptor_v0_3.class) @Authenticated +@Blocking public class QuarkusGrpcHandler_v0_3 extends GrpcHandler_v0_3 { private final AgentCard_v0_3 agentCard; diff --git a/compat-0.3/reference/jsonrpc/src/main/java/org/a2aproject/sdk/compat03/server/apps/quarkus/A2AServerRoutes_v0_3.java b/compat-0.3/reference/jsonrpc/src/main/java/org/a2aproject/sdk/compat03/server/apps/quarkus/A2AServerRoutes_v0_3.java index 6264e7091..e4bf2434c 100644 --- a/compat-0.3/reference/jsonrpc/src/main/java/org/a2aproject/sdk/compat03/server/apps/quarkus/A2AServerRoutes_v0_3.java +++ b/compat-0.3/reference/jsonrpc/src/main/java/org/a2aproject/sdk/compat03/server/apps/quarkus/A2AServerRoutes_v0_3.java @@ -97,7 +97,7 @@ void setupRoutes(@Observes Router router) { } catch (Exception e) { VertxSecurityHelper.handleGenericError(ctx); } - }); + }, false); // Only register v0.3 agent card if no real v1.0 agent card producer exists. // DefaultProducers provides a @DefaultBean AgentCard fallback that is always diff --git a/compat-0.3/reference/rest/src/main/java/org/a2aproject/sdk/compat03/server/rest/quarkus/A2AServerRoutes_v0_3.java b/compat-0.3/reference/rest/src/main/java/org/a2aproject/sdk/compat03/server/rest/quarkus/A2AServerRoutes_v0_3.java index f00795328..da53ed42d 100644 --- a/compat-0.3/reference/rest/src/main/java/org/a2aproject/sdk/compat03/server/rest/quarkus/A2AServerRoutes_v0_3.java +++ b/compat-0.3/reference/rest/src/main/java/org/a2aproject/sdk/compat03/server/rest/quarkus/A2AServerRoutes_v0_3.java @@ -75,29 +75,29 @@ void setupRouter(@Observes @Priority(10) Router router) { .handler(BodyHandler.create()) .blockingHandler(authenticated(ctx -> { sendMessage(extractBody(ctx), ctx); - })); + }), false); // POST /v1/message:stream router.postWithRegex("^\\/v1\\/message:stream$") .handler(BodyHandler.create()) .blockingHandler(authenticatedStreaming(ctx -> { sendMessageStreaming(extractBody(ctx), ctx); - })); + }), false); // GET /v1/tasks/:id router.get("/v1/tasks/:id") .order(1) - .blockingHandler(authenticated(this::getTask)); + .blockingHandler(authenticated(this::getTask), false); // POST /v1/tasks/{id}:cancel router.postWithRegex("^\\/v1\\/tasks\\/([^/]+):cancel$") .order(1) - .blockingHandler(authenticated(this::cancelTask)); + .blockingHandler(authenticated(this::cancelTask), false); // POST /v1/tasks/{id}:subscribe router.postWithRegex("^\\/v1\\/tasks\\/([^/]+):subscribe$") .order(1) - .blockingHandler(authenticatedStreaming(this::resubscribeTask)); + .blockingHandler(authenticatedStreaming(this::resubscribeTask), false); // POST /v1/tasks/:id/pushNotificationConfigs router.post("/v1/tasks/:id/pushNotificationConfigs") @@ -105,22 +105,22 @@ void setupRouter(@Observes @Priority(10) Router router) { .handler(BodyHandler.create()) .blockingHandler(authenticated(ctx -> { setTaskPushNotificationConfiguration(extractBody(ctx), ctx); - })); + }), false); // GET /v1/tasks/:id/pushNotificationConfigs/:configId router.get("/v1/tasks/:id/pushNotificationConfigs/:configId") .order(1) - .blockingHandler(authenticated(this::getTaskPushNotificationConfiguration)); + .blockingHandler(authenticated(this::getTaskPushNotificationConfiguration), false); // GET /v1/tasks/:id/pushNotificationConfigs router.get("/v1/tasks/:id/pushNotificationConfigs") .order(2) - .blockingHandler(authenticated(this::listTaskPushNotificationConfigurations)); + .blockingHandler(authenticated(this::listTaskPushNotificationConfigurations), false); // DELETE /v1/tasks/:id/pushNotificationConfigs/:configId router.delete("/v1/tasks/:id/pushNotificationConfigs/:configId") .order(1) - .blockingHandler(authenticated(this::deleteTaskPushNotificationConfiguration)); + .blockingHandler(authenticated(this::deleteTaskPushNotificationConfiguration), false); // Only register v0.3 agent card if no real v1.0 agent card producer exists. // DefaultProducers provides a @DefaultBean AgentCard fallback that is always @@ -136,7 +136,7 @@ void setupRouter(@Observes @Priority(10) Router router) { router.get("/v1/card") .order(1) .produces(APPLICATION_JSON) - .blockingHandler(authenticated(this::getAuthenticatedExtendedCard)); + .blockingHandler(authenticated(this::getAuthenticatedExtendedCard), false); } private Handler authenticated(Consumer action) { diff --git a/compat-0.3/server-conversion/src/main/java/org/a2aproject/sdk/compat03/conversion/mappers/domain/FileContentMapper_v0_3.java b/compat-0.3/server-conversion/src/main/java/org/a2aproject/sdk/compat03/conversion/mappers/domain/FileContentMapper_v0_3.java index c77998042..a39f840cf 100644 --- a/compat-0.3/server-conversion/src/main/java/org/a2aproject/sdk/compat03/conversion/mappers/domain/FileContentMapper_v0_3.java +++ b/compat-0.3/server-conversion/src/main/java/org/a2aproject/sdk/compat03/conversion/mappers/domain/FileContentMapper_v0_3.java @@ -51,9 +51,11 @@ default FileContent toV10(FileContent_v0_3 v03) { } if (v03 instanceof FileWithBytes_v0_3 v03Bytes) { - return new FileWithBytes(v03Bytes.mimeType(), v03Bytes.name(), v03Bytes.bytes()); + String name = v03Bytes.name() != null ? v03Bytes.name() : ""; + return new FileWithBytes(v03Bytes.mimeType(), name, v03Bytes.bytes()); } else if (v03 instanceof FileWithUri_v0_3 v03Uri) { - return new FileWithUri(v03Uri.mimeType(), v03Uri.name(), v03Uri.uri()); + String name = v03Uri.name() != null ? v03Uri.name() : ""; + return new FileWithUri(v03Uri.mimeType(), name, v03Uri.uri()); } throw new InvalidRequestError(null, "Unrecognized FileContent type: " + v03.getClass().getName(), null); diff --git a/compat-0.3/transport/grpc/src/main/java/org/a2aproject/sdk/compat03/transport/grpc/handler/GrpcHandler_v0_3.java b/compat-0.3/transport/grpc/src/main/java/org/a2aproject/sdk/compat03/transport/grpc/handler/GrpcHandler_v0_3.java index 2f3795a81..72ce9faf0 100644 --- a/compat-0.3/transport/grpc/src/main/java/org/a2aproject/sdk/compat03/transport/grpc/handler/GrpcHandler_v0_3.java +++ b/compat-0.3/transport/grpc/src/main/java/org/a2aproject/sdk/compat03/transport/grpc/handler/GrpcHandler_v0_3.java @@ -43,12 +43,14 @@ import org.a2aproject.sdk.server.auth.UnauthenticatedUser; import org.a2aproject.sdk.server.auth.User; import org.a2aproject.sdk.spec.A2AError; +import io.grpc.Context; import io.grpc.Status; import io.grpc.stub.StreamObserver; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.function.UnaryOperator; /** * Abstract gRPC handler for v0.3 protocol with translation layer to v1.0. @@ -235,6 +237,7 @@ public void sendStreamingMessage(org.a2aproject.sdk.compat03.grpc.SendMessageReq try { ServerCallContext context = createCallContext(responseObserver); + installForkedContextWrapper(context); MessageSendParams_v0_3 params = FromProto.messageSendParams(request); Flow.Publisher publisher = requestHandler.onMessageSendStream(params, context); convertToStreamResponse(publisher, responseObserver); @@ -259,6 +262,7 @@ public void taskSubscription(org.a2aproject.sdk.compat03.grpc.TaskSubscriptionRe try { ServerCallContext context = createCallContext(responseObserver); + installForkedContextWrapper(context); TaskIdParams_v0_3 params = FromProto.taskIdParams(request); Flow.Publisher publisher = requestHandler.onResubscribeToTask(params, context); convertToStreamResponse(publisher, responseObserver); @@ -273,6 +277,19 @@ public void taskSubscription(org.a2aproject.sdk.compat03.grpc.TaskSubscriptionRe } } + private void installForkedContextWrapper(ServerCallContext context) { + Context forked = Context.current().fork(); + context.getState().put(ServerCallContext.EXECUTION_WRAPPER_KEY, + (UnaryOperator) (runnable -> () -> { + Context prev = forked.attach(); + try { + runnable.run(); + } finally { + forked.detach(prev); + } + })); + } + private void convertToStreamResponse(Flow.Publisher publisher, StreamObserver responseObserver) { CompletableFuture.runAsync(() -> { diff --git a/jsonrpc-common/src/main/java/org/a2aproject/sdk/jsonrpc/common/json/JsonUtil.java b/jsonrpc-common/src/main/java/org/a2aproject/sdk/jsonrpc/common/json/JsonUtil.java index 01666237f..8a43260e0 100644 --- a/jsonrpc-common/src/main/java/org/a2aproject/sdk/jsonrpc/common/json/JsonUtil.java +++ b/jsonrpc-common/src/main/java/org/a2aproject/sdk/jsonrpc/common/json/JsonUtil.java @@ -20,6 +20,7 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import com.google.gson.JsonElement; import com.google.gson.JsonParser; import com.google.gson.JsonSyntaxException; import com.google.gson.ToNumberPolicy; @@ -474,7 +475,11 @@ public static Map readMetadata(@Nullable String json) throws Jso return Collections.emptyMap(); } try { - return readMetadata(JsonParser.parseString(json).getAsJsonObject()); + JsonElement element = JsonParser.parseString(json); + if (!element.isJsonObject()) { + return Collections.emptyMap(); + } + return readMetadata(element.getAsJsonObject()); } catch (JsonSyntaxException e) { throw new JsonProcessingException("Failed to parse metadata JSON", e); } diff --git a/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/BlockingOffloadInterceptor.java b/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/BlockingOffloadInterceptor.java new file mode 100644 index 000000000..8ba6f2939 --- /dev/null +++ b/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/BlockingOffloadInterceptor.java @@ -0,0 +1,81 @@ +package org.a2aproject.sdk.server.grpc.quarkus; + +import java.util.concurrent.Executor; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import org.a2aproject.sdk.server.util.async.Internal; +import io.grpc.Context; +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; + +/** + * gRPC server interceptor that offloads handler execution from the Vert.x event loop + * to a worker thread, using a forked gRPC {@link Context}. + * + *

In Quarkus separate-server mode ({@code quarkus.grpc.server.use-separate-server=true}), + * the {@code @Blocking} annotation is ignored and gRPC handlers run on the Vert.x event loop. + * Synchronous operations like {@code sendMessage()} deadlock the event loop. This interceptor + * wraps the {@code onHalfClose()} callback to run the handler on a worker thread instead. + * + *

The context is forked ({@link Context#fork()}) so that the handler's outbound gRPC calls + * do not inherit the inbound call's cancellation signal. Without this, Quarkus' + * {@code ContextStorageOverride} propagates the server context through the + * {@code ManagedExecutor}, causing outbound client calls to be cancelled when the + * inbound caller disconnects. + * + *

This applies to both unary and server-streaming methods (which both have a single + * inbound request). Client-streaming and bidi-streaming methods are excluded. + */ +@ApplicationScoped +public class BlockingOffloadInterceptor implements ServerInterceptor { + + private final Executor executor; + + @Inject + public BlockingOffloadInterceptor(@Internal Executor executor) { + this.executor = executor; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata headers, + ServerCallHandler next) { + + MethodDescriptor.MethodType type = call.getMethodDescriptor().getType(); + if (type == MethodDescriptor.MethodType.CLIENT_STREAMING + || type == MethodDescriptor.MethodType.BIDI_STREAMING) { + return next.startCall(call, headers); + } + + ServerCall.Listener delegate = next.startCall(call, headers); + + return new SimpleForwardingServerCallListener(delegate) { + @Override + public void onHalfClose() { + Context grpcContext = Context.current().fork(); + try { + executor.execute(() -> { + Context previous = grpcContext.attach(); + try { + super.onHalfClose(); + } catch (Exception e) { + call.close(Status.INTERNAL.withDescription("Error during execution: " + e.getMessage()), new Metadata()); + } finally { + grpcContext.detach(previous); + } + }); + } catch (Exception e) { + call.close(Status.INTERNAL.withDescription("Failed to offload to worker thread: " + e.getMessage()), new Metadata()); + } + } + }; + } +} diff --git a/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusGrpcHandler.java b/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusGrpcHandler.java index e23656b8a..8df6e0699 100644 --- a/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusGrpcHandler.java +++ b/reference/grpc/src/main/java/org/a2aproject/sdk/server/grpc/quarkus/QuarkusGrpcHandler.java @@ -15,6 +15,7 @@ import io.quarkus.grpc.GrpcService; import io.quarkus.grpc.RegisterInterceptor; import io.quarkus.security.Authenticated; +import io.smallrye.common.annotation.Blocking; import org.jspecify.annotations.Nullable; /** @@ -68,7 +69,9 @@ */ @GrpcService @RegisterInterceptor(A2AExtensionsInterceptor.class) +@RegisterInterceptor(BlockingOffloadInterceptor.class) @Authenticated +@Blocking public class QuarkusGrpcHandler extends GrpcHandler { private final AgentCard agentCard; diff --git a/reference/multiversion-jsonrpc/src/main/java/org/a2aproject/sdk/server/multiversion/jsonrpc/MultiVersionJSONRPCRoutes.java b/reference/multiversion-jsonrpc/src/main/java/org/a2aproject/sdk/server/multiversion/jsonrpc/MultiVersionJSONRPCRoutes.java index e8bc6c405..aaf90f42e 100644 --- a/reference/multiversion-jsonrpc/src/main/java/org/a2aproject/sdk/server/multiversion/jsonrpc/MultiVersionJSONRPCRoutes.java +++ b/reference/multiversion-jsonrpc/src/main/java/org/a2aproject/sdk/server/multiversion/jsonrpc/MultiVersionJSONRPCRoutes.java @@ -62,6 +62,6 @@ void setupRoutes(@Observes Router router) { } catch (Exception e) { VertxSecurityHelper.handleGenericError(ctx); } - }); + }, false); } } diff --git a/reference/multiversion-rest/src/main/java/org/a2aproject/sdk/server/multiversion/rest/MultiVersionRestRoutes.java b/reference/multiversion-rest/src/main/java/org/a2aproject/sdk/server/multiversion/rest/MultiVersionRestRoutes.java index f607d0437..05a39a480 100644 --- a/reference/multiversion-rest/src/main/java/org/a2aproject/sdk/server/multiversion/rest/MultiVersionRestRoutes.java +++ b/reference/multiversion-rest/src/main/java/org/a2aproject/sdk/server/multiversion/rest/MultiVersionRestRoutes.java @@ -45,7 +45,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { .blockingHandler(versionDispatch( MultiVersionRestRoutes::bridgeTenant, (body, ctx) -> v10Routes.sendMessage(body, ctx), - (body, ctx) -> v03Routes.sendMessage(body, ctx))); + (body, ctx) -> v03Routes.sendMessage(body, ctx)), false); // POST /v1/message:stream router.postWithRegex("^\\/v1\\/message:stream$") @@ -54,7 +54,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { .blockingHandler(versionDispatch( MultiVersionRestRoutes::bridgeTenant, (body, ctx) -> v10Routes.sendMessageStreaming(body, ctx), - (body, ctx) -> v03Routes.sendMessageStreaming(body, ctx))); + (body, ctx) -> v03Routes.sendMessageStreaming(body, ctx)), false); // GET /v1/tasks/{taskId} router.getWithRegex("^\\/v1\\/tasks\\/(?[^:^/]+)$") @@ -62,7 +62,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { .blockingHandler(versionDispatchNoBody( ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.getTask(ctx), - ctx -> v03Routes.getTask(ctx))); + ctx -> v03Routes.getTask(ctx)), false); // POST /v1/tasks/{taskId}:cancel router.postWithRegex("^\\/v1\\/tasks\\/(?[^/]+):cancel$") @@ -71,7 +71,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { .blockingHandler(versionDispatch( ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, (body, ctx) -> v10Routes.cancelTask(body, ctx), - (body, ctx) -> v03Routes.cancelTask(ctx))); + (body, ctx) -> v03Routes.cancelTask(ctx)), false); // POST /v1/tasks/{taskId}:subscribe router.postWithRegex("^\\/v1\\/tasks\\/(?[^/]+):subscribe$") @@ -79,7 +79,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { .blockingHandler(versionDispatchNoBody( ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.subscribeToTask(ctx), - ctx -> v03Routes.resubscribeTask(ctx))); + ctx -> v03Routes.resubscribeTask(ctx)), false); // POST /v1/tasks/{taskId}/pushNotificationConfigs router.postWithRegex("^\\/v1\\/tasks\\/(?[^/]+)\\/pushNotificationConfigs$") @@ -88,7 +88,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { .blockingHandler(versionDispatch( ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, (body, ctx) -> v10Routes.createTaskPushNotificationConfiguration(body, ctx), - (body, ctx) -> v03Routes.setTaskPushNotificationConfiguration(body, ctx))); + (body, ctx) -> v03Routes.setTaskPushNotificationConfiguration(body, ctx)), false); // GET /v1/tasks/{taskId}/pushNotificationConfigs/{configId} router.getWithRegex("^\\/v1\\/tasks\\/(?[^/]+)\\/pushNotificationConfigs\\/(?[^\\/]+)") @@ -96,7 +96,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { .blockingHandler(versionDispatchNoBody( ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.getTaskPushNotificationConfiguration(ctx), - ctx -> v03Routes.getTaskPushNotificationConfiguration(ctx))); + ctx -> v03Routes.getTaskPushNotificationConfiguration(ctx)), false); // GET /v1/tasks/{taskId}/pushNotificationConfigs router.getWithRegex("^\\/v1\\/tasks\\/(?[^/]+)\\/pushNotificationConfigs\\/?$") @@ -104,7 +104,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { .blockingHandler(versionDispatchNoBody( ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.listTaskPushNotificationConfigurations(ctx), - ctx -> v03Routes.listTaskPushNotificationConfigurations(ctx))); + ctx -> v03Routes.listTaskPushNotificationConfigurations(ctx)), false); // DELETE /v1/tasks/{taskId}/pushNotificationConfigs/{configId} router.deleteWithRegex("^\\/v1\\/tasks\\/(?[^/]+)\\/pushNotificationConfigs\\/(?[^/]+)") @@ -112,7 +112,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { .blockingHandler(versionDispatchNoBody( ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.deleteTaskPushNotificationConfiguration(ctx), - ctx -> v03Routes.deleteTaskPushNotificationConfiguration(ctx))); + ctx -> v03Routes.deleteTaskPushNotificationConfiguration(ctx)), false); // GET /v1/card — v0.3 only (v1.0 uses /{tenant}/extendedAgentCard) router.get("/v1/card") @@ -127,7 +127,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { } catch (Exception e) { VertxSecurityHelper.handleGenericError(ctx); } - }); + }, false); } private static void bridgeTenant(RoutingContext ctx) { diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/ServerCallContext.java b/server-common/src/main/java/org/a2aproject/sdk/server/ServerCallContext.java index ff4652158..e7e819098 100644 --- a/server-common/src/main/java/org/a2aproject/sdk/server/ServerCallContext.java +++ b/server-common/src/main/java/org/a2aproject/sdk/server/ServerCallContext.java @@ -23,6 +23,14 @@ public class ServerCallContext { */ public static final String STRICT_CONTEXT_VALIDATION_KEY = "strictContextValidation"; + /** + * Key for an execution wrapper in the state map. + * Value should be a {@link java.util.function.UnaryOperator UnaryOperator<Runnable>} that wraps + * the agent execution runnable. Used by gRPC transport to fork the inbound call's context + * so that agent outbound calls are isolated from inbound cancellation. + */ + public static final String EXECUTION_WRAPPER_KEY = "executionWrapper"; + // TODO Not totally sure yet about these field types private final Map modelConfig = new ConcurrentHashMap<>(); private final Map state; diff --git a/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/DefaultRequestHandler.java index d2c3aa028..8287a07a8 100644 --- a/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/org/a2aproject/sdk/server/requesthandlers/DefaultRequestHandler.java @@ -962,7 +962,14 @@ public void run() { // Mark as started to prevent further callback additions (enforced by runtime check) runnable.markStarted(); - CompletableFuture cf = CompletableFuture.runAsync(runnable, executor) + // Apply transport-provided execution wrapper (e.g. gRPC context fork for cancellation isolation) + @SuppressWarnings("unchecked") + java.util.function.UnaryOperator wrapper = requestContext.getCallContext() != null + ? (java.util.function.UnaryOperator) requestContext.getCallContext().getState().get(ServerCallContext.EXECUTION_WRAPPER_KEY) + : null; + Runnable wrappedRunnable = wrapper != null ? wrapper.apply(runnable) : runnable; + + CompletableFuture cf = CompletableFuture.runAsync(wrappedRunnable, executor) .whenComplete((v, err) -> { if (err != null) { LOGGER.error("Agent execution failed for task {}", taskId, err); diff --git a/transport/grpc/src/main/java/org/a2aproject/sdk/transport/grpc/handler/GrpcHandler.java b/transport/grpc/src/main/java/org/a2aproject/sdk/transport/grpc/handler/GrpcHandler.java index 178a32adf..70185f32d 100644 --- a/transport/grpc/src/main/java/org/a2aproject/sdk/transport/grpc/handler/GrpcHandler.java +++ b/transport/grpc/src/main/java/org/a2aproject/sdk/transport/grpc/handler/GrpcHandler.java @@ -394,6 +394,7 @@ public void sendStreamingMessage(org.a2aproject.sdk.grpc.SendMessageRequest requ try { ServerCallContext context = createCallContext(responseObserver); + installForkedContextWrapper(context); A2AVersionValidator.validateProtocolVersion(getAgentCardInternal(), context); A2AExtensions.validateRequiredExtensions(getAgentCardInternal(), context); MessageSendParams params = FromProto.messageSendParams(request); @@ -418,6 +419,7 @@ public void subscribeToTask(org.a2aproject.sdk.grpc.SubscribeToTaskRequest reque try { ServerCallContext context = createCallContext(responseObserver); + installForkedContextWrapper(context); TaskIdParams params = FromProto.taskIdParams(request); Flow.Publisher publisher = getRequestHandler().onSubscribeToTask(params, context); convertToStreamResponse(publisher, responseObserver, context); @@ -705,6 +707,24 @@ private ServerCallContext createCallContext(StreamObserver responseObserv } } + /** + * Forks the current gRPC context and installs it as an execution wrapper on the + * server call context. This isolates agent outbound calls from the inbound call's + * cancellation signal. + */ + private void installForkedContextWrapper(ServerCallContext context) { + Context forked = Context.current().fork(); + context.getState().put(ServerCallContext.EXECUTION_WRAPPER_KEY, + (java.util.function.UnaryOperator) (runnable -> () -> { + Context prev = forked.attach(); + try { + runnable.run(); + } finally { + forked.detach(prev); + } + })); + } + /** * Handles A2A protocol errors by mapping them to appropriate gRPC status codes. *