diff --git a/.github/workflows/e2e-java-tracer.yaml b/.github/workflows/e2e-java-tracer.yaml new file mode 100644 index 000000000..7e92e9eee --- /dev/null +++ b/.github/workflows/e2e-java-tracer.yaml @@ -0,0 +1,98 @@ +name: E2E - Java Tracer + +on: + pull_request: + paths: + - 'codeflash/languages/java/**' + - 'codeflash/languages/base.py' + - 'codeflash/languages/registry.py' + - 'codeflash/tracer.py' + - 'codeflash/benchmarking/function_ranker.py' + - 'codeflash/discovery/functions_to_optimize.py' + - 'codeflash/optimization/**' + - 'codeflash/verification/**' + - 'codeflash-java-runtime/**' + - 'tests/test_languages/fixtures/java_tracer_e2e/**' + - 'tests/scripts/end_to_end_test_java_tracer.py' + - '.github/workflows/e2e-java-tracer.yaml' + + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref_name }} + cancel-in-progress: true + +jobs: + java-tracer-e2e: + environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }} + + runs-on: ubuntu-latest + env: + CODEFLASH_AIS_SERVER: prod + POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }} + CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }} + COLUMNS: 110 + MAX_RETRIES: 3 + RETRY_DELAY: 5 + EXPECTED_IMPROVEMENT_PCT: 10 + CODEFLASH_END_TO_END: 1 + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Validate PR + env: + PR_AUTHOR: ${{ github.event.pull_request.user.login }} + PR_STATE: ${{ github.event.pull_request.state }} + BASE_SHA: ${{ github.event.pull_request.base.sha }} + HEAD_SHA: ${{ github.event.pull_request.head.sha }} + run: | + if git diff --name-only "$BASE_SHA" "$HEAD_SHA" | grep -q "^.github/workflows/"; then + echo "⚠️ Workflow changes detected." + echo "PR Author: $PR_AUTHOR" + if [[ "$PR_AUTHOR" == "misrasaurabh1" || "$PR_AUTHOR" == "KRRT7" ]]; then + echo "✅ Authorized user ($PR_AUTHOR). Proceeding." + elif [[ "$PR_STATE" == "open" ]]; then + echo "✅ PR is open. Proceeding." + else + echo "⛔ Unauthorized user ($PR_AUTHOR) attempting to modify workflows. Exiting." + exit 1 + fi + else + echo "✅ No workflow file changes detected. Proceeding." + fi + + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: 'temurin' + cache: maven + + - name: Set up Python 3.11 for CLI + uses: astral-sh/setup-uv@v6 + with: + python-version: 3.11.6 + + - name: Install dependencies (CLI) + run: uv sync + + - name: Build codeflash-runtime JAR + run: | + cd codeflash-java-runtime + mvn clean package -q -DskipTests + mvn install -q -DskipTests + + - name: Verify Java installation + run: | + java -version + mvn --version + + - name: Run Java tracer e2e test + run: | + uv run python tests/scripts/end_to_end_test_java_tracer.py diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/AgentDispatcher.java b/codeflash-java-runtime/src/main/java/com/codeflash/AgentDispatcher.java index 4eb1eef84..6c06e9ad1 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/AgentDispatcher.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/AgentDispatcher.java @@ -19,13 +19,20 @@ */ public class AgentDispatcher { + static boolean isTracerMode(String agentArgs) { + return agentArgs != null + && (agentArgs.startsWith("trace=") || agentArgs.contains(",trace=")); + } + static boolean isProfilerMode(String agentArgs) { return agentArgs != null && (agentArgs.startsWith("config=") || agentArgs.contains(",config=")); } public static void premain(String agentArgs, Instrumentation inst) throws Exception { - if (isProfilerMode(agentArgs)) { + if (isTracerMode(agentArgs)) { + com.codeflash.tracer.TracerAgent.premain(agentArgs, inst); + } else if (isProfilerMode(agentArgs)) { com.codeflash.profiler.ProfilerAgent.premain(agentArgs, inst); } else { org.jacoco.agent.rt.internal_0e20598.PreMain.premain(agentArgs, inst); diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java b/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java new file mode 100644 index 000000000..f4b9ec453 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java @@ -0,0 +1,116 @@ +package com.codeflash; + +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import org.objectweb.asm.Type; + +public class ReplayHelper { + + private final Connection db; + + public ReplayHelper(String traceDbPath) { + try { + this.db = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath); + } catch (SQLException e) { + throw new RuntimeException("Failed to open trace database: " + traceDbPath, e); + } + } + + public void replay(String className, String methodName, String descriptor, int invocationIndex) throws Exception { + // Query the function_calls table for this method at the given index + byte[] argsBlob; + try (PreparedStatement stmt = db.prepareStatement( + "SELECT args FROM function_calls " + + "WHERE classname = ? AND function = ? AND descriptor = ? " + + "ORDER BY time_ns LIMIT 1 OFFSET ?")) { + stmt.setString(1, className); + stmt.setString(2, methodName); + stmt.setString(3, descriptor); + stmt.setInt(4, invocationIndex); + + try (ResultSet rs = stmt.executeQuery()) { + if (!rs.next()) { + throw new RuntimeException("No invocation found at index " + invocationIndex + + " for " + className + "." + methodName + descriptor); + } + argsBlob = rs.getBytes("args"); + } + } + + // Deserialize args + Object deserialized = Serializer.deserialize(argsBlob); + if (!(deserialized instanceof Object[])) { + throw new RuntimeException("Deserialized args is not Object[], got: " + + (deserialized == null ? "null" : deserialized.getClass().getName())); + } + Object[] allArgs = (Object[]) deserialized; + + // Load the target class + Class targetClass = Class.forName(className); + + // Parse descriptor to find parameter types + Type[] paramTypes = Type.getArgumentTypes(descriptor); + Class[] paramClasses = new Class[paramTypes.length]; + for (int i = 0; i < paramTypes.length; i++) { + paramClasses[i] = typeToClass(paramTypes[i]); + } + + // Find the method + Method method = targetClass.getDeclaredMethod(methodName, paramClasses); + method.setAccessible(true); + + boolean isStatic = Modifier.isStatic(method.getModifiers()); + + if (isStatic) { + method.invoke(null, allArgs); + } else { + // Args contain only explicit parameters (no 'this'). + // Create a default instance via no-arg constructor or Kryo. + Object instance; + try { + java.lang.reflect.Constructor ctor = targetClass.getDeclaredConstructor(); + ctor.setAccessible(true); + instance = ctor.newInstance(); + } catch (NoSuchMethodException e) { + // Fall back to Objenesis instantiation (no constructor needed) + instance = new org.objenesis.ObjenesisStd().newInstance(targetClass); + } + method.invoke(instance, allArgs); + } + } + + private static Class typeToClass(Type type) throws ClassNotFoundException { + switch (type.getSort()) { + case Type.BOOLEAN: return boolean.class; + case Type.BYTE: return byte.class; + case Type.CHAR: return char.class; + case Type.SHORT: return short.class; + case Type.INT: return int.class; + case Type.LONG: return long.class; + case Type.FLOAT: return float.class; + case Type.DOUBLE: return double.class; + case Type.VOID: return void.class; + case Type.ARRAY: + Class elementClass = typeToClass(type.getElementType()); + return java.lang.reflect.Array.newInstance(elementClass, 0).getClass(); + case Type.OBJECT: + return Class.forName(type.getClassName()); + default: + throw new ClassNotFoundException("Unknown type: " + type); + } + } + + public void close() { + try { + if (db != null) db.close(); + } catch (SQLException e) { + System.err.println("Error closing ReplayHelper: " + e.getMessage()); + } + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java new file mode 100644 index 000000000..2a22b74f4 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java @@ -0,0 +1,123 @@ +package com.codeflash.tracer; + +import com.codeflash.Serializer; + +import java.time.Instant; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; + +public final class TraceRecorder { + + private static volatile TraceRecorder instance; + + private static final long SERIALIZATION_TIMEOUT_MS = 500; + + private final TracerConfig config; + private final TraceWriter writer; + private final ConcurrentHashMap functionCounts = new ConcurrentHashMap<>(); + private final int maxFunctionCount; + private final ExecutorService serializerExecutor; + + // Reentrancy guard: prevent recursive tracing when serialization triggers class loading + private static final ThreadLocal RECORDING = ThreadLocal.withInitial(() -> Boolean.FALSE); + + private TraceRecorder(TracerConfig config) { + this.config = config; + this.writer = new TraceWriter(config.getDbPath()); + this.maxFunctionCount = config.getMaxFunctionCount(); + this.serializerExecutor = Executors.newCachedThreadPool(r -> { + Thread t = new Thread(r, "codeflash-serializer"); + t.setDaemon(true); + return t; + }); + } + + public static void initialize(TracerConfig config) { + instance = new TraceRecorder(config); + } + + public static TraceRecorder getInstance() { + return instance; + } + + public static boolean isRecording() { + return Boolean.TRUE.equals(RECORDING.get()); + } + + public void onEntry(String className, String methodName, String descriptor, + int lineNumber, String sourceFile, Object[] args) { + // Reentrancy guard + if (RECORDING.get()) { + return; + } + RECORDING.set(Boolean.TRUE); + try { + onEntryImpl(className, methodName, descriptor, lineNumber, sourceFile, args); + } finally { + RECORDING.set(Boolean.FALSE); + } + } + + private void onEntryImpl(String className, String methodName, String descriptor, + int lineNumber, String sourceFile, Object[] args) { + String qualifiedName = className + "." + methodName + descriptor; + + // Check per-method count limit + AtomicInteger count = functionCounts.computeIfAbsent(qualifiedName, k -> new AtomicInteger(0)); + if (count.get() >= maxFunctionCount) { + return; + } + + // Serialize args with timeout to prevent deep object graph traversal from blocking + byte[] argsBlob; + Future future = serializerExecutor.submit(() -> Serializer.serialize(args)); + try { + argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + future.cancel(true); + System.err.println("[codeflash-tracer] Serialization timed out for " + className + "." + + methodName); + return; + } catch (Exception e) { + Throwable cause = e.getCause() != null ? e.getCause() : e; + System.err.println("[codeflash-tracer] Serialization failed for " + className + "." + + methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage()); + return; + } + + long timeNs = System.nanoTime(); + count.incrementAndGet(); + + writer.recordFunctionCall("call", methodName, className, sourceFile, + lineNumber, descriptor, timeNs, argsBlob); + } + + public void flush() { + serializerExecutor.shutdownNow(); + // Write metadata + Map metadata = new LinkedHashMap<>(); + metadata.put("projectRoot", config.getProjectRoot()); + metadata.put("timestamp", Instant.now().toString()); + metadata.put("totalFunctions", String.valueOf(functionCounts.size())); + + int totalCaptures = 0; + for (AtomicInteger count : functionCounts.values()) { + totalCaptures += count.get(); + } + metadata.put("totalCaptures", String.valueOf(totalCaptures)); + + writer.writeMetadata(metadata); + writer.flush(); + writer.close(); + + System.err.println("[codeflash-tracer] Captured " + totalCaptures + + " invocations across " + functionCounts.size() + " methods"); + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java new file mode 100644 index 000000000..a9eeabf60 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java @@ -0,0 +1,210 @@ +package com.codeflash.tracer; + +import java.nio.file.Path; +import java.nio.file.Paths; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +public final class TraceWriter { + + private final Connection connection; + private final BlockingQueue writeQueue; + private final Thread writerThread; + private final AtomicBoolean running; + + private PreparedStatement insertFunctionCall; + private PreparedStatement insertMetadata; + + public TraceWriter(String dbPath) { + this.writeQueue = new LinkedBlockingQueue<>(); + this.running = new AtomicBoolean(true); + + try { + Path path = Paths.get(dbPath).toAbsolutePath(); + path.getParent().toFile().mkdirs(); + this.connection = DriverManager.getConnection("jdbc:sqlite:" + path); + initializeSchema(); + prepareStatements(); + + this.writerThread = new Thread(this::writerLoop, "codeflash-trace-writer"); + this.writerThread.setDaemon(true); + this.writerThread.start(); + + } catch (SQLException e) { + throw new RuntimeException("Failed to initialize TraceWriter: " + e.getMessage(), e); + } + } + + private void initializeSchema() throws SQLException { + try (Statement stmt = connection.createStatement()) { + stmt.execute("PRAGMA journal_mode=WAL"); + stmt.execute("PRAGMA synchronous=NORMAL"); + + stmt.execute( + "CREATE TABLE IF NOT EXISTS function_calls(" + + "id INTEGER PRIMARY KEY AUTOINCREMENT, " + + "type TEXT, " + + "function TEXT, " + + "classname TEXT, " + + "filename TEXT, " + + "line_number INTEGER, " + + "descriptor TEXT, " + + "time_ns INTEGER, " + + "args BLOB)" + ); + + stmt.execute( + "CREATE TABLE IF NOT EXISTS metadata(" + + "key TEXT PRIMARY KEY, " + + "value TEXT)" + ); + + stmt.execute("CREATE INDEX IF NOT EXISTS idx_fc_class_func ON function_calls(classname, function)"); + } + } + + private void prepareStatements() throws SQLException { + insertFunctionCall = connection.prepareStatement( + "INSERT INTO function_calls (type, function, classname, filename, line_number, descriptor, time_ns, args) " + + "VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + ); + insertMetadata = connection.prepareStatement( + "INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)" + ); + } + + public void recordFunctionCall(String type, String function, String classname, + String filename, int lineNumber, String descriptor, + long timeNs, byte[] argsBlob) { + writeQueue.offer(new FunctionCallTask(type, function, classname, filename, + lineNumber, descriptor, timeNs, argsBlob)); + } + + public void writeMetadata(Map metadata) { + for (Map.Entry entry : metadata.entrySet()) { + writeQueue.offer(new MetadataTask(entry.getKey(), entry.getValue())); + } + } + + private void writerLoop() { + while (running.get() || !writeQueue.isEmpty()) { + try { + WriteTask task = writeQueue.poll(100, TimeUnit.MILLISECONDS); + if (task != null) { + task.execute(this); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } catch (SQLException e) { + System.err.println("[codeflash-tracer] Write error: " + e.getMessage()); + } + } + + // Drain remaining + WriteTask task; + while ((task = writeQueue.poll()) != null) { + try { + task.execute(this); + } catch (SQLException e) { + System.err.println("[codeflash-tracer] Write error: " + e.getMessage()); + } + } + } + + public void flush() { + while (!writeQueue.isEmpty()) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + } + } + + public void close() { + running.set(false); + try { + writerThread.join(5000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + try { + if (insertFunctionCall != null) insertFunctionCall.close(); + if (insertMetadata != null) insertMetadata.close(); + if (connection != null) connection.close(); + } catch (SQLException e) { + System.err.println("[codeflash-tracer] Error closing TraceWriter: " + e.getMessage()); + } + } + + // Task types + + private interface WriteTask { + void execute(TraceWriter writer) throws SQLException; + } + + private static class FunctionCallTask implements WriteTask { + final String type; + final String function; + final String classname; + final String filename; + final int lineNumber; + final String descriptor; + final long timeNs; + final byte[] argsBlob; + + FunctionCallTask(String type, String function, String classname, + String filename, int lineNumber, String descriptor, + long timeNs, byte[] argsBlob) { + this.type = type; + this.function = function; + this.classname = classname; + this.filename = filename; + this.lineNumber = lineNumber; + this.descriptor = descriptor; + this.timeNs = timeNs; + this.argsBlob = argsBlob; + } + + @Override + public void execute(TraceWriter writer) throws SQLException { + writer.insertFunctionCall.setString(1, type); + writer.insertFunctionCall.setString(2, function); + writer.insertFunctionCall.setString(3, classname); + writer.insertFunctionCall.setString(4, filename); + writer.insertFunctionCall.setInt(5, lineNumber); + writer.insertFunctionCall.setString(6, descriptor); + writer.insertFunctionCall.setLong(7, timeNs); + writer.insertFunctionCall.setBytes(8, argsBlob); + writer.insertFunctionCall.executeUpdate(); + } + } + + private static class MetadataTask implements WriteTask { + final String key; + final String value; + + MetadataTask(String key, String value) { + this.key = key; + this.value = value; + } + + @Override + public void execute(TraceWriter writer) throws SQLException { + writer.insertMetadata.setString(1, key); + writer.insertMetadata.setString(2, value); + writer.insertMetadata.executeUpdate(); + } + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerAgent.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerAgent.java new file mode 100644 index 000000000..4aa0458fa --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerAgent.java @@ -0,0 +1,26 @@ +package com.codeflash.tracer; + +import java.lang.instrument.Instrumentation; + +public class TracerAgent { + + public static void premain(String agentArgs, Instrumentation inst) { + TracerConfig config = TracerConfig.parse(agentArgs); + + if (config.getPackages().isEmpty()) { + System.err.println("[codeflash-tracer] Warning: no packages configured, will instrument all non-JDK classes"); + } + + // Register transformer BEFORE initializing TraceRecorder, to ensure + // classes loaded during initialization (SQLite, Kryo) are visible. + inst.addTransformer(new TracingTransformer(config), true); + + TraceRecorder.initialize(config); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + TraceRecorder.getInstance().flush(); + }, "codeflash-tracer-shutdown")); + + System.err.println("[codeflash-tracer] Agent loaded, tracing packages: " + config.getPackages()); + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java new file mode 100644 index 000000000..8fe799d2f --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java @@ -0,0 +1,113 @@ +package com.codeflash.tracer; + +import com.google.gson.Gson; +import com.google.gson.annotations.SerializedName; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.List; + +public final class TracerConfig { + + @SerializedName("dbPath") + private String dbPath = "codeflash_trace.db"; + + @SerializedName("packages") + private List packages = Collections.emptyList(); + + @SerializedName("excludePackages") + private List excludePackages = Collections.emptyList(); + + @SerializedName("maxFunctionCount") + private int maxFunctionCount = 256; + + @SerializedName("timeout") + private int timeout = 0; + + @SerializedName("projectRoot") + private String projectRoot = ""; + + private static final Gson GSON = new Gson(); + + public static TracerConfig parse(String agentArgs) { + if (agentArgs == null || agentArgs.isEmpty()) { + return new TracerConfig(); + } + + String configPath = null; + for (String part : agentArgs.split(",")) { + String trimmed = part.trim(); + if (trimmed.startsWith("trace=")) { + configPath = trimmed.substring("trace=".length()); + } + } + + if (configPath == null) { + System.err.println("[codeflash-tracer] No trace= in agent args: " + agentArgs); + return new TracerConfig(); + } + + try { + String json = new String(Files.readAllBytes(Paths.get(configPath)), StandardCharsets.UTF_8); + TracerConfig config = GSON.fromJson(json, TracerConfig.class); + if (config == null) { + return new TracerConfig(); + } + if (config.packages == null) config.packages = Collections.emptyList(); + if (config.excludePackages == null) config.excludePackages = Collections.emptyList(); + return config; + } catch (IOException e) { + System.err.println("[codeflash-tracer] Failed to read config: " + e.getMessage()); + return new TracerConfig(); + } + } + + public String getDbPath() { + return dbPath; + } + + public List getPackages() { + return packages; + } + + public List getExcludePackages() { + return excludePackages; + } + + public int getMaxFunctionCount() { + return maxFunctionCount; + } + + public int getTimeout() { + return timeout; + } + + public String getProjectRoot() { + return projectRoot; + } + + public boolean shouldInstrumentClass(String internalClassName) { + String dotName = internalClassName.replace('/', '.'); + + for (String excluded : excludePackages) { + if (dotName.startsWith(excluded)) { + return false; + } + } + + if (packages.isEmpty()) { + return true; + } + + for (String pkg : packages) { + if (dotName.startsWith(pkg)) { + return true; + } + } + + return false; + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java new file mode 100644 index 000000000..c760ea636 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java @@ -0,0 +1,43 @@ +package com.codeflash.tracer; + +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; + +public class TracingClassVisitor extends ClassVisitor { + + private final String internalClassName; + private String sourceFile; + + public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName) { + super(Opcodes.ASM9, classVisitor); + this.internalClassName = internalClassName; + } + + @Override + public void visitSource(String source, String debug) { + super.visitSource(source, debug); + this.sourceFile = source; + } + + @Override + public MethodVisitor visitMethod(int access, String name, String descriptor, + String signature, String[] exceptions) { + MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions); + + // Skip static initializers, synthetic, and bridge methods + if (name.equals("") + || (access & Opcodes.ACC_SYNTHETIC) != 0 + || (access & Opcodes.ACC_BRIDGE) != 0) { + return mv; + } + + // Skip constructors for now (they have complex init semantics) + if (name.equals("")) { + return mv; + } + + return new TracingMethodAdapter(mv, access, name, descriptor, + internalClassName, 0, sourceFile != null ? sourceFile : ""); + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingMethodAdapter.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingMethodAdapter.java new file mode 100644 index 000000000..de71d4984 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingMethodAdapter.java @@ -0,0 +1,132 @@ +package com.codeflash.tracer; + +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; +import org.objectweb.asm.commons.AdviceAdapter; + +/** + * ASM AdviceAdapter that captures method arguments on entry. + * + *

On method entry, boxes all parameters into an Object[] array and calls + * {@link TraceRecorder#onEntry} to record the invocation. For instance methods, + * {@code this} is included as the first element. + */ +public class TracingMethodAdapter extends AdviceAdapter { + + private static final String TRACE_RECORDER = "com/codeflash/tracer/TraceRecorder"; + + private final String className; + private final String methodName; + private final String descriptor; + private final int lineNumber; + private final String sourceFile; + private final boolean isStatic; + + protected TracingMethodAdapter(MethodVisitor mv, int access, String name, String descriptor, + String className, int lineNumber, String sourceFile) { + super(Opcodes.ASM9, mv, access, name, descriptor); + this.className = className; + this.methodName = name; + this.descriptor = descriptor; + this.lineNumber = lineNumber; + this.sourceFile = sourceFile; + this.isStatic = (access & Opcodes.ACC_STATIC) != 0; + } + + @Override + protected void onMethodEnter() { + // Build Object[] containing explicit parameters only (skip 'this' to avoid + // expensive serialization of the receiver's full object graph) + Type[] argTypes = Type.getArgumentTypes(descriptor); + + // Push array size and create Object[] + pushInt(argTypes.length); + mv.visitTypeInsn(ANEWARRAY, "java/lang/Object"); + + int arrayIndex = 0; + int localIndex = isStatic ? 0 : 1; // skip 'this' slot for instance methods + + // Box and store each parameter + for (Type argType : argTypes) { + mv.visitInsn(DUP); + pushInt(arrayIndex); + loadAndBox(argType, localIndex); + mv.visitInsn(AASTORE); + arrayIndex++; + localIndex += argType.getSize(); + } + + // Stack now has: Object[] args on top + // Store in a local variable + int argsLocal = newLocal(Type.getType("[Ljava/lang/Object;")); + mv.visitVarInsn(ASTORE, argsLocal); + + // Call TraceRecorder.getInstance().onEntry(className, methodName, descriptor, lineNumber, sourceFile, args) + mv.visitMethodInsn(INVOKESTATIC, TRACE_RECORDER, "getInstance", + "()L" + TRACE_RECORDER + ";", false); + mv.visitLdcInsn(className.replace('/', '.')); + mv.visitLdcInsn(methodName); + mv.visitLdcInsn(descriptor); + pushInt(lineNumber); + mv.visitLdcInsn(sourceFile != null ? sourceFile : ""); + mv.visitVarInsn(ALOAD, argsLocal); + mv.visitMethodInsn(INVOKEVIRTUAL, TRACE_RECORDER, "onEntry", + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;ILjava/lang/String;[Ljava/lang/Object;)V", + false); + } + + private void loadAndBox(Type type, int localIndex) { + switch (type.getSort()) { + case Type.BOOLEAN: + mv.visitVarInsn(ILOAD, localIndex); + mv.visitMethodInsn(INVOKESTATIC, "java/lang/Boolean", "valueOf", "(Z)Ljava/lang/Boolean;", false); + break; + case Type.BYTE: + mv.visitVarInsn(ILOAD, localIndex); + mv.visitMethodInsn(INVOKESTATIC, "java/lang/Byte", "valueOf", "(B)Ljava/lang/Byte;", false); + break; + case Type.CHAR: + mv.visitVarInsn(ILOAD, localIndex); + mv.visitMethodInsn(INVOKESTATIC, "java/lang/Character", "valueOf", "(C)Ljava/lang/Character;", false); + break; + case Type.SHORT: + mv.visitVarInsn(ILOAD, localIndex); + mv.visitMethodInsn(INVOKESTATIC, "java/lang/Short", "valueOf", "(S)Ljava/lang/Short;", false); + break; + case Type.INT: + mv.visitVarInsn(ILOAD, localIndex); + mv.visitMethodInsn(INVOKESTATIC, "java/lang/Integer", "valueOf", "(I)Ljava/lang/Integer;", false); + break; + case Type.LONG: + mv.visitVarInsn(LLOAD, localIndex); + mv.visitMethodInsn(INVOKESTATIC, "java/lang/Long", "valueOf", "(J)Ljava/lang/Long;", false); + break; + case Type.FLOAT: + mv.visitVarInsn(FLOAD, localIndex); + mv.visitMethodInsn(INVOKESTATIC, "java/lang/Float", "valueOf", "(F)Ljava/lang/Float;", false); + break; + case Type.DOUBLE: + mv.visitVarInsn(DLOAD, localIndex); + mv.visitMethodInsn(INVOKESTATIC, "java/lang/Double", "valueOf", "(D)Ljava/lang/Double;", false); + break; + default: + // Object or array — just load reference + mv.visitVarInsn(ALOAD, localIndex); + break; + } + } + + private void pushInt(int value) { + if (value >= -1 && value <= 5) { + mv.visitInsn(ICONST_0 + value); + } else if (value >= Byte.MIN_VALUE && value <= Byte.MAX_VALUE) { + mv.visitIntInsn(BIPUSH, value); + } else if (value >= Short.MIN_VALUE && value <= Short.MAX_VALUE) { + mv.visitIntInsn(SIPUSH, value); + } else { + mv.visitLdcInsn(value); + } + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java new file mode 100644 index 000000000..974c767a9 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java @@ -0,0 +1,65 @@ +package com.codeflash.tracer; + +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassWriter; + +import java.lang.instrument.ClassFileTransformer; +import java.security.ProtectionDomain; + +public class TracingTransformer implements ClassFileTransformer { + + private final TracerConfig config; + + public TracingTransformer(TracerConfig config) { + this.config = config; + } + + @Override + public byte[] transform(ClassLoader loader, String className, + Class classBeingRedefined, ProtectionDomain protectionDomain, + byte[] classfileBuffer) { + if (className == null || !config.shouldInstrumentClass(className)) { + return null; + } + + // Skip instrumentation if we're inside a recording call (e.g., during Kryo serialization) + if (TraceRecorder.isRecording()) { + return null; + } + + // Skip internal JDK, framework, and synthetic classes + if (className.startsWith("java/") + || className.startsWith("javax/") + || className.startsWith("jdk/") + || className.startsWith("sun/") + || className.startsWith("com/sun/") + || className.startsWith("com/codeflash/") + || className.contains("ConstructorAccess") + || className.contains("FieldAccess") + || className.contains("$$")) { + return null; + } + + try { + return instrumentClass(className, classfileBuffer); + } catch (Throwable e) { + System.err.println("[codeflash-tracer] Failed to instrument " + className + ": " + + e.getClass().getName() + ": " + e.getMessage()); + return null; + } + } + + private byte[] instrumentClass(String internalClassName, byte[] bytecode) { + ClassReader cr = new ClassReader(bytecode); + // Use COMPUTE_MAXS only (not COMPUTE_FRAMES) to preserve original stack map frames. + // COMPUTE_FRAMES recomputes all frames and calls getCommonSuperClass() which either + // triggers classloader deadlocks or produces incorrect frames when returning "java/lang/Object". + // With COMPUTE_MAXS + ClassReader passed to constructor, ASM copies original frames and + // adjusts offsets for injected code. Our AdviceAdapter only injects at method entry + // (before any branch points), so existing frames remain valid. + ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS); + TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName); + cr.accept(cv, ClassReader.EXPAND_FRAMES); + return cw.toByteArray(); + } +} diff --git a/codeflash/benchmarking/function_ranker.py b/codeflash/benchmarking/function_ranker.py index 20c45f443..da565c6d7 100644 --- a/codeflash/benchmarking/function_ranker.py +++ b/codeflash/benchmarking/function_ranker.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD @@ -11,6 +11,7 @@ from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.java.jfr_parser import JfrProfile pytest_patterns = { " bool: """Check if a function is part of pytest infrastructure that should be excluded from ranking. @@ -38,6 +52,97 @@ def is_pytest_infrastructure(filename: str, function_name: str) -> bool: return any(pattern in function_name.lower() for pattern in pytest_func_patterns) +def is_java_infrastructure(class_name: str) -> bool: + return any(class_name.startswith(pattern) for pattern in java_infra_patterns) + + +class JavaFunctionRanker: + """Ranks Java functions using JFR profiling data.""" + + def __init__(self, jfr_profile: JfrProfile) -> None: + self._jfr_profile = jfr_profile + self._ranking = jfr_profile.get_method_ranking() + self._ranking_by_name: dict[str, dict[str, Any]] = {} + for entry in self._ranking: + name = entry["method_name"] + if name not in self._ranking_by_name: + self._ranking_by_name[name] = entry + + def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict[str, Any] | None: + for entry in self._ranking: + if entry["method_name"] == function_to_optimize.function_name: + return { + "filename": "", + "function_name": entry["method_name"], + "qualified_name": f"{entry['class_name']}.{entry['method_name']}", + "class_name": entry["class_name"], + "line_number": 0, + "call_count": entry["sample_count"], + "own_time_ns": self._jfr_profile.get_addressable_time_ns(entry["class_name"], entry["method_name"]), + "addressable_time_ns": self._jfr_profile.get_addressable_time_ns( + entry["class_name"], entry["method_name"] + ), + } + return None + + def get_function_addressable_time(self, function_to_optimize: FunctionToOptimize) -> float: + entry = self._ranking_by_name.get(function_to_optimize.function_name) + if entry is None: + return 0.0 + return self._jfr_profile.get_addressable_time_ns(entry["class_name"], entry["method_name"]) + + def rank_functions( + self, functions_to_optimize: list[FunctionToOptimize], min_functions: int = 5 + ) -> list[FunctionToOptimize]: + if not self._ranking: + logger.warning("No JFR profiling data available to rank functions.") + return functions_to_optimize + + total_time = sum( + self._jfr_profile.get_addressable_time_ns(e["class_name"], e["method_name"]) + for e in self._ranking + if not is_java_infrastructure(e["class_name"]) + ) + + if total_time == 0: + return functions_to_optimize + + functions_with_time = [] + functions_without_time = [] + for func in functions_to_optimize: + addr_time = self.get_function_addressable_time(func) + if addr_time > 0: + importance = addr_time / total_time + if importance >= DEFAULT_IMPORTANCE_THRESHOLD: + functions_with_time.append(func) + else: + logger.debug( + f"Filtering out Java function {func.qualified_name} with importance " + f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})" + ) + functions_without_time.append(func) + else: + functions_without_time.append(func) + + ranked = sorted(functions_with_time, key=self.get_function_addressable_time, reverse=True) + + # Guarantee at least min_functions pass through even when JFR data is sparse. + # Functions without JFR samples may still benefit from optimization. + if len(ranked) < min_functions: + shortfall = min_functions - len(ranked) + ranked_set = {id(f) for f in ranked} + for func in functions_without_time[:shortfall]: + if id(func) not in ranked_set: + ranked.append(func) + if shortfall > 0: + logger.info( + f"JFR data only covered {len(functions_with_time)} functions; " + f"added {min(shortfall, len(functions_without_time))} more to meet minimum of {min_functions}" + ) + + return ranked + + class FunctionRanker: """Ranks and filters functions based on % of addressable time derived from profiling data. diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index ba3371a90..5283f31ac 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -648,26 +648,32 @@ def discover_tests_for_language( # Convert TestInfo back to FunctionCalledInTest format # Use the full qualified name (with modules) as the key for consistency with Python function_to_tests: dict[str, set[FunctionCalledInTest]] = defaultdict(set) - num_tests = 0 + num_unit_tests = 0 + num_replay_tests = 0 for qualified_name, test_infos in test_map.items(): # Convert simple qualified_name to full qualified_name_with_modules full_qualified_name = simple_to_full_name.get(qualified_name, qualified_name) for test_info in test_infos: + is_replay = getattr(test_info, "is_replay", False) + test_type = TestType.REPLAY_TEST if is_replay else TestType.EXISTING_UNIT_TEST function_to_tests[full_qualified_name].add( FunctionCalledInTest( tests_in_file=TestsInFile( test_file=test_info.test_file, test_class=test_info.test_class, test_function=test_info.test_name, - test_type=TestType.EXISTING_UNIT_TEST, + test_type=test_type, ), position=CodePosition(line_no=0, col_no=0), ) ) - num_tests += 1 + if is_replay: + num_replay_tests += 1 + else: + num_unit_tests += 1 - return dict(function_to_tests), num_tests, 0 + return dict(function_to_tests), num_unit_tests, num_replay_tests def discover_unit_tests( diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 9c8776f75..5780f4def 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -465,6 +465,10 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt def get_all_replay_test_functions( replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path ) -> tuple[dict[Path, list[FunctionToOptimize]], Path]: + # Check if these are Java replay tests + if replay_test and replay_test[0].suffix == ".java": + return _get_java_replay_test_functions(replay_test, test_cfg, project_root_path) + trace_file_path: Path | None = None for replay_test_file in replay_test: try: @@ -549,6 +553,75 @@ def get_all_replay_test_functions( return dict(filtered_valid_functions), trace_file_path +def _get_java_replay_test_functions( + replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path +) -> tuple[dict[Path, list[FunctionToOptimize]], Path]: + """Parse Java replay test files to extract functions and trace file path.""" + from codeflash.languages.java.replay_test import parse_replay_test_metadata + + trace_file_path: Path | None = None + functions: dict[Path, list[FunctionToOptimize]] = defaultdict(list) + + for test_file in replay_test: + metadata = parse_replay_test_metadata(test_file) + + if trace_file_path is None and "trace_file" in metadata: + trace_file_path = Path(metadata["trace_file"]) + + classname = metadata.get("classname", "") + function_names = [f.strip() for f in metadata.get("functions", "").split(",") if f.strip()] + + if not classname or not function_names: + continue + + # Resolve the source file from the classname (e.g., "com.aerospike.benchmarks.Main") + class_parts = classname.split(".") + # Try matching by full package path first (e.g., com/aerospike/benchmarks/Main.java) + expected_path_suffix = "/".join(class_parts) + ".java" + source_file = None + for java_file in project_root_path.rglob("*.java"): + if java_file.as_posix().endswith(expected_path_suffix): + source_file = java_file + break + # Fall back to simple name match + if source_file is None: + for java_file in project_root_path.rglob("*.java"): + if java_file.stem == class_parts[-1]: + source_file = java_file + break + + if source_file is None: + logger.warning(f"Could not find source file for class {classname}") + continue + + # Use Java discovery to find functions in the source file + from codeflash.languages.registry import get_language_support + + lang_support = get_language_support(source_file) + source_code = source_file.read_text(encoding="utf-8") + all_functions = lang_support.discover_functions(source_code, source_file) + + for func in all_functions: + if func.function_name in function_names: + functions[source_file].append(func) + + if trace_file_path is None: + logger.error("Could not find trace_file_path in Java replay test files.") + from codeflash.code_utils.code_utils import exit_with_message + + exit_with_message("Could not find trace_file_path in Java replay test files.") + raise AssertionError("Unreachable") + + if not trace_file_path.exists(): + from codeflash.code_utils.code_utils import exit_with_message + + exit_with_message( + f"Trace file not found: {trace_file_path}\nPlease regenerate the replay test by re-running the tracer." + ) + + return dict(functions), trace_file_path + + def is_git_repo(file_path: str) -> bool: try: git.Repo(file_path, search_parent_directories=True) diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index f8f890163..bcdabeb8d 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -115,6 +115,7 @@ class TestInfo: test_name: str test_file: Path test_class: str | None = None + is_replay: bool = False @property def full_test_path(self) -> str: diff --git a/codeflash/languages/java/jfr_parser.py b/codeflash/languages/java/jfr_parser.py new file mode 100644 index 000000000..7775378e6 --- /dev/null +++ b/codeflash/languages/java/jfr_parser.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import json +import logging +import shutil +import subprocess +from datetime import datetime +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +class JfrProfile: + """Parses JFR (Java Flight Recorder) files for method-level profiling data. + + Uses the `jfr` CLI tool (ships with JDK 11+) to extract ExecutionSample events + and build method-level timing estimates from sampling data. + """ + + def __init__(self, jfr_file: Path, packages: list[str]) -> None: + self.jfr_file = jfr_file + self.packages = packages + self._method_samples: dict[str, int] = {} + self._method_info: dict[str, dict[str, str]] = {} + self._caller_map: dict[str, dict[str, int]] = {} + self._recording_duration_ns: int = 0 + self._total_samples: int = 0 + self._parse() + + def _find_jfr_tool(self) -> str | None: + jfr_path = shutil.which("jfr") + if jfr_path: + return jfr_path + + java_home = subprocess.run( + ["java", "-XshowSettings:property", "-version"], capture_output=True, text=True, check=False + ) + for line in java_home.stderr.splitlines(): + if "java.home" in line: + home = line.split("=", 1)[1].strip() + candidate = Path(home) / "bin" / "jfr" + if candidate.exists(): + return str(candidate) + return None + + def _parse(self) -> None: + if not self.jfr_file.exists(): + logger.warning("JFR file not found: %s", self.jfr_file) + return + + jfr_tool = self._find_jfr_tool() + if jfr_tool is None: + logger.warning("jfr CLI tool not found, cannot parse JFR profile") + return + + try: + result = subprocess.run( + [jfr_tool, "print", "--events", "jdk.ExecutionSample", "--json", str(self.jfr_file)], + capture_output=True, + text=True, + timeout=120, + check=False, + ) + if result.returncode != 0: + logger.warning("jfr print failed: %s", result.stderr) + return + self._parse_json(result.stdout) + except subprocess.TimeoutExpired: + logger.warning("jfr print timed out for %s", self.jfr_file) + except Exception: + logger.exception("Failed to parse JFR file %s", self.jfr_file) + + def _parse_json(self, json_str: str) -> None: + try: + data = json.loads(json_str) + except json.JSONDecodeError: + logger.warning("Failed to parse JFR JSON output") + return + + events = data.get("recording", {}).get("events", []) + if not events: + events = data.get("events", []) + + # Cache package matching results to avoid repeated checks + package_match_cache: dict[str, bool] = {} + + def matches_packages_cached(method_key: str | None) -> bool: + if method_key is None: + return False + if method_key not in package_match_cache: + package_match_cache[method_key] = self._matches_packages(method_key) + return package_match_cache[method_key] + + for event in events: + if event.get("type") != "jdk.ExecutionSample": + continue + + stack_trace = event.get("values", {}).get("stackTrace", {}) + frames = stack_trace.get("frames", []) + if not frames: + continue + + self._total_samples += 1 + + # Precompute keys for all frames in this stack to avoid repeated conversions + keys = [self._frame_to_key(f) for f in frames] + + # Top-of-stack = own time + top_method_key = keys[0] if keys else None + if matches_packages_cached(top_method_key): + self._method_samples[top_method_key] = self._method_samples.get(top_method_key, 0) + 1 + self._store_method_info(top_method_key, frames[0]) + + # Build caller-callee relationships from adjacent frames + for i in range(len(keys) - 1): + callee_key = keys[i] + caller_key = keys[i + 1] + if callee_key and caller_key and matches_packages_cached(callee_key): + callee_callers = self._caller_map.setdefault(callee_key, {}) + callee_callers[caller_key] = callee_callers.get(caller_key, 0) + 1 + + # Estimate recording duration from event timestamps + if events: + min_ts = None + max_ts = None + for event in events: + try: + start_time = event.get("values", {}).get("startTime") + if not start_time: + continue + # JFR timestamps are in ISO format or epoch nanos + if isinstance(start_time, str): + dt = datetime.fromisoformat(start_time.replace("Z", "+00:00")) + ts = int(dt.timestamp() * 1_000_000_000) + elif isinstance(start_time, (int, float)): + ts = int(start_time) + else: + continue + if min_ts is None or ts < min_ts: + min_ts = ts + if max_ts is None or ts > max_ts: + max_ts = ts + except (ValueError, TypeError): + continue + if min_ts is not None and max_ts is not None: + self._recording_duration_ns = max_ts - min_ts + + def _frame_to_key(self, frame: dict[str, Any]) -> str | None: + method = frame.get("method", {}) + class_name = method.get("type", {}).get("name", "") + method_name = method.get("name", "") + if not class_name or not method_name: + return None + return f"{class_name}.{method_name}" + + def _store_method_info(self, key: str, frame: dict[str, Any]) -> None: + if key in self._method_info: + return + method = frame.get("method", {}) + self._method_info[key] = { + "class_name": method.get("type", {}).get("name", ""), + "method_name": method.get("name", ""), + "descriptor": method.get("descriptor", ""), + "line_number": str(frame.get("lineNumber", 0)), + } + + def _matches_packages(self, method_key: str) -> bool: + if not self.packages: + return True + return any(method_key.startswith(pkg) for pkg in self.packages) + + def get_method_ranking(self) -> list[dict[str, Any]]: + if not self._method_samples or self._total_samples == 0: + return [] + + ranking = [] + for method_key, sample_count in sorted(self._method_samples.items(), key=lambda x: x[1], reverse=True): + info = self._method_info.get(method_key, {}) + ranking.append( + { + "class_name": info.get("class_name", method_key.rsplit(".", 1)[0]), + "method_name": info.get("method_name", method_key.rsplit(".", 1)[-1]), + "sample_count": sample_count, + "pct_of_total": (sample_count / self._total_samples) * 100, + } + ) + return ranking + + def get_addressable_time_ns(self, class_name: str, method_name: str) -> float: + method_key = f"{class_name}.{method_name}" + sample_count = self._method_samples.get(method_key, 0) + if sample_count == 0 or self._total_samples == 0: + return 0.0 + + if self._recording_duration_ns > 0: + return (sample_count / self._total_samples) * self._recording_duration_ns + + # Fallback: return sample count as a proxy (higher = more time) + return float(sample_count * 1_000_000) diff --git a/codeflash/languages/java/replay_test.py b/codeflash/languages/java/replay_test.py new file mode 100644 index 000000000..c753bf4fa --- /dev/null +++ b/codeflash/languages/java/replay_test.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import logging +import re +import sqlite3 +from collections import defaultdict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256) -> int: + """Generate JUnit 5 replay test files from a trace SQLite database. + + Returns the number of test files generated. + """ + if not trace_db_path.exists(): + logger.error("Trace database not found: %s", trace_db_path) + return 0 + + output_dir.mkdir(parents=True, exist_ok=True) + + conn = sqlite3.connect(str(trace_db_path)) + try: + cursor = conn.execute( + "SELECT DISTINCT classname, function, descriptor FROM function_calls ORDER BY classname, function" + ) + methods = cursor.fetchall() + + # Group by class + class_methods: dict[str, list[tuple[str, str]]] = defaultdict(list) + for classname, function, descriptor in methods: + class_methods[classname].append((function, descriptor)) + + test_count = 0 + all_function_names: list[str] = [] + + for classname, method_list in class_methods.items(): + safe_class_name = _sanitize_identifier(classname.replace(".", "_")) + test_class_name = f"ReplayTest_{safe_class_name}" + + test_methods_code: list[str] = [] + class_function_names: list[str] = [] + + for method_name, descriptor in method_list: + # Count invocations for this method + count_result = conn.execute( + "SELECT COUNT(*) FROM function_calls WHERE classname = ? AND function = ? AND descriptor = ?", + (classname, method_name, descriptor), + ).fetchone() + invocation_count = min(count_result[0], max_run_count) + + class_function_names.append(method_name) + safe_method = _sanitize_identifier(method_name) + + for i in range(invocation_count): + escaped_descriptor = descriptor.replace('"', '\\"') + test_methods_code.append( + f" @Test void replay_{safe_method}_{i}() throws Exception {{\n" + f' helper.replay("{classname}", "{method_name}", ' + f'"{escaped_descriptor}", {i});\n' + f" }}" + ) + + all_function_names.extend(class_function_names) + + # Generate the test file + functions_comment = ",".join(class_function_names) + test_content = ( + f"// codeflash:functions={functions_comment}\n" + f"// codeflash:trace_file={trace_db_path.as_posix()}\n" + f"// codeflash:classname={classname}\n" + f"package codeflash.replay;\n\n" + f"import org.junit.jupiter.api.Test;\n" + f"import org.junit.jupiter.api.AfterAll;\n" + f"import com.codeflash.ReplayHelper;\n\n" + f"class {test_class_name} {{\n" + f" private static final ReplayHelper helper =\n" + f' new ReplayHelper("{trace_db_path.as_posix()}");\n\n' + f" @AfterAll static void cleanup() {{ helper.close(); }}\n\n" + "\n\n".join(test_methods_code) + "\n" + "}\n" + ) + + test_file = output_dir / f"{test_class_name}.java" + test_file.write_text(test_content, encoding="utf-8") + test_count += 1 + logger.info("Generated replay test: %s (%d test methods)", test_file.name, len(test_methods_code)) + + finally: + conn.close() + + return test_count + + +def _sanitize_identifier(name: str) -> str: + """Sanitize a string for use as a Java identifier.""" + return re.sub(r"[^a-zA-Z0-9_]", "_", name) + + +def parse_replay_test_metadata(test_file: Path) -> dict[str, str]: + """Parse codeflash metadata comments from a Java replay test file. + + Returns a dict with keys: functions, trace_file, classname. + """ + metadata: dict[str, str] = {} + try: + with test_file.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line.startswith("// codeflash:"): + if line and not line.startswith("//"): + break + continue + key_value = line[len("// codeflash:") :] + if "=" in key_value: + key, value = key_value.split("=", 1) + metadata[key] = value + except Exception: + logger.exception("Failed to parse replay test metadata from %s", test_file) + return metadata diff --git a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar index 842f2c19b..cfcee9390 100644 Binary files a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar and b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar differ diff --git a/codeflash/languages/java/test_discovery.py b/codeflash/languages/java/test_discovery.py index 5a31ff9ef..7db04298b 100644 --- a/codeflash/languages/java/test_discovery.py +++ b/codeflash/languages/java/test_discovery.py @@ -39,6 +39,8 @@ def discover_tests( Resolves method invocations in test code back to their declaring class by tracing variable types, field types, static imports, and constructor calls. + Also handles replay test files (generated by the Java tracer) by parsing + their metadata comments. Args: test_root: Root directory containing tests. @@ -70,7 +72,10 @@ def discover_tests( # Find all test files (various naming conventions) test_files = ( - list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java")) + list(test_root.rglob("*Test.java")) + + list(test_root.rglob("*Tests.java")) + + list(test_root.rglob("Test*.java")) + + list(test_root.rglob("ReplayTest_*.java")) ) # Deduplicate (a file like FooTest.java could match multiple patterns) test_files = list(dict.fromkeys(test_files)) @@ -79,6 +84,12 @@ def discover_tests( for test_file in test_files: try: + # Check if this is a replay test file (has codeflash metadata) + replay_metadata = _parse_replay_metadata(test_file) + if replay_metadata: + _discover_replay_tests(test_file, replay_metadata, function_map, result, analyzer) + continue + test_methods = discover_test_methods(test_file, analyzer) source = test_file.read_text(encoding="utf-8") @@ -104,6 +115,53 @@ def discover_tests( return dict(result) +def _parse_replay_metadata(test_file: Path) -> dict[str, str] | None: + """Check if a test file is a replay test and extract its metadata. + + Returns metadata dict if it's a replay test, None otherwise. + """ + from codeflash.languages.java.replay_test import parse_replay_test_metadata + + metadata = parse_replay_test_metadata(test_file) + return metadata if "functions" in metadata else None + + +def _discover_replay_tests( + test_file: Path, + metadata: dict[str, str], + function_map: dict[str, FunctionToOptimize], + result: dict[str, list[TestInfo]], + analyzer: JavaAnalyzer, +) -> None: + """Map replay test methods to source functions using metadata comments.""" + function_names = [f.strip() for f in metadata.get("functions", "").split(",") if f.strip()] + test_methods = discover_test_methods(test_file, analyzer) + + # Extract test class name from the file + test_class = None + if test_methods: + test_class = test_methods[0].class_name + + for test_method in test_methods: + # Each replay test method is named replay__ + # Map it to the source function it exercises + for func_name in function_names: + if func_name in function_map: + qualified_name = function_map[func_name].qualified_name + result[qualified_name].append( + TestInfo( + test_name=test_method.function_name, + test_file=test_file, + test_class=test_class or test_method.class_name, + is_replay=True, + ) + ) + + logger.debug( + "Discovered %d replay test methods for functions %s in %s", len(test_methods), function_names, test_file.name + ) + + def _compute_file_context(test_source: str, analyzer: JavaAnalyzer) -> tuple: """Pre-compute per-file analysis data: parse tree and static imports. diff --git a/codeflash/languages/java/tracer.py b/codeflash/languages/java/tracer.py new file mode 100644 index 000000000..7b5a30421 --- /dev/null +++ b/codeflash/languages/java/tracer.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import json +import logging +import os +import subprocess +from typing import TYPE_CHECKING + +from codeflash.languages.java.line_profiler import find_agent_jar +from codeflash.languages.java.replay_test import generate_replay_tests + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + +# --add-opens flags needed for Kryo serialization on Java 16+ +ADD_OPENS_FLAGS = ( + "--add-opens=java.base/java.util=ALL-UNNAMED " + "--add-opens=java.base/java.lang=ALL-UNNAMED " + "--add-opens=java.base/java.lang.reflect=ALL-UNNAMED " + "--add-opens=java.base/java.math=ALL-UNNAMED " + "--add-opens=java.base/java.io=ALL-UNNAMED " + "--add-opens=java.base/java.net=ALL-UNNAMED " + "--add-opens=java.base/java.time=ALL-UNNAMED" +) + + +class JavaTracer: + """Orchestrates two-stage Java tracing: JFR profiling + argument capture.""" + + def trace( + self, + java_command: list[str], + trace_db_path: Path, + packages: list[str], + project_root: Path | None = None, + max_function_count: int = 256, + timeout: int = 0, + ) -> tuple[Path, Path]: + """Run the Java program twice: once for profiling, once for arg capture. + + Returns (trace_db_path, jfr_file_path). + """ + jfr_file = trace_db_path.with_suffix(".jfr") + trace_db_path.parent.mkdir(parents=True, exist_ok=True) + + # Stage 1: JFR Profiling + logger.info("Stage 1: Running JFR profiling...") + jfr_env = self.build_jfr_env(jfr_file) + try: + subprocess.run(java_command, env=jfr_env, check=False, timeout=timeout or None) + except subprocess.TimeoutExpired: + logger.warning("JFR profiling stage timed out after %d seconds", timeout) + + if not jfr_file.exists(): + logger.warning("JFR file was not created at %s", jfr_file) + + # Stage 2: Argument Capture via Tracing Agent + logger.info("Stage 2: Running argument capture...") + config_path = self.create_tracer_config( + trace_db_path, packages, project_root=project_root, max_function_count=max_function_count, timeout=timeout + ) + agent_env = self.build_agent_env(config_path) + try: + subprocess.run(java_command, env=agent_env, check=False, timeout=timeout or None) + except subprocess.TimeoutExpired: + logger.warning("Argument capture stage timed out after %d seconds", timeout) + + if not trace_db_path.exists(): + logger.error("Trace database was not created at %s", trace_db_path) + + return trace_db_path, jfr_file + + def create_tracer_config( + self, + trace_db_path: Path, + packages: list[str], + project_root: Path | None = None, + max_function_count: int = 256, + timeout: int = 0, + ) -> Path: + config = { + "dbPath": str(trace_db_path.resolve()), + "packages": packages, + "excludePackages": [], + "maxFunctionCount": max_function_count, + "timeout": timeout, + "projectRoot": str(project_root.resolve()) if project_root else "", + } + + config_path = trace_db_path.with_suffix(".config.json") + config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") + return config_path + + def build_jfr_env(self, jfr_file: Path) -> dict[str, str]: + env = os.environ.copy() + jfr_opts = f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" + existing = env.get("JAVA_TOOL_OPTIONS", "") + env["JAVA_TOOL_OPTIONS"] = f"{existing} {jfr_opts}".strip() + return env + + def build_agent_env(self, config_path: Path) -> dict[str, str]: + env = os.environ.copy() + agent_jar = find_agent_jar() + if agent_jar is None: + msg = "codeflash-runtime JAR not found, cannot run tracing agent" + raise FileNotFoundError(msg) + + agent_opts = f"{ADD_OPENS_FLAGS} -javaagent:{agent_jar}=trace={config_path.resolve()}" + existing = env.get("JAVA_TOOL_OPTIONS", "") + env["JAVA_TOOL_OPTIONS"] = f"{existing} {agent_opts}".strip() + return env + + @staticmethod + def detect_packages_from_source(module_root: Path) -> list[str]: + """Scan Java files for package declarations and return unique package prefixes.""" + packages: set[str] = set() + for java_file in module_root.rglob("*.java"): + try: + in_block_comment = False + with java_file.open("r", encoding="utf-8") as f: + for line in f: + stripped = line.strip() + if in_block_comment: + if "*/" in stripped: + in_block_comment = False + continue + if stripped.startswith("/*"): + if "*/" not in stripped: + in_block_comment = True + continue + if stripped.startswith("package "): + pkg = stripped[8:].rstrip(";").strip() + parts = pkg.split(".") + prefix = ".".join(parts[: min(2, len(parts))]) + packages.add(prefix) + break + if stripped and not stripped.startswith("//"): + break + except (OSError, UnicodeDecodeError): + continue + + return sorted(packages) + + +def run_java_tracer( + java_command: list[str], + trace_db_path: Path, + packages: list[str], + project_root: Path, + output_dir: Path, + max_function_count: int = 256, + timeout: int = 0, + max_run_count: int = 256, +) -> tuple[Path, Path, int]: + """High-level entry point: trace a Java command and generate replay tests. + + Returns (trace_db_path, jfr_file, test_count). + """ + tracer = JavaTracer() + trace_db, jfr_file = tracer.trace( + java_command=java_command, + trace_db_path=trace_db_path, + packages=packages, + project_root=project_root, + max_function_count=max_function_count, + timeout=timeout, + ) + + test_count = generate_replay_tests( + trace_db_path=trace_db, output_dir=output_dir, project_root=project_root, max_run_count=max_run_count + ) + + return trace_db, jfr_file, test_count diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 49eaa23e5..8e9c08ac2 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -41,6 +41,18 @@ from codeflash.models.models import BenchmarkKey, FunctionCalledInTest, ValidCode +def _extract_java_package_from_path(file_path: Path) -> str | None: + """Extract Java package from file path by finding src/main/java or src/test/java marker.""" + parts = file_path.parts + for i, part in enumerate(parts): + if part == "java" and i >= 2 and parts[i - 1] in ("main", "test") and parts[i - 2] == "src": + package_parts = parts[i + 1 : -1] # After java/, exclude filename + if package_parts: + return ".".join(package_parts) + return None + return None + + class Optimizer: def __init__(self, args: Namespace) -> None: self.args = args @@ -360,17 +372,40 @@ def rank_all_functions_globally( return all_functions try: - from codeflash.benchmarking.function_ranker import FunctionRanker + from codeflash.benchmarking.function_ranker import FunctionRanker, JavaFunctionRanker console.rule() logger.info("loading|Ranking functions globally by performance impact...") console.rule() - # Create ranker with trace data - ranker = FunctionRanker(trace_file_path) # Extract just the functions for ranking (without file paths) functions_only = [func for _, func in all_functions] + # Detect if functions are Java and use appropriate ranker + if functions_only and functions_only[0].language == "java": + from codeflash.languages.java.jfr_parser import JfrProfile + + # JFR file is alongside the trace DB with .jfr extension + jfr_file_path = trace_file_path.with_suffix(".jfr") + if not jfr_file_path.exists(): + logger.warning(f"JFR file not found: {jfr_file_path}, falling back to original order") + return all_functions + + # Extract packages from file paths (e.g., src/main/java/com/example/Workload.java → "com.example") + packages = set() + for func in functions_only: + package = _extract_java_package_from_path(func.file_path) + if package: + # Use top two levels as filter prefix (e.g., "com.example" from "com.example.sub") + parts = package.split(".") + packages.add(".".join(parts[: min(2, len(parts))])) + + jfr_profile = JfrProfile(jfr_file_path, list(packages)) + ranker = JavaFunctionRanker(jfr_profile) + else: + # Python ranker with trace data + ranker = FunctionRanker(trace_file_path) + # Rank globally ranked_functions = ranker.rank_functions(functions_only) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 199d07b6e..84f58e9da 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -32,42 +32,82 @@ if TYPE_CHECKING: from argparse import Namespace + from codeflash.languages import Language + logger = logging.getLogger(__name__) -def main(args: Namespace | None = None) -> ArgumentParser: - # For non-Python languages, detect early and route to Optimizer - # Java, JavaScript, and TypeScript use their own test runners (Maven/JUnit, Jest) - # and should not go through Python tracing - if args is None and "--file" in sys.argv: +def _detect_non_python_language(args: Namespace | None) -> Language | None: + """Detect if the project uses a non-Python language from --file or config. + + Returns a Language enum value if non-Python detected, None otherwise. + """ + from codeflash.languages import Language + + # Method 1: Check --file argument for non-Python file extension + file_path_to_check: Path | None = None + if args is not None and getattr(args, "file", None): + file_path_to_check = Path(args.file) + elif args is None and "--file" in sys.argv: try: file_idx = sys.argv.index("--file") if file_idx + 1 < len(sys.argv): - file_path = Path(sys.argv[file_idx + 1]) - if file_path.exists(): - from codeflash.languages import Language, get_language_support + file_path_to_check = Path(sys.argv[file_idx + 1]) + except (IndexError, ValueError): + pass - lang_support = get_language_support(file_path) - detected_language = lang_support.language + if file_path_to_check is not None and file_path_to_check.exists(): + try: + from codeflash.languages import get_language_support - if detected_language in (Language.JAVA, Language.JAVASCRIPT, Language.TYPESCRIPT): - # Parse and process args like main.py does - from codeflash.cli_cmds.cli import parse_args, process_pyproject_config + lang_support = get_language_support(file_path_to_check) + if lang_support.language != Language.PYTHON: + return lang_support.language + except Exception: + pass - full_args = parse_args() - full_args = process_pyproject_config(full_args) - # Set checkpoint functions to None (no checkpoint for single-file optimization) - full_args.previous_checkpoint_functions = None + # Method 2: Check project config for language field + try: + from codeflash.code_utils.config_parser import parse_config_file - from codeflash.optimization import optimizer + config_file = getattr(args, "config_file_path", None) if args else None + config, _ = parse_config_file(config_file) + lang_str = config.get("language", "") + if lang_str == "java": + return Language.JAVA + if lang_str in ("javascript", "typescript"): + return Language(lang_str) + except Exception: + pass - logger.info( - "Detected %s file, routing to Optimizer instead of Python tracer", detected_language.value - ) - optimizer.run_with_args(full_args) - return ArgumentParser() # Return dummy parser since we're done - except (IndexError, OSError, Exception): - pass # Fall through to normal tracing if detection fails + return None + + +def main(args: Namespace | None = None) -> ArgumentParser: + # For non-Python languages, detect early and route to the appropriate handler. + # Java, JavaScript, and TypeScript use their own test runners (Maven/JUnit, Jest) + # and should not go through Python tracing. + # + # Detection methods (in priority order): + # 1. --file pointing to a .java/.js/.ts file + # 2. language field in project config (codeflash.toml or pyproject.toml) + detected_language = _detect_non_python_language(args) + if detected_language is not None: + from codeflash.languages import Language + + if detected_language in (Language.JAVASCRIPT, Language.TYPESCRIPT): + from codeflash.cli_cmds.cli import parse_args, process_pyproject_config + from codeflash.optimization import optimizer + + full_args = parse_args() + full_args = process_pyproject_config(full_args) + full_args.previous_checkpoint_functions = None + logger.info("Detected %s project, routing to Optimizer instead of Python tracer", detected_language.value) + optimizer.run_with_args(full_args) + return ArgumentParser() + + if detected_language == Language.JAVA: + return _run_java_tracer(args) parser = ArgumentParser(allow_abbrev=False) parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", default="codeflash.trace") @@ -141,16 +181,18 @@ def main(args: Namespace | None = None) -> ArgumentParser: "module": parsed_args.module, } try: - pytest_splits = [] - test_paths = [] - replay_test_paths = [] + pytest_splits: list[list[str]] = [] + test_paths: list[str] = [] + replay_test_paths: list[str] = [] if parsed_args.module and unknown_args[0] == "pytest": - pytest_splits, test_paths = pytest_split(unknown_args[1:], limit=parsed_args.limit) - if pytest_splits is None or test_paths is None: + result_splits, result_paths = pytest_split(unknown_args[1:], limit=parsed_args.limit) + if result_splits is None or result_paths is None: console.print(f"❌ Could not find test files in the specified paths: {unknown_args[1:]}") console.print(f"Current working directory: {Path.cwd()}") console.print("Please ensure the test directory exists and contains test files.") sys.exit(1) + pytest_splits = result_splits + test_paths = result_paths if len(pytest_splits) > 1: processes = [] @@ -237,8 +279,7 @@ def main(args: Namespace | None = None) -> ArgumentParser: from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.console import paneled_text from codeflash.cli_cmds.console_constants import CODEFLASH_LOGO - from codeflash.languages import set_current_language - from codeflash.languages.base import Language + from codeflash.languages import Language, set_current_language from codeflash.telemetry import posthog_cf from codeflash.telemetry.sentry import init_sentry @@ -278,5 +319,96 @@ def main(args: Namespace | None = None) -> ArgumentParser: return parser +def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: + """Run the Java two-stage tracer (JFR + argument capture) and optionally optimize.""" + from codeflash.cli_cmds.cli import parse_args, process_pyproject_config + + if existing_args is not None: + full_args = process_pyproject_config(existing_args) + else: + full_args = parse_args() + full_args = process_pyproject_config(full_args) + config = full_args + + trace_only = getattr(config, "trace_only", False) + project_root = Path(getattr(config, "project_root", ".")).resolve() + module_root = Path(getattr(config, "module_root", project_root)).resolve() + max_function_count = getattr(config, "max_function_count", 256) + timeout = int(getattr(config, "timeout", None) or getattr(config, "tracer_timeout", 0) or 0) + + from codeflash.code_utils.code_utils import get_run_tmp_file + from codeflash.languages.java.build_tools import find_test_root + from codeflash.languages.java.tracer import JavaTracer, run_java_tracer + + tracer = JavaTracer() + packages = tracer.detect_packages_from_source(module_root) + if not packages: + logger.warning("No Java packages detected in %s, will trace all non-JDK classes", module_root) + + trace_db_path = get_run_tmp_file(Path("java_trace.db")) + + # Place replay tests in the project's test source tree so Maven/Gradle can compile them + test_root = find_test_root(project_root) + if test_root: + output_dir = test_root / "codeflash" / "replay" + else: + output_dir = project_root / "src" / "test" / "java" / "codeflash" / "replay" + output_dir.mkdir(parents=True, exist_ok=True) + + # Remaining args after our flags are the Java command + remaining = sys.argv[sys.argv.index("--file") + 2 :] if "--file" in sys.argv else sys.argv[1:] + if not remaining: + console.print("[bold red]Error:[/] No Java command provided.") + console.print("Usage: codeflash optimize java -jar target/my-app.jar [args...]") + console.print(" codeflash optimize java -cp target/classes com.example.Main [args...]") + sys.exit(1) + java_command = remaining + + trace_db, jfr_file, test_count = run_java_tracer( + java_command=java_command, + trace_db_path=trace_db_path, + packages=packages, + project_root=project_root, + output_dir=output_dir, + max_function_count=max_function_count, + timeout=timeout, + ) + + console.print(f"[bold green]Java tracing complete:[/] {test_count} replay test files generated") + if jfr_file.exists(): + console.print(f" JFR profile: {jfr_file}") + if trace_db.exists(): + console.print(f" Trace DB: {trace_db}") + + if not trace_only and test_count > 0: + from codeflash.code_utils.config_consts import EffortLevel + from codeflash.languages import Language, set_current_language + from codeflash.optimization import optimizer + + set_current_language(Language.JAVA) + + replay_test_paths = [p.resolve() for p in output_dir.glob("*.java")] + config.replay_test = replay_test_paths + config.previous_checkpoint_functions = None + config.effort = EffortLevel.HIGH.value + config.no_pr = True + config.file = None + config.function = None + config.test_project_root = project_root + optimizer.run_with_args(config) + + # Clean up generated replay tests + for replay_test_path in replay_test_paths: + Path(replay_test_path).unlink(missing_ok=True) + # Clean up codeflash/replay directory if empty + if output_dir.exists() and not any(output_dir.iterdir()): + output_dir.rmdir() + codeflash_dir = output_dir.parent + if codeflash_dir.exists() and codeflash_dir.name == "codeflash" and not any(codeflash_dir.iterdir()): + codeflash_dir.rmdir() + + return ArgumentParser() + + if __name__ == "__main__": main() diff --git a/docs/configuration/java.mdx b/docs/configuration/java.mdx new file mode 100644 index 000000000..9d110fc55 --- /dev/null +++ b/docs/configuration/java.mdx @@ -0,0 +1,153 @@ +--- +title: "Java Configuration" +description: "Configure Codeflash for Java projects using codeflash.toml" +icon: "java" +sidebarTitle: "Java (codeflash.toml)" +keywords: + [ + "configuration", + "codeflash.toml", + "java", + "maven", + "gradle", + "junit", + ] +--- + +# Java Configuration + +Codeflash stores its configuration in `codeflash.toml` under the `[tool.codeflash]` section. + +## Full Reference + +```toml +[tool.codeflash] +# Required +module-root = "src/main/java" +tests-root = "src/test/java" +language = "java" + +# Optional +test-framework = "junit5" # "junit5", "junit4", or "testng" +disable-telemetry = false +git-remote = "origin" +ignore-paths = ["src/main/java/generated/"] +``` + +All file paths are relative to the directory containing `codeflash.toml`. + + +Codeflash auto-detects most settings from your project structure. Running `codeflash init` will set up the correct config — manual configuration is usually not needed. + + +## Auto-Detection + +When you run `codeflash init`, Codeflash inspects your project and auto-detects: + +| Setting | Detection logic | +|---------|----------------| +| `module-root` | Looks for `src/main/java` (Maven/Gradle standard layout) | +| `tests-root` | Looks for `src/test/java`, `test/`, `tests/` | +| `language` | Detected from build files (`pom.xml`, `build.gradle`) and `.java` files | +| `test-framework` | Checks build file dependencies for JUnit 5, JUnit 4, or TestNG | + +## Required Options + +- **`module-root`**: The source directory to optimize. Only code under this directory is discovered for optimization. For standard Maven/Gradle projects, this is `src/main/java`. +- **`tests-root`**: The directory where your tests are located. Codeflash discovers existing tests and places generated replay tests here. +- **`language`**: Must be set to `"java"` for Java projects. + +## Optional Options + +- **`test-framework`**: Test framework. Auto-detected from build dependencies. Supported values: `"junit5"` (default), `"junit4"`, `"testng"`. +- **`disable-telemetry`**: Disable anonymized telemetry. Defaults to `false`. +- **`git-remote`**: Git remote for pull requests. Defaults to `"origin"`. +- **`ignore-paths`**: Paths within `module-root` to skip during optimization. + +## Multi-Module Projects + +For multi-module Maven/Gradle projects, place `codeflash.toml` at the project root and set `module-root` to the module you want to optimize: + +```text +my-project/ +|- client/ +| |- src/main/java/com/example/client/ +| |- src/test/java/com/example/client/ +|- server/ +| |- src/main/java/com/example/server/ +|- pom.xml +|- codeflash.toml +``` + +```toml +[tool.codeflash] +module-root = "client/src/main/java" +tests-root = "client/src/test/java" +language = "java" +``` + +For non-standard layouts (like the Aerospike client where source is under `client/src/`), adjust paths accordingly: + +```toml +[tool.codeflash] +module-root = "client/src" +tests-root = "test/src" +language = "java" +``` + +## Tracer Options + +When using `codeflash optimize` to trace a Java program, these CLI options are available: + +| Option | Description | Default | +|--------|------------|---------| +| `--timeout` | Maximum time (seconds) for each tracing stage | No limit | +| `--max-function-count` | Maximum captures per method | 100 | +| `--trace-only` | Trace and generate replay tests without optimizing | `false` | + +Example with timeout: + +```bash +codeflash optimize --timeout 30 java -jar target/my-app.jar --app-args +``` + +## Example + +### Standard Maven project + +```text +my-app/ +|- src/ +| |- main/java/com/example/ +| | |- App.java +| | |- Utils.java +| |- test/java/com/example/ +| |- AppTest.java +|- pom.xml +|- codeflash.toml +``` + +```toml +[tool.codeflash] +module-root = "src/main/java" +tests-root = "src/test/java" +language = "java" +``` + +### Gradle project + +```text +my-lib/ +|- src/ +| |- main/java/com/example/ +| |- test/java/com/example/ +|- build.gradle +|- codeflash.toml +``` + +```toml +[tool.codeflash] +module-root = "src/main/java" +tests-root = "src/test/java" +language = "java" +``` diff --git a/docs/docs.json b/docs/docs.json index fe0c23098..e3fead77f 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -26,7 +26,8 @@ "group": "Getting Started", "pages": [ "getting-started/local-installation", - "getting-started/javascript-installation" + "getting-started/javascript-installation", + "getting-started/java-installation" ] }, { @@ -45,6 +46,7 @@ "pages": [ "configuration/python", "configuration/javascript", + "configuration/java", "getting-the-best-out-of-codeflash" ] }, diff --git a/docs/getting-started/java-installation.mdx b/docs/getting-started/java-installation.mdx new file mode 100644 index 000000000..a75e1f0b7 --- /dev/null +++ b/docs/getting-started/java-installation.mdx @@ -0,0 +1,131 @@ +--- +title: "Java Installation" +description: "Install and configure Codeflash for your Java project" +icon: "java" +sidebarTitle: "Java Setup" +keywords: + [ + "installation", + "java", + "maven", + "gradle", + "junit", + "junit5", + "tracing", + ] +--- + +Codeflash supports Java projects using Maven or Gradle build systems. It uses a two-stage tracing approach to capture method arguments and profiling data from running Java programs, then optimizes the hottest functions. + +### Prerequisites + +Before installing Codeflash, ensure you have: + +1. **Java 11 or above** installed +2. **Maven or Gradle** as your build tool +3. **A Java project** with source code under a standard directory layout + +Good to have (optional): + +1. **Unit tests** (JUnit 5 or JUnit 4) — Codeflash uses them alongside traced replay tests to verify correctness + + + + +Codeflash CLI is a Python tool. Install it with pip: + +```bash +pip install codeflash +``` + +Or with uv: + +```bash +uv pip install codeflash +``` + + + + +Navigate to your Java project root (where `pom.xml` or `build.gradle` is) and run: + +```bash +codeflash init +``` + +This will: +- Detect your build tool (Maven/Gradle) +- Find your source and test directories +- Create a `codeflash.toml` configuration file + + + + +Check that the configuration looks correct: + +```bash +cat codeflash.toml +``` + +You should see something like: + +```toml +[tool.codeflash] +module-root = "src/main/java" +tests-root = "src/test/java" +language = "java" +``` + + + + +Trace and optimize a running Java program: + +```bash +codeflash optimize java -jar target/my-app.jar +``` + +Or with Maven: + +```bash +codeflash optimize mvn exec:java -Dexec.mainClass="com.example.Main" +``` + +Codeflash will: +1. Profile your program using JFR (Java Flight Recorder) +2. Capture method arguments using a bytecode instrumentation agent +3. Generate JUnit replay tests from the captured data +4. Rank functions by performance impact +5. Optimize the most impactful functions + + + + +## How it works + +Codeflash uses a **two-stage tracing** approach for Java: + +1. **Stage 1 — JFR Profiling**: Runs your program with Java Flight Recorder enabled to collect accurate method-level CPU profiling data. JFR has ~1% overhead and doesn't affect JIT compilation. + +2. **Stage 2 — Argument Capture**: Runs your program again with a bytecode instrumentation agent that captures method arguments using Kryo serialization. Arguments are stored in an SQLite database. + +The traced data is used to generate **JUnit replay tests** that exercise your functions with real-world inputs. Codeflash uses these tests alongside any existing unit tests to verify correctness and benchmark optimization candidates. + + +Your program runs **twice** — once for profiling, once for argument capture. This separation ensures profiling data isn't distorted by serialization overhead. + + +## Supported build tools + +| Build Tool | Detection | Test Execution | +|-----------|-----------|---------------| +| **Maven** | `pom.xml` | Maven Surefire plugin | +| **Gradle** | `build.gradle` / `build.gradle.kts` | Gradle test task | + +## Supported test frameworks + +| Framework | Support Level | +|-----------|-------------| +| **JUnit 5** | Full support (default) | +| **JUnit 4** | Full support | +| **TestNG** | Basic support | diff --git a/docs/optimizing-with-codeflash/trace-and-optimize.mdx b/docs/optimizing-with-codeflash/trace-and-optimize.mdx index fb62ea1c1..4c332a929 100644 --- a/docs/optimizing-with-codeflash/trace-and-optimize.mdx +++ b/docs/optimizing-with-codeflash/trace-and-optimize.mdx @@ -15,6 +15,10 @@ keywords: "typescript", "jest", "vitest", + "java", + "jfr", + "maven", + "gradle", ] --- @@ -52,6 +56,26 @@ codeflash optimize --vitest codeflash optimize --language javascript script.js ``` + +To trace and optimize a running Java program, replace your `java` command with `codeflash optimize java`: + +```bash +# JAR application +codeflash optimize java -jar target/my-app.jar --app-args + +# Class with classpath +codeflash optimize java -cp target/classes com.example.Main + +# Maven exec +codeflash optimize mvn exec:java -Dexec.mainClass="com.example.Main" +``` + +For long-running programs (servers, benchmarks), use `--timeout` to limit each tracing stage: + +```bash +codeflash optimize --timeout 30 java -jar target/my-app.jar +``` + The `codeflash optimize` command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results. @@ -194,5 +218,57 @@ The JavaScript tracer uses Babel instrumentation to capture function calls durin - `--max-function-count`: Maximum traces per function (default: 256). - `--only-functions`: Comma-separated list of function names to trace. + + + +The Java tracer uses a **two-stage approach**: JFR (Java Flight Recorder) for accurate profiling, then a bytecode instrumentation agent for argument capture. + +1. **Trace and optimize a Java program** + + Replace your `java` command with `codeflash optimize java`: + + ```bash + # JAR application + codeflash optimize java -jar target/my-app.jar --app-args + + # Class with classpath + codeflash optimize java -cp target/classes com.example.Main + ``` + + Codeflash will run your program twice (once for profiling, once for argument capture), generate JUnit replay tests, then optimize the most impactful functions. + +2. **Long-running programs** + + For servers, benchmarks, or programs that don't terminate on their own, use `--timeout` to limit each tracing stage: + + ```bash + codeflash optimize --timeout 30 java -jar target/my-benchmark.jar + ``` + + Each stage runs for at most 30 seconds, then the program is terminated and captured data is processed. + +3. **Trace only (no optimization)** + + ```bash + codeflash optimize --trace-only java -jar target/my-app.jar + ``` + + This generates replay tests in `src/test/java/codeflash/replay/` without running the optimizer. + + More Options: + + - `--timeout`: Maximum time (seconds) for each tracing stage. + - `--max-function-count`: Maximum captures per method (default: 100). + + +**How the Java tracer works:** + +- **Stage 1 (JFR)**: Runs your program with Java Flight Recorder enabled. JFR is built into the JVM (Java 11+), has ~1% overhead, and doesn't interfere with JIT compilation. This produces accurate method-level CPU profiling data. + +- **Stage 2 (Agent)**: Runs your program with a bytecode instrumentation agent injected via `JAVA_TOOL_OPTIONS`. The agent intercepts method entry points, serializes arguments using Kryo, and writes them to an SQLite database. A 500ms timeout per serialization prevents hangs on complex object graphs. + +- **Replay Tests**: Generated JUnit 5 test classes that deserialize captured arguments and invoke the original methods via reflection. These tests exercise your code with real-world inputs. + + diff --git a/tests/scripts/end_to_end_test_java_tracer.py b/tests/scripts/end_to_end_test_java_tracer.py new file mode 100644 index 000000000..e904a4e98 --- /dev/null +++ b/tests/scripts/end_to_end_test_java_tracer.py @@ -0,0 +1,144 @@ +import logging +import os +import pathlib +import re +import shutil +import subprocess +import time + + +def run_test(expected_improvement_pct: int) -> bool: + logging.basicConfig(level=logging.INFO) + fixture_dir = (pathlib.Path(__file__).parent.parent / "test_languages" / "fixtures" / "java_tracer_e2e").resolve() + + # Ensure test directory exists (git doesn't track empty dirs) + test_java_dir = fixture_dir / "src" / "test" / "java" + test_java_dir.mkdir(parents=True, exist_ok=True) + + # Clean up leftover replay tests from previous runs + replay_dir = test_java_dir / "codeflash" / "replay" + if replay_dir.exists(): + shutil.rmtree(replay_dir, ignore_errors=True) + for f in test_java_dir.rglob("*__perfinstrumented*.java"): + f.unlink(missing_ok=True) + for f in test_java_dir.rglob("*__perfonlyinstrumented*.java"): + f.unlink(missing_ok=True) + + # Compile the workload + classes_dir = fixture_dir / "target" / "classes" + classes_dir.mkdir(parents=True, exist_ok=True) + compile_result = subprocess.run( + [ + "javac", + "--release", + "11", + "-d", + str(classes_dir), + str(fixture_dir / "src" / "main" / "java" / "com" / "example" / "Workload.java"), + ], + capture_output=True, + text=True, + ) + if compile_result.returncode != 0: + logging.error(f"javac failed: {compile_result.stderr}") + return False + + # Run the Java tracer + optimizer + command = [ + "uv", + "run", + "--no-project", + "-m", + "codeflash.main", + "optimize", + "java", + "-cp", + str(classes_dir), + "com.example.Workload", + ] + + env = os.environ.copy() + env["PYTHONIOENCODING"] = "utf-8" + logging.info(f"Running command: {' '.join(command)}") + logging.info(f"Working directory: {fixture_dir}") + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + cwd=str(fixture_dir), + env=env, + encoding="utf-8", + ) + + output = [] + for line in process.stdout: + logging.info(line.strip()) + output.append(line) + + return_code = process.wait() + stdout = "".join(output) + if return_code != 0: + logging.error(f"Full output:\n{stdout}") + + if return_code != 0: + logging.error(f"Command returned exit code {return_code}") + return False + + # Validate: replay tests were generated + if "replay test files generated" not in stdout: + logging.error("Failed to find replay test generation message") + return False + + # Validate: replay tests were discovered + replay_match = re.search(r"Discovered \d+ existing unit tests? and (\d+) replay tests?", stdout) + if not replay_match: + logging.error("Failed to find replay test discovery message") + return False + num_replay = int(replay_match.group(1)) + if num_replay == 0: + logging.error("No replay tests discovered") + return False + logging.info(f"Replay tests discovered: {num_replay}") + + # Validate: at least one optimization was found + if "⚡️ Optimization successful! 📄 " not in stdout: + logging.error("Failed to find optimization success message") + return False + + improvement_match = re.search(r"📈 ([\d,]+)% (?:(\w+) )?improvement", stdout) + if not improvement_match: + logging.error("Could not find improvement percentage in output") + return False + + improvement_pct = int(improvement_match.group(1).replace(",", "")) + logging.info(f"Performance improvement: {improvement_pct}%") + + if improvement_pct <= expected_improvement_pct: + logging.error(f"Performance improvement {improvement_pct}% not above {expected_improvement_pct}%") + return False + + logging.info(f"Success: Java tracer e2e passed with {improvement_pct}% improvement") + return True + + +def run_with_retries(test_func, *args) -> int: + max_retries = int(os.getenv("MAX_RETRIES", 3)) + retry_delay = int(os.getenv("RETRY_DELAY", 5)) + for attempt in range(1, max_retries + 1): + logging.info(f"\n=== Attempt {attempt} of {max_retries} ===") + if test_func(*args): + logging.info(f"Test passed on attempt {attempt}") + return 0 + logging.error(f"Test failed on attempt {attempt}") + if attempt < max_retries: + logging.info(f"Retrying in {retry_delay} seconds...") + time.sleep(retry_delay) + else: + logging.error("Test failed after all retries") + return 1 + return 1 + + +if __name__ == "__main__": + exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10)))) diff --git a/tests/test_languages/fixtures/java_tracer_e2e/codeflash.toml b/tests/test_languages/fixtures/java_tracer_e2e/codeflash.toml new file mode 100644 index 000000000..a501ef8cb --- /dev/null +++ b/tests/test_languages/fixtures/java_tracer_e2e/codeflash.toml @@ -0,0 +1,6 @@ +# Codeflash configuration for Java project + +[tool.codeflash] +module-root = "src/main/java" +tests-root = "src/test/java" +language = "java" diff --git a/tests/test_languages/fixtures/java_tracer_e2e/pom.xml b/tests/test_languages/fixtures/java_tracer_e2e/pom.xml new file mode 100644 index 000000000..7fffde8b2 --- /dev/null +++ b/tests/test_languages/fixtures/java_tracer_e2e/pom.xml @@ -0,0 +1,67 @@ + + 4.0.0 + + com.example + tracer-e2e + 1.0.0 + + + 11 + 11 + UTF-8 + + + + + org.junit.jupiter + junit-jupiter + 5.10.0 + test + + + com.codeflash + codeflash-runtime + 1.0.0 + test + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.2.2 + + + + org.jacoco + jacoco-maven-plugin + 0.8.13 + + + prepare-agent + + prepare-agent + + + + report + verify + + report + + + + + **/*.class + + + + + + + + diff --git a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java new file mode 100644 index 000000000..9b6078000 --- /dev/null +++ b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java @@ -0,0 +1,56 @@ +package com.example; + +import java.util.ArrayList; +import java.util.List; + +public class Workload { + + public static int computeSum(int n) { + int sum = 0; + for (int i = 0; i < n; i++) { + sum += i; + } + return sum; + } + + public static String repeatString(String s, int count) { + String result = ""; + for (int i = 0; i < count; i++) { + result = result + s; + } + return result; + } + + public static List filterEvens(List numbers) { + List result = new ArrayList<>(); + for (int n : numbers) { + if (n % 2 == 0) { + result.add(n); + } + } + return result; + } + + public int instanceMethod(int x, int y) { + return x * y + computeSum(x); + } + + public static void main(String[] args) { + // Exercise the methods so the tracer can capture invocations + System.out.println("computeSum(100) = " + computeSum(100)); + System.out.println("computeSum(50) = " + computeSum(50)); + + System.out.println("repeatString(\"ab\", 3) = " + repeatString("ab", 3)); + System.out.println("repeatString(\"x\", 5) = " + repeatString("x", 5)); + + List nums = new ArrayList<>(); + for (int i = 1; i <= 10; i++) nums.add(i); + System.out.println("filterEvens(1..10) = " + filterEvens(nums)); + + Workload w = new Workload(); + System.out.println("instanceMethod(5, 3) = " + w.instanceMethod(5, 3)); + System.out.println("instanceMethod(10, 2) = " + w.instanceMethod(10, 2)); + + System.out.println("Workload complete."); + } +} diff --git a/tests/test_languages/test_java/test_java_tracer_e2e.py b/tests/test_languages/test_java/test_java_tracer_e2e.py new file mode 100644 index 000000000..157f23eb6 --- /dev/null +++ b/tests/test_languages/test_java/test_java_tracer_e2e.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +import sqlite3 +import subprocess +from pathlib import Path + +import pytest + +from codeflash.languages.java.line_profiler import find_agent_jar +from codeflash.languages.java.replay_test import generate_replay_tests, parse_replay_test_metadata +from codeflash.languages.java.tracer import ADD_OPENS_FLAGS, JavaTracer + +FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "java_tracer_e2e" +WORKLOAD_SOURCE = FIXTURE_DIR / "src" / "main" / "java" / "com" / "example" / "Workload.java" +WORKLOAD_CLASS = "com.example.Workload" +WORKLOAD_PACKAGE = "com.example" + + +@pytest.fixture(scope="module") +def compiled_workload() -> Path: + """Compile the Java workload fixture (once per module).""" + classes_dir = FIXTURE_DIR / "target" / "classes" + classes_dir.mkdir(parents=True, exist_ok=True) + result = subprocess.run( + ["javac", "--release", "11", "-d", str(classes_dir), str(WORKLOAD_SOURCE)], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode == 0, f"javac failed: {result.stderr}" + return classes_dir + + +@pytest.fixture +def trace_db(tmp_path: Path) -> Path: + return tmp_path / "trace.db" + + +class TestTracingAgent: + def test_agent_jar_found(self) -> None: + jar = find_agent_jar() + assert jar is not None, "codeflash-runtime JAR not found" + assert jar.exists() + + def test_agent_captures_invocations(self, compiled_workload: Path, trace_db: Path) -> None: + """Test that the tracing agent captures method invocations into SQLite.""" + agent_jar = find_agent_jar() + assert agent_jar is not None + + import json + + config = { + "dbPath": str(trace_db), + "packages": [WORKLOAD_PACKAGE], + "excludePackages": [], + "maxFunctionCount": 256, + "timeout": 0, + "projectRoot": str(FIXTURE_DIR), + } + config_path = trace_db.with_suffix(".config.json") + config_path.write_text(json.dumps(config), encoding="utf-8") + + result = subprocess.run( + [ + "java", + *ADD_OPENS_FLAGS.split(), + f"-javaagent:{agent_jar}=trace={config_path}", + "-cp", + str(compiled_workload), + WORKLOAD_CLASS, + ], + capture_output=True, + text=True, + check=False, + timeout=30, + ) + assert "Workload complete." in result.stdout, f"Workload failed to run: {result.stderr}" + assert trace_db.exists(), "Trace DB not created" + + # Verify database contents + conn = sqlite3.connect(str(trace_db)) + try: + rows = conn.execute("SELECT function, classname, descriptor, length(args) FROM function_calls").fetchall() + assert len(rows) >= 5, f"Expected at least 5 captured invocations, got {len(rows)}" + + # Check that specific methods were captured + functions = {row[0] for row in rows} + assert "computeSum" in functions + assert "repeatString" in functions + assert "filterEvens" in functions + assert "instanceMethod" in functions + + # Verify all rows have non-empty args blobs + for row in rows: + assert row[3] > 0, f"Empty args blob for {row[0]}" + + # Verify metadata + metadata = dict(conn.execute("SELECT key, value FROM metadata").fetchall()) + assert "totalCaptures" in metadata + assert int(metadata["totalCaptures"]) >= 5 + finally: + conn.close() + + def test_max_function_count_limit(self, compiled_workload: Path, trace_db: Path) -> None: + """Test that maxFunctionCount limits captures per method.""" + agent_jar = find_agent_jar() + assert agent_jar is not None + + import json + + config = { + "dbPath": str(trace_db), + "packages": [WORKLOAD_PACKAGE], + "excludePackages": [], + "maxFunctionCount": 2, + "timeout": 0, + "projectRoot": str(FIXTURE_DIR), + } + config_path = trace_db.with_suffix(".config.json") + config_path.write_text(json.dumps(config), encoding="utf-8") + + subprocess.run( + [ + "java", + *ADD_OPENS_FLAGS.split(), + f"-javaagent:{agent_jar}=trace={config_path}", + "-cp", + str(compiled_workload), + WORKLOAD_CLASS, + ], + capture_output=True, + text=True, + check=False, + timeout=30, + ) + + conn = sqlite3.connect(str(trace_db)) + try: + # computeSum is called 4 times (2 direct + 2 from instanceMethod) + compute_count = conn.execute( + "SELECT COUNT(*) FROM function_calls WHERE function = 'computeSum'" + ).fetchone()[0] + assert compute_count <= 2, f"Expected at most 2 computeSum captures, got {compute_count}" + finally: + conn.close() + + +class TestReplayTestGeneration: + def test_generates_test_files(self, compiled_workload: Path, trace_db: Path, tmp_path: Path) -> None: + """Test that replay test files are generated from trace DB.""" + # First, create a trace + agent_jar = find_agent_jar() + assert agent_jar is not None + + import json + + config = { + "dbPath": str(trace_db), + "packages": [WORKLOAD_PACKAGE], + "excludePackages": [], + "maxFunctionCount": 256, + "timeout": 0, + "projectRoot": str(FIXTURE_DIR), + } + config_path = trace_db.with_suffix(".config.json") + config_path.write_text(json.dumps(config), encoding="utf-8") + + subprocess.run( + [ + "java", + *ADD_OPENS_FLAGS.split(), + f"-javaagent:{agent_jar}=trace={config_path}", + "-cp", + str(compiled_workload), + WORKLOAD_CLASS, + ], + capture_output=True, + check=False, + timeout=30, + ) + + # Generate replay tests + output_dir = tmp_path / "replay_tests" + count = generate_replay_tests( + trace_db_path=trace_db, + output_dir=output_dir, + project_root=FIXTURE_DIR, + ) + + assert count >= 1, f"Expected at least 1 test file, got {count}" + test_files = list(output_dir.glob("*.java")) + assert len(test_files) >= 1 + + # Find the main workload test file + workload_files = [f for f in test_files if "Workload" in f.name and "ConstructorAccess" not in f.name] + assert len(workload_files) == 1 + content = workload_files[0].read_text(encoding="utf-8") + assert "package codeflash.replay;" in content + assert "import org.junit.jupiter.api.Test;" in content + assert "ReplayHelper" in content + assert "replay_computeSum_0" in content + assert "replay_repeatString_0" in content + + def test_metadata_parsing(self, compiled_workload: Path, trace_db: Path, tmp_path: Path) -> None: + """Test that metadata comments are correctly parsed from generated tests.""" + agent_jar = find_agent_jar() + assert agent_jar is not None + + import json + + config = { + "dbPath": str(trace_db), + "packages": [WORKLOAD_PACKAGE], + "excludePackages": [], + "maxFunctionCount": 256, + "timeout": 0, + "projectRoot": str(FIXTURE_DIR), + } + config_path = trace_db.with_suffix(".config.json") + config_path.write_text(json.dumps(config), encoding="utf-8") + + subprocess.run( + [ + "java", + *ADD_OPENS_FLAGS.split(), + f"-javaagent:{agent_jar}=trace={config_path}", + "-cp", + str(compiled_workload), + WORKLOAD_CLASS, + ], + capture_output=True, + check=False, + timeout=30, + ) + + output_dir = tmp_path / "replay_tests" + generate_replay_tests(trace_db_path=trace_db, output_dir=output_dir, project_root=FIXTURE_DIR) + + test_files = [f for f in output_dir.glob("*.java") if "ConstructorAccess" not in f.name] + test_file = test_files[0] + metadata = parse_replay_test_metadata(test_file) + + assert "functions" in metadata + assert "trace_file" in metadata + assert "classname" in metadata + assert "computeSum" in metadata["functions"] + assert metadata["classname"] == "com.example.Workload" + assert metadata["trace_file"] == trace_db.as_posix() + + +class TestJavaTracerOrchestration: + def test_two_stage_trace(self, compiled_workload: Path, tmp_path: Path) -> None: + """Test the full two-stage JavaTracer flow (JFR + agent).""" + trace_db_path = tmp_path / "trace.db" + tracer = JavaTracer() + + trace_db, _jfr_file = tracer.trace( + java_command=["java", "-cp", str(compiled_workload), WORKLOAD_CLASS], + trace_db_path=trace_db_path, + packages=[WORKLOAD_PACKAGE], + project_root=FIXTURE_DIR, + ) + + assert trace_db.exists(), "Trace DB not created by JavaTracer" + + # Verify trace DB has captures + conn = sqlite3.connect(str(trace_db)) + try: + count = conn.execute("SELECT COUNT(*) FROM function_calls").fetchone()[0] + assert count >= 5, f"Expected at least 5 captured invocations, got {count}" + finally: + conn.close() + + def test_full_trace_and_replay_generation(self, compiled_workload: Path, tmp_path: Path) -> None: + """Test the full flow: trace → generate replay tests.""" + from codeflash.languages.java.tracer import run_java_tracer + + trace_db_path = tmp_path / "trace.db" + output_dir = tmp_path / "replay_tests" + + trace_db, _jfr_file, test_count = run_java_tracer( + java_command=["java", "-cp", str(compiled_workload), WORKLOAD_CLASS], + trace_db_path=trace_db_path, + packages=[WORKLOAD_PACKAGE], + project_root=FIXTURE_DIR, + output_dir=output_dir, + ) + + assert trace_db.exists() + assert test_count >= 1 + + # Verify the generated test files + test_files = list(output_dir.glob("*.java")) + assert len(test_files) >= 1 + workload_files = [f for f in test_files if "Workload" in f.name and "ConstructorAccess" not in f.name] + assert len(workload_files) == 1 + content = workload_files[0].read_text(encoding="utf-8") + assert "replay_computeSum" in content + assert "replay_instanceMethod" in content + + def test_package_detection(self) -> None: + """Test that package detection finds Java packages from source files.""" + packages = JavaTracer.detect_packages_from_source(FIXTURE_DIR) + assert "com.example" in packages diff --git a/tests/test_languages/test_java/test_java_tracer_integration.py b/tests/test_languages/test_java/test_java_tracer_integration.py new file mode 100644 index 000000000..f6ffefdf2 --- /dev/null +++ b/tests/test_languages/test_java/test_java_tracer_integration.py @@ -0,0 +1,345 @@ +"""End-to-end integration test for the Java tracer → optimizer pipeline. + +Tests the full flow: trace → replay test generation → function discovery → +test discovery → function ranking, using the simple Workload fixture. +""" + +from __future__ import annotations + +import subprocess +from pathlib import Path + +import pytest + +from codeflash.languages.java.tracer import run_java_tracer + +FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "java_tracer_e2e" +WORKLOAD_SOURCE = FIXTURE_DIR / "src" / "main" / "java" / "com" / "example" / "Workload.java" +WORKLOAD_CLASS = "com.example.Workload" +WORKLOAD_PACKAGE = "com.example" + + +@pytest.fixture(scope="module") +def compiled_workload() -> Path: + classes_dir = FIXTURE_DIR / "target" / "classes" + classes_dir.mkdir(parents=True, exist_ok=True) + result = subprocess.run( + ["javac", "--release", "11", "-d", str(classes_dir), str(WORKLOAD_SOURCE)], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode == 0, f"javac failed: {result.stderr}" + return classes_dir + + +@pytest.fixture +def traced_workload(compiled_workload: Path, tmp_path: Path) -> tuple[Path, Path, Path, int]: + """Trace the workload and generate replay tests. Returns (trace_db, jfr_file, output_dir, test_count).""" + trace_db_path = tmp_path / "trace.db" + output_dir = tmp_path / "replay_tests" + + trace_db, jfr_file, test_count = run_java_tracer( + java_command=["java", "-cp", str(compiled_workload), WORKLOAD_CLASS], + trace_db_path=trace_db_path, + packages=[WORKLOAD_PACKAGE], + project_root=FIXTURE_DIR, + output_dir=output_dir, + ) + + assert trace_db.exists(), "Trace DB not created" + assert test_count >= 1, f"Expected at least 1 replay test file, got {test_count}" + return trace_db, jfr_file, output_dir, test_count + + +class TestFunctionDiscoveryFromReplayTests: + """Test that functions are correctly discovered from replay test metadata.""" + + def test_discover_functions_from_replay_tests(self, traced_workload: tuple) -> None: + _trace_db, _jfr_file, output_dir, _test_count = traced_workload + + from codeflash.discovery.functions_to_optimize import _get_java_replay_test_functions + from codeflash.verification.verification_utils import TestConfig + + replay_test_paths = list(output_dir.glob("*.java")) + assert len(replay_test_paths) >= 1 + + test_cfg = TestConfig( + tests_root=FIXTURE_DIR / "src" / "test" / "java", + tests_project_rootdir=FIXTURE_DIR, + project_root_path=FIXTURE_DIR, + pytest_cmd="pytest", + ) + + functions, trace_file_path = _get_java_replay_test_functions(replay_test_paths, test_cfg, FIXTURE_DIR) + + # Should have found functions in the Workload source file + assert len(functions) > 0, "No functions discovered from replay tests" + assert trace_file_path.exists(), f"Trace file not found: {trace_file_path}" + + # Collect all discovered function names + all_func_names = set() + for file_path, func_list in functions.items(): + assert file_path.exists(), f"Source file not found: {file_path}" + assert "Workload" in file_path.name + for func in func_list: + all_func_names.add(func.function_name) + assert func.language == "java", f"Expected language='java', got '{func.language}'" + assert func.file_path == file_path + + assert "computeSum" in all_func_names + assert "repeatString" in all_func_names + + def test_discover_tests_for_replay_tests(self, traced_workload: tuple) -> None: + """Test that test discovery maps replay tests to source functions.""" + _trace_db, _jfr_file, output_dir, _test_count = traced_workload + + from codeflash.languages.java.discovery import discover_functions_from_source + from codeflash.languages.java.test_discovery import discover_tests + + source_code = WORKLOAD_SOURCE.read_text(encoding="utf-8") + source_functions = discover_functions_from_source(source_code, file_path=WORKLOAD_SOURCE) + + result = discover_tests(output_dir, source_functions) + + # Replay tests should be mapped to source functions + assert len(result) > 0, "No test mappings found from replay tests" + + # Check specific functions are mapped + matched_func_names = set() + for qualified_name in result: + func_name = qualified_name.split(".")[-1] if "." in qualified_name else qualified_name + matched_func_names.add(func_name) + + assert "computeSum" in matched_func_names, f"computeSum not found in: {result.keys()}" + assert "repeatString" in matched_func_names, f"repeatString not found in: {result.keys()}" + + # Each function should have at least one test + for func_name, test_infos in result.items(): + assert len(test_infos) > 0, f"No tests for {func_name}" + for test_info in test_infos: + assert test_info.test_file.exists() + assert "ReplayTest" in test_info.test_file.name + + +class TestJfrProfiling: + """Test JFR profiling and function ranking.""" + + def test_jfr_parsing(self, traced_workload: tuple) -> None: + _trace_db, jfr_file, _output_dir, _test_count = traced_workload + + if not jfr_file.exists(): + pytest.skip("JFR file not created (JFR may not be available)") + + from codeflash.languages.java.jfr_parser import JfrProfile + + profile = JfrProfile(jfr_file, [WORKLOAD_PACKAGE]) + ranking = profile.get_method_ranking() + + # The workload is very short, so JFR might not capture many samples + # Just verify the parser doesn't crash and returns a list + assert isinstance(ranking, list) + + def test_java_function_ranker(self, traced_workload: tuple) -> None: + _trace_db, jfr_file, _output_dir, _test_count = traced_workload + + if not jfr_file.exists(): + pytest.skip("JFR file not created (JFR may not be available)") + + from codeflash.benchmarking.function_ranker import JavaFunctionRanker + from codeflash.languages.java.discovery import discover_functions_from_source + from codeflash.languages.java.jfr_parser import JfrProfile + + profile = JfrProfile(jfr_file, [WORKLOAD_PACKAGE]) + ranker = JavaFunctionRanker(profile) + + source_code = WORKLOAD_SOURCE.read_text(encoding="utf-8") + source_functions = discover_functions_from_source(source_code, file_path=WORKLOAD_SOURCE) + + # Rank functions - should not crash even with minimal JFR data + ranked = ranker.rank_functions(source_functions) + assert isinstance(ranked, list) + + +class TestFullDiscoveryPipeline: + """Test the complete discovery pipeline as the optimizer would run it.""" + + def test_full_pipeline(self, compiled_workload: Path, tmp_path: Path) -> None: + """Simulate what optimizer.run() does: discover functions, discover tests, rank. + + Uses the same directory layout as the real flow: replay tests go into + src/test/java/codeflash/replay/ so test discovery can find them. + """ + trace_db_path = tmp_path / "trace.db" + + # Generate replay tests into the project's test source tree (like _run_java_tracer does) + test_root = FIXTURE_DIR / "src" / "test" / "java" + output_dir = test_root / "codeflash" / "replay" + output_dir.mkdir(parents=True, exist_ok=True) + + try: + _trace_db, jfr_file, test_count = run_java_tracer( + java_command=["java", "-cp", str(compiled_workload), WORKLOAD_CLASS], + trace_db_path=trace_db_path, + packages=[WORKLOAD_PACKAGE], + project_root=FIXTURE_DIR, + output_dir=output_dir, + ) + assert test_count >= 1 + + # Step 1: Discover functions from replay tests (like get_optimizable_functions) + from codeflash.discovery.functions_to_optimize import _get_java_replay_test_functions + from codeflash.verification.verification_utils import TestConfig + + replay_test_paths = list(output_dir.glob("*.java")) + test_cfg = TestConfig( + tests_root=test_root, + tests_project_rootdir=FIXTURE_DIR, + project_root_path=FIXTURE_DIR, + pytest_cmd="pytest", + ) + + file_to_funcs, trace_file_path = _get_java_replay_test_functions(replay_test_paths, test_cfg, FIXTURE_DIR) + assert len(file_to_funcs) > 0 + assert trace_file_path.exists() + + # Step 2: Set language (like optimizer.run lines 496-502) + from codeflash.languages import set_current_language + from codeflash.languages.base import Language + + set_current_language(Language.JAVA) + + # Step 3: Discover tests (like optimizer.discover_tests) + from codeflash.discovery.discover_unit_tests import discover_tests_for_language + + all_functions = [func for funcs in file_to_funcs.values() for func in funcs] + function_to_tests, num_unit_tests, num_replay_tests = discover_tests_for_language( + test_cfg, "java", file_to_funcs + ) + + assert num_unit_tests + num_replay_tests > 0, "No tests discovered" + assert num_replay_tests > 0, f"Expected replay tests, got {num_replay_tests}" + assert len(function_to_tests) > 0, "No function-to-test mappings" + + # Verify function_to_tests has entries for our traced functions + has_compute_sum = any("computeSum" in key for key in function_to_tests) + assert has_compute_sum, f"computeSum not in function_to_tests keys: {list(function_to_tests.keys())}" + + # Step 4: Rank functions (like optimizer.rank_all_functions_globally) + if jfr_file.exists(): + from codeflash.benchmarking.function_ranker import JavaFunctionRanker + from codeflash.languages.java.jfr_parser import JfrProfile + + packages = set() + for func in all_functions: + parts = func.qualified_name.split(".") + if len(parts) >= 2: + packages.add(".".join(parts[:-1])) + + profile = JfrProfile(jfr_file, list(packages)) + ranker = JavaFunctionRanker(profile) + ranked = ranker.rank_functions(all_functions) + assert isinstance(ranked, list) + + finally: + # Clean up generated replay tests from fixture directory + for f in output_dir.glob("*.java"): + f.unlink() + if output_dir.exists() and not any(output_dir.iterdir()): + output_dir.rmdir() + codeflash_dir = output_dir.parent + if codeflash_dir.exists() and codeflash_dir.name == "codeflash" and not any(codeflash_dir.iterdir()): + codeflash_dir.rmdir() + + def test_instrument_and_compile_replay_tests(self, compiled_workload: Path, tmp_path: Path) -> None: + """Test that replay tests can be instrumented and compiled by Maven.""" + trace_db_path = tmp_path / "trace.db" + + test_root = FIXTURE_DIR / "src" / "test" / "java" + output_dir = test_root / "codeflash" / "replay" + output_dir.mkdir(parents=True, exist_ok=True) + + cleanup_paths: list[Path] = [] + try: + _trace_db, _jfr_file, test_count = run_java_tracer( + java_command=["java", "-cp", str(compiled_workload), WORKLOAD_CLASS], + trace_db_path=trace_db_path, + packages=[WORKLOAD_PACKAGE], + project_root=FIXTURE_DIR, + output_dir=output_dir, + ) + assert test_count >= 1 + + replay_test_paths = list(output_dir.glob("*.java")) + cleanup_paths.extend(replay_test_paths) + + # Instrument a replay test (like instrument_existing_tests does) + from codeflash.languages.java.discovery import discover_functions_from_source + from codeflash.languages.java.instrumentation import instrument_existing_test + + source_code = WORKLOAD_SOURCE.read_text(encoding="utf-8") + source_functions = discover_functions_from_source(source_code, file_path=WORKLOAD_SOURCE) + # Pick the first function with a return type for instrumentation + target_func = next(f for f in source_functions if f.function_name == "computeSum") + + replay_test_file = replay_test_paths[0] + test_source = replay_test_file.read_text(encoding="utf-8") + + # Instrument for behavior mode + success, instrumented_source = instrument_existing_test( + test_string=test_source, function_to_optimize=target_func, mode="behavior", test_path=replay_test_file + ) + assert success, "Failed to instrument replay test for behavior mode" + assert instrumented_source is not None + assert "__perfinstrumented" in instrumented_source + + # Write the instrumented test + instrumented_path = replay_test_file.parent / f"{replay_test_file.stem}__perfinstrumented.java" + instrumented_path.write_text(instrumented_source, encoding="utf-8") + cleanup_paths.append(instrumented_path) + + # Instrument for performance mode + success, perf_source = instrument_existing_test( + test_string=test_source, + function_to_optimize=target_func, + mode="performance", + test_path=replay_test_file, + ) + assert success, "Failed to instrument replay test for performance mode" + assert perf_source is not None + + perf_path = replay_test_file.parent / f"{replay_test_file.stem}__perfonlyinstrumented.java" + perf_path.write_text(perf_source, encoding="utf-8") + cleanup_paths.append(perf_path) + + # Install codeflash-runtime as Maven dependency and compile + from codeflash.languages.java.build_tool_strategy import get_strategy + + strategy = get_strategy(FIXTURE_DIR) + strategy.ensure_runtime(FIXTURE_DIR, None) + + import os + + compile_env = os.environ.copy() + compile_result = strategy.compile_tests(FIXTURE_DIR, compile_env, None, timeout=120) + + assert compile_result.returncode == 0, ( + f"Maven compilation failed (rc={compile_result.returncode}):\n" + f"stdout: {compile_result.stdout}\n" + f"stderr: {compile_result.stderr}" + ) + + finally: + for f in cleanup_paths: + f.unlink(missing_ok=True) + # Also clean up Maven build artifacts for the replay package + replay_classes = FIXTURE_DIR / "target" / "test-classes" / "codeflash" + if replay_classes.exists(): + import shutil + + shutil.rmtree(replay_classes, ignore_errors=True) + if output_dir.exists() and not any(output_dir.iterdir()): + output_dir.rmdir() + codeflash_dir = output_dir.parent + if codeflash_dir.exists() and codeflash_dir.name == "codeflash" and not any(codeflash_dir.iterdir()): + codeflash_dir.rmdir() diff --git a/tests/test_languages/test_java/test_test_discovery.py b/tests/test_languages/test_java/test_test_discovery.py index 1644a272a..d6830389c 100644 --- a/tests/test_languages/test_java/test_test_discovery.py +++ b/tests/test_languages/test_java/test_test_discovery.py @@ -557,3 +557,87 @@ def test_tests_suffix_pattern(self, tmp_path: Path): # CalculatorTests should match Calculator class assert len(result) > 0 assert "Calculator.add" in result + + +class TestReplayTestDiscovery: + """Tests for replay test file discovery.""" + + def test_discover_replay_tests(self, tmp_path: Path): + """Test that replay test files are discovered and mapped to source functions.""" + src_file = tmp_path / "Calculator.java" + src_file.write_text( + """ +public class Calculator { + public int add(int a, int b) { return a + b; } + public int multiply(int a, int b) { return a * b; } +} +""" + ) + + test_dir = tmp_path / "test" + test_dir.mkdir() + replay_test = test_dir / "ReplayTest_Calculator.java" + replay_test.write_text( + """// codeflash:functions=add,multiply +// codeflash:trace_file=/tmp/trace.db +// codeflash:classname=Calculator +package codeflash.replay; + +import org.junit.jupiter.api.Test; +import com.codeflash.ReplayHelper; + +class ReplayTest_Calculator { + private static final ReplayHelper helper = + new ReplayHelper("/tmp/trace.db"); + + @Test void replay_add_0() throws Exception { + helper.replay("Calculator", "add", "(II)I", 0); + } + + @Test void replay_multiply_0() throws Exception { + helper.replay("Calculator", "multiply", "(II)I", 0); + } +} +""" + ) + + source_functions = discover_functions_from_source(src_file.read_text(), file_path=src_file) + result = discover_tests(test_dir, source_functions) + + assert "Calculator.add" in result + assert "Calculator.multiply" in result + assert len(result["Calculator.add"]) == 2 # Both replay_add_0 and replay_multiply_0 mapped + assert len(result["Calculator.multiply"]) == 2 + + def test_replay_tests_not_confused_with_regular_tests(self, tmp_path: Path): + """Test that files without codeflash metadata are not treated as replay tests.""" + src_file = tmp_path / "Calculator.java" + src_file.write_text( + """ +public class Calculator { + public int add(int a, int b) { return a + b; } +} +""" + ) + + test_dir = tmp_path / "test" + test_dir.mkdir() + regular_test = test_dir / "CalculatorTest.java" + regular_test.write_text( + """ +import org.junit.jupiter.api.Test; +public class CalculatorTest { + @Test + public void testAdd() { + Calculator calc = new Calculator(); + calc.add(1, 2); + } +} +""" + ) + + source_functions = discover_functions_from_source(src_file.read_text(), file_path=src_file) + result = discover_tests(test_dir, source_functions) + + # Should find through regular static analysis, not replay metadata + assert "Calculator.add" in result