diff --git a/buildSrc/call-site-instrumentation-plugin/build.gradle.kts b/buildSrc/call-site-instrumentation-plugin/build.gradle.kts index db2e49e26f5..cbb61de5512 100644 --- a/buildSrc/call-site-instrumentation-plugin/build.gradle.kts +++ b/buildSrc/call-site-instrumentation-plugin/build.gradle.kts @@ -1,6 +1,5 @@ plugins { java - groovy id("com.diffplug.spotless") version "8.2.1" id("com.gradleup.shadow") version "8.3.9" } @@ -34,8 +33,9 @@ dependencies { implementation(libs.javaparser.symbol.solver) testImplementation(libs.bytebuddy) - testImplementation(libs.bundles.groovy) - testImplementation(libs.bundles.spock) + testImplementation(libs.bundles.junit5) + testRuntimeOnly(libs.junit.platform.launcher) + testImplementation(libs.bundles.mockito) testImplementation("javax.servlet", "javax.servlet-api", "3.0.1") testImplementation(libs.spotbugs.annotations) } diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/AdviceGeneratorTest.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/AdviceGeneratorTest.groovy deleted file mode 100644 index 39b9dc3fb29..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/AdviceGeneratorTest.groovy +++ /dev/null @@ -1,518 +0,0 @@ -package datadog.trace.plugin.csi.impl - - -import datadog.trace.agent.tooling.csi.CallSite -import datadog.trace.agent.tooling.csi.CallSites -import datadog.trace.plugin.csi.AdviceGenerator -import datadog.trace.plugin.csi.impl.assertion.AssertBuilder -import datadog.trace.plugin.csi.impl.assertion.CallSiteAssert -import datadog.trace.plugin.csi.impl.ext.tests.IastCallSites -import datadog.trace.plugin.csi.impl.ext.tests.RaspCallSites -import groovy.transform.CompileDynamic -import spock.lang.Requires -import spock.lang.TempDir - -import javax.servlet.ServletRequest -import java.lang.invoke.MethodHandles -import java.lang.invoke.MethodType - -import static CallSiteFactory.pointcutParser - -@CompileDynamic -final class AdviceGeneratorTest extends BaseCsiPluginTest { - - @TempDir - private File buildDir - - @CallSite(spi = CallSites) - class BeforeAdvice { - @CallSite.Before('java.security.MessageDigest java.security.MessageDigest.getInstance(java.lang.String)') - static void before(@CallSite.Argument final String algorithm) {} - } - - void 'test before advice'() { - setup: - final spec = buildClassSpecification(BeforeAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - assertCallSites(result.file) { - interfaces(CallSites) - helpers(BeforeAdvice) - advices(0) { - type("BEFORE") - pointcut('java/security/MessageDigest', 'getInstance', '(Ljava/lang/String;)Ljava/security/MessageDigest;') - statements( - 'handler.dupParameters(descriptor, StackDupMode.COPY);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$BeforeAdvice", "before", "(Ljava/lang/String;)V");', - 'handler.method(opcode, owner, name, descriptor, isInterface);' - ) - } - } - } - - @CallSite(spi = CallSites) - class AroundAdvice { - @CallSite.Around('java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)') - static String around(@CallSite.This final String self, @CallSite.Argument final String regexp, @CallSite.Argument final String replacement) { - return self.replaceAll(regexp, replacement) - } - } - - void 'test around advice'() { - setup: - final spec = buildClassSpecification(AroundAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - assertCallSites(result.file) { - interfaces(CallSites) - helpers(AroundAdvice) - advices(0) { - type("AROUND") - pointcut('java/lang/String', 'replaceAll', '(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;') - statements( - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AroundAdvice", "around", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;");' - ) - } - } - } - - @CallSite(spi = CallSites) - class AfterAdvice { - @CallSite.After('java.lang.String java.lang.String.concat(java.lang.String)') - static String after(@CallSite.This final String self, @CallSite.Argument final String param, @CallSite.Return final String result) { - return result - } - } - - void 'test after advice'() { - setup: - final spec = buildClassSpecification(AfterAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - assertCallSites(result.file) { - interfaces(CallSites) - helpers(AfterAdvice) - advices(0) { - type("AFTER") - pointcut('java/lang/String', 'concat', '(Ljava/lang/String;)Ljava/lang/String;') - statements( - 'handler.dupInvoke(owner, descriptor, StackDupMode.COPY);', - 'handler.method(opcode, owner, name, descriptor, isInterface);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdvice", "after", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;");', - ) - } - } - } - - @CallSite(spi = CallSites) - class AfterAdviceCtor { - @CallSite.After('void java.net.URL.(java.lang.String)') - static URL after(@CallSite.AllArguments final Object[] args, @CallSite.Return final URL url) { - return url - } - } - - void 'test after advice ctor'() { - setup: - final spec = buildClassSpecification(AfterAdviceCtor) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - assertCallSites(result.file) { - interfaces(CallSites) - helpers(AfterAdviceCtor) - advices(0) { - pointcut('java/net/URL', '', '(Ljava/lang/String;)V') - statements( - 'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY_CTOR);', - 'handler.method(opcode, owner, name, descriptor, isInterface);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdviceCtor", "after", "([Ljava/lang/Object;Ljava/net/URL;)Ljava/net/URL;");', - ) - } - } - } - - @CallSite(spi = SampleSpi.class) - class SpiAdvice { - @CallSite.Before('java.security.MessageDigest java.security.MessageDigest.getInstance(java.lang.String)') - static void before(@CallSite.Argument final String algorithm) {} - - interface SampleSpi {} - } - - void 'test generator with spi'() { - setup: - final spec = buildClassSpecification(SpiAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - - assertCallSites(result.file) { - interfaces(CallSites, SpiAdvice.SampleSpi) - } - } - - @CallSite(spi = CallSites) - class InvokeDynamicAfterAdvice { - @CallSite.After( - value = 'java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])', - invokeDynamic = true - ) - static String after(@CallSite.AllArguments final Object[] arguments, @CallSite.Return final String result) { - result - } - } - - @Requires({ - jvm.java9Compatible - }) - void 'test invoke dynamic after advice'() { - setup: - final spec = buildClassSpecification(InvokeDynamicAfterAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - assertCallSites(result.file) { - interfaces(CallSites) - helpers(InvokeDynamicAfterAdvice) - advices(0) { - pointcut( - 'java/lang/invoke/StringConcatFactory', - 'makeConcatWithConstants', - '(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;' - ) - statements( - 'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY);', - 'handler.invokeDynamic(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicAfterAdvice", "after", "([Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;");' - ) - } - } - } - - @CallSite(spi = CallSites) - class InvokeDynamicAroundAdvice { - @CallSite.Around( - value = 'java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])', - invokeDynamic = true - ) - static java.lang.invoke.CallSite around(@CallSite.Argument final MethodHandles.Lookup lookup, - @CallSite.Argument final String name, - @CallSite.Argument final MethodType concatType, - @CallSite.Argument final String recipe, - @CallSite.Argument final Object... constants) { - return null - } - } - - @Requires({ - jvm.java9Compatible - }) - void 'test invoke dynamic around advice'() { - setup: - final spec = buildClassSpecification(InvokeDynamicAroundAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - assertCallSites(result.file) { - interfaces(CallSites) - helpers(InvokeDynamicAroundAdvice) - advices(0) { - pointcut( - 'java/lang/invoke/StringConcatFactory', - 'makeConcatWithConstants', - '(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;' - ) - statements( - 'handler.invokeDynamic(name, descriptor, new Handle(Opcodes.H_INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicAroundAdvice", "around", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;", false), bootstrapMethodArguments);', - ) - } - } - } - - @CallSite(spi = CallSites) - class InvokeDynamicWithConstantsAdvice { - @CallSite.After( - value = 'java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])', - invokeDynamic = true - ) - static String after(@CallSite.AllArguments final Object[] arguments, - @CallSite.Return final String result, - @CallSite.InvokeDynamicConstants final Object[] constants) { - return result - } - } - - @Requires({ - jvm.java9Compatible - }) - void 'test invoke dynamic with constants advice'() { - setup: - final spec = buildClassSpecification(InvokeDynamicWithConstantsAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - assertCallSites(result.file) { - interfaces(CallSites) - helpers(InvokeDynamicWithConstantsAdvice) - advices(0) { - pointcut( - 'java/lang/invoke/StringConcatFactory', - 'makeConcatWithConstants', - '(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;' - ) - statements( - 'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY);', - 'handler.invokeDynamic(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);', - 'handler.loadConstantArray(bootstrapMethodArguments);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicWithConstantsAdvice", "after", "([Ljava/lang/Object;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;");' - ) - } - } - } - - @CallSite(spi = CallSites) - class ArrayAdvice { - @CallSite.AfterArray([ - @CallSite.After('java.util.Map javax.servlet.ServletRequest.getParameterMap()'), - @CallSite.After('java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()') - ]) - static Map after(@CallSite.This final ServletRequest request, @CallSite.Return final Map parameters) { - return parameters - } - } - - void 'test array advice'() { - setup: - final spec = buildClassSpecification(ArrayAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - assertCallSites(result.file) { - advices(0) { - pointcut('javax/servlet/ServletRequest', 'getParameterMap', '()Ljava/util/Map;') - } - advices(1) { - pointcut('javax/servlet/ServletRequestWrapper', 'getParameterMap', '()Ljava/util/Map;') - } - } - } - - class MinJavaVersionCheck { - static boolean isAtLeast(final String version) { - return Integer.parseInt(version) >= 9 - } - } - - @CallSite(spi = CallSites, enabled = ['datadog.trace.plugin.csi.impl.AdviceGeneratorTest$MinJavaVersionCheck', 'isAtLeast', '18']) - class MinJavaVersionAdvice { - @CallSite.After('java.lang.String java.lang.String.concat(java.lang.String)') - static String after(@CallSite.This final String self, @CallSite.Argument final String param, @CallSite.Return final String result) { - return result - } - } - - void 'test custom enabled property'() { - setup: - final spec = buildClassSpecification(MinJavaVersionAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors(result) - assertCallSites(result.file) { callSites -> - interfaces(CallSites, CallSites.HasEnabledProperty) - enabled(MinJavaVersionCheck.getDeclaredMethod('isAtLeast', String), '18') - } - } - - - @CallSite(spi = CallSites) - class PartialArgumentsBeforeAdvice { - @CallSite.Before("int java.sql.Statement.executeUpdate(java.lang.String, java.lang.String[])") - static void before(@CallSite.Argument(0) String arg1) {} - - @CallSite.Before("java.lang.String java.lang.String.format(java.lang.String, java.lang.Object[])") - static void before(@CallSite.Argument(1) Object[] arg) {} - - @CallSite.Before("java.lang.CharSequence java.lang.String.subSequence(int, int)") - static void before(@CallSite.This String thiz, @CallSite.Argument(0) int arg) {} - } - - void 'partial arguments with before advice'() { - setup: - final spec = buildClassSpecification(PartialArgumentsBeforeAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors result - assertCallSites(result.file) { - advices(0) { - pointcut('java/sql/Statement', 'executeUpdate', '(Ljava/lang/String;[Ljava/lang/String;)I') - statements( - 'int[] parameterIndices = new int[] { 0 };', - 'handler.dupParameters(descriptor, parameterIndices, owner);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice", "before", "(Ljava/lang/String;)V");', - 'handler.method(opcode, owner, name, descriptor, isInterface);', - ) - } - advices(1) { - pointcut('java/lang/String', 'format', '(Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;') - statements( - 'int[] parameterIndices = new int[] { 1 };', - 'handler.dupParameters(descriptor, parameterIndices, null);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice", "before", "([Ljava/lang/Object;)V");', - 'handler.method(opcode, owner, name, descriptor, isInterface);', - ) - } - advices(2) { - pointcut('java/lang/String', 'subSequence', '(II)Ljava/lang/CharSequence;') - statements( - 'int[] parameterIndices = new int[] { 0 };', - 'handler.dupInvoke(owner, descriptor, parameterIndices);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice", "before", "(Ljava/lang/String;I)V");', - 'handler.method(opcode, owner, name, descriptor, isInterface);', - ) - } - } - } - - - @CallSite(spi = CallSites) - class SuperTypeReturnAdvice { - @CallSite.After("void java.lang.StringBuilder.(java.lang.String)") - static Object after(@CallSite.AllArguments Object[] args, @CallSite.Return Object result) { - return result - } - } - - void 'test returning super type'() { - setup: - final spec = buildClassSpecification(SuperTypeReturnAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors result - assertCallSites(result.file) { - advices(0) { - pointcut('java/lang/StringBuilder', '', '(Ljava/lang/String;)V') - statements( - 'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY_CTOR);', - 'handler.method(opcode, owner, name, descriptor, isInterface);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$SuperTypeReturnAdvice", "after", "([Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");', - 'handler.instruction(Opcodes.CHECKCAST, "java/lang/StringBuilder");' - ) - } - } - } - - @CallSite(spi = [IastCallSites, RaspCallSites]) - class MultipleSpiClassesAdvice { - @CallSite.After("void java.lang.StringBuilder.(java.lang.String)") - static Object after(@CallSite.AllArguments Object[] args, @CallSite.Return Object result) { - return result - } - } - - void 'test multiple spi classes'() { - setup: - final spec = buildClassSpecification(MultipleSpiClassesAdvice) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors result - assertCallSites(result.file) { - spi(IastCallSites, RaspCallSites) - } - } - - - @CallSite(spi = CallSites) - class AfterAdviceWithVoidReturn { - @CallSite.After("void java.lang.StringBuilder.setLength(int)") - static void after(@CallSite.This StringBuilder self, @CallSite.Argument(0) int length) { - } - } - - void 'test after advice with void return'() { - setup: - final spec = buildClassSpecification(AfterAdviceWithVoidReturn) - final generator = buildAdviceGenerator(buildDir) - - when: - final result = generator.generate(spec) - - then: - assertNoErrors result - assertCallSites(result.file) { - advices(0) { - pointcut('java/lang/StringBuilder', 'setLength', '(I)V') - statements( - 'handler.dupInvoke(owner, descriptor, StackDupMode.COPY);', - 'handler.method(opcode, owner, name, descriptor, isInterface);', - 'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdviceWithVoidReturn", "after", "(Ljava/lang/StringBuilder;I)V");', - ) - } - } - } - - private static AdviceGenerator buildAdviceGenerator(final File targetFolder) { - return new AdviceGeneratorImpl(targetFolder, pointcutParser()) - } - - private static void assertCallSites(final File generated, @DelegatesTo(CallSiteAssert) final Closure closure) { - final asserter = new AssertBuilder(generated).build() - closure.delegate = asserter - closure(asserter) - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/AdviceSpecificationTest.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/AdviceSpecificationTest.groovy deleted file mode 100644 index fafca5757b4..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/AdviceSpecificationTest.groovy +++ /dev/null @@ -1,567 +0,0 @@ -package datadog.trace.plugin.csi.impl - -import datadog.trace.agent.tooling.csi.CallSite -import datadog.trace.agent.tooling.csi.CallSites -import datadog.trace.plugin.csi.HasErrors.Failure -import datadog.trace.plugin.csi.util.ErrorCode -import groovy.transform.CompileDynamic -import org.objectweb.asm.Type - -import datadog.trace.plugin.csi.impl.CallSiteSpecification.ThisSpecification as This -import datadog.trace.plugin.csi.impl.CallSiteSpecification.ReturnSpecification as Return -import datadog.trace.plugin.csi.impl.CallSiteSpecification.ArgumentSpecification as Arg -import datadog.trace.plugin.csi.impl.CallSiteSpecification.AllArgsSpecification as AllArgs -import datadog.trace.plugin.csi.impl.CallSiteSpecification.InvokeDynamicConstantsSpecification as DynConsts -import datadog.trace.plugin.csi.impl.CallSiteSpecification.ParameterSpecification -import spock.lang.Requires - -import javax.servlet.ServletRequest -import java.lang.invoke.MethodHandles -import java.lang.invoke.MethodType -import java.security.MessageDigest - -@CompileDynamic -class AdviceSpecificationTest extends BaseCsiPluginTest { - - @CallSite(spi = CallSites) - class EmptyAdvice {} - - void 'test class generator error, call site without advices'() { - setup: - final context = mockValidationContext() - final spec = buildClassSpecification(EmptyAdvice) - - when: - spec.validate(context) - - then: - 1 * context.addError(ErrorCode.CALL_SITE_SHOULD_HAVE_ADVICE_METHODS, _) - } - - @CallSite(spi = CallSites) - class NonPublicStaticMethodAdvice { - @CallSite.Before("void java.lang.Runnable.run()") - private void advice(@CallSite.This final Runnable run) {} - } - - void 'test class generator error, non public static method'() { - setup: - final context = mockValidationContext() - final spec = buildClassSpecification(NonPublicStaticMethodAdvice) - - when: - spec.advices.each { it.validate(context) } - - then: - 1 * context.addError(ErrorCode.ADVICE_METHOD_NOT_STATIC_AND_PUBLIC, _) - } - - class BeforeStringConcat { - static void concat(final String self, final String value) {} - } - - void 'test advice class should be on the classpath'(final Type type, final int errors) { - setup: - final context = mockValidationContext() - final spec = before { - advice { - method(BeforeStringConcat.getDeclaredMethod('concat', String, String)) - owner(type) // override owner - } - parameters(new This(), new Arg()) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - errors * context.addError { Failure failure -> failure.errorCode == ErrorCode.UNRESOLVED_TYPE } - 0 * context.addError(*_) - - where: - type | errors - Type.getType('Lfoo/bar/FooBar;') | 1 - Type.getType(BeforeStringConcat) | 0 - } - - void 'test before advice should return void'(final Class returnType, final int errors) { - setup: - final context = mockValidationContext() - final spec = before { - advice { - owner(BeforeStringConcat) - method('concat') - descriptor(returnType, String, String) // change return - } - parameters(new This(), new Arg()) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - errors * context.addError(ErrorCode.ADVICE_BEFORE_SHOULD_RETURN_VOID, _) - 0 * context.addError(*_) - - - where: - returnType || errors - String || 1 - void.class || 0 - } - - class AroundStringConcat { - static String concat(final String self, final String value) { - return self.concat(value) - } - } - - void 'test around advice should return type compatible with pointcut'(final Class returnType, final int errors) { - setup: - final context = mockValidationContext() - final spec = around { - advice { - owner(AroundStringConcat) - method('concat') - descriptor(returnType, String, String) // change return - } - parameters(new This(), new Arg()) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - errors * context.addError(ErrorCode.ADVICE_METHOD_RETURN_NOT_COMPATIBLE, _) - 0 * context.addError(*_) - - where: - returnType | errors - MessageDigest | 1 - Object | 0 - String | 0 - } - - class AfterStringConcat { - static String concat(final String self, final String value, final String result) { - return result - } - } - - void 'test after advice should return type compatible with pointcut'(final Class returnType, final int errors) { - setup: - final context = mockValidationContext() - final spec = after { - advice { - owner(AfterStringConcat) - method('concat') - descriptor(returnType, String, String, String) - // change return - } - parameters(new This(), new Arg(), new Return()) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - errors * context.addError(ErrorCode.ADVICE_METHOD_RETURN_NOT_COMPATIBLE, _) - 0 * context.addError(*_) - - where: - returnType | errors - MessageDigest | 1 - Object | 0 - String | 0 - } - - void 'test this parameter should always be the first'(final List params, final int errors) { - setup: - final context = mockValidationContext() - final spec = around { - advice { - method(AroundStringConcat.getDeclaredMethod('concat', String, String)) - } - parameters(params as ParameterSpecification[]) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - errors * context.addError(ErrorCode.ADVICE_PARAMETER_THIS_SHOULD_BE_FIRST, _) - 0 * context.addError(*_) - - where: - params | errors - [new This(), new Arg()] | 0 - [new Arg(), new This()] | 1 - } - - - void 'test this parameter should be compatible with pointcut'(final Class type, final int errors) { - setup: - final context = mockValidationContext() - final spec = around { - advice { - owner(AroundStringConcat) - method('concat') - descriptor(String, type, String) - } - parameters(new This(), new Arg()) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - errors * context.addError(ErrorCode.ADVICE_METHOD_PARAM_THIS_NOT_COMPATIBLE, _) - // advice returns String so other return types won't be able to find the method - if (type != String) { - 1 * context.addError { Failure failure -> failure.errorCode == ErrorCode.UNRESOLVED_METHOD } - } - 0 * context.addError(*_) - - where: - type | errors - MessageDigest | 1 - Object | 0 - String | 0 - } - - void 'test return parameter should always be the last'(final List params, final int errors) { - setup: - final context = mockValidationContext() - final spec = after { - advice { - method(AfterStringConcat.getDeclaredMethod('concat', String, String, String)) - } - parameters(params as ParameterSpecification[]) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - errors * context.addError(ErrorCode.ADVICE_PARAMETER_RETURN_SHOULD_BE_LAST, _) - // other errors are ignored - - where: - params | errors - [new This(), new Arg(), new Return()] | 0 - [new This(), new Return(), new Arg()] | 1 - } - - - void 'test return parameter should be compatible with pointcut'(final Class returnType, final int errors) { - setup: - final context = mockValidationContext() - final spec = after { - advice { - owner(AfterStringConcat) - method('concat') - descriptor(String, String, String, returnType) - } - parameters(new This(), new Arg(), new Return()) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - errors * context.addError(ErrorCode.ADVICE_METHOD_PARAM_RETURN_NOT_COMPATIBLE, _) - // advice returns String so other return types won't be able to find the method - if (returnType != String) { - 1 * context.addError { Failure failure -> failure.errorCode == ErrorCode.UNRESOLVED_METHOD } - } - 0 * context.addError(*_) - - where: - returnType | errors - MessageDigest | 1 - String | 0 - Object | 0 - } - - - void 'test argument parameter should be compatible with pointcut'(final Class parameterType, final int errors) { - setup: - final context = mockValidationContext() - final spec = after { - advice { - owner(AfterStringConcat) - method('concat') - descriptor(String, String, parameterType, String) - } - parameters(new This(), new Arg(), new Return()) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - errors * context.addError(ErrorCode.ADVICE_METHOD_PARAM_NOT_COMPATIBLE, _) - // advice parameter is a String so with other types won't be able to find the method - if (parameterType != String) { - 1 * context.addError { Failure failure -> failure.errorCode == ErrorCode.UNRESOLVED_METHOD } - } - 0 * context.addError(*_) - - where: - parameterType | errors - MessageDigest | 1 - String | 0 - Object | 0 - } - - class BadAfterStringConcat { - static String concat(final String param1, final String param2) { - return param2 - } - } - - void 'test after advice requires @This and @Return parameters'(final List params, final ErrorCode error) { - setup: - final context = mockValidationContext() - final spec = after { - advice { - method(BadAfterStringConcat.getDeclaredMethod('concat', String, String)) - } - parameters(params as ParameterSpecification[]) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - 1 * context.addError(error, _) - 0 * context.addError(*_) - - where: - params | error - [new Arg(), new Return()] | ErrorCode.ADVICE_AFTER_SHOULD_HAVE_THIS - [new This(), new Arg()] | ErrorCode.ADVICE_AFTER_SHOULD_HAVE_RETURN - } - - class BadAllArgsAfterStringConcat { - static String concat(final Object[] param1, final String param2, final String param3) { - return param3 - } - } - - void 'should not mix @AllArguments and @Argument'() { - setup: - final context = mockValidationContext() - final spec = after { - advice { - method(BadAllArgsAfterStringConcat.getDeclaredMethod('concat', Object[], String, String)) - } - parameters(new AllArgs(includeThis: true), new Arg(), new Return()) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - 1 * context.addError(ErrorCode.ADVICE_PARAMETER_ALL_ARGS_MIXED, _) - 1 * context.addError(ErrorCode.ADVICE_PARAMETER_ARGUMENT_OUT_OF_BOUNDS, _) // all args consumes all arguments - 0 * context.addError(*_) - } - - static class TestInheritedMethod { - static String after(final ServletRequest request, final String parameter, final String value) { - return value - } - } - - void 'test inherited methods'() { - setup: - final context = mockValidationContext() - final spec = after { - advice { - method(TestInheritedMethod.getDeclaredMethod('after', ServletRequest, String, String)) - } - parameters(new This(), new Arg(), new Return()) - signature('java.lang.String javax.servlet.http.HttpServletRequest.getParameter(java.lang.String)') - } - - when: - spec.validate(context) - - then: - 0 * context.addError(*_) - } - - static class TestInvokeDynamicConstants { - static Object after(final Object[] parameter, final Object result, final Object[] constants) { - return result - } - } - - @Requires({ - jvm.java9Compatible - }) - void 'test invoke dynamic constants'() { - setup: - final context = mockValidationContext() - final spec = after { - advice { - method(TestInvokeDynamicConstants.getDeclaredMethod('after', Object[], Object, Object[])) - } - parameters(new AllArgs(), new Return(), new DynConsts()) - signature('java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])') - invokeDynamic(true) - } - - when: - spec.validate(context) - - then: - 0 * context.addError(*_) - } - - @Requires({ - jvm.java9Compatible - }) - void 'test invoke dynamic constants should be last'(final List params, final ErrorCode error) { - setup: - final context = mockValidationContext() - final spec = after { - advice { - method(TestInvokeDynamicConstants.getDeclaredMethod('after', Object[], Object, Object[])) - } - parameters(params as ParameterSpecification[]) - signature('java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])') - invokeDynamic(true) - } - - when: - spec.validate(context) - - then: - if (error != null) { - 1 * context.addError(error, _) - } - 0 * context.addError(*_) - - where: - params | error - [new AllArgs(), new Return(), new DynConsts()] | null - [new AllArgs(), new DynConsts(), new Return()] | ErrorCode.ADVICE_PARAMETER_INVOKE_DYNAMIC_CONSTANTS_SHOULD_BE_LAST - } - - static class TestInvokeDynamicConstantsNonInvokeDynamic { - static Object after(final Object self, final Object[] parameter, final Object value, final Object[] constants) { - return value - } - } - - @Requires({ - jvm.java9Compatible - }) - void 'test invoke dynamic constants on non invoke dynamic pointcut'() { - setup: - final context = mockValidationContext() - final spec = after { - advice { - method(TestInvokeDynamicConstantsNonInvokeDynamic.getDeclaredMethod('after', Object, Object[], Object, Object[])) - } - parameters(new This(), new AllArgs(), new DynConsts(), new Return()) - signature('java.lang.String java.lang.String.concat(java.lang.String)') - } - - when: - spec.validate(context) - - then: - 1 * context.addError(ErrorCode.ADVICE_PARAMETER_INVOKE_DYNAMIC_CONSTANTS_ON_NON_INVOKE_DYNAMIC, _) - } - - static class TestInvokeDynamicConstantsBefore { - static void before(final Object[] parameter, final Object[] constants) { - } - } - - @Requires({ - jvm.java9Compatible - }) - void 'test invoke dynamic constants on non @After advice'() { - setup: - final context = mockValidationContext() - final spec = before { - advice { - method(TestInvokeDynamicConstantsBefore.getDeclaredMethod('before', Object[], Object[])) - } - parameters(new AllArgs(), new DynConsts()) - signature('java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])') - invokeDynamic(true) - } - - when: - spec.validate(context) - - then: - 1 * context.addError(ErrorCode.ADVICE_PARAMETER_INVOKE_DYNAMIC_CONSTANTS_NON_AFTER_ADVICE, _) - } - - static class TestInvokeDynamicConstantsAround { - static java.lang.invoke.CallSite around(final MethodHandles.Lookup lookup, final String name, final MethodType concatType, final String recipe, final Object... constants) { - return null - } - } - - @Requires({ - jvm.java9Compatible - }) - void 'test invoke dynamic on @Around advice'() { - setup: - final context = mockValidationContext() - final spec = around { - advice { - method(TestInvokeDynamicConstantsAround.getDeclaredMethod('around', MethodHandles.Lookup, String, MethodType, String, Object[])) - } - parameters(new Arg(), new Arg(), new Arg(), new Arg(), new Arg()) - signature('java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])') - invokeDynamic(true) - } - - when: - spec.validate(context) - - then: - 0 * context.addError(_, _) - } - - - @CallSite(spi = CallSites) - class AfterWithVoidWrongAdvice { - @CallSite.After("void java.lang.String.getChars(int, int, char[], int)") - static String after(@CallSite.AllArguments final Object[] args, @CallSite.Return final String result) { - return result - } - } - - void 'test after advice with void should not use @Return'() { - setup: - final context = mockValidationContext() - final spec = buildClassSpecification(AfterWithVoidWrongAdvice) - - when: - spec.advices.each { it.validate(context) } - - then: - 1 * context.addError(ErrorCode.ADVICE_AFTER_VOID_METHOD_SHOULD_RETURN_VOID, _) - 1 * context.addError(ErrorCode.ADVICE_AFTER_VOID_METHOD_SHOULD_NOT_HAVE_RETURN, _) - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/AsmSpecificationBuilderTest.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/AsmSpecificationBuilderTest.groovy deleted file mode 100644 index 6223b30afc3..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/AsmSpecificationBuilderTest.groovy +++ /dev/null @@ -1,516 +0,0 @@ -package datadog.trace.plugin.csi.impl - -import datadog.trace.agent.tooling.csi.CallSite -import datadog.trace.agent.tooling.csi.CallSites -import datadog.trace.plugin.csi.impl.CallSiteSpecification.AdviceSpecification -import datadog.trace.plugin.csi.impl.CallSiteSpecification.AfterSpecification -import datadog.trace.plugin.csi.impl.CallSiteSpecification.AroundSpecification -import datadog.trace.plugin.csi.impl.CallSiteSpecification.BeforeSpecification -import datadog.trace.plugin.csi.util.Types -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings -import groovy.transform.CompileDynamic -import org.objectweb.asm.Type - -import javax.annotation.Nonnull -import javax.annotation.Nullable -import javax.servlet.ServletRequest -import java.lang.invoke.MethodHandles -import java.lang.invoke.MethodType -import java.util.stream.Collectors - -@CompileDynamic -final class AsmSpecificationBuilderTest extends BaseCsiPluginTest { - - static class NonCallSite {} - - void 'test specification builder for non call site'() { - setup: - final advice = fetchClass(NonCallSite) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice) - - then: - !result.present - } - - @CallSite(spi = Spi) - static class WithSpiClass { - interface Spi {} - } - - void 'test specification builder with custom spi class'() { - setup: - final advice = fetchClass(WithSpiClass) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.spi == [Type.getType(WithSpiClass.Spi)] as Type[] - } - - @CallSite(spi = CallSites, helpers = [SampleHelper1.class, SampleHelper2.class]) - static class HelpersAdvice { - static class SampleHelper1 {} - static class SampleHelper2 {} - } - - void 'test specification builder with custom helper classes'() { - setup: - final advice = fetchClass(HelpersAdvice) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.helpers.toList().containsAll([ - Type.getType(HelpersAdvice), - Type.getType(HelpersAdvice.SampleHelper1), - Type.getType(HelpersAdvice.SampleHelper2) - ]) - } - - @CallSite(spi = CallSites) - static class BeforeAdvice { - @CallSite.Before('java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)') - static void before(@CallSite.This final String self, @CallSite.Argument final String regexp, @CallSite.Argument final String replacement) { - } - } - - void 'test specification builder for before advice'() { - setup: - final advice = fetchClass(BeforeAdvice) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == BeforeAdvice.name - final beforeSpec = findAdvice(result, 'before') - beforeSpec instanceof BeforeSpecification - beforeSpec.advice.methodType.descriptor == '(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V' - beforeSpec.signature == 'java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)' - beforeSpec.findThis() != null - beforeSpec.findReturn() == null - beforeSpec.findAllArguments() == null - beforeSpec.findInvokeDynamicConstants() == null - final arguments = getArguments(beforeSpec) - arguments == [0, 1] - } - - @CallSite(spi = CallSites) - static class AroundAdvice { - @CallSite.Around('java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)') - static String around(@CallSite.This final String self, @CallSite.Argument final String regexp, @CallSite.Argument final String replacement) { - return self.replaceAll(regexp, replacement) - } - } - - void 'test specification builder for around advice'() { - setup: - final advice = fetchClass(AroundAdvice) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == AroundAdvice.name - final aroundSpec = findAdvice(result, 'around') - aroundSpec instanceof AroundSpecification - aroundSpec.advice.methodType.descriptor == '(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;' - aroundSpec.signature == 'java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)' - aroundSpec.findThis() != null - aroundSpec.findReturn() == null - aroundSpec.findAllArguments() == null - aroundSpec.findInvokeDynamicConstants() == null - final arguments = getArguments(aroundSpec) - arguments == [0, 1] - } - - @CallSite(spi = CallSites) - static class AfterAdvice { - @CallSite.After('java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)') - static String after(@CallSite.This final String self, @CallSite.Argument final String regexp, @CallSite.Argument final String replacement, @CallSite.Return final String result) { - return result - } - } - - void 'test specification builder for after advice'() { - setup: - final advice = fetchClass(AfterAdvice) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == AfterAdvice.name - final afterSpec = findAdvice(result, 'after') - afterSpec instanceof AfterSpecification - afterSpec.advice.methodType.descriptor == '(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;' - afterSpec.signature == 'java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)' - afterSpec.findThis() != null - afterSpec.findReturn() != null - afterSpec.findAllArguments() == null - afterSpec.findInvokeDynamicConstants() == null - final arguments = getArguments(afterSpec) - arguments == [0, 1] - } - - @CallSite - static class AllArgsAdvice { - @CallSite.Around('java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)') - static String allArgs(@CallSite.AllArguments(includeThis = true) final Object[] arguments, @CallSite.Return final String result) { - return result - } - } - - void 'test specification builder for advice with @AllArguments'() { - setup: - final advice = fetchClass(AllArgsAdvice) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == AllArgsAdvice.name - final allArgsSpec = findAdvice(result, 'allArgs') - allArgsSpec instanceof AroundSpecification - allArgsSpec.advice.methodType.descriptor == '([Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;' - allArgsSpec.signature == 'java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)' - allArgsSpec.findThis() == null - allArgsSpec.findReturn() != null - final allArguments = allArgsSpec.findAllArguments() - allArguments != null - allArguments.includeThis - allArgsSpec.findInvokeDynamicConstants() == null - final arguments = getArguments(allArgsSpec) - arguments == [] - } - - @CallSite(spi = CallSites) - static class InvokeDynamicBeforeAdvice { - @CallSite.After( - value = 'java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])', - invokeDynamic = true - ) - static String invokeDynamic(@CallSite.AllArguments final Object[] arguments, @CallSite.Return final String result) { - return result - } - } - - void 'test specification builder for before invoke dynamic'() { - setup: - final advice = fetchClass(InvokeDynamicBeforeAdvice) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == InvokeDynamicBeforeAdvice.name - final invokeDynamicSpec = findAdvice(result, 'invokeDynamic') - invokeDynamicSpec instanceof AfterSpecification - invokeDynamicSpec.advice.methodType.descriptor == '([Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;' - invokeDynamicSpec.signature == 'java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])' - invokeDynamicSpec.findThis() == null - invokeDynamicSpec.findReturn() != null - final allArguments = invokeDynamicSpec.findAllArguments() - allArguments != null - !allArguments.includeThis - invokeDynamicSpec.findInvokeDynamicConstants() == null - final arguments = getArguments(invokeDynamicSpec) - arguments == [] - } - - @CallSite(spi = CallSites) - static class InvokeDynamicAroundAdvice { - @CallSite.Around( - value = 'java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])', - invokeDynamic = true - ) - static java.lang.invoke.CallSite invokeDynamic(@CallSite.Argument final MethodHandles.Lookup lookup, - @CallSite.Argument final String name, - @CallSite.Argument final MethodType concatType, - @CallSite.Argument final String recipe, - @CallSite.Argument final Object... constants) { - return null - } - } - - void 'test specification builder for around invoke dynamic'() { - setup: - final advice = fetchClass(InvokeDynamicAroundAdvice) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == InvokeDynamicAroundAdvice.name - final invokeDynamicSpec = findAdvice(result, 'invokeDynamic') - invokeDynamicSpec instanceof AroundSpecification - invokeDynamicSpec.advice.methodType.descriptor == '(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;' - invokeDynamicSpec.signature == 'java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])' - invokeDynamicSpec.findThis() == null - invokeDynamicSpec.findReturn() == null - invokeDynamicSpec.findAllArguments() == null - invokeDynamicSpec.findInvokeDynamicConstants() == null - final arguments = getArguments(invokeDynamicSpec) - arguments == [0, 1, 2, 3, 4] - } - - @CallSite(spi = CallSites) - static class TestInvokeDynamicConstants { - @CallSite.After( - value = 'java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])', - invokeDynamic = true - ) - static String after(@CallSite.AllArguments final Object[] parameter, - @CallSite.InvokeDynamicConstants final Object[] constants, - @CallSite.Return final String value) { - return value - } - } - - void 'test invoke dynamic constants'() { - setup: - final advice = fetchClass(TestInvokeDynamicConstants) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == TestInvokeDynamicConstants.name - final inheritedSpec = findAdvice(result, 'after') - inheritedSpec instanceof AfterSpecification - inheritedSpec.advice.methodType.descriptor == '([Ljava/lang/Object;[Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;' - inheritedSpec.signature == 'java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])' - inheritedSpec.findThis() == null - inheritedSpec.findReturn() != null - inheritedSpec.findInvokeDynamicConstants() != null - final arguments = getArguments(inheritedSpec) - arguments == [] - } - - @CallSite(spi = CallSites) - static class TestBeforeArray { - - @CallSite.BeforeArray([ - @CallSite.Before('java.util.Map javax.servlet.ServletRequest.getParameterMap()'), - @CallSite.Before('java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()') - ]) - static void before(@CallSite.This final ServletRequest request) { } - } - - void 'test specification builder for before advice array'() { - setup: - final advice = fetchClass(TestBeforeArray) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == TestBeforeArray.name - final list = result.advices - list.size() == 2 - list.each { - assert it instanceof BeforeSpecification - assert it.advice.methodType.descriptor == '(Ljavax/servlet/ServletRequest;)V' - assert it.signature in [ - 'java.util.Map javax.servlet.ServletRequest.getParameterMap()', - 'java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()' - ] - assert it.findThis() != null - assert it.findReturn() == null - assert it.findAllArguments() == null - assert it.findInvokeDynamicConstants() == null - final arguments = getArguments(it) - assert arguments == [] - } - } - - @CallSite(spi = CallSites) - static class TestAroundArray { - - @CallSite.AroundArray([ - @CallSite.Around('java.util.Map javax.servlet.ServletRequest.getParameterMap()'), - @CallSite.Around('java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()') - ]) - static Map around(@CallSite.This final ServletRequest request) { - return request.getParameterMap() - } - } - - void 'test specification builder for around advice array'() { - setup: - final advice = fetchClass(TestAroundArray) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == TestAroundArray.name - final list = result.advices - list.size() == 2 - list.each { - assert it instanceof AroundSpecification - assert it.advice.methodType.descriptor == '(Ljavax/servlet/ServletRequest;)Ljava/util/Map;' - assert it.signature in [ - 'java.util.Map javax.servlet.ServletRequest.getParameterMap()', - 'java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()' - ] - assert it.findThis() != null - assert it.findReturn() == null - assert it.findAllArguments() == null - assert it.findInvokeDynamicConstants() == null - final arguments = getArguments(it) - assert arguments == [] - } - } - - @CallSite(spi = CallSites) - static class TestAfterArray { - - @CallSite.AfterArray([ - @CallSite.After('java.util.Map javax.servlet.ServletRequest.getParameterMap()'), - @CallSite.After('java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()') - ]) - static Map after(@CallSite.This final ServletRequest request, @CallSite.Return final Map parameters) { - return parameters - } - } - - void 'test specification builder for before advice array'() { - setup: - final advice = fetchClass(TestAfterArray) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == TestAfterArray.name - final list = result.advices - list.size() == 2 - list.each { - assert it instanceof AfterSpecification - assert it.advice.methodType.descriptor == '(Ljavax/servlet/ServletRequest;Ljava/util/Map;)Ljava/util/Map;' - assert it.signature in [ - 'java.util.Map javax.servlet.ServletRequest.getParameterMap()', - 'java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()' - ] - assert it.findThis() != null - assert it.findReturn() != null - assert it.findAllArguments() == null - assert it.findInvokeDynamicConstants() == null - final arguments = getArguments(it) - assert arguments == [] - } - } - - @CallSite(spi = CallSites) - static class TestInheritedMethod { - @CallSite.After('java.lang.String javax.servlet.http.HttpServletRequest.getParameter(java.lang.String)') - static String after(@CallSite.This final ServletRequest request, @CallSite.Argument final String parameter, @CallSite.Return final String value) { - return value - } - } - - void 'test specification builder for inherited methods'() { - setup: - final advice = fetchClass(TestInheritedMethod) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == TestInheritedMethod.name - final inheritedSpec = findAdvice(result, 'after') - inheritedSpec instanceof AfterSpecification - inheritedSpec.advice.methodType.descriptor == '(Ljavax/servlet/ServletRequest;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;' - inheritedSpec.signature == 'java.lang.String javax.servlet.http.HttpServletRequest.getParameter(java.lang.String)' - inheritedSpec.findThis() != null - inheritedSpec.findReturn() != null - inheritedSpec.findAllArguments() == null - inheritedSpec.findInvokeDynamicConstants() == null - final arguments = getArguments(inheritedSpec) - arguments == [0] - } - - static class IsEnabled { - static boolean isEnabled(final String defaultValue) { - return true - } - } - - @CallSite(spi = CallSites, enabled = ['datadog.trace.plugin.csi.impl.AsmSpecificationBuilderTest$IsEnabled', 'isEnabled', 'true']) - static class TestEnablement { - @CallSite.After('java.lang.String javax.servlet.http.HttpServletRequest.getParameter(java.lang.String)') - static String after(@CallSite.This final ServletRequest request, @CallSite.Argument final String parameter, @CallSite.Return final String value) { - return value - } - } - - void 'test specification builder with enabled property'() { - setup: - final advice = fetchClass(TestEnablement) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == TestEnablement.name - result.enabled != null - result.enabled.method.owner == Type.getType(IsEnabled) - result.enabled.method.methodName == 'isEnabled' - result.enabled.method.methodType == Type.getMethodType(Types.BOOLEAN, Types.STRING) - result.enabled.arguments == ['true'] - } - - @CallSite(spi = CallSites) - static class TestWithOtherAnnotations { - @CallSite.Around("java.lang.StringBuilder java.lang.StringBuilder.append(java.lang.Object)") - @CallSite.Around("java.lang.StringBuffer java.lang.StringBuffer.append(java.lang.Object)") - @Nonnull - @SuppressFBWarnings( - "NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") // we do check for null on self - // parameter - static Appendable aroundAppend(@CallSite.This @Nullable final Appendable self, @CallSite.Argument(0) @Nullable final Object param) throws Throwable { - return self.append(param.toString()) - } - } - - void 'test specification builder with multiple method annotations'() { - setup: - final advice = fetchClass(TestWithOtherAnnotations) - final specificationBuilder = new AsmSpecificationBuilder() - - when: - final result = specificationBuilder.build(advice).orElseThrow(RuntimeException::new) - - then: - result.clazz.className == TestWithOtherAnnotations.name - result.advices.size() == 2 - } - - private static List getArguments(final AdviceSpecification advice) { - return advice.arguments.map(it -> it.index).collect(Collectors.toList()) - } - - private static AdviceSpecification findAdvice(final CallSiteSpecification result, final String name) { - return result.advices.find { it.advice.methodName == name } - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/BaseCsiPluginTest.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/BaseCsiPluginTest.groovy deleted file mode 100644 index cae4c42ba96..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/BaseCsiPluginTest.groovy +++ /dev/null @@ -1,202 +0,0 @@ -package datadog.trace.plugin.csi.impl - -import datadog.trace.plugin.csi.HasErrors -import datadog.trace.plugin.csi.ValidationContext -import datadog.trace.plugin.csi.util.MethodType -import groovy.transform.CompileDynamic -import org.objectweb.asm.Type -import spock.lang.Specification - -import java.lang.reflect.Constructor -import java.lang.reflect.Executable -import java.lang.reflect.Method -import java.nio.file.Files -import java.nio.file.Paths -import java.util.stream.Collectors -import datadog.trace.plugin.csi.impl.CallSiteSpecification.ParameterSpecification -import datadog.trace.plugin.csi.impl.CallSiteSpecification.AdviceSpecification -import datadog.trace.plugin.csi.impl.CallSiteSpecification.BeforeSpecification -import datadog.trace.plugin.csi.impl.CallSiteSpecification.AroundSpecification -import datadog.trace.plugin.csi.impl.CallSiteSpecification.AfterSpecification -import datadog.trace.plugin.csi.impl.CallSiteSpecification.ArgumentSpecification - -import static datadog.trace.plugin.csi.impl.CallSiteFactory.pointcutParser -import static datadog.trace.plugin.csi.impl.CallSiteFactory.specificationBuilder -import static datadog.trace.plugin.csi.impl.CallSiteFactory.typeResolver -import static datadog.trace.plugin.csi.util.CallSiteConstants.TYPE_RESOLVER - -@CompileDynamic -abstract class BaseCsiPluginTest extends Specification { - - protected static void assertNoErrors(final HasErrors hasErrors) { - final errors = hasErrors.errors.collect { error -> - "${error.message}: ${error.cause == null ? '-' : error.causeString}" - } - assert errors == [] - } - - protected static File fetchClass(final Class clazz) { - final folder = Paths.get(clazz.getResource('/').toURI()).resolve('../../') - final fileSeparatorPattern = File.separator == "\\" ? "\\\\" : File.separator - final classFile = clazz.getName().replaceAll('\\.', fileSeparatorPattern) + '.class' - final groovy = folder.resolve('groovy/test').resolve(classFile) - if (Files.exists(groovy)) { - return groovy.toFile() - } - return folder.resolve('java/test').resolve(classFile).toFile() - } - - protected static CallSiteSpecification buildClassSpecification(final Class clazz) { - final classFile = fetchClass(clazz) - final spec = specificationBuilder().build(classFile).get() - final pointcutParser = pointcutParser() - spec.advices.each { it.parseSignature(pointcutParser) } - return spec - } - - protected ValidationContext mockValidationContext() { - return Mock(ValidationContext) { - mock -> - mock.getContextProperty(TYPE_RESOLVER) >> typeResolver() - } - } - - protected static BeforeSpecification before(@DelegatesTo(strategy = Closure.DELEGATE_ONLY, value = BeforeAdviceSpecificationBuilder) final Closure cl) { - final spec = new BeforeAdviceSpecificationBuilder() - final code = cl.rehydrate(spec, this, this) - code.resolveStrategy = Closure.DELEGATE_ONLY - code() - return spec.build() - } - - protected static AroundSpecification around(@DelegatesTo(strategy = Closure.DELEGATE_ONLY, value = AroundAdviceSpecificationBuilder) final Closure cl) { - final spec = new AroundAdviceSpecificationBuilder() - final code = cl.rehydrate(spec, this, this) - code.resolveStrategy = Closure.DELEGATE_ONLY - code() - return spec.build() - } - - protected static AfterSpecification after(@DelegatesTo(strategy = Closure.DELEGATE_ONLY, value = AfterAdviceSpecificationBuilder) final Closure cl) { - final spec = new AfterAdviceSpecificationBuilder() - final code = cl.rehydrate(spec, this, this) - code.resolveStrategy = Closure.DELEGATE_ONLY - code() - return spec.build() - } - - private static class BeforeAdviceSpecificationBuilder extends AdviceSpecificationBuilder { - @Override - protected AdviceSpecification build(final MethodType advice, - final Map parameters, - final String signature, - final boolean invokeDynamic) { - return new BeforeSpecification(advice, parameters, signature, invokeDynamic) - } - } - - private static class AroundAdviceSpecificationBuilder extends AdviceSpecificationBuilder { - @Override - protected AroundSpecification build(final MethodType advice, - final Map parameters, - final String signature, - final boolean invokeDynamic) { - return new AroundSpecification(advice, parameters, signature, invokeDynamic) - } - } - - private static class AfterAdviceSpecificationBuilder extends AdviceSpecificationBuilder { - @Override - protected AfterSpecification build(final MethodType advice, - final Map parameters, - final String signature, - final boolean invokeDynamic) { - return new AfterSpecification(advice, parameters, signature, invokeDynamic) - } - } - - private abstract static class AdviceSpecificationBuilder { - protected MethodType advice - protected Map parameters = [:] - protected String signature - protected boolean invokeDynamic - - void advice(@DelegatesTo(strategy = Closure.DELEGATE_ONLY, value = MethodTypeBuilder) final Closure body) { - final spec = new MethodTypeBuilder() - final code = body.rehydrate(spec, this, this) - code.resolveStrategy = Closure.DELEGATE_ONLY - code() - advice = spec.build() - } - - void parameters(final ParameterSpecification... parameters) { - parameters.eachWithIndex { entry, int i -> this.parameters.put(i, entry) } - parameters.grep { it instanceof ArgumentSpecification } - .collect { it as ArgumentSpecification} - .eachWithIndex{ spec, int i -> spec.index = i} - } - - void signature(final String signature) { - this.signature = signature - } - - void invokeDynamic(final boolean invokeDynamic) { - this.invokeDynamic = invokeDynamic - } - - E build() { - final result = build(advice, parameters, signature, invokeDynamic) as E - result.parseSignature(pointcutParser()) - return result - } - - - protected abstract AdviceSpecification build(final MethodType advice, - final Map parameters, - final String signature, - final boolean invokeDynamic) - } - - private static class MethodTypeBuilder { - protected Type owner - protected String methodName - protected Type methodType - - void owner(final Type value) { - owner = value - } - - void owner(final Class value) { - owner = Type.getType(value) - } - - void method(final String value) { - methodName = value - } - - void descriptor(final Type value) { - methodType = value - } - - void descriptor(final Class returnType, final Class... args) { - methodType = Type.getMethodType(Type.getType(returnType), args.collect { Type.getType(it) } as Type[]) - } - - void method(final Executable executable) { - owner = Type.getType(executable.declaringClass) - final args = executable.parameterTypes.collect { Type.getType(it) }.toArray(new Type[0]) as Type[] - if (executable instanceof Constructor) { - methodName = '' - methodType = Type.getMethodType(Type.VOID_TYPE, args) - } else { - final method = executable as Method - methodName = method.name - methodType = Type.getMethodType(Type.getType(method.getReturnType()), args) - } - } - - private MethodType build() { - return new MethodType(owner, methodName, methodType) - } - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/CallSiteSpecificationTest.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/CallSiteSpecificationTest.groovy deleted file mode 100644 index 25a44004759..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/CallSiteSpecificationTest.groovy +++ /dev/null @@ -1,45 +0,0 @@ -package datadog.trace.plugin.csi.impl - -import datadog.trace.agent.tooling.csi.CallSiteAdvice -import datadog.trace.plugin.csi.util.ErrorCode -import org.objectweb.asm.Type -import datadog.trace.plugin.csi.impl.CallSiteSpecification.AdviceSpecification - -class CallSiteSpecificationTest extends BaseCsiPluginTest { - - def 'test call site spi should be an interface'() { - setup: - final context = mockValidationContext() - final spec = new CallSiteSpecification(Type.getType(String), [Mock(AdviceSpecification)], [Type.getType(String)] as Set, [] as List, [] as Set) - - when: - spec.validate(context) - - then: - 1 * context.addError(ErrorCode.CALL_SITE_SPI_SHOULD_BE_AN_INTERFACE, _) - } - - def 'test call site spi should not define any methods'() { - setup: - final context = mockValidationContext() - final spec = new CallSiteSpecification(Type.getType(String), [Mock(AdviceSpecification)], [Type.getType(Comparable)] as Set, [] as List, [] as Set) - - when: - spec.validate(context) - - then: - 1 * context.addError(ErrorCode.CALL_SITE_SPI_SHOULD_BE_EMPTY, _) - } - - def 'test call site should have advices'() { - setup: - final context = mockValidationContext() - final spec = new CallSiteSpecification(Type.getType(String), [], [Type.getType(CallSiteAdvice)] as Set, [] as List, [] as Set) - - when: - spec.validate(context) - - then: - 1 * context.addError(ErrorCode.CALL_SITE_SHOULD_HAVE_ADVICE_METHODS, _) - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/RegexpAdvicePointcutParserTest.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/RegexpAdvicePointcutParserTest.groovy deleted file mode 100644 index e7fdb727b38..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/RegexpAdvicePointcutParserTest.groovy +++ /dev/null @@ -1,136 +0,0 @@ -package datadog.trace.plugin.csi.impl - -import spock.lang.Specification - -final class RegexpAdvicePointcutParserTest extends Specification { - - def 'resolve constructor'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("void datadog.trace.plugin.csi.samples.SignatureParserExample.()") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == '' - signature.methodType.descriptor == '()V' - } - - def 'resolve constructor with args'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("void datadog.trace.plugin.csi.samples.SignatureParserExample.(java.lang.String)") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == '' - signature.methodType.descriptor == '(Ljava/lang/String;)V' - } - - def 'resolve without args'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("java.lang.String datadog.trace.plugin.csi.samples.SignatureParserExample.noParams()") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == 'noParams' - signature.methodType.descriptor == '()Ljava/lang/String;' - } - - def 'resolve one param'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("java.lang.String datadog.trace.plugin.csi.samples.SignatureParserExample.oneParam(java.util.Map)") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == 'oneParam' - signature.methodType.descriptor == '(Ljava/util/Map;)Ljava/lang/String;' - } - - def 'resolve multiple params'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("java.lang.String datadog.trace.plugin.csi.samples.SignatureParserExample.multipleParams(java.lang.String, int, java.util.List)") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == 'multipleParams' - signature.methodType.descriptor == '(Ljava/lang/String;ILjava/util/List;)Ljava/lang/String;' - } - - def 'resolve varargs'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("java.lang.String datadog.trace.plugin.csi.samples.SignatureParserExample.varargs(java.lang.String[])") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == 'varargs' - signature.methodType.descriptor == '([Ljava/lang/String;)Ljava/lang/String;' - } - - def 'resolve primitive'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("int datadog.trace.plugin.csi.samples.SignatureParserExample.primitive()") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == 'primitive' - signature.methodType.descriptor == '()I' - } - - def 'resolve primitive array type'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("byte[] datadog.trace.plugin.csi.samples.SignatureParserExample.primitiveArray()") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == 'primitiveArray' - signature.methodType.descriptor == '()[B' - } - - def 'resolve object array type'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("java.lang.Object[] datadog.trace.plugin.csi.samples.SignatureParserExample.objectArray()") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == 'objectArray' - signature.methodType.descriptor == '()[Ljava/lang/Object;' - } - - def 'resolve multi dimensional object array type'() { - setup: - final pointcutParser = new RegexpAdvicePointcutParser() - - when: - final signature = pointcutParser.parse("java.lang.Object[][][] datadog.trace.plugin.csi.samples.SignatureParserExample.objectArray()") - - then: - signature.owner.className == 'datadog.trace.plugin.csi.samples.SignatureParserExample' - signature.methodName == 'objectArray' - signature.methodType.descriptor == '()[[[Ljava/lang/Object;' - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/TypeResolverPoolTest.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/TypeResolverPoolTest.groovy deleted file mode 100644 index ffeeb6b7f40..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/TypeResolverPoolTest.groovy +++ /dev/null @@ -1,109 +0,0 @@ -package datadog.trace.plugin.csi.impl - -import datadog.trace.plugin.csi.util.MethodType -import org.objectweb.asm.Type -import spock.lang.Specification - -import javax.servlet.ServletRequest -import javax.servlet.http.HttpServletRequest - -class TypeResolverPoolTest extends Specification { - - def 'test resolve primitive'() { - setup: - final resolver = new TypeResolverPool() - - when: - final result = resolver.resolveType(Type.INT_TYPE) - - then: - result == int.class - } - - def 'test resolve primitive array'() { - setup: - final resolver = new TypeResolverPool() - final type = Type.getType('[I') - - when: - final result = resolver.resolveType(type) - - then: - result == int[].class - } - - def 'test resolve primitive multidimensional array'() { - setup: - final resolver = new TypeResolverPool() - final type = Type.getType('[[[I') - - when: - final result = resolver.resolveType(type) - - then: - result == int[][][].class - } - - def 'test resolve class'() { - setup: - final resolver = new TypeResolverPool() - final type = Type.getType(String) - - when: - final result = resolver.resolveType(type) - - then: - result == String - } - - - def 'test resolve class array'() { - setup: - final resolver = new TypeResolverPool() - final type = Type.getType(String[]) - - when: - final result = resolver.resolveType(type) - - then: - result == String[] - } - - def 'test resolve class multidimensional array'() { - setup: - final resolver = new TypeResolverPool() - final type = Type.getType(String[][][]) - - when: - final result = resolver.resolveType(type) - - then: - result == String[][][] - } - - def 'test type resolver from method'() { - setup: - final resolver = new TypeResolverPool() - final type = Type.getMethodType(Type.getType(String[]), Type.getType(String), Type.getType(String)) - - when: - final result = resolver.resolveType(type.getReturnType()) - - then: - result == String[] - } - - def 'test inherited methods'() { - setup: - final resolver = new TypeResolverPool() - final owner = Type.getType(HttpServletRequest) - final name = 'getParameter' - final descriptor = Type.getMethodType(Type.getType(String), Type.getType(String)) - - when: - final result = resolver.resolveMethod(new MethodType(owner, name, descriptor)) - - then: - result == ServletRequest.getDeclaredMethod('getParameter', String) - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/assertion/AdviceAssert.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/assertion/AdviceAssert.groovy deleted file mode 100644 index d83120a6c52..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/assertion/AdviceAssert.groovy +++ /dev/null @@ -1,23 +0,0 @@ -package datadog.trace.plugin.csi.impl.assertion - -class AdviceAssert { - protected String type - protected String owner - protected String method - protected String descriptor - protected Collection statements - - void type(String type) { - assert type == this.type - } - - void pointcut(String owner, String method, String descriptor) { - assert owner == this.owner - assert method == this.method - assert descriptor == this.descriptor - } - - void statements(String... values) { - assert values.toList() == statements - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/assertion/AssertBuilder.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/assertion/AssertBuilder.groovy deleted file mode 100644 index 35c2456ea85..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/assertion/AssertBuilder.groovy +++ /dev/null @@ -1,121 +0,0 @@ -package datadog.trace.plugin.csi.impl.assertion - -import com.github.javaparser.JavaParser -import com.github.javaparser.ParserConfiguration -import com.github.javaparser.ast.CompilationUnit -import com.github.javaparser.ast.Node -import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration -import com.github.javaparser.ast.body.MethodDeclaration -import com.github.javaparser.ast.expr.MethodCallExpr -import com.github.javaparser.symbolsolver.JavaSymbolSolver -import datadog.trace.agent.tooling.csi.CallSites - -import java.lang.reflect.Executable -import java.lang.reflect.Method - -import static datadog.trace.plugin.csi.impl.CallSiteFactory.typeResolver -import static datadog.trace.plugin.csi.util.CallSiteUtils.classNameToType - -class AssertBuilder { - private final File file - - AssertBuilder(final File file) { - this.file = file - } - - C build() { - final javaFile = parseJavaFile(file) - assert javaFile.parsed == Node.Parsedness.PARSED - final targetType = javaFile.primaryType.get().asClassOrInterfaceDeclaration() - final interfaces = getInterfaces(targetType) - def (enabled, enabledArgs) = getEnabledDeclaration(targetType, interfaces) - return (C) new CallSiteAssert([ - interfaces : getInterfaces(targetType), - spi : getSpi(targetType), - helpers : getHelpers(targetType), - advices : getAdvices(targetType), - enabled : enabled, - enabledArgs: enabledArgs - ]) - } - - protected Set> getSpi(final ClassOrInterfaceDeclaration type) { - return type.getAnnotationByName('AutoService').get().asNormalAnnotationExpr() - .collect { it.pairs.find { it.name.toString() == 'value' }.value.asArrayInitializerExpr() } - .collectMany { it.getValues() } - .collect { it.asClassExpr().getType().resolve().typeDeclaration.get().clazz } - .toSet() - } - - protected Set> getInterfaces(final ClassOrInterfaceDeclaration type) { - return type.asClassOrInterfaceDeclaration().implementedTypes.collect { - final resolved = it.asClassOrInterfaceType().resolve() - return resolved.typeDeclaration.get().clazz - }.toSet() - } - - protected def getEnabledDeclaration(final ClassOrInterfaceDeclaration type, final Set> interfaces) { - if (!interfaces.contains(CallSites.HasEnabledProperty)) { - return [null, null] - } - final isEnabled = type.getMethodsByName('isEnabled').first() - // JavaParser's NodeList has method getFirst() returning an Optional, however with Java 21's - // SequencedCollection, Groovy picks the getFirst() that returns the object itself. - // Using `first()` rather than `first` picks the groovy method instead, fixing the situation. - final returnStatement = isEnabled.body.get().statements.first().asReturnStmt() - final enabledMethodCall = returnStatement.expression.get().asMethodCallExpr() - final enabled = resolveMethod(enabledMethodCall) - final enabledArgs = enabledMethodCall.getArguments().collect { it.asStringLiteralExpr().asString() }.toSet() - return [enabled, enabledArgs] - } - - protected Set> getHelpers(final ClassOrInterfaceDeclaration type) { - final acceptMethod = type.getMethodsByName('accept').first() - final methodCalls = getMethodCalls(acceptMethod) - return methodCalls.findAll { - it.nameAsString == 'addHelpers' - }.collectMany { - it.arguments - }.collect { - typeResolver().resolveType(classNameToType(it.asStringLiteralExpr().asString())) - }.toSet() - } - - protected List getAdvices(final ClassOrInterfaceDeclaration type) { - final acceptMethod = type.getMethodsByName('accept').first() - return getMethodCalls(acceptMethod).findAll { - it.nameAsString == 'addAdvice' - }.collect { - final adviceType = it.arguments.get(0).asFieldAccessExpr().getName() - def (owner, method, descriptor) = it.arguments.subList(1, 4)*.asStringLiteralExpr()*.asString() - final handlerLambda = it.arguments[4].asLambdaExpr() - final advice = handlerLambda.body.asBlockStmt().statements*.toString() - return new AdviceAssert([ - type : adviceType, - owner : owner, - method : method, - descriptor: descriptor, - statements: advice - ]) - } - } - - protected static List getMethodCalls(final MethodDeclaration method) { - return method.body.get().asBlockStmt().getStatements().findAll { - it.isExpressionStmt() && it.asExpressionStmt().getExpression().isMethodCallExpr() - }.collect { - it.asExpressionStmt().getExpression().asMethodCallExpr() - } - } - - private static Executable resolveMethod(final MethodCallExpr methodCallExpr) { - final resolved = methodCallExpr.resolve() - return resolved.@method as Method - } - - private static CompilationUnit parseJavaFile(final File file) throws FileNotFoundException { - final JavaSymbolSolver solver = new JavaSymbolSolver(typeResolver()) - final JavaParser parser = new JavaParser(new ParserConfiguration().setSymbolResolver(solver)) - return parser.parse(file).getResult().get() - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/assertion/CallSiteAssert.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/assertion/CallSiteAssert.groovy deleted file mode 100644 index fdd557adc95..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/assertion/CallSiteAssert.groovy +++ /dev/null @@ -1,42 +0,0 @@ -package datadog.trace.plugin.csi.impl.assertion - -import java.lang.reflect.Method - -import static java.util.Arrays.asList - -class CallSiteAssert { - - protected Set> interfaces - protected Set> spi - protected Set> helpers - protected Collection advices - protected Method enabled - protected Set enabledArgs - - void interfaces(Class... values) { - assertSameElements(interfaces, values) - } - - void helpers(Class... values) { - assertSameElements(helpers, values) - } - - void spi(Class...values) { - assertSameElements(spi, values) - } - - void advices(int index, @DelegatesTo(AdviceAssert) Closure closure) { - final asserter = advices[index] - closure.delegate = asserter - closure(asserter) - } - - void enabled(Method method, String... args) { - assert method == enabled - assertSameElements(enabledArgs, args) - } - - private static void assertSameElements(final Set expected, final E...received) { - assert received.length == expected.size() && expected.containsAll(asList(received)) - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/ext/IastExtensionTest.groovy b/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/ext/IastExtensionTest.groovy deleted file mode 100644 index bcd28fad4e4..00000000000 --- a/buildSrc/call-site-instrumentation-plugin/src/test/groovy/datadog/trace/plugin/csi/impl/ext/IastExtensionTest.groovy +++ /dev/null @@ -1,244 +0,0 @@ -package datadog.trace.plugin.csi.impl.ext - -import com.github.javaparser.JavaParser -import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration -import com.github.javaparser.ast.stmt.IfStmt -import datadog.trace.agent.tooling.csi.CallSites -import datadog.trace.plugin.csi.AdviceGenerator -import datadog.trace.plugin.csi.PluginApplication.Configuration -import datadog.trace.plugin.csi.impl.AdviceGeneratorImpl -import datadog.trace.plugin.csi.impl.BaseCsiPluginTest -import datadog.trace.plugin.csi.impl.CallSiteSpecification -import datadog.trace.plugin.csi.impl.assertion.AdviceAssert -import datadog.trace.plugin.csi.impl.assertion.AssertBuilder -import datadog.trace.plugin.csi.impl.assertion.CallSiteAssert -import datadog.trace.plugin.csi.impl.ext.tests.IastExtensionCallSite -import datadog.trace.plugin.csi.impl.ext.tests.SourceTypes -import groovy.transform.CompileDynamic -import spock.lang.TempDir - -import java.nio.file.Files -import java.nio.file.Path -import java.nio.file.Paths - -import static datadog.trace.plugin.csi.impl.CallSiteFactory.pointcutParser -import static datadog.trace.plugin.csi.util.CallSiteUtils.classNameToType - -@CompileDynamic -class IastExtensionTest extends BaseCsiPluginTest { - - @TempDir - private File buildDir - private Path targetFolder - private Path projectFolder - private Path srcFolder - - void setup() { - targetFolder = buildDir.toPath().resolve('target') - Files.createDirectories(targetFolder) - projectFolder = buildDir.toPath().resolve('project') - Files.createDirectories(projectFolder) - srcFolder = projectFolder.resolve('src/main/java') - Files.createDirectories(srcFolder) - } - - void 'test that extension only applies to iast advices'() { - setup: - final type = classNameToType(typeName) - final callSite = Mock(CallSiteSpecification) { - getSpi() >> type - } - final extension = new IastExtension() - - when: - final applies = extension.appliesTo(callSite) - - then: - applies == expected - - where: - typeName | expected - CallSites.name | false - IastExtension.IAST_CALL_SITES_FQCN | true - } - - void 'test that extension generates a call site with telemetry'() { - setup: - final config = Mock(Configuration) { - getTargetFolder() >> targetFolder - getSrcFolder() >> getCallSiteSrcFolder() - getClassPath() >> [] - } - final spec = buildClassSpecification(IastExtensionCallSite) - final generator = buildAdviceGenerator(buildDir) - final result = generator.generate(spec) - if (!result.success) { - throw new IllegalArgumentException("Error with call site ${IastExtensionCallSite}") - } - final extension = new IastExtension() - - when: - extension.apply(config, result) - - then: 'the call site provider is modified with telemetry' - assertNoErrors result - assertCallSites(result.file) { - advices(0) { - pointcut('javax/servlet/http/HttpServletRequest', 'getHeader', '(Ljava/lang/String;)Ljava/lang/String;') - instrumentedMetric('IastMetric.INSTRUMENTED_SOURCE') { - metricStatements('IastMetricCollector.add(IastMetric.INSTRUMENTED_SOURCE, (byte) 3, 1);') - } - executedMetric('IastMetric.EXECUTED_SOURCE') { - metricStatements( - 'handler.field(net.bytebuddy.jar.asm.Opcodes.GETSTATIC, "datadog/trace/api/iast/telemetry/IastMetric", "EXECUTED_SOURCE", "Ldatadog/trace/api/iast/telemetry/IastMetric;");', - 'handler.instruction(net.bytebuddy.jar.asm.Opcodes.ICONST_3);', - 'handler.instruction(net.bytebuddy.jar.asm.Opcodes.ICONST_1);', - 'handler.method(net.bytebuddy.jar.asm.Opcodes.INVOKESTATIC, "datadog/trace/api/iast/telemetry/IastMetricCollector", "add", "(Ldatadog/trace/api/iast/telemetry/IastMetric;BI)V", false);' - ) - } - } - advices(1) { - pointcut('javax/servlet/http/HttpServletRequest', 'getInputStream', '()Ljavax/servlet/ServletInputStream;') - instrumentedMetric('IastMetric.INSTRUMENTED_SOURCE') { - metricStatements('IastMetricCollector.add(IastMetric.INSTRUMENTED_SOURCE, (byte) 127, 1);') - } - executedMetric('IastMetric.EXECUTED_SOURCE') { - metricStatements( - 'handler.field(net.bytebuddy.jar.asm.Opcodes.GETSTATIC, "datadog/trace/api/iast/telemetry/IastMetric", "EXECUTED_SOURCE", "Ldatadog/trace/api/iast/telemetry/IastMetric;");', - 'handler.instruction(net.bytebuddy.jar.asm.Opcodes.BIPUSH, 127);', - 'handler.instruction(net.bytebuddy.jar.asm.Opcodes.ICONST_1);', - 'handler.method(net.bytebuddy.jar.asm.Opcodes.INVOKESTATIC, "datadog/trace/api/iast/telemetry/IastMetricCollector", "add", "(Ldatadog/trace/api/iast/telemetry/IastMetric;BI)V", false);' - ) - } - } - advices(2) { - pointcut('javax/servlet/ServletRequest', 'getReader', '()Ljava/io/BufferedReader;') - instrumentedMetric('IastMetric.INSTRUMENTED_PROPAGATION') { - metricStatements('IastMetricCollector.add(IastMetric.INSTRUMENTED_PROPAGATION, 1);') - } - executedMetric('IastMetric.EXECUTED_PROPAGATION') { - metricStatements( - 'handler.field(net.bytebuddy.jar.asm.Opcodes.GETSTATIC, "datadog/trace/api/iast/telemetry/IastMetric", "EXECUTED_PROPAGATION", "Ldatadog/trace/api/iast/telemetry/IastMetric;");', - 'handler.instruction(net.bytebuddy.jar.asm.Opcodes.ICONST_1);', - 'handler.method(net.bytebuddy.jar.asm.Opcodes.INVOKESTATIC, "datadog/trace/api/iast/telemetry/IastMetricCollector", "add", "(Ldatadog/trace/api/iast/telemetry/IastMetric;I)V", false);' - ) - } - } - } - } - - private static AdviceGenerator buildAdviceGenerator(final File targetFolder) { - return new AdviceGeneratorImpl(targetFolder, pointcutParser()) - } - - private static Path getCallSiteSrcFolder() { - final file = Thread.currentThread().contextClassLoader.getResource('') - return Paths.get(file.toURI()).resolve('../../../../src/test/java') - } - - private static ClassOrInterfaceDeclaration parse(final File path) { - final parsedAdvice = new JavaParser().parse(path).getResult().get() - return parsedAdvice.primaryType.get().asClassOrInterfaceDeclaration() - } - - private static void assertCallSites(final File generated, @DelegatesTo(IastExtensionCallSiteAssert) final Closure closure) { - final asserter = new IastExtensionAssertBuilder(generated).build() - closure.delegate = asserter - closure(asserter) - } - - static class IastExtensionCallSiteAssert extends CallSiteAssert { - - IastExtensionCallSiteAssert(CallSiteAssert base) { - interfaces = base.interfaces - spi = base.spi - helpers = base.helpers - advices = base.advices - enabled = base.enabled - enabledArgs = base.enabledArgs - } - - void advices(int index, @DelegatesTo(IastExtensionAdviceAssert) Closure closure) { - final asserter = advices[index] - closure.delegate = asserter - closure(asserter) - } - - void advices(@DelegatesTo(IastExtensionAdviceAssert) Closure closure) { - advices.each { - closure.delegate = it - closure(it) - } - } - } - - static class IastExtensionAdviceAssert extends AdviceAssert { - - protected IastExtensionMetricAsserter instrumented - protected IastExtensionMetricAsserter executed - - void instrumentedMetric(final String metric, @DelegatesTo(IastExtensionMetricAsserter) Closure closure) { - assert metric == instrumented.metric - closure.delegate = instrumented - closure(instrumented) - } - - void executedMetric(final String metric, @DelegatesTo(IastExtensionMetricAsserter) Closure closure) { - assert metric == executed.metric - closure.delegate = executed - closure(executed) - } - } - - static class IastExtensionMetricAsserter { - protected String metric - protected Collection statements - - void metricStatements(String... values) { - assert values.toList() == statements - } - } - - static class IastExtensionAssertBuilder extends AssertBuilder { - - IastExtensionAssertBuilder(File file) { - super(file) - } - - @Override - IastExtensionCallSiteAssert build() { - final base = super.build() - return new IastExtensionCallSiteAssert(base) - } - - @Override - protected List getAdvices(ClassOrInterfaceDeclaration type) { - final acceptMethod = type.getMethodsByName('accept').first() - return getMethodCalls(acceptMethod).findAll { - it.nameAsString == 'addAdvice' - }.collect { - def (owner, method, descriptor) = it.arguments.subList(1, 4)*.asStringLiteralExpr()*.asString() - final handlerLambda = it.arguments[4].asLambdaExpr() - final statements = handlerLambda.body.asBlockStmt().statements - final instrumentedStmt = statements.get(0).asIfStmt() - final executedStmt = statements.get(1).asIfStmt() - return new IastExtensionAdviceAssert([ - owner : owner, - method : method, - descriptor: descriptor, - instrumented : buildMetricAsserter(instrumentedStmt), - executed: buildMetricAsserter(executedStmt), - statements: statements.findAll { !it.isIfStmt() } - ]) - } - } - - protected IastExtensionMetricAsserter buildMetricAsserter(final IfStmt ifStmt) { - final condition = ifStmt.getCondition().asMethodCallExpr() - return new IastExtensionMetricAsserter( - metric: condition.getScope().get().toString(), - statements: ifStmt.getThenStmt().asBlockStmt().statements*.toString() - ) - } - } -} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/AdviceGeneratorTest.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/AdviceGeneratorTest.java new file mode 100644 index 00000000000..aa5f0cd6467 --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/AdviceGeneratorTest.java @@ -0,0 +1,513 @@ +package datadog.trace.plugin.csi.impl; + +import static datadog.trace.plugin.csi.impl.CallSiteFactory.pointcutParser; + +import datadog.trace.agent.tooling.csi.CallSite; +import datadog.trace.agent.tooling.csi.CallSites; +import datadog.trace.plugin.csi.AdviceGenerator; +import datadog.trace.plugin.csi.AdviceGenerator.CallSiteResult; +import datadog.trace.plugin.csi.impl.assertion.AssertBuilder; +import datadog.trace.plugin.csi.impl.assertion.CallSiteAssert; +import datadog.trace.plugin.csi.impl.ext.tests.IastCallSites; +import datadog.trace.plugin.csi.impl.ext.tests.RaspCallSites; +import java.io.File; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.net.URL; +import java.util.Map; +import javax.servlet.ServletRequest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.api.io.TempDir; + +class AdviceGeneratorTest extends BaseCsiPluginTest { + + @TempDir private File buildDir; + + @CallSite(spi = CallSites.class) + public static class BeforeAdvice { + @CallSite.Before( + "java.security.MessageDigest java.security.MessageDigest.getInstance(java.lang.String)") + public static void before(@CallSite.Argument String algorithm) {} + } + + @Test + void testBeforeAdvice() { + CallSiteSpecification spec = buildClassSpecification(BeforeAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.interfaces(CallSites.class); + asserter.helpers(BeforeAdvice.class); + asserter.advices( + 0, + advice -> { + advice.type("BEFORE"); + advice.pointcut( + "java/security/MessageDigest", + "getInstance", + "(Ljava/lang/String;)Ljava/security/MessageDigest;"); + advice.statements( + "handler.dupParameters(descriptor, StackDupMode.COPY);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$BeforeAdvice\", \"before\", \"(Ljava/lang/String;)V\");", + "handler.method(opcode, owner, name, descriptor, isInterface);"); + }); + } + + @CallSite(spi = CallSites.class) + public static class AroundAdvice { + @CallSite.Around( + "java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)") + public static String around( + @CallSite.This String self, + @CallSite.Argument String regexp, + @CallSite.Argument String replacement) { + return self.replaceAll(regexp, replacement); + } + } + + @Test + void testAroundAdvice() { + CallSiteSpecification spec = buildClassSpecification(AroundAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.interfaces(CallSites.class); + asserter.helpers(AroundAdvice.class); + asserter.advices( + 0, + advice -> { + advice.type("AROUND"); + advice.pointcut( + "java/lang/String", + "replaceAll", + "(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;"); + advice.statements( + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AroundAdvice\", \"around\", \"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;\");"); + }); + } + + @CallSite(spi = CallSites.class) + public static class AfterAdvice { + @CallSite.After("java.lang.String java.lang.String.concat(java.lang.String)") + public static String after( + @CallSite.This String self, + @CallSite.Argument String param, + @CallSite.Return String result) { + return result; + } + } + + @Test + void testAfterAdvice() { + CallSiteSpecification spec = buildClassSpecification(AfterAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.interfaces(CallSites.class); + asserter.helpers(AfterAdvice.class); + asserter.advices( + 0, + advice -> { + advice.type("AFTER"); + advice.pointcut("java/lang/String", "concat", "(Ljava/lang/String;)Ljava/lang/String;"); + advice.statements( + "handler.dupInvoke(owner, descriptor, StackDupMode.COPY);", + "handler.method(opcode, owner, name, descriptor, isInterface);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdvice\", \"after\", \"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;\");"); + }); + } + + @CallSite(spi = CallSites.class) + public static class AfterAdviceCtor { + @CallSite.After("void java.net.URL.(java.lang.String)") + public static URL after(@CallSite.AllArguments Object[] args, @CallSite.Return URL url) { + return url; + } + } + + @Test + void testAfterAdviceCtor() { + CallSiteSpecification spec = buildClassSpecification(AfterAdviceCtor.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.interfaces(CallSites.class); + asserter.helpers(AfterAdviceCtor.class); + asserter.advices( + 0, + advice -> { + advice.pointcut("java/net/URL", "", "(Ljava/lang/String;)V"); + advice.statements( + "handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY_CTOR);", + "handler.method(opcode, owner, name, descriptor, isInterface);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdviceCtor\", \"after\", \"([Ljava/lang/Object;Ljava/net/URL;)Ljava/net/URL;\");"); + }); + } + + @CallSite(spi = SpiAdvice.SampleSpi.class) + public static class SpiAdvice { + @CallSite.Before( + "java.security.MessageDigest java.security.MessageDigest.getInstance(java.lang.String)") + public static void before(@CallSite.Argument String algorithm) {} + + interface SampleSpi {} + } + + @Test + void testGeneratorWithSpi() { + CallSiteSpecification spec = buildClassSpecification(SpiAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.interfaces(CallSites.class, SpiAdvice.SampleSpi.class); + } + + @CallSite(spi = CallSites.class) + public static class InvokeDynamicAfterAdvice { + @CallSite.After( + value = + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + invokeDynamic = true) + public static String after( + @CallSite.AllArguments Object[] arguments, @CallSite.Return String result) { + return result; + } + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_9) + void testInvokeDynamicAfterAdvice() { + CallSiteSpecification spec = buildClassSpecification(InvokeDynamicAfterAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.interfaces(CallSites.class); + asserter.helpers(InvokeDynamicAfterAdvice.class); + asserter.advices( + 0, + advice -> { + advice.pointcut( + "java/lang/invoke/StringConcatFactory", + "makeConcatWithConstants", + "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;"); + advice.statements( + "handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY);", + "handler.invokeDynamic(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicAfterAdvice\", \"after\", \"([Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;\");"); + }); + } + + @CallSite(spi = CallSites.class) + public static class InvokeDynamicAroundAdvice { + @CallSite.Around( + value = + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + invokeDynamic = true) + public static java.lang.invoke.CallSite around( + @CallSite.Argument MethodHandles.Lookup lookup, + @CallSite.Argument String name, + @CallSite.Argument MethodType concatType, + @CallSite.Argument String recipe, + @CallSite.Argument Object... constants) { + return null; + } + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_9) + void testInvokeDynamicAroundAdvice() { + CallSiteSpecification spec = buildClassSpecification(InvokeDynamicAroundAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.interfaces(CallSites.class); + asserter.helpers(InvokeDynamicAroundAdvice.class); + asserter.advices( + 0, + advice -> { + advice.pointcut( + "java/lang/invoke/StringConcatFactory", + "makeConcatWithConstants", + "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;"); + advice.statements( + "handler.invokeDynamic(name, descriptor, new Handle(Opcodes.H_INVOKESTATIC, \"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicAroundAdvice\", \"around\", \"(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;\", false), bootstrapMethodArguments);"); + }); + } + + @CallSite(spi = CallSites.class) + public static class InvokeDynamicWithConstantsAdvice { + @CallSite.After( + value = + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + invokeDynamic = true) + public static String after( + @CallSite.AllArguments Object[] arguments, + @CallSite.Return String result, + @CallSite.InvokeDynamicConstants Object[] constants) { + return result; + } + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_9) + void testInvokeDynamicWithConstantsAdvice() { + CallSiteSpecification spec = buildClassSpecification(InvokeDynamicWithConstantsAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.interfaces(CallSites.class); + asserter.helpers(InvokeDynamicWithConstantsAdvice.class); + asserter.advices( + 0, + advice -> { + advice.pointcut( + "java/lang/invoke/StringConcatFactory", + "makeConcatWithConstants", + "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;"); + advice.statements( + "handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY);", + "handler.invokeDynamic(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);", + "handler.loadConstantArray(bootstrapMethodArguments);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicWithConstantsAdvice\", \"after\", \"([Ljava/lang/Object;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;\");"); + }); + } + + @CallSite(spi = CallSites.class) + public static class ArrayAdvice { + @CallSite.AfterArray({ + @CallSite.After("java.util.Map javax.servlet.ServletRequest.getParameterMap()"), + @CallSite.After("java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()") + }) + public static Map after( + @CallSite.This ServletRequest request, @CallSite.Return Map parameters) { + return parameters; + } + } + + @Test + void testArrayAdvice() { + CallSiteSpecification spec = buildClassSpecification(ArrayAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.advices( + 0, + advice -> { + advice.pointcut("javax/servlet/ServletRequest", "getParameterMap", "()Ljava/util/Map;"); + }); + asserter.advices( + 1, + advice -> { + advice.pointcut( + "javax/servlet/ServletRequestWrapper", "getParameterMap", "()Ljava/util/Map;"); + }); + } + + public static class MinJavaVersionCheck { + public static boolean isAtLeast(String version) { + return Integer.parseInt(version) >= 9; + } + } + + @CallSite( + spi = CallSites.class, + enabled = { + "datadog.trace.plugin.csi.impl.AdviceGeneratorTest$MinJavaVersionCheck", + "isAtLeast", + "18" + }) + public static class MinJavaVersionAdvice { + @CallSite.After("java.lang.String java.lang.String.concat(java.lang.String)") + public static String after( + @CallSite.This String self, + @CallSite.Argument String param, + @CallSite.Return String result) { + return result; + } + } + + @Test + void testCustomEnabledProperty() throws Exception { + CallSiteSpecification spec = buildClassSpecification(MinJavaVersionAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.interfaces(CallSites.class, CallSites.HasEnabledProperty.class); + asserter.enabled(MinJavaVersionCheck.class.getDeclaredMethod("isAtLeast", String.class), "18"); + } + + @CallSite(spi = CallSites.class) + public static class PartialArgumentsBeforeAdvice { + @CallSite.Before("int java.sql.Statement.executeUpdate(java.lang.String, java.lang.String[])") + public static void before(@CallSite.Argument(0) String arg1) {} + + @CallSite.Before( + "java.lang.String java.lang.String.format(java.lang.String, java.lang.Object[])") + public static void before(@CallSite.Argument(1) Object[] arg) {} + + @CallSite.Before("java.lang.CharSequence java.lang.String.subSequence(int, int)") + public static void before(@CallSite.This String thiz, @CallSite.Argument(0) int arg) {} + } + + @Test + void partialArgumentsWithBeforeAdvice() { + CallSiteSpecification spec = buildClassSpecification(PartialArgumentsBeforeAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.advices( + 0, + advice -> { + advice.pointcut( + "java/sql/Statement", "executeUpdate", "(Ljava/lang/String;[Ljava/lang/String;)I"); + advice.statements( + "int[] parameterIndices = new int[] { 0 };", + "handler.dupParameters(descriptor, parameterIndices, owner);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice\", \"before\", \"(Ljava/lang/String;)V\");", + "handler.method(opcode, owner, name, descriptor, isInterface);"); + }); + asserter.advices( + 1, + advice -> { + advice.pointcut( + "java/lang/String", + "format", + "(Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;"); + advice.statements( + "int[] parameterIndices = new int[] { 1 };", + "handler.dupParameters(descriptor, parameterIndices, null);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice\", \"before\", \"([Ljava/lang/Object;)V\");", + "handler.method(opcode, owner, name, descriptor, isInterface);"); + }); + asserter.advices( + 2, + advice -> { + advice.pointcut("java/lang/String", "subSequence", "(II)Ljava/lang/CharSequence;"); + advice.statements( + "int[] parameterIndices = new int[] { 0 };", + "handler.dupInvoke(owner, descriptor, parameterIndices);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice\", \"before\", \"(Ljava/lang/String;I)V\");", + "handler.method(opcode, owner, name, descriptor, isInterface);"); + }); + } + + @CallSite(spi = CallSites.class) + public static class SuperTypeReturnAdvice { + @CallSite.After("void java.lang.StringBuilder.(java.lang.String)") + public static Object after( + @CallSite.AllArguments Object[] args, @CallSite.Return Object result) { + return result; + } + } + + @Test + void testReturningSuperType() { + CallSiteSpecification spec = buildClassSpecification(SuperTypeReturnAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.advices( + 0, + advice -> { + advice.pointcut("java/lang/StringBuilder", "", "(Ljava/lang/String;)V"); + advice.statements( + "handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY_CTOR);", + "handler.method(opcode, owner, name, descriptor, isInterface);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$SuperTypeReturnAdvice\", \"after\", \"([Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;\");", + "handler.instruction(Opcodes.CHECKCAST, \"java/lang/StringBuilder\");"); + }); + } + + @CallSite(spi = {IastCallSites.class, RaspCallSites.class}) + public static class MultipleSpiClassesAdvice { + @CallSite.After("void java.lang.StringBuilder.(java.lang.String)") + public static Object after( + @CallSite.AllArguments Object[] args, @CallSite.Return Object result) { + return result; + } + } + + @Test + void testMultipleSpiClasses() { + CallSiteSpecification spec = buildClassSpecification(MultipleSpiClassesAdvice.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.spi(IastCallSites.class, RaspCallSites.class); + } + + @CallSite(spi = CallSites.class) + public static class AfterAdviceWithVoidReturn { + @CallSite.After("void java.lang.StringBuilder.setLength(int)") + public static void after(@CallSite.This StringBuilder self, @CallSite.Argument(0) int length) {} + } + + @Test + void testAfterAdviceWithVoidReturn() { + CallSiteSpecification spec = buildClassSpecification(AfterAdviceWithVoidReturn.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + + CallSiteResult result = generator.generate(spec); + + assertNoErrors(result); + CallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.advices( + 0, + advice -> { + advice.pointcut("java/lang/StringBuilder", "setLength", "(I)V"); + advice.statements( + "handler.dupInvoke(owner, descriptor, StackDupMode.COPY);", + "handler.method(opcode, owner, name, descriptor, isInterface);", + "handler.advice(\"datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdviceWithVoidReturn\", \"after\", \"(Ljava/lang/StringBuilder;I)V\");"); + }); + } + + private static AdviceGenerator buildAdviceGenerator(File targetFolder) { + return new AdviceGeneratorImpl(targetFolder, pointcutParser()); + } + + private static CallSiteAssert assertCallSites(File generated) { + return new AssertBuilder(generated).build(); + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/AdviceSpecificationTest.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/AdviceSpecificationTest.java new file mode 100644 index 00000000000..3c9fe3cacb5 --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/AdviceSpecificationTest.java @@ -0,0 +1,723 @@ +package datadog.trace.plugin.csi.impl; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import datadog.trace.agent.tooling.csi.CallSite; +import datadog.trace.agent.tooling.csi.CallSites; +import datadog.trace.plugin.csi.HasErrors.Failure; +import datadog.trace.plugin.csi.ValidationContext; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.AfterSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.AllArgsSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.ArgumentSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.AroundSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.BeforeSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.InvokeDynamicConstantsSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.ParameterSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.ReturnSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.ThisSpecification; +import datadog.trace.plugin.csi.util.ErrorCode; +import datadog.trace.plugin.csi.util.MethodType; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Method; +import java.security.MessageDigest; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import javax.servlet.ServletRequest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.objectweb.asm.Type; + +class AdviceSpecificationTest extends BaseCsiPluginTest { + + @CallSite(spi = CallSites.class) + static class EmptyAdvice {} + + @Test + void testClassGeneratorErrorCallSiteWithoutAdvices() { + ValidationContext context = mockValidationContext(); + CallSiteSpecification spec = buildClassSpecification(EmptyAdvice.class); + + spec.validate(context); + verify(context).addError(eq(ErrorCode.CALL_SITE_SHOULD_HAVE_ADVICE_METHODS), any()); + } + + @CallSite(spi = CallSites.class) + static class NonPublicStaticMethodAdvice { + @CallSite.Before("void java.lang.Runnable.run()") + private void advice(@CallSite.This Runnable run) {} + } + + @Test + void testClassGeneratorErrorNonPublicStaticMethod() { + ValidationContext context = mockValidationContext(); + CallSiteSpecification spec = buildClassSpecification(NonPublicStaticMethodAdvice.class); + + spec.getAdvices().forEach(it -> it.validate(context)); + + verify(context).addError(eq(ErrorCode.ADVICE_METHOD_NOT_STATIC_AND_PUBLIC), any()); + } + + static class BeforeStringConcat { + static void concat(String self, String value) {} + } + + static Stream adviceClassShouldBeOnClasspathProvider() { + return Stream.of( + Arguments.of(Type.getType("Lfoo/bar/FooBar;"), 1), + Arguments.of(Type.getType(BeforeStringConcat.class), 0)); + } + + @ParameterizedTest + @MethodSource("adviceClassShouldBeOnClasspathProvider") + void testAdviceClassShouldBeOnTheClasspath(Type type, int errors) throws Exception { + ValidationContext context = mockValidationContext(); + BeforeSpecification spec = + createBeforeSpec( + BeforeStringConcat.class.getDeclaredMethod("concat", String.class, String.class), + type, + Arrays.asList(new ThisSpecification(), new ArgumentSpecification()), + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context, times(errors)) + .addError( + argThat((Failure failure) -> failure.getErrorCode() == ErrorCode.UNRESOLVED_TYPE)); + } + + static Stream beforeAdviceShouldReturnVoidProvider() { + return Stream.of(Arguments.of(String.class, 1), Arguments.of(void.class, 0)); + } + + @ParameterizedTest + @MethodSource("beforeAdviceShouldReturnVoidProvider") + void testBeforeAdviceShouldReturnVoid(Class returnType, int errors) { + ValidationContext context = mockValidationContext(); + BeforeSpecification spec = + createBeforeSpec( + BeforeStringConcat.class, + "concat", + returnType, + new Class[] {String.class, String.class}, + Arrays.asList(new ThisSpecification(), new ArgumentSpecification()), + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context, times(errors)).addError(eq(ErrorCode.ADVICE_BEFORE_SHOULD_RETURN_VOID), any()); + } + + static class AroundStringConcat { + static String concat(String self, String value) { + return self.concat(value); + } + } + + static Stream aroundAdviceReturnTypeProvider() { + return Stream.of( + Arguments.of(MessageDigest.class, 1), + Arguments.of(Object.class, 0), + Arguments.of(String.class, 0)); + } + + @ParameterizedTest + @MethodSource("aroundAdviceReturnTypeProvider") + void testAroundAdviceShouldReturnTypeCompatibleWithPointcut(Class returnType, int errors) { + ValidationContext context = mockValidationContext(); + AroundSpecification spec = + createAroundSpec( + AroundStringConcat.class, + "concat", + returnType, + new Class[] {String.class, String.class}, + Arrays.asList(new ThisSpecification(), new ArgumentSpecification()), + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context, times(errors)) + .addError(eq(ErrorCode.ADVICE_METHOD_RETURN_NOT_COMPATIBLE), any()); + } + + static class AfterStringConcat { + static String concat(String self, String value, String result) { + return result; + } + } + + static Stream afterAdviceReturnTypeProvider() { + return Stream.of( + Arguments.of(MessageDigest.class, 1), + Arguments.of(Object.class, 0), + Arguments.of(String.class, 0)); + } + + @ParameterizedTest + @MethodSource("afterAdviceReturnTypeProvider") + void testAfterAdviceShouldReturnTypeCompatibleWithPointcut(Class returnType, int errors) { + ValidationContext context = mockValidationContext(); + AfterSpecification spec = + createAfterSpec( + AfterStringConcat.class, + "concat", + returnType, + new Class[] {String.class, String.class, String.class}, + Arrays.asList( + new ThisSpecification(), new ArgumentSpecification(), new ReturnSpecification()), + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context, times(errors)) + .addError(eq(ErrorCode.ADVICE_METHOD_RETURN_NOT_COMPATIBLE), any()); + } + + static Stream thisParameterShouldBeFirstProvider() { + return Stream.of( + Arguments.of(Arrays.asList(new ThisSpecification(), new ArgumentSpecification()), 0), + Arguments.of(Arrays.asList(new ArgumentSpecification(), new ThisSpecification()), 1)); + } + + @ParameterizedTest + @MethodSource("thisParameterShouldBeFirstProvider") + void testThisParameterShouldAlwaysBeTheFirst(List params, int errors) + throws Exception { + ValidationContext context = mockValidationContext(); + AroundSpecification spec = + createAroundSpec( + AroundStringConcat.class.getDeclaredMethod("concat", String.class, String.class), + params, + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context, times(errors)) + .addError(eq(ErrorCode.ADVICE_PARAMETER_THIS_SHOULD_BE_FIRST), any()); + } + + static Stream thisParameterCompatibilityProvider() { + return Stream.of( + Arguments.of(MessageDigest.class, 1), + Arguments.of(Object.class, 0), + Arguments.of(String.class, 0)); + } + + @ParameterizedTest + @MethodSource("thisParameterCompatibilityProvider") + void testThisParameterShouldBeCompatibleWithPointcut(Class type, int errors) { + ValidationContext context = mockValidationContext(); + AroundSpecification spec = + createAroundSpec( + AroundStringConcat.class, + "concat", + String.class, + new Class[] {type, String.class}, + Arrays.asList(new ThisSpecification(), new ArgumentSpecification()), + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context, times(errors)) + .addError(eq(ErrorCode.ADVICE_METHOD_PARAM_THIS_NOT_COMPATIBLE), any()); + if (type != String.class) { + verify(context) + .addError( + argThat((Failure failure) -> failure.getErrorCode() == ErrorCode.UNRESOLVED_METHOD)); + } + } + + static Stream returnParameterShouldBeLastProvider() { + return Stream.of( + Arguments.of( + Arrays.asList( + new ThisSpecification(), new ArgumentSpecification(), new ReturnSpecification()), + 0), + Arguments.of( + Arrays.asList( + new ThisSpecification(), new ReturnSpecification(), new ArgumentSpecification()), + 1)); + } + + @ParameterizedTest + @MethodSource("returnParameterShouldBeLastProvider") + void testReturnParameterShouldAlwaysBeTheLast(List params, int errors) + throws Exception { + ValidationContext context = mockValidationContext(); + AfterSpecification spec = + createAfterSpec( + AfterStringConcat.class.getDeclaredMethod( + "concat", String.class, String.class, String.class), + params, + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context, times(errors)) + .addError(eq(ErrorCode.ADVICE_PARAMETER_RETURN_SHOULD_BE_LAST), any()); + } + + static Stream returnParameterCompatibilityProvider() { + return Stream.of( + Arguments.of(MessageDigest.class, 1), + Arguments.of(String.class, 0), + Arguments.of(Object.class, 0)); + } + + @ParameterizedTest + @MethodSource("returnParameterCompatibilityProvider") + void testReturnParameterShouldBeCompatibleWithPointcut(Class returnType, int errors) { + ValidationContext context = mockValidationContext(); + AfterSpecification spec = + createAfterSpec( + AfterStringConcat.class, + "concat", + String.class, + new Class[] {String.class, String.class, returnType}, + Arrays.asList( + new ThisSpecification(), new ArgumentSpecification(), new ReturnSpecification()), + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context, times(errors)) + .addError(eq(ErrorCode.ADVICE_METHOD_PARAM_RETURN_NOT_COMPATIBLE), any()); + if (returnType != String.class) { + verify(context) + .addError( + argThat((Failure failure) -> failure.getErrorCode() == ErrorCode.UNRESOLVED_METHOD)); + } + } + + static Stream argumentParameterCompatibilityProvider() { + return Stream.of( + Arguments.of(MessageDigest.class, 1), + Arguments.of(String.class, 0), + Arguments.of(Object.class, 0)); + } + + @ParameterizedTest + @MethodSource("argumentParameterCompatibilityProvider") + void testArgumentParameterShouldBeCompatibleWithPointcut(Class parameterType, int errors) { + ValidationContext context = mockValidationContext(); + AfterSpecification spec = + createAfterSpec( + AfterStringConcat.class, + "concat", + String.class, + new Class[] {String.class, parameterType, String.class}, + Arrays.asList( + new ThisSpecification(), new ArgumentSpecification(), new ReturnSpecification()), + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context, times(errors)) + .addError(eq(ErrorCode.ADVICE_METHOD_PARAM_NOT_COMPATIBLE), any()); + if (parameterType != String.class) { + verify(context) + .addError( + argThat((Failure failure) -> failure.getErrorCode() == ErrorCode.UNRESOLVED_METHOD)); + } + } + + static class BadAfterStringConcat { + static String concat(String param1, String param2) { + return param2; + } + } + + static Stream afterAdviceRequiresThisAndReturnProvider() { + return Stream.of( + Arguments.of( + Arrays.asList(new ArgumentSpecification(), new ReturnSpecification()), + ErrorCode.ADVICE_AFTER_SHOULD_HAVE_THIS), + Arguments.of( + Arrays.asList(new ThisSpecification(), new ArgumentSpecification()), + ErrorCode.ADVICE_AFTER_SHOULD_HAVE_RETURN)); + } + + @ParameterizedTest + @MethodSource("afterAdviceRequiresThisAndReturnProvider") + void testAfterAdviceRequiresThisAndReturnParameters( + List params, ErrorCode error) throws Exception { + ValidationContext context = mockValidationContext(); + AfterSpecification spec = + createAfterSpec( + BadAfterStringConcat.class.getDeclaredMethod("concat", String.class, String.class), + params, + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context).addError(eq(error), any()); + } + + static class BadAllArgsAfterStringConcat { + static String concat(Object[] param1, String param2, String param3) { + return param3; + } + } + + @Test + void shouldNotMixAllArgumentsAndArgument() throws Exception { + ValidationContext context = mockValidationContext(); + AllArgsSpecification allArgs = new AllArgsSpecification(); + allArgs.setIncludeThis(true); + AfterSpecification spec = + createAfterSpec( + BadAllArgsAfterStringConcat.class.getDeclaredMethod( + "concat", Object[].class, String.class, String.class), + Arrays.asList(allArgs, new ArgumentSpecification(), new ReturnSpecification()), + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context).addError(eq(ErrorCode.ADVICE_PARAMETER_ALL_ARGS_MIXED), any()); + verify(context).addError(eq(ErrorCode.ADVICE_PARAMETER_ARGUMENT_OUT_OF_BOUNDS), any()); + } + + static class TestInheritedMethod { + static String after(ServletRequest request, String parameter, String value) { + return value; + } + } + + @Test + void testInheritedMethods() throws Exception { + ValidationContext context = mockValidationContext(); + AfterSpecification spec = + createAfterSpec( + TestInheritedMethod.class.getDeclaredMethod( + "after", ServletRequest.class, String.class, String.class), + Arrays.asList( + new ThisSpecification(), new ArgumentSpecification(), new ReturnSpecification()), + "java.lang.String javax.servlet.http.HttpServletRequest.getParameter(java.lang.String)"); + + spec.validate(context); + } + + static class TestInvokeDynamicConstants { + static Object after(Object[] parameter, Object result, Object[] constants) { + return result; + } + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_9) + void testInvokeDynamicConstants() throws Exception { + ValidationContext context = mockValidationContext(); + AfterSpecification spec = + createAfterSpec( + TestInvokeDynamicConstants.class.getDeclaredMethod( + "after", Object[].class, Object.class, Object[].class), + Arrays.asList( + new AllArgsSpecification(), + new ReturnSpecification(), + new InvokeDynamicConstantsSpecification()), + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + true); + + spec.validate(context); + } + + static Stream invokeDynamicConstantsShouldBeLastProvider() { + return Stream.of( + Arguments.of( + Arrays.asList( + new AllArgsSpecification(), + new ReturnSpecification(), + new InvokeDynamicConstantsSpecification()), + null), + Arguments.of( + Arrays.asList( + new AllArgsSpecification(), + new InvokeDynamicConstantsSpecification(), + new ReturnSpecification()), + ErrorCode.ADVICE_PARAMETER_INVOKE_DYNAMIC_CONSTANTS_SHOULD_BE_LAST)); + } + + @ParameterizedTest + @MethodSource("invokeDynamicConstantsShouldBeLastProvider") + @EnabledForJreRange(min = JRE.JAVA_9) + void testInvokeDynamicConstantsShouldBeLast(List params, ErrorCode error) + throws Exception { + ValidationContext context = mockValidationContext(); + AfterSpecification spec = + createAfterSpec( + TestInvokeDynamicConstants.class.getDeclaredMethod( + "after", Object[].class, Object.class, Object[].class), + params, + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + true); + + spec.validate(context); + if (error != null) { + verify(context).addError(eq(error), any()); + } + } + + static class TestInvokeDynamicConstantsNonInvokeDynamic { + static Object after(Object self, Object[] parameter, Object value, Object[] constants) { + return value; + } + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_9) + void testInvokeDynamicConstantsOnNonInvokeDynamicPointcut() throws Exception { + ValidationContext context = mockValidationContext(); + AfterSpecification spec = + createAfterSpec( + TestInvokeDynamicConstantsNonInvokeDynamic.class.getDeclaredMethod( + "after", Object.class, Object[].class, Object.class, Object[].class), + Arrays.asList( + new ThisSpecification(), + new AllArgsSpecification(), + new InvokeDynamicConstantsSpecification(), + new ReturnSpecification()), + "java.lang.String java.lang.String.concat(java.lang.String)"); + + spec.validate(context); + verify(context) + .addError( + eq(ErrorCode.ADVICE_PARAMETER_INVOKE_DYNAMIC_CONSTANTS_ON_NON_INVOKE_DYNAMIC), any()); + } + + static class TestInvokeDynamicConstantsBefore { + static void before(Object[] parameter, Object[] constants) {} + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_9) + void testInvokeDynamicConstantsOnNonAfterAdvice() throws Exception { + ValidationContext context = mockValidationContext(); + BeforeSpecification spec = + createBeforeSpec( + TestInvokeDynamicConstantsBefore.class.getDeclaredMethod( + "before", Object[].class, Object[].class), + Arrays.asList(new AllArgsSpecification(), new InvokeDynamicConstantsSpecification()), + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + true); + + spec.validate(context); + verify(context) + .addError(eq(ErrorCode.ADVICE_PARAMETER_INVOKE_DYNAMIC_CONSTANTS_NON_AFTER_ADVICE), any()); + } + + static class TestInvokeDynamicConstantsAround { + static java.lang.invoke.CallSite around( + MethodHandles.Lookup lookup, + String name, + java.lang.invoke.MethodType concatType, + String recipe, + Object... constants) { + return null; + } + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_9) + void testInvokeDynamicOnAroundAdvice() throws Exception { + ValidationContext context = mockValidationContext(); + AroundSpecification spec = + createAroundSpec( + TestInvokeDynamicConstantsAround.class.getDeclaredMethod( + "around", + MethodHandles.Lookup.class, + String.class, + java.lang.invoke.MethodType.class, + String.class, + Object[].class), + Arrays.asList( + new ArgumentSpecification(), + new ArgumentSpecification(), + new ArgumentSpecification(), + new ArgumentSpecification(), + new ArgumentSpecification()), + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + true); + + spec.validate(context); + } + + @CallSite(spi = CallSites.class) + static class AfterWithVoidWrongAdvice { + @CallSite.After("void java.lang.String.getChars(int, int, char[], int)") + static String after(@CallSite.AllArguments Object[] args, @CallSite.Return String result) { + return result; + } + } + + @Test + void testAfterAdviceWithVoidShouldNotUseReturn() { + ValidationContext context = mockValidationContext(); + CallSiteSpecification spec = buildClassSpecification(AfterWithVoidWrongAdvice.class); + + spec.getAdvices().forEach(it -> it.validate(context)); + + verify(context).addError(eq(ErrorCode.ADVICE_AFTER_VOID_METHOD_SHOULD_RETURN_VOID), any()); + verify(context).addError(eq(ErrorCode.ADVICE_AFTER_VOID_METHOD_SHOULD_NOT_HAVE_RETURN), any()); + } + + // Helper methods to create specifications + private BeforeSpecification createBeforeSpec( + Method method, List params, String signature) { + return createBeforeSpec(method, null, params, signature, false); + } + + private BeforeSpecification createBeforeSpec( + Method method, List params, String signature, boolean invokeDynamic) { + return createBeforeSpec(method, null, params, signature, invokeDynamic); + } + + private BeforeSpecification createBeforeSpec( + Method method, Type ownerOverride, List params, String signature) { + return createBeforeSpec(method, ownerOverride, params, signature, false); + } + + private BeforeSpecification createBeforeSpec( + Method method, + Type ownerOverride, + List params, + String signature, + boolean invokeDynamic) { + Type owner = ownerOverride != null ? ownerOverride : Type.getType(method.getDeclaringClass()); + Type[] argTypes = + Arrays.stream(method.getParameterTypes()).map(Type::getType).toArray(Type[]::new); + Type returnType = Type.getType(method.getReturnType()); + MethodType methodType = + new MethodType(owner, method.getName(), Type.getMethodType(returnType, argTypes)); + Map paramMap = new HashMap<>(); + for (int i = 0; i < params.size(); i++) { + paramMap.put(i, params.get(i)); + } + updateArgumentIndices(paramMap); + BeforeSpecification spec = + new BeforeSpecification(methodType, paramMap, signature, invokeDynamic); + spec.parseSignature(CallSiteFactory.pointcutParser()); + return spec; + } + + private BeforeSpecification createBeforeSpec( + Class clazz, + String methodName, + Class returnType, + Class[] argTypes, + List params, + String signature) { + Type owner = Type.getType(clazz); + Type[] argTypesAsm = Arrays.stream(argTypes).map(Type::getType).toArray(Type[]::new); + Type returnTypeAsm = Type.getType(returnType); + MethodType methodType = + new MethodType(owner, methodName, Type.getMethodType(returnTypeAsm, argTypesAsm)); + Map paramMap = new HashMap<>(); + for (int i = 0; i < params.size(); i++) { + paramMap.put(i, params.get(i)); + } + updateArgumentIndices(paramMap); + BeforeSpecification spec = new BeforeSpecification(methodType, paramMap, signature, false); + spec.parseSignature(CallSiteFactory.pointcutParser()); + return spec; + } + + private AroundSpecification createAroundSpec( + Method method, List params, String signature) { + return createAroundSpec(method, params, signature, false); + } + + private AroundSpecification createAroundSpec( + Method method, List params, String signature, boolean invokeDynamic) { + Type owner = Type.getType(method.getDeclaringClass()); + Type[] argTypes = + Arrays.stream(method.getParameterTypes()).map(Type::getType).toArray(Type[]::new); + Type returnType = Type.getType(method.getReturnType()); + MethodType methodType = + new MethodType(owner, method.getName(), Type.getMethodType(returnType, argTypes)); + Map paramMap = new HashMap<>(); + for (int i = 0; i < params.size(); i++) { + paramMap.put(i, params.get(i)); + } + updateArgumentIndices(paramMap); + AroundSpecification spec = + new AroundSpecification(methodType, paramMap, signature, invokeDynamic); + spec.parseSignature(CallSiteFactory.pointcutParser()); + return spec; + } + + private AroundSpecification createAroundSpec( + Class clazz, + String methodName, + Class returnType, + Class[] argTypes, + List params, + String signature) { + Type owner = Type.getType(clazz); + Type[] argTypesAsm = Arrays.stream(argTypes).map(Type::getType).toArray(Type[]::new); + Type returnTypeAsm = Type.getType(returnType); + MethodType methodType = + new MethodType(owner, methodName, Type.getMethodType(returnTypeAsm, argTypesAsm)); + Map paramMap = new HashMap<>(); + for (int i = 0; i < params.size(); i++) { + paramMap.put(i, params.get(i)); + } + updateArgumentIndices(paramMap); + AroundSpecification spec = new AroundSpecification(methodType, paramMap, signature, false); + spec.parseSignature(CallSiteFactory.pointcutParser()); + return spec; + } + + private AfterSpecification createAfterSpec( + Method method, List params, String signature) { + return createAfterSpec(method, params, signature, false); + } + + private AfterSpecification createAfterSpec( + Method method, List params, String signature, boolean invokeDynamic) { + Type owner = Type.getType(method.getDeclaringClass()); + Type[] argTypes = + Arrays.stream(method.getParameterTypes()).map(Type::getType).toArray(Type[]::new); + Type returnType = Type.getType(method.getReturnType()); + MethodType methodType = + new MethodType(owner, method.getName(), Type.getMethodType(returnType, argTypes)); + Map paramMap = new HashMap<>(); + for (int i = 0; i < params.size(); i++) { + paramMap.put(i, params.get(i)); + } + updateArgumentIndices(paramMap); + AfterSpecification spec = + new AfterSpecification(methodType, paramMap, signature, invokeDynamic); + spec.parseSignature(CallSiteFactory.pointcutParser()); + return spec; + } + + private AfterSpecification createAfterSpec( + Class clazz, + String methodName, + Class returnType, + Class[] argTypes, + List params, + String signature) { + Type owner = Type.getType(clazz); + Type[] argTypesAsm = Arrays.stream(argTypes).map(Type::getType).toArray(Type[]::new); + Type returnTypeAsm = Type.getType(returnType); + MethodType methodType = + new MethodType(owner, methodName, Type.getMethodType(returnTypeAsm, argTypesAsm)); + Map paramMap = new HashMap<>(); + for (int i = 0; i < params.size(); i++) { + paramMap.put(i, params.get(i)); + } + updateArgumentIndices(paramMap); + AfterSpecification spec = new AfterSpecification(methodType, paramMap, signature, false); + spec.parseSignature(CallSiteFactory.pointcutParser()); + return spec; + } + + private void updateArgumentIndices(Map paramMap) { + int index = 0; + for (ParameterSpecification param : paramMap.values()) { + if (param instanceof ArgumentSpecification) { + ((ArgumentSpecification) param).setIndex(index++); + } + } + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/AsmSpecificationBuilderTest.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/AsmSpecificationBuilderTest.java new file mode 100644 index 00000000000..ec37662ff58 --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/AsmSpecificationBuilderTest.java @@ -0,0 +1,591 @@ +package datadog.trace.plugin.csi.impl; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import datadog.trace.agent.tooling.csi.CallSite; +import datadog.trace.agent.tooling.csi.CallSites; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.AdviceSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.AfterSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.AllArgsSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.AroundSpecification; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.BeforeSpecification; +import datadog.trace.plugin.csi.util.Types; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.io.File; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import javax.servlet.ServletRequest; +import org.junit.jupiter.api.Test; +import org.objectweb.asm.Type; + +class AsmSpecificationBuilderTest extends BaseCsiPluginTest { + + static class NonCallSite {} + + @Test + void testSpecificationBuilderForNonCallSite() { + File advice = fetchClass(NonCallSite.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + Optional result = specificationBuilder.build(advice); + + assertFalse(result.isPresent()); + } + + @CallSite(spi = WithSpiClass.Spi.class) + static class WithSpiClass { + interface Spi {} + } + + @Test + void testSpecificationBuilderWithCustomSpiClass() { + File advice = fetchClass(WithSpiClass.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals( + Arrays.asList(Type.getType(WithSpiClass.Spi.class)), Arrays.asList(result.getSpi())); + } + + @CallSite( + spi = CallSites.class, + helpers = {HelpersAdvice.SampleHelper1.class, HelpersAdvice.SampleHelper2.class}) + static class HelpersAdvice { + static class SampleHelper1 {} + + static class SampleHelper2 {} + } + + @Test + void testSpecificationBuilderWithCustomHelperClasses() { + File advice = fetchClass(HelpersAdvice.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + List helpers = Arrays.asList(result.getHelpers()); + assertTrue( + helpers.containsAll( + Arrays.asList( + Type.getType(HelpersAdvice.class), + Type.getType(HelpersAdvice.SampleHelper1.class), + Type.getType(HelpersAdvice.SampleHelper2.class)))); + } + + @CallSite(spi = CallSites.class) + static class BeforeAdvice { + @CallSite.Before( + "java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)") + static void before( + @CallSite.This String self, + @CallSite.Argument String regexp, + @CallSite.Argument String replacement) {} + } + + @Test + void testSpecificationBuilderForBeforeAdvice() { + File advice = fetchClass(BeforeAdvice.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(BeforeAdvice.class.getName(), result.getClazz().getClassName()); + BeforeSpecification beforeSpec = (BeforeSpecification) findAdvice(result, "before"); + assertNotNull(beforeSpec); + assertEquals( + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V", + beforeSpec.getAdvice().getMethodType().getDescriptor()); + assertEquals( + "java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)", + beforeSpec.getSignature()); + assertNotNull(beforeSpec.findThis()); + assertNull(beforeSpec.findReturn()); + assertNull(beforeSpec.findAllArguments()); + assertNull(beforeSpec.findInvokeDynamicConstants()); + List arguments = getArguments(beforeSpec); + assertEquals(Arrays.asList(0, 1), arguments); + } + + @CallSite(spi = CallSites.class) + static class AroundAdvice { + @CallSite.Around( + "java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)") + static String around( + @CallSite.This String self, + @CallSite.Argument String regexp, + @CallSite.Argument String replacement) { + return self.replaceAll(regexp, replacement); + } + } + + @Test + void testSpecificationBuilderForAroundAdvice() { + File advice = fetchClass(AroundAdvice.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(AroundAdvice.class.getName(), result.getClazz().getClassName()); + AroundSpecification aroundSpec = (AroundSpecification) findAdvice(result, "around"); + assertNotNull(aroundSpec); + assertEquals( + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;", + aroundSpec.getAdvice().getMethodType().getDescriptor()); + assertEquals( + "java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)", + aroundSpec.getSignature()); + assertNotNull(aroundSpec.findThis()); + assertNull(aroundSpec.findReturn()); + assertNull(aroundSpec.findAllArguments()); + assertNull(aroundSpec.findInvokeDynamicConstants()); + List arguments = getArguments(aroundSpec); + assertEquals(Arrays.asList(0, 1), arguments); + } + + @CallSite(spi = CallSites.class) + static class AfterAdvice { + @CallSite.After( + "java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)") + static String after( + @CallSite.This String self, + @CallSite.Argument String regexp, + @CallSite.Argument String replacement, + @CallSite.Return String result) { + return result; + } + } + + @Test + void testSpecificationBuilderForAfterAdvice() { + File advice = fetchClass(AfterAdvice.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(AfterAdvice.class.getName(), result.getClazz().getClassName()); + AfterSpecification afterSpec = (AfterSpecification) findAdvice(result, "after"); + assertNotNull(afterSpec); + assertEquals( + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;", + afterSpec.getAdvice().getMethodType().getDescriptor()); + assertEquals( + "java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)", + afterSpec.getSignature()); + assertNotNull(afterSpec.findThis()); + assertNotNull(afterSpec.findReturn()); + assertNull(afterSpec.findAllArguments()); + assertNull(afterSpec.findInvokeDynamicConstants()); + List arguments = getArguments(afterSpec); + assertEquals(Arrays.asList(0, 1), arguments); + } + + @CallSite(spi = CallSites.class) + static class AllArgsAdvice { + @CallSite.Around( + "java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)") + static String allArgs( + @CallSite.AllArguments(includeThis = true) Object[] arguments, + @CallSite.Return String result) { + return result; + } + } + + @Test + void testSpecificationBuilderForAdviceWithAllArguments() { + File advice = fetchClass(AllArgsAdvice.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(AllArgsAdvice.class.getName(), result.getClazz().getClassName()); + AroundSpecification allArgsSpec = (AroundSpecification) findAdvice(result, "allArgs"); + assertNotNull(allArgsSpec); + assertEquals( + "([Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;", + allArgsSpec.getAdvice().getMethodType().getDescriptor()); + assertEquals( + "java.lang.String java.lang.String.replaceAll(java.lang.String, java.lang.String)", + allArgsSpec.getSignature()); + assertNull(allArgsSpec.findThis()); + assertNotNull(allArgsSpec.findReturn()); + AllArgsSpecification allArguments = allArgsSpec.findAllArguments(); + assertNotNull(allArguments); + assertTrue(allArguments.isIncludeThis()); + assertNull(allArgsSpec.findInvokeDynamicConstants()); + List arguments = getArguments(allArgsSpec); + assertEquals(Arrays.asList(), arguments); + } + + @CallSite(spi = CallSites.class) + static class InvokeDynamicBeforeAdvice { + @CallSite.After( + value = + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + invokeDynamic = true) + static String invokeDynamic( + @CallSite.AllArguments Object[] arguments, @CallSite.Return String result) { + return result; + } + } + + @Test + void testSpecificationBuilderForBeforeInvokeDynamic() { + File advice = fetchClass(InvokeDynamicBeforeAdvice.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(InvokeDynamicBeforeAdvice.class.getName(), result.getClazz().getClassName()); + AfterSpecification invokeDynamicSpec = (AfterSpecification) findAdvice(result, "invokeDynamic"); + assertNotNull(invokeDynamicSpec); + assertEquals( + "([Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;", + invokeDynamicSpec.getAdvice().getMethodType().getDescriptor()); + assertEquals( + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + invokeDynamicSpec.getSignature()); + assertNull(invokeDynamicSpec.findThis()); + assertNotNull(invokeDynamicSpec.findReturn()); + AllArgsSpecification allArguments = invokeDynamicSpec.findAllArguments(); + assertNotNull(allArguments); + assertFalse(allArguments.isIncludeThis()); + assertNull(invokeDynamicSpec.findInvokeDynamicConstants()); + List arguments = getArguments(invokeDynamicSpec); + assertEquals(Arrays.asList(), arguments); + } + + @CallSite(spi = CallSites.class) + static class InvokeDynamicAroundAdvice { + @CallSite.Around( + value = + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + invokeDynamic = true) + static java.lang.invoke.CallSite invokeDynamic( + @CallSite.Argument MethodHandles.Lookup lookup, + @CallSite.Argument String name, + @CallSite.Argument MethodType concatType, + @CallSite.Argument String recipe, + @CallSite.Argument Object... constants) { + return null; + } + } + + @Test + void testSpecificationBuilderForAroundInvokeDynamic() { + File advice = fetchClass(InvokeDynamicAroundAdvice.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(InvokeDynamicAroundAdvice.class.getName(), result.getClazz().getClassName()); + AroundSpecification invokeDynamicSpec = + (AroundSpecification) findAdvice(result, "invokeDynamic"); + assertNotNull(invokeDynamicSpec); + assertEquals( + "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;", + invokeDynamicSpec.getAdvice().getMethodType().getDescriptor()); + assertEquals( + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + invokeDynamicSpec.getSignature()); + assertNull(invokeDynamicSpec.findThis()); + assertNull(invokeDynamicSpec.findReturn()); + assertNull(invokeDynamicSpec.findAllArguments()); + assertNull(invokeDynamicSpec.findInvokeDynamicConstants()); + List arguments = getArguments(invokeDynamicSpec); + assertEquals(Arrays.asList(0, 1, 2, 3, 4), arguments); + } + + @CallSite(spi = CallSites.class) + static class TestInvokeDynamicConstants { + @CallSite.After( + value = + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + invokeDynamic = true) + static String after( + @CallSite.AllArguments Object[] parameter, + @CallSite.InvokeDynamicConstants Object[] constants, + @CallSite.Return String value) { + return value; + } + } + + @Test + void testInvokeDynamicConstants() { + File advice = fetchClass(TestInvokeDynamicConstants.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(TestInvokeDynamicConstants.class.getName(), result.getClazz().getClassName()); + AfterSpecification inheritedSpec = (AfterSpecification) findAdvice(result, "after"); + assertNotNull(inheritedSpec); + assertEquals( + "([Ljava/lang/Object;[Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;", + inheritedSpec.getAdvice().getMethodType().getDescriptor()); + assertEquals( + "java.lang.invoke.CallSite java.lang.invoke.StringConcatFactory.makeConcatWithConstants(java.lang.invoke.MethodHandles$Lookup, java.lang.String, java.lang.invoke.MethodType, java.lang.String, java.lang.Object[])", + inheritedSpec.getSignature()); + assertNull(inheritedSpec.findThis()); + assertNotNull(inheritedSpec.findReturn()); + assertNotNull(inheritedSpec.findInvokeDynamicConstants()); + List arguments = getArguments(inheritedSpec); + assertEquals(Arrays.asList(), arguments); + } + + @CallSite(spi = CallSites.class) + static class TestBeforeArray { + + @CallSite.BeforeArray({ + @CallSite.Before("java.util.Map javax.servlet.ServletRequest.getParameterMap()"), + @CallSite.Before("java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()") + }) + static void before(@CallSite.This ServletRequest request) {} + } + + @Test + void testSpecificationBuilderForBeforeAdviceArray() { + File advice = fetchClass(TestBeforeArray.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(TestBeforeArray.class.getName(), result.getClazz().getClassName()); + List list = result.getAdvices(); + assertEquals(2, list.size()); + for (AdviceSpecification spec : list) { + assertInstanceOf(BeforeSpecification.class, spec); + assertEquals( + "(Ljavax/servlet/ServletRequest;)V", spec.getAdvice().getMethodType().getDescriptor()); + assertTrue( + spec.getSignature().equals("java.util.Map javax.servlet.ServletRequest.getParameterMap()") + || spec.getSignature() + .equals("java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()")); + assertNotNull(spec.findThis()); + assertNull(spec.findReturn()); + assertNull(spec.findAllArguments()); + assertNull(spec.findInvokeDynamicConstants()); + List arguments = getArguments(spec); + assertEquals(Arrays.asList(), arguments); + } + } + + @CallSite(spi = CallSites.class) + static class TestAroundArray { + + @CallSite.AroundArray({ + @CallSite.Around("java.util.Map javax.servlet.ServletRequest.getParameterMap()"), + @CallSite.Around("java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()") + }) + static Map around(@CallSite.This ServletRequest request) { + return request.getParameterMap(); + } + } + + @Test + void testSpecificationBuilderForAroundAdviceArray() { + File advice = fetchClass(TestAroundArray.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(TestAroundArray.class.getName(), result.getClazz().getClassName()); + List list = result.getAdvices(); + assertEquals(2, list.size()); + for (AdviceSpecification spec : list) { + assertInstanceOf(AroundSpecification.class, spec); + assertEquals( + "(Ljavax/servlet/ServletRequest;)Ljava/util/Map;", + spec.getAdvice().getMethodType().getDescriptor()); + assertTrue( + spec.getSignature().equals("java.util.Map javax.servlet.ServletRequest.getParameterMap()") + || spec.getSignature() + .equals("java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()")); + assertNotNull(spec.findThis()); + assertNull(spec.findReturn()); + assertNull(spec.findAllArguments()); + assertNull(spec.findInvokeDynamicConstants()); + List arguments = getArguments(spec); + assertEquals(Arrays.asList(), arguments); + } + } + + @CallSite(spi = CallSites.class) + static class TestAfterArray { + + @CallSite.AfterArray({ + @CallSite.After("java.util.Map javax.servlet.ServletRequest.getParameterMap()"), + @CallSite.After("java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()") + }) + static Map after(@CallSite.This ServletRequest request, @CallSite.Return Map parameters) { + return parameters; + } + } + + @Test + void testSpecificationBuilderForAfterAdviceArray() { + File advice = fetchClass(TestAfterArray.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(TestAfterArray.class.getName(), result.getClazz().getClassName()); + List list = result.getAdvices(); + assertEquals(2, list.size()); + for (AdviceSpecification spec : list) { + assertInstanceOf(AfterSpecification.class, spec); + assertEquals( + "(Ljavax/servlet/ServletRequest;Ljava/util/Map;)Ljava/util/Map;", + spec.getAdvice().getMethodType().getDescriptor()); + assertTrue( + spec.getSignature().equals("java.util.Map javax.servlet.ServletRequest.getParameterMap()") + || spec.getSignature() + .equals("java.util.Map javax.servlet.ServletRequestWrapper.getParameterMap()")); + assertNotNull(spec.findThis()); + assertNotNull(spec.findReturn()); + assertNull(spec.findAllArguments()); + assertNull(spec.findInvokeDynamicConstants()); + List arguments = getArguments(spec); + assertEquals(Arrays.asList(), arguments); + } + } + + @CallSite(spi = CallSites.class) + static class TestInheritedMethod { + @CallSite.After( + "java.lang.String javax.servlet.http.HttpServletRequest.getParameter(java.lang.String)") + static String after( + @CallSite.This ServletRequest request, + @CallSite.Argument String parameter, + @CallSite.Return String value) { + return value; + } + } + + @Test + void testSpecificationBuilderForInheritedMethods() { + File advice = fetchClass(TestInheritedMethod.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(TestInheritedMethod.class.getName(), result.getClazz().getClassName()); + AfterSpecification inheritedSpec = (AfterSpecification) findAdvice(result, "after"); + assertNotNull(inheritedSpec); + assertEquals( + "(Ljavax/servlet/ServletRequest;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;", + inheritedSpec.getAdvice().getMethodType().getDescriptor()); + assertEquals( + "java.lang.String javax.servlet.http.HttpServletRequest.getParameter(java.lang.String)", + inheritedSpec.getSignature()); + assertNotNull(inheritedSpec.findThis()); + assertNotNull(inheritedSpec.findReturn()); + assertNull(inheritedSpec.findAllArguments()); + assertNull(inheritedSpec.findInvokeDynamicConstants()); + List arguments = getArguments(inheritedSpec); + assertEquals(Arrays.asList(0), arguments); + } + + static class IsEnabled { + static boolean isEnabled(String defaultValue) { + return true; + } + } + + @CallSite( + spi = CallSites.class, + enabled = { + "datadog.trace.plugin.csi.impl.AsmSpecificationBuilderTest$IsEnabled", + "isEnabled", + "true" + }) + static class TestEnablement { + @CallSite.After( + "java.lang.String javax.servlet.http.HttpServletRequest.getParameter(java.lang.String)") + static String after( + @CallSite.This ServletRequest request, + @CallSite.Argument String parameter, + @CallSite.Return String value) { + return value; + } + } + + @Test + void testSpecificationBuilderWithEnabledProperty() { + File advice = fetchClass(TestEnablement.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(TestEnablement.class.getName(), result.getClazz().getClassName()); + assertNotNull(result.getEnabled()); + assertEquals(Type.getType(IsEnabled.class), result.getEnabled().getMethod().getOwner()); + assertEquals("isEnabled", result.getEnabled().getMethod().getMethodName()); + assertEquals( + Type.getMethodType(Types.BOOLEAN, Types.STRING), + result.getEnabled().getMethod().getMethodType()); + assertEquals(Arrays.asList("true"), result.getEnabled().getArguments()); + } + + @CallSite(spi = CallSites.class) + static class TestWithOtherAnnotations { + @CallSite.Around("java.lang.StringBuilder java.lang.StringBuilder.append(java.lang.Object)") + @CallSite.Around("java.lang.StringBuffer java.lang.StringBuffer.append(java.lang.Object)") + @Nonnull + @SuppressFBWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + static Appendable aroundAppend( + @CallSite.This @Nullable Appendable self, @CallSite.Argument(0) @Nullable Object param) + throws Throwable { + return self.append(param.toString()); + } + } + + @Test + void testSpecificationBuilderWithMultipleMethodAnnotations() { + File advice = fetchClass(TestWithOtherAnnotations.class); + AsmSpecificationBuilder specificationBuilder = new AsmSpecificationBuilder(); + + CallSiteSpecification result = + specificationBuilder.build(advice).orElseThrow(RuntimeException::new); + + assertEquals(TestWithOtherAnnotations.class.getName(), result.getClazz().getClassName()); + assertEquals(2, result.getAdvices().size()); + } + + private static List getArguments(AdviceSpecification advice) { + return advice.getArguments().map(arg -> arg.getIndex()).collect(Collectors.toList()); + } + + private static AdviceSpecification findAdvice(CallSiteSpecification result, String name) { + return result.getAdvices().stream() + .filter(it -> it.getAdvice().getMethodName().equals(name)) + .findFirst() + .orElse(null); + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/BaseCsiPluginTest.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/BaseCsiPluginTest.java new file mode 100644 index 00000000000..a2bd158d20b --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/BaseCsiPluginTest.java @@ -0,0 +1,63 @@ +package datadog.trace.plugin.csi.impl; + +import static datadog.trace.plugin.csi.impl.CallSiteFactory.pointcutParser; +import static datadog.trace.plugin.csi.impl.CallSiteFactory.specificationBuilder; +import static datadog.trace.plugin.csi.impl.CallSiteFactory.typeResolver; +import static datadog.trace.plugin.csi.util.CallSiteConstants.TYPE_RESOLVER; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import datadog.trace.plugin.csi.HasErrors; +import datadog.trace.plugin.csi.ValidationContext; +import java.io.File; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +public abstract class BaseCsiPluginTest { + + protected static void assertNoErrors(HasErrors hasErrors) { + List errors = + hasErrors.getErrors().stream() + .map( + error -> { + String causeString = error.getCause() == null ? "-" : error.getCauseString(); + return error.getMessage() + ": " + causeString; + }) + .collect(Collectors.toList()); + assertEquals(Collections.emptyList(), errors); + } + + protected static File fetchClass(Class clazz) { + try { + Path folder = Paths.get(clazz.getResource("/").toURI()).resolve("../../"); + String fileSeparator = File.separator.equals("\\") ? "\\\\" : File.separator; + String classFile = clazz.getName().replaceAll("\\.", fileSeparator) + ".class"; + Path groovy = folder.resolve("groovy/test").resolve(classFile); + if (Files.exists(groovy)) { + return groovy.toFile(); + } + return folder.resolve("java/test").resolve(classFile).toFile(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + protected static CallSiteSpecification buildClassSpecification(Class clazz) { + File classFile = fetchClass(clazz); + CallSiteSpecification spec = specificationBuilder().build(classFile).get(); + spec.getAdvices().forEach(advice -> advice.parseSignature(pointcutParser())); + return spec; + } + + protected ValidationContext mockValidationContext() { + ValidationContext context = mock(ValidationContext.class); + when(context.getContextProperty(TYPE_RESOLVER)).thenReturn(typeResolver()); + return context; + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/CallSiteSpecificationTest.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/CallSiteSpecificationTest.java new file mode 100644 index 00000000000..a0d60d379c1 --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/CallSiteSpecificationTest.java @@ -0,0 +1,69 @@ +package datadog.trace.plugin.csi.impl; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import datadog.trace.agent.tooling.csi.CallSiteAdvice; +import datadog.trace.plugin.csi.ValidationContext; +import datadog.trace.plugin.csi.impl.CallSiteSpecification.AdviceSpecification; +import datadog.trace.plugin.csi.util.ErrorCode; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import org.junit.jupiter.api.Test; +import org.objectweb.asm.Type; + +class CallSiteSpecificationTest extends BaseCsiPluginTest { + + @Test + void testCallSiteSpiShouldBeAnInterface() { + ValidationContext context = mockValidationContext(); + AdviceSpecification mockAdvice = mock(AdviceSpecification.class); + List advices = Collections.singletonList(mockAdvice); + Set spiTypes = Collections.singleton(Type.getType(String.class)); + List helperClassNames = Collections.emptyList(); + Set constants = Collections.emptySet(); + CallSiteSpecification spec = + new CallSiteSpecification( + Type.getType(String.class), advices, spiTypes, helperClassNames, constants); + + spec.validate(context); + + verify(context).addError(eq(ErrorCode.CALL_SITE_SPI_SHOULD_BE_AN_INTERFACE), any()); + } + + @Test + void testCallSiteSpiShouldNotDefineAnyMethods() { + ValidationContext context = mockValidationContext(); + AdviceSpecification mockAdvice = mock(AdviceSpecification.class); + List advices = Collections.singletonList(mockAdvice); + Set spiTypes = Collections.singleton(Type.getType(Comparable.class)); + List helperClassNames = Collections.emptyList(); + Set constants = Collections.emptySet(); + CallSiteSpecification spec = + new CallSiteSpecification( + Type.getType(String.class), advices, spiTypes, helperClassNames, constants); + + spec.validate(context); + + verify(context).addError(eq(ErrorCode.CALL_SITE_SPI_SHOULD_BE_EMPTY), any()); + } + + @Test + void testCallSiteShouldHaveAdvices() { + ValidationContext context = mockValidationContext(); + List advices = Collections.emptyList(); + Set spiTypes = Collections.singleton(Type.getType(CallSiteAdvice.class)); + List helperClassNames = Collections.emptyList(); + Set constants = Collections.emptySet(); + CallSiteSpecification spec = + new CallSiteSpecification( + Type.getType(String.class), advices, spiTypes, helperClassNames, constants); + + spec.validate(context); + + verify(context).addError(eq(ErrorCode.CALL_SITE_SHOULD_HAVE_ADVICE_METHODS), any()); + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/RegexpAdvicePointcutParserTest.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/RegexpAdvicePointcutParserTest.java new file mode 100644 index 00000000000..550082ccb7d --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/RegexpAdvicePointcutParserTest.java @@ -0,0 +1,162 @@ +package datadog.trace.plugin.csi.impl; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import datadog.trace.plugin.csi.util.MethodType; +import org.junit.jupiter.api.Test; + +class RegexpAdvicePointcutParserTest { + + @Test + void resolveConstructor() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "void datadog.trace.plugin.csi.samples.SignatureParserExample.()"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("", signature.getMethodName()); + assertEquals("()V", signature.getMethodType().getDescriptor()); + } + + @Test + void resolveConstructorWithArgs() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "void datadog.trace.plugin.csi.samples.SignatureParserExample.(java.lang.String)"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("", signature.getMethodName()); + assertEquals("(Ljava/lang/String;)V", signature.getMethodType().getDescriptor()); + } + + @Test + void resolveWithoutArgs() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "java.lang.String datadog.trace.plugin.csi.samples.SignatureParserExample.noParams()"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("noParams", signature.getMethodName()); + assertEquals("()Ljava/lang/String;", signature.getMethodType().getDescriptor()); + } + + @Test + void resolveOneParam() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "java.lang.String datadog.trace.plugin.csi.samples.SignatureParserExample.oneParam(java.util.Map)"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("oneParam", signature.getMethodName()); + assertEquals("(Ljava/util/Map;)Ljava/lang/String;", signature.getMethodType().getDescriptor()); + } + + @Test + void resolveMultipleParams() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "java.lang.String datadog.trace.plugin.csi.samples.SignatureParserExample.multipleParams(java.lang.String, int, java.util.List)"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("multipleParams", signature.getMethodName()); + assertEquals( + "(Ljava/lang/String;ILjava/util/List;)Ljava/lang/String;", + signature.getMethodType().getDescriptor()); + } + + @Test + void resolveVarargs() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "java.lang.String datadog.trace.plugin.csi.samples.SignatureParserExample.varargs(java.lang.String[])"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("varargs", signature.getMethodName()); + assertEquals( + "([Ljava/lang/String;)Ljava/lang/String;", signature.getMethodType().getDescriptor()); + } + + @Test + void resolvePrimitive() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "int datadog.trace.plugin.csi.samples.SignatureParserExample.primitive()"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("primitive", signature.getMethodName()); + assertEquals("()I", signature.getMethodType().getDescriptor()); + } + + @Test + void resolvePrimitiveArrayType() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "byte[] datadog.trace.plugin.csi.samples.SignatureParserExample.primitiveArray()"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("primitiveArray", signature.getMethodName()); + assertEquals("()[B", signature.getMethodType().getDescriptor()); + } + + @Test + void resolveObjectArrayType() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "java.lang.Object[] datadog.trace.plugin.csi.samples.SignatureParserExample.objectArray()"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("objectArray", signature.getMethodName()); + assertEquals("()[Ljava/lang/Object;", signature.getMethodType().getDescriptor()); + } + + @Test + void resolveMultiDimensionalObjectArrayType() { + RegexpAdvicePointcutParser pointcutParser = new RegexpAdvicePointcutParser(); + + MethodType signature = + pointcutParser.parse( + "java.lang.Object[][][] datadog.trace.plugin.csi.samples.SignatureParserExample.objectArray()"); + + assertEquals( + "datadog.trace.plugin.csi.samples.SignatureParserExample", + signature.getOwner().getClassName()); + assertEquals("objectArray", signature.getMethodName()); + assertEquals("()[[[Ljava/lang/Object;", signature.getMethodType().getDescriptor()); + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/TypeResolverPoolTest.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/TypeResolverPoolTest.java new file mode 100644 index 00000000000..13c85343c7b --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/TypeResolverPoolTest.java @@ -0,0 +1,96 @@ +package datadog.trace.plugin.csi.impl; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import datadog.trace.plugin.csi.util.MethodType; +import java.lang.reflect.Method; +import javax.servlet.ServletRequest; +import javax.servlet.http.HttpServletRequest; +import org.junit.jupiter.api.Test; +import org.objectweb.asm.Type; + +class TypeResolverPoolTest { + + @Test + void testResolvePrimitive() { + TypeResolverPool resolver = new TypeResolverPool(); + + Class result = resolver.resolveType(Type.INT_TYPE); + + assertEquals(int.class, result); + } + + @Test + void testResolvePrimitiveArray() { + TypeResolverPool resolver = new TypeResolverPool(); + Type type = Type.getType("[I"); + + Class result = resolver.resolveType(type); + + assertEquals(int[].class, result); + } + + @Test + void testResolvePrimitiveMultidimensionalArray() { + TypeResolverPool resolver = new TypeResolverPool(); + Type type = Type.getType("[[[I"); + + Class result = resolver.resolveType(type); + + assertEquals(int[][][].class, result); + } + + @Test + void testResolveClass() { + TypeResolverPool resolver = new TypeResolverPool(); + Type type = Type.getType(String.class); + + Class result = resolver.resolveType(type); + + assertEquals(String.class, result); + } + + @Test + void testResolveClassArray() { + TypeResolverPool resolver = new TypeResolverPool(); + Type type = Type.getType(String[].class); + + Class result = resolver.resolveType(type); + + assertEquals(String[].class, result); + } + + @Test + void testResolveClassMultidimensionalArray() { + TypeResolverPool resolver = new TypeResolverPool(); + Type type = Type.getType(String[][][].class); + + Class result = resolver.resolveType(type); + + assertEquals(String[][][].class, result); + } + + @Test + void testTypeResolverFromMethod() { + TypeResolverPool resolver = new TypeResolverPool(); + Type type = + Type.getMethodType( + Type.getType(String[].class), Type.getType(String.class), Type.getType(String.class)); + + Class result = resolver.resolveType(type.getReturnType()); + + assertEquals(String[].class, result); + } + + @Test + void testInheritedMethods() throws Exception { + TypeResolverPool resolver = new TypeResolverPool(); + Type owner = Type.getType(HttpServletRequest.class); + String name = "getParameter"; + Type descriptor = Type.getMethodType(Type.getType(String.class), Type.getType(String.class)); + + Method result = (Method) resolver.resolveMethod(new MethodType(owner, name, descriptor)); + + assertEquals(ServletRequest.class.getDeclaredMethod("getParameter", String.class), result); + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/assertion/AdviceAssert.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/assertion/AdviceAssert.java new file mode 100644 index 00000000000..9f814d7171d --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/assertion/AdviceAssert.java @@ -0,0 +1,37 @@ +package datadog.trace.plugin.csi.impl.assertion; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +public class AdviceAssert { + protected String type; + protected String owner; + protected String method; + protected String descriptor; + protected List statements; + + public AdviceAssert( + String type, String owner, String method, String descriptor, List statements) { + this.type = type; + this.owner = owner; + this.method = method; + this.descriptor = descriptor; + this.statements = statements; + } + + public void type(String type) { + assertEquals(type, this.type); + } + + public void pointcut(String owner, String method, String descriptor) { + assertEquals(owner, this.owner); + assertEquals(method, this.method); + assertEquals(descriptor, this.descriptor); + } + + public void statements(String... values) { + assertArrayEquals(values, statements.toArray(new String[0])); + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/assertion/AssertBuilder.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/assertion/AssertBuilder.java new file mode 100644 index 00000000000..42e87842adc --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/assertion/AssertBuilder.java @@ -0,0 +1,205 @@ +package datadog.trace.plugin.csi.impl.assertion; + +import static datadog.trace.plugin.csi.impl.CallSiteFactory.typeResolver; +import static datadog.trace.plugin.csi.util.CallSiteUtils.classNameToType; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.github.javaparser.JavaParser; +import com.github.javaparser.ParserConfiguration; +import com.github.javaparser.ast.CompilationUnit; +import com.github.javaparser.ast.Node; +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration; +import com.github.javaparser.ast.body.MethodDeclaration; +import com.github.javaparser.ast.expr.MethodCallExpr; +import com.github.javaparser.resolution.declarations.ResolvedMethodDeclaration; +import com.github.javaparser.symbolsolver.JavaSymbolSolver; +import datadog.trace.agent.tooling.csi.CallSites; +import java.io.File; +import java.io.FileNotFoundException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +public class AssertBuilder { + protected final File file; + + public AssertBuilder(File file) { + this.file = file; + } + + public CallSiteAssert build() { + CompilationUnit javaFile; + javaFile = parseJavaFile(file); + assertEquals(Node.Parsedness.PARSED, javaFile.getParsed()); + ClassOrInterfaceDeclaration targetType = + javaFile.getPrimaryType().get().asClassOrInterfaceDeclaration(); + Set> interfaces = getInterfaces(targetType); + Method enabled = null; + Set enabledArgs = null; + Object[] enabledDeclaration = getEnabledDeclaration(targetType, interfaces); + enabled = (Method) enabledDeclaration[0]; + enabledArgs = (Set) enabledDeclaration[1]; + return new CallSiteAssert( + interfaces, + getSpi(targetType), + getHelpers(targetType), + getAdvices(targetType), + enabled, + enabledArgs); + } + + protected Set> getSpi(ClassOrInterfaceDeclaration type) { + return type.getAnnotationByName("AutoService") + .>>map( + annotation -> + annotation.asNormalAnnotationExpr().getPairs().stream() + .filter(pair -> pair.getNameAsString().equals("value")) + .flatMap( + pair -> + pair.getValue().asArrayInitializerExpr().getValues().stream() + .map( + value -> + value + .asClassExpr() + .getType() + .resolve() + .asReferenceType() + .getTypeDeclaration() + .get() + .getQualifiedName())) + .map(AssertBuilder::loadClass) + .collect(Collectors.toSet())) + .orElse(Collections.emptySet()); + } + + protected Set> getInterfaces(ClassOrInterfaceDeclaration type) { + return type.getImplementedTypes().stream() + .map( + implementedType -> { + String qualifiedName = + implementedType + .asClassOrInterfaceType() + .resolve() + .asReferenceType() + .getTypeDeclaration() + .get() + .getQualifiedName(); + return loadClass(qualifiedName); + }) + .collect(Collectors.toSet()); + } + + private static Class loadClass(String qualifiedName) { + // Try progressively replacing dots with $ from right to left for inner classes + String current = qualifiedName; + int lastDot = current.lastIndexOf('.'); + do { + try { + return Class.forName(current); + } catch (ClassNotFoundException e) { + if (lastDot <= 0) { + throw new RuntimeException(new ClassNotFoundException(qualifiedName)); + } + current = current.substring(0, lastDot) + "$" + current.substring(lastDot + 1); + lastDot = current.lastIndexOf('.', lastDot - 1); + } + } while (true); + } + + protected Object[] getEnabledDeclaration( + ClassOrInterfaceDeclaration type, Set> interfaces) { + if (!interfaces.contains(CallSites.HasEnabledProperty.class)) { + return new Object[] {null, null}; + } + MethodDeclaration isEnabled = type.getMethodsByName("isEnabled").get(0); + MethodCallExpr enabledMethodCall = + isEnabled + .getBody() + .get() + .getStatements() + .get(0) + .asReturnStmt() + .getExpression() + .get() + .asMethodCallExpr(); + Method enabled = resolveMethod(enabledMethodCall); + Set enabledArgs = + enabledMethodCall.getArguments().stream() + .map(arg -> arg.asStringLiteralExpr().asString()) + .collect(Collectors.toSet()); + return new Object[] {enabled, enabledArgs}; + } + + protected Set> getHelpers(ClassOrInterfaceDeclaration type) { + MethodDeclaration acceptMethod = type.getMethodsByName("accept").get(0); + List methodCalls = getMethodCalls(acceptMethod); + return methodCalls.stream() + .filter(methodCall -> methodCall.getNameAsString().equals("addHelpers")) + .flatMap(methodCall -> methodCall.getArguments().stream()) + .map( + arg -> { + String className = arg.asStringLiteralExpr().asString(); + return typeResolver().resolveType(classNameToType(className)); + }) + .collect(Collectors.toSet()); + } + + protected List getAdvices(ClassOrInterfaceDeclaration type) { + MethodDeclaration acceptMethod = type.getMethodsByName("accept").get(0); + return getMethodCalls(acceptMethod).stream() + .filter(methodCall -> methodCall.getNameAsString().equals("addAdvice")) + .map( + methodCall -> { + String adviceType = methodCall.getArgument(0).asFieldAccessExpr().getNameAsString(); + String owner = methodCall.getArgument(1).asStringLiteralExpr().asString(); + String method = methodCall.getArgument(2).asStringLiteralExpr().asString(); + String descriptor = methodCall.getArgument(3).asStringLiteralExpr().asString(); + List statements = + methodCall + .getArgument(4) + .asLambdaExpr() + .getBody() + .asBlockStmt() + .getStatements() + .stream() + .map(Object::toString) + .collect(Collectors.toList()); + return new AdviceAssert(adviceType, owner, method, descriptor, statements); + }) + .collect(Collectors.toList()); + } + + protected static List getMethodCalls(MethodDeclaration method) { + return method.getBody().get().asBlockStmt().getStatements().stream() + .filter( + stmt -> + stmt.isExpressionStmt() + && stmt.asExpressionStmt().getExpression().isMethodCallExpr()) + .map(stmt -> stmt.asExpressionStmt().getExpression().asMethodCallExpr()) + .collect(Collectors.toList()); + } + + private static Method resolveMethod(MethodCallExpr methodCallExpr) { + ResolvedMethodDeclaration resolved = methodCallExpr.resolve(); + try { + Field methodField = resolved.getClass().getDeclaredField("method"); + methodField.setAccessible(true); + return (Method) methodField.get(resolved); + } catch (IllegalAccessException | NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + + private static CompilationUnit parseJavaFile(File file) { + JavaSymbolSolver solver = new JavaSymbolSolver(typeResolver()); + JavaParser parser = new JavaParser(new ParserConfiguration().setSymbolResolver(solver)); + try { + return parser.parse(file).getResult().get(); + } catch (FileNotFoundException e) { + throw new RuntimeException(e); + } + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/assertion/CallSiteAssert.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/assertion/CallSiteAssert.java new file mode 100644 index 00000000000..57ce0146921 --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/assertion/CallSiteAssert.java @@ -0,0 +1,88 @@ +package datadog.trace.plugin.csi.impl.assertion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; + +public class CallSiteAssert { + + protected Set> interfaces; + protected Set> spi; + protected Set> helpers; + protected List advices; + protected Method enabled; + protected Set enabledArgs; + + public CallSiteAssert( + Set> interfaces, + Set> spi, + Set> helpers, + List advices, + Method enabled, + Set enabledArgs) { + this.interfaces = interfaces; + this.spi = spi; + this.helpers = helpers; + this.advices = advices; + this.enabled = enabled; + this.enabledArgs = enabledArgs; + } + + public void interfaces(Class... values) { + assertSameElements(interfaces, values); + } + + public void helpers(Class... values) { + assertSameElements(helpers, values); + } + + public void spi(Class... values) { + assertSameElements(spi, values); + } + + public void advices(int index, Consumer assertions) { + AdviceAssert asserter = advices.get(index); + assertions.accept(asserter); + } + + public void enabled(Method method, String... args) { + assertEquals(method, enabled); + assertSameElements(enabledArgs, args); + } + + private static void assertSameElements(Set expected, E... received) { + assertEquals(received.length, expected.size()); + Set receivedSet = new HashSet<>(Arrays.asList(received)); + assertTrue(expected.containsAll(receivedSet) && receivedSet.containsAll(expected)); + } + + public Set> getInterfaces() { + return interfaces; + } + + public Set> getSpi() { + return spi; + } + + public Set> getHelpers() { + return helpers; + } + + public List getAdvices() { + return advices; + } + + public Method getEnabled() { + return enabled; + } + + public Set getEnabledArgs() { + return enabledArgs; + } +} diff --git a/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/ext/IastExtensionTest.java b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/ext/IastExtensionTest.java new file mode 100644 index 00000000000..ff43401c8f2 --- /dev/null +++ b/buildSrc/call-site-instrumentation-plugin/src/test/java/datadog/trace/plugin/csi/impl/ext/IastExtensionTest.java @@ -0,0 +1,311 @@ +package datadog.trace.plugin.csi.impl.ext; + +import static datadog.trace.plugin.csi.impl.CallSiteFactory.pointcutParser; +import static datadog.trace.plugin.csi.util.CallSiteUtils.classNameToType; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.github.javaparser.JavaParser; +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration; +import com.github.javaparser.ast.stmt.IfStmt; +import datadog.trace.agent.tooling.csi.CallSites; +import datadog.trace.plugin.csi.AdviceGenerator; +import datadog.trace.plugin.csi.AdviceGenerator.CallSiteResult; +import datadog.trace.plugin.csi.PluginApplication.Configuration; +import datadog.trace.plugin.csi.impl.AdviceGeneratorImpl; +import datadog.trace.plugin.csi.impl.BaseCsiPluginTest; +import datadog.trace.plugin.csi.impl.CallSiteSpecification; +import datadog.trace.plugin.csi.impl.assertion.AdviceAssert; +import datadog.trace.plugin.csi.impl.assertion.AssertBuilder; +import datadog.trace.plugin.csi.impl.assertion.CallSiteAssert; +import datadog.trace.plugin.csi.impl.ext.tests.IastExtensionCallSite; +import java.io.File; +import java.lang.reflect.Method; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.objectweb.asm.Type; + +class IastExtensionTest extends BaseCsiPluginTest { + + @TempDir private File buildDir; + private Path targetFolder; + private Path projectFolder; + private Path srcFolder; + + @BeforeEach + void setup() throws Exception { + targetFolder = buildDir.toPath().resolve("target"); + Files.createDirectories(targetFolder); + projectFolder = buildDir.toPath().resolve("project"); + Files.createDirectories(projectFolder); + srcFolder = projectFolder.resolve("src/main/java"); + Files.createDirectories(srcFolder); + } + + @ParameterizedTest + @CsvSource( + delimiter = '|', + nullValues = "null", + value = { + "datadog.trace.agent.tooling.csi.CallSites | false", + "datadog.trace.agent.tooling.iast.IastCallSites | true" + }) + void testThatExtensionOnlyAppliesToIastAdvices(String typeName, boolean expected) { + Type type = classNameToType(typeName); + Type[] types = new Type[] {type}; + CallSiteSpecification callSite = mock(CallSiteSpecification.class); + when(callSite.getSpi()).thenReturn(types); + IastExtension extension = new IastExtension(); + + boolean applies = extension.appliesTo(callSite); + + assertEquals(expected, applies); + } + + @Test + void testThatExtensionGeneratesACallSiteWithTelemetry() throws Exception { + Configuration config = mock(Configuration.class); + when(config.getTargetFolder()).thenReturn(targetFolder); + when(config.getSrcFolder()).thenReturn(getCallSiteSrcFolder()); + when(config.getClassPath()).thenReturn(Collections.emptyList()); + CallSiteSpecification spec = buildClassSpecification(IastExtensionCallSite.class); + AdviceGenerator generator = buildAdviceGenerator(buildDir); + CallSiteResult result = generator.generate(spec); + assertTrue(result.isSuccess()); + IastExtension extension = new IastExtension(); + + extension.apply(config, result); + + assertNoErrors(result); + IastExtensionCallSiteAssert asserter = assertCallSites(result.getFile()); + asserter.iastAdvices( + 0, + advice -> { + advice.pointcut( + "javax/servlet/http/HttpServletRequest", + "getHeader", + "(Ljava/lang/String;)Ljava/lang/String;"); + advice.instrumentedMetric( + "IastMetric.INSTRUMENTED_SOURCE", + metric -> { + metric.metricStatements( + "IastMetricCollector.add(IastMetric.INSTRUMENTED_SOURCE, (byte) 3, 1);"); + }); + advice.executedMetric( + "IastMetric.EXECUTED_SOURCE", + metric -> { + metric.metricStatements( + "handler.field(net.bytebuddy.jar.asm.Opcodes.GETSTATIC, \"datadog/trace/api/iast/telemetry/IastMetric\", \"EXECUTED_SOURCE\", \"Ldatadog/trace/api/iast/telemetry/IastMetric;\");", + "handler.instruction(net.bytebuddy.jar.asm.Opcodes.ICONST_3);", + "handler.instruction(net.bytebuddy.jar.asm.Opcodes.ICONST_1);", + "handler.method(net.bytebuddy.jar.asm.Opcodes.INVOKESTATIC, \"datadog/trace/api/iast/telemetry/IastMetricCollector\", \"add\", \"(Ldatadog/trace/api/iast/telemetry/IastMetric;BI)V\", false);"); + }); + }); + asserter.iastAdvices( + 1, + advice -> { + advice.pointcut( + "javax/servlet/http/HttpServletRequest", + "getInputStream", + "()Ljavax/servlet/ServletInputStream;"); + advice.instrumentedMetric( + "IastMetric.INSTRUMENTED_SOURCE", + metric -> { + metric.metricStatements( + "IastMetricCollector.add(IastMetric.INSTRUMENTED_SOURCE, (byte) 127, 1);"); + }); + advice.executedMetric( + "IastMetric.EXECUTED_SOURCE", + metric -> { + metric.metricStatements( + "handler.field(net.bytebuddy.jar.asm.Opcodes.GETSTATIC, \"datadog/trace/api/iast/telemetry/IastMetric\", \"EXECUTED_SOURCE\", \"Ldatadog/trace/api/iast/telemetry/IastMetric;\");", + "handler.instruction(net.bytebuddy.jar.asm.Opcodes.BIPUSH, 127);", + "handler.instruction(net.bytebuddy.jar.asm.Opcodes.ICONST_1);", + "handler.method(net.bytebuddy.jar.asm.Opcodes.INVOKESTATIC, \"datadog/trace/api/iast/telemetry/IastMetricCollector\", \"add\", \"(Ldatadog/trace/api/iast/telemetry/IastMetric;BI)V\", false);"); + }); + }); + asserter.iastAdvices( + 2, + advice -> { + advice.pointcut( + "javax/servlet/ServletRequest", "getReader", "()Ljava/io/BufferedReader;"); + advice.instrumentedMetric( + "IastMetric.INSTRUMENTED_PROPAGATION", + metric -> { + metric.metricStatements( + "IastMetricCollector.add(IastMetric.INSTRUMENTED_PROPAGATION, 1);"); + }); + advice.executedMetric( + "IastMetric.EXECUTED_PROPAGATION", + metric -> { + metric.metricStatements( + "handler.field(net.bytebuddy.jar.asm.Opcodes.GETSTATIC, \"datadog/trace/api/iast/telemetry/IastMetric\", \"EXECUTED_PROPAGATION\", \"Ldatadog/trace/api/iast/telemetry/IastMetric;\");", + "handler.instruction(net.bytebuddy.jar.asm.Opcodes.ICONST_1);", + "handler.method(net.bytebuddy.jar.asm.Opcodes.INVOKESTATIC, \"datadog/trace/api/iast/telemetry/IastMetricCollector\", \"add\", \"(Ldatadog/trace/api/iast/telemetry/IastMetric;I)V\", false);"); + }); + }); + } + + private static AdviceGenerator buildAdviceGenerator(File targetFolder) { + return new AdviceGeneratorImpl(targetFolder, pointcutParser()); + } + + private static Path getCallSiteSrcFolder() throws Exception { + File file = new File(Thread.currentThread().getContextClassLoader().getResource("").toURI()); + return file.toPath().resolve("../../../../src/test/java"); + } + + private static ClassOrInterfaceDeclaration parse(File path) throws Exception { + return new JavaParser() + .parse(path) + .getResult() + .get() + .getPrimaryType() + .get() + .asClassOrInterfaceDeclaration(); + } + + private static IastExtensionCallSiteAssert assertCallSites(File generated) { + try { + return new IastExtensionAssertBuilder(generated).build(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + static class IastExtensionCallSiteAssert extends CallSiteAssert { + + IastExtensionCallSiteAssert( + Set> interfaces, + Set> spi, + Set> helpers, + List advices, + Method enabled, + Set enabledArgs) { + super(interfaces, spi, helpers, advices, enabled, enabledArgs); + } + + public void iastAdvices(int index, Consumer assertions) { + IastExtensionAdviceAssert asserter = (IastExtensionAdviceAssert) advices.get(index); + assertions.accept(asserter); + } + } + + static class IastExtensionAdviceAssert extends AdviceAssert { + + protected IastExtensionMetricAsserter instrumented; + protected IastExtensionMetricAsserter executed; + + IastExtensionAdviceAssert( + String owner, + String method, + String descriptor, + IastExtensionMetricAsserter instrumented, + IastExtensionMetricAsserter executed, + List statements) { + super(null, owner, method, descriptor, statements); + this.instrumented = instrumented; + this.executed = executed; + } + + public void instrumentedMetric( + String metric, Consumer assertions) { + assertEquals(metric, instrumented.metric); + assertions.accept(instrumented); + } + + public void executedMetric(String metric, Consumer assertions) { + assertEquals(metric, executed.metric); + assertions.accept(executed); + } + } + + static class IastExtensionMetricAsserter { + protected String metric; + protected List statements; + + IastExtensionMetricAsserter(String metric, List statements) { + this.metric = metric; + this.statements = statements; + } + + public void metricStatements(String... values) { + assertArrayEquals(values, statements.toArray(new String[0])); + } + } + + static class IastExtensionAssertBuilder extends AssertBuilder { + + IastExtensionAssertBuilder(File file) { + super(file); + } + + @Override + public IastExtensionCallSiteAssert build() { + CallSiteAssert base = super.build(); + return new IastExtensionCallSiteAssert( + base.getInterfaces(), + base.getSpi(), + base.getHelpers(), + base.getAdvices(), + base.getEnabled(), + base.getEnabledArgs()); + } + + @Override + protected List getAdvices(ClassOrInterfaceDeclaration type) { + return getMethodCalls(type.getMethodsByName("accept").get(0)).stream() + .filter(methodCall -> methodCall.getNameAsString().equals("addAdvice")) + .map( + methodCall -> { + String owner = methodCall.getArgument(1).asStringLiteralExpr().asString(); + String method = methodCall.getArgument(2).asStringLiteralExpr().asString(); + String descriptor = methodCall.getArgument(3).asStringLiteralExpr().asString(); + List statements = + methodCall + .getArgument(4) + .asLambdaExpr() + .getBody() + .asBlockStmt() + .getStatements(); + IfStmt instrumentedStmt = statements.get(0).asIfStmt(); + IfStmt executedStmt = statements.get(1).asIfStmt(); + List nonIfStatements = + statements.stream() + .filter(stmt -> !stmt.isIfStmt()) + .map(Object::toString) + .collect(Collectors.toList()); + return new IastExtensionAdviceAssert( + owner, + method, + descriptor, + buildMetricAsserter(instrumentedStmt), + buildMetricAsserter(executedStmt), + nonIfStatements); + }) + .collect(Collectors.toList()); + } + + protected IastExtensionMetricAsserter buildMetricAsserter(IfStmt ifStmt) { + String metric = ifStmt.getCondition().asMethodCallExpr().getScope().get().toString(); + List statements = + ifStmt.getThenStmt().asBlockStmt().getStatements().stream() + .map(Object::toString) + .collect(Collectors.toList()); + return new IastExtensionMetricAsserter(metric, statements); + } + } +} diff --git a/tooling/move-groovy-to-java.sh b/tooling/move-groovy-to-java.sh new file mode 100755 index 00000000000..6e1b44da488 --- /dev/null +++ b/tooling/move-groovy-to-java.sh @@ -0,0 +1,140 @@ +#!/usr/bin/env bash +# move-groovy-to-java.sh +# Usage: ./move-groovy-to-java.sh +# +# Finds all directories matching */src/test/groovy under the start folder, +# ensures corresponding src/test/java exists, mirrors missing subdirs, +# moves files with `git mv` (preserving history) and commits the changes. + +set -o pipefail + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + exit 2 +fi + +START_DIR="$1" + +if [ ! -d "$START_DIR" ]; then + echo "Error: start folder '$START_DIR' does not exist or is not a directory." + exit 3 +fi + +# Resolve absolute path for START_DIR +START_DIR="$(cd "$START_DIR" && pwd)" + +# Determine git repo root (must be inside a git repo) +REPO_ROOT="$(git -C "$START_DIR" rev-parse --show-toplevel 2>/dev/null || true)" +if [ -z "$REPO_ROOT" ]; then + echo "Error: '$START_DIR' is not inside a git repository (or git not available)." + exit 4 +fi + +echo "Repository root: $REPO_ROOT" +echo "Scanning under: $START_DIR" +echo + +# Find all src/test/groovy directories (bash 3.2 compatible for macOS) +GROOVY_DIRS=() +while IFS= read -r -d '' dir; do + GROOVY_DIRS+=("$dir") +done < <(find "$START_DIR" -type d -path '*/src/test/groovy' -print0) + +if [ ${#GROOVY_DIRS[@]} -eq 0 ]; then + echo "No 'src/test/groovy' directories found under $START_DIR. Nothing to do." + exit 0 +fi + +echo "Found ${#GROOVY_DIRS[@]} groovy module(s)." +echo + +# Track whether we made staged changes +CHANGES_MADE=0 + +for GROOVY_DIR in "${GROOVY_DIRS[@]}"; do + echo "Processing: $GROOVY_DIR" + + # base is the parent '.../src/test' + BASE_DIR="$(dirname "$GROOVY_DIR")" # .../src/test + JAVA_DIR="$BASE_DIR/java" + + # Ensure src/test/java exists + if [ ! -d "$JAVA_DIR" ]; then + echo " Creating java dir: $JAVA_DIR" + mkdir -p "$JAVA_DIR" || { echo " Failed to create $JAVA_DIR"; continue; } + # stage new directories (optional, git will stage movements below) + git -C "$REPO_ROOT" add -- "$JAVA_DIR" >/dev/null 2>&1 || true + else + echo " java dir exists: $JAVA_DIR" + fi + + # Mirror missing subdirectories from groovy -> java + echo " Mirroring directory structure..." + # find all directories under groovy dir + while IFS= read -r -d '' subdir; do + # relative path inside groovy tree (empty for the top-level groovy dir) + rel="${subdir#$GROOVY_DIR}" + # remove leading slash if any + rel="${rel#/}" + target_dir="$JAVA_DIR/$rel" + if [ ! -d "$target_dir" ]; then + echo " mkdir -p $target_dir" + mkdir -p "$target_dir" || { echo " Failed to create $target_dir"; continue; } + git -C "$REPO_ROOT" add -- "$target_dir" >/dev/null 2>&1 || true + fi + done < <(find "$GROOVY_DIR" -type d -print0) + + # Move .groovy files recursively using git mv + echo " Moving .groovy files..." + while IFS= read -r -d '' groovy_file; do + # relative path inside groovy tree + rel_file="${groovy_file#$GROOVY_DIR/}" + dest_file="$JAVA_DIR/${rel_file%.groovy}.java" + + # Ensure destination dir exists (should from the mirroring step, but double-check) + dest_dir="$(dirname "$dest_file")" + if [ ! -d "$dest_dir" ]; then + echo " (creating dest dir) mkdir -p '$dest_dir'" + mkdir -p "$dest_dir" || { echo " Failed to create $dest_dir"; continue; } + fi + + if [ -e "$dest_file" ]; then + echo " SKIP: destination already exists: $dest_file" + continue + fi + + # Perform git mv (stages the rename). Use -f to overwrite if git allows (shouldn't be needed). + echo " git mv '$groovy_file' '$dest_file'" + if git -C "$REPO_ROOT" mv -- "$groovy_file" "$dest_file"; then + CHANGES_MADE=1 + else + echo " ERROR: git mv failed for: $groovy_file -> $dest_file" + # attempt fallback: plain mv then git add/rm (less ideal but tries to continue) + if mv -- "$groovy_file" "$dest_file"; then + git -C "$REPO_ROOT" add -- "$dest_file" >/dev/null 2>&1 || true + git -C "$REPO_ROOT" rm --cached --ignore-unmatch -- "$groovy_file" >/dev/null 2>&1 || true + CHANGES_MADE=1 + else + echo " FATAL: fallback mv failed for $groovy_file" + fi + fi + done < <(find "$GROOVY_DIR" -type f -name '*.groovy' -print0) + + echo +done + +# If there are staged changes, commit them +if [ "$CHANGES_MADE" -eq 1 ]; then + # check that there is something staged + if git -C "$REPO_ROOT" diff --cached --quiet; then + echo "No staged changes to commit." + else + echo "Committing changes..." + git -C "$REPO_ROOT" commit --no-verify -m "Moving groovy to java to keep history" + echo "Commit created." + fi +else + echo "No files moved; nothing to commit." +fi + +echo "Done."