diff --git a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java index 7f1cbd4cc..d4fe5b395 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java +++ b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java @@ -6,6 +6,7 @@ import io.a2a.spec.Event; import io.a2a.spec.Message; import io.a2a.spec.Task; +import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatusUpdateEvent; import mutiny.zero.BackpressureStrategy; import mutiny.zero.TubeConfiguration; @@ -77,7 +78,7 @@ public Flow.Publisher consumeAll() { } else if (event instanceof Message) { isFinalEvent = true; } else if (event instanceof Task task) { - isFinalEvent = task.status().state().isFinal(); + isFinalEvent = isStreamTerminatingTask(task); } else if (event instanceof QueueClosedEvent) { // Poison pill event - signals queue closure from remote node // Do NOT send to subscribers - just close the queue @@ -94,7 +95,7 @@ public Flow.Publisher consumeAll() { } if (isFinalEvent) { - LOGGER.debug("Final event detected, closing queue and breaking loop for queue {}", System.identityHashCode(queue)); + LOGGER.debug("Final or interrupted event detected, closing queue and breaking loop for queue {}", System.identityHashCode(queue)); queue.close(); LOGGER.debug("Queue closed, breaking loop for queue {}", System.identityHashCode(queue)); break; @@ -120,6 +121,21 @@ public Flow.Publisher consumeAll() { }); } + /** + * Determines if a task is in a state for terminating the stream. + *

A task is terminating if:

+ *
    + *
  • Its state is final (e.g., completed, canceled, rejected, failed), OR
  • + *
  • Its state is interrupted (e.g., input-required)
  • + *
+ * @param task the task to check + * @return true if the task has a final state or an interrupted state, false otherwise + */ + private boolean isStreamTerminatingTask(Task task) { + TaskState state = task.status().state(); + return state.isFinal() || state == TaskState.INPUT_REQUIRED; + } + public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() { return agentRunnable -> { if (agentRunnable.getError() != null) { diff --git a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java index 6114e8f21..4354f1639 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java @@ -114,32 +114,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException { final List receivedEvents = new ArrayList<>(); final AtomicReference error = new AtomicReference<>(); - publisher.subscribe(new Flow.Subscriber<>() { - private Flow.Subscription subscription; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - subscription.request(1); - } - - @Override - public void onNext(EventQueueItem item) { - receivedEvents.add(item.getEvent()); - subscription.request(1); - - } - - @Override - public void onError(Throwable throwable) { - error.set(throwable); - } - - @Override - public void onComplete() { - subscription.cancel(); - } - }); + publisher.subscribe(getSubscriber(receivedEvents, error)); assertNull(error.get()); assertEquals(events.size(), receivedEvents.size()); @@ -175,32 +150,7 @@ public void testConsumeUntilMessage() throws Exception { final List receivedEvents = new ArrayList<>(); final AtomicReference error = new AtomicReference<>(); - publisher.subscribe(new Flow.Subscriber<>() { - private Flow.Subscription subscription; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - subscription.request(1); - } - - @Override - public void onNext(EventQueueItem item) { - receivedEvents.add(item.getEvent()); - subscription.request(1); - - } - - @Override - public void onError(Throwable throwable) { - error.set(throwable); - } - - @Override - public void onComplete() { - subscription.cancel(); - } - }); + publisher.subscribe(getSubscriber(receivedEvents, error)); assertNull(error.get()); assertEquals(3, receivedEvents.size()); @@ -224,7 +174,55 @@ public void testConsumeMessageEvents() throws Exception { final List receivedEvents = new ArrayList<>(); final AtomicReference error = new AtomicReference<>(); - publisher.subscribe(new Flow.Subscriber<>() { + publisher.subscribe(getSubscriber(receivedEvents, error)); + + assertNull(error.get()); + // The stream is closed after the first Message + assertEquals(1, receivedEvents.size()); + assertSame(message, receivedEvents.get(0)); + } + + @Test + public void testConsumeTaskInputRequired() { + Task task = Task.builder() + .id("task-id") + .contextId("task-context") + .status(new TaskStatus(TaskState.INPUT_REQUIRED)) + .build(); + List events = List.of( + task, + TaskArtifactUpdateEvent.builder() + .taskId("task-123") + .contextId("session-xyz") + .artifact(Artifact.builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + TaskStatusUpdateEvent.builder() + .taskId("task-123") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.WORKING)) + .isFinal(true) + .build()); + for (Event event : events) { + eventQueue.enqueueEvent(event); + } + + Flow.Publisher publisher = eventConsumer.consumeAll(); + final List receivedEvents = new ArrayList<>(); + final AtomicReference error = new AtomicReference<>(); + + publisher.subscribe(getSubscriber(receivedEvents, error)); + + assertNull(error.get()); + // The stream is closed after the input_required task + assertEquals(1, receivedEvents.size()); + assertSame(task, receivedEvents.get(0)); + } + + private Flow.Subscriber getSubscriber(List receivedEvents, AtomicReference error) { + return new Flow.Subscriber<>() { private Flow.Subscription subscription; @Override @@ -237,7 +235,6 @@ public void onSubscribe(Flow.Subscription subscription) { public void onNext(EventQueueItem item) { receivedEvents.add(item.getEvent()); subscription.request(1); - } @Override @@ -249,12 +246,7 @@ public void onError(Throwable throwable) { public void onComplete() { subscription.cancel(); } - }); - - assertNull(error.get()); - // The stream is closed after the first Message - assertEquals(1, receivedEvents.size()); - assertSame(message, receivedEvents.get(0)); + }; } @Test