From 44a4a18636305ebf895d6603e604a3edc19323b8 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 13 Jan 2026 14:35:54 +0100 Subject: [PATCH] Support Spring specific service/vobj/workflow annotations --- .../main/java/dev/restate/client/Client.java | 14 ++-- .../common/reflections/ReflectionUtils.java | 84 +++++++++++++++---- .../main/java/dev/restate/sdk/Restate.java | 14 ++-- .../ReflectionServiceDefinitionFactory.java | 9 +- .../sdk/springboot/java/GreeterNewApi.java | 6 +- 5 files changed, 84 insertions(+), 43 deletions(-) diff --git a/client/src/main/java/dev/restate/client/Client.java b/client/src/main/java/dev/restate/client/Client.java index cf3145b7..3d7b9d75 100644 --- a/client/src/main/java/dev/restate/client/Client.java +++ b/client/src/main/java/dev/restate/client/Client.java @@ -8,8 +8,6 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.client; -import static dev.restate.common.reflections.ReflectionUtils.mustHaveAnnotation; - import dev.restate.common.Output; import dev.restate.common.Request; import dev.restate.common.Target; @@ -556,7 +554,7 @@ default Response> getOutput() throws IngressException { */ @org.jetbrains.annotations.ApiStatus.Experimental default SVC service(Class clazz) { - mustHaveAnnotation(clazz, Service.class); + ReflectionUtils.mustHaveServiceAnnotation(clazz); var serviceName = ReflectionUtils.extractServiceName(clazz); return ProxySupport.createProxy( clazz, @@ -607,7 +605,7 @@ default SVC service(Class clazz) { */ @org.jetbrains.annotations.ApiStatus.Experimental default ClientServiceHandle serviceHandle(Class clazz) { - mustHaveAnnotation(clazz, Service.class); + ReflectionUtils.mustHaveServiceAnnotation(clazz); return new ClientServiceHandleImpl<>(this, clazz, null); } @@ -635,7 +633,7 @@ default ClientServiceHandle serviceHandle(Class clazz) { */ @org.jetbrains.annotations.ApiStatus.Experimental default SVC virtualObject(Class clazz, String key) { - mustHaveAnnotation(clazz, VirtualObject.class); + ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz); var serviceName = ReflectionUtils.extractServiceName(clazz); return ProxySupport.createProxy( clazz, @@ -687,7 +685,7 @@ default SVC virtualObject(Class clazz, String key) { */ @org.jetbrains.annotations.ApiStatus.Experimental default ClientServiceHandle virtualObjectHandle(Class clazz, String key) { - mustHaveAnnotation(clazz, VirtualObject.class); + ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz); return new ClientServiceHandleImpl<>(this, clazz, key); } @@ -715,7 +713,7 @@ default ClientServiceHandle virtualObjectHandle(Class clazz, Str */ @org.jetbrains.annotations.ApiStatus.Experimental default SVC workflow(Class clazz, String key) { - mustHaveAnnotation(clazz, Workflow.class); + ReflectionUtils.mustHaveWorkflowAnnotation(clazz); var serviceName = ReflectionUtils.extractServiceName(clazz); return ProxySupport.createProxy( clazz, @@ -767,7 +765,7 @@ default SVC workflow(Class clazz, String key) { */ @org.jetbrains.annotations.ApiStatus.Experimental default ClientServiceHandle workflowHandle(Class clazz, String key) { - mustHaveAnnotation(clazz, Workflow.class); + ReflectionUtils.mustHaveWorkflowAnnotation(clazz); return new ClientServiceHandleImpl<>(this, clazz, key); } diff --git a/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java b/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java index 8c118b63..ca55efb1 100644 --- a/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java +++ b/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java @@ -20,6 +20,14 @@ public class ReflectionUtils { + private static final @Nullable Class RESTATE_SPRING_SERVICE_ANNOTATION = + tryLoadClass("dev.restate.sdk.springboot.RestateService"); + private static final @Nullable Class + RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION = + tryLoadClass("dev.restate.sdk.springboot.RestateVirtualObject"); + private static final @Nullable Class RESTATE_SPRING_WORKFLOW_ANNOTATION = + tryLoadClass("dev.restate.sdk.springboot.RestateWorkflow"); + /** Record containing handler information extracted from annotations. */ public record HandlerInfo(String name, boolean shared) {} @@ -163,16 +171,17 @@ private static String inferRestateNameFromHierarchy(Class type) { } // Check if the type has any of the Restate component annotations - var restateServiceAnnotation = type.getAnnotation(Service.class); - if (restateServiceAnnotation != null) { - return extractNameFromAnnotations(type); - } - var restateVirtualObjectAnnotation = type.getAnnotation(VirtualObject.class); - if (restateVirtualObjectAnnotation != null) { - return extractNameFromAnnotations(type); - } - var restateWorkflowAnnotation = type.getAnnotation(Workflow.class); - if (restateWorkflowAnnotation != null) { + var isRestateAnnotated = + type.getAnnotation(Service.class) != null + || type.getAnnotation(VirtualObject.class) != null + || type.getAnnotation(Workflow.class) != null + || (RESTATE_SPRING_SERVICE_ANNOTATION != null + && type.getAnnotation(RESTATE_SPRING_SERVICE_ANNOTATION) != null) + || (RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION != null + && type.getAnnotation(RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION) != null) + || (RESTATE_SPRING_WORKFLOW_ANNOTATION != null + && type.getAnnotation(RESTATE_SPRING_WORKFLOW_ANNOTATION) != null); + if (isRestateAnnotated) { return extractNameFromAnnotations(type); } @@ -200,17 +209,49 @@ private static String extractNameFromAnnotations(Class type) { return type.getSimpleName(); } - public static A mustHaveAnnotation( - Class clazz, Class annotationClazz) { - A annotation = findAnnotation(clazz, annotationClazz); - if (annotation == null) { + public static boolean hasServiceAnnotation(Class clazz) { + return findAnnotation(clazz, Service.class) != null + || (RESTATE_SPRING_SERVICE_ANNOTATION != null + && findAnnotation(clazz, RESTATE_SPRING_SERVICE_ANNOTATION) != null); + } + + public static void mustHaveServiceAnnotation(Class clazz) { + if (!hasServiceAnnotation(clazz)) { + throw new IllegalArgumentException( + "The given class " + + clazz.getName() + + " is not annotated with the Restate service annotation"); + } + } + + public static boolean hasVirtualObjectAnnotation(Class clazz) { + return findAnnotation(clazz, VirtualObject.class) != null + || (RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION != null + && findAnnotation(clazz, RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION) != null); + } + + public static void mustHaveVirtualObjectAnnotation(Class clazz) { + if (!hasVirtualObjectAnnotation(clazz)) { + throw new IllegalArgumentException( + "The given class " + + clazz.getName() + + " is not annotated with the Restate virtualObject annotation"); + } + } + + public static boolean hasWorkflowAnnotation(Class clazz) { + return findAnnotation(clazz, Workflow.class) != null + || (RESTATE_SPRING_WORKFLOW_ANNOTATION != null + && findAnnotation(clazz, RESTATE_SPRING_WORKFLOW_ANNOTATION) != null); + } + + public static void mustHaveWorkflowAnnotation(Class clazz) { + if (!hasWorkflowAnnotation(clazz)) { throw new IllegalArgumentException( "The given class " + clazz.getName() - + " is not annotated with @" - + annotationClazz.getSimpleName()); + + " is not annotated with the Restate workflow annotation"); } - return annotation; } public static HandlerInfo mustHaveHandlerAnnotation(@NonNull Method method) { @@ -308,6 +349,15 @@ public static boolean isKotlinClass(Class clazz) { .anyMatch(annotation -> annotation.annotationType().getName().equals("kotlin.Metadata")); } + @SuppressWarnings("unchecked") + private static @Nullable Class tryLoadClass(String className) { + try { + return (Class) Class.forName(className); + } catch (ClassNotFoundException e) { + return null; + } + } + // From Spring's ReflectionUtils // License Apache 2.0 diff --git a/sdk-api/src/main/java/dev/restate/sdk/Restate.java b/sdk-api/src/main/java/dev/restate/sdk/Restate.java index 6286ecd8..fd8ff1e4 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Restate.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Restate.java @@ -8,8 +8,6 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import static dev.restate.common.reflections.ReflectionUtils.mustHaveAnnotation; - import dev.restate.common.Request; import dev.restate.common.Slice; import dev.restate.common.Target; @@ -430,7 +428,7 @@ public static AwakeableHandle awakeableHandle(String id) { */ @org.jetbrains.annotations.ApiStatus.Experimental public static SVC service(Class clazz) { - mustHaveAnnotation(clazz, Service.class); + ReflectionUtils.mustHaveServiceAnnotation(clazz); String serviceName = ReflectionUtils.extractServiceName(clazz); return ProxySupport.createProxy( clazz, @@ -481,7 +479,7 @@ public static SVC service(Class clazz) { */ @org.jetbrains.annotations.ApiStatus.Experimental public static ServiceHandle serviceHandle(Class clazz) { - mustHaveAnnotation(clazz, Service.class); + ReflectionUtils.mustHaveServiceAnnotation(clazz); return new ServiceHandleImpl<>(clazz, null); } @@ -506,7 +504,7 @@ public static ServiceHandle serviceHandle(Class clazz) { */ @org.jetbrains.annotations.ApiStatus.Experimental public static SVC virtualObject(Class clazz, String key) { - mustHaveAnnotation(clazz, VirtualObject.class); + ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz); String serviceName = ReflectionUtils.extractServiceName(clazz); return ProxySupport.createProxy( clazz, @@ -558,7 +556,7 @@ public static SVC virtualObject(Class clazz, String key) { */ @org.jetbrains.annotations.ApiStatus.Experimental public static ServiceHandle virtualObjectHandle(Class clazz, String key) { - mustHaveAnnotation(clazz, VirtualObject.class); + ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz); return new ServiceHandleImpl<>(clazz, key); } @@ -583,7 +581,7 @@ public static ServiceHandle virtualObjectHandle(Class clazz, Str */ @org.jetbrains.annotations.ApiStatus.Experimental public static SVC workflow(Class clazz, String key) { - mustHaveAnnotation(clazz, Workflow.class); + ReflectionUtils.mustHaveWorkflowAnnotation(clazz); String serviceName = ReflectionUtils.extractServiceName(clazz); return ProxySupport.createProxy( clazz, @@ -635,7 +633,7 @@ public static SVC workflow(Class clazz, String key) { */ @org.jetbrains.annotations.ApiStatus.Experimental public static ServiceHandle workflowHandle(Class clazz, String key) { - mustHaveAnnotation(clazz, Workflow.class); + ReflectionUtils.mustHaveWorkflowAnnotation(clazz); return new ServiceHandleImpl<>(clazz, key); } diff --git a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java index a250298f..67351e55 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java +++ b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java @@ -47,12 +47,9 @@ public ServiceDefinition create( Class serviceClazz = serviceInstance.getClass(); - boolean hasServiceAnnotation = - ReflectionUtils.findAnnotation(serviceClazz, Service.class) != null; - boolean hasVirtualObjectAnnotation = - ReflectionUtils.findAnnotation(serviceClazz, VirtualObject.class) != null; - boolean hasWorkflowAnnotation = - ReflectionUtils.findAnnotation(serviceClazz, Workflow.class) != null; + boolean hasServiceAnnotation = ReflectionUtils.hasServiceAnnotation(serviceClazz); + boolean hasVirtualObjectAnnotation = ReflectionUtils.hasVirtualObjectAnnotation(serviceClazz); + boolean hasWorkflowAnnotation = ReflectionUtils.hasWorkflowAnnotation(serviceClazz); boolean hasAnyAnnotation = hasServiceAnnotation || hasVirtualObjectAnnotation || hasWorkflowAnnotation; diff --git a/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/GreeterNewApi.java b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/GreeterNewApi.java index ad64db27..cc6aff24 100644 --- a/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/GreeterNewApi.java +++ b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/GreeterNewApi.java @@ -10,12 +10,10 @@ import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Name; -import dev.restate.sdk.annotation.Service; -import dev.restate.sdk.springboot.RestateComponent; +import dev.restate.sdk.springboot.RestateService; import org.springframework.beans.factory.annotation.Value; -@Service -@RestateComponent +@RestateService @Name("greeterNewApi") public class GreeterNewApi {