diff --git a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java index 8c4050fd9..a5758974f 100644 --- a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java +++ b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java @@ -5,6 +5,7 @@ import java.time.Instant; import java.time.format.DateTimeParseException; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -12,6 +13,7 @@ import java.util.concurrent.Flow; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.stream.Collectors; import jakarta.enterprise.context.ApplicationScoped; import jakarta.enterprise.inject.Instance; @@ -208,7 +210,21 @@ public HTTPRestResponse listTasks(@Nullable String contextId, @Nullable String s paramsBuilder.contextId(contextId); } if (status != null) { - paramsBuilder.status(TaskState.valueOf(status)); + try { + paramsBuilder.status(TaskState.fromString(status)); + } catch (IllegalArgumentException e) { + try { + paramsBuilder.status(TaskState.valueOf(status)); + } catch (IllegalArgumentException valueOfError) { + String validStates = Arrays.stream(TaskState.values()) + .map(TaskState::asString) + .collect(Collectors.joining(", ")); + Map errorData = new HashMap<>(); + errorData.put("parameter", "status"); + errorData.put("reason", "Must be one of: " + validStates); + throw new InvalidParamsError(null, "Invalid params", errorData); + } + } } if (pageSize != null) { paramsBuilder.pageSize(pageSize); diff --git a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java index df2a5e7af..0487a6f54 100644 --- a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java +++ b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java @@ -58,6 +58,31 @@ public void testGetTaskNotFound() { Assertions.assertTrue(response.getBody().contains("TaskNotFoundError")); } + @Test + public void testListTasksStatusWireString() { + RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); + taskStore.save(MINIMAL_TASK); + + RestHandler.HTTPRestResponse response = handler.listTasks(null, "submitted", null, null, + null, null, null, "", callContext); + + Assertions.assertEquals(200, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains(MINIMAL_TASK.id())); + } + + @Test + public void testListTasksInvalidStatus() { + RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); + + RestHandler.HTTPRestResponse response = handler.listTasks(null, "not-a-status", null, null, + null, null, null, "", callContext); + + Assertions.assertEquals(422, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains("InvalidParamsError")); + } + @Test public void testSendMessage() throws InvalidProtocolBufferException { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor);