From e316d7b5e23cce12bdd778bbdf5306e1c61ffca4 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 9 Jan 2026 14:30:21 +0100 Subject: [PATCH 1/7] A solution that should be all wired up, not tested yet --- bytebuddy-proxy-support/build.gradle.kts | 18 + .../proxysupport/ByteBuddyProxyFactory.java | 117 ++ ...mmon.reflections.ProxySupport.ProxyFactory | 1 + client/build.gradle.kts | 1 + .../main/java/dev/restate/client/Client.java | 23 + .../client/ClientServiceReference.java | 481 ++++++ .../client/ClientServiceReferenceImpl.java | 211 +++ common/build.gradle.kts | 2 + .../dev/restate/common/InvocationOptions.java | 154 ++ .../ConcurrentReferenceHashMap.java | 1316 +++++++++++++++++ .../common/reflections/MethodInfo.java | 46 + .../reflections/MethodInfoCollector.java | 68 + .../common/reflections/ProxySupport.java | 137 ++ .../common/reflections/ReflectionUtils.java | 1201 +++++++++++++++ .../common/reflections/RestateUtils.java | 55 + .../dev/restate/sdk/annotation/Accept.java | 2 +- .../sdk/annotation/CustomSerdeFactory.java | 2 +- .../dev/restate/sdk/annotation/Exclusive.java | 2 +- .../dev/restate/sdk/annotation/Handler.java | 2 +- .../java/dev/restate/sdk/annotation/Json.java | 2 +- .../java/dev/restate/sdk/annotation/Name.java | 2 +- .../java/dev/restate/sdk/annotation/Raw.java | 2 +- .../dev/restate/sdk/annotation/Service.java | 6 +- .../dev/restate/sdk/annotation/Shared.java | 2 +- .../restate/sdk/annotation/VirtualObject.java | 6 +- .../dev/restate/sdk/annotation/Workflow.java | 6 +- gradle/libs.versions.toml | 2 + .../sdk/MalformedRestateServiceException.java | 21 + .../ReflectionServiceDefinitionFactory.java | 227 +++ .../main/java/dev/restate/sdk/Restate.java | 181 +++ .../sdk/RestateThreadLocalContext.java | 40 + .../dev/restate/sdk/ServiceReference.java | 198 +++ .../dev/restate/sdk/ServiceReferenceImpl.java | 210 +++ ...dpoint.definition.ServiceDefinitionFactory | 1 + .../ServiceDefinitionFactories.java | 2 + .../definition/ServiceDefinitionFactory.java | 16 + .../jackson/JacksonSerdeFactoryProvider.java | 2 +- settings.gradle.kts | 1 + 38 files changed, 4743 insertions(+), 23 deletions(-) create mode 100644 bytebuddy-proxy-support/build.gradle.kts create mode 100644 bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java create mode 100644 bytebuddy-proxy-support/src/main/resources/META-INF/services/dev.restate.common.reflections.ProxySupport.ProxyFactory create mode 100644 client/src/main/java/dev/restate/client/ClientServiceReference.java create mode 100644 client/src/main/java/dev/restate/client/ClientServiceReferenceImpl.java create mode 100644 common/src/main/java/dev/restate/common/InvocationOptions.java create mode 100644 common/src/main/java/dev/restate/common/reflections/ConcurrentReferenceHashMap.java create mode 100644 common/src/main/java/dev/restate/common/reflections/MethodInfo.java create mode 100644 common/src/main/java/dev/restate/common/reflections/MethodInfoCollector.java create mode 100644 common/src/main/java/dev/restate/common/reflections/ProxySupport.java create mode 100644 common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java create mode 100644 common/src/main/java/dev/restate/common/reflections/RestateUtils.java rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/Accept.java (95%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java (96%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/Exclusive.java (95%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/Handler.java (96%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/Json.java (95%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/Name.java (95%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/Raw.java (96%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/Service.java (75%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/Shared.java (96%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/VirtualObject.java (75%) rename {sdk-common => common}/src/main/java/dev/restate/sdk/annotation/Workflow.java (69%) create mode 100644 sdk-api/src/main/java/dev/restate/sdk/MalformedRestateServiceException.java create mode 100644 sdk-api/src/main/java/dev/restate/sdk/ReflectionServiceDefinitionFactory.java create mode 100644 sdk-api/src/main/java/dev/restate/sdk/Restate.java create mode 100644 sdk-api/src/main/java/dev/restate/sdk/RestateThreadLocalContext.java create mode 100644 sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java create mode 100644 sdk-api/src/main/java/dev/restate/sdk/ServiceReferenceImpl.java create mode 100644 sdk-api/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory diff --git a/bytebuddy-proxy-support/build.gradle.kts b/bytebuddy-proxy-support/build.gradle.kts new file mode 100644 index 00000000..902703ce --- /dev/null +++ b/bytebuddy-proxy-support/build.gradle.kts @@ -0,0 +1,18 @@ +plugins { + `java-conventions` + `kotlin-conventions` + `java-library` + `library-publishing-conventions` +} + +description = "ByteBuddy proxy support" + +dependencies { + compileOnly(libs.jspecify) + + implementation(project(":common")) + implementation(libs.bytebuddy) + implementation(libs.objenesis) +} + +tasks.withType { isFailOnError = false } diff --git a/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java new file mode 100644 index 00000000..879351da --- /dev/null +++ b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java @@ -0,0 +1,117 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.bytebuddy.proxysupport; + +import dev.restate.common.reflections.ProxySupport; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.TypeCache; +import net.bytebuddy.description.modifier.Visibility; +import net.bytebuddy.dynamic.scaffold.TypeValidation; +import net.bytebuddy.implementation.InvocationHandlerAdapter; +import net.bytebuddy.matcher.ElementMatchers; +import org.jspecify.annotations.Nullable; +import org.objenesis.Objenesis; +import org.objenesis.ObjenesisStd; + +/** + * ByteBuddy-based proxy factory that supports both interfaces and concrete classes. This + * implementation can create proxies for any class that is not final. Uses Objenesis to instantiate + * objects without calling constructors, which allows proxying classes that don't have a no-arg + * constructor. Uses TypeCache to cache generated proxy classes for better performance + * (thread-safe). + */ +public final class ByteBuddyProxyFactory implements ProxySupport.ProxyFactory { + + private static final String INTERCEPTOR_FIELD_NAME = "$$interceptor$$"; + + private final Objenesis objenesis = new ObjenesisStd(); + private final TypeCache> proxyClassCache = + new TypeCache.WithInlineExpunction<>(TypeCache.Sort.SOFT); + + @Override + @SuppressWarnings("unchecked") + public @Nullable T createProxy(Class clazz, ProxySupport.MethodInterceptor interceptor) { + // Cannot proxy final classes + if (Modifier.isFinal(clazz.getModifiers())) { + return null; + } + + try { + // Find or create the proxy class (cached) + Class proxyClass = + (Class) + proxyClassCache.findOrInsert( + clazz.getClassLoader(), clazz, () -> generateProxyClass(clazz), proxyClassCache); + + // Instantiate the proxy class using Objenesis (no constructor call) + T proxyInstance = objenesis.newInstance(proxyClass); + + // Set the interceptor field + Field interceptorField = proxyClass.getDeclaredField(INTERCEPTOR_FIELD_NAME); + interceptorField.setAccessible(true); + interceptorField.set(proxyInstance, interceptor); + + return proxyInstance; + + } catch (Exception e) { + // Could not create or instantiate the proxy + return null; + } + } + + private Class generateProxyClass(Class clazz) { + ByteBuddy byteBuddy = new ByteBuddy().with(TypeValidation.DISABLED); + + var builder = + clazz.isInterface() + ? byteBuddy.subclass(Object.class).implement(clazz) + : byteBuddy.subclass(clazz); + + try (var unloaded = + builder + // Add a field to store the interceptor + .defineField( + INTERCEPTOR_FIELD_NAME, ProxySupport.MethodInterceptor.class, Visibility.PUBLIC) + // Intercept all methods + .method(ElementMatchers.any()) + .intercept( + InvocationHandlerAdapter.of( + (proxy, method, args) -> { + // Get the interceptor from the field + Field field = proxy.getClass().getDeclaredField(INTERCEPTOR_FIELD_NAME); + field.setAccessible(true); + ProxySupport.MethodInterceptor interceptor = + (ProxySupport.MethodInterceptor) field.get(proxy); + + if (interceptor == null) { + throw new IllegalStateException("Interceptor not set on proxy instance"); + } + + ProxySupport.MethodInvocation invocation = + new ProxySupport.MethodInvocation() { + @Override + public Object[] getArguments() { + return args != null ? args : new Object[0]; + } + + @Override + public Method getMethod() { + return method; + } + }; + return interceptor.invoke(invocation); + })) + .make()) { + return unloaded.load(clazz.getClassLoader()).getLoaded(); + } + } +} diff --git a/bytebuddy-proxy-support/src/main/resources/META-INF/services/dev.restate.common.reflections.ProxySupport.ProxyFactory b/bytebuddy-proxy-support/src/main/resources/META-INF/services/dev.restate.common.reflections.ProxySupport.ProxyFactory new file mode 100644 index 00000000..f205102d --- /dev/null +++ b/bytebuddy-proxy-support/src/main/resources/META-INF/services/dev.restate.common.reflections.ProxySupport.ProxyFactory @@ -0,0 +1 @@ +dev.restate.bytebuddy.proxysupport.ByteBuddyProxyFactory \ No newline at end of file diff --git a/client/build.gradle.kts b/client/build.gradle.kts index 026d6b65..b86c0067 100644 --- a/client/build.gradle.kts +++ b/client/build.gradle.kts @@ -9,6 +9,7 @@ description = "Restate Client to interact with services from within other Java a dependencies { compileOnly(libs.jspecify) + compileOnly(libs.jetbrains.annotations) api(project(":common")) api(project(":sdk-serde-jackson")) diff --git a/client/src/main/java/dev/restate/client/Client.java b/client/src/main/java/dev/restate/client/Client.java index 16543ab4..286570d8 100644 --- a/client/src/main/java/dev/restate/client/Client.java +++ b/client/src/main/java/dev/restate/client/Client.java @@ -8,10 +8,15 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.client; +import static dev.restate.common.reflections.ReflectionUtils.mustHaveAnnotation; + import dev.restate.common.Output; import dev.restate.common.Request; import dev.restate.common.Target; import dev.restate.common.WorkflowRequest; +import dev.restate.sdk.annotation.Service; +import dev.restate.sdk.annotation.VirtualObject; +import dev.restate.sdk.annotation.Workflow; import dev.restate.serde.SerdeFactory; import dev.restate.serde.TypeTag; import java.time.Duration; @@ -525,6 +530,24 @@ default Response> getOutput() throws IngressException { } } + @org.jetbrains.annotations.ApiStatus.Experimental + default ClientServiceReference service(Class clazz) { + mustHaveAnnotation(clazz, Service.class); + return new ClientServiceReferenceImpl<>(this, clazz, null); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default ClientServiceReference virtualObject(Class clazz, String key) { + mustHaveAnnotation(clazz, VirtualObject.class); + return new ClientServiceReferenceImpl<>(this, clazz, key); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default ClientServiceReference workflow(Class clazz, String key) { + mustHaveAnnotation(clazz, Workflow.class); + return new ClientServiceReferenceImpl<>(this, clazz, key); + } + /** * Create a default JDK client. * diff --git a/client/src/main/java/dev/restate/client/ClientServiceReference.java b/client/src/main/java/dev/restate/client/ClientServiceReference.java new file mode 100644 index 00000000..38dda052 --- /dev/null +++ b/client/src/main/java/dev/restate/client/ClientServiceReference.java @@ -0,0 +1,481 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client; + +import dev.restate.common.InvocationOptions; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +@org.jetbrains.annotations.ApiStatus.Experimental +public interface ClientServiceReference { + @org.jetbrains.annotations.ApiStatus.Experimental + SVC client(); + + // call - BiFunction variants + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call(BiFunction s, I input) { + return call(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call( + BiFunction s, I input, InvocationOptions.Builder options) { + return call(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call( + BiFunction s, I input, InvocationOptions invocationOptions) { + try { + return callAsync(s, input, invocationOptions).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + // call - BiConsumer variants + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call(BiConsumer s, I input) { + return call(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call( + BiConsumer s, I input, InvocationOptions.Builder options) { + return call(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call( + BiConsumer s, I input, InvocationOptions invocationOptions) { + try { + return callAsync(s, input, invocationOptions).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + // call - Function variants + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call(Function s) { + return call(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call(Function s, InvocationOptions.Builder options) { + return call(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call(Function s, InvocationOptions invocationOptions) { + try { + return callAsync(s, invocationOptions).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + // call - Consumer variants + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call(Consumer s) { + return call(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call(Consumer s, InvocationOptions.Builder options) { + return call(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default Response call(Consumer s, InvocationOptions invocationOptions) { + try { + return callAsync(s, invocationOptions).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + // callAsync - BiFunction variants + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> callAsync(BiFunction s, I input) { + return callAsync(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> callAsync( + BiFunction s, I input, InvocationOptions.Builder options) { + return callAsync(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + CompletableFuture> callAsync( + BiFunction s, I input, InvocationOptions invocationOptions); + + // callAsync - BiConsumer variants + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> callAsync(BiConsumer s, I input) { + return callAsync(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> callAsync( + BiConsumer s, I input, InvocationOptions.Builder options) { + return callAsync(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + CompletableFuture> callAsync( + BiConsumer s, I input, InvocationOptions invocationOptions); + + // callAsync - Function variants + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> callAsync(Function s) { + return callAsync(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> callAsync( + Function s, InvocationOptions.Builder options) { + return callAsync(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + CompletableFuture> callAsync( + Function s, InvocationOptions invocationOptions); + + // callAsync - Consumer variants + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> callAsync(Consumer s) { + return callAsync(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> callAsync( + Consumer s, InvocationOptions.Builder options) { + return callAsync(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + CompletableFuture> callAsync(Consumer s, InvocationOptions invocationOptions); + + // send - BiFunction variants + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(BiFunction s, I input) { + return send(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + BiFunction s, I input, InvocationOptions.Builder options) { + return send(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + BiFunction s, I input, InvocationOptions invocationOptions) { + return send(s, input, null, invocationOptions); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(BiFunction s, I input, Duration delay) { + return send(s, input, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + BiFunction s, I input, Duration delay, InvocationOptions.Builder options) { + return send(s, input, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + BiFunction s, I input, Duration delay, InvocationOptions invocationOptions) { + try { + return sendAsync(s, input, delay, invocationOptions).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + // send - BiConsumer variants + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(BiConsumer s, I input) { + return send(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + BiConsumer s, I input, InvocationOptions.Builder options) { + return send(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + BiConsumer s, I input, InvocationOptions invocationOptions) { + return send(s, input, null, invocationOptions); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(BiConsumer s, I input, Duration delay) { + return send(s, input, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + BiConsumer s, I input, Duration delay, InvocationOptions.Builder options) { + return send(s, input, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + BiConsumer s, I input, Duration delay, InvocationOptions invocationOptions) { + try { + return sendAsync(s, input, delay, invocationOptions).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + // send - Function variants + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(Function s) { + return send(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(Function s, InvocationOptions.Builder options) { + return send(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(Function s, InvocationOptions invocationOptions) { + return send(s, null, invocationOptions); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(Function s, Duration delay) { + return send(s, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + Function s, Duration delay, InvocationOptions.Builder options) { + return send(s, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + Function s, Duration delay, InvocationOptions invocationOptions) { + try { + return sendAsync(s, delay, invocationOptions).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + // send - Consumer variants + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(Consumer s) { + return send(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(Consumer s, InvocationOptions.Builder options) { + return send(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(Consumer s, InvocationOptions invocationOptions) { + return send(s, null, invocationOptions); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send(Consumer s, Duration delay) { + return send(s, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + Consumer s, Duration delay, InvocationOptions.Builder options) { + return send(s, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default SendResponse send( + Consumer s, Duration delay, InvocationOptions invocationOptions) { + try { + return sendAsync(s, delay, invocationOptions).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + // sendAsync - BiFunction variants + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync(BiFunction s, I input) { + return sendAsync(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + BiFunction s, I input, InvocationOptions.Builder options) { + return sendAsync(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + BiFunction s, I input, InvocationOptions invocationOptions) { + return sendAsync(s, input, null, invocationOptions); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + BiFunction s, I input, Duration delay) { + return sendAsync(s, input, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + BiFunction s, I input, Duration delay, InvocationOptions.Builder options) { + return sendAsync(s, input, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + CompletableFuture> sendAsync( + BiFunction s, I input, Duration delay, InvocationOptions invocationOptions); + + // sendAsync - BiConsumer variants + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync(BiConsumer s, I input) { + return sendAsync(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + BiConsumer s, I input, InvocationOptions.Builder options) { + return sendAsync(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + BiConsumer s, I input, InvocationOptions invocationOptions) { + return sendAsync(s, input, null, invocationOptions); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + BiConsumer s, I input, Duration delay) { + return sendAsync(s, input, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + BiConsumer s, I input, Duration delay, InvocationOptions.Builder options) { + return sendAsync(s, input, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + CompletableFuture> sendAsync( + BiConsumer s, I input, Duration delay, InvocationOptions invocationOptions); + + // sendAsync - Function variants + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync(Function s) { + return sendAsync(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + Function s, InvocationOptions.Builder options) { + return sendAsync(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + Function s, InvocationOptions invocationOptions) { + return sendAsync(s, null, invocationOptions); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync(Function s, Duration delay) { + return sendAsync(s, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + Function s, Duration delay, InvocationOptions.Builder options) { + return sendAsync(s, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + CompletableFuture> sendAsync( + Function s, Duration delay, InvocationOptions invocationOptions); + + // sendAsync - Consumer variants + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync(Consumer s) { + return sendAsync(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + Consumer s, InvocationOptions.Builder options) { + return sendAsync(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + Consumer s, InvocationOptions invocationOptions) { + return sendAsync(s, null, invocationOptions); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync(Consumer s, Duration delay) { + return sendAsync(s, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default CompletableFuture> sendAsync( + Consumer s, Duration delay, InvocationOptions.Builder options) { + return sendAsync(s, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + CompletableFuture> sendAsync( + Consumer s, Duration delay, InvocationOptions invocationOptions); +} diff --git a/client/src/main/java/dev/restate/client/ClientServiceReferenceImpl.java b/client/src/main/java/dev/restate/client/ClientServiceReferenceImpl.java new file mode 100644 index 00000000..7730a4b6 --- /dev/null +++ b/client/src/main/java/dev/restate/client/ClientServiceReferenceImpl.java @@ -0,0 +1,211 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client; + +import static dev.restate.common.reflections.RestateUtils.toRequest; + +import dev.restate.common.InvocationOptions; +import dev.restate.common.Request; +import dev.restate.common.Target; +import dev.restate.common.reflections.*; +import dev.restate.serde.Serde; +import dev.restate.serde.TypeTag; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import org.jspecify.annotations.Nullable; + +final class ClientServiceReferenceImpl implements ClientServiceReference { + + private final Client innerClient; + + private final Class clazz; + private final String serviceName; + private final @Nullable String key; + + // The simple proxy for users + private SVC proxyClient; + + // To use call/send + private MethodInfoCollector methodInfoCollector; + + ClientServiceReferenceImpl(Client innerClient, Class clazz, @Nullable String key) { + this.innerClient = innerClient; + this.clazz = clazz; + this.serviceName = ReflectionUtils.extractServiceName(clazz); + this.key = key; + } + + @Override + public SVC client() { + if (proxyClient == null) { + this.proxyClient = + ProxySupport.createProxy( + clazz, + invocation -> { + var methodInfo = MethodInfo.fromMethod(invocation.getMethod()); + + //noinspection unchecked + return innerClient + .call( + Request.of( + Target.virtualObject(serviceName, key, methodInfo.getHandlerName()), + (TypeTag) + RestateUtils.typeTag(methodInfo.getInputType()), + (TypeTag) + RestateUtils.typeTag(methodInfo.getOutputType()), + invocation.getArguments().length == 0 + ? null + : invocation.getArguments()[0])) + .response(); + }); + } + return this.proxyClient; + } + + @SuppressWarnings("unchecked") + @Override + public CompletableFuture> callAsync( + BiFunction s, I input, InvocationOptions invocationOptions) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); + return innerClient.callAsync( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + (TypeTag) RestateUtils.typeTag(methodInfo.getInputType()), + (TypeTag) RestateUtils.typeTag(methodInfo.getOutputType()), + input, + invocationOptions)); + } + + @SuppressWarnings("unchecked") + @Override + public CompletableFuture> callAsync( + BiConsumer s, I input, InvocationOptions invocationOptions) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); + return innerClient.callAsync( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + (TypeTag) RestateUtils.typeTag(methodInfo.getInputType()), + Serde.VOID, + input, + invocationOptions)); + } + + @SuppressWarnings("unchecked") + @Override + public CompletableFuture> callAsync( + Function s, InvocationOptions invocationOptions) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s); + return innerClient.callAsync( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + Serde.VOID, + (TypeTag) RestateUtils.typeTag(methodInfo.getOutputType()), + null, + invocationOptions)); + } + + @Override + public CompletableFuture> callAsync( + Consumer s, InvocationOptions invocationOptions) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s); + return innerClient.callAsync( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + Serde.VOID, + Serde.VOID, + null, + invocationOptions)); + } + + @SuppressWarnings("unchecked") + @Override + public CompletableFuture> sendAsync( + BiFunction s, I input, Duration delay, InvocationOptions invocationOptions) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); + return innerClient.sendAsync( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + (TypeTag) RestateUtils.typeTag(methodInfo.getInputType()), + (TypeTag) RestateUtils.typeTag(methodInfo.getOutputType()), + input, + invocationOptions), + delay); + } + + @SuppressWarnings("unchecked") + @Override + public CompletableFuture> sendAsync( + BiConsumer s, I input, Duration delay, InvocationOptions invocationOptions) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); + return innerClient.sendAsync( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + (TypeTag) RestateUtils.typeTag(methodInfo.getInputType()), + Serde.VOID, + input, + invocationOptions), + delay); + } + + @SuppressWarnings("unchecked") + @Override + public CompletableFuture> sendAsync( + Function s, Duration delay, InvocationOptions invocationOptions) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s); + return innerClient.sendAsync( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + Serde.VOID, + (TypeTag) RestateUtils.typeTag(methodInfo.getOutputType()), + null, + invocationOptions), + delay); + } + + @Override + public CompletableFuture> sendAsync( + Consumer s, Duration delay, InvocationOptions invocationOptions) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s); + return innerClient.sendAsync( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + Serde.VOID, + Serde.VOID, + null, + invocationOptions), + delay); + } + + private MethodInfoCollector getMethodInfoCollector() { + if (this.methodInfoCollector == null) { + this.methodInfoCollector = new MethodInfoCollector<>(this.clazz); + } + return this.methodInfoCollector; + } +} diff --git a/common/build.gradle.kts b/common/build.gradle.kts index cb46d2f5..27cfb5f9 100644 --- a/common/build.gradle.kts +++ b/common/build.gradle.kts @@ -10,6 +10,8 @@ description = "Common types used by different Restate Java modules" dependencies { compileOnly(libs.jspecify) + implementation(libs.log4j.api) + testImplementation(libs.junit.jupiter) testImplementation(libs.assertj) } diff --git a/common/src/main/java/dev/restate/common/InvocationOptions.java b/common/src/main/java/dev/restate/common/InvocationOptions.java new file mode 100644 index 00000000..fd989caa --- /dev/null +++ b/common/src/main/java/dev/restate/common/InvocationOptions.java @@ -0,0 +1,154 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; +import org.jspecify.annotations.Nullable; + +public class InvocationOptions { + + public static final InvocationOptions DEFAULT = new InvocationOptions(null, null); + + private final @Nullable String idempotencyKey; + private final @Nullable LinkedHashMap headers; + + InvocationOptions( + @Nullable String idempotencyKey, @Nullable LinkedHashMap headers) { + this.idempotencyKey = idempotencyKey; + this.headers = headers; + } + + public @Nullable String getIdempotencyKey() { + return idempotencyKey; + } + + public @Nullable Map getHeaders() { + return headers; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof InvocationOptions that)) return false; + return Objects.equals(getIdempotencyKey(), that.getIdempotencyKey()) + && Objects.equals(getHeaders(), that.getHeaders()); + } + + @Override + public int hashCode() { + return Objects.hash(getIdempotencyKey(), getHeaders()); + } + + @Override + public String toString() { + return "RequestOptions{" + + "idempotencyKey='" + + idempotencyKey + + '\'' + + ", headers=" + + headers + + '}'; + } + + public static Builder idempotencyKey(String idempotencyKey) { + return new Builder(null, null).idempotencyKey(idempotencyKey); + } + + public static Builder header(String key, String value) { + return new Builder(null, null).header(key, value); + } + + public static Builder headers(Map newHeaders) { + return new Builder(null, null).headers(newHeaders); + } + + public static final class Builder { + @Nullable private String idempotencyKey; + @Nullable private LinkedHashMap headers; + + private Builder( + @Nullable String idempotencyKey, @Nullable LinkedHashMap headers) { + this.idempotencyKey = idempotencyKey; + this.headers = headers; + } + + /** + * @param idempotencyKey Idempotency key to attach in the request. + * @return this instance, so the builder can be used fluently. + */ + public Builder idempotencyKey(String idempotencyKey) { + this.idempotencyKey = idempotencyKey; + return this; + } + + /** + * Append this header to the list of configured headers. + * + * @param key header key + * @param value header value + * @return this instance, so the builder can be used fluently. + */ + public Builder header(String key, String value) { + if (this.headers == null) { + this.headers = new LinkedHashMap<>(); + } + this.headers.put(key, value); + return this; + } + + /** + * Append the given header map to the list of headers. + * + * @param newHeaders headers to send together with the request. + * @return this instance, so the builder can be used fluently. + */ + public Builder headers(Map newHeaders) { + if (this.headers == null) { + this.headers = new LinkedHashMap<>(); + } + this.headers.putAll(newHeaders); + return this; + } + + public @Nullable String getIdempotencyKey() { + return idempotencyKey; + } + + /** + * @param idempotencyKey Idempotency key to attach in the request. + */ + public void setIdempotencyKey(@Nullable String idempotencyKey) { + idempotencyKey(idempotencyKey); + } + + public @Nullable Map getHeaders() { + return headers; + } + + /** + * @param headers headers to send together with the request. This will overwrite the already + * configured headers + */ + public void setHeaders(@Nullable Map headers) { + headers(headers); + } + + /** + * @return build the request + */ + public InvocationOptions build() { + return new InvocationOptions(this.idempotencyKey, this.headers); + } + } + + public Builder toBuilder() { + return new Builder(this.idempotencyKey, this.headers); + } +} diff --git a/common/src/main/java/dev/restate/common/reflections/ConcurrentReferenceHashMap.java b/common/src/main/java/dev/restate/common/reflections/ConcurrentReferenceHashMap.java new file mode 100644 index 00000000..c1e196fa --- /dev/null +++ b/common/src/main/java/dev/restate/common/reflections/ConcurrentReferenceHashMap.java @@ -0,0 +1,1316 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflections; + +import java.lang.ref.ReferenceQueue; +import java.lang.ref.SoftReference; +import java.lang.ref.WeakReference; +import java.lang.reflect.Array; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.BiFunction; +import java.util.function.Function; +import org.jspecify.annotations.Nullable; + +/** + * A {@link ConcurrentHashMap} variant that uses {@link ReferenceType#SOFT soft} or {@linkplain + * ReferenceType#WEAK weak} references for both {@code keys} and {@code values}. + * + *

This class can be used as an alternative to {@code Collections.synchronizedMap(new + * WeakHashMap>())} in order to support better performance when accessed + * concurrently. This implementation follows the same design constraints as {@link + * ConcurrentHashMap} with the exception that {@code null} values and {@code null} keys are + * supported. + * + *

NOTE: The use of references means that there is no guarantee that items placed into the + * map will be subsequently available. The garbage collector may discard references at any time, so + * it may appear that an unknown thread is silently removing entries. + * + *

If not explicitly specified, this implementation will use {@linkplain SoftReference soft entry + * references}. + * + * @author Phillip Webb + * @author Juergen Hoeller + * @author Brian Clozel + * @since 3.2 + * @param the key type + * @param the value type + */ +public class ConcurrentReferenceHashMap extends AbstractMap + implements ConcurrentMap { + + private static final int DEFAULT_INITIAL_CAPACITY = 16; + + private static final float DEFAULT_LOAD_FACTOR = 0.75f; + + private static final int DEFAULT_CONCURRENCY_LEVEL = 16; + + private static final ReferenceType DEFAULT_REFERENCE_TYPE = ReferenceType.SOFT; + + private static final int MAXIMUM_CONCURRENCY_LEVEL = 1 << 16; + + private static final int MAXIMUM_SEGMENT_SIZE = 1 << 30; + + /** Array of segments indexed using the high order bits from the hash. */ + private final Segment[] segments; + + /** + * When the average number of references per table exceeds this value resize will be attempted. + */ + private final float loadFactor; + + /** The reference type: SOFT or WEAK. */ + private final ReferenceType referenceType; + + /** + * The shift value used to calculate the size of the segments array and an index from the hash. + */ + private final int shift; + + /** Late binding entry set. */ + private @Nullable Set> entrySet; + + /** Late binding key set. */ + private @Nullable Set keySet; + + /** Late binding values collection. */ + private @Nullable Collection values; + + /** Create a new {@code ConcurrentReferenceHashMap} instance. */ + public ConcurrentReferenceHashMap() { + this( + DEFAULT_INITIAL_CAPACITY, + DEFAULT_LOAD_FACTOR, + DEFAULT_CONCURRENCY_LEVEL, + DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * + * @param initialCapacity the initial capacity of the map + */ + public ConcurrentReferenceHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL, DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * + * @param initialCapacity the initial capacity of the map + * @param loadFactor the load factor. When the average number of references per table exceeds this + * value resize will be attempted + */ + public ConcurrentReferenceHashMap(int initialCapacity, float loadFactor) { + this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL, DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * + * @param initialCapacity the initial capacity of the map + * @param concurrencyLevel the expected number of threads that will concurrently write to the map + */ + public ConcurrentReferenceHashMap(int initialCapacity, int concurrencyLevel) { + this(initialCapacity, DEFAULT_LOAD_FACTOR, concurrencyLevel, DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * + * @param initialCapacity the initial capacity of the map + * @param referenceType the reference type used for entries (soft or weak) + */ + public ConcurrentReferenceHashMap(int initialCapacity, ReferenceType referenceType) { + this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL, referenceType); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * + * @param initialCapacity the initial capacity of the map + * @param loadFactor the load factor. When the average number of references per table exceeds this + * value, resize will be attempted. + * @param concurrencyLevel the expected number of threads that will concurrently write to the map + */ + public ConcurrentReferenceHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) { + this(initialCapacity, loadFactor, concurrencyLevel, DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * + * @param initialCapacity the initial capacity of the map + * @param loadFactor the load factor. When the average number of references per table exceeds this + * value, resize will be attempted. + * @param concurrencyLevel the expected number of threads that will concurrently write to the map + * @param referenceType the reference type used for entries (soft or weak) + */ + @SuppressWarnings("unchecked") + public ConcurrentReferenceHashMap( + int initialCapacity, float loadFactor, int concurrencyLevel, ReferenceType referenceType) { + this.loadFactor = loadFactor; + this.shift = calculateShift(concurrencyLevel, MAXIMUM_CONCURRENCY_LEVEL); + int size = 1 << this.shift; + this.referenceType = referenceType; + int roundedUpSegmentCapacity = (int) ((initialCapacity + size - 1L) / size); + int initialSize = 1 << calculateShift(roundedUpSegmentCapacity, MAXIMUM_SEGMENT_SIZE); + Segment[] segments = (Segment[]) Array.newInstance(Segment.class, size); + int resizeThreshold = (int) (initialSize * getLoadFactor()); + for (int i = 0; i < segments.length; i++) { + segments[i] = new Segment(initialSize, resizeThreshold); + } + this.segments = segments; + } + + protected final float getLoadFactor() { + return this.loadFactor; + } + + protected final int getSegmentsSize() { + return this.segments.length; + } + + protected final Segment getSegment(int index) { + return this.segments[index]; + } + + /** + * Factory method that returns the {@link ReferenceManager}. This method will be called once for + * each {@link Segment}. + * + * @return a new reference manager + */ + protected ReferenceManager createReferenceManager() { + return new ReferenceManager(); + } + + /** + * Get the hash for a given object, apply an additional hash function to reduce collisions. This + * implementation uses the same Wang/Jenkins algorithm as {@link ConcurrentHashMap}. Subclasses + * can override to provide alternative hashing. + * + * @param o the object to hash (may be null) + * @return the resulting hash code + */ + protected int getHash(@Nullable Object o) { + int hash = (o != null ? o.hashCode() : 0); + hash += (hash << 15) ^ 0xffffcd7d; + hash ^= (hash >>> 10); + hash += (hash << 3); + hash ^= (hash >>> 6); + hash += (hash << 2) + (hash << 14); + hash ^= (hash >>> 16); + return hash; + } + + @Override + public @Nullable V get(@Nullable Object key) { + Reference ref = getReference(key, Restructure.WHEN_NECESSARY); + Entry entry = (ref != null ? ref.get() : null); + return (entry != null ? entry.getValue() : null); + } + + @Override + public @Nullable V getOrDefault(@Nullable Object key, @Nullable V defaultValue) { + Reference ref = getReference(key, Restructure.WHEN_NECESSARY); + Entry entry = (ref != null ? ref.get() : null); + return (entry != null ? entry.getValue() : defaultValue); + } + + @Override + public boolean containsKey(@Nullable Object key) { + Reference ref = getReference(key, Restructure.WHEN_NECESSARY); + Entry entry = (ref != null ? ref.get() : null); + return (entry != null && Objects.equals(entry.getKey(), key)); + } + + /** + * Return a {@link Reference} to the {@link Entry} for the specified {@code key}, or {@code null} + * if not found. + * + * @param key the key (can be {@code null}) + * @param restructure types of restructure allowed during this call + * @return the reference, or {@code null} if not found + */ + protected final @Nullable Reference getReference( + @Nullable Object key, Restructure restructure) { + int hash = getHash(key); + return getSegmentForHash(hash).getReference(key, hash, restructure); + } + + @Override + public @Nullable V put(@Nullable K key, @Nullable V value) { + return put(key, value, true); + } + + @Override + public @Nullable V putIfAbsent(@Nullable K key, @Nullable V value) { + return put(key, value, false); + } + + private @Nullable V put( + final @Nullable K key, final @Nullable V value, final boolean overwriteExisting) { + return doTask( + key, + new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) { + @Override + protected @Nullable V execute( + @Nullable Reference ref, + @Nullable Entry entry, + @Nullable Entries entries) { + if (entry != null) { + V oldValue = entry.getValue(); + if (overwriteExisting) { + entry.setValue(value); + } + return oldValue; + } + entries.add(value); + return null; + } + }); + } + + @Override + public @Nullable V remove(@Nullable Object key) { + return doTask( + key, + new Task(TaskOption.RESTRUCTURE_AFTER, TaskOption.SKIP_IF_EMPTY) { + @Override + protected @Nullable V execute( + @Nullable Reference ref, @Nullable Entry entry) { + if (entry != null) { + if (ref != null) { + ref.release(); + } + return entry.value; + } + return null; + } + }); + } + + @Override + public boolean remove(@Nullable Object key, final @Nullable Object value) { + Boolean result = + doTask( + key, + new Task(TaskOption.RESTRUCTURE_AFTER, TaskOption.SKIP_IF_EMPTY) { + @Override + protected Boolean execute( + @Nullable Reference ref, @Nullable Entry entry) { + if (entry != null && Objects.equals(entry.getValue(), value)) { + if (ref != null) { + ref.release(); + } + return true; + } + return false; + } + }); + return Boolean.TRUE.equals(result); + } + + @Override + public boolean replace(@Nullable K key, final @Nullable V oldValue, final @Nullable V newValue) { + Boolean result = + doTask( + key, + new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.SKIP_IF_EMPTY) { + @Override + protected Boolean execute( + @Nullable Reference ref, @Nullable Entry entry) { + if (entry != null && Objects.equals(entry.getValue(), oldValue)) { + entry.setValue(newValue); + return true; + } + return false; + } + }); + return Boolean.TRUE.equals(result); + } + + @Override + public @Nullable V replace(@Nullable K key, final @Nullable V value) { + return doTask( + key, + new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.SKIP_IF_EMPTY) { + @Override + protected @Nullable V execute( + @Nullable Reference ref, @Nullable Entry entry) { + if (entry != null) { + V oldValue = entry.getValue(); + entry.setValue(value); + return oldValue; + } + return null; + } + }); + } + + @Override + public @Nullable V computeIfAbsent( + @Nullable K key, Function<@Nullable ? super K, @Nullable ? extends V> mappingFunction) { + return doTask( + key, + new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) { + @Override + protected @Nullable V execute( + @Nullable Reference ref, + @Nullable Entry entry, + @Nullable Entries entries) { + if (entry != null) { + return entry.getValue(); + } + V value = mappingFunction.apply(key); + // Add entry only if not null + if (value != null) { + entries.add(value); + } + return value; + } + }); + } + + @Override + public @Nullable V computeIfPresent( + @Nullable K key, + BiFunction<@Nullable ? super K, @Nullable ? super V, @Nullable ? extends V> + remappingFunction) { + return doTask( + key, + new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) { + @Override + protected @Nullable V execute( + @Nullable Reference ref, + @Nullable Entry entry, + @Nullable Entries entries) { + if (entry != null) { + V oldValue = entry.getValue(); + V value = remappingFunction.apply(key, oldValue); + if (value != null) { + // Replace entry + entry.setValue(value); + return value; + } else { + // Remove entry + if (ref != null) { + ref.release(); + } + } + } + return null; + } + }); + } + + @Override + public @Nullable V compute( + @Nullable K key, + BiFunction<@Nullable ? super K, @Nullable ? super V, @Nullable ? extends V> + remappingFunction) { + return doTask( + key, + new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) { + @Override + protected @Nullable V execute( + @Nullable Reference ref, + @Nullable Entry entry, + @Nullable Entries entries) { + V oldValue = null; + if (entry != null) { + oldValue = entry.getValue(); + } + V value = remappingFunction.apply(key, oldValue); + if (value != null) { + if (entry != null) { + // Replace entry + entry.setValue(value); + } else { + // Add entry + entries.add(value); + } + return value; + } else { + // Remove entry + if (ref != null) { + ref.release(); + } + } + return null; + } + }); + } + + @Override + public @Nullable V merge( + @Nullable K key, + @Nullable V value, + BiFunction<@Nullable ? super V, @Nullable ? super V, @Nullable ? extends V> + remappingFunction) { + return doTask( + key, + new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) { + @Override + protected @Nullable V execute( + @Nullable Reference ref, + @Nullable Entry entry, + @Nullable Entries entries) { + if (entry != null) { + V oldValue = entry.getValue(); + V newValue = remappingFunction.apply(oldValue, value); + if (newValue != null) { + // Replace entry + entry.setValue(newValue); + return newValue; + } else { + // Remove entry + if (ref != null) { + ref.release(); + } + return null; + } + } else { + // Add entry + entries.add(value); + return value; + } + } + }); + } + + @Override + public void clear() { + for (Segment segment : this.segments) { + segment.clear(); + } + } + + /** + * Remove any entries that have been garbage collected and are no longer referenced. Under normal + * circumstances garbage collected entries are automatically purged as items are added or removed + * from the Map. This method can be used to force a purge, and is useful when the Map is read + * frequently but updated less often. + */ + public void purgeUnreferencedEntries() { + for (Segment segment : this.segments) { + segment.restructureIfNecessary(false); + } + } + + @Override + public int size() { + int size = 0; + for (Segment segment : this.segments) { + size += segment.getCount(); + } + return size; + } + + @Override + public boolean isEmpty() { + for (Segment segment : this.segments) { + if (segment.getCount() > 0) { + return false; + } + } + return true; + } + + @Override + public Set> entrySet() { + Set> entrySet = this.entrySet; + if (entrySet == null) { + entrySet = new EntrySet(); + this.entrySet = entrySet; + } + return entrySet; + } + + @Override + public Set keySet() { + Set keySet = this.keySet; + if (keySet == null) { + keySet = new KeySet(); + this.keySet = keySet; + } + return keySet; + } + + @Override + public Collection values() { + Collection values = this.values; + if (values == null) { + values = new Values(); + this.values = values; + } + return values; + } + + private @Nullable T doTask(@Nullable Object key, Task task) { + int hash = getHash(key); + return getSegmentForHash(hash).doTask(hash, key, task); + } + + private Segment getSegmentForHash(int hash) { + return this.segments[(hash >>> (32 - this.shift)) & (this.segments.length - 1)]; + } + + /** + * Calculate a shift value that can be used to create a power-of-two value between the specified + * maximum and minimum values. + * + * @param minimumValue the minimum value + * @param maximumValue the maximum value + * @return the calculated shift (use {@code 1 << shift} to obtain a value) + */ + protected static int calculateShift(int minimumValue, int maximumValue) { + int shift = 0; + int value = 1; + while (value < minimumValue && value < maximumValue) { + value <<= 1; + shift++; + } + return shift; + } + + /** Various reference types supported by this map. */ + public enum ReferenceType { + + /** Use {@link SoftReference SoftReferences}. */ + SOFT, + + /** Use {@link WeakReference WeakReferences}. */ + WEAK + } + + /** A single segment used to divide the map to allow better concurrent performance. */ + @SuppressWarnings("serial") + protected final class Segment extends ReentrantLock { + + private final ReferenceManager referenceManager; + + private final int initialSize; + + /** + * Array of references indexed using the low order bits from the hash. This property should only + * be set along with {@code resizeThreshold}. + */ + private volatile @Nullable Reference[] references; + + /** + * The total number of references contained in this segment. This includes chained references + * and references that have been garbage collected but not purged. + */ + private final AtomicInteger count = new AtomicInteger(); + + /** + * The threshold when resizing of the references should occur. When {@code count} exceeds this + * value references will be resized. + */ + private int resizeThreshold; + + public Segment(int initialSize, int resizeThreshold) { + this.referenceManager = createReferenceManager(); + this.initialSize = initialSize; + this.references = createReferenceArray(initialSize); + this.resizeThreshold = resizeThreshold; + } + + public @Nullable Reference getReference( + @Nullable Object key, int hash, Restructure restructure) { + if (restructure == Restructure.WHEN_NECESSARY) { + restructureIfNecessary(false); + } + if (this.count.get() == 0) { + return null; + } + // Use a local copy to protect against other threads writing + @Nullable Reference[] references = this.references; + int index = getIndex(hash, references); + Reference head = references[index]; + return findInChain(head, key, hash); + } + + /** + * Apply an update operation to this segment. The segment will be locked during the update. + * + * @param hash the hash of the key + * @param key the key + * @param task the update operation + * @return the result of the operation + */ + private @Nullable T doTask(final int hash, final @Nullable Object key, final Task task) { + boolean resize = task.hasOption(TaskOption.RESIZE); + if (task.hasOption(TaskOption.RESTRUCTURE_BEFORE)) { + restructureIfNecessary(resize); + } + if (task.hasOption(TaskOption.SKIP_IF_EMPTY) && this.count.get() == 0) { + return task.execute(null, null, null); + } + lock(); + try { + final int index = getIndex(hash, this.references); + final Reference head = this.references[index]; + Reference ref = findInChain(head, key, hash); + Entry entry = (ref != null ? ref.get() : null); + Entries entries = + value -> { + @SuppressWarnings("unchecked") + Entry newEntry = new Entry<>((K) key, value); + Reference newReference = + Segment.this.referenceManager.createReference(newEntry, hash, head); + Segment.this.references[index] = newReference; + Segment.this.count.incrementAndGet(); + }; + return task.execute(ref, entry, entries); + } finally { + unlock(); + if (task.hasOption(TaskOption.RESTRUCTURE_AFTER)) { + restructureIfNecessary(resize); + } + } + } + + /** Clear all items from this segment. */ + public void clear() { + if (this.count.get() == 0) { + return; + } + lock(); + try { + this.references = createReferenceArray(this.initialSize); + this.resizeThreshold = (int) (this.references.length * getLoadFactor()); + this.count.set(0); + } finally { + unlock(); + } + } + + /** + * Restructure the underlying data structure when it becomes necessary. This method can increase + * the size of the references table as well as purge any references that have been garbage + * collected. + * + * @param allowResize if resizing is permitted + */ + void restructureIfNecessary(boolean allowResize) { + int currCount = this.count.get(); + boolean needsResize = allowResize && (currCount > 0 && currCount >= this.resizeThreshold); + Reference ref = this.referenceManager.pollForPurge(); + if (ref != null || (needsResize)) { + restructure(allowResize, ref); + } + } + + private void restructure(boolean allowResize, @Nullable Reference ref) { + lock(); + try { + int expectedCount = this.count.get(); + Set> toPurge = Collections.emptySet(); + if (ref != null) { + toPurge = new HashSet<>(); + while (ref != null) { + toPurge.add(ref); + ref = this.referenceManager.pollForPurge(); + } + } + expectedCount -= toPurge.size(); + + // Estimate new count, taking into account count inside lock and items that + // will be purged. + boolean needsResize = (expectedCount > 0 && expectedCount >= this.resizeThreshold); + boolean resizing = false; + int restructureSize = this.references.length; + if (allowResize && needsResize && restructureSize < MAXIMUM_SEGMENT_SIZE) { + restructureSize <<= 1; + resizing = true; + } + + int newCount = 0; + // Restructure the resized reference array + if (resizing) { + Reference[] restructured = createReferenceArray(restructureSize); + for (Reference reference : this.references) { + ref = reference; + while (ref != null) { + if (!toPurge.contains(ref)) { + Entry entry = ref.get(); + // Also filter out null references that are now null + // they should be polled from the queue in a later restructure call. + if (entry != null) { + int index = getIndex(ref.getHash(), restructured); + restructured[index] = + this.referenceManager.createReference( + entry, ref.getHash(), restructured[index]); + newCount++; + } + } + ref = ref.getNext(); + } + } + // Replace volatile members + this.references = restructured; + this.resizeThreshold = (int) (this.references.length * getLoadFactor()); + } + // Restructure the existing reference array "in place" + else { + for (int i = 0; i < this.references.length; i++) { + Reference purgedRef = null; + ref = this.references[i]; + while (ref != null) { + if (!toPurge.contains(ref)) { + Entry entry = ref.get(); + // Also filter out null references that are now null: + // They should be polled from the queue in a later restructure call. + if (entry != null) { + purgedRef = + this.referenceManager.createReference(entry, ref.getHash(), purgedRef); + } + newCount++; + } + ref = ref.getNext(); + } + this.references[i] = purgedRef; + } + } + this.count.set(newCount); + } finally { + unlock(); + } + } + + private @Nullable Reference findInChain( + @Nullable Reference ref, @Nullable Object key, int hash) { + Reference currRef = ref; + while (currRef != null) { + if (currRef.getHash() == hash) { + Entry entry = currRef.get(); + if (entry != null) { + K entryKey = entry.getKey(); + if (Objects.equals(entryKey, key)) { + return currRef; + } + } + } + currRef = currRef.getNext(); + } + return null; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private Reference[] createReferenceArray(int size) { + return new Reference[size]; + } + + private int getIndex(int hash, @Nullable Reference[] references) { + return (hash & (references.length - 1)); + } + + /** Return the size of the current references array. */ + public int getSize() { + return this.references.length; + } + + /** Return the total number of references in this segment. */ + public int getCount() { + return this.count.get(); + } + } + + /** + * A reference to an {@link Entry} contained in the map. Implementations are usually wrappers + * around specific Java reference implementations (for example, {@link SoftReference}). + * + * @param the key type + * @param the value type + */ + protected interface Reference { + + /** Return the referenced entry, or {@code null} if the entry is no longer available. */ + @Nullable Entry get(); + + /** Return the hash for the reference. */ + int getHash(); + + /** Return the next reference in the chain, or {@code null} if none. */ + @Nullable Reference getNext(); + + /** + * Release this entry and ensure that it will be returned from {@code + * ReferenceManager#pollForPurge()}. + */ + void release(); + } + + /** + * A single map entry. + * + * @param the key type + * @param the value type + */ + protected static final class Entry implements Map.Entry { + + private final @Nullable K key; + + private volatile @Nullable V value; + + public Entry(@Nullable K key, @Nullable V value) { + this.key = key; + this.value = value; + } + + @Override + public @Nullable K getKey() { + return this.key; + } + + @Override + public @Nullable V getValue() { + return this.value; + } + + @Override + public @Nullable V setValue(@Nullable V value) { + V previous = this.value; + this.value = value; + return previous; + } + + @Override + public boolean equals(@Nullable Object other) { + return (this == other + || (other instanceof Map.Entry that + && Objects.equals(getKey(), that.getKey()) + && Objects.equals(getValue(), that.getValue()))); + } + + @Override + public int hashCode() { + return (Objects.hashCode(this.key) ^ Objects.hashCode(this.value)); + } + + @Override + public String toString() { + return (this.key + "=" + this.value); + } + } + + /** A task that can be {@link Segment#doTask run} against a {@link Segment}. */ + private abstract class Task { + + private final EnumSet options; + + public Task(TaskOption... options) { + this.options = + (options.length == 0 + ? EnumSet.noneOf(TaskOption.class) + : EnumSet.of(options[0], options)); + } + + public boolean hasOption(TaskOption option) { + return this.options.contains(option); + } + + /** + * Execute the task. + * + * @param ref the found reference (or {@code null}) + * @param entry the found entry (or {@code null}) + * @param entries access to the underlying entries + * @return the result of the task + * @see #execute(Reference, Entry) + */ + protected @Nullable T execute( + @Nullable Reference ref, @Nullable Entry entry, @Nullable Entries entries) { + return execute(ref, entry); + } + + /** + * Convenience method that can be used for tasks that do not need access to {@link Entries}. + * + * @param ref the found reference (or {@code null}) + * @param entry the found entry (or {@code null}) + * @return the result of the task + * @see #execute(Reference, Entry, Entries) + */ + protected @Nullable T execute(@Nullable Reference ref, @Nullable Entry entry) { + return null; + } + } + + /** Various options supported by a {@code Task}. */ + private enum TaskOption { + RESTRUCTURE_BEFORE, + RESTRUCTURE_AFTER, + SKIP_IF_EMPTY, + RESIZE + } + + /** Allows a task access to {@link ConcurrentReferenceHashMap.Segment} entries. */ + private interface Entries { + + /** + * Add a new entry with the specified value. + * + * @param value the value to add + */ + void add(@Nullable V value); + } + + /** Internal entry-set implementation. */ + private final class EntrySet extends AbstractSet> { + + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + @Override + public boolean contains(@Nullable Object o) { + if (o instanceof Map.Entry entry) { + Reference ref = + ConcurrentReferenceHashMap.this.getReference(entry.getKey(), Restructure.NEVER); + Entry otherEntry = (ref != null ? ref.get() : null); + if (otherEntry != null) { + return Objects.equals(entry.getValue(), otherEntry.getValue()); + } + } + return false; + } + + @Override + public boolean remove(Object o) { + if (o instanceof Map.Entry entry) { + return ConcurrentReferenceHashMap.this.remove(entry.getKey(), entry.getValue()); + } + return false; + } + + @Override + public int size() { + return ConcurrentReferenceHashMap.this.size(); + } + + @Override + public void clear() { + ConcurrentReferenceHashMap.this.clear(); + } + + @Override + public Spliterator> spliterator() { + return Spliterators.spliterator(this, Spliterator.DISTINCT | Spliterator.CONCURRENT); + } + } + + /** Internal key-set implementation. */ + private final class KeySet extends AbstractSet { + + @Override + public Iterator iterator() { + return new KeyIterator(); + } + + @Override + public int size() { + return ConcurrentReferenceHashMap.this.size(); + } + + @Override + public boolean isEmpty() { + return ConcurrentReferenceHashMap.this.isEmpty(); + } + + @Override + public void clear() { + ConcurrentReferenceHashMap.this.clear(); + } + + @Override + public boolean contains(Object k) { + return ConcurrentReferenceHashMap.this.containsKey(k); + } + + @Override + public Spliterator spliterator() { + return Spliterators.spliterator(this, Spliterator.DISTINCT | Spliterator.CONCURRENT); + } + } + + /** Internal key iterator implementation. */ + private final class KeyIterator implements Iterator { + + private final Iterator> iterator = entrySet().iterator(); + + @Override + public boolean hasNext() { + return this.iterator.hasNext(); + } + + @Override + public void remove() { + this.iterator.remove(); + } + + @Override + public K next() { + return this.iterator.next().getKey(); + } + } + + /** Internal values collection implementation. */ + private final class Values extends AbstractCollection { + + @Override + public Iterator iterator() { + return new ValueIterator(); + } + + @Override + public int size() { + return ConcurrentReferenceHashMap.this.size(); + } + + @Override + public boolean isEmpty() { + return ConcurrentReferenceHashMap.this.isEmpty(); + } + + @Override + public void clear() { + ConcurrentReferenceHashMap.this.clear(); + } + + @Override + public boolean contains(Object v) { + return ConcurrentReferenceHashMap.this.containsValue(v); + } + + @Override + public Spliterator spliterator() { + return Spliterators.spliterator(this, Spliterator.CONCURRENT); + } + } + + /** Internal value iterator implementation. */ + private final class ValueIterator implements Iterator { + + private final Iterator> iterator = entrySet().iterator(); + + @Override + public boolean hasNext() { + return this.iterator.hasNext(); + } + + @Override + public void remove() { + this.iterator.remove(); + } + + @Override + public V next() { + return this.iterator.next().getValue(); + } + } + + /** Internal entry iterator implementation. */ + private final class EntryIterator implements Iterator> { + + private int segmentIndex; + + private int referenceIndex; + + private @Nullable Reference @Nullable [] references; + + private @Nullable Reference reference; + + private @Nullable Entry next; + + private @Nullable Entry last; + + public EntryIterator() { + moveToNextSegment(); + } + + @Override + public boolean hasNext() { + getNextIfNecessary(); + return (this.next != null); + } + + @Override + public Entry next() { + getNextIfNecessary(); + if (this.next == null) { + throw new NoSuchElementException(); + } + this.last = this.next; + this.next = null; + return this.last; + } + + private void getNextIfNecessary() { + while (this.next == null) { + moveToNextReference(); + if (this.reference == null) { + return; + } + this.next = this.reference.get(); + } + } + + private void moveToNextReference() { + if (this.reference != null) { + this.reference = this.reference.getNext(); + } + while (this.reference == null && this.references != null) { + if (this.referenceIndex >= this.references.length) { + moveToNextSegment(); + this.referenceIndex = 0; + } else { + this.reference = this.references[this.referenceIndex]; + this.referenceIndex++; + } + } + } + + private void moveToNextSegment() { + this.reference = null; + this.references = null; + if (this.segmentIndex < ConcurrentReferenceHashMap.this.segments.length) { + this.references = ConcurrentReferenceHashMap.this.segments[this.segmentIndex].references; + this.segmentIndex++; + } + } + + @Override + public void remove() { + ConcurrentReferenceHashMap.this.remove(this.last.getKey()); + this.last = null; + } + } + + /** The types of restructuring that can be performed. */ + protected enum Restructure { + WHEN_NECESSARY, + NEVER + } + + /** + * Strategy class used to manage {@link Reference References}. This class can be overridden if + * alternative reference types need to be supported. + */ + protected class ReferenceManager { + + private final ReferenceQueue> queue = new ReferenceQueue<>(); + + /** + * Factory method used to create a new {@link Reference}. + * + * @param entry the entry contained in the reference + * @param hash the hash + * @param next the next reference in the chain, or {@code null} if none + * @return a new {@link Reference} + */ + public Reference createReference( + Entry entry, int hash, @Nullable Reference next) { + if (ConcurrentReferenceHashMap.this.referenceType == ReferenceType.WEAK) { + return new WeakEntryReference<>(entry, hash, next, this.queue); + } + return new SoftEntryReference<>(entry, hash, next, this.queue); + } + + /** + * Return any reference that has been garbage collected and can be purged from the underlying + * structure or {@code null} if no references need purging. This method must be thread safe and + * ideally should not block when returning {@code null}. References should be returned once and + * only once. + * + * @return a reference to purge or {@code null} + */ + @SuppressWarnings("unchecked") + public @Nullable Reference pollForPurge() { + return (Reference) this.queue.poll(); + } + } + + /** Internal {@link Reference} implementation for {@link SoftReference SoftReferences}. */ + private static final class SoftEntryReference extends SoftReference> + implements Reference { + + private final int hash; + + private final @Nullable Reference nextReference; + + public SoftEntryReference( + Entry entry, + int hash, + @Nullable Reference next, + ReferenceQueue> queue) { + + super(entry, queue); + this.hash = hash; + this.nextReference = next; + } + + @Override + public int getHash() { + return this.hash; + } + + @Override + public @Nullable Reference getNext() { + return this.nextReference; + } + + @Override + public void release() { + enqueue(); + } + } + + /** Internal {@link Reference} implementation for {@link WeakReference WeakReferences}. */ + private static final class WeakEntryReference extends WeakReference> + implements Reference { + + private final int hash; + + private final @Nullable Reference nextReference; + + public WeakEntryReference( + Entry entry, + int hash, + @Nullable Reference next, + ReferenceQueue> queue) { + + super(entry, queue); + this.hash = hash; + this.nextReference = next; + } + + @Override + public int getHash() { + return this.hash; + } + + @Override + public @Nullable Reference getNext() { + return this.nextReference; + } + + @Override + public void release() { + enqueue(); + } + } +} diff --git a/common/src/main/java/dev/restate/common/reflections/MethodInfo.java b/common/src/main/java/dev/restate/common/reflections/MethodInfo.java new file mode 100644 index 00000000..e2598e2f --- /dev/null +++ b/common/src/main/java/dev/restate/common/reflections/MethodInfo.java @@ -0,0 +1,46 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflections; + +import java.lang.reflect.Method; +import java.lang.reflect.Type; + +public class MethodInfo extends RuntimeException { + private final String handlerName; + private final Type inputType; + private final Type outputType; + + private MethodInfo(String handlerName, Type inputType, Type outputType) { + this.inputType = inputType; + this.outputType = outputType; + this.handlerName = handlerName; + } + + public String getHandlerName() { + return handlerName; + } + + public Type getInputType() { + return inputType; + } + + public Type getOutputType() { + return outputType; + } + + public static MethodInfo fromMethod(Method method) { + var handlerInfo = ReflectionUtils.mustHaveHandlerAnnotation(method); + var genericParameters = method.getGenericParameterTypes(); + var inputType = genericParameters.length == 0 ? Void.TYPE : genericParameters[0]; + var outputType = method.getGenericReturnType(); + var handlerName = handlerInfo.name(); + + return new MethodInfo(handlerName, inputType, outputType); + } +} diff --git a/common/src/main/java/dev/restate/common/reflections/MethodInfoCollector.java b/common/src/main/java/dev/restate/common/reflections/MethodInfoCollector.java new file mode 100644 index 00000000..9f5d2196 --- /dev/null +++ b/common/src/main/java/dev/restate/common/reflections/MethodInfoCollector.java @@ -0,0 +1,68 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflections; + +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +public final class MethodInfoCollector { + + private final SVC infoCollectorProxy; + + public MethodInfoCollector(Class svcClass) { + this.infoCollectorProxy = ProxySupport.createProxy(svcClass, METHOD_INFO_COLLECTOR_INTERCEPTOR); + } + + public MethodInfo resolve(Function s) { + try { + s.apply(this.infoCollectorProxy); + throw new UnsupportedOperationException( + "The provided lambda MUST contain ONLY a method reference to the service method"); + } catch (MethodInfo e) { + return e; + } + } + + public MethodInfo resolve(BiFunction s, I input) { + try { + s.apply(this.infoCollectorProxy, input); + throw new UnsupportedOperationException( + "The provided lambda MUST contain ONLY a method reference to the service method"); + } catch (MethodInfo e) { + return e; + } + } + + public MethodInfo resolve(BiConsumer s, I input) { + try { + s.accept(this.infoCollectorProxy, input); + throw new UnsupportedOperationException( + "The provided lambda MUST contain ONLY a method reference to a service method"); + } catch (MethodInfo e) { + return e; + } + } + + public MethodInfo resolve(Consumer s) { + try { + s.accept(this.infoCollectorProxy); + throw new UnsupportedOperationException( + "The provided lambda MUST contain ONLY a method reference to a service method"); + } catch (MethodInfo e) { + return e; + } + } + + private static final ProxySupport.MethodInterceptor METHOD_INFO_COLLECTOR_INTERCEPTOR = + invocation -> { + throw MethodInfo.fromMethod(invocation.getMethod()); + }; +} diff --git a/common/src/main/java/dev/restate/common/reflections/ProxySupport.java b/common/src/main/java/dev/restate/common/reflections/ProxySupport.java new file mode 100644 index 00000000..4b0749c6 --- /dev/null +++ b/common/src/main/java/dev/restate/common/reflections/ProxySupport.java @@ -0,0 +1,137 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflections; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.ArrayList; +import java.util.List; +import java.util.ServiceConfigurationError; +import java.util.ServiceLoader; +import java.util.stream.Collectors; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.Nullable; + +public final class ProxySupport { + + private static class ProxySupportSingleton { + private static final ProxySupport INSTANCE = new ProxySupport(); + } + + private static final Logger LOG = LogManager.getLogger(ProxySupport.class); + + private final List factories; + + public ProxySupport() { + this.factories = new ArrayList<>(2); + this.factories.add(new JdkProxyFactory()); + + var serviceLoaderIterator = ServiceLoader.load(ProxyFactory.class).iterator(); + while (serviceLoaderIterator.hasNext()) { + try { + this.factories.add(serviceLoaderIterator.next()); + } catch (ServiceConfigurationError | Exception e) { + LOG.error( + "Found proxy factory that cannot be loaded using service provider. Proxy clients might not work correctly.", + e); + throw e; + } + } + } + + /** Resolve the code generated {@link ProxyFactory} */ + public static T createProxy(Class clazz, MethodInterceptor interceptor) { + ProxySupport proxySupport = ProxySupportSingleton.INSTANCE; + + for (ProxyFactory proxyFactory : proxySupport.factories) { + T proxy = proxyFactory.createProxy(clazz, interceptor); + if (proxy != null) { + return proxy; + } + } + + throw new IllegalStateException( + "Class " + + clazz.toString() + + " cannot be proxied. If the type is a concrete class, make sure to have sdk-proxy-bytebuddy in your dependencies. Registered proxies: " + + proxySupport.factories.stream() + .map(pf -> pf.getClass().toString()) + .collect(Collectors.joining(", "))); + } + + public interface MethodInvocation { + Object[] getArguments(); + + Method getMethod(); + } + + @FunctionalInterface + public interface MethodInterceptor { + @Nullable Object invoke(MethodInvocation invocation) throws Throwable; + } + + @FunctionalInterface + public interface ProxyFactory { + /** If returns null, it's not supported. */ + @Nullable T createProxy(Class clazz, MethodInterceptor interceptor); + } + + private static final class JdkProxyFactory implements ProxyFactory { + + /** + * Mutable InvocationHandler wrapper that holds the interceptor. This provides consistency with + * ByteBuddy proxies where the interceptor is set via a field after instantiation. + */ + private static class InterceptorHolder implements InvocationHandler { + MethodInterceptor interceptor; + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + if (interceptor == null) { + throw new IllegalStateException("Interceptor not set on JDK proxy instance"); + } + + MethodInvocation invocation = + new MethodInvocation() { + @Override + public Object[] getArguments() { + return args != null ? args : new Object[0]; + } + + @Override + public Method getMethod() { + return method; + } + }; + return interceptor.invoke(invocation); + } + } + + @Override + @SuppressWarnings("unchecked") + public @Nullable T createProxy(Class clazz, MethodInterceptor interceptor) { + if (!clazz.isInterface()) { + return null; + } + + // Create holder with interceptor field (similar to ByteBuddy approach) + InterceptorHolder holder = new InterceptorHolder(); + + // Create proxy with the holder (JDK caches proxy class automatically) + T proxy = (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class[] {clazz}, holder); + + // Set the interceptor after proxy creation (consistent with ByteBuddy) + holder.interceptor = interceptor; + + return proxy; + } + } +} diff --git a/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java b/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java new file mode 100644 index 00000000..1270a61a --- /dev/null +++ b/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java @@ -0,0 +1,1201 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflections; + +import dev.restate.sdk.annotation.*; +import java.lang.annotation.Annotation; +import java.lang.reflect.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; + +public class ReflectionUtils { + + /** Record containing handler information extracted from annotations. */ + public record HandlerInfo(String name, boolean shared) {} + + /** + * Find a single {@link Annotation} of {@code annotationType} on the supplied {@link Class}, + * traversing its interfaces, annotations, and superclasses if the annotation is not directly + * present on the given class itself. + * + *

This method explicitly handles class-level annotations which are not declared as {@linkplain + * java.lang.annotation.Inherited inherited} as well as meta-annotations and annotations on + * interfaces. + * + *

The algorithm operates as follows: + * + *

    + *
  1. Search for the annotation on the given class and return it if found. + *
  2. Recursively search through all interfaces that the given class declares. + *
  3. Recursively search through the superclass hierarchy of the given class. + *
+ * + *

Note: in this context, the term recursively means that the search process continues + * by returning to step #1 with the current interface, annotation, or superclass as the class to + * look for annotations on. + * + * @param clazz the class to look for annotations on + * @param annotationType the type of annotation to look for + * @return the first matching annotation, or {@code null} if not found + */ + @Nullable + public static A findAnnotation( + Class clazz, @Nullable Class annotationType) { + if (annotationType == null) { + return null; + } + return findAnnotation(clazz, annotationType, new java.util.HashSet<>()); + } + + @Nullable + private static A findAnnotation( + Class clazz, Class annotationType, java.util.Set visited) { + + if (clazz == null || clazz == Object.class) { + return null; + } + + // Check if the annotation is directly present on the class + A annotation = clazz.getDeclaredAnnotation(annotationType); + if (annotation != null) { + return annotation; + } + + // Search on interfaces + for (Class ifc : clazz.getInterfaces()) { + annotation = findAnnotation(ifc, annotationType, visited); + if (annotation != null) { + return annotation; + } + } + + // Search on superclass + return findAnnotation(clazz.getSuperclass(), annotationType, visited); + } + + /** + * Find a single {@link Annotation} of {@code annotationType} on the supplied {@link Method}, + * traversing its super methods if the annotation is not directly present on the given + * method itself. + * + *

Annotations on methods are not inherited by default, so we need to handle this explicitly. + * + * @param method the method to look for annotations on + * @param annotationType the type of annotation to look for + * @return the first matching annotation, or {@code null} if not found + */ + @Nullable + public static A findAnnotation( + Method method, @Nullable Class annotationType) { + if (annotationType == null) { + return null; + } + + // Check if the annotation is directly present on the method + A annotation = method.getDeclaredAnnotation(annotationType); + if (annotation != null) { + return annotation; + } + + // Search through the type hierarchy + Class clazz = method.getDeclaringClass(); + return findAnnotationInTypeHierarchy(clazz, method, annotationType, new java.util.HashSet<>()); + } + + @Nullable + private static A findAnnotationInTypeHierarchy( + Class clazz, Method method, Class annotationType, java.util.Set> visited) { + + if (clazz == null || clazz == Object.class || !visited.add(clazz)) { + return null; + } + + // Try to find an equivalent method in this class/interface + Method equivalentMethod = null; + try { + equivalentMethod = clazz.getDeclaredMethod(method.getName(), method.getParameterTypes()); + } catch (NoSuchMethodException ex) { + // No such method in this class, continue searching + } + + if (equivalentMethod != null) { + A annotation = equivalentMethod.getDeclaredAnnotation(annotationType); + if (annotation != null) { + return annotation; + } + } + + // Search in interfaces + for (Class ifc : clazz.getInterfaces()) { + A annotation = findAnnotationInTypeHierarchy(ifc, method, annotationType, visited); + if (annotation != null) { + return annotation; + } + } + + // Search in superclass + return findAnnotationInTypeHierarchy(clazz.getSuperclass(), method, annotationType, visited); + } + + public static String extractServiceName(Class clazz) { + // Fallback: infer from hierarchy against known Restate markers + String inferred = inferRestateNameFromHierarchy(clazz); + if (inferred != null) { + return inferred; + } + + throw new IllegalArgumentException("Cannot infer Restate name from type: " + clazz.getName()); + } + + private static String inferRestateNameFromHierarchy(Class type) { + if (type == null || Object.class.equals(type)) { + return null; + } + + // Check if the type has any of the Restate component annotations + var restateServiceAnnotation = type.getAnnotation(Service.class); + if (restateServiceAnnotation != null) { + return extractNameFromAnnotations(type); + } + var restateVirtualObjectAnnotation = type.getAnnotation(VirtualObject.class); + if (restateVirtualObjectAnnotation != null) { + return extractNameFromAnnotations(type); + } + var restateWorkflowAnnotation = type.getAnnotation(Workflow.class); + if (restateWorkflowAnnotation != null) { + return extractNameFromAnnotations(type); + } + + // Check parent interfaces + for (Class parent : type.getInterfaces()) { + String res = inferRestateNameFromHierarchy(parent); + if (res != null) { + return res; + } + } + + // Recurse into superclass + return inferRestateNameFromHierarchy(type.getSuperclass()); + } + + private static String extractNameFromAnnotations(Class type) { + // Check for @Name annotation first + var nameAnnotation = type.getAnnotation(Name.class); + if (nameAnnotation != null + && nameAnnotation.value() != null + && !nameAnnotation.value().isEmpty()) { + return nameAnnotation.value(); + } + // Default to simple class name + return type.getSimpleName(); + } + + public static A mustHaveAnnotation( + Class clazz, Class annotationClazz) { + A annotation = findAnnotation(clazz, annotationClazz); + if (annotation == null) { + throw new IllegalArgumentException( + "The given class " + + clazz.getName() + + " is not annotated with @" + + annotationClazz.getSimpleName()); + } + return annotation; + } + + public static HandlerInfo mustHaveHandlerAnnotation(@NonNull Method method) { + // Check for @Handler or @Shared annotation (Shared implies Handler) + var handlerAnnotation = findAnnotation(method, Handler.class); + var sharedAnnotation = findAnnotation(method, Shared.class); + + if (handlerAnnotation == null && sharedAnnotation == null) { + throw new IllegalArgumentException( + "The invoked method '" + + method.getName() + + "' is not annotated with @" + + Handler.class.getSimpleName() + + " or @" + + Shared.class.getSimpleName()); + } + + // Extract the name from @Name annotation, or default to method name + var nameAnnotation = findAnnotation(method, Name.class); + String handlerName; + if (nameAnnotation != null + && nameAnnotation.value() != null + && !nameAnnotation.value().isEmpty()) { + handlerName = nameAnnotation.value(); + } else { + handlerName = method.getName(); + } + + // Determine if it's shared + boolean isShared = sharedAnnotation != null; + + return new HandlerInfo(handlerName, isShared); + } + + /** + * Walks the type hierarchy to find where the given rawType interface was parameterized. This + * handles inheritance chains and multiple interfaces correctly. + * + * @param concreteClass The concrete class to start searching from + * @param rawType The raw interface type to find (e.g., Function.class) + * @return The ParameterizedType with resolved type arguments, or null if not found + */ + public static ParameterizedType findParameterizedType(Class concreteClass, Class rawType) { + if (concreteClass == null || Object.class.equals(concreteClass)) { + return null; + } + + // Check direct interfaces + for (Type genericInterface : concreteClass.getGenericInterfaces()) { + ParameterizedType result = findParameterizedTypeInType(genericInterface, rawType); + if (result != null) { + return result; + } + } + + // Check superclass + Type genericSuperclass = concreteClass.getGenericSuperclass(); + if (genericSuperclass != null) { + ParameterizedType result = findParameterizedTypeInType(genericSuperclass, rawType); + if (result != null) { + return result; + } + } + + // Recurse up the hierarchy + return findParameterizedType(concreteClass.getSuperclass(), rawType); + } + + private static ParameterizedType findParameterizedTypeInType(Type type, Class rawType) { + if (type instanceof ParameterizedType paramType) { + if (paramType.getRawType().equals(rawType)) { + return paramType; + } + // Check if this parameterized type extends/implements the target + if (paramType.getRawType() instanceof Class clazz) { + return findParameterizedType(clazz, rawType); + } + } else if (type instanceof Class clazz) { + return findParameterizedType(clazz, rawType); + } + return null; + } + + public static boolean isKotlinClass(Class clazz) { + return Arrays.stream(clazz.getDeclaredAnnotations()) + .anyMatch(annotation -> annotation.annotationType().getName().equals("kotlin.Metadata")); + } + + // From Spring's ReflectionUtils + // License Apache 2.0 + + /** + * Pre-built {@link MethodFilter} that matches all non-bridge non-synthetic methods which are not + * declared on {@code java.lang.Object}. + * + * @since 3.0.5 + */ + public static final MethodFilter USER_DECLARED_METHODS = + (method -> + !method.isBridge() + && !method.isSynthetic() + && (method.getDeclaringClass() != Object.class)); + + /** Pre-built FieldFilter that matches all non-static, non-final fields. */ + public static final FieldFilter COPYABLE_FIELDS = + (field -> + !(Modifier.isStatic(field.getModifiers()) || Modifier.isFinal(field.getModifiers()))); + + /** + * Naming prefix for CGLIB-renamed methods. + * + * @see #isCglibRenamedMethod + */ + private static final String CGLIB_RENAMED_METHOD_PREFIX = "CGLIB$"; + + private static final Class[] EMPTY_CLASS_ARRAY = new Class[0]; + + private static final Method[] EMPTY_METHOD_ARRAY = new Method[0]; + + private static final Field[] EMPTY_FIELD_ARRAY = new Field[0]; + + private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0]; + + /** + * Cache for {@link Class#getDeclaredMethods()} plus equivalent default methods from Java 8 based + * interfaces, allowing for fast iteration. + */ + private static final Map, Method[]> declaredMethodsCache = + new ConcurrentReferenceHashMap<>(256); + + /** Cache for {@link Class#getDeclaredFields()}, allowing for fast iteration. */ + private static final Map, Field[]> declaredFieldsCache = + new ConcurrentReferenceHashMap<>(256); + + // Exception handling + + /** + * Handle the given reflection exception. + * + *

Should only be called if no checked exception is expected to be thrown by a target method, + * or if an error occurs while accessing a method or field. + * + *

Throws the underlying RuntimeException or Error in case of an InvocationTargetException with + * such a root cause. Throws an IllegalStateException with an appropriate message or + * UndeclaredThrowableException otherwise. + * + * @param ex the reflection exception to handle + */ + public static void handleReflectionException(Exception ex) { + if (ex instanceof NoSuchMethodException) { + throw new IllegalStateException("Method not found: " + ex.getMessage()); + } + if (ex instanceof IllegalAccessException) { + throw new IllegalStateException("Could not access method or field: " + ex.getMessage()); + } + if (ex instanceof InvocationTargetException invocationTargetException) { + handleInvocationTargetException(invocationTargetException); + } + if (ex instanceof RuntimeException runtimeException) { + throw runtimeException; + } + throw new UndeclaredThrowableException(ex); + } + + /** + * Handle the given invocation target exception. Should only be called if no checked exception is + * expected to be thrown by the target method. + * + *

Throws the underlying RuntimeException or Error in case of such a root cause. Throws an + * UndeclaredThrowableException otherwise. + * + * @param ex the invocation target exception to handle + */ + public static void handleInvocationTargetException(InvocationTargetException ex) { + rethrowRuntimeException(ex.getTargetException()); + } + + /** + * Rethrow the given {@link Throwable exception}, which is presumably the target + * exception of an {@link InvocationTargetException}. Should only be called if no checked + * exception is expected to be thrown by the target method. + * + *

Rethrows the underlying exception cast to a {@link RuntimeException} or {@link Error} if + * appropriate; otherwise, throws an {@link UndeclaredThrowableException}. + * + * @param ex the exception to rethrow + * @throws RuntimeException the rethrown exception + */ + public static void rethrowRuntimeException(@Nullable Throwable ex) { + if (ex instanceof RuntimeException runtimeException) { + throw runtimeException; + } + if (ex instanceof Error error) { + throw error; + } + throw new UndeclaredThrowableException(ex); + } + + /** + * Rethrow the given {@link Throwable exception}, which is presumably the target + * exception of an {@link InvocationTargetException}. Should only be called if no checked + * exception is expected to be thrown by the target method. + * + *

Rethrows the underlying exception cast to an {@link Exception} or {@link Error} if + * appropriate; otherwise, throws an {@link UndeclaredThrowableException}. + * + * @param throwable the exception to rethrow + * @throws Exception the rethrown exception (in case of a checked exception) + */ + public static void rethrowException(@Nullable Throwable throwable) throws Exception { + if (throwable instanceof Exception exception) { + throw exception; + } + if (throwable instanceof Error error) { + throw error; + } + throw new UndeclaredThrowableException(throwable); + } + + // Constructor handling + + /** + * Obtain an accessible constructor for the given class and parameters. + * + * @param clazz the clazz to check + * @param parameterTypes the parameter types of the desired constructor + * @return the constructor reference + * @throws NoSuchMethodException if no such constructor exists + * @since 5.0 + */ + public static Constructor accessibleConstructor(Class clazz, Class... parameterTypes) + throws NoSuchMethodException { + + Constructor ctor = clazz.getDeclaredConstructor(parameterTypes); + makeAccessible(ctor); + return ctor; + } + + /** + * Make the given constructor accessible, explicitly setting it accessible if necessary. The + * {@code setAccessible(true)} method is only called when actually necessary, to avoid unnecessary + * conflicts. + * + * @param ctor the constructor to make accessible + * @see Constructor#setAccessible + */ + @SuppressWarnings("deprecation") + public static void makeAccessible(Constructor ctor) { + if ((!Modifier.isPublic(ctor.getModifiers()) + || !Modifier.isPublic(ctor.getDeclaringClass().getModifiers())) + && !ctor.isAccessible()) { + ctor.setAccessible(true); + } + } + + // Method handling + + /** + * Attempt to find a {@link Method} on the supplied class with the supplied name and no + * parameters. Searches all superclasses up to {@code Object}. + * + *

Returns {@code null} if no {@link Method} can be found. + * + * @param clazz the class to introspect + * @param name the name of the method + * @return the Method object, or {@code null} if none found + */ + public static @Nullable Method findMethod(Class clazz, String name) { + return findMethod(clazz, name, EMPTY_CLASS_ARRAY); + } + + /** + * Attempt to find a {@link Method} on the supplied class with the supplied name and parameter + * types. Searches all superclasses up to {@code Object}. + * + *

Returns {@code null} if no {@link Method} can be found. + * + * @param clazz the class to introspect + * @param name the name of the method + * @param paramTypes the parameter types of the method (may be {@code null} to indicate any + * signature) + * @return the Method object, or {@code null} if none found + */ + public static @Nullable Method findMethod( + Class clazz, String name, Class @Nullable ... paramTypes) { + Class searchType = clazz; + while (searchType != null) { + Method[] methods = + (searchType.isInterface() + ? searchType.getMethods() + : getDeclaredMethods(searchType, false)); + for (Method method : methods) { + if (name.equals(method.getName()) + && (paramTypes == null || hasSameParams(method, paramTypes))) { + return method; + } + } + searchType = searchType.getSuperclass(); + } + return null; + } + + private static boolean hasSameParams(Method method, Class[] paramTypes) { + return (paramTypes.length == method.getParameterCount() + && Arrays.equals(paramTypes, method.getParameterTypes())); + } + + /** + * Invoke the specified {@link Method} against the supplied target object with no arguments. The + * target object can be {@code null} when invoking a static {@link Method}. + * + *

Thrown exceptions are handled via a call to {@link #handleReflectionException}. + * + * @param method the method to invoke + * @param target the target object to invoke the method on + * @return the invocation result, if any + * @see #invokeMethod(Method, Object, Object[]) + */ + public static @Nullable Object invokeMethod(Method method, @Nullable Object target) { + return invokeMethod(method, target, EMPTY_OBJECT_ARRAY); + } + + /** + * Invoke the specified {@link Method} against the supplied target object with the supplied + * arguments. The target object can be {@code null} when invoking a static {@link Method}. + * + *

Thrown exceptions are handled via a call to {@link #handleReflectionException}. + * + * @param method the method to invoke + * @param target the target object to invoke the method on + * @param args the invocation arguments (may be {@code null}) + * @return the invocation result, if any + */ + public static @Nullable Object invokeMethod( + Method method, @Nullable Object target, @Nullable Object... args) { + try { + return method.invoke(target, args); + } catch (Exception ex) { + handleReflectionException(ex); + } + throw new IllegalStateException("Should never get here"); + } + + /** + * Determine whether the given method explicitly declares the given exception or one of its + * superclasses, which means that an exception of that type can be propagated as-is within a + * reflective invocation. + * + * @param method the declaring method + * @param exceptionType the exception to throw + * @return {@code true} if the exception can be thrown as-is; {@code false} if it needs to be + * wrapped + */ + public static boolean declaresException(Method method, Class exceptionType) { + Class[] declaredExceptions = method.getExceptionTypes(); + for (Class declaredException : declaredExceptions) { + if (declaredException.isAssignableFrom(exceptionType)) { + return true; + } + } + return false; + } + + /** + * Perform the given callback operation on all matching methods of the given class, as locally + * declared or equivalent thereof (such as default methods on Java 8 based interfaces that the + * given class implements). + * + * @param clazz the class to introspect + * @param mc the callback to invoke for each method + * @throws IllegalStateException if introspection fails + * @see #doWithMethods + * @since 4.2 + */ + public static void doWithLocalMethods(Class clazz, MethodCallback mc) { + Method[] methods = getDeclaredMethods(clazz, false); + for (Method method : methods) { + try { + mc.doWith(method); + } catch (IllegalAccessException ex) { + throw new IllegalStateException( + "Not allowed to access method '" + method.getName() + "': " + ex); + } + } + } + + /** + * Perform the given callback operation on all matching methods of the given class and + * superclasses. + * + *

The same named method occurring on subclass and superclass will appear twice, unless + * excluded by a {@link MethodFilter}. + * + * @param clazz the class to introspect + * @param mc the callback to invoke for each method + * @throws IllegalStateException if introspection fails + * @see #doWithMethods(Class, MethodCallback, MethodFilter) + */ + public static void doWithMethods(Class clazz, MethodCallback mc) { + doWithMethods(clazz, mc, null); + } + + /** + * Perform the given callback operation on all matching methods of the given class and + * superclasses (or given interface and super-interfaces). + * + *

The same named method occurring on subclass and superclass will appear twice, unless + * excluded by the specified {@link MethodFilter}. + * + * @param clazz the class to introspect + * @param mc the callback to invoke for each method + * @param mf the filter that determines the methods to apply the callback to + * @throws IllegalStateException if introspection fails + */ + public static void doWithMethods(Class clazz, MethodCallback mc, @Nullable MethodFilter mf) { + if (mf == USER_DECLARED_METHODS && clazz == Object.class) { + // nothing to introspect + return; + } + Method[] methods = getDeclaredMethods(clazz, false); + for (Method method : methods) { + if (mf != null && !mf.matches(method)) { + continue; + } + try { + mc.doWith(method); + } catch (IllegalAccessException ex) { + throw new IllegalStateException( + "Not allowed to access method '" + method.getName() + "': " + ex); + } + } + // Keep backing up the inheritance hierarchy. + if (clazz.getSuperclass() != null + && (mf != USER_DECLARED_METHODS || clazz.getSuperclass() != Object.class)) { + doWithMethods(clazz.getSuperclass(), mc, mf); + } else if (clazz.isInterface()) { + for (Class superIfc : clazz.getInterfaces()) { + doWithMethods(superIfc, mc, mf); + } + } + } + + /** + * Get all declared methods on the leaf class and all superclasses. Leaf class methods are + * included first. + * + * @param leafClass the class to introspect + * @throws IllegalStateException if introspection fails + */ + public static Method[] getAllDeclaredMethods(Class leafClass) { + final List methods = new ArrayList<>(20); + doWithMethods(leafClass, methods::add); + return methods.toArray(EMPTY_METHOD_ARRAY); + } + + /** + * Get the unique set of declared methods on the leaf class and all superclasses. Leaf class + * methods are included first and while traversing the superclass hierarchy any methods found with + * signatures matching a method already included are filtered out. + * + * @param leafClass the class to introspect + * @throws IllegalStateException if introspection fails + */ + public static Method[] getUniqueDeclaredMethods(Class leafClass) { + return getUniqueDeclaredMethods(leafClass, null); + } + + /** + * Get the unique set of declared methods on the leaf class and all superclasses. Leaf class + * methods are included first and while traversing the superclass hierarchy any methods found with + * signatures matching a method already included are filtered out. + * + * @param leafClass the class to introspect + * @param mf the filter that determines the methods to take into account + * @throws IllegalStateException if introspection fails + * @since 5.2 + */ + public static Method[] getUniqueDeclaredMethods(Class leafClass, @Nullable MethodFilter mf) { + final List methods = new ArrayList<>(20); + doWithMethods( + leafClass, + method -> { + boolean knownSignature = false; + Method methodBeingOverriddenWithCovariantReturnType = null; + for (Method existingMethod : methods) { + if (method.getName().equals(existingMethod.getName()) + && method.getParameterCount() == existingMethod.getParameterCount() + && Arrays.equals(method.getParameterTypes(), existingMethod.getParameterTypes())) { + // Is this a covariant return type situation? + if (existingMethod.getReturnType() != method.getReturnType() + && existingMethod.getReturnType().isAssignableFrom(method.getReturnType())) { + methodBeingOverriddenWithCovariantReturnType = existingMethod; + } else { + knownSignature = true; + } + break; + } + } + if (methodBeingOverriddenWithCovariantReturnType != null) { + methods.remove(methodBeingOverriddenWithCovariantReturnType); + } + if (!knownSignature && !isCglibRenamedMethod(method)) { + methods.add(method); + } + }, + mf); + return methods.toArray(EMPTY_METHOD_ARRAY); + } + + /** + * Variant of {@link Class#getDeclaredMethods()} that uses a local cache in order to avoid new + * Method instances. In addition, it also includes Java 8 default methods from locally implemented + * interfaces, since those are effectively to be treated just like declared methods. + * + * @param clazz the class to introspect + * @return the cached array of methods + * @throws IllegalStateException if introspection fails + * @see Class#getDeclaredMethods() + * @since 5.2 + */ + public static Method[] getDeclaredMethods(Class clazz) { + return getDeclaredMethods(clazz, true); + } + + private static Method[] getDeclaredMethods(Class clazz, boolean defensive) { + Method[] result = declaredMethodsCache.get(clazz); + if (result == null) { + try { + Method[] declaredMethods = clazz.getDeclaredMethods(); + List defaultMethods = findDefaultMethodsOnInterfaces(clazz); + if (defaultMethods != null) { + result = new Method[declaredMethods.length + defaultMethods.size()]; + System.arraycopy(declaredMethods, 0, result, 0, declaredMethods.length); + int index = declaredMethods.length; + for (Method defaultMethod : defaultMethods) { + result[index] = defaultMethod; + index++; + } + } else { + result = declaredMethods; + } + declaredMethodsCache.put(clazz, (result.length == 0 ? EMPTY_METHOD_ARRAY : result)); + } catch (Throwable ex) { + throw new IllegalStateException( + "Failed to introspect Class [" + + clazz.getName() + + "] from ClassLoader [" + + clazz.getClassLoader() + + "]", + ex); + } + } + return (result.length == 0 || !defensive) ? result : result.clone(); + } + + private static @Nullable List findDefaultMethodsOnInterfaces(Class clazz) { + List result = null; + for (Class ifc : clazz.getInterfaces()) { + for (Method method : ifc.getMethods()) { + if (method.isDefault()) { + if (result == null) { + result = new ArrayList<>(); + } + result.add(method); + } + } + } + return result; + } + + /** + * Determine whether the given method is an "equals" method. + * + * @see Object#equals(Object) + */ + public static boolean isEqualsMethod(@Nullable Method method) { + return (method != null + && method.getParameterCount() == 1 + && method.getName().equals("equals") + && method.getParameterTypes()[0] == Object.class); + } + + /** + * Determine whether the given method is a "hashCode" method. + * + * @see Object#hashCode() + */ + public static boolean isHashCodeMethod(@Nullable Method method) { + return (method != null + && method.getParameterCount() == 0 + && method.getName().equals("hashCode")); + } + + /** + * Determine whether the given method is a "toString" method. + * + * @see Object#toString() + */ + public static boolean isToStringMethod(@Nullable Method method) { + return (method != null + && method.getParameterCount() == 0 + && method.getName().equals("toString")); + } + + /** Determine whether the given method is originally declared by {@link Object}. */ + public static boolean isObjectMethod(@Nullable Method method) { + return (method != null + && (method.getDeclaringClass() == Object.class + || isEqualsMethod(method) + || isHashCodeMethod(method) + || isToStringMethod(method))); + } + + /** + * Determine whether the given method is a CGLIB 'renamed' method, following the pattern + * "CGLIB$methodName$0". + * + * @param renamedMethod the method to check + */ + public static boolean isCglibRenamedMethod(Method renamedMethod) { + String name = renamedMethod.getName(); + if (name.startsWith(CGLIB_RENAMED_METHOD_PREFIX)) { + int i = name.length() - 1; + while (i >= 0 && Character.isDigit(name.charAt(i))) { + i--; + } + return (i > CGLIB_RENAMED_METHOD_PREFIX.length() + && (i < name.length() - 1) + && name.charAt(i) == '$'); + } + return false; + } + + /** + * Make the given method accessible, explicitly setting it accessible if necessary. The {@code + * setAccessible(true)} method is only called when actually necessary, to avoid unnecessary + * conflicts. + * + * @param method the method to make accessible + * @see Method#setAccessible + */ + @SuppressWarnings("deprecation") + public static void makeAccessible(Method method) { + if ((!Modifier.isPublic(method.getModifiers()) + || !Modifier.isPublic(method.getDeclaringClass().getModifiers())) + && !method.isAccessible()) { + method.setAccessible(true); + } + } + + // Field handling + + /** + * Attempt to find a {@link Field field} on the supplied {@link Class} with the supplied {@code + * name}. Searches all superclasses up to {@link Object}. + * + * @param clazz the class to introspect + * @param name the name of the field + * @return the corresponding Field object, or {@code null} if not found + */ + public static @Nullable Field findField(Class clazz, String name) { + return findField(clazz, name, null); + } + + /** + * Attempt to find a {@link Field field} on the supplied {@link Class} with the supplied {@code + * name} and/or {@link Class type}. Searches all superclasses up to {@link Object}. + * + * @param clazz the class to introspect + * @param name the name of the field (may be {@code null} if type is specified) + * @param type the type of the field (may be {@code null} if name is specified) + * @return the corresponding Field object, or {@code null} if not found + */ + public static @Nullable Field findField( + Class clazz, @Nullable String name, @Nullable Class type) { + Class searchType = clazz; + while (Object.class != searchType && searchType != null) { + Field[] fields = getDeclaredFields(searchType); + for (Field field : fields) { + if ((name == null || name.equals(field.getName())) + && (type == null || type.equals(field.getType()))) { + return field; + } + } + searchType = searchType.getSuperclass(); + } + return null; + } + + /** + * Attempt to find a {@link Field field} on the supplied {@link Class} with the supplied {@code + * name}. Searches all superclasses up to {@link Object}. + * + * @param clazz the class to introspect + * @param name the name of the field (with upper/lower case to be ignored) + * @return the corresponding Field object, or {@code null} if not found + * @since 6.1 + */ + public static @Nullable Field findFieldIgnoreCase(Class clazz, String name) { + Class searchType = clazz; + while (Object.class != searchType && searchType != null) { + Field[] fields = getDeclaredFields(searchType); + for (Field field : fields) { + if (name.equalsIgnoreCase(field.getName())) { + return field; + } + } + searchType = searchType.getSuperclass(); + } + return null; + } + + /** + * Set the field represented by the supplied {@linkplain Field field object} on the specified + * {@linkplain Object target object} to the specified {@code value}. + * + *

In accordance with {@link Field#set(Object, Object)} semantics, the new value is + * automatically unwrapped if the underlying field has a primitive type. + * + *

This method does not support setting {@code static final} fields. + * + *

Thrown exceptions are handled via a call to {@link #handleReflectionException(Exception)}. + * + * @param field the field to set + * @param target the target object on which to set the field (or {@code null} for a static field) + * @param value the value to set (may be {@code null}) + */ + public static void setField(Field field, @Nullable Object target, @Nullable Object value) { + try { + field.set(target, value); + } catch (IllegalAccessException ex) { + handleReflectionException(ex); + } + } + + /** + * Get the field represented by the supplied {@link Field field object} on the specified {@link + * Object target object}. In accordance with {@link Field#get(Object)} semantics, the returned + * value is automatically wrapped if the underlying field has a primitive type. + * + *

Thrown exceptions are handled via a call to {@link #handleReflectionException(Exception)}. + * + * @param field the field to get + * @param target the target object from which to get the field (or {@code null} for a static + * field) + * @return the field's current value + */ + public static @Nullable Object getField(Field field, @Nullable Object target) { + try { + return field.get(target); + } catch (IllegalAccessException ex) { + handleReflectionException(ex); + } + throw new IllegalStateException("Should never get here"); + } + + /** + * Invoke the given callback on all locally declared fields in the given class. + * + * @param clazz the target class to analyze + * @param fc the callback to invoke for each field + * @throws IllegalStateException if introspection fails + * @see #doWithFields + * @since 4.2 + */ + public static void doWithLocalFields(Class clazz, FieldCallback fc) { + for (Field field : getDeclaredFields(clazz)) { + try { + fc.doWith(field); + } catch (IllegalAccessException ex) { + throw new IllegalStateException( + "Not allowed to access field '" + field.getName() + "': " + ex); + } + } + } + + /** + * Invoke the given callback on all fields in the target class, going up the class hierarchy to + * get all declared fields. + * + * @param clazz the target class to analyze + * @param fc the callback to invoke for each field + * @throws IllegalStateException if introspection fails + */ + public static void doWithFields(Class clazz, FieldCallback fc) { + doWithFields(clazz, fc, null); + } + + /** + * Invoke the given callback on all fields in the target class, going up the class hierarchy to + * get all declared fields. + * + * @param clazz the target class to analyze + * @param fc the callback to invoke for each field + * @param ff the filter that determines the fields to apply the callback to + * @throws IllegalStateException if introspection fails + */ + public static void doWithFields(Class clazz, FieldCallback fc, @Nullable FieldFilter ff) { + // Keep backing up the inheritance hierarchy. + Class targetClass = clazz; + do { + for (Field field : getDeclaredFields(targetClass)) { + if (ff != null && !ff.matches(field)) { + continue; + } + try { + fc.doWith(field); + } catch (IllegalAccessException ex) { + throw new IllegalStateException( + "Not allowed to access field '" + field.getName() + "': " + ex); + } + } + targetClass = targetClass.getSuperclass(); + } while (targetClass != null && targetClass != Object.class); + } + + /** + * This variant retrieves {@link Class#getDeclaredFields()} from a local cache in order to avoid + * defensive array copying. + * + * @param clazz the class to introspect + * @return the cached array of fields + * @throws IllegalStateException if introspection fails + * @see Class#getDeclaredFields() + */ + private static Field[] getDeclaredFields(Class clazz) { + Field[] result = declaredFieldsCache.get(clazz); + if (result == null) { + try { + result = clazz.getDeclaredFields(); + declaredFieldsCache.put(clazz, (result.length == 0 ? EMPTY_FIELD_ARRAY : result)); + } catch (Throwable ex) { + throw new IllegalStateException( + "Failed to introspect Class [" + + clazz.getName() + + "] from ClassLoader [" + + clazz.getClassLoader() + + "]", + ex); + } + } + return result; + } + + /** + * Given the source object and the destination, which must be the same class or a subclass, copy + * all fields, including inherited fields. Designed to work on objects with public no-arg + * constructors. + * + * @throws IllegalStateException if introspection fails + */ + public static void shallowCopyFieldState(final Object src, final Object dest) { + if (!src.getClass().isAssignableFrom(dest.getClass())) { + throw new IllegalArgumentException( + "Destination class [" + + dest.getClass().getName() + + "] must be same or subclass as source class [" + + src.getClass().getName() + + "]"); + } + doWithFields( + src.getClass(), + field -> { + makeAccessible(field); + Object srcValue = field.get(src); + field.set(dest, srcValue); + }, + COPYABLE_FIELDS); + } + + /** + * Determine whether the given field is a "public static final" constant. + * + * @param field the field to check + */ + public static boolean isPublicStaticFinal(Field field) { + int modifiers = field.getModifiers(); + return (Modifier.isPublic(modifiers) + && Modifier.isStatic(modifiers) + && Modifier.isFinal(modifiers)); + } + + /** + * Make the given field accessible, explicitly setting it accessible if necessary. The {@code + * setAccessible(true)} method is only called when actually necessary, to avoid unnecessary + * conflicts. + * + * @param field the field to make accessible + * @see Field#setAccessible + */ + @SuppressWarnings("deprecation") + public static void makeAccessible(Field field) { + if ((!Modifier.isPublic(field.getModifiers()) + || !Modifier.isPublic(field.getDeclaringClass().getModifiers()) + || Modifier.isFinal(field.getModifiers())) + && !field.isAccessible()) { + field.setAccessible(true); + } + } + + // Cache handling + + /** + * Clear the internal method/field cache. + * + * @since 4.2.4 + */ + public static void clearCache() { + declaredMethodsCache.clear(); + declaredFieldsCache.clear(); + } + + /** Action to take on each method. */ + @FunctionalInterface + public interface MethodCallback { + + /** + * Perform an operation using the given method. + * + * @param method the method to operate on + */ + void doWith(Method method) throws IllegalArgumentException, IllegalAccessException; + } + + /** Callback optionally used to filter methods to be operated on by a method callback. */ + @FunctionalInterface + public interface MethodFilter { + + /** + * Determine whether the given method matches. + * + * @param method the method to check + */ + boolean matches(Method method); + + /** + * Create a composite filter based on this filter and the provided filter. + * + *

If this filter does not match, the next filter will not be applied. + * + * @param next the next {@code MethodFilter} + * @return a composite {@code MethodFilter} + * @throws IllegalArgumentException if the MethodFilter argument is {@code null} + * @since 5.3.2 + */ + default MethodFilter and(MethodFilter next) { + return method -> matches(method) && next.matches(method); + } + } + + /** Callback interface invoked on each field in the hierarchy. */ + @FunctionalInterface + public interface FieldCallback { + + /** + * Perform an operation using the given field. + * + * @param field the field to operate on + */ + void doWith(Field field) throws IllegalArgumentException, IllegalAccessException; + } + + /** Callback optionally used to filter fields to be operated on by a field callback. */ + @FunctionalInterface + public interface FieldFilter { + + /** + * Determine whether the given field matches. + * + * @param field the field to check + */ + boolean matches(Field field); + + /** + * Create a composite filter based on this filter and the provided filter. + * + *

If this filter does not match, the next filter will not be applied. + * + * @param next the next {@code FieldFilter} + * @return a composite {@code FieldFilter} + * @throws IllegalArgumentException if the FieldFilter argument is {@code null} + * @since 5.3.2 + */ + default FieldFilter and(FieldFilter next) { + return field -> matches(field) && next.matches(field); + } + } +} diff --git a/common/src/main/java/dev/restate/common/reflections/RestateUtils.java b/common/src/main/java/dev/restate/common/reflections/RestateUtils.java new file mode 100644 index 00000000..4f578638 --- /dev/null +++ b/common/src/main/java/dev/restate/common/reflections/RestateUtils.java @@ -0,0 +1,55 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflections; + +import dev.restate.common.InvocationOptions; +import dev.restate.common.Request; +import dev.restate.common.Target; +import dev.restate.serde.Serde; +import dev.restate.serde.TypeRef; +import dev.restate.serde.TypeTag; +import java.lang.reflect.Type; +import org.jspecify.annotations.Nullable; + +public final class RestateUtils { + + public static Request toRequest( + String serviceName, + @Nullable String key, + String handlerName, + TypeTag reqTypeTag, + TypeTag resTypeTag, + Req request, + @Nullable InvocationOptions options) { + var builder = + Request.of( + Target.virtualObject(serviceName, key, handlerName), reqTypeTag, resTypeTag, request); + if (options != null) { + builder.setIdempotencyKey(options.getIdempotencyKey()); + if (options.getHeaders() != null) { + builder.setHeaders(options.getHeaders()); + } + } + + return builder.build(); + } + + public static TypeTag typeTag(Type type) { + if (type.equals(Void.TYPE)) { + return Serde.VOID; + } + return TypeTag.of( + new TypeRef<>() { + @Override + public Type getType() { + return type; + } + }); + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Accept.java b/common/src/main/java/dev/restate/sdk/annotation/Accept.java similarity index 95% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/Accept.java rename to common/src/main/java/dev/restate/sdk/annotation/Accept.java index 5c2d1b46..c2bfefe0 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Accept.java +++ b/common/src/main/java/dev/restate/sdk/annotation/Accept.java @@ -18,7 +18,7 @@ * * / *} */ @Target(ElementType.PARAMETER) -@Retention(RetentionPolicy.SOURCE) +@Retention(RetentionPolicy.RUNTIME) public @interface Accept { String value(); } diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java b/common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java similarity index 96% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java rename to common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java index 922a03d5..b8b61aa4 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java +++ b/common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java @@ -21,7 +21,7 @@ * annotation. */ @Target(ElementType.TYPE) -@Retention(RetentionPolicy.SOURCE) +@Retention(RetentionPolicy.RUNTIME) public @interface CustomSerdeFactory { Class value(); } diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Exclusive.java b/common/src/main/java/dev/restate/sdk/annotation/Exclusive.java similarity index 95% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/Exclusive.java rename to common/src/main/java/dev/restate/sdk/annotation/Exclusive.java index 24c0cb86..c0c42517 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Exclusive.java +++ b/common/src/main/java/dev/restate/sdk/annotation/Exclusive.java @@ -18,5 +18,5 @@ * only on methods of {@link VirtualObject}. This implies the annotation {@link Handler}. */ @Target(ElementType.METHOD) -@Retention(RetentionPolicy.SOURCE) +@Retention(RetentionPolicy.RUNTIME) public @interface Exclusive {} diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Handler.java b/common/src/main/java/dev/restate/sdk/annotation/Handler.java similarity index 96% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/Handler.java rename to common/src/main/java/dev/restate/sdk/annotation/Handler.java index 4177154d..c3d3df76 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Handler.java +++ b/common/src/main/java/dev/restate/sdk/annotation/Handler.java @@ -20,5 +20,5 @@ * handlers. */ @Target(ElementType.METHOD) -@Retention(RetentionPolicy.SOURCE) +@Retention(RetentionPolicy.RUNTIME) public @interface Handler {} diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Json.java b/common/src/main/java/dev/restate/sdk/annotation/Json.java similarity index 95% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/Json.java rename to common/src/main/java/dev/restate/sdk/annotation/Json.java index f7fd1b63..ff587ab5 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Json.java +++ b/common/src/main/java/dev/restate/sdk/annotation/Json.java @@ -15,7 +15,7 @@ /** Serialize/Deserialize the annotated element as Json */ @Target({ElementType.METHOD, ElementType.PARAMETER}) -@Retention(RetentionPolicy.SOURCE) +@Retention(RetentionPolicy.RUNTIME) public @interface Json { /** Content-type to use in request/responses. */ String contentType() default "application/json"; diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Name.java b/common/src/main/java/dev/restate/sdk/annotation/Name.java similarity index 95% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/Name.java rename to common/src/main/java/dev/restate/sdk/annotation/Name.java index 4cf2f5a9..089046df 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Name.java +++ b/common/src/main/java/dev/restate/sdk/annotation/Name.java @@ -20,7 +20,7 @@ * When not provided for a handler, it will be the annotated method name. */ @Target({ElementType.METHOD, ElementType.TYPE}) -@Retention(RetentionPolicy.SOURCE) +@Retention(RetentionPolicy.RUNTIME) public @interface Name { String value() default ""; } diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Raw.java b/common/src/main/java/dev/restate/sdk/annotation/Raw.java similarity index 96% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/Raw.java rename to common/src/main/java/dev/restate/sdk/annotation/Raw.java index 65ad8c95..327fe67a 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Raw.java +++ b/common/src/main/java/dev/restate/sdk/annotation/Raw.java @@ -18,7 +18,7 @@ * parameter/return type to be {@code byte[]} */ @Target({ElementType.METHOD, ElementType.PARAMETER}) -@Retention(RetentionPolicy.SOURCE) +@Retention(RetentionPolicy.RUNTIME) public @interface Raw { /** Content-type to use in request/responses. */ String contentType() default "application/octet-stream"; diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Service.java b/common/src/main/java/dev/restate/sdk/annotation/Service.java similarity index 75% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/Service.java rename to common/src/main/java/dev/restate/sdk/annotation/Service.java index 99f5c306..287e82d6 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Service.java +++ b/common/src/main/java/dev/restate/sdk/annotation/Service.java @@ -8,13 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.annotation; -import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory; import java.lang.annotation.*; -/** - * Annotation to define a class/interface as Restate Service. This triggers the code generation of - * the related Client class and the {@link ServiceDefinitionFactory}. - */ +/** Annotation to define a class/interface as Restate Service. */ @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) @Documented diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Shared.java b/common/src/main/java/dev/restate/sdk/annotation/Shared.java similarity index 96% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/Shared.java rename to common/src/main/java/dev/restate/sdk/annotation/Shared.java index dda03ba0..498eb76c 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Shared.java +++ b/common/src/main/java/dev/restate/sdk/annotation/Shared.java @@ -23,5 +23,5 @@ *

This implies the annotation {@link Handler}. */ @Target(ElementType.METHOD) -@Retention(RetentionPolicy.SOURCE) +@Retention(RetentionPolicy.RUNTIME) public @interface Shared {} diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java b/common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java similarity index 75% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java rename to common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java index 23a6f32f..6d77623d 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java +++ b/common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java @@ -8,13 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.annotation; -import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory; import java.lang.annotation.*; -/** - * Annotation to define a class/interface as Restate VirtualObject. This triggers the code - * generation of the related Client class and the {@link ServiceDefinitionFactory}. - */ +/** Annotation to define a class/interface as Restate VirtualObject. */ @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) @Documented diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Workflow.java b/common/src/main/java/dev/restate/sdk/annotation/Workflow.java similarity index 69% rename from sdk-common/src/main/java/dev/restate/sdk/annotation/Workflow.java rename to common/src/main/java/dev/restate/sdk/annotation/Workflow.java index 841a6e08..108be88f 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Workflow.java +++ b/common/src/main/java/dev/restate/sdk/annotation/Workflow.java @@ -8,13 +8,11 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.annotation; -import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory; import java.lang.annotation.*; /** - * Annotation to define a class/interface as Restate Workflow. This triggers the code generation of - * the related Client class and the {@link ServiceDefinitionFactory}. When defining a - * class/interface as workflow, you must annotate one of its methods too as {@link Workflow}. + * Annotation to define a class/interface as Restate Workflow. When defining a class/interface as + * workflow, you must annotate one of its methods too as {@link Workflow}. */ @Target({ElementType.METHOD, ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 4a45b155..3c63e9f0 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -13,6 +13,8 @@ testcontainers = 'org.testcontainers:testcontainers:1.20.4' tink = 'com.google.crypto.tink:tink:1.18.0' tomcat-annotations = 'org.apache.tomcat:annotations-api:6.0.53' + bytebuddy = "net.bytebuddy:byte-buddy:1.18.3" + objenesis = "org.objenesis:objenesis:3.4" [libraries.jackson-annotations] module = 'com.fasterxml.jackson.core:jackson-annotations' diff --git a/sdk-api/src/main/java/dev/restate/sdk/MalformedRestateServiceException.java b/sdk-api/src/main/java/dev/restate/sdk/MalformedRestateServiceException.java new file mode 100644 index 00000000..d23c4719 --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/MalformedRestateServiceException.java @@ -0,0 +1,21 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +public class MalformedRestateServiceException extends RuntimeException { + + public MalformedRestateServiceException(String serviceName, String message) { + super("Failed to instantiate Restate service '" + serviceName + "'.\nReason: " + message); + } + + public MalformedRestateServiceException(String serviceName, String message, Throwable cause) { + super( + "Failed to instantiate Restate service '" + serviceName + "'.\nReason: " + message, cause); + } +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/ReflectionServiceDefinitionFactory.java b/sdk-api/src/main/java/dev/restate/sdk/ReflectionServiceDefinitionFactory.java new file mode 100644 index 00000000..8bd1d9f5 --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/ReflectionServiceDefinitionFactory.java @@ -0,0 +1,227 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.common.function.ThrowingBiFunction; +import dev.restate.common.reflections.ReflectionUtils; +import dev.restate.common.reflections.RestateUtils; +import dev.restate.sdk.annotation.*; +import dev.restate.sdk.endpoint.definition.*; +import dev.restate.serde.Serde; +import dev.restate.serde.SerdeFactory; +import dev.restate.serde.provider.DefaultSerdeFactoryProvider; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.ServiceLoader; +import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; + +@org.jetbrains.annotations.ApiStatus.Experimental +final class ReflectionServiceDefinitionFactory implements ServiceDefinitionFactory { + + private volatile SerdeFactory cachedDefaultSerdeFactory; + + @Override + public ServiceDefinition create( + Object serviceInstance, + dev.restate.sdk.endpoint.definition.HandlerRunner.Options overrideHandlerOptions) { + dev.restate.sdk.HandlerRunner.Options handlerRunnerOptions; + if (overrideHandlerOptions == null + || overrideHandlerOptions instanceof dev.restate.sdk.HandlerRunner.Options) { + handlerRunnerOptions = (dev.restate.sdk.HandlerRunner.Options) overrideHandlerOptions; + } else { + throw new IllegalArgumentException( + "The provided options class MUST be instance of dev.restate.sdk.HandlerRunner.Options, but was " + + overrideHandlerOptions.getClass()); + } + + Class serviceClazz = serviceInstance.getClass(); + + boolean hasServiceAnnotation = + ReflectionUtils.findAnnotation(serviceClazz, Service.class) != null; + boolean hasVirtualObjectAnnotation = + ReflectionUtils.findAnnotation(serviceClazz, VirtualObject.class) != null; + boolean hasWorkflowAnnotation = + ReflectionUtils.findAnnotation(serviceClazz, Workflow.class) != null; + + boolean hasAnyAnnotation = + hasServiceAnnotation || hasVirtualObjectAnnotation || hasWorkflowAnnotation; + if (!hasAnyAnnotation) { + throw new MalformedRestateServiceException( + serviceClazz.getSimpleName(), + "A restate component MUST be annotated with " + + "exactly one annotation between @Service/@VirtualObject/@Workflow, no annotation was found"); + } + boolean hasExactlyOneAnnotation = + Boolean.logicalXor( + hasServiceAnnotation, + Boolean.logicalXor(hasVirtualObjectAnnotation, hasWorkflowAnnotation)); + + if (!hasExactlyOneAnnotation) { + throw new MalformedRestateServiceException( + serviceClazz.getSimpleName(), + "A restate component MUST be annotated with " + + "exactly one annotation between @Service/@VirtualObject/@Workflow, more than one annotation found"); + } + + var serviceName = ReflectionUtils.extractServiceName(serviceClazz); + var serviceType = + hasServiceAnnotation + ? ServiceType.SERVICE + : hasVirtualObjectAnnotation ? ServiceType.VIRTUAL_OBJECT : ServiceType.WORKFLOW; + var serdeFactory = resolveSerdeFactory(serviceClazz); + + var methods = + ReflectionUtils.getUniqueDeclaredMethods( + serviceClazz, + method -> + ReflectionUtils.findAnnotation(method, Handler.class) != null + || ReflectionUtils.findAnnotation(method, Shared.class) != null); + if (methods.length == 0) { + throw new MalformedRestateServiceException(serviceName, "No @Handler method found"); + } + return ServiceDefinition.of( + serviceName, + serviceType, + Arrays.stream(methods) + .map( + method -> + createHandlerDefinition( + serviceInstance, + method, + serviceName, + serviceType, + serdeFactory, + handlerRunnerOptions)) + .collect(Collectors.toUnmodifiableList())); + } + + private HandlerDefinition createHandlerDefinition( + Object serviceInstance, + Method method, + String serviceName, + ServiceType serviceType, + SerdeFactory serdeFactory, + HandlerRunner.@Nullable Options overrideHandlerOptions) { + var handlerInfo = ReflectionUtils.mustHaveHandlerAnnotation(method); + var handlerName = handlerInfo.name(); + var genericParameterTypes = method.getGenericParameterTypes(); + if (genericParameterTypes.length > 1) { + throw new MalformedRestateServiceException( + serviceName, + "More than one parameter found in method " + + method.getName() + + ". Only one parameter is supported."); + } + var inputType = genericParameterTypes.length == 0 ? Void.TYPE : genericParameterTypes[0]; + var outputType = method.getGenericReturnType(); + + if (serviceType == ServiceType.SERVICE && handlerInfo.shared()) { + throw new MalformedRestateServiceException( + serviceName, "@Shared is only supported on virtual objects and workflow handlers"); + } + var handlerType = + serviceType == ServiceType.SERVICE || handlerInfo.shared() + ? HandlerType.SHARED + : serviceType == ServiceType.VIRTUAL_OBJECT + ? HandlerType.EXCLUSIVE + : HandlerType.WORKFLOW; + + var parameterCount = method.getParameterCount(); + + // TODO here we should add some code to handle handling Context in method definition. + // This is because we want to make sure people declaring the handlers with the Context in the method works + // providing a smoother path to transition from code generation + // Plus plus plus important bit -> we need to validate the input paramters can be one and only one (OBV)! + + var runner = + dev.restate.sdk.HandlerRunner.of( + (ThrowingBiFunction) + (ctx, in) -> + RestateThreadLocalContext.wrap( + ctx, + () -> { + try { + if (parameterCount == 0) { + return method.invoke(serviceInstance); + } else { + return method.invoke(serviceInstance, in); + } + } catch (InvocationTargetException e) { + throw e.getCause(); + } + }), + serdeFactory, + overrideHandlerOptions); + + //noinspection unchecked + return HandlerDefinition.of( + handlerName, + handlerType, + (Serde) serdeFactory.create(RestateUtils.typeTag(inputType)), + (Serde) serdeFactory.create(RestateUtils.typeTag(outputType)), + runner); + } + + private SerdeFactory resolveSerdeFactory(Class serviceClazz) { + // Check for CustomSerdeFactory annotation + CustomSerdeFactory customSerdeFactoryAnnotation = + ReflectionUtils.findAnnotation(serviceClazz, CustomSerdeFactory.class); + + if (customSerdeFactoryAnnotation != null) { + try { + return customSerdeFactoryAnnotation.value().getDeclaredConstructor().newInstance(); + } catch (Exception e) { + throw new MalformedRestateServiceException( + serviceClazz.getSimpleName(), + "Failed to instantiate custom SerdeFactory: " + + customSerdeFactoryAnnotation.value().getName(), + e); + } + } + + // Try DefaultSerdeFactoryProvider -> if there's one, it's an easy pick! + if (this.cachedDefaultSerdeFactory != null) { + return this.cachedDefaultSerdeFactory; + } + + var loadedFactories = ServiceLoader.load(DefaultSerdeFactoryProvider.class).stream().toList(); + if (loadedFactories.size() == 1) { + this.cachedDefaultSerdeFactory = loadedFactories.get(0).get().create(); + return this.cachedDefaultSerdeFactory; + } + + // Load Jackson serde factory + try { + Class jacksonSerdeFactoryClass = + Class.forName("dev.restate.serde.jackson.JacksonSerdeFactory"); + Object defaultInstance = jacksonSerdeFactoryClass.getField("DEFAULT").get(null); + this.cachedDefaultSerdeFactory = (SerdeFactory) defaultInstance; + return this.cachedDefaultSerdeFactory; + } catch (Exception e) { + throw new MalformedRestateServiceException( + serviceClazz.getSimpleName(), + "Failed to load JacksonSerdeFactory for Java service. " + + "Make sure sdk-serde-jackson is on the classpath.", + e); + } + } + + @Override + public boolean supports(Object serviceObject) { + return true; + } + + @Override + public int priority() { + // Run last - after code-generated factories + return LOWEST_PRIORITY; + } +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/Restate.java b/sdk-api/src/main/java/dev/restate/sdk/Restate.java new file mode 100644 index 00000000..51ef165e --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/Restate.java @@ -0,0 +1,181 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import static dev.restate.common.reflections.ReflectionUtils.mustHaveAnnotation; + +import dev.restate.common.Slice; +import dev.restate.common.function.ThrowingRunnable; +import dev.restate.common.function.ThrowingSupplier; +import dev.restate.sdk.annotation.Service; +import dev.restate.sdk.annotation.VirtualObject; +import dev.restate.sdk.annotation.Workflow; +import dev.restate.sdk.common.HandlerRequest; +import dev.restate.sdk.common.RetryPolicy; +import dev.restate.sdk.common.TerminalException; +import dev.restate.serde.TypeTag; +import java.time.Duration; + +@org.jetbrains.annotations.ApiStatus.Experimental +public final class Restate { + @org.jetbrains.annotations.ApiStatus.Experimental + public static Context get() { + return RestateThreadLocalContext.getContext(); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static boolean isInsideHandler() { + return RestateThreadLocalContext.CONTEXT_THREAD_LOCAL.get() != null; + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static HandlerRequest request() { + return get().request(); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static RestateRandom random() { + return get().random(); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static InvocationHandle invocationHandle( + String invocationId, TypeTag responseTypeTag) { + return get().invocationHandle(invocationId, responseTypeTag); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static InvocationHandle invocationHandle( + String invocationId, Class responseClazz) { + return get().invocationHandle(invocationId, responseClazz); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static InvocationHandle invocationHandle(String invocationId) { + return get().invocationHandle(invocationId); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static void sleep(Duration duration) { + get().sleep(duration); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static DurableFuture timer(String name, Duration duration) { + return get().timer(name, duration); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static T run(String name, Class clazz, ThrowingSupplier action) + throws TerminalException { + return get().run(name, clazz, action); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static T run( + String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) + throws TerminalException { + return get().run(name, typeTag, retryPolicy, action); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static T run( + String name, Class clazz, RetryPolicy retryPolicy, ThrowingSupplier action) + throws TerminalException { + return get().run(name, clazz, retryPolicy, action); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static T run(String name, TypeTag typeTag, ThrowingSupplier action) + throws TerminalException { + return get().run(name, typeTag, action); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static void run(String name, RetryPolicy retryPolicy, ThrowingRunnable runnable) + throws TerminalException { + get().run(name, retryPolicy, runnable); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static void run(String name, ThrowingRunnable runnable) throws TerminalException { + get().run(name, runnable); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static DurableFuture runAsync( + String name, Class clazz, ThrowingSupplier action) throws TerminalException { + return get().runAsync(name, clazz, action); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static DurableFuture runAsync( + String name, TypeTag typeTag, ThrowingSupplier action) throws TerminalException { + return get().runAsync(name, typeTag, action); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static DurableFuture runAsync( + String name, Class clazz, RetryPolicy retryPolicy, ThrowingSupplier action) + throws TerminalException { + return get().runAsync(name, clazz, retryPolicy, action); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static DurableFuture runAsync( + String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) + throws TerminalException { + return get().runAsync(name, typeTag, retryPolicy, action); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static DurableFuture runAsync( + String name, RetryPolicy retryPolicy, ThrowingRunnable runnable) throws TerminalException { + return get().runAsync(name, retryPolicy, runnable); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static DurableFuture runAsync(String name, ThrowingRunnable runnable) + throws TerminalException { + return get().runAsync(name, runnable); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static Awakeable awakeable(Class clazz) { + return get().awakeable(clazz); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static Awakeable awakeable(TypeTag typeTag) { + return get().awakeable(typeTag); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static AwakeableHandle awakeableHandle(String id) { + return get().awakeableHandle(id); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static ServiceReference service(Class clazz) { + mustHaveAnnotation(clazz, Service.class); + return new ServiceReferenceImpl<>(clazz, null); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static ServiceReference virtualObject(Class clazz, String key) { + mustHaveAnnotation(clazz, VirtualObject.class); + return new ServiceReferenceImpl<>(clazz, key); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static ServiceReference workflow(Class clazz, String key) { + mustHaveAnnotation(clazz, Workflow.class); + return new ServiceReferenceImpl<>(clazz, key); + } +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/RestateThreadLocalContext.java b/sdk-api/src/main/java/dev/restate/sdk/RestateThreadLocalContext.java new file mode 100644 index 00000000..49fd2b92 --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/RestateThreadLocalContext.java @@ -0,0 +1,40 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.common.function.ThrowingSupplier; +import java.util.Objects; + +final class RestateThreadLocalContext { + + static final ThreadLocal CONTEXT_THREAD_LOCAL = new ThreadLocal<>(); + + static Context getContext() { + return Objects.requireNonNull( + CONTEXT_THREAD_LOCAL.get(), + "Restate methods must be invoked from within a Restate handler"); + } + + static T wrap(Context context, ThrowingSupplier runnable) throws Throwable { + setContext(context); + try { + return runnable.get(); + } finally { + clearContext(); + } + } + + static void setContext(Context context) { + CONTEXT_THREAD_LOCAL.set(context); + } + + static void clearContext() { + CONTEXT_THREAD_LOCAL.remove(); + } +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java b/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java new file mode 100644 index 00000000..71ba1aa8 --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java @@ -0,0 +1,198 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.common.InvocationOptions; +import java.time.Duration; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +@org.jetbrains.annotations.ApiStatus.Experimental +public interface ServiceReference { + @org.jetbrains.annotations.ApiStatus.Experimental + SVC client(); + + @org.jetbrains.annotations.ApiStatus.Experimental + default DurableFuture call(BiFunction s, I input) { + return call(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default DurableFuture call( + BiFunction s, I input, InvocationOptions.Builder options) { + return call(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + DurableFuture call(BiFunction s, I input, InvocationOptions options); + + @org.jetbrains.annotations.ApiStatus.Experimental + default DurableFuture call(BiConsumer s, I input) { + return call(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default DurableFuture call( + BiConsumer s, I input, InvocationOptions.Builder options) { + return call(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + DurableFuture call(BiConsumer s, I input, InvocationOptions options); + + @org.jetbrains.annotations.ApiStatus.Experimental + default DurableFuture call(Function s) { + return call(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default DurableFuture call(Function s, InvocationOptions.Builder options) { + return call(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + DurableFuture call(Function s, InvocationOptions options); + + @org.jetbrains.annotations.ApiStatus.Experimental + default DurableFuture call(Consumer s) { + return call(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default DurableFuture call(Consumer s, InvocationOptions.Builder options) { + return call(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + DurableFuture call(Consumer s, InvocationOptions options); + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(BiFunction s, I input) { + return send(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send( + BiFunction s, I input, InvocationOptions.Builder options) { + return send(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send( + BiFunction s, I input, InvocationOptions options) { + return send(s, input, null, options); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(BiFunction s, I input, Duration delay) { + return send(s, input, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send( + BiFunction s, I input, Duration delay, InvocationOptions.Builder options) { + return send(s, input, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + InvocationHandle send( + BiFunction s, I input, Duration delay, InvocationOptions options); + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(BiConsumer s, I input) { + return send(s, input, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send( + BiConsumer s, I input, InvocationOptions.Builder options) { + return send(s, input, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send( + BiConsumer s, I input, InvocationOptions options) { + return send(s, input, null, options); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(BiConsumer s, I input, Duration delay) { + return send(s, input, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send( + BiConsumer s, I input, Duration delay, InvocationOptions.Builder options) { + return send(s, input, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + InvocationHandle send( + BiConsumer s, I input, Duration delay, InvocationOptions options); + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(Function s) { + return send(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(Function s, InvocationOptions.Builder options) { + return send(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(Function s, InvocationOptions options) { + return send(s, null, options); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(Function s, Duration delay) { + return send(s, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send( + Function s, Duration delay, InvocationOptions.Builder options) { + return send(s, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + InvocationHandle send(Function s, Duration delay, InvocationOptions options); + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(Consumer s) { + return send(s, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(Consumer s, InvocationOptions.Builder options) { + return send(s, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(Consumer s, InvocationOptions options) { + return send(s, null, options); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send(Consumer s, Duration delay) { + return send(s, delay, InvocationOptions.DEFAULT); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + default InvocationHandle send( + Consumer s, Duration delay, InvocationOptions.Builder options) { + return send(s, delay, options.build()); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + InvocationHandle send(Consumer s, Duration delay, InvocationOptions options); +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/ServiceReferenceImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ServiceReferenceImpl.java new file mode 100644 index 00000000..4ce4c0a3 --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/ServiceReferenceImpl.java @@ -0,0 +1,210 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import static dev.restate.common.reflections.RestateUtils.toRequest; + +import dev.restate.common.InvocationOptions; +import dev.restate.common.Request; +import dev.restate.common.Target; +import dev.restate.common.reflections.*; +import dev.restate.serde.Serde; +import dev.restate.serde.TypeTag; +import java.time.Duration; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import org.jspecify.annotations.Nullable; + +final class ServiceReferenceImpl implements ServiceReference { + + private final Class clazz; + private final String serviceName; + private final @Nullable String key; + + // The simple proxy for users + private SVC proxyClient; + + // To use call/send + private MethodInfoCollector methodInfoCollector; + + ServiceReferenceImpl(Class clazz, @Nullable String key) { + this.clazz = clazz; + this.serviceName = ReflectionUtils.extractServiceName(clazz); + this.key = key; + } + + @Override + public SVC client() { + if (proxyClient == null) { + this.proxyClient = + ProxySupport.createProxy( + clazz, + invocation -> { + var methodInfo = MethodInfo.fromMethod(invocation.getMethod()); + + //noinspection unchecked + return Restate.get() + .call( + Request.of( + Target.virtualObject(serviceName, key, methodInfo.getHandlerName()), + (TypeTag) + RestateUtils.typeTag(methodInfo.getInputType()), + (TypeTag) + RestateUtils.typeTag(methodInfo.getOutputType()), + invocation.getArguments().length == 0 + ? null + : invocation.getArguments()[0])) + .await(); + }); + } + return this.proxyClient; + } + + @SuppressWarnings("unchecked") + @Override + public DurableFuture call(BiFunction s, I input, InvocationOptions options) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); + return Restate.get() + .call( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + (TypeTag) RestateUtils.typeTag(methodInfo.getInputType()), + (TypeTag) RestateUtils.typeTag(methodInfo.getOutputType()), + input, + options)); + } + + @SuppressWarnings("unchecked") + @Override + public DurableFuture call(BiConsumer s, I input, InvocationOptions options) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); + return Restate.get() + .call( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + (TypeTag) RestateUtils.typeTag(methodInfo.getInputType()), + Serde.VOID, + input, + options)); + } + + @SuppressWarnings("unchecked") + @Override + public DurableFuture call(Function s, InvocationOptions options) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s); + return Restate.get() + .call( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + Serde.VOID, + (TypeTag) RestateUtils.typeTag(methodInfo.getOutputType()), + null, + options)); + } + + @Override + public DurableFuture call(Consumer s, InvocationOptions options) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s); + return Restate.get() + .call( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + Serde.VOID, + Serde.VOID, + null, + options)); + } + + @SuppressWarnings("unchecked") + @Override + public InvocationHandle send( + BiFunction s, I input, Duration delay, InvocationOptions options) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); + return Restate.get() + .send( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + (TypeTag) RestateUtils.typeTag(methodInfo.getInputType()), + (TypeTag) RestateUtils.typeTag(methodInfo.getOutputType()), + input, + options), + delay); + } + + @SuppressWarnings("unchecked") + @Override + public InvocationHandle send( + BiConsumer s, I input, Duration delay, InvocationOptions options) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); + return Restate.get() + .send( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + (TypeTag) RestateUtils.typeTag(methodInfo.getInputType()), + Serde.VOID, + input, + options), + delay); + } + + @SuppressWarnings("unchecked") + @Override + public InvocationHandle send( + Function s, Duration delay, InvocationOptions options) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s); + return Restate.get() + .send( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + Serde.VOID, + (TypeTag) RestateUtils.typeTag(methodInfo.getOutputType()), + null, + options), + delay); + } + + @Override + public InvocationHandle send(Consumer s, Duration delay, InvocationOptions options) { + MethodInfo methodInfo = getMethodInfoCollector().resolve(s); + return Restate.get() + .send( + toRequest( + serviceName, + key, + methodInfo.getHandlerName(), + Serde.VOID, + Serde.VOID, + null, + options), + delay); + } + + private MethodInfoCollector getMethodInfoCollector() { + if (this.methodInfoCollector == null) { + this.methodInfoCollector = new MethodInfoCollector<>(this.clazz); + } + return this.methodInfoCollector; + } +} diff --git a/sdk-api/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory b/sdk-api/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory new file mode 100644 index 00000000..40c361d0 --- /dev/null +++ b/sdk-api/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory @@ -0,0 +1 @@ +dev.restate.sdk.ReflectionServiceDefinitionFactory \ No newline at end of file diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactories.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactories.java index 19b27053..b2466531 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactories.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactories.java @@ -40,6 +40,8 @@ public ServiceDefinitionFactories() { e); } } + + this.factories.sort(Comparator.comparingInt(ServiceDefinitionFactory::priority)); } /** Resolve the code generated {@link ServiceDefinitionFactory} */ diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactory.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactory.java index fe457eda..55141c5a 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactory.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactory.java @@ -15,4 +15,20 @@ public interface ServiceDefinitionFactory { ServiceDefinition create(T serviceObject, HandlerRunner.@Nullable Options overrideHandlerOptions); boolean supports(Object serviceObject); + + /** + * Get the priority of this factory. Lower values are tried first. The default priority is + * HIGHEST_PRIORITY. + * + *

Code-generated factories should use the default priority so they are tried first. + * + * @return the priority value + */ + default int priority() { + return HIGHEST_PRIORITY; + } + + int HIGHEST_PRIORITY = Integer.MIN_VALUE; + + int LOWEST_PRIORITY = Integer.MAX_VALUE; } diff --git a/sdk-serde-jackson/src/main/java/dev/restate/serde/jackson/JacksonSerdeFactoryProvider.java b/sdk-serde-jackson/src/main/java/dev/restate/serde/jackson/JacksonSerdeFactoryProvider.java index 8ccba50e..d3653033 100644 --- a/sdk-serde-jackson/src/main/java/dev/restate/serde/jackson/JacksonSerdeFactoryProvider.java +++ b/sdk-serde-jackson/src/main/java/dev/restate/serde/jackson/JacksonSerdeFactoryProvider.java @@ -14,6 +14,6 @@ public class JacksonSerdeFactoryProvider implements DefaultSerdeFactoryProvider { @Override public SerdeFactory create() { - return new JacksonSerdeFactory(); + return JacksonSerdeFactory.DEFAULT; } } diff --git a/settings.gradle.kts b/settings.gradle.kts index 23a1c617..7513cfa0 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -13,6 +13,7 @@ plugins { id("org.gradle.toolchains.foojay-resolver-convention") version "0.9.0" include( "admin-client", + "bytebuddy-proxy-support", "common", "client", "client-kotlin", From 29d83059d7bfdb8233c3d62ed68eaa3f32768615 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 9 Jan 2026 17:03:37 +0100 Subject: [PATCH 2/7] Working solution + testing this in the simple exmamples here --- examples/build.gradle.kts | 1 - .../java/my/restate/sdk/examples/Counter.java | 17 +++++-- .../java/my/restate/sdk/examples/Greeter.java | 3 +- .../my/restate/sdk/examples/LoanWorkflow.java | 49 ++++++++++-------- .../java/dev/restate/sdk/HandlerRunner.java | 30 ++++++++++- .../main/java/dev/restate/sdk/Restate.java | 50 +++++++++---------- .../sdk/RestateThreadLocalContext.java | 40 --------------- .../dev/restate/sdk/ServiceReferenceImpl.java | 18 +++---- .../ReflectionServiceDefinitionFactory.java | 39 ++++++++------- ...dpoint.definition.ServiceDefinitionFactory | 2 +- 10 files changed, 126 insertions(+), 123 deletions(-) delete mode 100644 sdk-api/src/main/java/dev/restate/sdk/RestateThreadLocalContext.java rename sdk-api/src/main/java/dev/restate/sdk/{ => internal}/ReflectionServiceDefinitionFactory.java (90%) diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index f430faf0..86530b6f 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -10,7 +10,6 @@ plugins { dependencies { ksp(project(":sdk-api-kotlin-gen")) - annotationProcessor(project(":sdk-api-gen")) implementation(project(":client")) implementation(project(":client-kotlin")) diff --git a/examples/src/main/java/my/restate/sdk/examples/Counter.java b/examples/src/main/java/my/restate/sdk/examples/Counter.java index 6bba0933..c9cf6d65 100644 --- a/examples/src/main/java/my/restate/sdk/examples/Counter.java +++ b/examples/src/main/java/my/restate/sdk/examples/Counter.java @@ -8,7 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package my.restate.sdk.examples; -import dev.restate.sdk.*; +import dev.restate.sdk.ObjectContext; +import dev.restate.sdk.Restate; +import dev.restate.sdk.SharedObjectContext; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Name; import dev.restate.sdk.annotation.Shared; @@ -22,7 +24,7 @@ /** Counter virtual object */ @VirtualObject -@Name("Counter") +@Name("BroCounter") public class Counter { private static final Logger LOG = LogManager.getLogger(Counter.class); @@ -37,7 +39,9 @@ public void reset(ObjectContext ctx) { /** Add the given value to the count. */ @Handler - public void add(ObjectContext ctx, long request) { + public void add(long request) { + var ctx = (ObjectContext) Restate.context(); + long currentValue = ctx.get(TOTAL).orElse(0L); long newValue = currentValue + request; ctx.sleep(Duration.ofSeconds(120)); @@ -47,13 +51,16 @@ public void add(ObjectContext ctx, long request) { /** Get the current counter value. */ @Shared @Handler - public long get(SharedObjectContext ctx) { + public long get() { + var ctx = (SharedObjectContext) Restate.context(); return ctx.get(TOTAL).orElse(0L); } /** Add a value, and get both the previous value and the new value. */ @Handler - public CounterUpdateResult getAndAdd(ObjectContext ctx, long request) { + public CounterUpdateResult getAndAdd(long request) { + var ctx = (ObjectContext) Restate.context(); + LOG.info("Invoked get and add with {}", request); long currentValue = ctx.get(TOTAL).orElse(0L); diff --git a/examples/src/main/java/my/restate/sdk/examples/Greeter.java b/examples/src/main/java/my/restate/sdk/examples/Greeter.java index 8bd6fe11..3e599623 100644 --- a/examples/src/main/java/my/restate/sdk/examples/Greeter.java +++ b/examples/src/main/java/my/restate/sdk/examples/Greeter.java @@ -8,7 +8,6 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package my.restate.sdk.examples; -import dev.restate.sdk.Context; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Service; import dev.restate.sdk.endpoint.Endpoint; @@ -22,7 +21,7 @@ public record Greeting(String name) {} public record GreetingResponse(String message) {} @Handler - public GreetingResponse greet(Context ctx, Greeting req) { + public GreetingResponse greet(Greeting req) { // Respond to caller return new GreetingResponse("You said hi to " + req.name + "!"); } diff --git a/examples/src/main/java/my/restate/sdk/examples/LoanWorkflow.java b/examples/src/main/java/my/restate/sdk/examples/LoanWorkflow.java index 8529e276..f1a6012d 100644 --- a/examples/src/main/java/my/restate/sdk/examples/LoanWorkflow.java +++ b/examples/src/main/java/my/restate/sdk/examples/LoanWorkflow.java @@ -8,9 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package my.restate.sdk.examples; -import dev.restate.sdk.Context; -import dev.restate.sdk.SharedWorkflowContext; -import dev.restate.sdk.WorkflowContext; +import dev.restate.client.Client; +import dev.restate.client.ClientServiceReference; +import dev.restate.sdk.*; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Service; import dev.restate.sdk.annotation.Shared; @@ -57,7 +57,9 @@ public record LoanRequest( // --- The main workflow method @Workflow - public String run(WorkflowContext ctx, LoanRequest loanRequest) { + public String run(LoanRequest loanRequest) { + var ctx = (WorkflowContext) Restate.context(); + // 1. Set status ctx.set(STATUS, Status.SUBMITTED); ctx.set(LOAN_REQUEST, loanRequest); @@ -79,12 +81,12 @@ public String run(WorkflowContext ctx, LoanRequest loanRequest) { ctx.set(STATUS, Status.APPROVED); // 4. Request money transaction to the bank - var bankClient = LoanWorkflowMockBankClient.fromContext(ctx); Instant executionTime; try { executionTime = - bankClient - .transfer( + Restate.service(MockBank.class) + .call( + MockBank::transfer, new TransferRequest(loanRequest.customerBankAccount(), loanRequest.amount())) .await(Duration.ofDays(7)); } catch (TerminalException e) { @@ -105,18 +107,24 @@ public String run(WorkflowContext ctx, LoanRequest loanRequest) { // --- Methods to approve/reject loan @Shared - public String approveLoan(SharedWorkflowContext ctx) { + public String approveLoan() { + var ctx = (SharedWorkflowContext) Restate.context(); + ctx.promiseHandle(HUMAN_APPROVAL).resolve(true); return "Approved"; } @Shared - public void rejectLoan(SharedWorkflowContext ctx) { + public void rejectLoan() { + var ctx = (SharedWorkflowContext) Restate.context(); + ctx.promiseHandle(HUMAN_APPROVAL).resolve(false); } @Shared - public Status getStatus(SharedWorkflowContext ctx) { + public Status getStatus() { + var ctx = (SharedWorkflowContext) Restate.context(); + return ctx.get(STATUS).orElse(Status.UNKNOWN); } @@ -139,11 +147,12 @@ public static void main(String[] args) { } // To invoke the workflow: - LoanWorkflowClient.IngressClient client = - LoanWorkflowClient.connect("http://127.0.0.1:8080", "my-loan"); - - var state = - client.submit( + Client restateClient = Client.connect("http://127.0.0.1:8080"); + ClientServiceReference loanWorkflow = + restateClient.workflow(LoanWorkflow.class, "my-loan"); + var handle = + loanWorkflow.send( + LoanWorkflow::run, new LoanRequest( "Francesco", "slinkydeveloper", "DE1234", new BigDecimal("1000000000"))); @@ -159,12 +168,12 @@ public static void main(String[] args) { LOG.info("We took the decision to approve your loan! You can now achieve your dreams!"); // Now approve it - client.approveLoan(); + loanWorkflow.client().approveLoan(); // Wait for output - client.workflowHandle().attach(); + handle.attach(); - LOG.info("Loan workflow completed, now in status {}", client.getStatus()); + LOG.info("Loan workflow completed, now in status {}", loanWorkflow.client().getStatus()); } // -- Some mocks @@ -177,8 +186,8 @@ private static void askHumanApproval(String workflowKey) throws InterruptedExcep @Service static class MockBank { @Handler - public Instant transfer(Context context, TransferRequest request) throws TerminalException { - boolean shouldAccept = context.random().nextInt(3) != 1; + public Instant transfer(TransferRequest request) throws TerminalException { + boolean shouldAccept = Restate.random().nextInt(3) != 1; if (shouldAccept) { return Instant.now(); } else { diff --git a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java index f111c1d5..f49b4d9d 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java +++ b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java @@ -18,6 +18,7 @@ import dev.restate.serde.Serde; import dev.restate.serde.SerdeFactory; import io.opentelemetry.context.Scope; +import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.Executors; @@ -87,11 +88,20 @@ public CompletableFuture run( } // Execute user code - RES res; + RES res = null; + Throwable error = null; try { + setContext(ctx); res = this.runner.apply(ctx, req); } catch (Throwable e) { - returnFuture.completeExceptionally(e); + error = e; + } finally { + clearContext(); + } + + // If error, just return now + if (error != null) { + returnFuture.completeExceptionally(error); return; } @@ -191,4 +201,20 @@ public static Options withExecutor(Executor executor) { return new Options(executor); } } + + static final ThreadLocal CONTEXT_THREAD_LOCAL = new ThreadLocal<>(); + + static Context getContext() { + return Objects.requireNonNull( + CONTEXT_THREAD_LOCAL.get(), + "Restate methods must be invoked from within a Restate handler"); + } + + static void setContext(Context context) { + CONTEXT_THREAD_LOCAL.set(context); + } + + static void clearContext() { + CONTEXT_THREAD_LOCAL.remove(); + } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/Restate.java b/sdk-api/src/main/java/dev/restate/sdk/Restate.java index 51ef165e..19d2590a 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Restate.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Restate.java @@ -25,140 +25,140 @@ @org.jetbrains.annotations.ApiStatus.Experimental public final class Restate { @org.jetbrains.annotations.ApiStatus.Experimental - public static Context get() { - return RestateThreadLocalContext.getContext(); + public static Context context() { + return HandlerRunner.getContext(); } @org.jetbrains.annotations.ApiStatus.Experimental public static boolean isInsideHandler() { - return RestateThreadLocalContext.CONTEXT_THREAD_LOCAL.get() != null; + return HandlerRunner.CONTEXT_THREAD_LOCAL.get() != null; } @org.jetbrains.annotations.ApiStatus.Experimental public static HandlerRequest request() { - return get().request(); + return context().request(); } @org.jetbrains.annotations.ApiStatus.Experimental public static RestateRandom random() { - return get().random(); + return context().random(); } @org.jetbrains.annotations.ApiStatus.Experimental public static InvocationHandle invocationHandle( String invocationId, TypeTag responseTypeTag) { - return get().invocationHandle(invocationId, responseTypeTag); + return context().invocationHandle(invocationId, responseTypeTag); } @org.jetbrains.annotations.ApiStatus.Experimental public static InvocationHandle invocationHandle( String invocationId, Class responseClazz) { - return get().invocationHandle(invocationId, responseClazz); + return context().invocationHandle(invocationId, responseClazz); } @org.jetbrains.annotations.ApiStatus.Experimental public static InvocationHandle invocationHandle(String invocationId) { - return get().invocationHandle(invocationId); + return context().invocationHandle(invocationId); } @org.jetbrains.annotations.ApiStatus.Experimental public static void sleep(Duration duration) { - get().sleep(duration); + context().sleep(duration); } @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture timer(String name, Duration duration) { - return get().timer(name, duration); + return context().timer(name, duration); } @org.jetbrains.annotations.ApiStatus.Experimental public static T run(String name, Class clazz, ThrowingSupplier action) throws TerminalException { - return get().run(name, clazz, action); + return context().run(name, clazz, action); } @org.jetbrains.annotations.ApiStatus.Experimental public static T run( String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) throws TerminalException { - return get().run(name, typeTag, retryPolicy, action); + return context().run(name, typeTag, retryPolicy, action); } @org.jetbrains.annotations.ApiStatus.Experimental public static T run( String name, Class clazz, RetryPolicy retryPolicy, ThrowingSupplier action) throws TerminalException { - return get().run(name, clazz, retryPolicy, action); + return context().run(name, clazz, retryPolicy, action); } @org.jetbrains.annotations.ApiStatus.Experimental public static T run(String name, TypeTag typeTag, ThrowingSupplier action) throws TerminalException { - return get().run(name, typeTag, action); + return context().run(name, typeTag, action); } @org.jetbrains.annotations.ApiStatus.Experimental public static void run(String name, RetryPolicy retryPolicy, ThrowingRunnable runnable) throws TerminalException { - get().run(name, retryPolicy, runnable); + context().run(name, retryPolicy, runnable); } @org.jetbrains.annotations.ApiStatus.Experimental public static void run(String name, ThrowingRunnable runnable) throws TerminalException { - get().run(name, runnable); + context().run(name, runnable); } @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, Class clazz, ThrowingSupplier action) throws TerminalException { - return get().runAsync(name, clazz, action); + return context().runAsync(name, clazz, action); } @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, TypeTag typeTag, ThrowingSupplier action) throws TerminalException { - return get().runAsync(name, typeTag, action); + return context().runAsync(name, typeTag, action); } @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, Class clazz, RetryPolicy retryPolicy, ThrowingSupplier action) throws TerminalException { - return get().runAsync(name, clazz, retryPolicy, action); + return context().runAsync(name, clazz, retryPolicy, action); } @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) throws TerminalException { - return get().runAsync(name, typeTag, retryPolicy, action); + return context().runAsync(name, typeTag, retryPolicy, action); } @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, RetryPolicy retryPolicy, ThrowingRunnable runnable) throws TerminalException { - return get().runAsync(name, retryPolicy, runnable); + return context().runAsync(name, retryPolicy, runnable); } @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync(String name, ThrowingRunnable runnable) throws TerminalException { - return get().runAsync(name, runnable); + return context().runAsync(name, runnable); } @org.jetbrains.annotations.ApiStatus.Experimental public static Awakeable awakeable(Class clazz) { - return get().awakeable(clazz); + return context().awakeable(clazz); } @org.jetbrains.annotations.ApiStatus.Experimental public static Awakeable awakeable(TypeTag typeTag) { - return get().awakeable(typeTag); + return context().awakeable(typeTag); } @org.jetbrains.annotations.ApiStatus.Experimental public static AwakeableHandle awakeableHandle(String id) { - return get().awakeableHandle(id); + return context().awakeableHandle(id); } @org.jetbrains.annotations.ApiStatus.Experimental diff --git a/sdk-api/src/main/java/dev/restate/sdk/RestateThreadLocalContext.java b/sdk-api/src/main/java/dev/restate/sdk/RestateThreadLocalContext.java deleted file mode 100644 index 49fd2b92..00000000 --- a/sdk-api/src/main/java/dev/restate/sdk/RestateThreadLocalContext.java +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; - -import dev.restate.common.function.ThrowingSupplier; -import java.util.Objects; - -final class RestateThreadLocalContext { - - static final ThreadLocal CONTEXT_THREAD_LOCAL = new ThreadLocal<>(); - - static Context getContext() { - return Objects.requireNonNull( - CONTEXT_THREAD_LOCAL.get(), - "Restate methods must be invoked from within a Restate handler"); - } - - static T wrap(Context context, ThrowingSupplier runnable) throws Throwable { - setContext(context); - try { - return runnable.get(); - } finally { - clearContext(); - } - } - - static void setContext(Context context) { - CONTEXT_THREAD_LOCAL.set(context); - } - - static void clearContext() { - CONTEXT_THREAD_LOCAL.remove(); - } -} diff --git a/sdk-api/src/main/java/dev/restate/sdk/ServiceReferenceImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ServiceReferenceImpl.java index 4ce4c0a3..bada53cc 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ServiceReferenceImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ServiceReferenceImpl.java @@ -51,7 +51,7 @@ public SVC client() { var methodInfo = MethodInfo.fromMethod(invocation.getMethod()); //noinspection unchecked - return Restate.get() + return Restate.context() .call( Request.of( Target.virtualObject(serviceName, key, methodInfo.getHandlerName()), @@ -72,7 +72,7 @@ public SVC client() { @Override public DurableFuture call(BiFunction s, I input, InvocationOptions options) { MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); - return Restate.get() + return Restate.context() .call( toRequest( serviceName, @@ -88,7 +88,7 @@ public DurableFuture call(BiFunction s, I input, Invocation @Override public DurableFuture call(BiConsumer s, I input, InvocationOptions options) { MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); - return Restate.get() + return Restate.context() .call( toRequest( serviceName, @@ -104,7 +104,7 @@ public DurableFuture call(BiConsumer s, I input, InvocationOpt @Override public DurableFuture call(Function s, InvocationOptions options) { MethodInfo methodInfo = getMethodInfoCollector().resolve(s); - return Restate.get() + return Restate.context() .call( toRequest( serviceName, @@ -119,7 +119,7 @@ public DurableFuture call(Function s, InvocationOptions options) @Override public DurableFuture call(Consumer s, InvocationOptions options) { MethodInfo methodInfo = getMethodInfoCollector().resolve(s); - return Restate.get() + return Restate.context() .call( toRequest( serviceName, @@ -136,7 +136,7 @@ public DurableFuture call(Consumer s, InvocationOptions options) { public InvocationHandle send( BiFunction s, I input, Duration delay, InvocationOptions options) { MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); - return Restate.get() + return Restate.context() .send( toRequest( serviceName, @@ -154,7 +154,7 @@ public InvocationHandle send( public InvocationHandle send( BiConsumer s, I input, Duration delay, InvocationOptions options) { MethodInfo methodInfo = getMethodInfoCollector().resolve(s, input); - return Restate.get() + return Restate.context() .send( toRequest( serviceName, @@ -172,7 +172,7 @@ public InvocationHandle send( public InvocationHandle send( Function s, Duration delay, InvocationOptions options) { MethodInfo methodInfo = getMethodInfoCollector().resolve(s); - return Restate.get() + return Restate.context() .send( toRequest( serviceName, @@ -188,7 +188,7 @@ public InvocationHandle send( @Override public InvocationHandle send(Consumer s, Duration delay, InvocationOptions options) { MethodInfo methodInfo = getMethodInfoCollector().resolve(s); - return Restate.get() + return Restate.context() .send( toRequest( serviceName, diff --git a/sdk-api/src/main/java/dev/restate/sdk/ReflectionServiceDefinitionFactory.java b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java similarity index 90% rename from sdk-api/src/main/java/dev/restate/sdk/ReflectionServiceDefinitionFactory.java rename to sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java index 8bd1d9f5..0238aa4f 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ReflectionServiceDefinitionFactory.java +++ b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java @@ -6,11 +6,14 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.internal; import dev.restate.common.function.ThrowingBiFunction; import dev.restate.common.reflections.ReflectionUtils; import dev.restate.common.reflections.RestateUtils; +import dev.restate.sdk.Context; +import dev.restate.sdk.HandlerRunner; +import dev.restate.sdk.MalformedRestateServiceException; import dev.restate.sdk.annotation.*; import dev.restate.sdk.endpoint.definition.*; import dev.restate.serde.Serde; @@ -24,7 +27,8 @@ import org.jspecify.annotations.Nullable; @org.jetbrains.annotations.ApiStatus.Experimental -final class ReflectionServiceDefinitionFactory implements ServiceDefinitionFactory { +@org.jetbrains.annotations.ApiStatus.Internal +public final class ReflectionServiceDefinitionFactory implements ServiceDefinitionFactory { private volatile SerdeFactory cachedDefaultSerdeFactory; @@ -137,27 +141,26 @@ public ServiceDefinition create( var parameterCount = method.getParameterCount(); // TODO here we should add some code to handle handling Context in method definition. - // This is because we want to make sure people declaring the handlers with the Context in the method works + // This is because we want to make sure people declaring the handlers with the Context in the + // method works // providing a smoother path to transition from code generation - // Plus plus plus important bit -> we need to validate the input paramters can be one and only one (OBV)! + // Plus plus plus important bit -> we need to validate the input paramters can be one and only + // one (OBV)! var runner = dev.restate.sdk.HandlerRunner.of( (ThrowingBiFunction) - (ctx, in) -> - RestateThreadLocalContext.wrap( - ctx, - () -> { - try { - if (parameterCount == 0) { - return method.invoke(serviceInstance); - } else { - return method.invoke(serviceInstance, in); - } - } catch (InvocationTargetException e) { - throw e.getCause(); - } - }), + (ctx, in) -> { + try { + if (parameterCount == 0) { + return method.invoke(serviceInstance); + } else { + return method.invoke(serviceInstance, in); + } + } catch (InvocationTargetException e) { + throw e.getCause(); + } + }, serdeFactory, overrideHandlerOptions); diff --git a/sdk-api/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory b/sdk-api/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory index 40c361d0..9bb86d5a 100644 --- a/sdk-api/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory +++ b/sdk-api/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory @@ -1 +1 @@ -dev.restate.sdk.ReflectionServiceDefinitionFactory \ No newline at end of file +dev.restate.sdk.internal.ReflectionServiceDefinitionFactory \ No newline at end of file From 0abc032ef3ce3c5c413ec4f34bfe496f0846eddb Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 9 Jan 2026 17:52:03 +0100 Subject: [PATCH 3/7] Expose the other context's here --- .../java/my/restate/sdk/examples/Counter.java | 14 ++--- .../java/dev/restate/sdk/HandlerRunner.java | 10 +++- .../main/java/dev/restate/sdk/Restate.java | 56 +++++++++++++++++++ .../ReflectionServiceDefinitionFactory.java | 4 +- .../endpoint/definition/HandlerContext.java | 8 +++ .../definition/HandlerDefinition.java | 6 +- .../sdk/core/EndpointRequestHandler.java | 1 + .../ExecutorSwitchingHandlerContextImpl.java | 6 +- .../restate/sdk/core/HandlerContextImpl.java | 28 ++++++++++ .../sdk/core/RequestProcessorImpl.java | 22 +++++++- 10 files changed, 135 insertions(+), 20 deletions(-) diff --git a/examples/src/main/java/my/restate/sdk/examples/Counter.java b/examples/src/main/java/my/restate/sdk/examples/Counter.java index c9cf6d65..c6090f00 100644 --- a/examples/src/main/java/my/restate/sdk/examples/Counter.java +++ b/examples/src/main/java/my/restate/sdk/examples/Counter.java @@ -8,9 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package my.restate.sdk.examples; -import dev.restate.sdk.ObjectContext; import dev.restate.sdk.Restate; -import dev.restate.sdk.SharedObjectContext; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Name; import dev.restate.sdk.annotation.Shared; @@ -18,7 +16,6 @@ import dev.restate.sdk.common.StateKey; import dev.restate.sdk.endpoint.Endpoint; import dev.restate.sdk.http.vertx.RestateHttpServer; -import java.time.Duration; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -33,18 +30,17 @@ public class Counter { /** Reset the counter. */ @Handler - public void reset(ObjectContext ctx) { - ctx.clearAll(); + public void reset() { + Restate.objectContext().clearAll(); } /** Add the given value to the count. */ @Handler public void add(long request) { - var ctx = (ObjectContext) Restate.context(); + var ctx = Restate.objectContext(); long currentValue = ctx.get(TOTAL).orElse(0L); long newValue = currentValue + request; - ctx.sleep(Duration.ofSeconds(120)); ctx.set(TOTAL, newValue); } @@ -52,14 +48,14 @@ public void add(long request) { @Shared @Handler public long get() { - var ctx = (SharedObjectContext) Restate.context(); + var ctx = Restate.sharedObjectContext(); return ctx.get(TOTAL).orElse(0L); } /** Add a value, and get both the previous value and the new value. */ @Handler public CounterUpdateResult getAndAdd(long request) { - var ctx = (ObjectContext) Restate.context(); + var ctx = Restate.objectContext(); LOG.info("Invoked get and add with {}", request); diff --git a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java index f49b4d9d..a9a913f5 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java +++ b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java @@ -210,11 +210,17 @@ static Context getContext() { "Restate methods must be invoked from within a Restate handler"); } - static void setContext(Context context) { + static HandlerContext getHandlerContext() { + return Objects.requireNonNull( + HANDLER_CONTEXT_THREAD_LOCAL.get(), + "Restate methods must be invoked from within a Restate handler"); + } + + private static void setContext(Context context) { CONTEXT_THREAD_LOCAL.set(context); } - static void clearContext() { + private static void clearContext() { CONTEXT_THREAD_LOCAL.remove(); } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/Restate.java b/sdk-api/src/main/java/dev/restate/sdk/Restate.java index 19d2590a..94c31420 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Restate.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Restate.java @@ -29,6 +29,62 @@ public static Context context() { return HandlerRunner.getContext(); } + @org.jetbrains.annotations.ApiStatus.Experimental + public static ObjectContext objectContext() { + var handlerContext = HandlerRunner.getHandlerContext(); + + if (handlerContext.canReadState() && handlerContext.canWriteState()) { + return (ObjectContext) context(); + } + if (handlerContext.canReadState()) { + throw new IllegalStateException( + "Calling objectContext() from a Virtual object shared handler. You must use Restate.sharedObjectContext() instead."); + } + + throw new IllegalStateException( + "Calling objectContext() from a non Virtual object handler. You can use Restate.objectContext() only inside a Restate Virtual Object handler."); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static SharedObjectContext sharedObjectContext() { + var handlerContext = HandlerRunner.getHandlerContext(); + + if (handlerContext.canReadState()) { + return (SharedObjectContext) context(); + } + + throw new IllegalStateException( + "Calling objectContext() from a non Virtual object handler. You can use Restate.objectContext() only inside a Restate Virtual Object handler."); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static WorkflowContext workflowContext() { + var handlerContext = HandlerRunner.getHandlerContext(); + + if (handlerContext.canReadPromises() && handlerContext.canWritePromises()) { + return (WorkflowContext) context(); + } + if (handlerContext.canReadPromises()) { + throw new IllegalStateException( + "Calling workflowContext() from a Workflow shared handler. You must use Restate.sharedWorkflowContext() instead."); + } + + throw new IllegalStateException( + "Calling workflowContext() from a non Workflow handler. You can use Restate.workflowContext() only inside a Restate Workflow handler."); + } + + @org.jetbrains.annotations.ApiStatus.Experimental + public static SharedWorkflowContext sharedWorkflowContext() { + var handlerContext = HandlerRunner.getHandlerContext(); + + if (handlerContext.canReadPromises()) { + return (SharedWorkflowContext) context(); + } + + throw new IllegalStateException( + "Calling workflowContext() from a non Workflow handler. You can use Restate.workflowContext() only inside a Restate Workflow handler."); + } + @org.jetbrains.annotations.ApiStatus.Experimental public static boolean isInsideHandler() { return HandlerRunner.CONTEXT_THREAD_LOCAL.get() != null; diff --git a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java index 0238aa4f..2a753858 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java +++ b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java @@ -132,11 +132,11 @@ public ServiceDefinition create( serviceName, "@Shared is only supported on virtual objects and workflow handlers"); } var handlerType = - serviceType == ServiceType.SERVICE || handlerInfo.shared() + handlerInfo.shared() ? HandlerType.SHARED : serviceType == ServiceType.VIRTUAL_OBJECT ? HandlerType.EXCLUSIVE - : HandlerType.WORKFLOW; + : serviceType == ServiceType.WORKFLOW ? HandlerType.WORKFLOW : null; var parameterCount = method.getParameterCount(); diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java index 575da8bf..ba564e4b 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java @@ -29,6 +29,14 @@ public interface HandlerContext { HandlerRequest request(); + boolean canReadState(); + + boolean canWriteState(); + + boolean canReadPromises(); + + boolean canWritePromises(); + // ----- IO // Note: These are not supposed to be exposed in the user's facing Context API. diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerDefinition.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerDefinition.java index 19c2f655..3a3f9510 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerDefinition.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerDefinition.java @@ -21,7 +21,7 @@ public final class HandlerDefinition { private final String name; - private final HandlerType handlerType; + private final @Nullable HandlerType handlerType; private final @Nullable String acceptContentType; private final Serde requestSerde; private final Serde responseSerde; @@ -39,7 +39,7 @@ public final class HandlerDefinition { HandlerDefinition( String name, - HandlerType handlerType, + @Nullable HandlerType handlerType, @Nullable String acceptContentType, Serde requestSerde, Serde responseSerde, @@ -82,7 +82,7 @@ public String getName() { /** * @return handler type. */ - public HandlerType getHandlerType() { + public @Nullable HandlerType getHandlerType() { return handlerType; } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java index aeb8043d..261b14a4 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java @@ -216,6 +216,7 @@ public RequestProcessor processorForRequest( return new RequestProcessorImpl( fullyQualifiedServiceMethod, stateMachine, + svc.getServiceType(), handler, otelContext, loggingContextSetter, diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java index ff9f2e3d..257e951a 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java @@ -14,6 +14,8 @@ import dev.restate.sdk.common.*; import dev.restate.sdk.core.statemachine.StateMachine; import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.endpoint.definition.HandlerType; +import dev.restate.sdk.endpoint.definition.ServiceType; import io.opentelemetry.context.Context; import java.time.Duration; import java.util.Collection; @@ -31,11 +33,13 @@ final class ExecutorSwitchingHandlerContextImpl extends HandlerContextImpl { ExecutorSwitchingHandlerContextImpl( String fullyQualifiedHandlerName, + ServiceType serviceType, + @Nullable HandlerType handlerType, StateMachine stateMachine, Context otelContext, StateMachine.Input input, Executor coreExecutor) { - super(fullyQualifiedHandlerName, stateMachine, otelContext, input); + super(fullyQualifiedHandlerName, serviceType, handlerType, stateMachine, otelContext, input); this.coreExecutor = coreExecutor; } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java index ff6910f6..57906ef8 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java @@ -19,6 +19,8 @@ import dev.restate.sdk.core.statemachine.NotificationValue; import dev.restate.sdk.core.statemachine.StateMachine; import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.endpoint.definition.HandlerType; +import dev.restate.sdk.endpoint.definition.ServiceType; import io.opentelemetry.context.Context; import java.time.Duration; import java.time.Instant; @@ -41,12 +43,16 @@ class HandlerContextImpl implements HandlerContextInternal { private final StateMachine stateMachine; private final @Nullable String objectKey; private final String fullyQualifiedHandlerName; + private final ServiceType serviceType; + private final @Nullable HandlerType handlerType; private final List> invocationIdsToCancel; private final HashMap> scheduledRuns; HandlerContextImpl( String fullyQualifiedHandlerName, + ServiceType serviceType, + @Nullable HandlerType handlerType, StateMachine stateMachine, Context otelContext, StateMachine.Input input) { @@ -55,6 +61,8 @@ class HandlerContextImpl implements HandlerContextInternal { this.objectKey = input.key(); this.stateMachine = stateMachine; this.fullyQualifiedHandlerName = fullyQualifiedHandlerName; + this.serviceType = serviceType; + this.handlerType = handlerType; this.invocationIdsToCancel = new ArrayList<>(); this.scheduledRuns = new HashMap<>(); } @@ -102,6 +110,26 @@ public HandlerRequest request() { return this.handlerRequest; } + @Override + public boolean canReadState() { + return serviceType == ServiceType.VIRTUAL_OBJECT || serviceType == ServiceType.WORKFLOW; + } + + @Override + public boolean canWriteState() { + return handlerType == HandlerType.EXCLUSIVE || handlerType == HandlerType.WORKFLOW; + } + + @Override + public boolean canReadPromises() { + return serviceType == ServiceType.WORKFLOW; + } + + @Override + public boolean canWritePromises() { + return serviceType == ServiceType.WORKFLOW; + } + @Override public String getFullyQualifiedMethodName() { return this.fullyQualifiedHandlerName; diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java index 0c202e41..7f7137e5 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java @@ -13,6 +13,7 @@ import dev.restate.sdk.core.statemachine.InvocationState; import dev.restate.sdk.core.statemachine.StateMachine; import dev.restate.sdk.endpoint.definition.HandlerDefinition; +import dev.restate.sdk.endpoint.definition.ServiceType; import io.opentelemetry.context.Context; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -28,6 +29,7 @@ final class RequestProcessorImpl implements RequestProcessor { private final String fullyQualifiedHandlerName; private final StateMachine stateMachine; + private final ServiceType serviceType; private final HandlerDefinition handlerDefinition; private final Context otelContext; private final EndpointRequestHandler.LoggingContextSetter loggingContextSetter; @@ -35,15 +37,17 @@ final class RequestProcessorImpl implements RequestProcessor { private final AtomicReference onHandlerTaskCancellation; @SuppressWarnings("unchecked") - public RequestProcessorImpl( + RequestProcessorImpl( String fullyQualifiedHandlerName, StateMachine stateMachine, + ServiceType serviceType, HandlerDefinition handlerDefinition, Context otelContext, EndpointRequestHandler.LoggingContextSetter loggingContextSetter, Executor syscallExecutor) { this.fullyQualifiedHandlerName = fullyQualifiedHandlerName; this.stateMachine = stateMachine; + this.serviceType = serviceType; this.otelContext = otelContext; this.loggingContextSetter = loggingContextSetter; this.handlerDefinition = (HandlerDefinition) handlerDefinition; @@ -143,8 +147,20 @@ private CompletableFuture onReady() { HandlerContextInternal contextInternal = this.syscallsExecutor != null ? new ExecutorSwitchingHandlerContextImpl( - fullyQualifiedHandlerName, stateMachine, otelContext, input, this.syscallsExecutor) - : new HandlerContextImpl(fullyQualifiedHandlerName, stateMachine, otelContext, input); + fullyQualifiedHandlerName, + serviceType, + handlerDefinition.getHandlerType(), + stateMachine, + otelContext, + input, + this.syscallsExecutor) + : new HandlerContextImpl( + fullyQualifiedHandlerName, + serviceType, + handlerDefinition.getHandlerType(), + stateMachine, + otelContext, + input); CompletableFuture userCodeFuture = this.handlerDefinition From 96adacc9682834d0201dcb51614cc2850c755251 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 12 Jan 2026 08:55:54 +0100 Subject: [PATCH 4/7] Add option to disable annotation processing for specific classes --- .../model/AnnotationProcessingOptions.java | 21 ++++++++++++++++--- .../dev/restate/sdk/gen/ServiceProcessor.java | 5 +++-- .../sdk/kotlin/gen/ServiceProcessor.kt | 3 ++- .../restate/sdk/fake/FakeHandlerContext.java | 20 ++++++++++++++++++ 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/AnnotationProcessingOptions.java b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/AnnotationProcessingOptions.java index c6a594a9..35328c36 100644 --- a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/AnnotationProcessingOptions.java +++ b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/AnnotationProcessingOptions.java @@ -15,30 +15,45 @@ public class AnnotationProcessingOptions { private static final String DISABLED_CLIENT_GENERATION = "dev.restate.codegen.disabledClientGeneration"; + private static final String DISABLED_CLASSES = "dev.restate.codegen.disabledClasses"; + private final Set disabledClientGenFQCN; + private final Set disabledClasses; public AnnotationProcessingOptions(Map options) { this.disabledClientGenFQCN = new HashSet<>(List.of(options.getOrDefault(DISABLED_CLIENT_GENERATION, "").split("[,|]"))); + this.disabledClasses = + new HashSet<>(List.of(options.getOrDefault(DISABLED_CLASSES, "").split("[,|]"))); } public boolean isClientGenDisabled(String fqcn) { return this.disabledClientGenFQCN.contains(fqcn); } + public boolean isClassDisabled(String fqcn) { + return this.disabledClasses.contains(fqcn); + } + @Override public boolean equals(Object o) { if (!(o instanceof AnnotationProcessingOptions that)) return false; - return Objects.equals(disabledClientGenFQCN, that.disabledClientGenFQCN); + return Objects.equals(disabledClientGenFQCN, that.disabledClientGenFQCN) + && Objects.equals(disabledClasses, that.disabledClasses); } @Override public int hashCode() { - return Objects.hashCode(disabledClientGenFQCN); + return Objects.hash(disabledClientGenFQCN, disabledClasses); } @Override public String toString() { - return "AnnotationProcessingOptions{" + "disabledClientGenFQCN=" + disabledClientGenFQCN + '}'; + return "AnnotationProcessingOptions{" + + "disabledClientGenFQCN=" + + disabledClientGenFQCN + + ", disabledClasses=" + + disabledClasses + + '}'; } } diff --git a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java index b4ffafea..46320b60 100644 --- a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java +++ b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java @@ -103,11 +103,12 @@ public boolean process(Set annotations, RoundEnvironment .getElementsAnnotatedWith(metaAnnotation.getAnnotationTypeElement()) .stream() .filter(e -> e.getKind().isClass() || e.getKind().isInterface()) + .map(e -> (TypeElement) e) + .filter(e -> !this.options.isClassDisabled(e.getQualifiedName().toString())) .map( e -> Map.entry( - (Element) e, - converter.fromTypeElement(metaAnnotation, (TypeElement) e)))) + (Element) e, converter.fromTypeElement(metaAnnotation, e)))) .collect(Collectors.toList()); Filer filer = processingEnv.getFiler(); diff --git a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt index 269b3a6b..b4bf50dc 100644 --- a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt +++ b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt @@ -175,7 +175,8 @@ class ServiceProcessor( when (annotatedElement.classKind) { ClassKind.INTERFACE, ClassKind.CLASS -> { - if (annotatedElement.containingFile!!.origin != Origin.KOTLIN) { + if (annotatedElement.containingFile!!.origin != Origin.KOTLIN || + options.isClassDisabled(annotatedElement.qualifiedName!!.asString())) { // Skip if it's not kotlin continue } diff --git a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java index 419b5162..75087414 100644 --- a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java +++ b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java @@ -66,6 +66,26 @@ public String toString() { expectations.requestHeaders()); } + @Override + public boolean canReadState() { + return true; + } + + @Override + public boolean canWriteState() { + return true; + } + + @Override + public boolean canReadPromises() { + return true; + } + + @Override + public boolean canWritePromises() { + return true; + } + @Override public CompletableFuture writeOutput(Slice slice) { throw new UnsupportedOperationException( From a4acd4661c49948d068d1dc5b6e3a29830f4496e Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 12 Jan 2026 09:38:23 +0100 Subject: [PATCH 5/7] Improve error message for bad handler signature --- .../MalformedRestateServiceException.java | 3 +- .../ReflectionServiceDefinitionFactory.java | 106 +++++++++++++----- 2 files changed, 80 insertions(+), 29 deletions(-) rename sdk-api/src/main/java/dev/restate/sdk/{ => internal}/MalformedRestateServiceException.java (90%) diff --git a/sdk-api/src/main/java/dev/restate/sdk/MalformedRestateServiceException.java b/sdk-api/src/main/java/dev/restate/sdk/internal/MalformedRestateServiceException.java similarity index 90% rename from sdk-api/src/main/java/dev/restate/sdk/MalformedRestateServiceException.java rename to sdk-api/src/main/java/dev/restate/sdk/internal/MalformedRestateServiceException.java index d23c4719..a2cc5e50 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/MalformedRestateServiceException.java +++ b/sdk-api/src/main/java/dev/restate/sdk/internal/MalformedRestateServiceException.java @@ -6,8 +6,9 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.internal; +@org.jetbrains.annotations.ApiStatus.Internal public class MalformedRestateServiceException extends RuntimeException { public MalformedRestateServiceException(String serviceName, String message) { diff --git a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java index 2a753858..537adffc 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java +++ b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java @@ -8,12 +8,10 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.internal; -import dev.restate.common.function.ThrowingBiFunction; import dev.restate.common.reflections.ReflectionUtils; import dev.restate.common.reflections.RestateUtils; -import dev.restate.sdk.Context; +import dev.restate.sdk.*; import dev.restate.sdk.HandlerRunner; -import dev.restate.sdk.MalformedRestateServiceException; import dev.restate.sdk.annotation.*; import dev.restate.sdk.endpoint.definition.*; import dev.restate.serde.Serde; @@ -21,8 +19,8 @@ import dev.restate.serde.provider.DefaultSerdeFactoryProvider; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.ServiceLoader; +import java.text.MessageFormat; +import java.util.*; import java.util.stream.Collectors; import org.jspecify.annotations.Nullable; @@ -117,12 +115,62 @@ public ServiceDefinition create( var handlerInfo = ReflectionUtils.mustHaveHandlerAnnotation(method); var handlerName = handlerInfo.name(); var genericParameterTypes = method.getGenericParameterTypes(); - if (genericParameterTypes.length > 1) { + var parameterCount = method.getParameterCount(); + + if ((parameterCount == 1 || parameterCount == 2) + && (genericParameterTypes[0].equals(Context.class) + || genericParameterTypes[0].equals(SharedObjectContext.class) + || genericParameterTypes[0].equals(ObjectContext.class) + || genericParameterTypes[0].equals(WorkflowContext.class) + || genericParameterTypes[0].equals(SharedWorkflowContext.class))) { + var ctxTypeName = ((Class) genericParameterTypes[0]).getSimpleName(); + var returnTypeName = + !method.getGenericReturnType().equals(Void.TYPE) + ? method.getGenericReturnType().toString() + : null; + var restateCtxGetter = ctxTypeName.substring(0, 1).toLowerCase() + ctxTypeName.substring(1); + throw new MalformedRestateServiceException( + serviceName, + MessageFormat.format( + """ + The service is being loaded with the new Reflection based API, but handler ''{0}'' contains {1} as first parameter. Suggestions: + * If you want to use the new Reflection based API, remove {2} from the method definition and use {3} inside the handler: + - {4} '{' + - // code + - '}' + Replace with: + + {5} '{' + + var ctx = Restate.{6}(); + + // code + + '} + * If you''re still using the annotation processor based API, make sure the ServiceDefinitionFactory class was correctly generated.""", + handlerName, + ctxTypeName, + ctxTypeName, + Restate.class.getName(), + renderSignature( + handlerName, + parameterCount == 1 + ? List.of(Map.entry(ctxTypeName, "ctx")) + : List.of( + Map.entry(ctxTypeName, "ctx"), + Map.entry(genericParameterTypes[1].getTypeName(), "input")), + returnTypeName), + renderSignature( + handlerName, + parameterCount == 1 + ? List.of() + : List.of(Map.entry(genericParameterTypes[1].getTypeName(), "input")), + returnTypeName), + restateCtxGetter)); + } + + if (parameterCount > 1) { throw new MalformedRestateServiceException( serviceName, "More than one parameter found in method " + method.getName() - + ". Only one parameter is supported."); + + ". Only zero or one parameter is supported."); } var inputType = genericParameterTypes.length == 0 ? Void.TYPE : genericParameterTypes[0]; var outputType = method.getGenericReturnType(); @@ -138,29 +186,19 @@ public ServiceDefinition create( ? HandlerType.EXCLUSIVE : serviceType == ServiceType.WORKFLOW ? HandlerType.WORKFLOW : null; - var parameterCount = method.getParameterCount(); - - // TODO here we should add some code to handle handling Context in method definition. - // This is because we want to make sure people declaring the handlers with the Context in the - // method works - // providing a smoother path to transition from code generation - // Plus plus plus important bit -> we need to validate the input paramters can be one and only - // one (OBV)! - var runner = dev.restate.sdk.HandlerRunner.of( - (ThrowingBiFunction) - (ctx, in) -> { - try { - if (parameterCount == 0) { - return method.invoke(serviceInstance); - } else { - return method.invoke(serviceInstance, in); - } - } catch (InvocationTargetException e) { - throw e.getCause(); - } - }, + (ctx, in) -> { + try { + if (parameterCount == 0) { + return method.invoke(serviceInstance); + } else { + return method.invoke(serviceInstance, in); + } + } catch (InvocationTargetException e) { + throw e.getCause(); + } + }, serdeFactory, overrideHandlerOptions); @@ -217,6 +255,18 @@ private SerdeFactory resolveSerdeFactory(Class serviceClazz) { } } + private String renderSignature( + String name, List> inputTypes, @Nullable String outputType) { + return Objects.requireNonNullElse(outputType, "void") + + " " + + name + + "(" + + inputTypes.stream() + .map(e -> e.getKey() + " " + e.getValue()) + .collect(Collectors.joining(", ")) + + ")"; + } + @Override public boolean supports(Object serviceObject) { return true; From fa2cffe533ee1a557fa76e5826ad10ae3bae2224 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 12 Jan 2026 10:24:25 +0100 Subject: [PATCH 6/7] Javadocs --- .../main/java/dev/restate/client/Client.java | 77 ++++ .../client/ClientServiceReference.java | 131 ++++++ .../java/dev/restate/sdk/HandlerRunner.java | 32 +- .../main/java/dev/restate/sdk/Restate.java | 382 +++++++++++++++++- .../dev/restate/sdk/ServiceReference.java | 154 +++++++ .../sdk/internal/ContextThreadLocal.java | 25 ++ .../dev/restate/sdk/fake/FakeRestate.java | 122 ++++++ 7 files changed, 898 insertions(+), 25 deletions(-) create mode 100644 sdk-api/src/main/java/dev/restate/sdk/internal/ContextThreadLocal.java create mode 100644 sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeRestate.java diff --git a/client/src/main/java/dev/restate/client/Client.java b/client/src/main/java/dev/restate/client/Client.java index 286570d8..db98dca7 100644 --- a/client/src/main/java/dev/restate/client/Client.java +++ b/client/src/main/java/dev/restate/client/Client.java @@ -530,18 +530,95 @@ default Response> getOutput() throws IngressException { } } + /** + * EXPERIMENTAL API: Create a reference to invoke a Restate service from the ingress. This + * API may change in future releases. + * + *

You can invoke the service in three ways: + * + *

{@code
+   * Client client = Client.connect("http://localhost:8080");
+   *
+   * // 1. Create a client proxy and call it directly (returns output directly)
+   * var greeterProxy = client.service(Greeter.class).client();
+   * GreetingResponse output = greeterProxy.greet(new Greeting("Alice"));
+   *
+   * // 2. Use call() with method reference and wait for the result
+   * Response response = client.service(Greeter.class)
+   *   .call(Greeter::greet, new Greeting("Alice"));
+   *
+   * // 3. Use send() for one-way invocation without waiting
+   * SendResponse sendResponse = client.service(Greeter.class)
+   *   .send(Greeter::greet, new Greeting("Alice"));
+   * }
+ * + * @param clazz the service class annotated with {@link Service} + * @return a reference to invoke the service + */ @org.jetbrains.annotations.ApiStatus.Experimental default ClientServiceReference service(Class clazz) { mustHaveAnnotation(clazz, Service.class); return new ClientServiceReferenceImpl<>(this, clazz, null); } + /** + * EXPERIMENTAL API: Create a reference to invoke a Restate Virtual Object from the + * ingress. This API may change in future releases. + * + *

You can invoke the virtual object in three ways: + * + *

{@code
+   * Client client = Client.connect("http://localhost:8080");
+   *
+   * // 1. Create a client proxy and call it directly (returns output directly)
+   * var counterProxy = client.virtualObject(Counter.class, "my-counter").client();
+   * int count = counterProxy.increment();
+   *
+   * // 2. Use call() with method reference and wait for the result
+   * Response response = client.virtualObject(Counter.class, "my-counter")
+   *   .call(Counter::increment);
+   *
+   * // 3. Use send() for one-way invocation without waiting
+   * SendResponse sendResponse = client.virtualObject(Counter.class, "my-counter")
+   *   .send(Counter::increment);
+   * }
+ * + * @param clazz the virtual object class annotated with {@link VirtualObject} + * @param key the key identifying the specific virtual object instance + * @return a reference to invoke the virtual object + */ @org.jetbrains.annotations.ApiStatus.Experimental default ClientServiceReference virtualObject(Class clazz, String key) { mustHaveAnnotation(clazz, VirtualObject.class); return new ClientServiceReferenceImpl<>(this, clazz, key); } + /** + * EXPERIMENTAL API: Create a reference to invoke a Restate Workflow from the ingress. + * This API may change in future releases. + * + *

You can invoke the workflow in three ways: + * + *

{@code
+   * Client client = Client.connect("http://localhost:8080");
+   *
+   * // 1. Create a client proxy and call it directly (returns output directly)
+   * var workflowProxy = client.workflow(OrderWorkflow.class, "order-123").client();
+   * OrderResult result = workflowProxy.start(new OrderRequest(...));
+   *
+   * // 2. Use call() with method reference and wait for the result
+   * Response response = client.workflow(OrderWorkflow.class, "order-123")
+   *   .call(OrderWorkflow::start, new OrderRequest(...));
+   *
+   * // 3. Use send() for one-way invocation without waiting
+   * SendResponse sendResponse = client.workflow(OrderWorkflow.class, "order-123")
+   *   .send(OrderWorkflow::start, new OrderRequest(...));
+   * }
+ * + * @param clazz the workflow class annotated with {@link Workflow} + * @param key the key identifying the specific workflow instance + * @return a reference to invoke the workflow + */ @org.jetbrains.annotations.ApiStatus.Experimental default ClientServiceReference workflow(Class clazz, String key) { mustHaveAnnotation(clazz, Workflow.class); diff --git a/client/src/main/java/dev/restate/client/ClientServiceReference.java b/client/src/main/java/dev/restate/client/ClientServiceReference.java index 38dda052..41a50bb7 100644 --- a/client/src/main/java/dev/restate/client/ClientServiceReference.java +++ b/client/src/main/java/dev/restate/client/ClientServiceReference.java @@ -17,12 +17,67 @@ import java.util.function.Consumer; import java.util.function.Function; +/** + * EXPERIMENTAL API: This interface is part of the new reflection-based API and may change + * in future releases. + * + *

A reference to a Restate service, virtual object, or workflow that can be invoked from the + * ingress (outside of a handler). Provides three ways to invoke methods: + * + *

{@code
+ * Client client = Client.connect("http://localhost:8080");
+ *
+ * // 1. Create a client proxy and call it directly (returns output directly)
+ * var greeterProxy = client.service(Greeter.class).client();
+ * GreetingResponse output = greeterProxy.greet(new Greeting("Alice"));
+ *
+ * // 2. Use call() with method reference and wait for the result
+ * Response response = client.service(Greeter.class)
+ *   .call(Greeter::greet, new Greeting("Alice"));
+ *
+ * // 3. Use send() for one-way invocation without waiting
+ * SendResponse sendResponse = client.service(Greeter.class)
+ *   .send(Greeter::greet, new Greeting("Alice"));
+ * }
+ * + *

Create instances using {@link Client#service(Class)}, {@link + * Client#virtualObject(Class, String)}, or {@link Client#workflow(Class, String)}. + * + * @param the service interface type + */ @org.jetbrains.annotations.ApiStatus.Experimental public interface ClientServiceReference { + /** + * EXPERIMENTAL API: Get a client proxy to call methods directly (returns output directly, + * not wrapped in Response). + * + *

{@code
+   * Client client = Client.connect("http://localhost:8080");
+   *
+   * // Get a proxy and call methods on it (returns output directly)
+   * var greeterProxy = client.service(Greeter.class).client();
+   * GreetingResponse output = greeterProxy.greet(new Greeting("Alice"));
+   * }
+ * + * @return a proxy instance of the service interface + */ @org.jetbrains.annotations.ApiStatus.Experimental SVC client(); // call - BiFunction variants + /** + * EXPERIMENTAL API: Invoke a service method with input and wait for the response. + * + *
{@code
+   * // Call with method reference and input
+   * Response response = client.service(Greeter.class)
+   *   .call(Greeter::greet, new Greeting("Alice"));
+   * }
+ * + * @param s method reference (e.g., {@code Greeter::greet}) + * @param input the input parameter to pass to the method + * @return a {@link Response} wrapping the result + */ @org.jetbrains.annotations.ApiStatus.Experimental default Response call(BiFunction s, I input) { return call(s, input, InvocationOptions.DEFAULT); @@ -182,34 +237,54 @@ default CompletableFuture> callAsync( CompletableFuture> callAsync(Consumer s, InvocationOptions invocationOptions); // send - BiFunction variants + /** + * EXPERIMENTAL API: Send a one-way invocation without waiting for the response. + * + *
{@code
+   * // Send without waiting for response
+   * SendResponse sendResponse = client.service(Greeter.class)
+   *   .send(Greeter::greet, new Greeting("Alice"));
+   * }
+ */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(BiFunction s, I input) { return send(s, input, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( BiFunction s, I input, InvocationOptions.Builder options) { return send(s, input, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( BiFunction s, I input, InvocationOptions invocationOptions) { return send(s, input, null, invocationOptions); } + /** EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(BiFunction s, I input, Duration delay) { return send(s, input, delay, InvocationOptions.DEFAULT); } + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with a delay and invocation + * options. + */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( BiFunction s, I input, Duration delay, InvocationOptions.Builder options) { return send(s, input, delay, options.build()); } + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with a delay and invocation + * options. + */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( BiFunction s, I input, Duration delay, InvocationOptions invocationOptions) { @@ -224,34 +299,49 @@ default SendResponse send( } // send - BiConsumer variants + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without a return + * value. + */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(BiConsumer s, I input) { return send(s, input, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( BiConsumer s, I input, InvocationOptions.Builder options) { return send(s, input, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( BiConsumer s, I input, InvocationOptions invocationOptions) { return send(s, input, null, invocationOptions); } + /** EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(BiConsumer s, I input, Duration delay) { return send(s, input, delay, InvocationOptions.DEFAULT); } + /** + * EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with a delay and invocation + * options. + */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( BiConsumer s, I input, Duration delay, InvocationOptions.Builder options) { return send(s, input, delay, options.build()); } + /** + * EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with a delay and invocation + * options. + */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( BiConsumer s, I input, Duration delay, InvocationOptions invocationOptions) { @@ -266,32 +356,40 @@ default SendResponse send( } // send - Function variants + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without input. + */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(Function s) { return send(s, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(Function s, InvocationOptions.Builder options) { return send(s, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(Function s, InvocationOptions invocationOptions) { return send(s, null, invocationOptions); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(Function s, Duration delay) { return send(s, delay, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with a delay and invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( Function s, Duration delay, InvocationOptions.Builder options) { return send(s, delay, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with a delay and invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( Function s, Duration delay, InvocationOptions invocationOptions) { @@ -306,32 +404,41 @@ default SendResponse send( } // send - Consumer variants + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without input or + * return value. + */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(Consumer s) { return send(s, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(Consumer s, InvocationOptions.Builder options) { return send(s, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(Consumer s, InvocationOptions invocationOptions) { return send(s, null, invocationOptions); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(Consumer s, Duration delay) { return send(s, delay, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with a delay and invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( Consumer s, Duration delay, InvocationOptions.Builder options) { return send(s, delay, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with a delay and invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send( Consumer s, Duration delay, InvocationOptions invocationOptions) { @@ -346,135 +453,159 @@ default SendResponse send( } // sendAsync - BiFunction variants + /** EXPERIMENTAL API: Async version of {@link #send(BiFunction, Object)}. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync(BiFunction s, I input) { return sendAsync(s, input, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, with options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiFunction s, I input, InvocationOptions.Builder options) { return sendAsync(s, input, options.build()); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, with options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiFunction s, I input, InvocationOptions invocationOptions) { return sendAsync(s, input, null, invocationOptions); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiFunction s, I input, Duration delay) { return sendAsync(s, input, delay, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, with delay and options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiFunction s, I input, Duration delay, InvocationOptions.Builder options) { return sendAsync(s, input, delay, options.build()); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, with delay and options. */ @org.jetbrains.annotations.ApiStatus.Experimental CompletableFuture> sendAsync( BiFunction s, I input, Duration delay, InvocationOptions invocationOptions); // sendAsync - BiConsumer variants + /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, for void methods. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync(BiConsumer s, I input) { return sendAsync(s, input, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiConsumer, Object)}, with options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiConsumer s, I input, InvocationOptions.Builder options) { return sendAsync(s, input, options.build()); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiConsumer, Object)}, with options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiConsumer s, I input, InvocationOptions invocationOptions) { return sendAsync(s, input, null, invocationOptions); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiConsumer, Object)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiConsumer s, I input, Duration delay) { return sendAsync(s, input, delay, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiConsumer, Object)}, with delay and options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiConsumer s, I input, Duration delay, InvocationOptions.Builder options) { return sendAsync(s, input, delay, options.build()); } + /** EXPERIMENTAL API: Like {@link #sendAsync(BiConsumer, Object)}, with delay and options. */ @org.jetbrains.annotations.ApiStatus.Experimental CompletableFuture> sendAsync( BiConsumer s, I input, Duration delay, InvocationOptions invocationOptions); // sendAsync - Function variants + /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, for no-input methods. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync(Function s) { return sendAsync(s, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Function)}, with options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( Function s, InvocationOptions.Builder options) { return sendAsync(s, options.build()); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Function)}, with options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( Function s, InvocationOptions invocationOptions) { return sendAsync(s, null, invocationOptions); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Function)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync(Function s, Duration delay) { return sendAsync(s, delay, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Function)}, with delay and options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( Function s, Duration delay, InvocationOptions.Builder options) { return sendAsync(s, delay, options.build()); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Function)}, with delay and options. */ @org.jetbrains.annotations.ApiStatus.Experimental CompletableFuture> sendAsync( Function s, Duration delay, InvocationOptions invocationOptions); // sendAsync - Consumer variants + /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, for no-input/void methods. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync(Consumer s) { return sendAsync(s, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Consumer)}, with options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( Consumer s, InvocationOptions.Builder options) { return sendAsync(s, options.build()); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Consumer)}, with options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( Consumer s, InvocationOptions invocationOptions) { return sendAsync(s, null, invocationOptions); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Consumer)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync(Consumer s, Duration delay) { return sendAsync(s, delay, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Consumer)}, with delay and options. */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( Consumer s, Duration delay, InvocationOptions.Builder options) { return sendAsync(s, delay, options.build()); } + /** EXPERIMENTAL API: Like {@link #sendAsync(Consumer)}, with delay and options. */ @org.jetbrains.annotations.ApiStatus.Experimental CompletableFuture> sendAsync( Consumer s, Duration delay, InvocationOptions invocationOptions); diff --git a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java index a9a913f5..48854481 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java +++ b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java @@ -15,9 +15,11 @@ import dev.restate.common.function.ThrowingFunction; import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.endpoint.definition.HandlerContext; +import dev.restate.sdk.internal.ContextThreadLocal; import dev.restate.serde.Serde; import dev.restate.serde.SerdeFactory; import io.opentelemetry.context.Scope; + import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -91,12 +93,12 @@ public CompletableFuture run( RES res = null; Throwable error = null; try { - setContext(ctx); + ContextThreadLocal.setContext(ctx); res = this.runner.apply(ctx, req); } catch (Throwable e) { error = e; } finally { - clearContext(); + ContextThreadLocal.clearContext(); } // If error, just return now @@ -202,25 +204,9 @@ public static Options withExecutor(Executor executor) { } } - static final ThreadLocal CONTEXT_THREAD_LOCAL = new ThreadLocal<>(); - - static Context getContext() { - return Objects.requireNonNull( - CONTEXT_THREAD_LOCAL.get(), - "Restate methods must be invoked from within a Restate handler"); - } - - static HandlerContext getHandlerContext() { - return Objects.requireNonNull( - HANDLER_CONTEXT_THREAD_LOCAL.get(), - "Restate methods must be invoked from within a Restate handler"); - } - - private static void setContext(Context context) { - CONTEXT_THREAD_LOCAL.set(context); - } - - private static void clearContext() { - CONTEXT_THREAD_LOCAL.remove(); - } + static HandlerContext getHandlerContext() { + return Objects.requireNonNull( + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get(), + "Restate methods must be invoked from within a Restate handler"); + } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/Restate.java b/sdk-api/src/main/java/dev/restate/sdk/Restate.java index 94c31420..85e6b896 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Restate.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Restate.java @@ -16,19 +16,93 @@ import dev.restate.sdk.annotation.Service; import dev.restate.sdk.annotation.VirtualObject; import dev.restate.sdk.annotation.Workflow; +import dev.restate.sdk.common.AbortedExecutionException; import dev.restate.sdk.common.HandlerRequest; import dev.restate.sdk.common.RetryPolicy; import dev.restate.sdk.common.TerminalException; +import dev.restate.sdk.internal.ContextThreadLocal; import dev.restate.serde.TypeTag; import java.time.Duration; +/** + * This class exposes the Restate functionalities to Restate services using the reflection-based + * API. It can be used to interact with other Restate services, record non-deterministic closures, + * execute timers, and synchronize with external systems. + * + *

This is the entry point for the new reflection-based API where services are defined using + * annotations and methods can access Restate features through static methods on this class. + * + *

Example Usage

+ * + *
{@code
+ * @Service
+ * public class Greeter {
+ *
+ *   @Handler
+ *   public String greet(String input) {
+ *     // Use Restate features via static methods
+ *     String result = Restate.run(
+ *       "external-call",
+ *       String.class,
+ *       () -> externalService.call(input)
+ *     );
+ *
+ *     return "You said hi to " + req.name + "!";
+ *   }
+ * }
+ * }
+ * + *

Error handling

+ * + * All methods of this class throws either {@link TerminalException} or {@link + * AbortedExecutionException}, where the former can be caught and acted upon, while the latter MUST + * NOT be caught, but simply propagated for clean up purposes. + * + *

Serialization and Deserialization

+ * + * The methods of this class that need to serialize or deserialize payloads have an overload both + * accepting {@link Class} or {@link TypeTag}. Depending on your case, you might use the {@link + * Class} overload for simple types, and {@link dev.restate.serde.TypeRef} for generic types. + * + *

By default, Jackson Databind will be used for all serialization/deserialization. Check {@link + * dev.restate.serde.SerdeFactory} for more details on how to customize that. + * + *

Thread safety

+ * + * This class MUST NOT be accessed concurrently since it can lead to different orderings of + * user actions, corrupting the execution of the invocation. + * + * @see Context + */ @org.jetbrains.annotations.ApiStatus.Experimental public final class Restate { + /** + * Get the base {@link Context} for the current handler invocation. + * + *

This method is safe to call from any Restate handler (Service, Virtual Object, or Workflow). + * + *

For handlers requiring access to state or promises, prefer using the specialized context + * getters: {@link #objectContext()}, {@link #sharedObjectContext()}, {@link #workflowContext()}, + * or {@link #sharedWorkflowContext()}. + * + * @return the current context + * @throws IllegalStateException if called outside a Restate handler + */ @org.jetbrains.annotations.ApiStatus.Experimental public static Context context() { - return HandlerRunner.getContext(); + return ContextThreadLocal.getContext(); } + /** + * Get the {@link ObjectContext} for the current Virtual Object handler invocation. + * + *

This context provides access to read and write state operations for Virtual Objects. It is + * safe to call this method only from exclusive Virtual Object handlers (non-shared handlers). + * + * @return the current object context + * @throws IllegalStateException if called from a shared Virtual Object handler (use {@link + * #sharedObjectContext()} instead) or from a non-Virtual Object handler + */ @org.jetbrains.annotations.ApiStatus.Experimental public static ObjectContext objectContext() { var handlerContext = HandlerRunner.getHandlerContext(); @@ -45,6 +119,16 @@ public static ObjectContext objectContext() { "Calling objectContext() from a non Virtual object handler. You can use Restate.objectContext() only inside a Restate Virtual Object handler."); } + /** + * Get the {@link SharedObjectContext} for the current Virtual Object shared handler invocation. + * + *

This context provides read-only access to state operations for Virtual Objects. It is safe + * to call this method from shared Virtual Object handlers that need to read state but not modify + * it. + * + * @return the current shared object context + * @throws IllegalStateException if called from a non-Virtual Object handler + */ @org.jetbrains.annotations.ApiStatus.Experimental public static SharedObjectContext sharedObjectContext() { var handlerContext = HandlerRunner.getHandlerContext(); @@ -57,6 +141,16 @@ public static SharedObjectContext sharedObjectContext() { "Calling objectContext() from a non Virtual object handler. You can use Restate.objectContext() only inside a Restate Virtual Object handler."); } + /** + * Get the {@link WorkflowContext} for the current Workflow handler invocation. + * + *

This context provides access to read and write promise operations for Workflows. It is safe + * to call this method only from exclusive Workflow handlers (non-shared handlers). + * + * @return the current workflow context + * @throws IllegalStateException if called from a shared Workflow handler (use {@link + * #sharedWorkflowContext()} instead) or from a non-Workflow handler + */ @org.jetbrains.annotations.ApiStatus.Experimental public static WorkflowContext workflowContext() { var handlerContext = HandlerRunner.getHandlerContext(); @@ -73,6 +167,15 @@ public static WorkflowContext workflowContext() { "Calling workflowContext() from a non Workflow handler. You can use Restate.workflowContext() only inside a Restate Workflow handler."); } + /** + * Get the {@link SharedWorkflowContext} for the current Workflow shared handler invocation. + * + *

This context provides read-only access to promise operations for Workflows. It is safe to + * call this method from shared Workflow handlers that need to read promises but not modify them. + * + * @return the current shared workflow context + * @throws IllegalStateException if called from a non-Workflow handler + */ @org.jetbrains.annotations.ApiStatus.Experimental public static SharedWorkflowContext sharedWorkflowContext() { var handlerContext = HandlerRunner.getHandlerContext(); @@ -85,54 +188,132 @@ public static SharedWorkflowContext sharedWorkflowContext() { "Calling workflowContext() from a non Workflow handler. You can use Restate.workflowContext() only inside a Restate Workflow handler."); } + /** + * Check if the current code is executing inside a Restate handler. + * + * @return true if currently inside a handler, false otherwise + */ @org.jetbrains.annotations.ApiStatus.Experimental public static boolean isInsideHandler() { - return HandlerRunner.CONTEXT_THREAD_LOCAL.get() != null; + return ContextThreadLocal.CONTEXT_THREAD_LOCAL.get() != null; } + /** @see Context#request() */ @org.jetbrains.annotations.ApiStatus.Experimental public static HandlerRequest request() { return context().request(); } + /** + * Returns a deterministic random. + * + * @see RestateRandom + * @see Context#random() + */ @org.jetbrains.annotations.ApiStatus.Experimental public static RestateRandom random() { return context().random(); } + /** @see Context#invocationHandle(String, TypeTag) */ @org.jetbrains.annotations.ApiStatus.Experimental public static InvocationHandle invocationHandle( String invocationId, TypeTag responseTypeTag) { return context().invocationHandle(invocationId, responseTypeTag); } + /** + * Get an {@link InvocationHandle} for an already existing invocation. This will let you interact + * with a running invocation, for example to cancel it or retrieve its result. + * + * @param invocationId The invocation to interact with. + * @param responseClazz The response class. + * @see Context#invocationHandle(String, Class) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static InvocationHandle invocationHandle( String invocationId, Class responseClazz) { return context().invocationHandle(invocationId, responseClazz); } + /** + * Like {@link #invocationHandle(String, Class)}, without providing a response parser + * + * @see Context#invocationHandle(String) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static InvocationHandle invocationHandle(String invocationId) { return context().invocationHandle(invocationId); } + /** + * Causes the current execution of the function invocation to sleep for the given duration. + * + * @param duration for which to sleep. + * @see Context#sleep(Duration) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static void sleep(Duration duration) { context().sleep(duration); } + /** + * Causes the start of a timer for the given duration. You can await on the timer end by invoking + * {@link DurableFuture#await()}. + * + * @param name name used for observability + * @param duration for which to sleep. + * @see Context#timer(String, Duration) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture timer(String name, Duration duration) { return context().timer(name, duration); } + /** + * Execute a closure, recording the result value in the journal. The result value will be + * re-played in case of re-invocation (e.g. because of failure recovery or suspension point) + * without re-executing the closure. + * + *

If the result type contains generic types, e.g. a {@code List}, you should use + * {@link #run(String, TypeTag, ThrowingSupplier)}. See {@link Context} for more details about + * serialization and deserialization. + * + *

You can name this closure using the {@code name} parameter. This name will be available in + * the observability tools. + * + *

The closure should tolerate retries, that is Restate might re-execute the closure multiple + * times until it records a result. You can control and limit the amount of retries using {@link + * #run(String, Class, RetryPolicy, ThrowingSupplier)}. + * + *

Error handling: Errors occurring within this closure won't be propagated to the + * caller, unless they are {@link TerminalException}. To propagate run failures to the call-site, + * make sure to wrap them in {@link TerminalException}. + * + * @param name name of the side effect. + * @param clazz the class of the return value, used to serialize/deserialize it. + * @param action closure to execute. + * @param type of the return value. + * @return value of the run operation. + * @see Context#run(String, Class, ThrowingSupplier) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static T run(String name, Class clazz, ThrowingSupplier action) throws TerminalException { return context().run(name, clazz, action); } + /** + * Like {@link #run(String, TypeTag, ThrowingSupplier)}, but using a custom retry policy. + * + *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, + * which by default retries indefinitely. + * + * @see #run(String, Class, ThrowingSupplier) + * @see RetryPolicy + * @see Context#run(String, TypeTag, RetryPolicy, ThrowingSupplier) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static T run( String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) @@ -140,6 +321,17 @@ public static T run( return context().run(name, typeTag, retryPolicy, action); } + /** + * Like {@link #run(String, Class, ThrowingSupplier)}, but using a custom retry policy. + * + *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, + * which by default retries indefinitely. + * + * @see #run(String, Class, ThrowingSupplier) + * @see RetryPolicy + * @see Context#run(String, Class, RetryPolicy, ThrowingSupplier) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static T run( String name, Class clazz, RetryPolicy retryPolicy, ThrowingSupplier action) @@ -147,35 +339,87 @@ public static T run( return context().run(name, clazz, retryPolicy, action); } + /** + * Like {@link #run(String, Class, ThrowingSupplier)}, but providing a {@link TypeTag}. + * + *

See {@link Context} for more details about serialization and deserialization. + * + * @see #run(String, Class, ThrowingSupplier) + * @see Context#run(String, TypeTag, ThrowingSupplier) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static T run(String name, TypeTag typeTag, ThrowingSupplier action) throws TerminalException { return context().run(name, typeTag, action); } + /** + * Like {@link #run(String, ThrowingRunnable)}, but without a return value and using a custom + * retry policy. + * + *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, + * which by default retries indefinitely. + * + * @see #run(String, Class, ThrowingSupplier) + * @see RetryPolicy + * @see Context#run(String, RetryPolicy, ThrowingRunnable) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static void run(String name, RetryPolicy retryPolicy, ThrowingRunnable runnable) throws TerminalException { context().run(name, retryPolicy, runnable); } + /** + * Like {@link #run(String, Class, ThrowingSupplier)} without output. + * + * @see #run(String, Class, ThrowingSupplier) + * @see Context#run(String, ThrowingRunnable) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static void run(String name, ThrowingRunnable runnable) throws TerminalException { context().run(name, runnable); } + /** + * Execute a closure asynchronously. This is like {@link #run(String, Class, ThrowingSupplier)}, + * but it returns a {@link DurableFuture} that you can combine and select. + * + * @see #run(String, Class, ThrowingSupplier) + * @see Context#runAsync(String, Class, ThrowingSupplier) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, Class clazz, ThrowingSupplier action) throws TerminalException { return context().runAsync(name, clazz, action); } + /** + * Like {@link #runAsync(String, Class, ThrowingSupplier)}, but providing a {@link TypeTag}. + * + *

See {@link Context} for more details about serialization and deserialization. + * + * @see #runAsync(String, Class, ThrowingSupplier) + * @see Context#runAsync(String, TypeTag, ThrowingSupplier) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, TypeTag typeTag, ThrowingSupplier action) throws TerminalException { return context().runAsync(name, typeTag, action); } + /** + * Like {@link #runAsync(String, Class, ThrowingSupplier)}, but using a custom retry policy. + * + *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, + * which by default retries indefinitely. + * + * @see #runAsync(String, Class, ThrowingSupplier) + * @see RetryPolicy + * @see Context#runAsync(String, Class, RetryPolicy, ThrowingSupplier) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, Class clazz, RetryPolicy retryPolicy, ThrowingSupplier action) @@ -183,6 +427,17 @@ public static DurableFuture runAsync( return context().runAsync(name, clazz, retryPolicy, action); } + /** + * Like {@link #runAsync(String, TypeTag, ThrowingSupplier)}, but using a custom retry policy. + * + *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, + * which by default retries indefinitely. + * + * @see #runAsync(String, Class, ThrowingSupplier) + * @see RetryPolicy + * @see Context#runAsync(String, TypeTag, RetryPolicy, ThrowingSupplier) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) @@ -190,45 +445,168 @@ public static DurableFuture runAsync( return context().runAsync(name, typeTag, retryPolicy, action); } + /** + * Like {@link #runAsync(String, Class, ThrowingSupplier)}, but without an output and using a + * custom retry policy. + * + *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, + * which by default retries indefinitely. + * + * @see #runAsync(String, Class, ThrowingSupplier) + * @see RetryPolicy + * @see Context#runAsync(String, RetryPolicy, ThrowingRunnable) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync( String name, RetryPolicy retryPolicy, ThrowingRunnable runnable) throws TerminalException { return context().runAsync(name, retryPolicy, runnable); } + /** + * Like {@link #runAsync(String, Class, ThrowingSupplier)} without output. + * + * @see Context#runAsync(String, ThrowingRunnable) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static DurableFuture runAsync(String name, ThrowingRunnable runnable) throws TerminalException { return context().runAsync(name, runnable); } + /** + * Create an {@link Awakeable}, addressable through {@link Awakeable#id()}. + * + *

You can use this feature to implement external asynchronous systems interactions, for + * example you can send a Kafka record including the {@link Awakeable#id()}, and then let another + * service consume from Kafka the responses of given external system interaction by using {@link + * #awakeableHandle(String)}. + * + * @param clazz the response type to use for deserializing the {@link Awakeable} result. When + * using generic types, use {@link #awakeable(TypeTag)} instead. + * @return the {@link Awakeable} to await on. + * @see Awakeable + * @see Context#awakeable(Class) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static Awakeable awakeable(Class clazz) { return context().awakeable(clazz); } + /** + * Create an {@link Awakeable}, addressable through {@link Awakeable#id()}. + * + *

You can use this feature to implement external asynchronous systems interactions, for + * example you can send a Kafka record including the {@link Awakeable#id()}, and then let another + * service consume from Kafka the responses of given external system interaction by using {@link + * #awakeableHandle(String)}. + * + * @param typeTag the response type tag to use for deserializing the {@link Awakeable} result. + * @return the {@link Awakeable} to await on. + * @see Awakeable + * @see Context#awakeable(TypeTag) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static Awakeable awakeable(TypeTag typeTag) { return context().awakeable(typeTag); } + /** + * Create a new {@link AwakeableHandle} for the provided identifier. You can use it to {@link + * AwakeableHandle#resolve(TypeTag, Object)} or {@link AwakeableHandle#reject(String)} the linked + * {@link Awakeable}. + * + * @see Awakeable + * @see Context#awakeableHandle(String) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static AwakeableHandle awakeableHandle(String id) { return context().awakeableHandle(id); } + /** + * EXPERIMENTAL API: Create a reference to invoke a Restate service. + * + *

You can invoke the service in three ways: + * + *

{@code
+   * // 1. Create a client proxy and call it directly
+   * var greeterProxy = Restate.service(Greeter.class).client();
+   * GreetingResponse response = greeterProxy.greet(new Greeting("Alice"));
+   *
+   * // 2. Use call() with method reference and await the result
+   * GreetingResponse response = Restate.service(Greeter.class)
+   *   .call(Greeter::greet, new Greeting("Alice"))
+   *   .await();
+   *
+   * // 3. Use send() for one-way invocation without waiting
+   * InvocationHandle handle = Restate.service(Greeter.class)
+   *   .send(Greeter::greet, new Greeting("Alice"));
+   * }
+ * + * @param clazz the service class annotated with {@link Service} + * @return a reference to invoke the service + */ @org.jetbrains.annotations.ApiStatus.Experimental public static ServiceReference service(Class clazz) { mustHaveAnnotation(clazz, Service.class); return new ServiceReferenceImpl<>(clazz, null); } + /** + * EXPERIMENTAL API: Create a reference to invoke a Restate Virtual Object. + * + *

You can invoke the virtual object in three ways: + * + *

{@code
+   * // 1. Create a client proxy and call it directly
+   * var counterProxy = Restate.virtualObject(Counter.class, "my-counter").client();
+   * int count = counterProxy.increment();
+   *
+   * // 2. Use call() with method reference and await the result
+   * int count = Restate.virtualObject(Counter.class, "my-counter")
+   *   .call(Counter::increment)
+   *   .await();
+   *
+   * // 3. Use send() for one-way invocation without waiting
+   * InvocationHandle handle = Restate.virtualObject(Counter.class, "my-counter")
+   *   .send(Counter::increment);
+   * }
+ * + * @param clazz the virtual object class annotated with {@link VirtualObject} + * @param key the key identifying the specific virtual object instance + * @return a reference to invoke the virtual object + */ @org.jetbrains.annotations.ApiStatus.Experimental public static ServiceReference virtualObject(Class clazz, String key) { mustHaveAnnotation(clazz, VirtualObject.class); return new ServiceReferenceImpl<>(clazz, key); } + /** + * EXPERIMENTAL API: Create a reference to invoke a Restate Workflow. + * + *

You can invoke the workflow in three ways: + * + *

{@code
+   * // 1. Create a client proxy and call it directly
+   * var workflowProxy = Restate.workflow(OrderWorkflow.class, "order-123").client();
+   * workflowProxy.start(new OrderRequest(...));
+   *
+   * // 2. Use call() with method reference and await the result
+   * Restate.workflow(OrderWorkflow.class, "order-123")
+   *   .call(OrderWorkflow::start, new OrderRequest(...))
+   *   .await();
+   *
+   * // 3. Use send() for one-way invocation without waiting
+   * InvocationHandle handle = Restate.workflow(OrderWorkflow.class, "order-123")
+   *   .send(OrderWorkflow::start, new OrderRequest(...));
+   * }
+ * + * @param clazz the workflow class annotated with {@link Workflow} + * @param key the key identifying the specific workflow instance + * @return a reference to invoke the workflow + */ @org.jetbrains.annotations.ApiStatus.Experimental public static ServiceReference workflow(Class clazz, String key) { mustHaveAnnotation(clazz, Workflow.class); diff --git a/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java b/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java index 71ba1aa8..664644ba 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java @@ -15,184 +15,338 @@ import java.util.function.Consumer; import java.util.function.Function; +/** + * EXPERIMENTAL API: This interface is part of the new reflection-based API and may change + * in future releases. + * + *

A reference to a Restate service, virtual object, or workflow that can be invoked from within + * a handler. Provides three ways to invoke methods: + * + *

{@code
+ * // 1. Create a client proxy and call it directly
+ * var greeterProxy = Restate.service(Greeter.class).client();
+ * GreetingResponse response = greeterProxy.greet(new Greeting("Alice"));
+ *
+ * // 2. Use call() with method reference and await the result
+ * GreetingResponse response = Restate.service(Greeter.class)
+ *   .call(Greeter::greet, new Greeting("Alice"))
+ *   .await();
+ *
+ * // 3. Use send() for one-way invocation without waiting
+ * InvocationHandle handle = Restate.service(Greeter.class)
+ *   .send(Greeter::greet, new Greeting("Alice"));
+ * }
+ * + *

Create instances using {@link Restate#service(Class)}, {@link + * Restate#virtualObject(Class, String)}, or {@link Restate#workflow(Class, String)}. + * + * @param the service interface type + */ @org.jetbrains.annotations.ApiStatus.Experimental public interface ServiceReference { + /** + * EXPERIMENTAL API: Get a client proxy to call methods directly. + * + *

{@code
+   * // Get a proxy and call methods on it
+   * var greeterProxy = Restate.service(Greeter.class).client();
+   * GreetingResponse response = greeterProxy.greet(new Greeting("Alice"));
+   * }
+ * + * @return a proxy instance of the service interface + */ @org.jetbrains.annotations.ApiStatus.Experimental SVC client(); + /** + * EXPERIMENTAL API: Invoke a service method with input and return a future for the result. + * + *
{@code
+   * // Call with method reference and input
+   * GreetingResponse response = Restate.service(Greeter.class)
+   *   .call(Greeter::greet, new Greeting("Alice"))
+   *   .await();
+   * }
+ * + * @param s method reference (e.g., {@code Greeter::greet}) + * @param input the input parameter to pass to the method + * @return a {@link DurableFuture} wrapping the result + */ @org.jetbrains.annotations.ApiStatus.Experimental default DurableFuture call(BiFunction s, I input) { return call(s, input, InvocationOptions.DEFAULT); } + /** + * EXPERIMENTAL API: Like {@link #call(BiFunction, Object)}, with invocation options. + * + *
{@code
+   * // Call with custom options
+   * var options = InvocationOptions.builder()
+   *   .idempotencyKey("unique-key")
+   *   .build();
+   * GreetingResponse response = Restate.service(Greeter.class)
+   *   .call(Greeter::greet, new Greeting("Alice"), options)
+   *   .await();
+   * }
+ */ @org.jetbrains.annotations.ApiStatus.Experimental default DurableFuture call( BiFunction s, I input, InvocationOptions.Builder options) { return call(s, input, options.build()); } + /** + * EXPERIMENTAL API: Like {@link #call(BiFunction, Object)}, with invocation options. + */ @org.jetbrains.annotations.ApiStatus.Experimental DurableFuture call(BiFunction s, I input, InvocationOptions options); + /** + * EXPERIMENTAL API: Like {@link #call(BiFunction, Object)}, for methods without a return + * value. + */ @org.jetbrains.annotations.ApiStatus.Experimental default DurableFuture call(BiConsumer s, I input) { return call(s, input, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #call(BiConsumer, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default DurableFuture call( BiConsumer s, I input, InvocationOptions.Builder options) { return call(s, input, options.build()); } + /** EXPERIMENTAL API: Like {@link #call(BiConsumer, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental DurableFuture call(BiConsumer s, I input, InvocationOptions options); + /** + * EXPERIMENTAL API: Invoke a service method without input and return a future for the + * result. + * + *
{@code
+   * // Call method without input
+   * int count = Restate.virtualObject(Counter.class, "my-counter")
+   *   .call(Counter::get)
+   *   .await();
+   * }
+ */ @org.jetbrains.annotations.ApiStatus.Experimental default DurableFuture call(Function s) { return call(s, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #call(Function)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default DurableFuture call(Function s, InvocationOptions.Builder options) { return call(s, options.build()); } + /** EXPERIMENTAL API: Like {@link #call(Function)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental DurableFuture call(Function s, InvocationOptions options); + /** + * EXPERIMENTAL API: Like {@link #call(BiFunction, Object)}, for methods without input or + * return value. + */ @org.jetbrains.annotations.ApiStatus.Experimental default DurableFuture call(Consumer s) { return call(s, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #call(Consumer)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default DurableFuture call(Consumer s, InvocationOptions.Builder options) { return call(s, options.build()); } + /** EXPERIMENTAL API: Like {@link #call(Consumer)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental DurableFuture call(Consumer s, InvocationOptions options); + /** + * EXPERIMENTAL API: Send a one-way invocation without waiting for the response. + * + *
{@code
+   * // Send without waiting for response
+   * InvocationHandle handle = Restate.service(Greeter.class)
+   *   .send(Greeter::greet, new Greeting("Alice"));
+   * String invocationId = handle.invocationId();
+   *
+   * // Send with a delay
+   * InvocationHandle handle = Restate.service(Greeter.class)
+   *   .send(Greeter::greet, new Greeting("Alice"), Duration.ofMinutes(5));
+   * }
+ * + * @param s method reference (e.g., {@code Greeter::greet}) + * @param input the input parameter to pass to the method + * @return an {@link InvocationHandle} for the invocation + */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(BiFunction s, I input) { return send(s, input, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send( BiFunction s, I input, InvocationOptions.Builder options) { return send(s, input, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send( BiFunction s, I input, InvocationOptions options) { return send(s, input, null, options); } + /** EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(BiFunction s, I input, Duration delay) { return send(s, input, delay, InvocationOptions.DEFAULT); } + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with a delay and invocation + * options. + */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send( BiFunction s, I input, Duration delay, InvocationOptions.Builder options) { return send(s, input, delay, options.build()); } + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, with a delay and invocation + * options. + */ @org.jetbrains.annotations.ApiStatus.Experimental InvocationHandle send( BiFunction s, I input, Duration delay, InvocationOptions options); + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without a return + * value. + */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(BiConsumer s, I input) { return send(s, input, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send( BiConsumer s, I input, InvocationOptions.Builder options) { return send(s, input, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send( BiConsumer s, I input, InvocationOptions options) { return send(s, input, null, options); } + /** EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(BiConsumer s, I input, Duration delay) { return send(s, input, delay, InvocationOptions.DEFAULT); } + /** + * EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with a delay and invocation + * options. + */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send( BiConsumer s, I input, Duration delay, InvocationOptions.Builder options) { return send(s, input, delay, options.build()); } + /** + * EXPERIMENTAL API: Like {@link #send(BiConsumer, Object)}, with a delay and invocation + * options. + */ @org.jetbrains.annotations.ApiStatus.Experimental InvocationHandle send( BiConsumer s, I input, Duration delay, InvocationOptions options); + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without input. + */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(Function s) { return send(s, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(Function s, InvocationOptions.Builder options) { return send(s, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(Function s, InvocationOptions options) { return send(s, null, options); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(Function s, Duration delay) { return send(s, delay, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with a delay and invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send( Function s, Duration delay, InvocationOptions.Builder options) { return send(s, delay, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(Function)}, with a delay and invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental InvocationHandle send(Function s, Duration delay, InvocationOptions options); + /** + * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without input or + * return value. + */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(Consumer s) { return send(s, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(Consumer s, InvocationOptions.Builder options) { return send(s, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(Consumer s, InvocationOptions options) { return send(s, null, options); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with a delay. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(Consumer s, Duration delay) { return send(s, delay, InvocationOptions.DEFAULT); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with a delay and invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send( Consumer s, Duration delay, InvocationOptions.Builder options) { return send(s, delay, options.build()); } + /** EXPERIMENTAL API: Like {@link #send(Consumer)}, with a delay and invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental InvocationHandle send(Consumer s, Duration delay, InvocationOptions options); } diff --git a/sdk-api/src/main/java/dev/restate/sdk/internal/ContextThreadLocal.java b/sdk-api/src/main/java/dev/restate/sdk/internal/ContextThreadLocal.java new file mode 100644 index 00000000..867bef0b --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/internal/ContextThreadLocal.java @@ -0,0 +1,25 @@ +package dev.restate.sdk.internal; + +import dev.restate.sdk.Context; + +import java.util.Objects; + +@org.jetbrains.annotations.ApiStatus.Internal +@org.jetbrains.annotations.ApiStatus.Experimental +public final class ContextThreadLocal { + public static final ThreadLocal CONTEXT_THREAD_LOCAL = new ThreadLocal<>(); + + public static Context getContext() { + return Objects.requireNonNull( + CONTEXT_THREAD_LOCAL.get(), + "Restate methods must be invoked from within a Restate handler"); + } + + public static void setContext(Context context) { + CONTEXT_THREAD_LOCAL.set(context); + } + + public static void clearContext() { + CONTEXT_THREAD_LOCAL.remove(); + } +} diff --git a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeRestate.java b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeRestate.java new file mode 100644 index 00000000..3d71d43a --- /dev/null +++ b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeRestate.java @@ -0,0 +1,122 @@ +package dev.restate.sdk.fake; + +import dev.restate.common.function.ThrowingRunnable; +import dev.restate.common.function.ThrowingSupplier; +import dev.restate.sdk.ContextInternal; +import dev.restate.sdk.endpoint.definition.HandlerRunner; +import dev.restate.sdk.internal.ContextThreadLocal; + +/** + * Fake Restate environment for testing handlers using the new reflection API. + * + *

This class provides utility methods to execute service methods that use the new reflection + * API (without explicit Context parameters) in a fake Restate context for testing purposes. + * + *

Example usage: + * + *

{@code
+ * @Test
+ * public void testGreeter() {
+ *     GreeterService greeter = new GreeterService();
+ *
+ *     // Execute the service method in a fake Restate context
+ *     String response = FakeRestate.execute(() -> greeter.greet(new Greeting("Francesco")));
+ *
+ *     assertEquals("You said hi to Francesco!", response);
+ * }
+ * }
+ * + *

For advanced scenarios, you can customize the context behavior using {@link + * ContextExpectations}: + * + *

{@code
+ * @Test
+ * public void testWithExpectations() {
+ *     GreeterService greeter = new GreeterService();
+ *
+ *     ContextExpectations expectations = new ContextExpectations()
+ *         .withRandom(new Random(42));
+ *
+ *     String response = FakeRestate.execute(expectations, () -> greeter.greet(new Greeting("Alice")));
+ *
+ *     assertEquals("Expected response", response);
+ * }
+ * }
+ */ +@org.jetbrains.annotations.ApiStatus.Experimental +public final class FakeRestate { + + /** + * Execute a runnable in a fake Restate context with default expectations. + * + * @param runnable the code to execute + */ + public static void execute(ThrowingRunnable runnable) { + execute(new ContextExpectations(), runnable); + } + + /** + * Execute a runnable in a fake Restate context with custom expectations. + * + * @param expectations the context expectations to use + * @param runnable the code to execute + */ + public static void execute(ContextExpectations expectations, ThrowingRunnable runnable) { + var fakeHandlerContext = new FakeHandlerContext(expectations); + var fakeContext = + ContextInternal.createContext( + fakeHandlerContext, Runnable::run, expectations.serdeFactory()); + HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.set(fakeHandlerContext); + ContextThreadLocal.setContext(fakeContext); + try { +runnable.run(); + } catch (Throwable e) { + sneakyThrow(e); + } finally { + ContextThreadLocal.clearContext(); + HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.remove(); + } + } + + /** + * Execute a supplier in a fake Restate context with default expectations and return the result. + * + * @param runnable the code to execute + * @param the return type + * @return the result of the supplier + */ + public static T execute( ThrowingSupplier runnable) { + return execute(new ContextExpectations(), runnable); + } + + /** + * Execute a supplier in a fake Restate context with custom expectations and return the result. + * + * @param expectations the context expectations to use + * @param runnable the code to execute + * @param the return type + * @return the result of the supplier + */ + public static T execute(ContextExpectations expectations, ThrowingSupplier runnable) { + var fakeHandlerContext = new FakeHandlerContext(expectations); + var fakeContext = + ContextInternal.createContext( + fakeHandlerContext, Runnable::run, expectations.serdeFactory()); + HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.set(fakeHandlerContext); + ContextThreadLocal.setContext(fakeContext); + try { + return runnable.get(); + } catch (Throwable e) { + sneakyThrow(e); + return null; + } finally { + ContextThreadLocal.clearContext(); + HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.remove(); + } + } + + @SuppressWarnings("unchecked") + private static void sneakyThrow(Throwable e) throws E { + throw (E) e; + } +} From a0cf91a90ee271880be2c18aa1a890b18118c749 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 12 Jan 2026 12:40:58 +0100 Subject: [PATCH 7/7] Lil refactor of ProxySupport, added test with spring --- .../proxysupport/ByteBuddyProxyFactory.java | 16 +- ...v.restate.common.reflections.ProxyFactory} | 0 .../main/java/dev/restate/client/Client.java | 4 +- .../client/ClientServiceReference.java | 33 ++-- .../reflections/MethodInfoCollector.java | 2 +- .../common/reflections/ProxyFactory.java | 30 ++++ .../common/reflections/ProxySupport.java | 21 +-- .../model/AnnotationProcessingOptions.java | 4 +- .../dev/restate/sdk/gen/ServiceProcessor.java | 7 + sdk-api/build.gradle.kts | 2 + .../java/dev/restate/sdk/HandlerRunner.java | 11 +- .../main/java/dev/restate/sdk/Restate.java | 8 +- .../dev/restate/sdk/ServiceReference.java | 16 +- .../sdk/internal/ContextThreadLocal.java | 33 ++-- .../restate/sdk/core/EndpointManifest.java | 6 +- .../dev/restate/sdk/fake/FakeRestate.java | 146 +++++++++--------- sdk-spring-boot-starter/build.gradle.kts | 11 +- .../sdk/springboot/java/GreeterNewApi.java | 29 ++++ .../java/SdkTestingIntegrationTest.java | 11 +- 19 files changed, 242 insertions(+), 148 deletions(-) rename bytebuddy-proxy-support/src/main/resources/META-INF/services/{dev.restate.common.reflections.ProxySupport.ProxyFactory => dev.restate.common.reflections.ProxyFactory} (100%) create mode 100644 common/src/main/java/dev/restate/common/reflections/ProxyFactory.java create mode 100644 sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/GreeterNewApi.java diff --git a/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java index 879351da..f2252d3a 100644 --- a/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java +++ b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.bytebuddy.proxysupport; -import dev.restate.common.reflections.ProxySupport; +import dev.restate.common.reflections.ProxyFactory; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; @@ -29,7 +29,7 @@ * constructor. Uses TypeCache to cache generated proxy classes for better performance * (thread-safe). */ -public final class ByteBuddyProxyFactory implements ProxySupport.ProxyFactory { +public final class ByteBuddyProxyFactory implements ProxyFactory { private static final String INTERCEPTOR_FIELD_NAME = "$$interceptor$$"; @@ -39,7 +39,7 @@ public final class ByteBuddyProxyFactory implements ProxySupport.ProxyFactory { @Override @SuppressWarnings("unchecked") - public @Nullable T createProxy(Class clazz, ProxySupport.MethodInterceptor interceptor) { + public @Nullable T createProxy(Class clazz, MethodInterceptor interceptor) { // Cannot proxy final classes if (Modifier.isFinal(clazz.getModifiers())) { return null; @@ -79,8 +79,7 @@ private Class generateProxyClass(Class clazz) { try (var unloaded = builder // Add a field to store the interceptor - .defineField( - INTERCEPTOR_FIELD_NAME, ProxySupport.MethodInterceptor.class, Visibility.PUBLIC) + .defineField(INTERCEPTOR_FIELD_NAME, MethodInterceptor.class, Visibility.PUBLIC) // Intercept all methods .method(ElementMatchers.any()) .intercept( @@ -89,15 +88,14 @@ private Class generateProxyClass(Class clazz) { // Get the interceptor from the field Field field = proxy.getClass().getDeclaredField(INTERCEPTOR_FIELD_NAME); field.setAccessible(true); - ProxySupport.MethodInterceptor interceptor = - (ProxySupport.MethodInterceptor) field.get(proxy); + MethodInterceptor interceptor = (MethodInterceptor) field.get(proxy); if (interceptor == null) { throw new IllegalStateException("Interceptor not set on proxy instance"); } - ProxySupport.MethodInvocation invocation = - new ProxySupport.MethodInvocation() { + MethodInvocation invocation = + new MethodInvocation() { @Override public Object[] getArguments() { return args != null ? args : new Object[0]; diff --git a/bytebuddy-proxy-support/src/main/resources/META-INF/services/dev.restate.common.reflections.ProxySupport.ProxyFactory b/bytebuddy-proxy-support/src/main/resources/META-INF/services/dev.restate.common.reflections.ProxyFactory similarity index 100% rename from bytebuddy-proxy-support/src/main/resources/META-INF/services/dev.restate.common.reflections.ProxySupport.ProxyFactory rename to bytebuddy-proxy-support/src/main/resources/META-INF/services/dev.restate.common.reflections.ProxyFactory diff --git a/client/src/main/java/dev/restate/client/Client.java b/client/src/main/java/dev/restate/client/Client.java index db98dca7..28df265b 100644 --- a/client/src/main/java/dev/restate/client/Client.java +++ b/client/src/main/java/dev/restate/client/Client.java @@ -594,8 +594,8 @@ default ClientServiceReference virtualObject(Class clazz, String } /** - * EXPERIMENTAL API: Create a reference to invoke a Restate Workflow from the ingress. - * This API may change in future releases. + * EXPERIMENTAL API: Create a reference to invoke a Restate Workflow from the ingress. This + * API may change in future releases. * *

You can invoke the workflow in three ways: * diff --git a/client/src/main/java/dev/restate/client/ClientServiceReference.java b/client/src/main/java/dev/restate/client/ClientServiceReference.java index 41a50bb7..5e2c4ae3 100644 --- a/client/src/main/java/dev/restate/client/ClientServiceReference.java +++ b/client/src/main/java/dev/restate/client/ClientServiceReference.java @@ -18,8 +18,8 @@ import java.util.function.Function; /** - * EXPERIMENTAL API: This interface is part of the new reflection-based API and may change - * in future releases. + * EXPERIMENTAL API: This interface is part of the new reflection-based API and may change in + * future releases. * *

A reference to a Restate service, virtual object, or workflow that can be invoked from the * ingress (outside of a handler). Provides three ways to invoke methods: @@ -40,8 +40,8 @@ * .send(Greeter::greet, new Greeting("Alice")); * } * - *

Create instances using {@link Client#service(Class)}, {@link - * Client#virtualObject(Class, String)}, or {@link Client#workflow(Class, String)}. + *

Create instances using {@link Client#service(Class)}, {@link Client#virtualObject(Class, + * String)}, or {@link Client#workflow(Class, String)}. * * @param the service interface type */ @@ -356,9 +356,7 @@ default SendResponse send( } // send - Function variants - /** - * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without input. - */ + /** EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without input. */ @org.jetbrains.annotations.ApiStatus.Experimental default SendResponse send(Function s) { return send(s, InvocationOptions.DEFAULT); @@ -480,14 +478,18 @@ default CompletableFuture> sendAsync( return sendAsync(s, input, delay, InvocationOptions.DEFAULT); } - /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, with delay and options. */ + /** + * EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, with delay and options. + */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiFunction s, I input, Duration delay, InvocationOptions.Builder options) { return sendAsync(s, input, delay, options.build()); } - /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, with delay and options. */ + /** + * EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, with delay and options. + */ @org.jetbrains.annotations.ApiStatus.Experimental CompletableFuture> sendAsync( BiFunction s, I input, Duration delay, InvocationOptions invocationOptions); @@ -520,14 +522,18 @@ default CompletableFuture> sendAsync( return sendAsync(s, input, delay, InvocationOptions.DEFAULT); } - /** EXPERIMENTAL API: Like {@link #sendAsync(BiConsumer, Object)}, with delay and options. */ + /** + * EXPERIMENTAL API: Like {@link #sendAsync(BiConsumer, Object)}, with delay and options. + */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync( BiConsumer s, I input, Duration delay, InvocationOptions.Builder options) { return sendAsync(s, input, delay, options.build()); } - /** EXPERIMENTAL API: Like {@link #sendAsync(BiConsumer, Object)}, with delay and options. */ + /** + * EXPERIMENTAL API: Like {@link #sendAsync(BiConsumer, Object)}, with delay and options. + */ @org.jetbrains.annotations.ApiStatus.Experimental CompletableFuture> sendAsync( BiConsumer s, I input, Duration delay, InvocationOptions invocationOptions); @@ -572,7 +578,10 @@ CompletableFuture> sendAsync( Function s, Duration delay, InvocationOptions invocationOptions); // sendAsync - Consumer variants - /** EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, for no-input/void methods. */ + /** + * EXPERIMENTAL API: Like {@link #sendAsync(BiFunction, Object)}, for no-input/void + * methods. + */ @org.jetbrains.annotations.ApiStatus.Experimental default CompletableFuture> sendAsync(Consumer s) { return sendAsync(s, InvocationOptions.DEFAULT); diff --git a/common/src/main/java/dev/restate/common/reflections/MethodInfoCollector.java b/common/src/main/java/dev/restate/common/reflections/MethodInfoCollector.java index 9f5d2196..797e92a7 100644 --- a/common/src/main/java/dev/restate/common/reflections/MethodInfoCollector.java +++ b/common/src/main/java/dev/restate/common/reflections/MethodInfoCollector.java @@ -61,7 +61,7 @@ public MethodInfo resolve(Consumer s) { } } - private static final ProxySupport.MethodInterceptor METHOD_INFO_COLLECTOR_INTERCEPTOR = + private static final ProxyFactory.MethodInterceptor METHOD_INFO_COLLECTOR_INTERCEPTOR = invocation -> { throw MethodInfo.fromMethod(invocation.getMethod()); }; diff --git a/common/src/main/java/dev/restate/common/reflections/ProxyFactory.java b/common/src/main/java/dev/restate/common/reflections/ProxyFactory.java new file mode 100644 index 00000000..b862aaea --- /dev/null +++ b/common/src/main/java/dev/restate/common/reflections/ProxyFactory.java @@ -0,0 +1,30 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflections; + +import java.lang.reflect.Method; +import org.jspecify.annotations.Nullable; + +@FunctionalInterface +public interface ProxyFactory { + + interface MethodInvocation { + Object[] getArguments(); + + Method getMethod(); + } + + @FunctionalInterface + interface MethodInterceptor { + @Nullable Object invoke(MethodInvocation invocation) throws Throwable; + } + + /** If returns null, it's not supported. */ + @Nullable T createProxy(Class clazz, MethodInterceptor interceptor); +} diff --git a/common/src/main/java/dev/restate/common/reflections/ProxySupport.java b/common/src/main/java/dev/restate/common/reflections/ProxySupport.java index 4b0749c6..24cacb13 100644 --- a/common/src/main/java/dev/restate/common/reflections/ProxySupport.java +++ b/common/src/main/java/dev/restate/common/reflections/ProxySupport.java @@ -48,7 +48,7 @@ public ProxySupport() { } /** Resolve the code generated {@link ProxyFactory} */ - public static T createProxy(Class clazz, MethodInterceptor interceptor) { + public static T createProxy(Class clazz, ProxyFactory.MethodInterceptor interceptor) { ProxySupport proxySupport = ProxySupportSingleton.INSTANCE; for (ProxyFactory proxyFactory : proxySupport.factories) { @@ -61,29 +61,12 @@ public static T createProxy(Class clazz, MethodInterceptor interceptor) { throw new IllegalStateException( "Class " + clazz.toString() - + " cannot be proxied. If the type is a concrete class, make sure to have sdk-proxy-bytebuddy in your dependencies. Registered proxies: " + + " cannot be proxied. If the type is a concrete class, make sure to have bytebuddy-proxy-support in your dependencies. Registered ProxyFactory: " + proxySupport.factories.stream() .map(pf -> pf.getClass().toString()) .collect(Collectors.joining(", "))); } - public interface MethodInvocation { - Object[] getArguments(); - - Method getMethod(); - } - - @FunctionalInterface - public interface MethodInterceptor { - @Nullable Object invoke(MethodInvocation invocation) throws Throwable; - } - - @FunctionalInterface - public interface ProxyFactory { - /** If returns null, it's not supported. */ - @Nullable T createProxy(Class clazz, MethodInterceptor interceptor); - } - private static final class JdkProxyFactory implements ProxyFactory { /** diff --git a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/AnnotationProcessingOptions.java b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/AnnotationProcessingOptions.java index 35328c36..61ba4285 100644 --- a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/AnnotationProcessingOptions.java +++ b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/AnnotationProcessingOptions.java @@ -12,10 +12,10 @@ public class AnnotationProcessingOptions { - private static final String DISABLED_CLIENT_GENERATION = + public static final String DISABLED_CLIENT_GENERATION = "dev.restate.codegen.disabledClientGeneration"; - private static final String DISABLED_CLASSES = "dev.restate.codegen.disabledClasses"; + public static final String DISABLED_CLASSES = "dev.restate.codegen.disabledClasses"; private final Set disabledClientGenFQCN; private final Set disabledClasses; diff --git a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java index 46320b60..b2fadc2b 100644 --- a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java +++ b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java @@ -163,6 +163,13 @@ public SourceVersion getSupportedSourceVersion() { return SourceVersion.latestSupported(); } + @Override + public Set getSupportedOptions() { + return Set.of( + AnnotationProcessingOptions.DISABLED_CLASSES, + AnnotationProcessingOptions.DISABLED_CLIENT_GENERATION); + } + public static Path readOrCreateResource(Filer filer, String file) throws IOException { try { FileObject fileObject = filer.getResource(StandardLocation.CLASS_OUTPUT, "", file); diff --git a/sdk-api/build.gradle.kts b/sdk-api/build.gradle.kts index 672b6dad..88fe65a8 100644 --- a/sdk-api/build.gradle.kts +++ b/sdk-api/build.gradle.kts @@ -14,4 +14,6 @@ dependencies { api(project(":sdk-serde-jackson")) implementation(libs.log4j.api) + + runtimeOnly(project(":bytebuddy-proxy-support")) { isTransitive = true } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java index 48854481..fbb1fcbf 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java +++ b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java @@ -19,7 +19,6 @@ import dev.restate.serde.Serde; import dev.restate.serde.SerdeFactory; import io.opentelemetry.context.Scope; - import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -204,9 +203,9 @@ public static Options withExecutor(Executor executor) { } } - static HandlerContext getHandlerContext() { - return Objects.requireNonNull( - dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get(), - "Restate methods must be invoked from within a Restate handler"); - } + static HandlerContext getHandlerContext() { + return Objects.requireNonNull( + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get(), + "Restate methods must be invoked from within a Restate handler"); + } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/Restate.java b/sdk-api/src/main/java/dev/restate/sdk/Restate.java index 85e6b896..9132de99 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Restate.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Restate.java @@ -198,7 +198,9 @@ public static boolean isInsideHandler() { return ContextThreadLocal.CONTEXT_THREAD_LOCAL.get() != null; } - /** @see Context#request() */ + /** + * @see Context#request() + */ @org.jetbrains.annotations.ApiStatus.Experimental public static HandlerRequest request() { return context().request(); @@ -215,7 +217,9 @@ public static RestateRandom random() { return context().random(); } - /** @see Context#invocationHandle(String, TypeTag) */ + /** + * @see Context#invocationHandle(String, TypeTag) + */ @org.jetbrains.annotations.ApiStatus.Experimental public static InvocationHandle invocationHandle( String invocationId, TypeTag responseTypeTag) { diff --git a/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java b/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java index 664644ba..b43a6cbb 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ServiceReference.java @@ -16,8 +16,8 @@ import java.util.function.Function; /** - * EXPERIMENTAL API: This interface is part of the new reflection-based API and may change - * in future releases. + * EXPERIMENTAL API: This interface is part of the new reflection-based API and may change in + * future releases. * *

A reference to a Restate service, virtual object, or workflow that can be invoked from within * a handler. Provides three ways to invoke methods: @@ -37,8 +37,8 @@ * .send(Greeter::greet, new Greeting("Alice")); * } * - *

Create instances using {@link Restate#service(Class)}, {@link - * Restate#virtualObject(Class, String)}, or {@link Restate#workflow(Class, String)}. + *

Create instances using {@link Restate#service(Class)}, {@link Restate#virtualObject(Class, + * String)}, or {@link Restate#workflow(Class, String)}. * * @param the service interface type */ @@ -96,9 +96,7 @@ default DurableFuture call( return call(s, input, options.build()); } - /** - * EXPERIMENTAL API: Like {@link #call(BiFunction, Object)}, with invocation options. - */ + /** EXPERIMENTAL API: Like {@link #call(BiFunction, Object)}, with invocation options. */ @org.jetbrains.annotations.ApiStatus.Experimental DurableFuture call(BiFunction s, I input, InvocationOptions options); @@ -275,9 +273,7 @@ default InvocationHandle send( InvocationHandle send( BiConsumer s, I input, Duration delay, InvocationOptions options); - /** - * EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without input. - */ + /** EXPERIMENTAL API: Like {@link #send(BiFunction, Object)}, for methods without input. */ @org.jetbrains.annotations.ApiStatus.Experimental default InvocationHandle send(Function s) { return send(s, InvocationOptions.DEFAULT); diff --git a/sdk-api/src/main/java/dev/restate/sdk/internal/ContextThreadLocal.java b/sdk-api/src/main/java/dev/restate/sdk/internal/ContextThreadLocal.java index 867bef0b..4a761951 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/internal/ContextThreadLocal.java +++ b/sdk-api/src/main/java/dev/restate/sdk/internal/ContextThreadLocal.java @@ -1,25 +1,32 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.internal; import dev.restate.sdk.Context; - import java.util.Objects; @org.jetbrains.annotations.ApiStatus.Internal @org.jetbrains.annotations.ApiStatus.Experimental public final class ContextThreadLocal { - public static final ThreadLocal CONTEXT_THREAD_LOCAL = new ThreadLocal<>(); + public static final ThreadLocal CONTEXT_THREAD_LOCAL = new ThreadLocal<>(); - public static Context getContext() { - return Objects.requireNonNull( - CONTEXT_THREAD_LOCAL.get(), - "Restate methods must be invoked from within a Restate handler"); - } + public static Context getContext() { + return Objects.requireNonNull( + CONTEXT_THREAD_LOCAL.get(), + "Restate methods must be invoked from within a Restate handler"); + } - public static void setContext(Context context) { - CONTEXT_THREAD_LOCAL.set(context); - } + public static void setContext(Context context) { + CONTEXT_THREAD_LOCAL.set(context); + } - public static void clearContext() { - CONTEXT_THREAD_LOCAL.remove(); - } + public static void clearContext() { + CONTEXT_THREAD_LOCAL.remove(); + } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java index a9d1fabd..1b6b513f 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java @@ -22,6 +22,7 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.jspecify.annotations.Nullable; final class EndpointManifest { @@ -278,7 +279,10 @@ private static Output convertHandlerOutput(HandlerDefinition def) { return output; } - private static Handler.Ty convertHandlerType(HandlerType handlerType) { + private static Handler.Ty convertHandlerType(@Nullable HandlerType handlerType) { + if (handlerType == null) { + return null; + } return switch (handlerType) { case WORKFLOW -> Handler.Ty.WORKFLOW; case EXCLUSIVE -> Handler.Ty.EXCLUSIVE; diff --git a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeRestate.java b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeRestate.java index 3d71d43a..afda2f2e 100644 --- a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeRestate.java +++ b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeRestate.java @@ -1,3 +1,11 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.fake; import dev.restate.common.function.ThrowingRunnable; @@ -9,8 +17,8 @@ /** * Fake Restate environment for testing handlers using the new reflection API. * - *

This class provides utility methods to execute service methods that use the new reflection - * API (without explicit Context parameters) in a fake Restate context for testing purposes. + *

This class provides utility methods to execute service methods that use the new reflection API + * (without explicit Context parameters) in a fake Restate context for testing purposes. * *

Example usage: * @@ -46,77 +54,77 @@ @org.jetbrains.annotations.ApiStatus.Experimental public final class FakeRestate { - /** - * Execute a runnable in a fake Restate context with default expectations. - * - * @param runnable the code to execute - */ - public static void execute(ThrowingRunnable runnable) { - execute(new ContextExpectations(), runnable); - } + /** + * Execute a runnable in a fake Restate context with default expectations. + * + * @param runnable the code to execute + */ + public static void execute(ThrowingRunnable runnable) { + execute(new ContextExpectations(), runnable); + } - /** - * Execute a runnable in a fake Restate context with custom expectations. - * - * @param expectations the context expectations to use - * @param runnable the code to execute - */ - public static void execute(ContextExpectations expectations, ThrowingRunnable runnable) { - var fakeHandlerContext = new FakeHandlerContext(expectations); - var fakeContext = - ContextInternal.createContext( - fakeHandlerContext, Runnable::run, expectations.serdeFactory()); - HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.set(fakeHandlerContext); - ContextThreadLocal.setContext(fakeContext); - try { -runnable.run(); - } catch (Throwable e) { - sneakyThrow(e); - } finally { - ContextThreadLocal.clearContext(); - HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.remove(); - } + /** + * Execute a runnable in a fake Restate context with custom expectations. + * + * @param expectations the context expectations to use + * @param runnable the code to execute + */ + public static void execute(ContextExpectations expectations, ThrowingRunnable runnable) { + var fakeHandlerContext = new FakeHandlerContext(expectations); + var fakeContext = + ContextInternal.createContext( + fakeHandlerContext, Runnable::run, expectations.serdeFactory()); + HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.set(fakeHandlerContext); + ContextThreadLocal.setContext(fakeContext); + try { + runnable.run(); + } catch (Throwable e) { + sneakyThrow(e); + } finally { + ContextThreadLocal.clearContext(); + HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.remove(); } + } - /** - * Execute a supplier in a fake Restate context with default expectations and return the result. - * - * @param runnable the code to execute - * @param the return type - * @return the result of the supplier - */ - public static T execute( ThrowingSupplier runnable) { - return execute(new ContextExpectations(), runnable); - } + /** + * Execute a supplier in a fake Restate context with default expectations and return the result. + * + * @param runnable the code to execute + * @param the return type + * @return the result of the supplier + */ + public static T execute(ThrowingSupplier runnable) { + return execute(new ContextExpectations(), runnable); + } - /** - * Execute a supplier in a fake Restate context with custom expectations and return the result. - * - * @param expectations the context expectations to use - * @param runnable the code to execute - * @param the return type - * @return the result of the supplier - */ - public static T execute(ContextExpectations expectations, ThrowingSupplier runnable) { - var fakeHandlerContext = new FakeHandlerContext(expectations); - var fakeContext = - ContextInternal.createContext( - fakeHandlerContext, Runnable::run, expectations.serdeFactory()); - HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.set(fakeHandlerContext); - ContextThreadLocal.setContext(fakeContext); - try { - return runnable.get(); - } catch (Throwable e) { - sneakyThrow(e); - return null; - } finally { - ContextThreadLocal.clearContext(); - HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.remove(); - } + /** + * Execute a supplier in a fake Restate context with custom expectations and return the result. + * + * @param expectations the context expectations to use + * @param runnable the code to execute + * @param the return type + * @return the result of the supplier + */ + public static T execute(ContextExpectations expectations, ThrowingSupplier runnable) { + var fakeHandlerContext = new FakeHandlerContext(expectations); + var fakeContext = + ContextInternal.createContext( + fakeHandlerContext, Runnable::run, expectations.serdeFactory()); + HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.set(fakeHandlerContext); + ContextThreadLocal.setContext(fakeContext); + try { + return runnable.get(); + } catch (Throwable e) { + sneakyThrow(e); + return null; + } finally { + ContextThreadLocal.clearContext(); + HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.remove(); } + } - @SuppressWarnings("unchecked") - private static void sneakyThrow(Throwable e) throws E { - throw (E) e; - } + @SuppressWarnings("unchecked") + private static void sneakyThrow(Throwable e) throws E { + throw (E) e; + } } diff --git a/sdk-spring-boot-starter/build.gradle.kts b/sdk-spring-boot-starter/build.gradle.kts index 58543d7d..8f263b0e 100644 --- a/sdk-spring-boot-starter/build.gradle.kts +++ b/sdk-spring-boot-starter/build.gradle.kts @@ -23,6 +23,7 @@ dependencies { api(project(":sdk-api"), excludeJackson) api(project(":client"), excludeJackson) api(project(":sdk-serde-jackson"), excludeJackson) + runtimeOnly(project(":bytebuddy-proxy-support")) { isTransitive = true } // Spring boot starter brought in here for convenience api(libs.spring.boot.starter) @@ -38,4 +39,12 @@ dependencies { testImplementation(project(":sdk-testing")) } -tasks.withType { options.compilerArgs.add("-parameters") } +tasks.withType { + val disabledClassesCodegen = listOf("dev.restate.sdk.springboot.java.GreeterNewApi") + + options.compilerArgs.addAll( + listOf( + "-parameters", + "-Adev.restate.codegen.disabledClasses=${disabledClassesCodegen.joinToString(",")}", + )) +} diff --git a/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/GreeterNewApi.java b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/GreeterNewApi.java new file mode 100644 index 00000000..ad64db27 --- /dev/null +++ b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/GreeterNewApi.java @@ -0,0 +1,29 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.springboot.java; + +import dev.restate.sdk.annotation.Handler; +import dev.restate.sdk.annotation.Name; +import dev.restate.sdk.annotation.Service; +import dev.restate.sdk.springboot.RestateComponent; +import org.springframework.beans.factory.annotation.Value; + +@Service +@RestateComponent +@Name("greeterNewApi") +public class GreeterNewApi { + + @Value("${greetingPrefix}") + private String greetingPrefix; + + @Handler + public String greet(String person) { + return greetingPrefix + person; + } +} diff --git a/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/SdkTestingIntegrationTest.java b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/SdkTestingIntegrationTest.java index 5c5fe57e..20e2d86c 100644 --- a/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/SdkTestingIntegrationTest.java +++ b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/SdkTestingIntegrationTest.java @@ -18,12 +18,13 @@ import org.springframework.boot.test.context.SpringBootTest; @SpringBootTest( - classes = Greeter.class, + classes = {Greeter.class, GreeterNewApi.class}, properties = {"greetingPrefix=Something something "}) @RestateTest(containerImage = "ghcr.io/restatedev/restate:main") public class SdkTestingIntegrationTest { @Autowired @BindService private Greeter greeter; + @Autowired @BindService private GreeterNewApi greeterNewApi; @Test @Timeout(value = 10) @@ -32,4 +33,12 @@ void greet(@RestateClient Client ingressClient) { assertThat(client.greet("Francesco")).isEqualTo("Something something Francesco"); } + + @Test + @Timeout(value = 10) + void greetNewApi(@RestateClient Client ingressClient) { + var client = ingressClient.service(GreeterNewApi.class).client(); + + assertThat(client.greet("Francesco")).isEqualTo("Something something Francesco"); + } }