Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bytebuddy-proxy-support/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ dependencies {
implementation(project(":common"))
implementation(libs.bytebuddy)
implementation(libs.objenesis)

testImplementation(libs.junit.jupiter)
testImplementation(libs.assertj)
}

tasks.withType<Javadoc> { isFailOnError = false }
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
import static net.bytebuddy.matcher.ElementMatchers.*;

import dev.restate.common.reflections.ProxyFactory;
import dev.restate.common.reflections.ReflectionUtils;
import dev.restate.sdk.annotation.Exclusive;
import dev.restate.sdk.annotation.Handler;
import dev.restate.sdk.annotation.Shared;
import dev.restate.sdk.annotation.Workflow;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import net.bytebuddy.ByteBuddy;
import net.bytebuddy.TypeCache;
import net.bytebuddy.description.modifier.Visibility;
import net.bytebuddy.dynamic.scaffold.TypeValidation;
import net.bytebuddy.implementation.ExceptionMethod;
import net.bytebuddy.implementation.InvocationHandlerAdapter;
import org.jspecify.annotations.Nullable;
import org.objenesis.Objenesis;
Expand All @@ -39,6 +42,7 @@ public final class ByteBuddyProxyFactory implements ProxyFactory {
private static final String INTERCEPTOR_FIELD_NAME = "$$interceptor$$";

private final Objenesis objenesis = new ObjenesisStd();
private final ByteBuddy byteBuddy = new ByteBuddy().with(TypeValidation.ENABLED);
private final TypeCache<Class<?>> proxyClassCache =
new TypeCache.WithInlineExpunction<>(TypeCache.Sort.SOFT);

Expand All @@ -62,64 +66,106 @@ public final class ByteBuddyProxyFactory implements ProxyFactory {

// Set the interceptor field
Field interceptorField = proxyClass.getDeclaredField(INTERCEPTOR_FIELD_NAME);
interceptorField.setAccessible(true);
interceptorField.set(proxyInstance, interceptor);
interceptorField.set(
proxyInstance,
(InvocationHandler)
(proxy, method, args) -> {
MethodInvocation invocation =
new MethodInvocation() {
@Override
public Object[] getArguments() {
return args != null ? args : new Object[0];
}

@Override
public Method getMethod() {
return method;
}
};
return interceptor.invoke(invocation);
});

return proxyInstance;
} catch (Exception e) {
throw new IllegalArgumentException("Cannot create proxy for class " + clazz, e);
}
}

private <T> Class<?> generateProxyClass(Class<T> clazz) {
ByteBuddy byteBuddy = new ByteBuddy().with(TypeValidation.ENABLED);
private <T> Class<?> generateProxyClass(Class<T> clazz) throws NoSuchFieldException {
if (!clazz.isInterface()) {
// We perform here some additional validation of the handlers that won't be executed by
// bytebuddy and can easily lead to strange behavior
var methods =
ReflectionUtils.getUniqueDeclaredMethods(
clazz,
method ->
ReflectionUtils.findAnnotation(method, Handler.class) != null
|| ReflectionUtils.findAnnotation(method, Shared.class) != null
|| ReflectionUtils.findAnnotation(method, Workflow.class) != null
|| ReflectionUtils.findAnnotation(method, Exclusive.class) != null);
for (var method : methods) {
validateMethod(method);
}
}

var builder =
clazz.isInterface()
? byteBuddy.subclass(Object.class).implement(clazz)
: byteBuddy.subclass(clazz);

var annotationMatcher =
isAnnotatedWith(Handler.class)
.or(isAnnotatedWith(Exclusive.class))
.or(isAnnotatedWith(Shared.class))
.or(isAnnotatedWith(Workflow.class));
try (var unloaded =
builder
// Add a field to store the interceptor
.defineField(INTERCEPTOR_FIELD_NAME, MethodInterceptor.class, Visibility.PUBLIC)
.defineField(INTERCEPTOR_FIELD_NAME, InvocationHandler.class, Visibility.PUBLIC)
// Intercept all methods
.method(
isMethod()
.and(
isAnnotatedWith(Handler.class)
.or(isAnnotatedWith(Exclusive.class))
.or(isAnnotatedWith(Shared.class))
.or(isAnnotatedWith(Workflow.class))))
.method(annotationMatcher)
.intercept(InvocationHandlerAdapter.toField(INTERCEPTOR_FIELD_NAME))
.method(not(annotationMatcher))
.intercept(
InvocationHandlerAdapter.of(
(proxy, method, args) -> {
// Get the interceptor from the field
Field field = proxy.getClass().getDeclaredField(INTERCEPTOR_FIELD_NAME);
field.setAccessible(true);
MethodInterceptor interceptor = (MethodInterceptor) field.get(proxy);

if (interceptor == null) {
throw new IllegalStateException(
"Interceptor not set on proxy instance. This is a bug, please contact the developers.");
}

MethodInvocation invocation =
new MethodInvocation() {
@Override
public Object[] getArguments() {
return args != null ? args : new Object[0];
}

@Override
public Method getMethod() {
return method;
}
};
return interceptor.invoke(invocation);
}))
ExceptionMethod.throwing(
UnsupportedOperationException.class,
"Calling a method not annotated with a Restate handler annotation on the proxy class"))
.make()) {
return unloaded.load(clazz.getClassLoader()).getLoaded();

var proxyClazz = unloaded.load(clazz.getClassLoader()).getLoaded();

// Make sure the field is accessible
Field interceptorField = proxyClazz.getDeclaredField(INTERCEPTOR_FIELD_NAME);
interceptorField.setAccessible(true);
return proxyClazz;
}
}

private static void validateMethod(Method method) {
if (!Modifier.isPublic(method.getModifiers())) {
throw new IllegalArgumentException(
"Method '"
+ method.getDeclaringClass().getSimpleName()
+ "#"
+ method.getName()
+ "' MUST be public to be used as Restate handler. Modifiers:"
+ Modifier.toString(method.getModifiers()));
}
if (Modifier.isStatic(method.getModifiers())) {
throw new IllegalArgumentException(
"Method '"
+ method.getDeclaringClass().getSimpleName()
+ "#"
+ method.getName()
+ "' is static, cannot be used as Restate handler");
}
if (Modifier.isFinal(method.getModifiers())) {
throw new IllegalArgumentException(
"Method '"
+ method.getDeclaringClass().getSimpleName()
+ "#"
+ method.getName()
+ "' is final, cannot be used as Restate handler");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.bytebuddy.proxysupport;

import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Fail.fail;

import dev.restate.sdk.annotation.Handler;
import dev.restate.sdk.annotation.Service;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

public class ByteBuddyProxyFactoryTest {

@Service
public static class InvokeNonRestateMethod {
public void somethingElse() {}
}

@Test
@DisplayName("Invoking non restate method should fail")
public void badCallToNonRestateMethod() {
var proxyFactory = new ByteBuddyProxyFactory();
var proxy =
proxyFactory.createProxy(
InvokeNonRestateMethod.class,
invocation -> fail("Unexpected call to method interceptor"));

assertThatCode(() -> proxy.somethingElse())
.hasMessageContaining(
"Calling a method not annotated with a Restate handler annotation on the proxy class")
.isInstanceOf(UnsupportedOperationException.class);
}

@Service
public static class PackagePrivateMethod {
@Handler
void handler() {
fail("This code should not be executed");
}
}

@Test
@DisplayName("Package private method should fail")
public void packagePrivateMethod() {
var proxyFactory = new ByteBuddyProxyFactory();
assertThatCode(
() ->
proxyFactory.createProxy(
PackagePrivateMethod.class,
invocation -> fail("Unexpected call to method interceptor")))
.cause()
.cause()
.hasMessageContaining("MUST be public to be used as Restate handler");
}

@Service
public static class FinalMethod {
@Handler
public final void handler() {
fail("This code should not be executed");
}
}

@Test
@DisplayName("Final method should fail")
public void finalMethod() {
var proxyFactory = new ByteBuddyProxyFactory();
assertThatCode(
() ->
proxyFactory.createProxy(
FinalMethod.class, invocation -> fail("Unexpected call to method interceptor")))
.cause()
.cause()
.hasMessageContaining("is final");
}

@Service
public static final class FinalClass {
@Handler
public void handler() {
fail("This code should not be executed");
}
}

@Test
@DisplayName("Final class should fail")
public void finalClass() {
var proxyFactory = new ByteBuddyProxyFactory();
assertThatCode(
() ->
proxyFactory.createProxy(
FinalClass.class, invocation -> fail("Unexpected call to method interceptor")))
.hasMessageContaining("is final, cannot be proxied");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,7 @@ public ServiceDefinition create(
var handlerName = handlerInfo.name();
var genericParameterTypes = method.getGenericParameterTypes();
var parameterCount = method.getParameterCount();

if (!Modifier.isPublic(method.getModifiers())) {
throw new MalformedRestateServiceException(
serviceName,
"Handler method '"
+ handlerName
+ "' MUST be public, but method '"
+ method.getName()
+ "' has modifiers: "
+ Modifier.toString(method.getModifiers()));
}
validateMethod(method, serviceName);

if ((parameterCount == 1 || parameterCount == 2)
&& (genericParameterTypes[0].equals(Context.class)
Expand Down Expand Up @@ -231,6 +221,22 @@ public ServiceDefinition create(
return handlerDefinition;
}

private static void validateMethod(Method method, String serviceName) {
if (!Modifier.isPublic(method.getModifiers())) {
throw new MalformedRestateServiceException(
serviceName,
"Method '"
+ method.getName()
+ "' MUST be public to be used as Restate handler. Modifiers:"
+ Modifier.toString(method.getModifiers()));
}
if (Modifier.isStatic(method.getModifiers())) {
throw new MalformedRestateServiceException(
serviceName,
"Method '" + method.getName() + "' is static, cannot be used as Restate handler");
}
}

@SuppressWarnings({"unchecked", "rawtypes"})
private Serde<Object> resolveInputSerde(
Method method, SerdeFactory serdeFactory, String serviceName) {
Expand Down